sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """This module has the LightningModule classes for all model types."""
2
2
 
3
- from typing import Optional, Union, Dict, Any
3
+ from typing import Optional, Union, Dict, Any, List
4
4
  import time
5
5
  from torch import nn
6
6
  import numpy as np
@@ -33,6 +33,7 @@ from sleap_nn.inference.bottomup import (
33
33
  )
34
34
  from sleap_nn.inference.paf_grouping import PAFScorer
35
35
  from sleap_nn.architectures.model import Model
36
+ from sleap_nn.data.normalization import normalize_on_gpu
36
37
  from sleap_nn.training.losses import compute_ohkm_loss
37
38
  from loguru import logger
38
39
  from sleap_nn.training.utils import (
@@ -40,14 +41,26 @@ from sleap_nn.training.utils import (
40
41
  plot_confmaps,
41
42
  plot_img,
42
43
  plot_peaks,
44
+ VisualizationData,
43
45
  )
46
+ import matplotlib
47
+
48
+ matplotlib.use(
49
+ "Agg"
50
+ ) # Use non-interactive backend to avoid tkinter issues on Windows CI
44
51
  import matplotlib.pyplot as plt
45
52
  from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
46
53
  from sleap_nn.config.trainer_config import (
54
+ CosineAnnealingWarmupConfig,
55
+ LinearWarmupLinearDecayConfig,
47
56
  LRSchedulerConfig,
48
57
  ReduceLROnPlateauConfig,
49
58
  StepLRConfig,
50
59
  )
60
+ from sleap_nn.training.schedulers import (
61
+ LinearWarmupCosineAnnealingLR,
62
+ LinearWarmupLinearDecayLR,
63
+ )
51
64
  from sleap_nn.config.get_config import get_backbone_config
52
65
  from sleap_nn.legacy_models import (
53
66
  load_legacy_model_weights,
@@ -177,6 +190,15 @@ class LightningModel(L.LightningModule):
177
190
  self.val_loss = {}
178
191
  self.learning_rate = {}
179
192
 
193
+ # For epoch-averaged loss tracking
194
+ self._epoch_loss_sum = 0.0
195
+ self._epoch_loss_count = 0
196
+
197
+ # For epoch-end evaluation
198
+ self.val_predictions: List[Dict] = []
199
+ self.val_ground_truth: List[Dict] = []
200
+ self._collect_val_predictions: bool = False
201
+
180
202
  # Initialization for encoder and decoder stacks.
181
203
  if self.init_weights == "xavier":
182
204
  self.model.apply(xavier_init_weights)
@@ -213,7 +235,9 @@ class LightningModel(L.LightningModule):
213
235
  elif self.pretrained_backbone_weights.endswith(".h5"):
214
236
  # load from sleap model weights
215
237
  load_legacy_model_weights(
216
- self.model.backbone, self.pretrained_backbone_weights
238
+ self.model.backbone,
239
+ self.pretrained_backbone_weights,
240
+ component="backbone",
217
241
  )
218
242
 
219
243
  else:
@@ -242,7 +266,9 @@ class LightningModel(L.LightningModule):
242
266
  elif self.pretrained_head_weights.endswith(".h5"):
243
267
  # load from sleap model weights
244
268
  load_legacy_model_weights(
245
- self.model.head_layers, self.pretrained_head_weights
269
+ self.model.head_layers,
270
+ self.pretrained_head_weights,
271
+ component="head",
246
272
  )
247
273
 
248
274
  else:
@@ -298,34 +324,82 @@ class LightningModel(L.LightningModule):
298
324
  def on_train_epoch_start(self):
299
325
  """Configure the train timer at the beginning of each epoch."""
300
326
  self.train_start_time = time.time()
327
+ # Reset epoch loss tracking
328
+ self._epoch_loss_sum = 0.0
329
+ self._epoch_loss_count = 0
330
+
331
+ def _accumulate_loss(self, loss: torch.Tensor):
332
+ """Accumulate loss for epoch-averaged logging. Call this in training_step."""
333
+ self._epoch_loss_sum += loss.detach().item()
334
+ self._epoch_loss_count += 1
301
335
 
302
336
  def on_train_epoch_end(self):
303
337
  """Configure the train timer at the end of every epoch."""
304
338
  train_time = time.time() - self.train_start_time
305
339
  self.log(
306
- "train_time",
340
+ "train/time",
307
341
  train_time,
308
342
  prog_bar=False,
309
343
  on_step=False,
310
344
  on_epoch=True,
311
- logger=True,
312
345
  sync_dist=True,
313
346
  )
347
+ # Log epoch explicitly for custom x-axis support in wandb
348
+ self.log(
349
+ "epoch",
350
+ float(self.current_epoch),
351
+ on_step=False,
352
+ on_epoch=True,
353
+ sync_dist=True,
354
+ )
355
+ # Log epoch-averaged training loss
356
+ if self._epoch_loss_count > 0:
357
+ avg_loss = self._epoch_loss_sum / self._epoch_loss_count
358
+ self.log(
359
+ "train/loss",
360
+ avg_loss,
361
+ prog_bar=False,
362
+ on_step=False,
363
+ on_epoch=True,
364
+ sync_dist=True,
365
+ )
366
+ # Log current learning rate (useful for monitoring LR schedulers)
367
+ if self.trainer.optimizers:
368
+ lr = self.trainer.optimizers[0].param_groups[0]["lr"]
369
+ self.log(
370
+ "train/lr",
371
+ lr,
372
+ prog_bar=False,
373
+ on_step=False,
374
+ on_epoch=True,
375
+ sync_dist=True,
376
+ )
314
377
 
315
378
  def on_validation_epoch_start(self):
316
379
  """Configure the val timer at the beginning of each epoch."""
317
380
  self.val_start_time = time.time()
381
+ # Clear accumulated predictions for new epoch
382
+ self.val_predictions = []
383
+ self.val_ground_truth = []
318
384
 
319
385
  def on_validation_epoch_end(self):
320
386
  """Configure the val timer at the end of every epoch."""
321
387
  val_time = time.time() - self.val_start_time
322
388
  self.log(
323
- "val_time",
389
+ "val/time",
324
390
  val_time,
325
391
  prog_bar=False,
326
392
  on_step=False,
327
393
  on_epoch=True,
328
- logger=True,
394
+ sync_dist=True,
395
+ )
396
+ # Log epoch explicitly so val/* metrics can use it as x-axis in wandb
397
+ # (mirrors what on_train_epoch_end does for train/* metrics)
398
+ self.log(
399
+ "epoch",
400
+ float(self.current_epoch),
401
+ on_step=False,
402
+ on_epoch=True,
329
403
  sync_dist=True,
330
404
  )
331
405
 
@@ -362,13 +436,51 @@ class LightningModel(L.LightningModule):
362
436
  lr_scheduler_cfg.step_lr = StepLRConfig()
363
437
  elif self.lr_scheduler == "reduce_lr_on_plateau":
364
438
  lr_scheduler_cfg.reduce_lr_on_plateau = ReduceLROnPlateauConfig()
439
+ elif self.lr_scheduler == "cosine_annealing_warmup":
440
+ lr_scheduler_cfg.cosine_annealing_warmup = CosineAnnealingWarmupConfig()
441
+ elif self.lr_scheduler == "linear_warmup_linear_decay":
442
+ lr_scheduler_cfg.linear_warmup_linear_decay = (
443
+ LinearWarmupLinearDecayConfig()
444
+ )
365
445
 
366
446
  elif isinstance(self.lr_scheduler, dict):
367
447
  lr_scheduler_cfg = self.lr_scheduler
368
448
 
369
449
  for k, v in self.lr_scheduler.items():
370
450
  if v is not None:
371
- if k == "step_lr":
451
+ if k == "cosine_annealing_warmup":
452
+ cfg = self.lr_scheduler.cosine_annealing_warmup
453
+ # Use trainer's max_epochs if not specified in config
454
+ max_epochs = (
455
+ cfg.max_epochs
456
+ if cfg.max_epochs is not None
457
+ else self.trainer.max_epochs
458
+ )
459
+ scheduler = LinearWarmupCosineAnnealingLR(
460
+ optimizer=optimizer,
461
+ warmup_epochs=cfg.warmup_epochs,
462
+ max_epochs=max_epochs,
463
+ warmup_start_lr=cfg.warmup_start_lr,
464
+ eta_min=cfg.eta_min,
465
+ )
466
+ break
467
+ elif k == "linear_warmup_linear_decay":
468
+ cfg = self.lr_scheduler.linear_warmup_linear_decay
469
+ # Use trainer's max_epochs if not specified in config
470
+ max_epochs = (
471
+ cfg.max_epochs
472
+ if cfg.max_epochs is not None
473
+ else self.trainer.max_epochs
474
+ )
475
+ scheduler = LinearWarmupLinearDecayLR(
476
+ optimizer=optimizer,
477
+ warmup_epochs=cfg.warmup_epochs,
478
+ max_epochs=max_epochs,
479
+ warmup_start_lr=cfg.warmup_start_lr,
480
+ end_lr=cfg.end_lr,
481
+ )
482
+ break
483
+ elif k == "step_lr":
372
484
  scheduler = torch.optim.lr_scheduler.StepLR(
373
485
  optimizer=optimizer,
374
486
  step_size=self.lr_scheduler.step_lr.step_size,
@@ -396,7 +508,7 @@ class LightningModel(L.LightningModule):
396
508
  "optimizer": optimizer,
397
509
  "lr_scheduler": {
398
510
  "scheduler": scheduler,
399
- "monitor": "val_loss",
511
+ "monitor": "val/loss",
400
512
  },
401
513
  }
402
514
 
@@ -493,8 +605,15 @@ class SingleInstanceLightningModule(LightningModel):
493
605
  )
494
606
  self.node_names = self.head_configs.single_instance.confmaps.part_names
495
607
 
496
- def visualize_example(self, sample):
497
- """Visualize predictions during training (used with callbacks)."""
608
+ def get_visualization_data(self, sample) -> VisualizationData:
609
+ """Extract visualization data from a sample.
610
+
611
+ Args:
612
+ sample: A sample dictionary from the data pipeline.
613
+
614
+ Returns:
615
+ VisualizationData containing image, confmaps, peaks, etc.
616
+ """
498
617
  ex = sample.copy()
499
618
  ex["eff_scale"] = torch.tensor([1.0])
500
619
  for k, v in ex.items():
@@ -502,27 +621,41 @@ class SingleInstanceLightningModule(LightningModel):
502
621
  ex[k] = v.to(device=self.device)
503
622
  ex["image"] = ex["image"].unsqueeze(dim=0)
504
623
  output = self.single_instance_inf_layer(ex)[0]
624
+
505
625
  peaks = output["pred_instance_peaks"].cpu().numpy()
506
- img = (
507
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
508
- ) # convert from (C, H, W) to (H, W, C)
626
+ peak_values = output["pred_peak_values"].cpu().numpy()
627
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
509
628
  gt_instances = ex["instances"][0].cpu().numpy()
510
- confmaps = (
511
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
512
- ) # convert from (C, H, W) to (H, W, C)
629
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
630
+
631
+ return VisualizationData(
632
+ image=img,
633
+ pred_confmaps=confmaps,
634
+ pred_peaks=peaks,
635
+ pred_peak_values=peak_values,
636
+ gt_instances=gt_instances,
637
+ node_names=list(self.node_names) if self.node_names else [],
638
+ output_scale=confmaps.shape[0] / img.shape[0],
639
+ is_paired=True,
640
+ )
641
+
642
+ def visualize_example(self, sample):
643
+ """Visualize predictions during training (used with callbacks)."""
644
+ data = self.get_visualization_data(sample)
513
645
  scale = 1.0
514
- if img.shape[0] < 512:
646
+ if data.image.shape[0] < 512:
515
647
  scale = 2.0
516
- if img.shape[0] < 256:
648
+ if data.image.shape[0] < 256:
517
649
  scale = 4.0
518
- fig = plot_img(img, dpi=72 * scale, scale=scale)
519
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
520
- plot_peaks(gt_instances, peaks, paired=True)
650
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
651
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
652
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
521
653
  return fig
522
654
 
523
655
  def forward(self, img):
524
656
  """Forward pass of the model."""
525
657
  img = torch.squeeze(img, dim=1).to(self.device)
658
+ img = normalize_on_gpu(img)
526
659
  return self.model(img)["SingleInstanceConfmapsHead"]
527
660
 
528
661
  def training_step(self, batch, batch_idx):
@@ -531,6 +664,7 @@ class SingleInstanceLightningModule(LightningModel):
531
664
  torch.squeeze(batch["image"], dim=1),
532
665
  torch.squeeze(batch["confidence_maps"], dim=1),
533
666
  )
667
+ X = normalize_on_gpu(X)
534
668
 
535
669
  y_preds = self.model(X)["SingleInstanceConfmapsHead"]
536
670
 
@@ -554,23 +688,24 @@ class SingleInstanceLightningModule(LightningModel):
554
688
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
555
689
  for node_idx, name in enumerate(self.node_names):
556
690
  self.log(
557
- f"{name}",
691
+ f"train/confmaps/{name}",
558
692
  channel_wise_loss[node_idx],
559
- prog_bar=True,
560
- on_step=True,
693
+ prog_bar=False,
694
+ on_step=False,
561
695
  on_epoch=True,
562
- logger=True,
563
696
  sync_dist=True,
564
697
  )
698
+ # Log step-level loss (every batch, uses global_step x-axis)
565
699
  self.log(
566
- "train_loss",
700
+ "loss",
567
701
  loss,
568
702
  prog_bar=True,
569
703
  on_step=True,
570
- on_epoch=True,
571
- logger=True,
704
+ on_epoch=False,
572
705
  sync_dist=True,
573
706
  )
707
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
708
+ self._accumulate_loss(loss)
574
709
  return loss
575
710
 
576
711
  def validation_step(self, batch, batch_idx):
@@ -579,6 +714,7 @@ class SingleInstanceLightningModule(LightningModel):
579
714
  torch.squeeze(batch["image"], dim=1),
580
715
  torch.squeeze(batch["confidence_maps"], dim=1),
581
716
  )
717
+ X = normalize_on_gpu(X)
582
718
 
583
719
  y_preds = self.model(X)["SingleInstanceConfmapsHead"]
584
720
  val_loss = nn.MSELoss()(y_preds, y)
@@ -592,26 +728,60 @@ class SingleInstanceLightningModule(LightningModel):
592
728
  loss_scale=self.loss_scale,
593
729
  )
594
730
  val_loss = val_loss + ohkm_loss
595
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
596
731
  self.log(
597
- "learning_rate",
598
- lr,
599
- prog_bar=True,
600
- on_step=True,
601
- on_epoch=True,
602
- logger=True,
603
- sync_dist=True,
604
- )
605
- self.log(
606
- "val_loss",
732
+ "val/loss",
607
733
  val_loss,
608
734
  prog_bar=True,
609
- on_step=True,
735
+ on_step=False,
610
736
  on_epoch=True,
611
- logger=True,
612
737
  sync_dist=True,
613
738
  )
614
739
 
740
+ # Collect predictions for epoch-end evaluation if enabled
741
+ if self._collect_val_predictions:
742
+ with torch.no_grad():
743
+ # Squeeze n_samples dim from image for inference (batch, 1, C, H, W) -> (batch, C, H, W)
744
+ inference_batch = {k: v for k, v in batch.items()}
745
+ if inference_batch["image"].ndim == 5:
746
+ inference_batch["image"] = inference_batch["image"].squeeze(1)
747
+ inference_output = self.single_instance_inf_layer(inference_batch)
748
+ if isinstance(inference_output, list):
749
+ inference_output = inference_output[0]
750
+
751
+ batch_size = len(batch["frame_idx"])
752
+ for i in range(batch_size):
753
+ eff = batch["eff_scale"][i].cpu().numpy()
754
+
755
+ # Predictions are already in original image space (inference divides by eff_scale)
756
+ pred_peaks = inference_output["pred_instance_peaks"][i].cpu().numpy()
757
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
758
+
759
+ # Transform GT from preprocessed to original image space
760
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
761
+ gt_prep = batch["instances"][i].cpu().numpy()
762
+ if gt_prep.ndim == 4:
763
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
764
+ gt_orig = gt_prep / eff
765
+ num_inst = batch["num_instances"][i].item()
766
+ gt_orig = gt_orig[:num_inst] # Only valid instances
767
+
768
+ self.val_predictions.append(
769
+ {
770
+ "video_idx": batch["video_idx"][i].item(),
771
+ "frame_idx": batch["frame_idx"][i].item(),
772
+ "pred_peaks": pred_peaks,
773
+ "pred_scores": pred_scores,
774
+ }
775
+ )
776
+ self.val_ground_truth.append(
777
+ {
778
+ "video_idx": batch["video_idx"][i].item(),
779
+ "frame_idx": batch["frame_idx"][i].item(),
780
+ "gt_instances": gt_orig,
781
+ "num_instances": num_inst,
782
+ }
783
+ )
784
+
615
785
 
616
786
  class TopDownCenteredInstanceLightningModule(LightningModel):
617
787
  """Lightning Module for TopDownCenteredInstance Model.
@@ -705,8 +875,8 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
705
875
 
706
876
  self.node_names = self.head_configs.centered_instance.confmaps.part_names
707
877
 
708
- def visualize_example(self, sample):
709
- """Visualize predictions during training (used with callbacks)."""
878
+ def get_visualization_data(self, sample) -> VisualizationData:
879
+ """Extract visualization data from a sample."""
710
880
  ex = sample.copy()
711
881
  ex["eff_scale"] = torch.tensor([1.0])
712
882
  for k, v in ex.items():
@@ -714,27 +884,41 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
714
884
  ex[k] = v.to(device=self.device)
715
885
  ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
716
886
  output = self.instance_peaks_inf_layer(ex)
887
+
717
888
  peaks = output["pred_instance_peaks"].cpu().numpy()
718
- img = (
719
- output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
720
- ) # convert from (C, H, W) to (H, W, C)
889
+ peak_values = output["pred_peak_values"].cpu().numpy()
890
+ img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
721
891
  gt_instances = ex["instance"].cpu().numpy()
722
- confmaps = (
723
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
724
- ) # convert from (C, H, W) to (H, W, C)
892
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
893
+
894
+ return VisualizationData(
895
+ image=img,
896
+ pred_confmaps=confmaps,
897
+ pred_peaks=peaks,
898
+ pred_peak_values=peak_values,
899
+ gt_instances=gt_instances,
900
+ node_names=list(self.node_names) if self.node_names else [],
901
+ output_scale=confmaps.shape[0] / img.shape[0],
902
+ is_paired=True,
903
+ )
904
+
905
+ def visualize_example(self, sample):
906
+ """Visualize predictions during training (used with callbacks)."""
907
+ data = self.get_visualization_data(sample)
725
908
  scale = 1.0
726
- if img.shape[0] < 512:
909
+ if data.image.shape[0] < 512:
727
910
  scale = 2.0
728
- if img.shape[0] < 256:
911
+ if data.image.shape[0] < 256:
729
912
  scale = 4.0
730
- fig = plot_img(img, dpi=72 * scale, scale=scale)
731
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
732
- plot_peaks(gt_instances, peaks, paired=True)
913
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
914
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
915
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
733
916
  return fig
734
917
 
735
918
  def forward(self, img):
736
919
  """Forward pass of the model."""
737
920
  img = torch.squeeze(img, dim=1).to(self.device)
921
+ img = normalize_on_gpu(img)
738
922
  return self.model(img)["CenteredInstanceConfmapsHead"]
739
923
 
740
924
  def training_step(self, batch, batch_idx):
@@ -743,6 +927,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
743
927
  torch.squeeze(batch["instance_image"], dim=1),
744
928
  torch.squeeze(batch["confidence_maps"], dim=1),
745
929
  )
930
+ X = normalize_on_gpu(X)
746
931
 
747
932
  y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
748
933
 
@@ -766,24 +951,25 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
766
951
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
767
952
  for node_idx, name in enumerate(self.node_names):
768
953
  self.log(
769
- f"{name}",
954
+ f"train/confmaps/{name}",
770
955
  channel_wise_loss[node_idx],
771
- prog_bar=True,
772
- on_step=True,
956
+ prog_bar=False,
957
+ on_step=False,
773
958
  on_epoch=True,
774
- logger=True,
775
959
  sync_dist=True,
776
960
  )
777
961
 
962
+ # Log step-level loss (every batch, uses global_step x-axis)
778
963
  self.log(
779
- "train_loss",
964
+ "loss",
780
965
  loss,
781
966
  prog_bar=True,
782
967
  on_step=True,
783
- on_epoch=True,
784
- logger=True,
968
+ on_epoch=False,
785
969
  sync_dist=True,
786
970
  )
971
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
972
+ self._accumulate_loss(loss)
787
973
  return loss
788
974
 
789
975
  def validation_step(self, batch, batch_idx):
@@ -792,6 +978,7 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
792
978
  torch.squeeze(batch["instance_image"], dim=1),
793
979
  torch.squeeze(batch["confidence_maps"], dim=1),
794
980
  )
981
+ X = normalize_on_gpu(X)
795
982
 
796
983
  y_preds = self.model(X)["CenteredInstanceConfmapsHead"]
797
984
  val_loss = nn.MSELoss()(y_preds, y)
@@ -805,26 +992,71 @@ class TopDownCenteredInstanceLightningModule(LightningModel):
805
992
  loss_scale=self.loss_scale,
806
993
  )
807
994
  val_loss = val_loss + ohkm_loss
808
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
809
- self.log(
810
- "learning_rate",
811
- lr,
812
- prog_bar=True,
813
- on_step=True,
814
- on_epoch=True,
815
- logger=True,
816
- sync_dist=True,
817
- )
818
995
  self.log(
819
- "val_loss",
996
+ "val/loss",
820
997
  val_loss,
821
998
  prog_bar=True,
822
- on_step=True,
999
+ on_step=False,
823
1000
  on_epoch=True,
824
- logger=True,
825
1001
  sync_dist=True,
826
1002
  )
827
1003
 
1004
+ # Collect predictions for epoch-end evaluation if enabled
1005
+ if self._collect_val_predictions:
1006
+ # SAVE bbox BEFORE inference (it modifies in-place!)
1007
+ bbox_prep_saved = batch["instance_bbox"].clone()
1008
+
1009
+ with torch.no_grad():
1010
+ inference_output = self.instance_peaks_inf_layer(batch)
1011
+
1012
+ batch_size = len(batch["frame_idx"])
1013
+ for i in range(batch_size):
1014
+ eff = batch["eff_scale"][i].cpu().numpy()
1015
+
1016
+ # Predictions from inference (crop-relative, original scale)
1017
+ pred_peaks_crop = (
1018
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
1019
+ )
1020
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
1021
+
1022
+ # Compute bbox offset in original space from SAVED prep bbox
1023
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
1024
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
1025
+ bbox_top_left_orig = (
1026
+ bbox_prep[0] / eff
1027
+ ) # Top-left corner in original space
1028
+
1029
+ # Full image coordinates (original space)
1030
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
1031
+
1032
+ # GT transform: crop-relative preprocessed -> full image original
1033
+ gt_crop_prep = (
1034
+ batch["instance"][i].squeeze(0).cpu().numpy()
1035
+ ) # (n_nodes, 2)
1036
+ gt_crop_orig = gt_crop_prep / eff
1037
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
1038
+
1039
+ self.val_predictions.append(
1040
+ {
1041
+ "video_idx": batch["video_idx"][i].item(),
1042
+ "frame_idx": batch["frame_idx"][i].item(),
1043
+ "pred_peaks": pred_peaks_full.reshape(
1044
+ 1, -1, 2
1045
+ ), # (1, n_nodes, 2)
1046
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
1047
+ }
1048
+ )
1049
+ self.val_ground_truth.append(
1050
+ {
1051
+ "video_idx": batch["video_idx"][i].item(),
1052
+ "frame_idx": batch["frame_idx"][i].item(),
1053
+ "gt_instances": gt_full_orig.reshape(
1054
+ 1, -1, 2
1055
+ ), # (1, n_nodes, 2)
1056
+ "num_instances": 1,
1057
+ }
1058
+ )
1059
+
828
1060
 
829
1061
  class CentroidLightningModule(LightningModel):
830
1062
  """Lightning Module for Centroid Model.
@@ -916,9 +1148,10 @@ class CentroidLightningModule(LightningModel):
916
1148
  output_stride=self.head_configs.centroid.confmaps.output_stride,
917
1149
  input_scale=1.0,
918
1150
  )
1151
+ self.node_names = ["centroid"]
919
1152
 
920
- def visualize_example(self, sample):
921
- """Visualize predictions during training (used with callbacks)."""
1153
+ def get_visualization_data(self, sample) -> VisualizationData:
1154
+ """Extract visualization data from a sample."""
922
1155
  ex = sample.copy()
923
1156
  ex["eff_scale"] = torch.tensor([1.0])
924
1157
  for k, v in ex.items():
@@ -927,26 +1160,40 @@ class CentroidLightningModule(LightningModel):
927
1160
  ex["image"] = ex["image"].unsqueeze(dim=0)
928
1161
  gt_centroids = ex["centroids"].cpu().numpy()
929
1162
  output = self.centroid_inf_layer(ex)
1163
+
930
1164
  peaks = output["centroids"][0].cpu().numpy()
931
- img = (
932
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
933
- ) # convert from (C, H, W) to (H, W, C)
934
- confmaps = (
935
- output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
936
- ) # convert from (C, H, W) to (H, W, C)
1165
+ centroid_vals = output["centroid_vals"][0].cpu().numpy()
1166
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1167
+ confmaps = output["pred_centroid_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1168
+
1169
+ return VisualizationData(
1170
+ image=img,
1171
+ pred_confmaps=confmaps,
1172
+ pred_peaks=peaks,
1173
+ pred_peak_values=centroid_vals,
1174
+ gt_instances=gt_centroids,
1175
+ node_names=self.node_names,
1176
+ output_scale=confmaps.shape[0] / img.shape[0],
1177
+ is_paired=False,
1178
+ )
1179
+
1180
+ def visualize_example(self, sample):
1181
+ """Visualize predictions during training (used with callbacks)."""
1182
+ data = self.get_visualization_data(sample)
937
1183
  scale = 1.0
938
- if img.shape[0] < 512:
1184
+ if data.image.shape[0] < 512:
939
1185
  scale = 2.0
940
- if img.shape[0] < 256:
1186
+ if data.image.shape[0] < 256:
941
1187
  scale = 4.0
942
- fig = plot_img(img, dpi=72 * scale, scale=scale)
943
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
944
- plot_peaks(gt_centroids, peaks, paired=False)
1188
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1189
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1190
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
945
1191
  return fig
946
1192
 
947
1193
  def forward(self, img):
948
1194
  """Forward pass of the model."""
949
1195
  img = torch.squeeze(img, dim=1).to(self.device)
1196
+ img = normalize_on_gpu(img)
950
1197
  return self.model(img)["CentroidConfmapsHead"]
951
1198
 
952
1199
  def training_step(self, batch, batch_idx):
@@ -955,18 +1202,21 @@ class CentroidLightningModule(LightningModel):
955
1202
  torch.squeeze(batch["image"], dim=1),
956
1203
  torch.squeeze(batch["centroids_confidence_maps"], dim=1),
957
1204
  )
1205
+ X = normalize_on_gpu(X)
958
1206
 
959
1207
  y_preds = self.model(X)["CentroidConfmapsHead"]
960
1208
  loss = nn.MSELoss()(y_preds, y)
1209
+ # Log step-level loss (every batch, uses global_step x-axis)
961
1210
  self.log(
962
- "train_loss",
1211
+ "loss",
963
1212
  loss,
964
1213
  prog_bar=True,
965
1214
  on_step=True,
966
- on_epoch=True,
967
- logger=True,
1215
+ on_epoch=False,
968
1216
  sync_dist=True,
969
1217
  )
1218
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1219
+ self._accumulate_loss(loss)
970
1220
  return loss
971
1221
 
972
1222
  def validation_step(self, batch, batch_idx):
@@ -975,29 +1225,74 @@ class CentroidLightningModule(LightningModel):
975
1225
  torch.squeeze(batch["image"], dim=1),
976
1226
  torch.squeeze(batch["centroids_confidence_maps"], dim=1),
977
1227
  )
1228
+ X = normalize_on_gpu(X)
978
1229
 
979
1230
  y_preds = self.model(X)["CentroidConfmapsHead"]
980
1231
  val_loss = nn.MSELoss()(y_preds, y)
981
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
982
1232
  self.log(
983
- "learning_rate",
984
- lr,
985
- prog_bar=True,
986
- on_step=True,
987
- on_epoch=True,
988
- logger=True,
989
- sync_dist=True,
990
- )
991
- self.log(
992
- "val_loss",
1233
+ "val/loss",
993
1234
  val_loss,
994
1235
  prog_bar=True,
995
- on_step=True,
1236
+ on_step=False,
996
1237
  on_epoch=True,
997
- logger=True,
998
1238
  sync_dist=True,
999
1239
  )
1000
1240
 
1241
+ # Collect predictions for epoch-end evaluation if enabled
1242
+ if self._collect_val_predictions:
1243
+ # Save GT centroids before inference (inference overwrites batch["centroids"])
1244
+ batch["gt_centroids"] = batch["centroids"].clone()
1245
+
1246
+ with torch.no_grad():
1247
+ inference_output = self.centroid_inf_layer(batch)
1248
+
1249
+ batch_size = len(batch["frame_idx"])
1250
+ for i in range(batch_size):
1251
+ eff = batch["eff_scale"][i].cpu().numpy()
1252
+
1253
+ # Predictions are in original image space (inference divides by eff_scale)
1254
+ # centroids shape: (batch, 1, max_instances, 2) - squeeze to (max_instances, 2)
1255
+ pred_centroids = (
1256
+ inference_output["centroids"][i].squeeze(0).cpu().numpy()
1257
+ )
1258
+ pred_vals = inference_output["centroid_vals"][i].cpu().numpy()
1259
+
1260
+ # Transform GT centroids from preprocessed to original image space
1261
+ # Use "gt_centroids" since inference overwrites "centroids" with predictions
1262
+ gt_centroids_prep = (
1263
+ batch["gt_centroids"][i].cpu().numpy()
1264
+ ) # (n_samples=1, max_inst, 2)
1265
+ gt_centroids_orig = gt_centroids_prep.squeeze(0) / eff # (max_inst, 2)
1266
+ num_inst = batch["num_instances"][i].item()
1267
+
1268
+ # Filter to valid instances (non-NaN)
1269
+ valid_pred_mask = ~np.isnan(pred_centroids).any(axis=1)
1270
+ pred_centroids = pred_centroids[valid_pred_mask]
1271
+ pred_vals = pred_vals[valid_pred_mask]
1272
+
1273
+ gt_centroids_valid = gt_centroids_orig[:num_inst]
1274
+
1275
+ self.val_predictions.append(
1276
+ {
1277
+ "video_idx": batch["video_idx"][i].item(),
1278
+ "frame_idx": batch["frame_idx"][i].item(),
1279
+ "pred_peaks": pred_centroids.reshape(
1280
+ -1, 1, 2
1281
+ ), # (n_inst, 1, 2)
1282
+ "pred_scores": pred_vals.reshape(-1, 1), # (n_inst, 1)
1283
+ }
1284
+ )
1285
+ self.val_ground_truth.append(
1286
+ {
1287
+ "video_idx": batch["video_idx"][i].item(),
1288
+ "frame_idx": batch["frame_idx"][i].item(),
1289
+ "gt_instances": gt_centroids_valid.reshape(
1290
+ -1, 1, 2
1291
+ ), # (n_inst, 1, 2)
1292
+ "num_instances": num_inst,
1293
+ }
1294
+ )
1295
+
1001
1296
 
1002
1297
  class BottomUpLightningModule(LightningModel):
1003
1298
  """Lightning Module for BottomUp Model.
@@ -1090,16 +1385,20 @@ class BottomUpLightningModule(LightningModel):
1090
1385
  self.bottomup_inf_layer = BottomUpInferenceModel(
1091
1386
  torch_model=self.forward,
1092
1387
  paf_scorer=paf_scorer,
1093
- peak_threshold=0.2,
1388
+ peak_threshold=0.1, # Lower threshold for epoch-end eval during training
1094
1389
  input_scale=1.0,
1095
1390
  return_confmaps=True,
1096
1391
  return_pafs=True,
1097
1392
  cms_output_stride=self.head_configs.bottomup.confmaps.output_stride,
1098
1393
  pafs_output_stride=self.head_configs.bottomup.pafs.output_stride,
1394
+ max_peaks_per_node=100, # Prevents combinatorial explosion in early training
1099
1395
  )
1396
+ self.node_names = list(self.head_configs.bottomup.confmaps.part_names)
1100
1397
 
1101
- def visualize_example(self, sample):
1102
- """Visualize predictions during training (used with callbacks)."""
1398
+ def get_visualization_data(
1399
+ self, sample, include_pafs: bool = False
1400
+ ) -> VisualizationData:
1401
+ """Extract visualization data from a sample."""
1103
1402
  ex = sample.copy()
1104
1403
  ex["eff_scale"] = torch.tensor([1.0])
1105
1404
  for k, v in ex.items():
@@ -1107,54 +1406,65 @@ class BottomUpLightningModule(LightningModel):
1107
1406
  ex[k] = v.to(device=self.device)
1108
1407
  ex["image"] = ex["image"].unsqueeze(dim=0)
1109
1408
  output = self.bottomup_inf_layer(ex)[0]
1409
+
1110
1410
  peaks = output["pred_instance_peaks"][0].cpu().numpy()
1111
- img = (
1112
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1113
- ) # convert from (C, H, W) to (H, W, C)
1411
+ peak_values = output["pred_peak_values"][0].cpu().numpy()
1412
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1114
1413
  gt_instances = ex["instances"][0].cpu().numpy()
1115
- confmaps = (
1116
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1117
- ) # convert from (C, H, W) to (H, W, C)
1414
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1415
+
1416
+ pred_pafs = None
1417
+ if include_pafs:
1418
+ pafs = output["pred_part_affinity_fields"].cpu().numpy()[0]
1419
+ pred_pafs = pafs # (h, w, 2*edges)
1420
+
1421
+ return VisualizationData(
1422
+ image=img,
1423
+ pred_confmaps=confmaps,
1424
+ pred_peaks=peaks,
1425
+ pred_peak_values=peak_values,
1426
+ gt_instances=gt_instances,
1427
+ node_names=self.node_names,
1428
+ output_scale=confmaps.shape[0] / img.shape[0],
1429
+ is_paired=False,
1430
+ pred_pafs=pred_pafs,
1431
+ )
1432
+
1433
+ def visualize_example(self, sample):
1434
+ """Visualize predictions during training (used with callbacks)."""
1435
+ data = self.get_visualization_data(sample)
1118
1436
  scale = 1.0
1119
- if img.shape[0] < 512:
1437
+ if data.image.shape[0] < 512:
1120
1438
  scale = 2.0
1121
- if img.shape[0] < 256:
1439
+ if data.image.shape[0] < 256:
1122
1440
  scale = 4.0
1123
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1124
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1441
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1442
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1125
1443
  plt.xlim(plt.xlim())
1126
1444
  plt.ylim(plt.ylim())
1127
- plot_peaks(gt_instances, peaks, paired=False)
1445
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1128
1446
  return fig
1129
1447
 
1130
1448
  def visualize_pafs_example(self, sample):
1131
- """Visualize predictions during training (used with callbacks)."""
1132
- ex = sample.copy()
1133
- ex["eff_scale"] = torch.tensor([1.0])
1134
- for k, v in ex.items():
1135
- if isinstance(v, torch.Tensor):
1136
- ex[k] = v.to(device=self.device)
1137
- ex["image"] = ex["image"].unsqueeze(dim=0)
1138
- output = self.bottomup_inf_layer(ex)[0]
1139
- img = (
1140
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1141
- ) # convert from (C, H, W) to (H, W, C)
1142
- pafs = output["pred_part_affinity_fields"].cpu().numpy()[0] # (h, w, 2*edges)
1449
+ """Visualize PAF predictions during training (used with callbacks)."""
1450
+ data = self.get_visualization_data(sample, include_pafs=True)
1143
1451
  scale = 1.0
1144
- if img.shape[0] < 512:
1452
+ if data.image.shape[0] < 512:
1145
1453
  scale = 2.0
1146
- if img.shape[0] < 256:
1454
+ if data.image.shape[0] < 256:
1147
1455
  scale = 4.0
1148
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1456
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1149
1457
 
1458
+ pafs = data.pred_pafs
1150
1459
  pafs = pafs.reshape((pafs.shape[0], pafs.shape[1], -1, 2))
1151
1460
  pafs_mag = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
1152
- plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / img.shape[0])
1461
+ plot_confmaps(pafs_mag, output_scale=pafs_mag.shape[0] / data.image.shape[0])
1153
1462
  return fig
1154
1463
 
1155
1464
  def forward(self, img):
1156
1465
  """Forward pass of the model."""
1157
1466
  img = torch.squeeze(img, dim=1).to(self.device)
1467
+ img = normalize_on_gpu(img)
1158
1468
  output = self.model(img)
1159
1469
  return {
1160
1470
  "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
@@ -1166,6 +1476,7 @@ class BottomUpLightningModule(LightningModel):
1166
1476
  X = torch.squeeze(batch["image"], dim=1)
1167
1477
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1168
1478
  y_paf = batch["part_affinity_fields"]
1479
+ X = normalize_on_gpu(X)
1169
1480
  preds = self.model(X)
1170
1481
  pafs = preds["PartAffinityFieldsHead"]
1171
1482
  confmaps = preds["MultiInstanceConfmapsHead"]
@@ -1198,13 +1509,29 @@ class BottomUpLightningModule(LightningModel):
1198
1509
  "PartAffinityFieldsHead": pafs_loss,
1199
1510
  }
1200
1511
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1512
+ # Log step-level loss (every batch, uses global_step x-axis)
1201
1513
  self.log(
1202
- "train_loss",
1514
+ "loss",
1203
1515
  loss,
1204
1516
  prog_bar=True,
1205
1517
  on_step=True,
1518
+ on_epoch=False,
1519
+ sync_dist=True,
1520
+ )
1521
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1522
+ self._accumulate_loss(loss)
1523
+ self.log(
1524
+ "train/confmaps_loss",
1525
+ confmap_loss,
1526
+ on_step=False,
1527
+ on_epoch=True,
1528
+ sync_dist=True,
1529
+ )
1530
+ self.log(
1531
+ "train/paf_loss",
1532
+ pafs_loss,
1533
+ on_step=False,
1206
1534
  on_epoch=True,
1207
- logger=True,
1208
1535
  sync_dist=True,
1209
1536
  )
1210
1537
  return loss
@@ -1214,6 +1541,7 @@ class BottomUpLightningModule(LightningModel):
1214
1541
  X = torch.squeeze(batch["image"], dim=1)
1215
1542
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1216
1543
  y_paf = batch["part_affinity_fields"]
1544
+ X = normalize_on_gpu(X)
1217
1545
 
1218
1546
  preds = self.model(X)
1219
1547
  pafs = preds["PartAffinityFieldsHead"]
@@ -1248,25 +1576,75 @@ class BottomUpLightningModule(LightningModel):
1248
1576
  }
1249
1577
 
1250
1578
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1251
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1252
1579
  self.log(
1253
- "learning_rate",
1254
- lr,
1580
+ "val/loss",
1581
+ val_loss,
1255
1582
  prog_bar=True,
1256
- on_step=True,
1583
+ on_step=False,
1257
1584
  on_epoch=True,
1258
- logger=True,
1259
1585
  sync_dist=True,
1260
1586
  )
1261
1587
  self.log(
1262
- "val_loss",
1263
- val_loss,
1264
- prog_bar=True,
1265
- on_step=True,
1588
+ "val/confmaps_loss",
1589
+ confmap_loss,
1590
+ on_step=False,
1266
1591
  on_epoch=True,
1267
- logger=True,
1268
1592
  sync_dist=True,
1269
1593
  )
1594
+ self.log(
1595
+ "val/paf_loss",
1596
+ pafs_loss,
1597
+ on_step=False,
1598
+ on_epoch=True,
1599
+ sync_dist=True,
1600
+ )
1601
+
1602
+ # Collect predictions for epoch-end evaluation if enabled
1603
+ if self._collect_val_predictions:
1604
+ with torch.no_grad():
1605
+ # Note: Do NOT squeeze the image here - the forward() method expects
1606
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
1607
+ inference_output = self.bottomup_inf_layer(batch)
1608
+ if isinstance(inference_output, list):
1609
+ inference_output = inference_output[0]
1610
+
1611
+ batch_size = len(batch["frame_idx"])
1612
+ for i in range(batch_size):
1613
+ eff = batch["eff_scale"][i].cpu().numpy()
1614
+
1615
+ # Predictions are already in original space (variable number of instances)
1616
+ pred_peaks = inference_output["pred_instance_peaks"][i]
1617
+ pred_scores = inference_output["pred_peak_values"][i]
1618
+ if torch.is_tensor(pred_peaks):
1619
+ pred_peaks = pred_peaks.cpu().numpy()
1620
+ if torch.is_tensor(pred_scores):
1621
+ pred_scores = pred_scores.cpu().numpy()
1622
+
1623
+ # Transform GT to original space
1624
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
1625
+ gt_prep = batch["instances"][i].cpu().numpy()
1626
+ if gt_prep.ndim == 4:
1627
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
1628
+ gt_orig = gt_prep / eff
1629
+ num_inst = batch["num_instances"][i].item()
1630
+ gt_orig = gt_orig[:num_inst] # Only valid instances
1631
+
1632
+ self.val_predictions.append(
1633
+ {
1634
+ "video_idx": batch["video_idx"][i].item(),
1635
+ "frame_idx": batch["frame_idx"][i].item(),
1636
+ "pred_peaks": pred_peaks, # Original space, variable instances
1637
+ "pred_scores": pred_scores,
1638
+ }
1639
+ )
1640
+ self.val_ground_truth.append(
1641
+ {
1642
+ "video_idx": batch["video_idx"][i].item(),
1643
+ "frame_idx": batch["frame_idx"][i].item(),
1644
+ "gt_instances": gt_orig, # Original space
1645
+ "num_instances": num_inst,
1646
+ }
1647
+ )
1270
1648
 
1271
1649
 
1272
1650
  class BottomUpMultiClassLightningModule(LightningModel):
@@ -1361,9 +1739,14 @@ class BottomUpMultiClassLightningModule(LightningModel):
1361
1739
  cms_output_stride=self.head_configs.multi_class_bottomup.confmaps.output_stride,
1362
1740
  class_maps_output_stride=self.head_configs.multi_class_bottomup.class_maps.output_stride,
1363
1741
  )
1742
+ self.node_names = list(
1743
+ self.head_configs.multi_class_bottomup.confmaps.part_names
1744
+ )
1364
1745
 
1365
- def visualize_example(self, sample):
1366
- """Visualize predictions during training (used with callbacks)."""
1746
+ def get_visualization_data(
1747
+ self, sample, include_class_maps: bool = False
1748
+ ) -> VisualizationData:
1749
+ """Extract visualization data from a sample."""
1367
1750
  ex = sample.copy()
1368
1751
  ex["eff_scale"] = torch.tensor([1.0])
1369
1752
  for k, v in ex.items():
@@ -1371,54 +1754,65 @@ class BottomUpMultiClassLightningModule(LightningModel):
1371
1754
  ex[k] = v.to(device=self.device)
1372
1755
  ex["image"] = ex["image"].unsqueeze(dim=0)
1373
1756
  output = self.bottomup_inf_layer(ex)[0]
1757
+
1374
1758
  peaks = output["pred_instance_peaks"][0].cpu().numpy()
1375
- img = (
1376
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1377
- ) # convert from (C, H, W) to (H, W, C)
1759
+ peak_values = output["pred_peak_values"][0].cpu().numpy()
1760
+ img = output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1378
1761
  gt_instances = ex["instances"][0].cpu().numpy()
1379
- confmaps = (
1380
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1381
- ) # convert from (C, H, W) to (H, W, C)
1762
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1763
+
1764
+ pred_class_maps = None
1765
+ if include_class_maps:
1766
+ pred_class_maps = (
1767
+ output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
1768
+ )
1769
+
1770
+ return VisualizationData(
1771
+ image=img,
1772
+ pred_confmaps=confmaps,
1773
+ pred_peaks=peaks,
1774
+ pred_peak_values=peak_values,
1775
+ gt_instances=gt_instances,
1776
+ node_names=self.node_names,
1777
+ output_scale=confmaps.shape[0] / img.shape[0],
1778
+ is_paired=False,
1779
+ pred_class_maps=pred_class_maps,
1780
+ )
1781
+
1782
+ def visualize_example(self, sample):
1783
+ """Visualize predictions during training (used with callbacks)."""
1784
+ data = self.get_visualization_data(sample)
1382
1785
  scale = 1.0
1383
- if img.shape[0] < 512:
1786
+ if data.image.shape[0] < 512:
1384
1787
  scale = 2.0
1385
- if img.shape[0] < 256:
1788
+ if data.image.shape[0] < 256:
1386
1789
  scale = 4.0
1387
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1388
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1790
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1791
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
1389
1792
  plt.xlim(plt.xlim())
1390
1793
  plt.ylim(plt.ylim())
1391
- plot_peaks(gt_instances, peaks, paired=False)
1794
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1392
1795
  return fig
1393
1796
 
1394
1797
  def visualize_class_maps_example(self, sample):
1395
- """Visualize predictions during training (used with callbacks)."""
1396
- ex = sample.copy()
1397
- ex["eff_scale"] = torch.tensor([1.0])
1398
- for k, v in ex.items():
1399
- if isinstance(v, torch.Tensor):
1400
- ex[k] = v.to(device=self.device)
1401
- ex["image"] = ex["image"].unsqueeze(dim=0)
1402
- output = self.bottomup_inf_layer(ex)[0]
1403
- img = (
1404
- output["image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1405
- ) # convert from (C, H, W) to (H, W, C)
1406
- classmaps = (
1407
- output["pred_class_maps"].cpu().numpy()[0].transpose(1, 2, 0)
1408
- ) # (n_classes, h, w)
1798
+ """Visualize class map predictions during training (used with callbacks)."""
1799
+ data = self.get_visualization_data(sample, include_class_maps=True)
1409
1800
  scale = 1.0
1410
- if img.shape[0] < 512:
1801
+ if data.image.shape[0] < 512:
1411
1802
  scale = 2.0
1412
- if img.shape[0] < 256:
1803
+ if data.image.shape[0] < 256:
1413
1804
  scale = 4.0
1414
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1415
-
1416
- plot_confmaps(classmaps, output_scale=classmaps.shape[0] / img.shape[0])
1805
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
1806
+ plot_confmaps(
1807
+ data.pred_class_maps,
1808
+ output_scale=data.pred_class_maps.shape[0] / data.image.shape[0],
1809
+ )
1417
1810
  return fig
1418
1811
 
1419
1812
  def forward(self, img):
1420
1813
  """Forward pass of the model."""
1421
1814
  img = torch.squeeze(img, dim=1).to(self.device)
1815
+ img = normalize_on_gpu(img)
1422
1816
  output = self.model(img)
1423
1817
  return {
1424
1818
  "MultiInstanceConfmapsHead": output["MultiInstanceConfmapsHead"],
@@ -1430,6 +1824,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1430
1824
  X = torch.squeeze(batch["image"], dim=1)
1431
1825
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1432
1826
  y_classmap = torch.squeeze(batch["class_maps"], dim=1)
1827
+ X = normalize_on_gpu(X)
1433
1828
  preds = self.model(X)
1434
1829
  classmaps = preds["ClassMapsHead"]
1435
1830
  confmaps = preds["MultiInstanceConfmapsHead"]
@@ -1453,15 +1848,84 @@ class BottomUpMultiClassLightningModule(LightningModel):
1453
1848
  "ClassMapsHead": classmaps_loss,
1454
1849
  }
1455
1850
  loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1851
+ # Log step-level loss (every batch, uses global_step x-axis)
1456
1852
  self.log(
1457
- "train_loss",
1853
+ "loss",
1458
1854
  loss,
1459
1855
  prog_bar=True,
1460
1856
  on_step=True,
1857
+ on_epoch=False,
1858
+ sync_dist=True,
1859
+ )
1860
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
1861
+ self._accumulate_loss(loss)
1862
+ self.log(
1863
+ "train/confmaps_loss",
1864
+ confmap_loss,
1865
+ on_step=False,
1866
+ on_epoch=True,
1867
+ sync_dist=True,
1868
+ )
1869
+ self.log(
1870
+ "train/classmap_loss",
1871
+ classmaps_loss,
1872
+ on_step=False,
1461
1873
  on_epoch=True,
1462
- logger=True,
1463
1874
  sync_dist=True,
1464
1875
  )
1876
+
1877
+ # Compute classification accuracy at GT keypoint locations
1878
+ with torch.no_grad():
1879
+ # Get output stride for class maps
1880
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1881
+
1882
+ # Get GT instances and sample class maps at those locations
1883
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1884
+ if instances.dim() == 5:
1885
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1886
+ num_instances = batch["num_instances"] # (batch,)
1887
+
1888
+ correct = 0
1889
+ total = 0
1890
+ for b in range(instances.shape[0]):
1891
+ n_inst = num_instances[b].item()
1892
+ for inst_idx in range(n_inst):
1893
+ for node_idx in range(instances.shape[2]):
1894
+ # Get keypoint location (in input image space)
1895
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
1896
+ if torch.isnan(kp).any():
1897
+ continue
1898
+
1899
+ # Convert to class map space
1900
+ x_cm = (
1901
+ (kp[0] / cms_stride)
1902
+ .long()
1903
+ .clamp(0, classmaps.shape[-1] - 1)
1904
+ )
1905
+ y_cm = (
1906
+ (kp[1] / cms_stride)
1907
+ .long()
1908
+ .clamp(0, classmaps.shape[-2] - 1)
1909
+ )
1910
+
1911
+ # Sample predicted and GT class at this location
1912
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
1913
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
1914
+
1915
+ if pred_class == gt_class:
1916
+ correct += 1
1917
+ total += 1
1918
+
1919
+ if total > 0:
1920
+ class_accuracy = torch.tensor(correct / total, device=X.device)
1921
+ self.log(
1922
+ "train/class_accuracy",
1923
+ class_accuracy,
1924
+ on_step=False,
1925
+ on_epoch=True,
1926
+ sync_dist=True,
1927
+ )
1928
+
1465
1929
  return loss
1466
1930
 
1467
1931
  def validation_step(self, batch, batch_idx):
@@ -1469,6 +1933,7 @@ class BottomUpMultiClassLightningModule(LightningModel):
1469
1933
  X = torch.squeeze(batch["image"], dim=1)
1470
1934
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1471
1935
  y_classmap = torch.squeeze(batch["class_maps"], dim=1)
1936
+ X = normalize_on_gpu(X)
1472
1937
 
1473
1938
  preds = self.model(X)
1474
1939
  classmaps = preds["ClassMapsHead"]
@@ -1494,26 +1959,128 @@ class BottomUpMultiClassLightningModule(LightningModel):
1494
1959
  }
1495
1960
 
1496
1961
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1497
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1498
1962
  self.log(
1499
- "learning_rate",
1500
- lr,
1963
+ "val/loss",
1964
+ val_loss,
1501
1965
  prog_bar=True,
1502
- on_step=True,
1966
+ on_step=False,
1503
1967
  on_epoch=True,
1504
- logger=True,
1505
1968
  sync_dist=True,
1506
1969
  )
1507
1970
  self.log(
1508
- "val_loss",
1509
- val_loss,
1510
- prog_bar=True,
1511
- on_step=True,
1971
+ "val/confmaps_loss",
1972
+ confmap_loss,
1973
+ on_step=False,
1974
+ on_epoch=True,
1975
+ sync_dist=True,
1976
+ )
1977
+ self.log(
1978
+ "val/classmap_loss",
1979
+ classmaps_loss,
1980
+ on_step=False,
1512
1981
  on_epoch=True,
1513
- logger=True,
1514
1982
  sync_dist=True,
1515
1983
  )
1516
1984
 
1985
+ # Compute classification accuracy at GT keypoint locations
1986
+ with torch.no_grad():
1987
+ # Get output stride for class maps
1988
+ cms_stride = self.head_configs.multi_class_bottomup.class_maps.output_stride
1989
+
1990
+ # Get GT instances and sample class maps at those locations
1991
+ instances = batch["instances"] # (batch, n_samples, max_inst, n_nodes, 2)
1992
+ if instances.dim() == 5:
1993
+ instances = instances.squeeze(1) # (batch, max_inst, n_nodes, 2)
1994
+ num_instances = batch["num_instances"] # (batch,)
1995
+
1996
+ correct = 0
1997
+ total = 0
1998
+ for b in range(instances.shape[0]):
1999
+ n_inst = num_instances[b].item()
2000
+ for inst_idx in range(n_inst):
2001
+ for node_idx in range(instances.shape[2]):
2002
+ # Get keypoint location (in input image space)
2003
+ kp = instances[b, inst_idx, node_idx] # (2,) = (x, y)
2004
+ if torch.isnan(kp).any():
2005
+ continue
2006
+
2007
+ # Convert to class map space
2008
+ x_cm = (
2009
+ (kp[0] / cms_stride)
2010
+ .long()
2011
+ .clamp(0, classmaps.shape[-1] - 1)
2012
+ )
2013
+ y_cm = (
2014
+ (kp[1] / cms_stride)
2015
+ .long()
2016
+ .clamp(0, classmaps.shape[-2] - 1)
2017
+ )
2018
+
2019
+ # Sample predicted and GT class at this location
2020
+ pred_class = classmaps[b, :, y_cm, x_cm].argmax()
2021
+ gt_class = y_classmap[b, :, y_cm, x_cm].argmax()
2022
+
2023
+ if pred_class == gt_class:
2024
+ correct += 1
2025
+ total += 1
2026
+
2027
+ if total > 0:
2028
+ class_accuracy = torch.tensor(correct / total, device=X.device)
2029
+ self.log(
2030
+ "val/class_accuracy",
2031
+ class_accuracy,
2032
+ on_step=False,
2033
+ on_epoch=True,
2034
+ sync_dist=True,
2035
+ )
2036
+
2037
+ # Collect predictions for epoch-end evaluation if enabled
2038
+ if self._collect_val_predictions:
2039
+ with torch.no_grad():
2040
+ # Note: Do NOT squeeze the image here - the forward() method expects
2041
+ # (batch, n_samples, C, H, W) and handles the n_samples squeeze internally
2042
+ inference_output = self.bottomup_inf_layer(batch)
2043
+ if isinstance(inference_output, list):
2044
+ inference_output = inference_output[0]
2045
+
2046
+ batch_size = len(batch["frame_idx"])
2047
+ for i in range(batch_size):
2048
+ eff = batch["eff_scale"][i].cpu().numpy()
2049
+
2050
+ # Predictions are already in original space (variable number of instances)
2051
+ pred_peaks = inference_output["pred_instance_peaks"][i]
2052
+ pred_scores = inference_output["pred_peak_values"][i]
2053
+ if torch.is_tensor(pred_peaks):
2054
+ pred_peaks = pred_peaks.cpu().numpy()
2055
+ if torch.is_tensor(pred_scores):
2056
+ pred_scores = pred_scores.cpu().numpy()
2057
+
2058
+ # Transform GT to original space
2059
+ # Note: instances have shape (1, max_inst, n_nodes, 2) - squeeze n_samples dim
2060
+ gt_prep = batch["instances"][i].cpu().numpy()
2061
+ if gt_prep.ndim == 4:
2062
+ gt_prep = gt_prep.squeeze(0) # (max_inst, n_nodes, 2)
2063
+ gt_orig = gt_prep / eff
2064
+ num_inst = batch["num_instances"][i].item()
2065
+ gt_orig = gt_orig[:num_inst] # Only valid instances
2066
+
2067
+ self.val_predictions.append(
2068
+ {
2069
+ "video_idx": batch["video_idx"][i].item(),
2070
+ "frame_idx": batch["frame_idx"][i].item(),
2071
+ "pred_peaks": pred_peaks, # Original space, variable instances
2072
+ "pred_scores": pred_scores,
2073
+ }
2074
+ )
2075
+ self.val_ground_truth.append(
2076
+ {
2077
+ "video_idx": batch["video_idx"][i].item(),
2078
+ "frame_idx": batch["frame_idx"][i].item(),
2079
+ "gt_instances": gt_orig, # Original space
2080
+ "num_instances": num_inst,
2081
+ }
2082
+ )
2083
+
1517
2084
 
1518
2085
  class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1519
2086
  """Lightning Module for TopDownCenteredInstance ID Model.
@@ -1607,8 +2174,8 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1607
2174
 
1608
2175
  self.node_names = self.head_configs.multi_class_topdown.confmaps.part_names
1609
2176
 
1610
- def visualize_example(self, sample):
1611
- """Visualize predictions during training (used with callbacks)."""
2177
+ def get_visualization_data(self, sample) -> VisualizationData:
2178
+ """Extract visualization data from a sample."""
1612
2179
  ex = sample.copy()
1613
2180
  ex["eff_scale"] = torch.tensor([1.0])
1614
2181
  for k, v in ex.items():
@@ -1616,27 +2183,41 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1616
2183
  ex[k] = v.to(device=self.device)
1617
2184
  ex["instance_image"] = ex["instance_image"].unsqueeze(dim=0)
1618
2185
  output = self.instance_peaks_inf_layer(ex)
2186
+
1619
2187
  peaks = output["pred_instance_peaks"].cpu().numpy()
1620
- img = (
1621
- output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1622
- ) # convert from (C, H, W) to (H, W, C)
2188
+ peak_values = output["pred_peak_values"].cpu().numpy()
2189
+ img = output["instance_image"][0, 0].cpu().numpy().transpose(1, 2, 0)
1623
2190
  gt_instances = ex["instance"].cpu().numpy()
1624
- confmaps = (
1625
- output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
1626
- ) # convert from (C, H, W) to (H, W, C)
2191
+ confmaps = output["pred_confmaps"][0].cpu().numpy().transpose(1, 2, 0)
2192
+
2193
+ return VisualizationData(
2194
+ image=img,
2195
+ pred_confmaps=confmaps,
2196
+ pred_peaks=peaks,
2197
+ pred_peak_values=peak_values,
2198
+ gt_instances=gt_instances,
2199
+ node_names=list(self.node_names) if self.node_names else [],
2200
+ output_scale=confmaps.shape[0] / img.shape[0],
2201
+ is_paired=True,
2202
+ )
2203
+
2204
+ def visualize_example(self, sample):
2205
+ """Visualize predictions during training (used with callbacks)."""
2206
+ data = self.get_visualization_data(sample)
1627
2207
  scale = 1.0
1628
- if img.shape[0] < 512:
2208
+ if data.image.shape[0] < 512:
1629
2209
  scale = 2.0
1630
- if img.shape[0] < 256:
2210
+ if data.image.shape[0] < 256:
1631
2211
  scale = 4.0
1632
- fig = plot_img(img, dpi=72 * scale, scale=scale)
1633
- plot_confmaps(confmaps, output_scale=confmaps.shape[0] / img.shape[0])
1634
- plot_peaks(gt_instances, peaks, paired=True)
2212
+ fig = plot_img(data.image, dpi=72 * scale, scale=scale)
2213
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
2214
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
1635
2215
  return fig
1636
2216
 
1637
2217
  def forward(self, img):
1638
2218
  """Forward pass of the model."""
1639
2219
  img = torch.squeeze(img, dim=1).to(self.device)
2220
+ img = normalize_on_gpu(img)
1640
2221
  output = self.model(img)
1641
2222
  return {
1642
2223
  "CenteredInstanceConfmapsHead": output["CenteredInstanceConfmapsHead"],
@@ -1648,6 +2229,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1648
2229
  X = torch.squeeze(batch["instance_image"], dim=1)
1649
2230
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1650
2231
  y_classvector = batch["class_vectors"]
2232
+ X = normalize_on_gpu(X)
1651
2233
  preds = self.model(X)
1652
2234
  classvector = preds["ClassVectorsHead"]
1653
2235
  confmaps = preds["CenteredInstanceConfmapsHead"]
@@ -1679,22 +2261,50 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1679
2261
  channel_wise_loss = torch.sum(mse, dim=(0, 2, 3)) / (batch_size * h * w)
1680
2262
  for node_idx, name in enumerate(self.node_names):
1681
2263
  self.log(
1682
- f"{name}",
2264
+ f"train/confmaps/{name}",
1683
2265
  channel_wise_loss[node_idx],
1684
- prog_bar=True,
1685
- on_step=True,
2266
+ prog_bar=False,
2267
+ on_step=False,
1686
2268
  on_epoch=True,
1687
- logger=True,
1688
2269
  sync_dist=True,
1689
2270
  )
1690
2271
 
2272
+ # Log step-level loss (every batch, uses global_step x-axis)
1691
2273
  self.log(
1692
- "train_loss",
2274
+ "loss",
1693
2275
  loss,
1694
2276
  prog_bar=True,
1695
2277
  on_step=True,
2278
+ on_epoch=False,
2279
+ sync_dist=True,
2280
+ )
2281
+ # Accumulate for epoch-averaged loss (logged in on_train_epoch_end)
2282
+ self._accumulate_loss(loss)
2283
+ self.log(
2284
+ "train/confmaps_loss",
2285
+ confmap_loss,
2286
+ on_step=False,
2287
+ on_epoch=True,
2288
+ sync_dist=True,
2289
+ )
2290
+ self.log(
2291
+ "train/classvector_loss",
2292
+ classvector_loss,
2293
+ on_step=False,
2294
+ on_epoch=True,
2295
+ sync_dist=True,
2296
+ )
2297
+
2298
+ # Compute classification accuracy
2299
+ with torch.no_grad():
2300
+ pred_classes = torch.argmax(classvector, dim=1)
2301
+ gt_classes = torch.argmax(y_classvector, dim=1)
2302
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2303
+ self.log(
2304
+ "train/class_accuracy",
2305
+ class_accuracy,
2306
+ on_step=False,
1696
2307
  on_epoch=True,
1697
- logger=True,
1698
2308
  sync_dist=True,
1699
2309
  )
1700
2310
  return loss
@@ -1704,6 +2314,7 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1704
2314
  X = torch.squeeze(batch["instance_image"], dim=1)
1705
2315
  y_confmap = torch.squeeze(batch["confidence_maps"], dim=1)
1706
2316
  y_classvector = batch["class_vectors"]
2317
+ X = normalize_on_gpu(X)
1707
2318
  preds = self.model(X)
1708
2319
  classvector = preds["ClassVectorsHead"]
1709
2320
  confmaps = preds["CenteredInstanceConfmapsHead"]
@@ -1727,23 +2338,94 @@ class TopDownCenteredInstanceMultiClassLightningModule(LightningModel):
1727
2338
  "ClassVectorsHead": classvector_loss,
1728
2339
  }
1729
2340
  val_loss = sum([s * losses[t] for s, t in zip(self.loss_weights, losses)])
1730
-
1731
- lr = self.optimizers().optimizer.param_groups[0]["lr"]
1732
2341
  self.log(
1733
- "learning_rate",
1734
- lr,
2342
+ "val/loss",
2343
+ val_loss,
1735
2344
  prog_bar=True,
1736
- on_step=True,
2345
+ on_step=False,
1737
2346
  on_epoch=True,
1738
- logger=True,
1739
2347
  sync_dist=True,
1740
2348
  )
1741
2349
  self.log(
1742
- "val_loss",
1743
- val_loss,
1744
- prog_bar=True,
1745
- on_step=True,
2350
+ "val/confmaps_loss",
2351
+ confmap_loss,
2352
+ on_step=False,
1746
2353
  on_epoch=True,
1747
- logger=True,
1748
2354
  sync_dist=True,
1749
2355
  )
2356
+ self.log(
2357
+ "val/classvector_loss",
2358
+ classvector_loss,
2359
+ on_step=False,
2360
+ on_epoch=True,
2361
+ sync_dist=True,
2362
+ )
2363
+
2364
+ # Compute classification accuracy
2365
+ with torch.no_grad():
2366
+ pred_classes = torch.argmax(classvector, dim=1)
2367
+ gt_classes = torch.argmax(y_classvector, dim=1)
2368
+ class_accuracy = (pred_classes == gt_classes).float().mean()
2369
+ self.log(
2370
+ "val/class_accuracy",
2371
+ class_accuracy,
2372
+ on_step=False,
2373
+ on_epoch=True,
2374
+ sync_dist=True,
2375
+ )
2376
+
2377
+ # Collect predictions for epoch-end evaluation if enabled
2378
+ if self._collect_val_predictions:
2379
+ # SAVE bbox BEFORE inference (it modifies in-place!)
2380
+ bbox_prep_saved = batch["instance_bbox"].clone()
2381
+
2382
+ with torch.no_grad():
2383
+ inference_output = self.instance_peaks_inf_layer(batch)
2384
+
2385
+ batch_size = len(batch["frame_idx"])
2386
+ for i in range(batch_size):
2387
+ eff = batch["eff_scale"][i].cpu().numpy()
2388
+
2389
+ # Predictions from inference (crop-relative, original scale)
2390
+ pred_peaks_crop = (
2391
+ inference_output["pred_instance_peaks"][i].cpu().numpy()
2392
+ )
2393
+ pred_scores = inference_output["pred_peak_values"][i].cpu().numpy()
2394
+
2395
+ # Compute bbox offset in original space from SAVED prep bbox
2396
+ # bbox has shape (n_samples=1, 4, 2) where 4 corners
2397
+ bbox_prep = bbox_prep_saved[i].squeeze(0).cpu().numpy() # (4, 2)
2398
+ bbox_top_left_orig = (
2399
+ bbox_prep[0] / eff
2400
+ ) # Top-left corner in original space
2401
+
2402
+ # Full image coordinates (original space)
2403
+ pred_peaks_full = pred_peaks_crop + bbox_top_left_orig
2404
+
2405
+ # GT transform: crop-relative preprocessed -> full image original
2406
+ gt_crop_prep = (
2407
+ batch["instance"][i].squeeze(0).cpu().numpy()
2408
+ ) # (n_nodes, 2)
2409
+ gt_crop_orig = gt_crop_prep / eff
2410
+ gt_full_orig = gt_crop_orig + bbox_top_left_orig
2411
+
2412
+ self.val_predictions.append(
2413
+ {
2414
+ "video_idx": batch["video_idx"][i].item(),
2415
+ "frame_idx": batch["frame_idx"][i].item(),
2416
+ "pred_peaks": pred_peaks_full.reshape(
2417
+ 1, -1, 2
2418
+ ), # (1, n_nodes, 2)
2419
+ "pred_scores": pred_scores.reshape(1, -1), # (1, n_nodes)
2420
+ }
2421
+ )
2422
+ self.val_ground_truth.append(
2423
+ {
2424
+ "video_idx": batch["video_idx"][i].item(),
2425
+ "frame_idx": batch["frame_idx"][i].item(),
2426
+ "gt_instances": gt_full_orig.reshape(
2427
+ 1, -1, 2
2428
+ ), # (1, n_nodes, 2)
2429
+ "num_instances": 1,
2430
+ }
2431
+ )