sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +2 -4
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -10
- sleap_nn/config/trainer_config.py +0 -76
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +39 -411
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -74
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +184 -211
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
"""This module has the LightningModule classes for all model types."""
|
|
2
2
|
|
|
3
|
-
from typing import Optional, Union, Dict, Any
|
|
3
|
+
from typing import Optional, Union, Dict, Any
|
|
4
4
|
import time
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
@@ -51,16 +51,10 @@ matplotlib.use(
|
|
|
51
51
|
import matplotlib.pyplot as plt
|
|
52
52
|
from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
|
|
53
53
|
from sleap_nn.config.trainer_config import (
|
|
54
|
-
CosineAnnealingWarmupConfig,
|
|
55
|
-
LinearWarmupLinearDecayConfig,
|
|
56
54
|
LRSchedulerConfig,
|
|
57
55
|
ReduceLROnPlateauConfig,
|
|
58
56
|
StepLRConfig,
|
|
59
57
|
)
|
|
60
|
-
from sleap_nn.training.schedulers import (
|
|
61
|
-
LinearWarmupCosineAnnealingLR,
|
|
62
|
-
LinearWarmupLinearDecayLR,
|
|
63
|
-
)
|
|
64
58
|
from sleap_nn.config.get_config import get_backbone_config
|
|
65
59
|
from sleap_nn.legacy_models import (
|
|
66
60
|
load_legacy_model_weights,
|
|
@@ -190,15 +184,6 @@ class LightningModel(L.LightningModule):
|
|
|
190
184
|
self.val_loss = {}
|
|
191
185
|
self.learning_rate = {}
|
|
192
186
|
|
|
193
|
-
# For epoch-averaged loss tracking
|
|
194
|
-
self._epoch_loss_sum = 0.0
|
|
195
|
-
self._epoch_loss_count = 0
|
|
196
|
-
|
|
197
|
-
# For epoch-end evaluation
|
|
198
|
-
self.val_predictions: List[Dict] = []
|
|
199
|
-
self.val_ground_truth: List[Dict] = []
|
|
200
|
-
self._collect_val_predictions: bool = False
|
|
201
|
-
|
|
202
187
|
# Initialization for encoder and decoder stacks.
|
|
203
188
|
if self.init_weights == "xavier":
|
|
204
189
|
self.model.apply(xavier_init_weights)
|
|
@@ -235,9 +220,7 @@ class LightningModel(L.LightningModule):
|
|
|
235
220
|
elif self.pretrained_backbone_weights.endswith(".h5"):
|
|
236
221
|
# load from sleap model weights
|
|
237
222
|
load_legacy_model_weights(
|
|
238
|
-
self.model.backbone,
|
|
239
|
-
self.pretrained_backbone_weights,
|
|
240
|
-
component="backbone",
|
|
223
|
+
self.model.backbone, self.pretrained_backbone_weights
|
|
241
224
|
)
|
|
242
225
|
|
|
243
226
|
else:
|
|
@@ -266,9 +249,7 @@ class LightningModel(L.LightningModule):
|
|
|
266
249
|
elif self.pretrained_head_weights.endswith(".h5"):
|
|
267
250
|
# load from sleap model weights
|
|
268
251
|
load_legacy_model_weights(
|
|
269
|
-
self.model.head_layers,
|
|
270
|
-
self.pretrained_head_weights,
|
|
271
|
-
component="head",
|
|
252
|
+
self.model.head_layers, self.pretrained_head_weights
|
|
272
253
|
)
|
|
273
254
|
|
|
274
255
|
else:
|
|
@@ -324,24 +305,17 @@ class LightningModel(L.LightningModule):
|
|
|
324
305
|
def on_train_epoch_start(self):
|
|
325
306
|
"""Configure the train timer at the beginning of each epoch."""
|
|
326
307
|
self.train_start_time = time.time()
|
|
327
|
-
# Reset epoch loss tracking
|
|
328
|
-
self._epoch_loss_sum = 0.0
|
|
329
|
-
self._epoch_loss_count = 0
|
|
330
|
-
|
|
331
|
-
def _accumulate_loss(self, loss: torch.Tensor):
|
|
332
|
-
"""Accumulate loss for epoch-averaged logging. Call this in training_step."""
|
|
333
|
-
self._epoch_loss_sum += loss.detach().item()
|
|
334
|
-
self._epoch_loss_count += 1
|
|
335
308
|
|
|
336
309
|
def on_train_epoch_end(self):
|
|
337
310
|
"""Configure the train timer at the end of every epoch."""
|
|
338
311
|
train_time = time.time() - self.train_start_time
|
|
339
312
|
self.log(
|
|
340
|
-
"
|
|
313
|
+
"train_time",
|
|
341
314
|
train_time,
|
|
342
315
|
prog_bar=False,
|
|
343
316
|
on_step=False,
|
|
344
317
|
on_epoch=True,
|
|
318
|
+
logger=True,
|
|
345
319
|
sync_dist=True,
|
|
346
320
|
)
|
|
347
321
|
# Log epoch explicitly for custom x-axis support in wandb
|
|
@@ -350,56 +324,24 @@ class LightningModel(L.LightningModule):
|
|
|
350
324
|
float(self.current_epoch),
|
|
351
325
|
on_step=False,
|
|
352
326
|
on_epoch=True,
|
|
327
|
+
logger=True,
|
|
353
328
|
sync_dist=True,
|
|
354
329
|
)
|
|
355
|
-
# Log epoch-averaged training loss
|
|
356
|
-
if self._epoch_loss_count > 0:
|
|
357
|
-
avg_loss = self._epoch_loss_sum / self._epoch_loss_count
|
|
358
|
-
self.log(
|
|
359
|
-
"train/loss",
|
|
360
|
-
avg_loss,
|
|
361
|
-
prog_bar=False,
|
|
362
|
-
on_step=False,
|
|
363
|
-
on_epoch=True,
|
|
364
|
-
sync_dist=True,
|
|
365
|
-
)
|
|
366
|
-
# Log current learning rate (useful for monitoring LR schedulers)
|
|
367
|
-
if self.trainer.optimizers:
|
|
368
|
-
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
|
|
369
|
-
self.log(
|
|
370
|
-
"train/lr",
|
|
371
|
-
lr,
|
|
372
|
-
prog_bar=False,
|
|
373
|
-
on_step=False,
|
|
374
|
-
on_epoch=True,
|
|
375
|
-
sync_dist=True,
|
|
376
|
-
)
|
|
377
330
|
|
|
378
331
|
def on_validation_epoch_start(self):
|
|
379
332
|
"""Configure the val timer at the beginning of each epoch."""
|
|
380
333
|
self.val_start_time = time.time()
|
|
381
|
-
# Clear accumulated predictions for new epoch
|
|
382
|
-
self.val_predictions = []
|
|
383
|
-
self.val_ground_truth = []
|
|
384
334
|
|
|
385
335
|
def on_validation_epoch_end(self):
|
|
386
336
|
"""Configure the val timer at the end of every epoch."""
|
|
387
337
|
val_time = time.time() - self.val_start_time
|
|
388
338
|
self.log(
|
|
389
|
-
"
|
|
339
|
+
"val_time",
|
|
390
340
|
val_time,
|
|
391
341
|
prog_bar=False,
|
|
392
342
|
on_step=False,
|
|
393
343
|
on_epoch=True,
|
|
394
|
-
|
|
395
|
-
)
|
|
396
|
-
# Log epoch explicitly so val/* metrics can use it as x-axis in wandb
|
|
397
|
-
# (mirrors what on_train_epoch_end does for train/* metrics)
|
|
398
|
-
self.log(
|
|
399
|
-
"epoch",
|
|
400
|
-
float(self.current_epoch),
|
|
401
|
-
on_step=False,
|
|
402
|
-
on_epoch=True,
|
|
344
|
+
logger=True,
|
|
403
345
|
sync_dist=True,
|
|
404
346
|
)
|
|
405
347
|
|
|
@@ -436,51 +378,13 @@ class LightningModel(L.LightningModule):
|
|
|
436
378
|
lr_scheduler_cfg.step_lr = StepLRConfig()
|
|
437
379
|
elif self.lr_scheduler == "reduce_lr_on_plateau":
|
|
438
380
|
lr_scheduler_cfg.reduce_lr_on_plateau = ReduceLROnPlateauConfig()
|
|
439
|
-
elif self.lr_scheduler == "cosine_annealing_warmup":
|
|
440
|
-
lr_scheduler_cfg.cosine_annealing_warmup = CosineAnnealingWarmupConfig()
|
|
441
|
-
elif self.lr_scheduler == "linear_warmup_linear_decay":
|
|
442
|
-
lr_scheduler_cfg.linear_warmup_linear_decay = (
|
|
443
|
-
LinearWarmupLinearDecayConfig()
|
|
444
|
-
)
|
|
445
381
|
|
|
446
382
|
elif isinstance(self.lr_scheduler, dict):
|
|
447
383
|
lr_scheduler_cfg = self.lr_scheduler
|
|
448
384
|
|
|
449
385
|
for k, v in self.lr_scheduler.items():
|
|
450
386
|
if v is not None:
|
|
451
|
-
if k == "
|
|
452
|
-
cfg = self.lr_scheduler.cosine_annealing_warmup
|
|
453
|
-
# Use trainer's max_epochs if not specified in config
|
|
454
|
-
max_epochs = (
|
|
455
|
-
cfg.max_epochs
|
|
456
|
-
if cfg.max_epochs is not None
|
|
457
|
-
else self.trainer.max_epochs
|
|
458
|
-
)
|
|
459
|
-
scheduler = LinearWarmupCosineAnnealingLR(
|
|
460
|
-
optimizer=optimizer,
|
|
461
|
-
warmup_epochs=cfg.warmup_epochs,
|
|
462
|
-
max_epochs=max_epochs,
|
|
463
|
-
warmup_start_lr=cfg.warmup_start_lr,
|
|
464
|
-
eta_min=cfg.eta_min,
|
|
465
|
-
)
|
|
466
|
-
break
|
|
467
|
-
elif k == "linear_warmup_linear_decay":
|
|
468
|
-
cfg = self.lr_scheduler.linear_warmup_linear_decay
|
|
469
|
-
# Use trainer's max_epochs if not specified in config
|
|
470
|
-
max_epochs = (
|
|
471
|
-
cfg.max_epochs
|
|
472
|
-
if cfg.max_epochs is not None
|
|
473
|
-
else self.trainer.max_epochs
|
|
474
|
-
)
|
|
475
|
-
scheduler = LinearWarmupLinearDecayLR(
|
|
476
|
-
optimizer=optimizer,
|
|
477
|
-
warmup_epochs=cfg.warmup_epochs,
|
|
478
|
-
max_epochs=max_epochs,
|
|
479
|
-
warmup_start_lr=cfg.warmup_start_lr,
|
|
480
|
-
end_lr=cfg.end_lr,
|
|
481
|
-
)
|
|
482
|
-
break
|
|
483
|
-
elif k == "step_lr":
|
|
387
|
+
if k == "step_lr":
|
|
484
388
|
scheduler = torch.optim.lr_scheduler.StepLR(
|
|
485
389
|
optimizer=optimizer,
|
|
486
390
|
step_size=self.lr_scheduler.step_lr.step_size,
|
|
@@ -508,7 +412,7 @@ class LightningModel(L.LightningModule):
|
|
|
508
412
|
"optimizer": optimizer,
|
|
509
413
|
"lr_scheduler": {
|
|
510
414
|
"scheduler": scheduler,
|
|
511
|
-
"monitor": "
|
|
415
|
+
"monitor": "val_loss",
|
|
512
416
|
},
|
|
513
417
|
}
|
|
514
418
|
|
|
@@ -664,7 +568,6 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
664
568
|
torch.squeeze(batch["image"], dim=1),
|
|
665
569
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
666
570
|
)
|
|
667
|
-
X = normalize_on_gpu(X)
|
|
668
571
|
|
|
669
572
|
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
|
|
670
573
|
|
|
@@ -688,24 +591,23 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
688
591
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
689
592
|
for node_idx, name in enumerate(self.node_names):
|
|
690
593
|
self.log(
|
|
691
|
-
f"
|
|
594
|
+
f"{name}",
|
|
692
595
|
channel_wise_loss[node_idx],
|
|
693
596
|
prog_bar=False,
|
|
694
597
|
on_step=False,
|
|
695
598
|
on_epoch=True,
|
|
599
|
+
logger=True,
|
|
696
600
|
sync_dist=True,
|
|
697
601
|
)
|
|
698
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
699
602
|
self.log(
|
|
700
|
-
"
|
|
603
|
+
"train_loss",
|
|
701
604
|
loss,
|
|
702
605
|
prog_bar=True,
|
|
703
606
|
on_step=True,
|
|
704
607
|
on_epoch=False,
|
|
608
|
+
logger=True,
|
|
705
609
|
sync_dist=True,
|
|
706
610
|
)
|
|
707
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
708
|
-
self._accumulate_loss(loss)
|
|
709
611
|
return loss
|
|
710
612
|
|
|
711
613
|
def validation_step(self, batch, batch_idx):
|
|
@@ -714,7 +616,6 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
714
616
|
torch.squeeze(batch["image"], dim=1),
|
|
715
617
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
716
618
|
)
|
|
717
|
-
X = normalize_on_gpu(X)
|
|
718
619
|
|
|
719
620
|
y_preds = self.model(X)["SingleInstanceConfmapsHead"]
|
|
720
621
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
@@ -729,59 +630,15 @@ class SingleInstanceLightningModule(LightningModel):
|
|
|
729
630
|
)
|
|
730
631
|
val_loss = val_loss + ohkm_loss
|
|
731
632
|
self.log(
|
|
732
|
-
"
|
|
633
|
+
"val_loss",
|
|
733
634
|
val_loss,
|
|
734
635
|
prog_bar=True,
|
|
735
636
|
on_step=False,
|
|
736
637
|
on_epoch=True,
|
|
638
|
+
logger=True,
|
|
737
639
|
sync_dist=True,
|
|
738
640
|
)
|
|
739
641
|
|
|
740
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
741
|
-
if self._collect_val_predictions:
|
|
742
|
-
with torch.no_grad():
|
|
743
|
-
# Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
|
|
744
|
-
inference_batch = {k: v for k, v in batch.items()}
|
|
745
|
-
if inference_batch["image"].ndim == 5:
|
|
746
|
-
inference_batch["image"] = inference_batch["image"].squeeze(1)
|
|
747
|
-
inference_output = self.single_instance_inf_layer(inference_batch)
|
|
748
|
-
if isinstance(inference_output, list):
|
|
749
|
-
inference_output = inference_output[0]
|
|
750
|
-
|
|
751
|
-
batch_size = len(batch["frame_idx"])
|
|
752
|
-
for i in range(batch_size):
|
|
753
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
754
|
-
|
|
755
|
-
# Predictions are already in original image space (inference divides by eff_scale)
|
|
756
|
-
pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
757
|
-
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
758
|
-
|
|
759
|
-
# Transform GT from preprocessed to original image space
|
|
760
|
-
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
761
|
-
gt_prep = batch["instances"][i].cpu().numpy()
|
|
762
|
-
if gt_prep.ndim == 4:
|
|
763
|
-
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
764
|
-
gt_orig = gt_prep / eff
|
|
765
|
-
num_inst = batch["num_instances"][i].item()
|
|
766
|
-
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
767
|
-
|
|
768
|
-
self.val_predictions.append(
|
|
769
|
-
{
|
|
770
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
771
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
772
|
-
"pred_peaks": pred_peaks,
|
|
773
|
-
"pred_scores": pred_scores,
|
|
774
|
-
}
|
|
775
|
-
)
|
|
776
|
-
self.val_ground_truth.append(
|
|
777
|
-
{
|
|
778
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
779
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
780
|
-
"gt_instances": gt_orig,
|
|
781
|
-
"num_instances": num_inst,
|
|
782
|
-
}
|
|
783
|
-
)
|
|
784
|
-
|
|
785
642
|
|
|
786
643
|
class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
787
644
|
"""Lightning Module for TopDownCenteredInstance Model.
|
|
@@ -927,7 +784,6 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
927
784
|
torch.squeeze(batch["instance_image"], dim=1),
|
|
928
785
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
929
786
|
)
|
|
930
|
-
X = normalize_on_gpu(X)
|
|
931
787
|
|
|
932
788
|
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
|
|
933
789
|
|
|
@@ -951,25 +807,24 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
951
807
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
952
808
|
for node_idx, name in enumerate(self.node_names):
|
|
953
809
|
self.log(
|
|
954
|
-
f"
|
|
810
|
+
f"{name}",
|
|
955
811
|
channel_wise_loss[node_idx],
|
|
956
812
|
prog_bar=False,
|
|
957
813
|
on_step=False,
|
|
958
814
|
on_epoch=True,
|
|
815
|
+
logger=True,
|
|
959
816
|
sync_dist=True,
|
|
960
817
|
)
|
|
961
818
|
|
|
962
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
963
819
|
self.log(
|
|
964
|
-
"
|
|
820
|
+
"train_loss",
|
|
965
821
|
loss,
|
|
966
822
|
prog_bar=True,
|
|
967
823
|
on_step=True,
|
|
968
824
|
on_epoch=False,
|
|
825
|
+
logger=True,
|
|
969
826
|
sync_dist=True,
|
|
970
827
|
)
|
|
971
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
972
|
-
self._accumulate_loss(loss)
|
|
973
828
|
return loss
|
|
974
829
|
|
|
975
830
|
def validation_step(self, batch, batch_idx):
|
|
@@ -978,7 +833,6 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
978
833
|
torch.squeeze(batch["instance_image"], dim=1),
|
|
979
834
|
torch.squeeze(batch["confidence_maps"], dim=1),
|
|
980
835
|
)
|
|
981
|
-
X = normalize_on_gpu(X)
|
|
982
836
|
|
|
983
837
|
y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
|
|
984
838
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
@@ -993,70 +847,15 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
|
|
|
993
847
|
)
|
|
994
848
|
val_loss = val_loss + ohkm_loss
|
|
995
849
|
self.log(
|
|
996
|
-
"
|
|
850
|
+
"val_loss",
|
|
997
851
|
val_loss,
|
|
998
852
|
prog_bar=True,
|
|
999
853
|
on_step=False,
|
|
1000
854
|
on_epoch=True,
|
|
855
|
+
logger=True,
|
|
1001
856
|
sync_dist=True,
|
|
1002
857
|
)
|
|
1003
858
|
|
|
1004
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
1005
|
-
if self._collect_val_predictions:
|
|
1006
|
-
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
1007
|
-
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
1008
|
-
|
|
1009
|
-
with torch.no_grad():
|
|
1010
|
-
inference_output = self.instance_peaks_inf_layer(batch)
|
|
1011
|
-
|
|
1012
|
-
batch_size = len(batch["frame_idx"])
|
|
1013
|
-
for i in range(batch_size):
|
|
1014
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1015
|
-
|
|
1016
|
-
# Predictions from inference (crop-relative, original scale)
|
|
1017
|
-
pred_peaks_crop = (
|
|
1018
|
-
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
1019
|
-
)
|
|
1020
|
-
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
1021
|
-
|
|
1022
|
-
# Compute bbox offset in original space from SAVED prep bbox
|
|
1023
|
-
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
1024
|
-
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
1025
|
-
bbox_top_left_orig = (
|
|
1026
|
-
bbox_prep[0] / eff
|
|
1027
|
-
) # Top-left corner in original space
|
|
1028
|
-
|
|
1029
|
-
# Full image coordinates (original space)
|
|
1030
|
-
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
1031
|
-
|
|
1032
|
-
# GT transform: crop-relative preprocessed -> full image original
|
|
1033
|
-
gt_crop_prep = (
|
|
1034
|
-
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
1035
|
-
) # (n_nodes, 2)
|
|
1036
|
-
gt_crop_orig = gt_crop_prep / eff
|
|
1037
|
-
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
1038
|
-
|
|
1039
|
-
self.val_predictions.append(
|
|
1040
|
-
{
|
|
1041
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1042
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1043
|
-
"pred_peaks": pred_peaks_full.reshape(
|
|
1044
|
-
1, -1, 2
|
|
1045
|
-
), # (1, n_nodes, 2)
|
|
1046
|
-
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
1047
|
-
}
|
|
1048
|
-
)
|
|
1049
|
-
self.val_ground_truth.append(
|
|
1050
|
-
{
|
|
1051
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1052
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1053
|
-
"gt_instances": gt_full_orig.reshape(
|
|
1054
|
-
1, -1, 2
|
|
1055
|
-
), # (1, n_nodes, 2)
|
|
1056
|
-
"num_instances": 1,
|
|
1057
|
-
}
|
|
1058
|
-
)
|
|
1059
|
-
|
|
1060
859
|
|
|
1061
860
|
class CentroidLightningModule(LightningModel):
|
|
1062
861
|
"""Lightning Module for Centroid Model.
|
|
@@ -1202,21 +1001,18 @@ class CentroidLightningModule(LightningModel):
|
|
|
1202
1001
|
torch.squeeze(batch["image"], dim=1),
|
|
1203
1002
|
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
|
|
1204
1003
|
)
|
|
1205
|
-
X = normalize_on_gpu(X)
|
|
1206
1004
|
|
|
1207
1005
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1208
1006
|
loss = nn.MSELoss()(y_preds, y)
|
|
1209
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1210
1007
|
self.log(
|
|
1211
|
-
"
|
|
1008
|
+
"train_loss",
|
|
1212
1009
|
loss,
|
|
1213
1010
|
prog_bar=True,
|
|
1214
1011
|
on_step=True,
|
|
1215
1012
|
on_epoch=False,
|
|
1013
|
+
logger=True,
|
|
1216
1014
|
sync_dist=True,
|
|
1217
1015
|
)
|
|
1218
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1219
|
-
self._accumulate_loss(loss)
|
|
1220
1016
|
return loss
|
|
1221
1017
|
|
|
1222
1018
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1225,74 +1021,19 @@ class CentroidLightningModule(LightningModel):
|
|
|
1225
1021
|
torch.squeeze(batch["image"], dim=1),
|
|
1226
1022
|
torch.squeeze(batch["centroids_confidence_maps"], dim=1),
|
|
1227
1023
|
)
|
|
1228
|
-
X = normalize_on_gpu(X)
|
|
1229
1024
|
|
|
1230
1025
|
y_preds = self.model(X)["CentroidConfmapsHead"]
|
|
1231
1026
|
val_loss = nn.MSELoss()(y_preds, y)
|
|
1232
1027
|
self.log(
|
|
1233
|
-
"
|
|
1028
|
+
"val_loss",
|
|
1234
1029
|
val_loss,
|
|
1235
1030
|
prog_bar=True,
|
|
1236
1031
|
on_step=False,
|
|
1237
1032
|
on_epoch=True,
|
|
1033
|
+
logger=True,
|
|
1238
1034
|
sync_dist=True,
|
|
1239
1035
|
)
|
|
1240
1036
|
|
|
1241
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
1242
|
-
if self._collect_val_predictions:
|
|
1243
|
-
# Save GT centroids before inference (inference overwrites batch["centroids"])
|
|
1244
|
-
batch["gt_centroids"] = batch["centroids"].clone()
|
|
1245
|
-
|
|
1246
|
-
with torch.no_grad():
|
|
1247
|
-
inference_output = self.centroid_inf_layer(batch)
|
|
1248
|
-
|
|
1249
|
-
batch_size = len(batch["frame_idx"])
|
|
1250
|
-
for i in range(batch_size):
|
|
1251
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1252
|
-
|
|
1253
|
-
# Predictions are in original image space (inference divides by eff_scale)
|
|
1254
|
-
# centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
|
|
1255
|
-
pred_centroids = (
|
|
1256
|
-
inference_output["centroids"][i].squeeze(0).cpu().numpy()
|
|
1257
|
-
)
|
|
1258
|
-
pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
|
|
1259
|
-
|
|
1260
|
-
# Transform GT centroids from preprocessed to original image space
|
|
1261
|
-
# Use "gt_centroids" since inference overwrites "centroids" with predictions
|
|
1262
|
-
gt_centroids_prep = (
|
|
1263
|
-
batch["gt_centroids"][i].cpu().numpy()
|
|
1264
|
-
) # (n_samples=1, max_inst, 2)
|
|
1265
|
-
gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
|
|
1266
|
-
num_inst = batch["num_instances"][i].item()
|
|
1267
|
-
|
|
1268
|
-
# Filter to valid instances (non-NaN)
|
|
1269
|
-
valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
|
|
1270
|
-
pred_centroids = pred_centroids[valid_pred_mask]
|
|
1271
|
-
pred_vals = pred_vals[valid_pred_mask]
|
|
1272
|
-
|
|
1273
|
-
gt_centroids_valid = gt_centroids_orig[:num_inst]
|
|
1274
|
-
|
|
1275
|
-
self.val_predictions.append(
|
|
1276
|
-
{
|
|
1277
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1278
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1279
|
-
"pred_peaks": pred_centroids.reshape(
|
|
1280
|
-
-1, 1, 2
|
|
1281
|
-
), # (n_inst, 1, 2)
|
|
1282
|
-
"pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
|
|
1283
|
-
}
|
|
1284
|
-
)
|
|
1285
|
-
self.val_ground_truth.append(
|
|
1286
|
-
{
|
|
1287
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1288
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1289
|
-
"gt_instances": gt_centroids_valid.reshape(
|
|
1290
|
-
-1, 1, 2
|
|
1291
|
-
), # (n_inst, 1, 2)
|
|
1292
|
-
"num_instances": num_inst,
|
|
1293
|
-
}
|
|
1294
|
-
)
|
|
1295
|
-
|
|
1296
1037
|
|
|
1297
1038
|
class BottomUpLightningModule(LightningModel):
|
|
1298
1039
|
"""Lightning Module for BottomUp Model.
|
|
@@ -1385,13 +1126,12 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1385
1126
|
self.bottomup_inf_layer = BottomUpInferenceModel(
|
|
1386
1127
|
torch_model=self.forward,
|
|
1387
1128
|
paf_scorer=paf_scorer,
|
|
1388
|
-
peak_threshold=0.
|
|
1129
|
+
peak_threshold=0.2,
|
|
1389
1130
|
input_scale=1.0,
|
|
1390
1131
|
return_confmaps=True,
|
|
1391
1132
|
return_pafs=True,
|
|
1392
1133
|
cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
|
|
1393
1134
|
pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
|
|
1394
|
-
max_peaks_per_node=100, # Prevents combinatorial explosion in early training
|
|
1395
1135
|
)
|
|
1396
1136
|
self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
|
|
1397
1137
|
|
|
@@ -1476,7 +1216,6 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1476
1216
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1477
1217
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1478
1218
|
y_paf = batch["part_affinity_fields"]
|
|
1479
|
-
X = normalize_on_gpu(X)
|
|
1480
1219
|
preds = self.model(X)
|
|
1481
1220
|
pafs = preds["PartAffinityFieldsHead"]
|
|
1482
1221
|
confmaps = preds["MultiInstanceConfmapsHead"]
|
|
@@ -1509,29 +1248,29 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1509
1248
|
"PartAffinityFieldsHead": pafs_loss,
|
|
1510
1249
|
}
|
|
1511
1250
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1512
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1513
1251
|
self.log(
|
|
1514
|
-
"
|
|
1252
|
+
"train_loss",
|
|
1515
1253
|
loss,
|
|
1516
1254
|
prog_bar=True,
|
|
1517
1255
|
on_step=True,
|
|
1518
1256
|
on_epoch=False,
|
|
1257
|
+
logger=True,
|
|
1519
1258
|
sync_dist=True,
|
|
1520
1259
|
)
|
|
1521
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1522
|
-
self._accumulate_loss(loss)
|
|
1523
1260
|
self.log(
|
|
1524
|
-
"
|
|
1261
|
+
"train_confmap_loss",
|
|
1525
1262
|
confmap_loss,
|
|
1526
1263
|
on_step=False,
|
|
1527
1264
|
on_epoch=True,
|
|
1265
|
+
logger=True,
|
|
1528
1266
|
sync_dist=True,
|
|
1529
1267
|
)
|
|
1530
1268
|
self.log(
|
|
1531
|
-
"
|
|
1269
|
+
"train_paf_loss",
|
|
1532
1270
|
pafs_loss,
|
|
1533
1271
|
on_step=False,
|
|
1534
1272
|
on_epoch=True,
|
|
1273
|
+
logger=True,
|
|
1535
1274
|
sync_dist=True,
|
|
1536
1275
|
)
|
|
1537
1276
|
return loss
|
|
@@ -1541,7 +1280,6 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1541
1280
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1542
1281
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1543
1282
|
y_paf = batch["part_affinity_fields"]
|
|
1544
|
-
X = normalize_on_gpu(X)
|
|
1545
1283
|
|
|
1546
1284
|
preds = self.model(X)
|
|
1547
1285
|
pafs = preds["PartAffinityFieldsHead"]
|
|
@@ -1577,75 +1315,31 @@ class BottomUpLightningModule(LightningModel):
|
|
|
1577
1315
|
|
|
1578
1316
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1579
1317
|
self.log(
|
|
1580
|
-
"
|
|
1318
|
+
"val_loss",
|
|
1581
1319
|
val_loss,
|
|
1582
1320
|
prog_bar=True,
|
|
1583
1321
|
on_step=False,
|
|
1584
1322
|
on_epoch=True,
|
|
1323
|
+
logger=True,
|
|
1585
1324
|
sync_dist=True,
|
|
1586
1325
|
)
|
|
1587
1326
|
self.log(
|
|
1588
|
-
"
|
|
1327
|
+
"val_confmap_loss",
|
|
1589
1328
|
confmap_loss,
|
|
1590
1329
|
on_step=False,
|
|
1591
1330
|
on_epoch=True,
|
|
1331
|
+
logger=True,
|
|
1592
1332
|
sync_dist=True,
|
|
1593
1333
|
)
|
|
1594
1334
|
self.log(
|
|
1595
|
-
"
|
|
1335
|
+
"val_paf_loss",
|
|
1596
1336
|
pafs_loss,
|
|
1597
1337
|
on_step=False,
|
|
1598
1338
|
on_epoch=True,
|
|
1339
|
+
logger=True,
|
|
1599
1340
|
sync_dist=True,
|
|
1600
1341
|
)
|
|
1601
1342
|
|
|
1602
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
1603
|
-
if self._collect_val_predictions:
|
|
1604
|
-
with torch.no_grad():
|
|
1605
|
-
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
1606
|
-
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
1607
|
-
inference_output = self.bottomup_inf_layer(batch)
|
|
1608
|
-
if isinstance(inference_output, list):
|
|
1609
|
-
inference_output = inference_output[0]
|
|
1610
|
-
|
|
1611
|
-
batch_size = len(batch["frame_idx"])
|
|
1612
|
-
for i in range(batch_size):
|
|
1613
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
1614
|
-
|
|
1615
|
-
# Predictions are already in original space (variable number of instances)
|
|
1616
|
-
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
1617
|
-
pred_scores = inference_output["pred_peak_values"][i]
|
|
1618
|
-
if torch.is_tensor(pred_peaks):
|
|
1619
|
-
pred_peaks = pred_peaks.cpu().numpy()
|
|
1620
|
-
if torch.is_tensor(pred_scores):
|
|
1621
|
-
pred_scores = pred_scores.cpu().numpy()
|
|
1622
|
-
|
|
1623
|
-
# Transform GT to original space
|
|
1624
|
-
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
1625
|
-
gt_prep = batch["instances"][i].cpu().numpy()
|
|
1626
|
-
if gt_prep.ndim == 4:
|
|
1627
|
-
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
1628
|
-
gt_orig = gt_prep / eff
|
|
1629
|
-
num_inst = batch["num_instances"][i].item()
|
|
1630
|
-
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
1631
|
-
|
|
1632
|
-
self.val_predictions.append(
|
|
1633
|
-
{
|
|
1634
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1635
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1636
|
-
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
1637
|
-
"pred_scores": pred_scores,
|
|
1638
|
-
}
|
|
1639
|
-
)
|
|
1640
|
-
self.val_ground_truth.append(
|
|
1641
|
-
{
|
|
1642
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
1643
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
1644
|
-
"gt_instances": gt_orig, # Original space
|
|
1645
|
-
"num_instances": num_inst,
|
|
1646
|
-
}
|
|
1647
|
-
)
|
|
1648
|
-
|
|
1649
1343
|
|
|
1650
1344
|
class BottomUpMultiClassLightningModule(LightningModel):
|
|
1651
1345
|
"""Lightning Module for BottomUp ID Model.
|
|
@@ -1824,7 +1518,6 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1824
1518
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1825
1519
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1826
1520
|
y_classmap = torch.squeeze(batch["class_maps"], dim=1)
|
|
1827
|
-
X = normalize_on_gpu(X)
|
|
1828
1521
|
preds = self.model(X)
|
|
1829
1522
|
classmaps = preds["ClassMapsHead"]
|
|
1830
1523
|
confmaps = preds["MultiInstanceConfmapsHead"]
|
|
@@ -1848,84 +1541,31 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1848
1541
|
"ClassMapsHead": classmaps_loss,
|
|
1849
1542
|
}
|
|
1850
1543
|
loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1851
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
1852
1544
|
self.log(
|
|
1853
|
-
"
|
|
1545
|
+
"train_loss",
|
|
1854
1546
|
loss,
|
|
1855
1547
|
prog_bar=True,
|
|
1856
1548
|
on_step=True,
|
|
1857
1549
|
on_epoch=False,
|
|
1550
|
+
logger=True,
|
|
1858
1551
|
sync_dist=True,
|
|
1859
1552
|
)
|
|
1860
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
1861
|
-
self._accumulate_loss(loss)
|
|
1862
1553
|
self.log(
|
|
1863
|
-
"
|
|
1554
|
+
"train_confmap_loss",
|
|
1864
1555
|
confmap_loss,
|
|
1865
1556
|
on_step=False,
|
|
1866
1557
|
on_epoch=True,
|
|
1558
|
+
logger=True,
|
|
1867
1559
|
sync_dist=True,
|
|
1868
1560
|
)
|
|
1869
1561
|
self.log(
|
|
1870
|
-
"
|
|
1562
|
+
"train_classmap_loss",
|
|
1871
1563
|
classmaps_loss,
|
|
1872
1564
|
on_step=False,
|
|
1873
1565
|
on_epoch=True,
|
|
1566
|
+
logger=True,
|
|
1874
1567
|
sync_dist=True,
|
|
1875
1568
|
)
|
|
1876
|
-
|
|
1877
|
-
# Compute classification accuracy at GT keypoint locations
|
|
1878
|
-
with torch.no_grad():
|
|
1879
|
-
# Get output stride for class maps
|
|
1880
|
-
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1881
|
-
|
|
1882
|
-
# Get GT instances and sample class maps at those locations
|
|
1883
|
-
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1884
|
-
if instances.dim() == 5:
|
|
1885
|
-
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1886
|
-
num_instances = batch["num_instances"] # (batch,)
|
|
1887
|
-
|
|
1888
|
-
correct = 0
|
|
1889
|
-
total = 0
|
|
1890
|
-
for b in range(instances.shape[0]):
|
|
1891
|
-
n_inst = num_instances[b].item()
|
|
1892
|
-
for inst_idx in range(n_inst):
|
|
1893
|
-
for node_idx in range(instances.shape[2]):
|
|
1894
|
-
# Get keypoint location (in input image space)
|
|
1895
|
-
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
1896
|
-
if torch.isnan(kp).any():
|
|
1897
|
-
continue
|
|
1898
|
-
|
|
1899
|
-
# Convert to class map space
|
|
1900
|
-
x_cm = (
|
|
1901
|
-
(kp[0] / cms_stride)
|
|
1902
|
-
.long()
|
|
1903
|
-
.clamp(0, classmaps.shape[-1] - 1)
|
|
1904
|
-
)
|
|
1905
|
-
y_cm = (
|
|
1906
|
-
(kp[1] / cms_stride)
|
|
1907
|
-
.long()
|
|
1908
|
-
.clamp(0, classmaps.shape[-2] - 1)
|
|
1909
|
-
)
|
|
1910
|
-
|
|
1911
|
-
# Sample predicted and GT class at this location
|
|
1912
|
-
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
1913
|
-
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
1914
|
-
|
|
1915
|
-
if pred_class == gt_class:
|
|
1916
|
-
correct += 1
|
|
1917
|
-
total += 1
|
|
1918
|
-
|
|
1919
|
-
if total > 0:
|
|
1920
|
-
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
1921
|
-
self.log(
|
|
1922
|
-
"train/class_accuracy",
|
|
1923
|
-
class_accuracy,
|
|
1924
|
-
on_step=False,
|
|
1925
|
-
on_epoch=True,
|
|
1926
|
-
sync_dist=True,
|
|
1927
|
-
)
|
|
1928
|
-
|
|
1929
1569
|
return loss
|
|
1930
1570
|
|
|
1931
1571
|
def validation_step(self, batch, batch_idx):
|
|
@@ -1933,7 +1573,6 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1933
1573
|
X = torch.squeeze(batch["image"], dim=1)
|
|
1934
1574
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
1935
1575
|
y_classmap = torch.squeeze(batch["class_maps"], dim=1)
|
|
1936
|
-
X = normalize_on_gpu(X)
|
|
1937
1576
|
|
|
1938
1577
|
preds = self.model(X)
|
|
1939
1578
|
classmaps = preds["ClassMapsHead"]
|
|
@@ -1960,127 +1599,31 @@ class BottomUpMultiClassLightningModule(LightningModel):
|
|
|
1960
1599
|
|
|
1961
1600
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
1962
1601
|
self.log(
|
|
1963
|
-
"
|
|
1602
|
+
"val_loss",
|
|
1964
1603
|
val_loss,
|
|
1965
1604
|
prog_bar=True,
|
|
1966
1605
|
on_step=False,
|
|
1967
1606
|
on_epoch=True,
|
|
1607
|
+
logger=True,
|
|
1968
1608
|
sync_dist=True,
|
|
1969
1609
|
)
|
|
1970
1610
|
self.log(
|
|
1971
|
-
"
|
|
1611
|
+
"val_confmap_loss",
|
|
1972
1612
|
confmap_loss,
|
|
1973
1613
|
on_step=False,
|
|
1974
1614
|
on_epoch=True,
|
|
1615
|
+
logger=True,
|
|
1975
1616
|
sync_dist=True,
|
|
1976
1617
|
)
|
|
1977
1618
|
self.log(
|
|
1978
|
-
"
|
|
1619
|
+
"val_classmap_loss",
|
|
1979
1620
|
classmaps_loss,
|
|
1980
1621
|
on_step=False,
|
|
1981
1622
|
on_epoch=True,
|
|
1623
|
+
logger=True,
|
|
1982
1624
|
sync_dist=True,
|
|
1983
1625
|
)
|
|
1984
1626
|
|
|
1985
|
-
# Compute classification accuracy at GT keypoint locations
|
|
1986
|
-
with torch.no_grad():
|
|
1987
|
-
# Get output stride for class maps
|
|
1988
|
-
cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
|
|
1989
|
-
|
|
1990
|
-
# Get GT instances and sample class maps at those locations
|
|
1991
|
-
instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
|
|
1992
|
-
if instances.dim() == 5:
|
|
1993
|
-
instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
|
|
1994
|
-
num_instances = batch["num_instances"] # (batch,)
|
|
1995
|
-
|
|
1996
|
-
correct = 0
|
|
1997
|
-
total = 0
|
|
1998
|
-
for b in range(instances.shape[0]):
|
|
1999
|
-
n_inst = num_instances[b].item()
|
|
2000
|
-
for inst_idx in range(n_inst):
|
|
2001
|
-
for node_idx in range(instances.shape[2]):
|
|
2002
|
-
# Get keypoint location (in input image space)
|
|
2003
|
-
kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
|
|
2004
|
-
if torch.isnan(kp).any():
|
|
2005
|
-
continue
|
|
2006
|
-
|
|
2007
|
-
# Convert to class map space
|
|
2008
|
-
x_cm = (
|
|
2009
|
-
(kp[0] / cms_stride)
|
|
2010
|
-
.long()
|
|
2011
|
-
.clamp(0, classmaps.shape[-1] - 1)
|
|
2012
|
-
)
|
|
2013
|
-
y_cm = (
|
|
2014
|
-
(kp[1] / cms_stride)
|
|
2015
|
-
.long()
|
|
2016
|
-
.clamp(0, classmaps.shape[-2] - 1)
|
|
2017
|
-
)
|
|
2018
|
-
|
|
2019
|
-
# Sample predicted and GT class at this location
|
|
2020
|
-
pred_class = classmaps[b, :, y_cm, x_cm].argmax()
|
|
2021
|
-
gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
|
|
2022
|
-
|
|
2023
|
-
if pred_class == gt_class:
|
|
2024
|
-
correct += 1
|
|
2025
|
-
total += 1
|
|
2026
|
-
|
|
2027
|
-
if total > 0:
|
|
2028
|
-
class_accuracy = torch.tensor(correct / total, device=X.device)
|
|
2029
|
-
self.log(
|
|
2030
|
-
"val/class_accuracy",
|
|
2031
|
-
class_accuracy,
|
|
2032
|
-
on_step=False,
|
|
2033
|
-
on_epoch=True,
|
|
2034
|
-
sync_dist=True,
|
|
2035
|
-
)
|
|
2036
|
-
|
|
2037
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
2038
|
-
if self._collect_val_predictions:
|
|
2039
|
-
with torch.no_grad():
|
|
2040
|
-
# Note: Do NOT squeeze the image here - the forward() method expects
|
|
2041
|
-
# (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
|
|
2042
|
-
inference_output = self.bottomup_inf_layer(batch)
|
|
2043
|
-
if isinstance(inference_output, list):
|
|
2044
|
-
inference_output = inference_output[0]
|
|
2045
|
-
|
|
2046
|
-
batch_size = len(batch["frame_idx"])
|
|
2047
|
-
for i in range(batch_size):
|
|
2048
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2049
|
-
|
|
2050
|
-
# Predictions are already in original space (variable number of instances)
|
|
2051
|
-
pred_peaks = inference_output["pred_instance_peaks"][i]
|
|
2052
|
-
pred_scores = inference_output["pred_peak_values"][i]
|
|
2053
|
-
if torch.is_tensor(pred_peaks):
|
|
2054
|
-
pred_peaks = pred_peaks.cpu().numpy()
|
|
2055
|
-
if torch.is_tensor(pred_scores):
|
|
2056
|
-
pred_scores = pred_scores.cpu().numpy()
|
|
2057
|
-
|
|
2058
|
-
# Transform GT to original space
|
|
2059
|
-
# Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
|
|
2060
|
-
gt_prep = batch["instances"][i].cpu().numpy()
|
|
2061
|
-
if gt_prep.ndim == 4:
|
|
2062
|
-
gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
|
|
2063
|
-
gt_orig = gt_prep / eff
|
|
2064
|
-
num_inst = batch["num_instances"][i].item()
|
|
2065
|
-
gt_orig = gt_orig[:num_inst] # Only valid instances
|
|
2066
|
-
|
|
2067
|
-
self.val_predictions.append(
|
|
2068
|
-
{
|
|
2069
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
2070
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
2071
|
-
"pred_peaks": pred_peaks, # Original space, variable instances
|
|
2072
|
-
"pred_scores": pred_scores,
|
|
2073
|
-
}
|
|
2074
|
-
)
|
|
2075
|
-
self.val_ground_truth.append(
|
|
2076
|
-
{
|
|
2077
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
2078
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
2079
|
-
"gt_instances": gt_orig, # Original space
|
|
2080
|
-
"num_instances": num_inst,
|
|
2081
|
-
}
|
|
2082
|
-
)
|
|
2083
|
-
|
|
2084
1627
|
|
|
2085
1628
|
class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
2086
1629
|
"""Lightning Module for TopDownCenteredInstance ID Model.
|
|
@@ -2229,7 +1772,6 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2229
1772
|
X = torch.squeeze(batch["instance_image"], dim=1)
|
|
2230
1773
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
2231
1774
|
y_classvector = batch["class_vectors"]
|
|
2232
|
-
X = normalize_on_gpu(X)
|
|
2233
1775
|
preds = self.model(X)
|
|
2234
1776
|
classvector = preds["ClassVectorsHead"]
|
|
2235
1777
|
confmaps = preds["CenteredInstanceConfmapsHead"]
|
|
@@ -2261,50 +1803,38 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2261
1803
|
channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
|
|
2262
1804
|
for node_idx, name in enumerate(self.node_names):
|
|
2263
1805
|
self.log(
|
|
2264
|
-
f"
|
|
1806
|
+
f"{name}",
|
|
2265
1807
|
channel_wise_loss[node_idx],
|
|
2266
1808
|
prog_bar=False,
|
|
2267
1809
|
on_step=False,
|
|
2268
1810
|
on_epoch=True,
|
|
1811
|
+
logger=True,
|
|
2269
1812
|
sync_dist=True,
|
|
2270
1813
|
)
|
|
2271
1814
|
|
|
2272
|
-
# Log step-level loss (every batch, uses global_step x-axis)
|
|
2273
1815
|
self.log(
|
|
2274
|
-
"
|
|
1816
|
+
"train_loss",
|
|
2275
1817
|
loss,
|
|
2276
1818
|
prog_bar=True,
|
|
2277
1819
|
on_step=True,
|
|
2278
1820
|
on_epoch=False,
|
|
1821
|
+
logger=True,
|
|
2279
1822
|
sync_dist=True,
|
|
2280
1823
|
)
|
|
2281
|
-
# Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
|
|
2282
|
-
self._accumulate_loss(loss)
|
|
2283
1824
|
self.log(
|
|
2284
|
-
"
|
|
1825
|
+
"train_confmap_loss",
|
|
2285
1826
|
confmap_loss,
|
|
2286
1827
|
on_step=False,
|
|
2287
1828
|
on_epoch=True,
|
|
1829
|
+
logger=True,
|
|
2288
1830
|
sync_dist=True,
|
|
2289
1831
|
)
|
|
2290
1832
|
self.log(
|
|
2291
|
-
"
|
|
1833
|
+
"train_classvector_loss",
|
|
2292
1834
|
classvector_loss,
|
|
2293
1835
|
on_step=False,
|
|
2294
1836
|
on_epoch=True,
|
|
2295
|
-
|
|
2296
|
-
)
|
|
2297
|
-
|
|
2298
|
-
# Compute classification accuracy
|
|
2299
|
-
with torch.no_grad():
|
|
2300
|
-
pred_classes = torch.argmax(classvector, dim=1)
|
|
2301
|
-
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2302
|
-
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2303
|
-
self.log(
|
|
2304
|
-
"train/class_accuracy",
|
|
2305
|
-
class_accuracy,
|
|
2306
|
-
on_step=False,
|
|
2307
|
-
on_epoch=True,
|
|
1837
|
+
logger=True,
|
|
2308
1838
|
sync_dist=True,
|
|
2309
1839
|
)
|
|
2310
1840
|
return loss
|
|
@@ -2314,7 +1844,6 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2314
1844
|
X = torch.squeeze(batch["instance_image"], dim=1)
|
|
2315
1845
|
y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
|
|
2316
1846
|
y_classvector = batch["class_vectors"]
|
|
2317
|
-
X = normalize_on_gpu(X)
|
|
2318
1847
|
preds = self.model(X)
|
|
2319
1848
|
classvector = preds["ClassVectorsHead"]
|
|
2320
1849
|
confmaps = preds["CenteredInstanceConfmapsHead"]
|
|
@@ -2339,93 +1868,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
|
|
|
2339
1868
|
}
|
|
2340
1869
|
val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
|
|
2341
1870
|
self.log(
|
|
2342
|
-
"
|
|
1871
|
+
"val_loss",
|
|
2343
1872
|
val_loss,
|
|
2344
1873
|
prog_bar=True,
|
|
2345
1874
|
on_step=False,
|
|
2346
1875
|
on_epoch=True,
|
|
1876
|
+
logger=True,
|
|
2347
1877
|
sync_dist=True,
|
|
2348
1878
|
)
|
|
2349
1879
|
self.log(
|
|
2350
|
-
"
|
|
1880
|
+
"val_confmap_loss",
|
|
2351
1881
|
confmap_loss,
|
|
2352
1882
|
on_step=False,
|
|
2353
1883
|
on_epoch=True,
|
|
1884
|
+
logger=True,
|
|
2354
1885
|
sync_dist=True,
|
|
2355
1886
|
)
|
|
2356
1887
|
self.log(
|
|
2357
|
-
"
|
|
1888
|
+
"val_classvector_loss",
|
|
2358
1889
|
classvector_loss,
|
|
2359
1890
|
on_step=False,
|
|
2360
1891
|
on_epoch=True,
|
|
1892
|
+
logger=True,
|
|
2361
1893
|
sync_dist=True,
|
|
2362
1894
|
)
|
|
2363
|
-
|
|
2364
|
-
# Compute classification accuracy
|
|
2365
|
-
with torch.no_grad():
|
|
2366
|
-
pred_classes = torch.argmax(classvector, dim=1)
|
|
2367
|
-
gt_classes = torch.argmax(y_classvector, dim=1)
|
|
2368
|
-
class_accuracy = (pred_classes == gt_classes).float().mean()
|
|
2369
|
-
self.log(
|
|
2370
|
-
"val/class_accuracy",
|
|
2371
|
-
class_accuracy,
|
|
2372
|
-
on_step=False,
|
|
2373
|
-
on_epoch=True,
|
|
2374
|
-
sync_dist=True,
|
|
2375
|
-
)
|
|
2376
|
-
|
|
2377
|
-
# Collect predictions for epoch-end evaluation if enabled
|
|
2378
|
-
if self._collect_val_predictions:
|
|
2379
|
-
# SAVE bbox BEFORE inference (it modifies in-place!)
|
|
2380
|
-
bbox_prep_saved = batch["instance_bbox"].clone()
|
|
2381
|
-
|
|
2382
|
-
with torch.no_grad():
|
|
2383
|
-
inference_output = self.instance_peaks_inf_layer(batch)
|
|
2384
|
-
|
|
2385
|
-
batch_size = len(batch["frame_idx"])
|
|
2386
|
-
for i in range(batch_size):
|
|
2387
|
-
eff = batch["eff_scale"][i].cpu().numpy()
|
|
2388
|
-
|
|
2389
|
-
# Predictions from inference (crop-relative, original scale)
|
|
2390
|
-
pred_peaks_crop = (
|
|
2391
|
-
inference_output["pred_instance_peaks"][i].cpu().numpy()
|
|
2392
|
-
)
|
|
2393
|
-
pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
|
|
2394
|
-
|
|
2395
|
-
# Compute bbox offset in original space from SAVED prep bbox
|
|
2396
|
-
# bbox has shape (n_samples=1, 4, 2) where 4 corners
|
|
2397
|
-
bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
|
|
2398
|
-
bbox_top_left_orig = (
|
|
2399
|
-
bbox_prep[0] / eff
|
|
2400
|
-
) # Top-left corner in original space
|
|
2401
|
-
|
|
2402
|
-
# Full image coordinates (original space)
|
|
2403
|
-
pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
|
|
2404
|
-
|
|
2405
|
-
# GT transform: crop-relative preprocessed -> full image original
|
|
2406
|
-
gt_crop_prep = (
|
|
2407
|
-
batch["instance"][i].squeeze(0).cpu().numpy()
|
|
2408
|
-
) # (n_nodes, 2)
|
|
2409
|
-
gt_crop_orig = gt_crop_prep / eff
|
|
2410
|
-
gt_full_orig = gt_crop_orig + bbox_top_left_orig
|
|
2411
|
-
|
|
2412
|
-
self.val_predictions.append(
|
|
2413
|
-
{
|
|
2414
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
2415
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
2416
|
-
"pred_peaks": pred_peaks_full.reshape(
|
|
2417
|
-
1, -1, 2
|
|
2418
|
-
), # (1, n_nodes, 2)
|
|
2419
|
-
"pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
|
|
2420
|
-
}
|
|
2421
|
-
)
|
|
2422
|
-
self.val_ground_truth.append(
|
|
2423
|
-
{
|
|
2424
|
-
"video_idx": batch["video_idx"][i].item(),
|
|
2425
|
-
"frame_idx": batch["frame_idx"][i].item(),
|
|
2426
|
-
"gt_instances": gt_full_orig.reshape(
|
|
2427
|
-
1, -1, 2
|
|
2428
|
-
), # (1, n_nodes, 2)
|
|
2429
|
-
"num_instances": 1,
|
|
2430
|
-
}
|
|
2431
|
-
)
|