sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/config/trainer_config.py +18 -0
- sleap_nn/evaluation.py +81 -22
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +324 -8
- sleap_nn/training/lightning_modules.py +542 -32
- sleap_nn/training/model_trainer.py +48 -57
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""This module has the LightningModule classes for all model types."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional, Union, Dict, Any
|
|
3
|
+
from typing import Optional, Union, Dict, Any, List
|
|
4
4
|
import time
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
@@ -184,6 +184,15 @@ class LightningModel(L.LightningModule):
|
|
|
184
184
|
self.val_loss = {}
|
|
185
185
|
self.learning_rate = {}
|
|
186
186
|
|
|
187
|
+
# For epoch-averaged loss tracking
|
|
188
|
+
self._epoch_loss_sum = 0.0
|
|
189
|
+
self._epoch_loss_count = 0
|
|
190
|
+
|
|
191
|
+
# For epoch-end evaluation
|
|
192
|
+
self.val_predictions: List[Dict] = []
|
|
193
|
+
self.val_ground_truth: List[Dict] = []
|
|
194
|
+
self._collect_val_predictions: bool = False
|
|
195
|
+
|
|
187
196
|
# Initialization for encoder and decoder stacks.
|
|
188
197
|
if self.init_weights == "xavier":
|
|
189
198
|
self.model.apply(xavier_init_weights)
|
|
@@ -305,12 +314,20 @@ class LightningModel(L.LightningModule):
|
|
|
305
314
|
def on_train_epoch_start(self):
|
|
306
315
|
"""Configure the train timer at the beginning of each epoch."""
|
|
307
316
|
self.train_start_time = time.time()
|
|
317
|
+
# Reset epoch loss tracking
|
|
318
|
+
self._epoch_loss_sum = 0.0
|
|
319
|
+
self._epoch_loss_count = 0
|
|
320
|
+
|
|
321
|
+
def _accumulate_loss(self, loss: torch.Tensor):
|
|
322
|
+
"""Accumulate loss for epoch-averaged logging. Call this in training_step."""
|
|
323
|
+
self._epoch_loss_sum += loss.detach().item()
|
|
324
|
+
self._epoch_loss_count += 1
|
|
308
325
|
|
|
309
326
|
def on_train_epoch_end(self):
|
|
310
327
|
"""Configure the train timer at the end of every epoch."""
|
|
311
328
|
train_time = time.time() - self.train_start_time
|
|
312
329
|
self.log(
|
|
313
|
-
"
|
|
330
|
+
"train/time",
|
|
314
331
|
train_time,
|
|
315
332
|
prog_bar=False,
|
|
316
333
|
on_step=False,
|
|
@@ -327,16 +344,43 @@ class LightningModel(L.LightningModule):
|
|
|
327
344
|
logger=True,
|
|
328
345
|
sync_dist=True,
|
|
329
346
|
)
|
|
347
|
+
# Log epoch-averaged training loss
|
|
348
|
+
if self._epoch_loss_count > 0:
|
|
349
|
+
avg_loss = self._epoch_loss_sum / self._epoch_loss_count
|
|
350
|
+
self.log(
|
|
351
|
+
"train/loss",
|
|
352
|
+
avg_loss,
|
|
353
|
+
prog_bar=False,
|
|
354
|
+
on_step=False,
|
|
355
|
+
on_epoch=True,
|
|
356
|
+
logger=True,
|
|
357
|
+
sync_dist=True,
|
|
358
|
+
)
|
|
359
|
+
# Log current learning rate (useful for monitoring LR schedulers)
|
|
360
|
+
if self.trainer.optimizers:
|
|
361
|
+
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
|
362
|
+
self.log(
|
|
363
|
+
"train/lr",
|
|
364
|
+
lr,
|
|
365
|
+
prog_bar=False,
|
|
366
|
+
on_step=False,
|
|
367
|
+
on_epoch=True,
|
|
368
|
+
logger=True,
|
|
369
|
+
sync_dist=True,
|
|
370
|
+
)
|
|
330
371
|
|
|
331
372
|
def on_validation_epoch_start(self):
|
|
332
373
|
"""Configure the val timer at the beginning of each epoch."""
|
|
333
374
|
self.val_start_time = time.time()
|
|
375
|
+
# Clear accumulated predictions for new epoch
|
|
376
|
+
self.val_predictions = []
|
|
377
|
+
self.val_ground_truth = []
|
|
334
378
|
|
|
335
379
|
def on_validation_epoch_end(self):
|
|
336
380
|
"""Configure the val timer at the end of every epoch."""
|
|
337
381
|
val_time = time.time() - self.val_start_time
|
|
338
382
|
self.log(
|
|
339
|
-
"
|
|
383
|
+
"val/time",
|
|
340
384
|
val_time,
|
|
341
385
|
prog_bar=False,
|
|
342
386
|
on_step=False,
|
|
@@ -344,6 +388,16 @@ class LightningModel(L.LightningModule):
|
|
|
344
388
|
logger=True,
|
|
345
389
|
sync_dist=True,
|
|
346
390
|
)
|
|
391
|
+
# Log epoch explicitly so val/* metrics can use it as x-axis in wandb
|
|
392
|
+
# (mirrors what on_train_epoch_end does for train/* metrics)
|
|
393
|
+
self.log(
|
|
394
|
+
"epoch",
|
|
395
|
+
float(self.current_epoch),
|
|
396
|
+
on_step=False,
|
|
397
|
+
on_epoch=True,
|
|
398
|
+
logger=True,
|
|
399
|
+
sync_dist=True,
|
|
400
|
+
)
|
|
347
401
|
|
|
348
402
|
def training_step(self, batch, batch_idx):
|
|
349
403
|
"""Training step."""
|
|
@@ -412,7 +466,7 @@ class LightningModel(L.LightningModule):
|
|
|
412
466
|
"optimizer": optimizer,
|
|
413
467
|
"lr_scheduler": {
|
|
414
468
|
"scheduler": scheduler,
|
|
415
|
-
"monitor": "
|
|
469
|
+
"monitor": "val/loss",
|
|
416
470
|
},
|
|
417
471
|
}
|
|
418
472
|
|
|
@@ -591,7 +645,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
591
645
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
592
646
|
for node_idx, name in enumerate(self.node_names):
|
|
593
647
|
self.log(
|
|
594
|
-
f"{name}",
|
|
648
|
+
f"train/confmaps/{name}",
|
|
595
649
|
channel_wise_loss[node_idx],
|
|
596
650
|
prog_bar=False,
|
|
597
651
|
on_step=False,
|
|
@@ -599,8 +653,9 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
599
653
|
logger=True,
|
|
600
654
|
sync_dist=True,
|
|
601
655
|
)
|
|
656
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
602
657
|
self.log(
|
|
603
|
-
"
|
|
658
|
+
"loss",
|
|
604
659
|
loss,
|
|
605
660
|
prog_bar=True,
|
|
606
661
|
on_step=True,
|
|
@@ -608,6 +663,8 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
608
663
|
logger=True,
|
|
609
664
|
sync_dist=True,
|
|
610
665
|
)
|
|
666
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
667
|
+
self._accumulate_loss(loss)
|
|
611
668
|
return loss
|
|
612
669
|
|
|
613
670
|
def validation_step(self, batch, batch_idx):
|
|
@@ -630,7 +687,7 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
630
687
|
)
|
|
631
688
|
val_loss = val_loss + ohkm_loss
|
|
632
689
|
self.log(
|
|
633
|
-
"
|
|
690
|
+
"val/loss",
|
|
634
691
|
val_loss,
|
|
635
692
|
prog_bar=True,
|
|
636
693
|
on_step=False,
|
|
@@ -639,6 +696,51 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
639
696
|
sync_dist=True,
|
|
640
697
|
)
|
|
641
698
|
|
|
699
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
700
|
+
if self._collect_val_predictions:
|
|
701
|
+
with torch.no_grad():
|
|
702
|
+
# Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
|
|
703
|
+
inference_batch = {k: v for k, v in batch.items()}
|
|
704
|
+
if inference_batch["image"].ndim == 5:
|
|
705
|
+
inference_batch["image"] = inference_batch["image"].squeeze(1)
|
|
706
|
+
inference_output = self.single_instance_inf_layer(inference_batch)
|
|
707
|
+
if isinstance(inference_output, list):
|
|
708
|
+
inference_output = inference_output[0]
|
|
709
|
+
|
|
710
|
+
batch_size = len(batch["frame_idx"])
|
|
711
|
+
for i in range(batch_size):
|
|
712
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
713
|
+
|
|
714
|
+
# Predictions are already in original image space (inference divides by eff_scale)
|
|
715
|
+
pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
716
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
717
|
+
|
|
718
|
+
# Transform GT from preprocessed to original image space
|
|
719
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
720
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
721
|
+
if gt_prep.ndim == 4:
|
|
722
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
723
|
+
gt_orig = gt_prep / eff
|
|
724
|
+
num_inst = batch["num_instances"][i].item()
|
|
725
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
726
|
+
|
|
727
|
+
self.val_predictions.append(
|
|
728
|
+
{
|
|
729
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
730
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
731
|
+
"pred_peaks": pred_peaks,
|
|
732
|
+
"pred_scores": pred_scores,
|
|
733
|
+
}
|
|
734
|
+
)
|
|
735
|
+
self.val_ground_truth.append(
|
|
736
|
+
{
|
|
737
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
738
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
739
|
+
"gt_instances": gt_orig,
|
|
740
|
+
"num_instances": num_inst,
|
|
741
|
+
}
|
|
742
|
+
)
|
|
743
|
+
|
|
642
744
|
|
|
643
745
|
class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
644
746
|
"""Lightning Module for TopDownCenteredInstance Model.
|
|
@@ -807,7 +909,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
807
909
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
808
910
|
for node_idx, name in enumerate(self.node_names):
|
|
809
911
|
self.log(
|
|
810
|
-
f"{name}",
|
|
912
|
+
f"train/confmaps/{name}",
|
|
811
913
|
channel_wise_loss[node_idx],
|
|
812
914
|
prog_bar=False,
|
|
813
915
|
on_step=False,
|
|
@@ -816,8 +918,9 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
816
918
|
sync_dist=True,
|
|
817
919
|
)
|
|
818
920
|
|
|
921
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
819
922
|
self.log(
|
|
820
|
-
"
|
|
923
|
+
"loss",
|
|
821
924
|
loss,
|
|
822
925
|
prog_bar=True,
|
|
823
926
|
on_step=True,
|
|
@@ -825,6 +928,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
825
928
|
logger=True,
|
|
826
929
|
sync_dist=True,
|
|
827
930
|
)
|
|
931
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
932
|
+
self._accumulate_loss(loss)
|
|
828
933
|
return loss
|
|
829
934
|
|
|
830
935
|
def validation_step(self, batch, batch_idx):
|
|
@@ -847,7 +952,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
847
952
|
)
|
|
848
953
|
val_loss = val_loss + ohkm_loss
|
|
849
954
|
self.log(
|
|
850
|
-
"
|
|
955
|
+
"val/loss",
|
|
851
956
|
val_loss,
|
|
852
957
|
prog_bar=True,
|
|
853
958
|
on_step=False,
|
|
@@ -856,6 +961,62 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
856
961
|
sync_dist=True,
|
|
857
962
|
)
|
|
858
963
|
|
|
964
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
965
|
+
if self._collect_val_predictions:
|
|
966
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
967
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
968
|
+
|
|
969
|
+
with torch.no_grad():
|
|
970
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
971
|
+
|
|
972
|
+
batch_size = len(batch["frame_idx"])
|
|
973
|
+
for i in range(batch_size):
|
|
974
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
975
|
+
|
|
976
|
+
# Predictions from inference (crop-relative, original scale)
|
|
977
|
+
pred_peaks_crop = (
|
|
978
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
979
|
+
)
|
|
980
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
981
|
+
|
|
982
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
983
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
984
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
985
|
+
bbox_top_left_orig = (
|
|
986
|
+
bbox_prep[0] / eff
|
|
987
|
+
) # Top-left corner in original space
|
|
988
|
+
|
|
989
|
+
# Full image coordinates (original space)
|
|
990
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
991
|
+
|
|
992
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
993
|
+
gt_crop_prep = (
|
|
994
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
995
|
+
) # (n_nodes, 2)
|
|
996
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
997
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
998
|
+
|
|
999
|
+
self.val_predictions.append(
|
|
1000
|
+
{
|
|
1001
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1002
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1003
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
1004
|
+
1, -1, 2
|
|
1005
|
+
), # (1, n_nodes, 2)
|
|
1006
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
1007
|
+
}
|
|
1008
|
+
)
|
|
1009
|
+
self.val_ground_truth.append(
|
|
1010
|
+
{
|
|
1011
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1012
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1013
|
+
"gt_instances": gt_full_orig.reshape(
|
|
1014
|
+
1, -1, 2
|
|
1015
|
+
), # (1, n_nodes, 2)
|
|
1016
|
+
"num_instances": 1,
|
|
1017
|
+
}
|
|
1018
|
+
)
|
|
1019
|
+
|
|
859
1020
|
|
|
860
1021
|
class CentroidLightningModule(LightningModel):
|
|
861
1022
|
"""Lightning Module for Centroid Model.
|
|
@@ -1004,8 +1165,9 @@ class CentroidLightningModule(LightningModel):
|
|
|
1004
1165
|
|
|
1005
1166
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1006
1167
|
loss = nn.MSELoss()(y_preds, y)
|
|
1168
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1007
1169
|
self.log(
|
|
1008
|
-
"
|
|
1170
|
+
"loss",
|
|
1009
1171
|
loss,
|
|
1010
1172
|
prog_bar=True,
|
|
1011
1173
|
on_step=True,
|
|
@@ -1013,6 +1175,8 @@ class CentroidLightningModule(LightningModel):
|
|
|
1013
1175
|
logger=True,
|
|
1014
1176
|
sync_dist=True,
|
|
1015
1177
|
)
|
|
1178
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1179
|
+
self._accumulate_loss(loss)
|
|
1016
1180
|
return loss
|
|
1017
1181
|
|
|
1018
1182
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1025,7 +1189,7 @@ class CentroidLightningModule(LightningModel):
|
|
|
1025
1189
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1026
1190
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
1027
1191
|
self.log(
|
|
1028
|
-
"
|
|
1192
|
+
"val/loss",
|
|
1029
1193
|
val_loss,
|
|
1030
1194
|
prog_bar=True,
|
|
1031
1195
|
on_step=False,
|
|
@@ -1034,6 +1198,57 @@ class CentroidLightningModule(LightningModel):
|
|
|
1034
1198
|
sync_dist=True,
|
|
1035
1199
|
)
|
|
1036
1200
|
|
|
1201
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1202
|
+
if self._collect_val_predictions:
|
|
1203
|
+
with torch.no_grad():
|
|
1204
|
+
inference_output = self.centroid_inf_layer(batch)
|
|
1205
|
+
|
|
1206
|
+
batch_size = len(batch["frame_idx"])
|
|
1207
|
+
for i in range(batch_size):
|
|
1208
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1209
|
+
|
|
1210
|
+
# Predictions are in original image space (inference divides by eff_scale)
|
|
1211
|
+
# centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
|
|
1212
|
+
pred_centroids = (
|
|
1213
|
+
inference_output["centroids"][i].squeeze(0).cpu().numpy()
|
|
1214
|
+
)
|
|
1215
|
+
pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
|
|
1216
|
+
|
|
1217
|
+
# Transform GT centroids from preprocessed to original image space
|
|
1218
|
+
gt_centroids_prep = (
|
|
1219
|
+
batch["centroids"][i].cpu().numpy()
|
|
1220
|
+
) # (n_samples=1, max_inst, 2)
|
|
1221
|
+
gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
|
|
1222
|
+
num_inst = batch["num_instances"][i].item()
|
|
1223
|
+
|
|
1224
|
+
# Filter to valid instances (non-NaN)
|
|
1225
|
+
valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
|
|
1226
|
+
pred_centroids = pred_centroids[valid_pred_mask]
|
|
1227
|
+
pred_vals = pred_vals[valid_pred_mask]
|
|
1228
|
+
|
|
1229
|
+
gt_centroids_valid = gt_centroids_orig[:num_inst]
|
|
1230
|
+
|
|
1231
|
+
self.val_predictions.append(
|
|
1232
|
+
{
|
|
1233
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1234
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1235
|
+
"pred_peaks": pred_centroids.reshape(
|
|
1236
|
+
-1, 1, 2
|
|
1237
|
+
), # (n_inst, 1, 2)
|
|
1238
|
+
"pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
|
|
1239
|
+
}
|
|
1240
|
+
)
|
|
1241
|
+
self.val_ground_truth.append(
|
|
1242
|
+
{
|
|
1243
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1244
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1245
|
+
"gt_instances": gt_centroids_valid.reshape(
|
|
1246
|
+
-1, 1, 2
|
|
1247
|
+
), # (n_inst, 1, 2)
|
|
1248
|
+
"num_instances": num_inst,
|
|
1249
|
+
}
|
|
1250
|
+
)
|
|
1251
|
+
|
|
1037
1252
|
|
|
1038
1253
|
class BottomUpLightningModule(LightningModel):
|
|
1039
1254
|
"""Lightning Module for BottomUp Model.
|
|
@@ -1126,12 +1341,13 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1126
1341
|
self.bottomup_inf_layer = BottomUpInferenceModel(
|
|
1127
1342
|
torch_model=self.forward,
|
|
1128
1343
|
paf_scorer=paf_scorer,
|
|
1129
|
-
peak_threshold=0.
|
|
1344
|
+
peak_threshold=0.1, # Lower threshold for epoch-end eval during training
|
|
1130
1345
|
input_scale=1.0,
|
|
1131
1346
|
return_confmaps=True,
|
|
1132
1347
|
return_pafs=True,
|
|
1133
1348
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1134
1349
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1350
|
+
max_peaks_per_node=100, # Prevents combinatorial explosion in early training
|
|
1135
1351
|
)
|
|
1136
1352
|
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1137
1353
|
|
|
@@ -1248,8 +1464,9 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1248
1464
|
"PartAffinityFieldsHead": pafs_loss,
|
|
1249
1465
|
}
|
|
1250
1466
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1467
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1251
1468
|
self.log(
|
|
1252
|
-
"
|
|
1469
|
+
"loss",
|
|
1253
1470
|
loss,
|
|
1254
1471
|
prog_bar=True,
|
|
1255
1472
|
on_step=True,
|
|
@@ -1257,8 +1474,10 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1257
1474
|
logger=True,
|
|
1258
1475
|
sync_dist=True,
|
|
1259
1476
|
)
|
|
1477
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1478
|
+
self._accumulate_loss(loss)
|
|
1260
1479
|
self.log(
|
|
1261
|
-
"
|
|
1480
|
+
"train/confmaps_loss",
|
|
1262
1481
|
confmap_loss,
|
|
1263
1482
|
on_step=False,
|
|
1264
1483
|
on_epoch=True,
|
|
@@ -1266,7 +1485,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1266
1485
|
sync_dist=True,
|
|
1267
1486
|
)
|
|
1268
1487
|
self.log(
|
|
1269
|
-
"
|
|
1488
|
+
"train/paf_loss",
|
|
1270
1489
|
pafs_loss,
|
|
1271
1490
|
on_step=False,
|
|
1272
1491
|
on_epoch=True,
|
|
@@ -1315,7 +1534,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1315
1534
|
|
|
1316
1535
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1317
1536
|
self.log(
|
|
1318
|
-
"
|
|
1537
|
+
"val/loss",
|
|
1319
1538
|
val_loss,
|
|
1320
1539
|
prog_bar=True,
|
|
1321
1540
|
on_step=False,
|
|
@@ -1324,7 +1543,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1324
1543
|
sync_dist=True,
|
|
1325
1544
|
)
|
|
1326
1545
|
self.log(
|
|
1327
|
-
"
|
|
1546
|
+
"val/confmaps_loss",
|
|
1328
1547
|
confmap_loss,
|
|
1329
1548
|
on_step=False,
|
|
1330
1549
|
on_epoch=True,
|
|
@@ -1332,7 +1551,7 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1332
1551
|
sync_dist=True,
|
|
1333
1552
|
)
|
|
1334
1553
|
self.log(
|
|
1335
|
-
"
|
|
1554
|
+
"val/paf_loss",
|
|
1336
1555
|
pafs_loss,
|
|
1337
1556
|
on_step=False,
|
|
1338
1557
|
on_epoch=True,
|
|
@@ -1340,6 +1559,53 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1340
1559
|
sync_dist=True,
|
|
1341
1560
|
)
|
|
1342
1561
|
|
|
1562
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
1563
|
+
if self._collect_val_predictions:
|
|
1564
|
+
with torch.no_grad():
|
|
1565
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
1566
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
1567
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
1568
|
+
if isinstance(inference_output, list):
|
|
1569
|
+
inference_output = inference_output[0]
|
|
1570
|
+
|
|
1571
|
+
batch_size = len(batch["frame_idx"])
|
|
1572
|
+
for i in range(batch_size):
|
|
1573
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1574
|
+
|
|
1575
|
+
# Predictions are already in original space (variable number of instances)
|
|
1576
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
1577
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
1578
|
+
if torch.is_tensor(pred_peaks):
|
|
1579
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
1580
|
+
if torch.is_tensor(pred_scores):
|
|
1581
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
1582
|
+
|
|
1583
|
+
# Transform GT to original space
|
|
1584
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
1585
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
1586
|
+
if gt_prep.ndim == 4:
|
|
1587
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
1588
|
+
gt_orig = gt_prep / eff
|
|
1589
|
+
num_inst = batch["num_instances"][i].item()
|
|
1590
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
1591
|
+
|
|
1592
|
+
self.val_predictions.append(
|
|
1593
|
+
{
|
|
1594
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1595
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1596
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
1597
|
+
"pred_scores": pred_scores,
|
|
1598
|
+
}
|
|
1599
|
+
)
|
|
1600
|
+
self.val_ground_truth.append(
|
|
1601
|
+
{
|
|
1602
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
1603
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
1604
|
+
"gt_instances": gt_orig, # Original space
|
|
1605
|
+
"num_instances": num_inst,
|
|
1606
|
+
}
|
|
1607
|
+
)
|
|
1608
|
+
|
|
1343
1609
|
|
|
1344
1610
|
class BottomUpMultiClassLightningModule(LightningModel):
|
|
1345
1611
|
"""Lightning Module for BottomUp ID Model.
|
|
@@ -1541,8 +1807,9 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1541
1807
|
"ClassMapsHead": classmaps_loss,
|
|
1542
1808
|
}
|
|
1543
1809
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1810
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1544
1811
|
self.log(
|
|
1545
|
-
"
|
|
1812
|
+
"loss",
|
|
1546
1813
|
loss,
|
|
1547
1814
|
prog_bar=True,
|
|
1548
1815
|
on_step=True,
|
|
@@ -1550,8 +1817,10 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1550
1817
|
logger=True,
|
|
1551
1818
|
sync_dist=True,
|
|
1552
1819
|
)
|
|
1820
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1821
|
+
self._accumulate_loss(loss)
|
|
1553
1822
|
self.log(
|
|
1554
|
-
"
|
|
1823
|
+
"train/confmaps_loss",
|
|
1555
1824
|
confmap_loss,
|
|
1556
1825
|
on_step=False,
|
|
1557
1826
|
on_epoch=True,
|
|
@@ -1559,13 +1828,67 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1559
1828
|
sync_dist=True,
|
|
1560
1829
|
)
|
|
1561
1830
|
self.log(
|
|
1562
|
-
"
|
|
1831
|
+
"train/classmap_loss",
|
|
1563
1832
|
classmaps_loss,
|
|
1564
1833
|
on_step=False,
|
|
1565
1834
|
on_epoch=True,
|
|
1566
1835
|
logger=True,
|
|
1567
1836
|
sync_dist=True,
|
|
1568
1837
|
)
|
|
1838
|
+
|
|
1839
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1840
|
+
with torch.no_grad():
|
|
1841
|
+
# Get output stride for class maps
|
|
1842
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1843
|
+
|
|
1844
|
+
# Get GT instances and sample class maps at those locations
|
|
1845
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1846
|
+
if instances.dim() == 5:
|
|
1847
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1848
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1849
|
+
|
|
1850
|
+
correct = 0
|
|
1851
|
+
total = 0
|
|
1852
|
+
for b in range(instances.shape[0]):
|
|
1853
|
+
n_inst = num_instances[b].item()
|
|
1854
|
+
for inst_idx in range(n_inst):
|
|
1855
|
+
for node_idx in range(instances.shape[2]):
|
|
1856
|
+
# Get keypoint location (in input image space)
|
|
1857
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1858
|
+
if torch.isnan(kp).any():
|
|
1859
|
+
continue
|
|
1860
|
+
|
|
1861
|
+
# Convert to class map space
|
|
1862
|
+
x_cm = (
|
|
1863
|
+
(kp[0] / cms_stride)
|
|
1864
|
+
.long()
|
|
1865
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
1866
|
+
)
|
|
1867
|
+
y_cm = (
|
|
1868
|
+
(kp[1] / cms_stride)
|
|
1869
|
+
.long()
|
|
1870
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
1871
|
+
)
|
|
1872
|
+
|
|
1873
|
+
# Sample predicted and GT class at this location
|
|
1874
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1875
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1876
|
+
|
|
1877
|
+
if pred_class == gt_class:
|
|
1878
|
+
correct += 1
|
|
1879
|
+
total += 1
|
|
1880
|
+
|
|
1881
|
+
if total > 0:
|
|
1882
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1883
|
+
self.log(
|
|
1884
|
+
"train/class_accuracy",
|
|
1885
|
+
class_accuracy,
|
|
1886
|
+
on_step=False,
|
|
1887
|
+
on_epoch=True,
|
|
1888
|
+
logger=True,
|
|
1889
|
+
sync_dist=True,
|
|
1890
|
+
)
|
|
1891
|
+
|
|
1569
1892
|
return loss
|
|
1570
1893
|
|
|
1571
1894
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1599,7 +1922,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1599
1922
|
|
|
1600
1923
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1601
1924
|
self.log(
|
|
1602
|
-
"
|
|
1925
|
+
"val/loss",
|
|
1603
1926
|
val_loss,
|
|
1604
1927
|
prog_bar=True,
|
|
1605
1928
|
on_step=False,
|
|
@@ -1608,7 +1931,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1608
1931
|
sync_dist=True,
|
|
1609
1932
|
)
|
|
1610
1933
|
self.log(
|
|
1611
|
-
"
|
|
1934
|
+
"val/confmaps_loss",
|
|
1612
1935
|
confmap_loss,
|
|
1613
1936
|
on_step=False,
|
|
1614
1937
|
on_epoch=True,
|
|
@@ -1616,7 +1939,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1616
1939
|
sync_dist=True,
|
|
1617
1940
|
)
|
|
1618
1941
|
self.log(
|
|
1619
|
-
"
|
|
1942
|
+
"val/classmap_loss",
|
|
1620
1943
|
classmaps_loss,
|
|
1621
1944
|
on_step=False,
|
|
1622
1945
|
on_epoch=True,
|
|
@@ -1624,6 +1947,106 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1624
1947
|
sync_dist=True,
|
|
1625
1948
|
)
|
|
1626
1949
|
|
|
1950
|
+
# Compute classification accuracy at GT keypoint locations
|
|
1951
|
+
with torch.no_grad():
|
|
1952
|
+
# Get output stride for class maps
|
|
1953
|
+
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1954
|
+
|
|
1955
|
+
# Get GT instances and sample class maps at those locations
|
|
1956
|
+
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1957
|
+
if instances.dim() == 5:
|
|
1958
|
+
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1959
|
+
num_instances = batch["num_instances"] # (batch,)
|
|
1960
|
+
|
|
1961
|
+
correct = 0
|
|
1962
|
+
total = 0
|
|
1963
|
+
for b in range(instances.shape[0]):
|
|
1964
|
+
n_inst = num_instances[b].item()
|
|
1965
|
+
for inst_idx in range(n_inst):
|
|
1966
|
+
for node_idx in range(instances.shape[2]):
|
|
1967
|
+
# Get keypoint location (in input image space)
|
|
1968
|
+
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1969
|
+
if torch.isnan(kp).any():
|
|
1970
|
+
continue
|
|
1971
|
+
|
|
1972
|
+
# Convert to class map space
|
|
1973
|
+
x_cm = (
|
|
1974
|
+
(kp[0] / cms_stride)
|
|
1975
|
+
.long()
|
|
1976
|
+
.clamp(0, classmaps.shape[-1] - 1)
|
|
1977
|
+
)
|
|
1978
|
+
y_cm = (
|
|
1979
|
+
(kp[1] / cms_stride)
|
|
1980
|
+
.long()
|
|
1981
|
+
.clamp(0, classmaps.shape[-2] - 1)
|
|
1982
|
+
)
|
|
1983
|
+
|
|
1984
|
+
# Sample predicted and GT class at this location
|
|
1985
|
+
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1986
|
+
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1987
|
+
|
|
1988
|
+
if pred_class == gt_class:
|
|
1989
|
+
correct += 1
|
|
1990
|
+
total += 1
|
|
1991
|
+
|
|
1992
|
+
if total > 0:
|
|
1993
|
+
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1994
|
+
self.log(
|
|
1995
|
+
"val/class_accuracy",
|
|
1996
|
+
class_accuracy,
|
|
1997
|
+
on_step=False,
|
|
1998
|
+
on_epoch=True,
|
|
1999
|
+
logger=True,
|
|
2000
|
+
sync_dist=True,
|
|
2001
|
+
)
|
|
2002
|
+
|
|
2003
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2004
|
+
if self._collect_val_predictions:
|
|
2005
|
+
with torch.no_grad():
|
|
2006
|
+
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
2007
|
+
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
2008
|
+
inference_output = self.bottomup_inf_layer(batch)
|
|
2009
|
+
if isinstance(inference_output, list):
|
|
2010
|
+
inference_output = inference_output[0]
|
|
2011
|
+
|
|
2012
|
+
batch_size = len(batch["frame_idx"])
|
|
2013
|
+
for i in range(batch_size):
|
|
2014
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2015
|
+
|
|
2016
|
+
# Predictions are already in original space (variable number of instances)
|
|
2017
|
+
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
2018
|
+
pred_scores = inference_output["pred_peak_values"][i]
|
|
2019
|
+
if torch.is_tensor(pred_peaks):
|
|
2020
|
+
pred_peaks = pred_peaks.cpu().numpy()
|
|
2021
|
+
if torch.is_tensor(pred_scores):
|
|
2022
|
+
pred_scores = pred_scores.cpu().numpy()
|
|
2023
|
+
|
|
2024
|
+
# Transform GT to original space
|
|
2025
|
+
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
2026
|
+
gt_prep = batch["instances"][i].cpu().numpy()
|
|
2027
|
+
if gt_prep.ndim == 4:
|
|
2028
|
+
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
2029
|
+
gt_orig = gt_prep / eff
|
|
2030
|
+
num_inst = batch["num_instances"][i].item()
|
|
2031
|
+
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
2032
|
+
|
|
2033
|
+
self.val_predictions.append(
|
|
2034
|
+
{
|
|
2035
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2036
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2037
|
+
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
2038
|
+
"pred_scores": pred_scores,
|
|
2039
|
+
}
|
|
2040
|
+
)
|
|
2041
|
+
self.val_ground_truth.append(
|
|
2042
|
+
{
|
|
2043
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2044
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2045
|
+
"gt_instances": gt_orig, # Original space
|
|
2046
|
+
"num_instances": num_inst,
|
|
2047
|
+
}
|
|
2048
|
+
)
|
|
2049
|
+
|
|
1627
2050
|
|
|
1628
2051
|
class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
1629
2052
|
"""Lightning Module for TopDownCenteredInstance ID Model.
|
|
@@ -1803,7 +2226,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1803
2226
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
1804
2227
|
for node_idx, name in enumerate(self.node_names):
|
|
1805
2228
|
self.log(
|
|
1806
|
-
f"{name}",
|
|
2229
|
+
f"train/confmaps/{name}",
|
|
1807
2230
|
channel_wise_loss[node_idx],
|
|
1808
2231
|
prog_bar=False,
|
|
1809
2232
|
on_step=False,
|
|
@@ -1812,8 +2235,9 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1812
2235
|
sync_dist=True,
|
|
1813
2236
|
)
|
|
1814
2237
|
|
|
2238
|
+
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1815
2239
|
self.log(
|
|
1816
|
-
"
|
|
2240
|
+
"loss",
|
|
1817
2241
|
loss,
|
|
1818
2242
|
prog_bar=True,
|
|
1819
2243
|
on_step=True,
|
|
@@ -1821,8 +2245,10 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1821
2245
|
logger=True,
|
|
1822
2246
|
sync_dist=True,
|
|
1823
2247
|
)
|
|
2248
|
+
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
2249
|
+
self._accumulate_loss(loss)
|
|
1824
2250
|
self.log(
|
|
1825
|
-
"
|
|
2251
|
+
"train/confmaps_loss",
|
|
1826
2252
|
confmap_loss,
|
|
1827
2253
|
on_step=False,
|
|
1828
2254
|
on_epoch=True,
|
|
@@ -1830,13 +2256,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1830
2256
|
sync_dist=True,
|
|
1831
2257
|
)
|
|
1832
2258
|
self.log(
|
|
1833
|
-
"
|
|
2259
|
+
"train/classvector_loss",
|
|
1834
2260
|
classvector_loss,
|
|
1835
2261
|
on_step=False,
|
|
1836
2262
|
on_epoch=True,
|
|
1837
2263
|
logger=True,
|
|
1838
2264
|
sync_dist=True,
|
|
1839
2265
|
)
|
|
2266
|
+
|
|
2267
|
+
# Compute classification accuracy
|
|
2268
|
+
with torch.no_grad():
|
|
2269
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2270
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2271
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2272
|
+
self.log(
|
|
2273
|
+
"train/class_accuracy",
|
|
2274
|
+
class_accuracy,
|
|
2275
|
+
on_step=False,
|
|
2276
|
+
on_epoch=True,
|
|
2277
|
+
logger=True,
|
|
2278
|
+
sync_dist=True,
|
|
2279
|
+
)
|
|
1840
2280
|
return loss
|
|
1841
2281
|
|
|
1842
2282
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1868,7 +2308,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1868
2308
|
}
|
|
1869
2309
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1870
2310
|
self.log(
|
|
1871
|
-
"
|
|
2311
|
+
"val/loss",
|
|
1872
2312
|
val_loss,
|
|
1873
2313
|
prog_bar=True,
|
|
1874
2314
|
on_step=False,
|
|
@@ -1877,7 +2317,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1877
2317
|
sync_dist=True,
|
|
1878
2318
|
)
|
|
1879
2319
|
self.log(
|
|
1880
|
-
"
|
|
2320
|
+
"val/confmaps_loss",
|
|
1881
2321
|
confmap_loss,
|
|
1882
2322
|
on_step=False,
|
|
1883
2323
|
on_epoch=True,
|
|
@@ -1885,10 +2325,80 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
1885
2325
|
sync_dist=True,
|
|
1886
2326
|
)
|
|
1887
2327
|
self.log(
|
|
1888
|
-
"
|
|
2328
|
+
"val/classvector_loss",
|
|
1889
2329
|
classvector_loss,
|
|
1890
2330
|
on_step=False,
|
|
1891
2331
|
on_epoch=True,
|
|
1892
2332
|
logger=True,
|
|
1893
2333
|
sync_dist=True,
|
|
1894
2334
|
)
|
|
2335
|
+
|
|
2336
|
+
# Compute classification accuracy
|
|
2337
|
+
with torch.no_grad():
|
|
2338
|
+
pred_classes = torch.argmax(classvector, dim=1)
|
|
2339
|
+
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2340
|
+
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2341
|
+
self.log(
|
|
2342
|
+
"val/class_accuracy",
|
|
2343
|
+
class_accuracy,
|
|
2344
|
+
on_step=False,
|
|
2345
|
+
on_epoch=True,
|
|
2346
|
+
logger=True,
|
|
2347
|
+
sync_dist=True,
|
|
2348
|
+
)
|
|
2349
|
+
|
|
2350
|
+
# Collect predictions for epoch-end evaluation if enabled
|
|
2351
|
+
if self._collect_val_predictions:
|
|
2352
|
+
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
2353
|
+
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
2354
|
+
|
|
2355
|
+
with torch.no_grad():
|
|
2356
|
+
inference_output = self.instance_peaks_inf_layer(batch)
|
|
2357
|
+
|
|
2358
|
+
batch_size = len(batch["frame_idx"])
|
|
2359
|
+
for i in range(batch_size):
|
|
2360
|
+
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2361
|
+
|
|
2362
|
+
# Predictions from inference (crop-relative, original scale)
|
|
2363
|
+
pred_peaks_crop = (
|
|
2364
|
+
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
2365
|
+
)
|
|
2366
|
+
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
2367
|
+
|
|
2368
|
+
# Compute bbox offset in original space from SAVED prep bbox
|
|
2369
|
+
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
2370
|
+
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
2371
|
+
bbox_top_left_orig = (
|
|
2372
|
+
bbox_prep[0] / eff
|
|
2373
|
+
) # Top-left corner in original space
|
|
2374
|
+
|
|
2375
|
+
# Full image coordinates (original space)
|
|
2376
|
+
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
2377
|
+
|
|
2378
|
+
# GT transform: crop-relative preprocessed -> full image original
|
|
2379
|
+
gt_crop_prep = (
|
|
2380
|
+
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
2381
|
+
) # (n_nodes, 2)
|
|
2382
|
+
gt_crop_orig = gt_crop_prep / eff
|
|
2383
|
+
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
2384
|
+
|
|
2385
|
+
self.val_predictions.append(
|
|
2386
|
+
{
|
|
2387
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2388
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2389
|
+
"pred_peaks": pred_peaks_full.reshape(
|
|
2390
|
+
1, -1, 2
|
|
2391
|
+
), # (1, n_nodes, 2)
|
|
2392
|
+
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
2393
|
+
}
|
|
2394
|
+
)
|
|
2395
|
+
self.val_ground_truth.append(
|
|
2396
|
+
{
|
|
2397
|
+
"video_idx": batch["video_idx"][i].item(),
|
|
2398
|
+
"frame_idx": batch["frame_idx"][i].item(),
|
|
2399
|
+
"gt_instances": gt_full_orig.reshape(
|
|
2400
|
+
1, -1, 2
|
|
2401
|
+
), # (1, n_nodes, 2)
|
|
2402
|
+
"num_instances": 1,
|
|
2403
|
+
}
|
|
2404
|
+
)
|