sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """This module has the LightningModule classes for all model types."""
2
2
 
3
- from typing import Optional, Union, Dict, Any, List
3
+ from typing import Optional, Union, Dict, Any
4
4
  import time
5
5
  from torch import nn
6
6
  import numpy as np
@@ -51,16 +51,10 @@ matplotlib.use(
51
51
  import matplotlib.pyplot as plt
52
52
  from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
53
53
  from sleap_nn.config.trainer_config import (
54
- CosineAnnealingWarmupConfig,
55
- LinearWarmupLinearDecayConfig,
56
54
  LRSchedulerConfig,
57
55
  ReduceLROnPlateauConfig,
58
56
  StepLRConfig,
59
57
  )
60
- from sleap_nn.training.schedulers import (
61
- LinearWarmupCosineAnnealingLR,
62
- LinearWarmupLinearDecayLR,
63
- )
64
58
  from sleap_nn.config.get_config import get_backbone_config
65
59
  from sleap_nn.legacy_models import (
66
60
  load_legacy_model_weights,
@@ -190,15 +184,6 @@ class LightningModel(L.LightningModule):
190
184
  self.val_loss = {}
191
185
  self.learning_rate = {}
192
186
 
193
- # For epoch-averaged loss tracking
194
- self._epoch_loss_sum = 0.0
195
- self._epoch_loss_count = 0
196
-
197
- # For epoch-end evaluation
198
- self.val_predictions: List[Dict] = []
199
- self.val_ground_truth: List[Dict] = []
200
- self._collect_val_predictions: bool = False
201
-
202
187
  # Initialization for encoder and decoder stacks.
203
188
  if self.init_weights == "xavier":
204
189
  self.model.apply(xavier_init_weights)
@@ -235,9 +220,7 @@ class LightningModel(L.LightningModule):
235
220
  elif self.pretrained_backbone_weights.endswith(".h5"):
236
221
  # load from sleap model weights
237
222
  load_legacy_model_weights(
238
- self.model.backbone,
239
- self.pretrained_backbone_weights,
240
- component="backbone",
223
+ self.model.backbone, self.pretrained_backbone_weights
241
224
  )
242
225
 
243
226
  else:
@@ -266,9 +249,7 @@ class LightningModel(L.LightningModule):
266
249
  elif self.pretrained_head_weights.endswith(".h5"):
267
250
  # load from sleap model weights
268
251
  load_legacy_model_weights(
269
- self.model.head_layers,
270
- self.pretrained_head_weights,
271
- component="head",
252
+ self.model.head_layers, self.pretrained_head_weights
272
253
  )
273
254
 
274
255
  else:
@@ -324,24 +305,17 @@ class LightningModel(L.LightningModule):
324
305
  def on_train_epoch_start(self):
325
306
  """Configure the train timer at the beginning of each epoch."""
326
307
  self.train_start_time = time.time()
327
- # Reset epoch loss tracking
328
- self._epoch_loss_sum = 0.0
329
- self._epoch_loss_count = 0
330
-
331
- def _accumulate_loss(self, loss: torch.Tensor):
332
- """Accumulate loss for epoch-averaged logging. Call this in training_step."""
333
- self._epoch_loss_sum += loss.detach().item()
334
- self._epoch_loss_count += 1
335
308
 
336
309
  def on_train_epoch_end(self):
337
310
  """Configure the train timer at the end of every epoch."""
338
311
  train_time = time.time() - self.train_start_time
339
312
  self.log(
340
- "train/time",
313
+ "train_time",
341
314
  train_time,
342
315
  prog_bar=False,
343
316
  on_step=False,
344
317
  on_epoch=True,
318
+ logger=True,
345
319
  sync_dist=True,
346
320
  )
347
321
  # Log epoch explicitly for custom x-axis support in wandb
@@ -350,56 +324,24 @@ class LightningModel(L.LightningModule):
350
324
  float(self.current_epoch),
351
325
  on_step=False,
352
326
  on_epoch=True,
327
+ logger=True,
353
328
  sync_dist=True,
354
329
  )
355
- # Log epoch-averaged training loss
356
- if self._epoch_loss_count > 0:
357
- avg_loss = self._epoch_loss_sum / self._epoch_loss_count
358
- self.log(
359
- "train/loss",
360
- avg_loss,
361
- prog_bar=False,
362
- on_step=False,
363
- on_epoch=True,
364
- sync_dist=True,
365
- )
366
- # Log current learning rate (useful for monitoring LR schedulers)
367
- if self.trainer.optimizers:
368
- lr = self.trainer.optimizers[0].param_groups[0]["lr"]
369
- self.log(
370
- "train/lr",
371
- lr,
372
- prog_bar=False,
373
- on_step=False,
374
- on_epoch=True,
375
- sync_dist=True,
376
- )
377
330
 
378
331
  def on_validation_epoch_start(self):
379
332
  """Configure the val timer at the beginning of each epoch."""
380
333
  self.val_start_time = time.time()
381
- # Clear accumulated predictions for new epoch
382
- self.val_predictions = []
383
- self.val_ground_truth = []
384
334
 
385
335
  def on_validation_epoch_end(self):
386
336
  """Configure the val timer at the end of every epoch."""
387
337
  val_time = time.time() - self.val_start_time
388
338
  self.log(
389
- "val/time",
339
+ "val_time",
390
340
  val_time,
391
341
  prog_bar=False,
392
342
  on_step=False,
393
343
  on_epoch=True,
394
- sync_dist=True,
395
- )
396
- # Log epoch explicitly so val/* metrics can use it as x-axis in wandb
397
- # (mirrors what on_train_epoch_end does for train/* metrics)
398
- self.log(
399
- "epoch",
400
- float(self.current_epoch),
401
- on_step=False,
402
- on_epoch=True,
344
+ logger=True,
403
345
  sync_dist=True,
404
346
  )
405
347
 
@@ -436,51 +378,13 @@ class LightningModel(L.LightningModule):
436
378
  lr_scheduler_cfg.step_lr = StepLRConfig()
437
379
  elif self.lr_scheduler == "reduce_lr_on_plateau":
438
380
  lr_scheduler_cfg.reduce_lr_on_plateau = ReduceLROnPlateauConfig()
439
- elif self.lr_scheduler == "cosine_annealing_warmup":
440
- lr_scheduler_cfg.cosine_annealing_warmup = CosineAnnealingWarmupConfig()
441
- elif self.lr_scheduler == "linear_warmup_linear_decay":
442
- lr_scheduler_cfg.linear_warmup_linear_decay = (
443
- LinearWarmupLinearDecayConfig()
444
- )
445
381
 
446
382
  elif isinstance(self.lr_scheduler, dict):
447
383
  lr_scheduler_cfg = self.lr_scheduler
448
384
 
449
385
  for k, v in self.lr_scheduler.items():
450
386
  if v is not None:
451
- if k == "cosine_annealing_warmup":
452
- cfg = self.lr_scheduler.cosine_annealing_warmup
453
- # Use trainer's max_epochs if not specified in config
454
- max_epochs = (
455
- cfg.max_epochs
456
- if cfg.max_epochs is not None
457
- else self.trainer.max_epochs
458
- )
459
- scheduler = LinearWarmupCosineAnnealingLR(
460
- optimizer=optimizer,
461
- warmup_epochs=cfg.warmup_epochs,
462
- max_epochs=max_epochs,
463
- warmup_start_lr=cfg.warmup_start_lr,
464
- eta_min=cfg.eta_min,
465
- )
466
- break
467
- elif k == "linear_warmup_linear_decay":
468
- cfg = self.lr_scheduler.linear_warmup_linear_decay
469
- # Use trainer's max_epochs if not specified in config
470
- max_epochs = (
471
- cfg.max_epochs
472
- if cfg.max_epochs is not None
473
- else self.trainer.max_epochs
474
- )
475
- scheduler = LinearWarmupLinearDecayLR(
476
- optimizer=optimizer,
477
- warmup_epochs=cfg.warmup_epochs,
478
- max_epochs=max_epochs,
479
- warmup_start_lr=cfg.warmup_start_lr,
480
- end_lr=cfg.end_lr,
481
- )
482
- break
483
- elif k == "step_lr":
387
+ if k == "step_lr":
484
388
  scheduler = torch.optim.lr_scheduler.StepLR(
485
389
  optimizer=optimizer,
486
390
  step_size=self.lr_scheduler.step_lr.step_size,
@@ -508,7 +412,7 @@ class LightningModel(L.LightningModule):
508
412
  "optimizer": optimizer,
509
413
  "lr_scheduler": {
510
414
  "scheduler": scheduler,
511
- "monitor": "val/loss",
415
+ "monitor": "val_loss",
512
416
  },
513
417
  }
514
418
 
@@ -664,7 +568,6 @@ class SingleInstanceLightningModule(LightningModel):
664
568
  torch.squeeze(batch["image"], dim=1),
665
569
  torch.squeeze(batch["confidence_maps"], dim=1),
666
570
  )
667
- X = normalize_on_gpu(X)
668
571
 
669
572
  y_preds = self.model(X)["SingleInstanceConfmapsHead"]
670
573
 
@@ -688,24 +591,23 @@ class SingleInstanceLightningModule(LightningModel):
688
591
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
689
592
  for node_idx, name in enumerate(self.node_names):
690
593
  self.log(
691
- f"train/confmaps/{name}",
594
+ f"{name}",
692
595
  channel_wise_loss[node_idx],
693
596
  prog_bar=False,
694
597
  on_step=False,
695
598
  on_epoch=True,
599
+ logger=True,
696
600
  sync_dist=True,
697
601
  )
698
- # Log step-level loss (every batch, uses global_step x-axis)
699
602
  self.log(
700
- "loss",
603
+ "train_loss",
701
604
  loss,
702
605
  prog_bar=True,
703
606
  on_step=True,
704
607
  on_epoch=False,
608
+ logger=True,
705
609
  sync_dist=True,
706
610
  )
707
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
708
- self._accumulate_loss(loss)
709
611
  return loss
710
612
 
711
613
  def validation_step(self, batch, batch_idx):
@@ -714,7 +616,6 @@ class SingleInstanceLightningModule(LightningModel):
714
616
  torch.squeeze(batch["image"], dim=1),
715
617
  torch.squeeze(batch["confidence_maps"], dim=1),
716
618
  )
717
- X = normalize_on_gpu(X)
718
619
 
719
620
  y_preds = self.model(X)["SingleInstanceConfmapsHead"]
720
621
  val_loss = nn.MSELoss()(y_preds, y)
@@ -729,59 +630,15 @@ class SingleInstanceLightningModule(LightningModel):
729
630
  )
730
631
  val_loss = val_loss + ohkm_loss
731
632
  self.log(
732
- "val/loss",
633
+ "val_loss",
733
634
  val_loss,
734
635
  prog_bar=True,
735
636
  on_step=False,
736
637
  on_epoch=True,
638
+ logger=True,
737
639
  sync_dist=True,
738
640
  )
739
641
 
740
- # Collect predictions for epoch-end evaluation if enabled
741
- if self._collect_val_predictions:
742
- with torch.no_grad():
743
- # Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
744
- inference_batch = {k: v for k, v in batch.items()}
745
- if inference_batch["image"].ndim == 5:
746
- inference_batch["image"] = inference_batch["image"].squeeze(1)
747
- inference_output = self.single_instance_inf_layer(inference_batch)
748
- if isinstance(inference_output, list):
749
- inference_output = inference_output[0]
750
-
751
- batch_size = len(batch["frame_idx"])
752
- for i in range(batch_size):
753
- eff = batch["eff_scale"][i].cpu().numpy()
754
-
755
- # Predictions are already in original image space (inference divides by eff_scale)
756
- pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
757
- pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
758
-
759
- # Transform GT from preprocessed to original image space
760
- # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
761
- gt_prep = batch["instances"][i].cpu().numpy()
762
- if gt_prep.ndim == 4:
763
- gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
764
- gt_orig = gt_prep / eff
765
- num_inst = batch["num_instances"][i].item()
766
- gt_orig = gt_orig[:num_inst] # Only valid instances
767
-
768
- self.val_predictions.append(
769
- {
770
- "video_idx": batch["video_idx"][i].item(),
771
- "frame_idx": batch["frame_idx"][i].item(),
772
- "pred_peaks": pred_peaks,
773
- "pred_scores": pred_scores,
774
- }
775
- )
776
- self.val_ground_truth.append(
777
- {
778
- "video_idx": batch["video_idx"][i].item(),
779
- "frame_idx": batch["frame_idx"][i].item(),
780
- "gt_instances": gt_orig,
781
- "num_instances": num_inst,
782
- }
783
- )
784
-
785
642
 
786
643
  class TopDownCenteredInstanceLightningModule(LightningModel):
787
644
  """Lightning Module for TopDownCenteredInstance Model.
@@ -927,7 +784,6 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
927
784
  torch.squeeze(batch["instance_image"], dim=1),
928
785
  torch.squeeze(batch["confidence_maps"], dim=1),
929
786
  )
930
- X = normalize_on_gpu(X)
931
787
 
932
788
  y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
933
789
 
@@ -951,25 +807,24 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
951
807
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
952
808
  for node_idx, name in enumerate(self.node_names):
953
809
  self.log(
954
- f"train/confmaps/{name}",
810
+ f"{name}",
955
811
  channel_wise_loss[node_idx],
956
812
  prog_bar=False,
957
813
  on_step=False,
958
814
  on_epoch=True,
815
+ logger=True,
959
816
  sync_dist=True,
960
817
  )
961
818
 
962
- # Log step-level loss (every batch, uses global_step x-axis)
963
819
  self.log(
964
- "loss",
820
+ "train_loss",
965
821
  loss,
966
822
  prog_bar=True,
967
823
  on_step=True,
968
824
  on_epoch=False,
825
+ logger=True,
969
826
  sync_dist=True,
970
827
  )
971
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
972
- self._accumulate_loss(loss)
973
828
  return loss
974
829
 
975
830
  def validation_step(self, batch, batch_idx):
@@ -978,7 +833,6 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
978
833
  torch.squeeze(batch["instance_image"], dim=1),
979
834
  torch.squeeze(batch["confidence_maps"], dim=1),
980
835
  )
981
- X = normalize_on_gpu(X)
982
836
 
983
837
  y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
984
838
  val_loss = nn.MSELoss()(y_preds, y)
@@ -993,70 +847,15 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
993
847
  )
994
848
  val_loss = val_loss + ohkm_loss
995
849
  self.log(
996
- "val/loss",
850
+ "val_loss",
997
851
  val_loss,
998
852
  prog_bar=True,
999
853
  on_step=False,
1000
854
  on_epoch=True,
855
+ logger=True,
1001
856
  sync_dist=True,
1002
857
  )
1003
858
 
1004
- # Collect predictions for epoch-end evaluation if enabled
1005
- if self._collect_val_predictions:
1006
- # SAVE bbox BEFORE inference (it modifies in-place!)
1007
- bbox_prep_saved = batch["instance_bbox"].clone()
1008
-
1009
- with torch.no_grad():
1010
- inference_output = self.instance_peaks_inf_layer(batch)
1011
-
1012
- batch_size = len(batch["frame_idx"])
1013
- for i in range(batch_size):
1014
- eff = batch["eff_scale"][i].cpu().numpy()
1015
-
1016
- # Predictions from inference (crop-relative, original scale)
1017
- pred_peaks_crop = (
1018
- inference_output["pred_instance_peaks"][i].cpu().numpy()
1019
- )
1020
- pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
1021
-
1022
- # Compute bbox offset in original space from SAVED prep bbox
1023
- # bbox has shape (n_samples=1, 4, 2) where 4 corners
1024
- bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
1025
- bbox_top_left_orig = (
1026
- bbox_prep[0] / eff
1027
- ) # Top-left corner in original space
1028
-
1029
- # Full image coordinates (original space)
1030
- pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
1031
-
1032
- # GT transform: crop-relative preprocessed -> full image original
1033
- gt_crop_prep = (
1034
- batch["instance"][i].squeeze(0).cpu().numpy()
1035
- ) # (n_nodes, 2)
1036
- gt_crop_orig = gt_crop_prep / eff
1037
- gt_full_orig = gt_crop_orig + bbox_top_left_orig
1038
-
1039
- self.val_predictions.append(
1040
- {
1041
- "video_idx": batch["video_idx"][i].item(),
1042
- "frame_idx": batch["frame_idx"][i].item(),
1043
- "pred_peaks": pred_peaks_full.reshape(
1044
- 1, -1, 2
1045
- ), # (1, n_nodes, 2)
1046
- "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
1047
- }
1048
- )
1049
- self.val_ground_truth.append(
1050
- {
1051
- "video_idx": batch["video_idx"][i].item(),
1052
- "frame_idx": batch["frame_idx"][i].item(),
1053
- "gt_instances": gt_full_orig.reshape(
1054
- 1, -1, 2
1055
- ), # (1, n_nodes, 2)
1056
- "num_instances": 1,
1057
- }
1058
- )
1059
-
1060
859
 
1061
860
  class CentroidLightningModule(LightningModel):
1062
861
  """Lightning Module for Centroid Model.
@@ -1202,21 +1001,18 @@ class CentroidLightningModule(LightningModel):
1202
1001
  torch.squeeze(batch["image"], dim=1),
1203
1002
  torch.squeeze(batch["centroids_confidence_maps"], dim=1),
1204
1003
  )
1205
- X = normalize_on_gpu(X)
1206
1004
 
1207
1005
  y_preds = self.model(X)["CentroidConfmapsHead"]
1208
1006
  loss = nn.MSELoss()(y_preds, y)
1209
- # Log step-level loss (every batch, uses global_step x-axis)
1210
1007
  self.log(
1211
- "loss",
1008
+ "train_loss",
1212
1009
  loss,
1213
1010
  prog_bar=True,
1214
1011
  on_step=True,
1215
1012
  on_epoch=False,
1013
+ logger=True,
1216
1014
  sync_dist=True,
1217
1015
  )
1218
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1219
- self._accumulate_loss(loss)
1220
1016
  return loss
1221
1017
 
1222
1018
  def validation_step(self, batch, batch_idx):
@@ -1225,74 +1021,19 @@ class CentroidLightningModule(LightningModel):
1225
1021
  torch.squeeze(batch["image"], dim=1),
1226
1022
  torch.squeeze(batch["centroids_confidence_maps"], dim=1),
1227
1023
  )
1228
- X = normalize_on_gpu(X)
1229
1024
 
1230
1025
  y_preds = self.model(X)["CentroidConfmapsHead"]
1231
1026
  val_loss = nn.MSELoss()(y_preds, y)
1232
1027
  self.log(
1233
- "val/loss",
1028
+ "val_loss",
1234
1029
  val_loss,
1235
1030
  prog_bar=True,
1236
1031
  on_step=False,
1237
1032
  on_epoch=True,
1033
+ logger=True,
1238
1034
  sync_dist=True,
1239
1035
  )
1240
1036
 
1241
- # Collect predictions for epoch-end evaluation if enabled
1242
- if self._collect_val_predictions:
1243
- # Save GT centroids before inference (inference overwrites batch["centroids"])
1244
- batch["gt_centroids"] = batch["centroids"].clone()
1245
-
1246
- with torch.no_grad():
1247
- inference_output = self.centroid_inf_layer(batch)
1248
-
1249
- batch_size = len(batch["frame_idx"])
1250
- for i in range(batch_size):
1251
- eff = batch["eff_scale"][i].cpu().numpy()
1252
-
1253
- # Predictions are in original image space (inference divides by eff_scale)
1254
- # centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
1255
- pred_centroids = (
1256
- inference_output["centroids"][i].squeeze(0).cpu().numpy()
1257
- )
1258
- pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
1259
-
1260
- # Transform GT centroids from preprocessed to original image space
1261
- # Use "gt_centroids" since inference overwrites "centroids" with predictions
1262
- gt_centroids_prep = (
1263
- batch["gt_centroids"][i].cpu().numpy()
1264
- ) # (n_samples=1, max_inst, 2)
1265
- gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
1266
- num_inst = batch["num_instances"][i].item()
1267
-
1268
- # Filter to valid instances (non-NaN)
1269
- valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
1270
- pred_centroids = pred_centroids[valid_pred_mask]
1271
- pred_vals = pred_vals[valid_pred_mask]
1272
-
1273
- gt_centroids_valid = gt_centroids_orig[:num_inst]
1274
-
1275
- self.val_predictions.append(
1276
- {
1277
- "video_idx": batch["video_idx"][i].item(),
1278
- "frame_idx": batch["frame_idx"][i].item(),
1279
- "pred_peaks": pred_centroids.reshape(
1280
- -1, 1, 2
1281
- ), # (n_inst, 1, 2)
1282
- "pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
1283
- }
1284
- )
1285
- self.val_ground_truth.append(
1286
- {
1287
- "video_idx": batch["video_idx"][i].item(),
1288
- "frame_idx": batch["frame_idx"][i].item(),
1289
- "gt_instances": gt_centroids_valid.reshape(
1290
- -1, 1, 2
1291
- ), # (n_inst, 1, 2)
1292
- "num_instances": num_inst,
1293
- }
1294
- )
1295
-
1296
1037
 
1297
1038
  class BottomUpLightningModule(LightningModel):
1298
1039
  """Lightning Module for BottomUp Model.
@@ -1385,13 +1126,12 @@ class BottomUpLightningModule(LightningModel):
1385
1126
  self.bottomup_inf_layer = BottomUpInferenceModel(
1386
1127
  torch_model=self.forward,
1387
1128
  paf_scorer=paf_scorer,
1388
- peak_threshold=0.1, # Lower threshold for epoch-end eval during training
1129
+ peak_threshold=0.2,
1389
1130
  input_scale=1.0,
1390
1131
  return_confmaps=True,
1391
1132
  return_pafs=True,
1392
1133
  cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
1393
1134
  pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
1394
- max_peaks_per_node=100, # Prevents combinatorial explosion in early training
1395
1135
  )
1396
1136
  self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
1397
1137
 
@@ -1476,7 +1216,6 @@ class BottomUpLightningModule(LightningModel):
1476
1216
  X = torch.squeeze(batch["image"], dim=1)
1477
1217
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1478
1218
  y_paf = batch["part_affinity_fields"]
1479
- X = normalize_on_gpu(X)
1480
1219
  preds = self.model(X)
1481
1220
  pafs = preds["PartAffinityFieldsHead"]
1482
1221
  confmaps = preds["MultiInstanceConfmapsHead"]
@@ -1509,29 +1248,29 @@ class BottomUpLightningModule(LightningModel):
1509
1248
  "PartAffinityFieldsHead": pafs_loss,
1510
1249
  }
1511
1250
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1512
- # Log step-level loss (every batch, uses global_step x-axis)
1513
1251
  self.log(
1514
- "loss",
1252
+ "train_loss",
1515
1253
  loss,
1516
1254
  prog_bar=True,
1517
1255
  on_step=True,
1518
1256
  on_epoch=False,
1257
+ logger=True,
1519
1258
  sync_dist=True,
1520
1259
  )
1521
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1522
- self._accumulate_loss(loss)
1523
1260
  self.log(
1524
- "train/confmaps_loss",
1261
+ "train_confmap_loss",
1525
1262
  confmap_loss,
1526
1263
  on_step=False,
1527
1264
  on_epoch=True,
1265
+ logger=True,
1528
1266
  sync_dist=True,
1529
1267
  )
1530
1268
  self.log(
1531
- "train/paf_loss",
1269
+ "train_paf_loss",
1532
1270
  pafs_loss,
1533
1271
  on_step=False,
1534
1272
  on_epoch=True,
1273
+ logger=True,
1535
1274
  sync_dist=True,
1536
1275
  )
1537
1276
  return loss
@@ -1541,7 +1280,6 @@ class BottomUpLightningModule(LightningModel):
1541
1280
  X = torch.squeeze(batch["image"], dim=1)
1542
1281
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1543
1282
  y_paf = batch["part_affinity_fields"]
1544
- X = normalize_on_gpu(X)
1545
1283
 
1546
1284
  preds = self.model(X)
1547
1285
  pafs = preds["PartAffinityFieldsHead"]
@@ -1577,75 +1315,31 @@ class BottomUpLightningModule(LightningModel):
1577
1315
 
1578
1316
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1579
1317
  self.log(
1580
- "val/loss",
1318
+ "val_loss",
1581
1319
  val_loss,
1582
1320
  prog_bar=True,
1583
1321
  on_step=False,
1584
1322
  on_epoch=True,
1323
+ logger=True,
1585
1324
  sync_dist=True,
1586
1325
  )
1587
1326
  self.log(
1588
- "val/confmaps_loss",
1327
+ "val_confmap_loss",
1589
1328
  confmap_loss,
1590
1329
  on_step=False,
1591
1330
  on_epoch=True,
1331
+ logger=True,
1592
1332
  sync_dist=True,
1593
1333
  )
1594
1334
  self.log(
1595
- "val/paf_loss",
1335
+ "val_paf_loss",
1596
1336
  pafs_loss,
1597
1337
  on_step=False,
1598
1338
  on_epoch=True,
1339
+ logger=True,
1599
1340
  sync_dist=True,
1600
1341
  )
1601
1342
 
1602
- # Collect predictions for epoch-end evaluation if enabled
1603
- if self._collect_val_predictions:
1604
- with torch.no_grad():
1605
- # Note: Do NOT squeeze the image here - the forward() method expects
1606
- # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
1607
- inference_output = self.bottomup_inf_layer(batch)
1608
- if isinstance(inference_output, list):
1609
- inference_output = inference_output[0]
1610
-
1611
- batch_size = len(batch["frame_idx"])
1612
- for i in range(batch_size):
1613
- eff = batch["eff_scale"][i].cpu().numpy()
1614
-
1615
- # Predictions are already in original space (variable number of instances)
1616
- pred_peaks = inference_output["pred_instance_peaks"][i]
1617
- pred_scores = inference_output["pred_peak_values"][i]
1618
- if torch.is_tensor(pred_peaks):
1619
- pred_peaks = pred_peaks.cpu().numpy()
1620
- if torch.is_tensor(pred_scores):
1621
- pred_scores = pred_scores.cpu().numpy()
1622
-
1623
- # Transform GT to original space
1624
- # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
1625
- gt_prep = batch["instances"][i].cpu().numpy()
1626
- if gt_prep.ndim == 4:
1627
- gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
1628
- gt_orig = gt_prep / eff
1629
- num_inst = batch["num_instances"][i].item()
1630
- gt_orig = gt_orig[:num_inst] # Only valid instances
1631
-
1632
- self.val_predictions.append(
1633
- {
1634
- "video_idx": batch["video_idx"][i].item(),
1635
- "frame_idx": batch["frame_idx"][i].item(),
1636
- "pred_peaks": pred_peaks, # Original space, variable instances
1637
- "pred_scores": pred_scores,
1638
- }
1639
- )
1640
- self.val_ground_truth.append(
1641
- {
1642
- "video_idx": batch["video_idx"][i].item(),
1643
- "frame_idx": batch["frame_idx"][i].item(),
1644
- "gt_instances": gt_orig, # Original space
1645
- "num_instances": num_inst,
1646
- }
1647
- )
1648
-
1649
1343
 
1650
1344
  class BottomUpMultiClassLightningModule(LightningModel):
1651
1345
  """Lightning Module for BottomUp ID Model.
@@ -1824,7 +1518,6 @@ class BottomUpMultiClassLightningModule(LightningModel):
1824
1518
  X = torch.squeeze(batch["image"], dim=1)
1825
1519
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1826
1520
  y_classmap = torch.squeeze(batch["class_maps"], dim=1)
1827
- X = normalize_on_gpu(X)
1828
1521
  preds = self.model(X)
1829
1522
  classmaps = preds["ClassMapsHead"]
1830
1523
  confmaps = preds["MultiInstanceConfmapsHead"]
@@ -1848,84 +1541,31 @@ class BottomUpMultiClassLightningModule(LightningModel):
1848
1541
  "ClassMapsHead": classmaps_loss,
1849
1542
  }
1850
1543
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1851
- # Log step-level loss (every batch, uses global_step x-axis)
1852
1544
  self.log(
1853
- "loss",
1545
+ "train_loss",
1854
1546
  loss,
1855
1547
  prog_bar=True,
1856
1548
  on_step=True,
1857
1549
  on_epoch=False,
1550
+ logger=True,
1858
1551
  sync_dist=True,
1859
1552
  )
1860
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1861
- self._accumulate_loss(loss)
1862
1553
  self.log(
1863
- "train/confmaps_loss",
1554
+ "train_confmap_loss",
1864
1555
  confmap_loss,
1865
1556
  on_step=False,
1866
1557
  on_epoch=True,
1558
+ logger=True,
1867
1559
  sync_dist=True,
1868
1560
  )
1869
1561
  self.log(
1870
- "train/classmap_loss",
1562
+ "train_classmap_loss",
1871
1563
  classmaps_loss,
1872
1564
  on_step=False,
1873
1565
  on_epoch=True,
1566
+ logger=True,
1874
1567
  sync_dist=True,
1875
1568
  )
1876
-
1877
- # Compute classification accuracy at GT keypoint locations
1878
- with torch.no_grad():
1879
- # Get output stride for class maps
1880
- cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1881
-
1882
- # Get GT instances and sample class maps at those locations
1883
- instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1884
- if instances.dim() == 5:
1885
- instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1886
- num_instances = batch["num_instances"] # (batch,)
1887
-
1888
- correct = 0
1889
- total = 0
1890
- for b in range(instances.shape[0]):
1891
- n_inst = num_instances[b].item()
1892
- for inst_idx in range(n_inst):
1893
- for node_idx in range(instances.shape[2]):
1894
- # Get keypoint location (in input image space)
1895
- kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1896
- if torch.isnan(kp).any():
1897
- continue
1898
-
1899
- # Convert to class map space
1900
- x_cm = (
1901
- (kp[0] / cms_stride)
1902
- .long()
1903
- .clamp(0, classmaps.shape[-1] - 1)
1904
- )
1905
- y_cm = (
1906
- (kp[1] / cms_stride)
1907
- .long()
1908
- .clamp(0, classmaps.shape[-2] - 1)
1909
- )
1910
-
1911
- # Sample predicted and GT class at this location
1912
- pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1913
- gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1914
-
1915
- if pred_class == gt_class:
1916
- correct += 1
1917
- total += 1
1918
-
1919
- if total > 0:
1920
- class_accuracy = torch.tensor(correct / total, device=X.device)
1921
- self.log(
1922
- "train/class_accuracy",
1923
- class_accuracy,
1924
- on_step=False,
1925
- on_epoch=True,
1926
- sync_dist=True,
1927
- )
1928
-
1929
1569
  return loss
1930
1570
 
1931
1571
  def validation_step(self, batch, batch_idx):
@@ -1933,7 +1573,6 @@ class BottomUpMultiClassLightningModule(LightningModel):
1933
1573
  X = torch.squeeze(batch["image"], dim=1)
1934
1574
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1935
1575
  y_classmap = torch.squeeze(batch["class_maps"], dim=1)
1936
- X = normalize_on_gpu(X)
1937
1576
 
1938
1577
  preds = self.model(X)
1939
1578
  classmaps = preds["ClassMapsHead"]
@@ -1960,127 +1599,31 @@ class BottomUpMultiClassLightningModule(LightningModel):
1960
1599
 
1961
1600
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1962
1601
  self.log(
1963
- "val/loss",
1602
+ "val_loss",
1964
1603
  val_loss,
1965
1604
  prog_bar=True,
1966
1605
  on_step=False,
1967
1606
  on_epoch=True,
1607
+ logger=True,
1968
1608
  sync_dist=True,
1969
1609
  )
1970
1610
  self.log(
1971
- "val/confmaps_loss",
1611
+ "val_confmap_loss",
1972
1612
  confmap_loss,
1973
1613
  on_step=False,
1974
1614
  on_epoch=True,
1615
+ logger=True,
1975
1616
  sync_dist=True,
1976
1617
  )
1977
1618
  self.log(
1978
- "val/classmap_loss",
1619
+ "val_classmap_loss",
1979
1620
  classmaps_loss,
1980
1621
  on_step=False,
1981
1622
  on_epoch=True,
1623
+ logger=True,
1982
1624
  sync_dist=True,
1983
1625
  )
1984
1626
 
1985
- # Compute classification accuracy at GT keypoint locations
1986
- with torch.no_grad():
1987
- # Get output stride for class maps
1988
- cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1989
-
1990
- # Get GT instances and sample class maps at those locations
1991
- instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1992
- if instances.dim() == 5:
1993
- instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1994
- num_instances = batch["num_instances"] # (batch,)
1995
-
1996
- correct = 0
1997
- total = 0
1998
- for b in range(instances.shape[0]):
1999
- n_inst = num_instances[b].item()
2000
- for inst_idx in range(n_inst):
2001
- for node_idx in range(instances.shape[2]):
2002
- # Get keypoint location (in input image space)
2003
- kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
2004
- if torch.isnan(kp).any():
2005
- continue
2006
-
2007
- # Convert to class map space
2008
- x_cm = (
2009
- (kp[0] / cms_stride)
2010
- .long()
2011
- .clamp(0, classmaps.shape[-1] - 1)
2012
- )
2013
- y_cm = (
2014
- (kp[1] / cms_stride)
2015
- .long()
2016
- .clamp(0, classmaps.shape[-2] - 1)
2017
- )
2018
-
2019
- # Sample predicted and GT class at this location
2020
- pred_class = classmaps[b, :, y_cm, x_cm].argmax()
2021
- gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
2022
-
2023
- if pred_class == gt_class:
2024
- correct += 1
2025
- total += 1
2026
-
2027
- if total > 0:
2028
- class_accuracy = torch.tensor(correct / total, device=X.device)
2029
- self.log(
2030
- "val/class_accuracy",
2031
- class_accuracy,
2032
- on_step=False,
2033
- on_epoch=True,
2034
- sync_dist=True,
2035
- )
2036
-
2037
- # Collect predictions for epoch-end evaluation if enabled
2038
- if self._collect_val_predictions:
2039
- with torch.no_grad():
2040
- # Note: Do NOT squeeze the image here - the forward() method expects
2041
- # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
2042
- inference_output = self.bottomup_inf_layer(batch)
2043
- if isinstance(inference_output, list):
2044
- inference_output = inference_output[0]
2045
-
2046
- batch_size = len(batch["frame_idx"])
2047
- for i in range(batch_size):
2048
- eff = batch["eff_scale"][i].cpu().numpy()
2049
-
2050
- # Predictions are already in original space (variable number of instances)
2051
- pred_peaks = inference_output["pred_instance_peaks"][i]
2052
- pred_scores = inference_output["pred_peak_values"][i]
2053
- if torch.is_tensor(pred_peaks):
2054
- pred_peaks = pred_peaks.cpu().numpy()
2055
- if torch.is_tensor(pred_scores):
2056
- pred_scores = pred_scores.cpu().numpy()
2057
-
2058
- # Transform GT to original space
2059
- # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
2060
- gt_prep = batch["instances"][i].cpu().numpy()
2061
- if gt_prep.ndim == 4:
2062
- gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
2063
- gt_orig = gt_prep / eff
2064
- num_inst = batch["num_instances"][i].item()
2065
- gt_orig = gt_orig[:num_inst] # Only valid instances
2066
-
2067
- self.val_predictions.append(
2068
- {
2069
- "video_idx": batch["video_idx"][i].item(),
2070
- "frame_idx": batch["frame_idx"][i].item(),
2071
- "pred_peaks": pred_peaks, # Original space, variable instances
2072
- "pred_scores": pred_scores,
2073
- }
2074
- )
2075
- self.val_ground_truth.append(
2076
- {
2077
- "video_idx": batch["video_idx"][i].item(),
2078
- "frame_idx": batch["frame_idx"][i].item(),
2079
- "gt_instances": gt_orig, # Original space
2080
- "num_instances": num_inst,
2081
- }
2082
- )
2083
-
2084
1627
 
2085
1628
  class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2086
1629
  """Lightning Module for TopDownCenteredInstance ID Model.
@@ -2229,7 +1772,6 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2229
1772
  X = torch.squeeze(batch["instance_image"], dim=1)
2230
1773
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
2231
1774
  y_classvector = batch["class_vectors"]
2232
- X = normalize_on_gpu(X)
2233
1775
  preds = self.model(X)
2234
1776
  classvector = preds["ClassVectorsHead"]
2235
1777
  confmaps = preds["CenteredInstanceConfmapsHead"]
@@ -2261,50 +1803,38 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2261
1803
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
2262
1804
  for node_idx, name in enumerate(self.node_names):
2263
1805
  self.log(
2264
- f"train/confmaps/{name}",
1806
+ f"{name}",
2265
1807
  channel_wise_loss[node_idx],
2266
1808
  prog_bar=False,
2267
1809
  on_step=False,
2268
1810
  on_epoch=True,
1811
+ logger=True,
2269
1812
  sync_dist=True,
2270
1813
  )
2271
1814
 
2272
- # Log step-level loss (every batch, uses global_step x-axis)
2273
1815
  self.log(
2274
- "loss",
1816
+ "train_loss",
2275
1817
  loss,
2276
1818
  prog_bar=True,
2277
1819
  on_step=True,
2278
1820
  on_epoch=False,
1821
+ logger=True,
2279
1822
  sync_dist=True,
2280
1823
  )
2281
- # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
2282
- self._accumulate_loss(loss)
2283
1824
  self.log(
2284
- "train/confmaps_loss",
1825
+ "train_confmap_loss",
2285
1826
  confmap_loss,
2286
1827
  on_step=False,
2287
1828
  on_epoch=True,
1829
+ logger=True,
2288
1830
  sync_dist=True,
2289
1831
  )
2290
1832
  self.log(
2291
- "train/classvector_loss",
1833
+ "train_classvector_loss",
2292
1834
  classvector_loss,
2293
1835
  on_step=False,
2294
1836
  on_epoch=True,
2295
- sync_dist=True,
2296
- )
2297
-
2298
- # Compute classification accuracy
2299
- with torch.no_grad():
2300
- pred_classes = torch.argmax(classvector, dim=1)
2301
- gt_classes = torch.argmax(y_classvector, dim=1)
2302
- class_accuracy = (pred_classes == gt_classes).float().mean()
2303
- self.log(
2304
- "train/class_accuracy",
2305
- class_accuracy,
2306
- on_step=False,
2307
- on_epoch=True,
1837
+ logger=True,
2308
1838
  sync_dist=True,
2309
1839
  )
2310
1840
  return loss
@@ -2314,7 +1844,6 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2314
1844
  X = torch.squeeze(batch["instance_image"], dim=1)
2315
1845
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
2316
1846
  y_classvector = batch["class_vectors"]
2317
- X = normalize_on_gpu(X)
2318
1847
  preds = self.model(X)
2319
1848
  classvector = preds["ClassVectorsHead"]
2320
1849
  confmaps = preds["CenteredInstanceConfmapsHead"]
@@ -2339,93 +1868,27 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
2339
1868
  }
2340
1869
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
2341
1870
  self.log(
2342
- "val/loss",
1871
+ "val_loss",
2343
1872
  val_loss,
2344
1873
  prog_bar=True,
2345
1874
  on_step=False,
2346
1875
  on_epoch=True,
1876
+ logger=True,
2347
1877
  sync_dist=True,
2348
1878
  )
2349
1879
  self.log(
2350
- "val/confmaps_loss",
1880
+ "val_confmap_loss",
2351
1881
  confmap_loss,
2352
1882
  on_step=False,
2353
1883
  on_epoch=True,
1884
+ logger=True,
2354
1885
  sync_dist=True,
2355
1886
  )
2356
1887
  self.log(
2357
- "val/classvector_loss",
1888
+ "val_classvector_loss",
2358
1889
  classvector_loss,
2359
1890
  on_step=False,
2360
1891
  on_epoch=True,
1892
+ logger=True,
2361
1893
  sync_dist=True,
2362
1894
  )
2363
-
2364
- # Compute classification accuracy
2365
- with torch.no_grad():
2366
- pred_classes = torch.argmax(classvector, dim=1)
2367
- gt_classes = torch.argmax(y_classvector, dim=1)
2368
- class_accuracy = (pred_classes == gt_classes).float().mean()
2369
- self.log(
2370
- "val/class_accuracy",
2371
- class_accuracy,
2372
- on_step=False,
2373
- on_epoch=True,
2374
- sync_dist=True,
2375
- )
2376
-
2377
- # Collect predictions for epoch-end evaluation if enabled
2378
- if self._collect_val_predictions:
2379
- # SAVE bbox BEFORE inference (it modifies in-place!)
2380
- bbox_prep_saved = batch["instance_bbox"].clone()
2381
-
2382
- with torch.no_grad():
2383
- inference_output = self.instance_peaks_inf_layer(batch)
2384
-
2385
- batch_size = len(batch["frame_idx"])
2386
- for i in range(batch_size):
2387
- eff = batch["eff_scale"][i].cpu().numpy()
2388
-
2389
- # Predictions from inference (crop-relative, original scale)
2390
- pred_peaks_crop = (
2391
- inference_output["pred_instance_peaks"][i].cpu().numpy()
2392
- )
2393
- pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
2394
-
2395
- # Compute bbox offset in original space from SAVED prep bbox
2396
- # bbox has shape (n_samples=1, 4, 2) where 4 corners
2397
- bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
2398
- bbox_top_left_orig = (
2399
- bbox_prep[0] / eff
2400
- ) # Top-left corner in original space
2401
-
2402
- # Full image coordinates (original space)
2403
- pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
2404
-
2405
- # GT transform: crop-relative preprocessed -> full image original
2406
- gt_crop_prep = (
2407
- batch["instance"][i].squeeze(0).cpu().numpy()
2408
- ) # (n_nodes, 2)
2409
- gt_crop_orig = gt_crop_prep / eff
2410
- gt_full_orig = gt_crop_orig + bbox_top_left_orig
2411
-
2412
- self.val_predictions.append(
2413
- {
2414
- "video_idx": batch["video_idx"][i].item(),
2415
- "frame_idx": batch["frame_idx"][i].item(),
2416
- "pred_peaks": pred_peaks_full.reshape(
2417
- 1, -1, 2
2418
- ), # (1, n_nodes, 2)
2419
- "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
2420
- }
2421
- )
2422
- self.val_ground_truth.append(
2423
- {
2424
- "video_idx": batch["video_idx"][i].item(),
2425
- "frame_idx": batch["frame_idx"][i].item(),
2426
- "gt_instances": gt_full_orig.reshape(
2427
- 1, -1, 2
2428
- ), # (1, n_nodes, 2)
2429
- "num_instances": 1,
2430
- }
2431
- )