sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -5
  8. sleap_nn/config/trainer_config.py +0 -71
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +34 -364
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -69
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +204 -201
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
@@ -85,15 +85,10 @@ class CSVLoggerCallback(Callback):
85
85
  if key == "epoch":
86
86
  log_data["epoch"] = trainer.current_epoch
87
87
  elif key == "learning_rate":
88
- # Handle multiple formats:
89
- # 1. Direct "learning_rate" key
90
- # 2. "train/lr" key (current format from lightning modules)
91
- # 3. "lr-*" keys from LearningRateMonitor (legacy)
88
+ # Handle both direct logging and LearningRateMonitor format (lr-*)
92
89
  value = metrics.get(key, None)
93
90
  if value is None:
94
- value = metrics.get("train/lr", None)
95
- if value is None:
96
- # Look for lr-* keys from LearningRateMonitor (legacy)
91
+ # Look for lr-* keys from LearningRateMonitor
97
92
  for metric_key in metrics.keys():
98
93
  if metric_key.startswith("lr-"):
99
94
  value = metrics[metric_key]
@@ -286,49 +281,45 @@ class WandBVizCallback(Callback):
286
281
 
287
282
  # Get the wandb logger to use its experiment for logging
288
283
  wandb_logger = self._get_wandb_logger(trainer)
284
+ if wandb_logger is None:
285
+ return # No wandb logger, skip visualization logging
286
+
287
+ # Get visualization data
288
+ train_data = self.train_viz_fn()
289
+ val_data = self.val_viz_fn()
290
+
291
+ # Render and log for each enabled mode
292
+ # Use the logger's experiment to let Lightning manage step tracking
293
+ log_dict = {}
294
+ for mode_name, renderer in self.renderers.items():
295
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
296
+ train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
297
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
298
+ log_dict[f"train_predictions{suffix}"] = train_img
299
+ log_dict[f"val_predictions{suffix}"] = val_img
300
+
301
+ if log_dict:
302
+ # Include epoch so wandb can use it as x-axis (via define_metric)
303
+ log_dict["epoch"] = epoch
304
+ # Use commit=False to accumulate with other metrics in this step
305
+ # Lightning will commit when it logs its own metrics
306
+ wandb_logger.experiment.log(log_dict, commit=False)
307
+
308
+ # Optionally also log to table for backwards compat
309
+ if self.log_table and "direct" in self.renderers:
310
+ train_img = self.renderers["direct"].render(
311
+ train_data, caption=f"Train Epoch {epoch}"
312
+ )
313
+ val_img = self.renderers["direct"].render(
314
+ val_data, caption=f"Val Epoch {epoch}"
315
+ )
316
+ table = wandb.Table(
317
+ columns=["Epoch", "Train", "Validation"],
318
+ data=[[epoch, train_img, val_img]],
319
+ )
320
+ wandb_logger.experiment.log({"predictions_table": table}, commit=False)
289
321
 
290
- # Only do visualization work if wandb logger is available
291
- if wandb_logger is not None:
292
- # Get visualization data
293
- train_data = self.train_viz_fn()
294
- val_data = self.val_viz_fn()
295
-
296
- # Render and log for each enabled mode
297
- # Use the logger's experiment to let Lightning manage step tracking
298
- log_dict = {}
299
- for mode_name, renderer in self.renderers.items():
300
- suffix = "" if mode_name == "direct" else f"_{mode_name}"
301
- train_img = renderer.render(
302
- train_data, caption=f"Train Epoch {epoch}"
303
- )
304
- val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
305
- log_dict[f"viz/train/predictions{suffix}"] = train_img
306
- log_dict[f"viz/val/predictions{suffix}"] = val_img
307
-
308
- if log_dict:
309
- # Include epoch so wandb can use it as x-axis (via define_metric)
310
- log_dict["epoch"] = epoch
311
- # Use commit=False to accumulate with other metrics in this step
312
- # Lightning will commit when it logs its own metrics
313
- wandb_logger.experiment.log(log_dict, commit=False)
314
-
315
- # Optionally also log to table for backwards compat
316
- if self.log_table and "direct" in self.renderers:
317
- train_img = self.renderers["direct"].render(
318
- train_data, caption=f"Train Epoch {epoch}"
319
- )
320
- val_img = self.renderers["direct"].render(
321
- val_data, caption=f"Val Epoch {epoch}"
322
- )
323
- table = wandb.Table(
324
- columns=["Epoch", "Train", "Validation"],
325
- data=[[epoch, train_img, val_img]],
326
- )
327
- wandb_logger.experiment.log(
328
- {"predictions_table": table}, commit=False
329
- )
330
-
331
- # Sync all processes - barrier must be reached by ALL ranks
322
+ # Sync all processes
332
323
  trainer.strategy.barrier()
333
324
 
334
325
 
@@ -387,377 +378,80 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
387
378
 
388
379
  # Get the wandb logger to use its experiment for logging
389
380
  wandb_logger = self._get_wandb_logger(trainer)
390
-
391
- # Only do visualization work if wandb logger is available
392
- if wandb_logger is not None:
393
- # Get visualization data
394
- train_data = self.train_viz_fn()
395
- val_data = self.val_viz_fn()
396
- train_pafs_data = self.train_pafs_viz_fn()
397
- val_pafs_data = self.val_pafs_viz_fn()
398
-
399
- # Render and log for each enabled mode
400
- # Use the logger's experiment to let Lightning manage step tracking
401
- log_dict = {}
402
- for mode_name, renderer in self.renderers.items():
403
- suffix = "" if mode_name == "direct" else f"_{mode_name}"
404
- train_img = renderer.render(
405
- train_data, caption=f"Train Epoch {epoch}"
406
- )
407
- val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
408
- log_dict[f"viz/train/predictions{suffix}"] = train_img
409
- log_dict[f"viz/val/predictions{suffix}"] = val_img
410
-
411
- # Render PAFs (always use matplotlib/direct for PAFs)
412
- from io import BytesIO
413
- import matplotlib.pyplot as plt
414
- from PIL import Image
415
-
416
- train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
417
- buf = BytesIO()
418
- train_pafs_fig.savefig(
419
- buf, format="png", bbox_inches="tight", pad_inches=0
420
- )
421
- buf.seek(0)
422
- plt.close(train_pafs_fig)
423
- train_pafs_pil = Image.open(buf)
424
- log_dict["viz/train/pafs"] = wandb.Image(
425
- train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
426
- )
427
-
428
- val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
429
- buf = BytesIO()
430
- val_pafs_fig.savefig(
431
- buf, format="png", bbox_inches="tight", pad_inches=0
432
- )
433
- buf.seek(0)
434
- plt.close(val_pafs_fig)
435
- val_pafs_pil = Image.open(buf)
436
- log_dict["viz/val/pafs"] = wandb.Image(
437
- val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
438
- )
439
-
440
- if log_dict:
441
- # Include epoch so wandb can use it as x-axis (via define_metric)
442
- log_dict["epoch"] = epoch
443
- # Use commit=False to accumulate with other metrics in this step
444
- # Lightning will commit when it logs its own metrics
445
- wandb_logger.experiment.log(log_dict, commit=False)
446
-
447
- # Optionally also log to table
448
- if self.log_table and "direct" in self.renderers:
449
- train_img = self.renderers["direct"].render(
450
- train_data, caption=f"Train Epoch {epoch}"
451
- )
452
- val_img = self.renderers["direct"].render(
453
- val_data, caption=f"Val Epoch {epoch}"
454
- )
455
- table = wandb.Table(
456
- columns=[
457
- "Epoch",
458
- "Train",
459
- "Validation",
460
- "Train PAFs",
461
- "Val PAFs",
462
- ],
463
- data=[
464
- [
465
- epoch,
466
- train_img,
467
- val_img,
468
- log_dict["viz/train/pafs"],
469
- log_dict["viz/val/pafs"],
470
- ]
471
- ],
472
- )
473
- wandb_logger.experiment.log(
474
- {"predictions_table": table}, commit=False
475
- )
476
-
477
- # Sync all processes - barrier must be reached by ALL ranks
478
- trainer.strategy.barrier()
479
-
480
-
481
- class UnifiedVizCallback(Callback):
482
- """Unified callback for all visualization outputs during training.
483
-
484
- This callback consolidates all visualization functionality into a single callback,
485
- eliminating redundant dataset copies and inference runs. It handles:
486
- - Local disk saving (matplotlib figures)
487
- - WandB logging (multiple modes: direct, boxes, masks)
488
- - Model-specific visualizations (PAFs for bottomup, class maps for multi_class_bottomup)
489
-
490
- Benefits over separate callbacks:
491
- - Uses ONE sample per epoch for all visualizations (no dataset deepcopy)
492
- - Runs inference ONCE per sample (vs 4-8x in previous implementation)
493
- - Outputs to multiple destinations from the same data
494
- - Simpler code with less duplication
495
-
496
- Attributes:
497
- model_trainer: Reference to the ModelTrainer (for lazy access to lightning_model).
498
- train_pipeline: Iterator over training visualization dataset.
499
- val_pipeline: Iterator over validation visualization dataset.
500
- model_type: Type of model (affects which visualizations are enabled).
501
- save_local: Whether to save matplotlib figures to disk.
502
- local_save_dir: Directory for local visualization saves.
503
- log_wandb: Whether to log visualizations to wandb.
504
- wandb_modes: List of wandb rendering modes ("direct", "boxes", "masks").
505
- wandb_box_size: Size of keypoint boxes in pixels (for "boxes" mode).
506
- wandb_confmap_threshold: Threshold for confmap masks (for "masks" mode).
507
- log_wandb_table: Whether to also log to a wandb.Table.
508
- """
509
-
510
- def __init__(
511
- self,
512
- model_trainer,
513
- train_dataset,
514
- val_dataset,
515
- model_type: str,
516
- save_local: bool = True,
517
- local_save_dir: Optional[Path] = None,
518
- log_wandb: bool = False,
519
- wandb_modes: Optional[list] = None,
520
- wandb_box_size: float = 5.0,
521
- wandb_confmap_threshold: float = 0.1,
522
- log_wandb_table: bool = False,
523
- ):
524
- """Initialize the unified visualization callback.
525
-
526
- Args:
527
- model_trainer: ModelTrainer instance (lightning_model accessed lazily).
528
- train_dataset: Training visualization dataset (will be cycled).
529
- val_dataset: Validation visualization dataset (will be cycled).
530
- model_type: Model type string (e.g., "bottomup", "multi_class_bottomup").
531
- save_local: If True, save matplotlib figures to local_save_dir.
532
- local_save_dir: Path to directory for saving visualization images.
533
- log_wandb: If True, log visualizations to wandb.
534
- wandb_modes: List of wandb rendering modes. Defaults to ["direct"].
535
- wandb_box_size: Size of keypoint boxes in pixels.
536
- wandb_confmap_threshold: Threshold for confidence map masks.
537
- log_wandb_table: If True, also log to a wandb.Table.
538
- """
539
- super().__init__()
540
- from itertools import cycle
541
-
542
- self.model_trainer = model_trainer
543
- self.train_pipeline = cycle(train_dataset)
544
- self.val_pipeline = cycle(val_dataset)
545
- self.model_type = model_type
546
-
547
- # Local disk config
548
- self.save_local = save_local
549
- self.local_save_dir = local_save_dir
550
-
551
- # WandB config
552
- self.log_wandb = log_wandb
553
- self.wandb_modes = wandb_modes or ["direct"]
554
- self.wandb_box_size = wandb_box_size
555
- self.wandb_confmap_threshold = wandb_confmap_threshold
556
- self.log_wandb_table = log_wandb_table
557
-
558
- # Auto-enable model-specific visualizations
559
- self.viz_pafs = model_type == "bottomup"
560
- self.viz_class_maps = model_type == "multi_class_bottomup"
561
-
562
- # Initialize renderers
563
- from sleap_nn.training.utils import MatplotlibRenderer, WandBRenderer
564
-
565
- self._mpl_renderer = MatplotlibRenderer()
566
-
567
- # Create wandb renderers for each enabled mode
568
- self._wandb_renderers = {}
569
- if log_wandb:
570
- for mode in self.wandb_modes:
571
- self._wandb_renderers[mode] = WandBRenderer(
572
- mode=mode,
573
- box_size=wandb_box_size,
574
- confmap_threshold=wandb_confmap_threshold,
575
- )
576
-
577
- def _get_wandb_logger(self, trainer):
578
- """Get the WandbLogger from trainer's loggers."""
579
- from lightning.pytorch.loggers import WandbLogger
580
-
581
- for log in trainer.loggers:
582
- if isinstance(log, WandbLogger):
583
- return log
584
- return None
585
-
586
- def _get_viz_data(self, sample):
587
- """Get visualization data with all needed fields based on model type.
588
-
589
- Args:
590
- sample: A sample from the visualization dataset.
591
-
592
- Returns:
593
- VisualizationData with appropriate fields populated.
594
- """
595
- # Build kwargs based on model type
596
- kwargs = {}
597
- if self.viz_pafs:
598
- kwargs["include_pafs"] = True
599
- if self.viz_class_maps:
600
- kwargs["include_class_maps"] = True
601
-
602
- # Access lightning_model lazily from model_trainer
603
- return self.model_trainer.lightning_model.get_visualization_data(
604
- sample, **kwargs
605
- )
606
-
607
- def _save_local_viz(self, data, prefix: str, epoch: int):
608
- """Save visualization to local disk.
609
-
610
- Args:
611
- data: VisualizationData object.
612
- prefix: Filename prefix (e.g., "train", "validation").
613
- epoch: Current epoch number.
614
- """
615
- if not self.save_local or self.local_save_dir is None:
616
- return
617
-
618
- # Confmaps visualization
619
- fig = self._mpl_renderer.render(data)
620
- fig_path = self.local_save_dir / f"{prefix}.{epoch:04d}.png"
621
- fig.savefig(fig_path, format="png")
622
- plt.close(fig)
623
-
624
- # PAFs visualization (for bottomup models)
625
- if self.viz_pafs and data.pred_pafs is not None:
626
- fig = self._mpl_renderer.render_pafs(data)
627
- fig_path = self.local_save_dir / f"{prefix}.pafs_magnitude.{epoch:04d}.png"
628
- fig.savefig(fig_path, format="png")
629
- plt.close(fig)
630
-
631
- # Class maps visualization (for multi_class_bottomup models)
632
- if self.viz_class_maps and data.pred_class_maps is not None:
633
- fig = self._render_class_maps(data)
634
- fig_path = self.local_save_dir / f"{prefix}.class_maps.{epoch:04d}.png"
635
- fig.savefig(fig_path, format="png")
636
- plt.close(fig)
637
-
638
- def _render_class_maps(self, data):
639
- """Render class maps visualization.
640
-
641
- Args:
642
- data: VisualizationData with pred_class_maps populated.
643
-
644
- Returns:
645
- A matplotlib Figure object.
646
- """
647
- from sleap_nn.training.utils import plot_img, plot_confmaps
648
-
649
- img = data.image
650
- scale = 1.0
651
- if img.shape[0] < 512:
652
- scale = 2.0
653
- if img.shape[0] < 256:
654
- scale = 4.0
655
-
656
- fig = plot_img(img, dpi=72 * scale, scale=scale)
657
- plot_confmaps(
658
- data.pred_class_maps,
659
- output_scale=data.pred_class_maps.shape[0] / img.shape[0],
660
- )
661
- return fig
662
-
663
- def _log_wandb_viz(self, data, prefix: str, epoch: int, wandb_logger):
664
- """Log visualization to wandb.
665
-
666
- Args:
667
- data: VisualizationData object.
668
- prefix: Log prefix (e.g., "train", "val").
669
- epoch: Current epoch number.
670
- wandb_logger: WandbLogger instance.
671
- """
672
- if not self.log_wandb or wandb_logger is None:
673
- return
674
-
675
- from io import BytesIO
676
- from PIL import Image as PILImage
677
-
678
- log_dict = {}
679
-
680
- # Render confmaps for each enabled mode
681
- for mode_name, renderer in self._wandb_renderers.items():
682
- suffix = "" if mode_name == "direct" else f"_{mode_name}"
683
- img = renderer.render(data, caption=f"{prefix.title()} Epoch {epoch}")
684
- log_dict[f"viz/{prefix}/predictions{suffix}"] = img
685
-
686
- # PAFs visualization (for bottomup models)
687
- if self.viz_pafs and data.pred_pafs is not None:
688
- pafs_fig = self._mpl_renderer.render_pafs(data)
381
+ if wandb_logger is None:
382
+ return # No wandb logger, skip visualization logging
383
+
384
+ # Get visualization data
385
+ train_data = self.train_viz_fn()
386
+ val_data = self.val_viz_fn()
387
+ train_pafs_data = self.train_pafs_viz_fn()
388
+ val_pafs_data = self.val_pafs_viz_fn()
389
+
390
+ # Render and log for each enabled mode
391
+ # Use the logger's experiment to let Lightning manage step tracking
392
+ log_dict = {}
393
+ for mode_name, renderer in self.renderers.items():
394
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
395
+ train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
396
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
397
+ log_dict[f"train_predictions{suffix}"] = train_img
398
+ log_dict[f"val_predictions{suffix}"] = val_img
399
+
400
+ # Render PAFs (always use matplotlib/direct for PAFs)
401
+ from io import BytesIO
402
+ import matplotlib.pyplot as plt
403
+ from PIL import Image
404
+
405
+ train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
689
406
  buf = BytesIO()
690
- pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
407
+ train_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
691
408
  buf.seek(0)
692
- plt.close(pafs_fig)
693
- pafs_pil = PILImage.open(buf)
694
- log_dict[f"viz/{prefix}/pafs"] = wandb.Image(
695
- pafs_pil, caption=f"{prefix.title()} PAFs Epoch {epoch}"
409
+ plt.close(train_pafs_fig)
410
+ train_pafs_pil = Image.open(buf)
411
+ log_dict["train_pafs"] = wandb.Image(
412
+ train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
696
413
  )
697
414
 
698
- # Class maps visualization (for multi_class_bottomup models)
699
- if self.viz_class_maps and data.pred_class_maps is not None:
700
- class_fig = self._render_class_maps(data)
415
+ val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
701
416
  buf = BytesIO()
702
- class_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
417
+ val_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
703
418
  buf.seek(0)
704
- plt.close(class_fig)
705
- class_pil = PILImage.open(buf)
706
- log_dict[f"viz/{prefix}/class_maps"] = wandb.Image(
707
- class_pil, caption=f"{prefix.title()} Class Maps Epoch {epoch}"
708
- )
709
-
710
- if log_dict:
711
- log_dict["epoch"] = epoch
712
- wandb_logger.experiment.log(log_dict, commit=False)
713
-
714
- # Optionally log to table for backwards compatibility
715
- if self.log_wandb_table and "direct" in self._wandb_renderers:
716
- train_img = self._wandb_renderers["direct"].render(
717
- data, caption=f"{prefix.title()} Epoch {epoch}"
718
- )
719
- table_data = [[epoch, train_img]]
720
- columns = ["Epoch", prefix.title()]
721
-
722
- if self.viz_pafs and data.pred_pafs is not None:
723
- columns.append(f"{prefix.title()} PAFs")
724
- table_data[0].append(log_dict.get(f"viz/{prefix}/pafs"))
725
-
726
- if self.viz_class_maps and data.pred_class_maps is not None:
727
- columns.append(f"{prefix.title()} Class Maps")
728
- table_data[0].append(log_dict.get(f"viz/{prefix}/class_maps"))
729
-
730
- table = wandb.Table(columns=columns, data=table_data)
731
- wandb_logger.experiment.log(
732
- {f"predictions_table_{prefix}": table}, commit=False
419
+ plt.close(val_pafs_fig)
420
+ val_pafs_pil = Image.open(buf)
421
+ log_dict["val_pafs"] = wandb.Image(
422
+ val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
733
423
  )
734
424
 
735
- def on_train_epoch_end(self, trainer, pl_module):
736
- """Generate and output all visualizations at epoch end.
425
+ if log_dict:
426
+ # Include epoch so wandb can use it as x-axis (via define_metric)
427
+ log_dict["epoch"] = epoch
428
+ # Use commit=False to accumulate with other metrics in this step
429
+ # Lightning will commit when it logs its own metrics
430
+ wandb_logger.experiment.log(log_dict, commit=False)
431
+
432
+ # Optionally also log to table
433
+ if self.log_table and "direct" in self.renderers:
434
+ train_img = self.renderers["direct"].render(
435
+ train_data, caption=f"Train Epoch {epoch}"
436
+ )
437
+ val_img = self.renderers["direct"].render(
438
+ val_data, caption=f"Val Epoch {epoch}"
439
+ )
440
+ table = wandb.Table(
441
+ columns=["Epoch", "Train", "Validation", "Train PAFs", "Val PAFs"],
442
+ data=[
443
+ [
444
+ epoch,
445
+ train_img,
446
+ val_img,
447
+ log_dict["train_pafs"],
448
+ log_dict["val_pafs"],
449
+ ]
450
+ ],
451
+ )
452
+ wandb_logger.experiment.log({"predictions_table": table}, commit=False)
737
453
 
738
- Args:
739
- trainer: PyTorch Lightning trainer.
740
- pl_module: Lightning module (not used, we use self.lightning_module).
741
- """
742
- if trainer.is_global_zero:
743
- epoch = trainer.current_epoch
744
- wandb_logger = self._get_wandb_logger(trainer) if self.log_wandb else None
745
-
746
- # Get ONE sample for train visualization
747
- train_sample = next(self.train_pipeline)
748
- # Run inference ONCE with all needed data
749
- train_data = self._get_viz_data(train_sample)
750
- # Output to all destinations
751
- self._save_local_viz(train_data, "train", epoch)
752
- self._log_wandb_viz(train_data, "train", epoch, wandb_logger)
753
-
754
- # Same for validation
755
- val_sample = next(self.val_pipeline)
756
- val_data = self._get_viz_data(val_sample)
757
- self._save_local_viz(val_data, "validation", epoch)
758
- self._log_wandb_viz(val_data, "val", epoch, wandb_logger)
759
-
760
- # Sync all processes - barrier must be reached by ALL ranks
454
+ # Sync all processes
761
455
  trainer.strategy.barrier()
762
456
 
763
457
 
@@ -968,638 +662,3 @@ class ProgressReporterZMQ(Callback):
968
662
  return {
969
663
  k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
970
664
  }
971
-
972
-
973
- class EpochEndEvaluationCallback(Callback):
974
- """Callback to run full evaluation metrics at end of validation epochs.
975
-
976
- This callback collects predictions and ground truth during validation,
977
- then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
978
- metrics to WandB.
979
-
980
- Attributes:
981
- skeleton: sio.Skeleton for creating instances.
982
- videos: List of sio.Video objects.
983
- eval_frequency: Run evaluation every N epochs (default: 1).
984
- oks_stddev: OKS standard deviation (default: 0.025).
985
- oks_scale: Optional OKS scale override.
986
- metrics_to_log: List of metric keys to log.
987
- """
988
-
989
- def __init__(
990
- self,
991
- skeleton: "sio.Skeleton",
992
- videos: list,
993
- eval_frequency: int = 1,
994
- oks_stddev: float = 0.025,
995
- oks_scale: Optional[float] = None,
996
- metrics_to_log: Optional[list] = None,
997
- ):
998
- """Initialize the callback.
999
-
1000
- Args:
1001
- skeleton: sio.Skeleton for creating instances.
1002
- videos: List of sio.Video objects.
1003
- eval_frequency: Run evaluation every N epochs (default: 1).
1004
- oks_stddev: OKS standard deviation (default: 0.025).
1005
- oks_scale: Optional OKS scale override.
1006
- metrics_to_log: List of metric keys to log. If None, logs all available.
1007
- """
1008
- super().__init__()
1009
- self.skeleton = skeleton
1010
- self.videos = videos
1011
- self.eval_frequency = eval_frequency
1012
- self.oks_stddev = oks_stddev
1013
- self.oks_scale = oks_scale
1014
- self.metrics_to_log = metrics_to_log or [
1015
- "mOKS",
1016
- "oks_voc.mAP",
1017
- "oks_voc.mAR",
1018
- "distance/avg",
1019
- "distance/p50",
1020
- "distance/p95",
1021
- "distance/p99",
1022
- "mPCK",
1023
- "PCK@5",
1024
- "PCK@10",
1025
- "visibility_precision",
1026
- "visibility_recall",
1027
- ]
1028
-
1029
- def on_validation_epoch_start(self, trainer, pl_module):
1030
- """Enable prediction collection at the start of validation.
1031
-
1032
- Skip during sanity check to avoid inference issues.
1033
- """
1034
- if trainer.sanity_checking:
1035
- return
1036
- pl_module._collect_val_predictions = True
1037
-
1038
- def on_validation_epoch_end(self, trainer, pl_module):
1039
- """Run evaluation and log metrics at end of validation epoch."""
1040
- import sleap_io as sio
1041
- import numpy as np
1042
- from lightning.pytorch.loggers import WandbLogger
1043
- from sleap_nn.evaluation import Evaluator
1044
-
1045
- # Determine if we should run evaluation this epoch (only on rank 0)
1046
- should_evaluate = (
1047
- trainer.current_epoch + 1
1048
- ) % self.eval_frequency == 0 and trainer.is_global_zero
1049
-
1050
- if should_evaluate:
1051
- # Check if we have predictions
1052
- if not pl_module.val_predictions or not pl_module.val_ground_truth:
1053
- logger.warning("No predictions collected for epoch-end evaluation")
1054
- else:
1055
- try:
1056
- # Build sio.Labels from accumulated predictions and ground truth
1057
- pred_labels = self._build_pred_labels(
1058
- pl_module.val_predictions, sio, np
1059
- )
1060
- gt_labels = self._build_gt_labels(
1061
- pl_module.val_ground_truth, sio, np
1062
- )
1063
-
1064
- # Check if we have valid frames to evaluate
1065
- if len(pred_labels) == 0:
1066
- logger.warning(
1067
- "No valid predictions for epoch-end evaluation "
1068
- "(all predictions may be empty or NaN)"
1069
- )
1070
- else:
1071
- # Run evaluation
1072
- evaluator = Evaluator(
1073
- ground_truth_instances=gt_labels,
1074
- predicted_instances=pred_labels,
1075
- oks_stddev=self.oks_stddev,
1076
- oks_scale=self.oks_scale,
1077
- user_labels_only=False, # All validation frames are "user" frames
1078
- )
1079
- metrics = evaluator.evaluate()
1080
-
1081
- # Log to WandB
1082
- self._log_metrics(trainer, metrics, trainer.current_epoch)
1083
-
1084
- logger.info(
1085
- f"Epoch {trainer.current_epoch} evaluation: "
1086
- f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
1087
- f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
1088
- f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
1089
- )
1090
-
1091
- except Exception as e:
1092
- logger.warning(f"Epoch-end evaluation failed: {e}")
1093
-
1094
- # Cleanup - all ranks reset the flag, rank 0 clears the lists
1095
- pl_module._collect_val_predictions = False
1096
- if trainer.is_global_zero:
1097
- pl_module.val_predictions = []
1098
- pl_module.val_ground_truth = []
1099
-
1100
- # Sync all processes - barrier must be reached by ALL ranks
1101
- trainer.strategy.barrier()
1102
-
1103
- def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
1104
- """Convert prediction dicts to sio.Labels."""
1105
- labeled_frames = []
1106
- for pred in predictions:
1107
- pred_peaks = pred["pred_peaks"]
1108
- pred_scores = pred["pred_scores"]
1109
-
1110
- # Handle NaN/missing predictions
1111
- if pred_peaks is None or (
1112
- isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
1113
- ):
1114
- continue
1115
-
1116
- # Handle multi-instance predictions (bottomup)
1117
- if len(pred_peaks.shape) == 2:
1118
- # Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
1119
- pred_peaks = pred_peaks.reshape(1, -1, 2)
1120
- pred_scores = pred_scores.reshape(1, -1)
1121
-
1122
- instances = []
1123
- for inst_idx in range(len(pred_peaks)):
1124
- inst_points = pred_peaks[inst_idx]
1125
- inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
1126
-
1127
- # Skip if all NaN
1128
- if np.isnan(inst_points).all():
1129
- continue
1130
-
1131
- inst = sio.PredictedInstance.from_numpy(
1132
- points_data=inst_points,
1133
- skeleton=self.skeleton,
1134
- point_scores=(
1135
- inst_scores
1136
- if inst_scores is not None
1137
- else np.ones(len(inst_points))
1138
- ),
1139
- score=(
1140
- float(np.nanmean(inst_scores))
1141
- if inst_scores is not None
1142
- else 1.0
1143
- ),
1144
- )
1145
- instances.append(inst)
1146
-
1147
- if instances:
1148
- lf = sio.LabeledFrame(
1149
- video=self.videos[pred["video_idx"]],
1150
- frame_idx=pred["frame_idx"],
1151
- instances=instances,
1152
- )
1153
- labeled_frames.append(lf)
1154
-
1155
- return sio.Labels(
1156
- videos=self.videos,
1157
- skeletons=[self.skeleton],
1158
- labeled_frames=labeled_frames,
1159
- )
1160
-
1161
- def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
1162
- """Convert ground truth dicts to sio.Labels."""
1163
- labeled_frames = []
1164
- for gt in ground_truth:
1165
- instances = []
1166
- gt_instances = gt["gt_instances"]
1167
-
1168
- # Handle shape variations
1169
- if len(gt_instances.shape) == 2:
1170
- # (n_nodes, 2) -> (1, n_nodes, 2)
1171
- gt_instances = gt_instances.reshape(1, -1, 2)
1172
-
1173
- for i in range(min(gt["num_instances"], len(gt_instances))):
1174
- inst_data = gt_instances[i]
1175
- if np.isnan(inst_data).all():
1176
- continue
1177
- inst = sio.Instance.from_numpy(
1178
- points_data=inst_data,
1179
- skeleton=self.skeleton,
1180
- )
1181
- instances.append(inst)
1182
-
1183
- if instances:
1184
- lf = sio.LabeledFrame(
1185
- video=self.videos[gt["video_idx"]],
1186
- frame_idx=gt["frame_idx"],
1187
- instances=instances,
1188
- )
1189
- labeled_frames.append(lf)
1190
-
1191
- return sio.Labels(
1192
- videos=self.videos,
1193
- skeletons=[self.skeleton],
1194
- labeled_frames=labeled_frames,
1195
- )
1196
-
1197
- def _log_metrics(self, trainer, metrics: dict, epoch: int):
1198
- """Log evaluation metrics to WandB."""
1199
- import numpy as np
1200
- from lightning.pytorch.loggers import WandbLogger
1201
-
1202
- # Get WandB logger
1203
- wandb_logger = None
1204
- for log in trainer.loggers:
1205
- if isinstance(log, WandbLogger):
1206
- wandb_logger = log
1207
- break
1208
-
1209
- if wandb_logger is None:
1210
- return
1211
-
1212
- log_dict = {"epoch": epoch}
1213
-
1214
- # Extract key metrics with consistent naming
1215
- # All eval metrics use eval/val/ prefix since they're computed on validation data
1216
- if "mOKS" in self.metrics_to_log:
1217
- log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
1218
-
1219
- if "oks_voc.mAP" in self.metrics_to_log:
1220
- log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
1221
-
1222
- if "oks_voc.mAR" in self.metrics_to_log:
1223
- log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
1224
-
1225
- # Distance metrics grouped under eval/val/distance/
1226
- if "distance/avg" in self.metrics_to_log:
1227
- val = metrics["distance_metrics"]["avg"]
1228
- if not np.isnan(val):
1229
- log_dict["eval/val/distance/avg"] = val
1230
-
1231
- if "distance/p50" in self.metrics_to_log:
1232
- val = metrics["distance_metrics"]["p50"]
1233
- if not np.isnan(val):
1234
- log_dict["eval/val/distance/p50"] = val
1235
-
1236
- if "distance/p95" in self.metrics_to_log:
1237
- val = metrics["distance_metrics"]["p95"]
1238
- if not np.isnan(val):
1239
- log_dict["eval/val/distance/p95"] = val
1240
-
1241
- if "distance/p99" in self.metrics_to_log:
1242
- val = metrics["distance_metrics"]["p99"]
1243
- if not np.isnan(val):
1244
- log_dict["eval/val/distance/p99"] = val
1245
-
1246
- # PCK metrics
1247
- if "mPCK" in self.metrics_to_log:
1248
- log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
1249
-
1250
- # PCK at specific thresholds (precomputed in evaluation.py)
1251
- if "PCK@5" in self.metrics_to_log:
1252
- log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
1253
-
1254
- if "PCK@10" in self.metrics_to_log:
1255
- log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
1256
-
1257
- # Visibility metrics
1258
- if "visibility_precision" in self.metrics_to_log:
1259
- val = metrics["visibility_metrics"]["precision"]
1260
- if not np.isnan(val):
1261
- log_dict["eval/val/visibility_precision"] = val
1262
-
1263
- if "visibility_recall" in self.metrics_to_log:
1264
- val = metrics["visibility_metrics"]["recall"]
1265
- if not np.isnan(val):
1266
- log_dict["eval/val/visibility_recall"] = val
1267
-
1268
- wandb_logger.experiment.log(log_dict, commit=False)
1269
-
1270
- # Update best metrics in summary (excluding epoch)
1271
- for key, value in log_dict.items():
1272
- if key == "epoch":
1273
- continue
1274
- # Create summary key like "best/eval/val/mOKS"
1275
- summary_key = f"best/{key}"
1276
- current_best = wandb_logger.experiment.summary.get(summary_key)
1277
- # For distance metrics, lower is better; for others, higher is better
1278
- is_distance = "distance" in key
1279
- if current_best is None:
1280
- wandb_logger.experiment.summary[summary_key] = value
1281
- elif is_distance and value < current_best:
1282
- wandb_logger.experiment.summary[summary_key] = value
1283
- elif not is_distance and value > current_best:
1284
- wandb_logger.experiment.summary[summary_key] = value
1285
-
1286
-
1287
- def match_centroids(
1288
- pred_centroids: "np.ndarray",
1289
- gt_centroids: "np.ndarray",
1290
- max_distance: float = 50.0,
1291
- ) -> tuple:
1292
- """Match predicted centroids to ground truth using Hungarian algorithm.
1293
-
1294
- Args:
1295
- pred_centroids: Predicted centroid locations, shape (n_pred, 2).
1296
- gt_centroids: Ground truth centroid locations, shape (n_gt, 2).
1297
- max_distance: Maximum distance threshold for valid matches (in pixels).
1298
-
1299
- Returns:
1300
- Tuple of:
1301
- - matched_pred_indices: Indices of matched predictions
1302
- - matched_gt_indices: Indices of matched ground truth
1303
- - unmatched_pred_indices: Indices of unmatched predictions (false positives)
1304
- - unmatched_gt_indices: Indices of unmatched ground truth (false negatives)
1305
- """
1306
- import numpy as np
1307
- from scipy.optimize import linear_sum_assignment
1308
- from scipy.spatial.distance import cdist
1309
-
1310
- n_pred = len(pred_centroids)
1311
- n_gt = len(gt_centroids)
1312
-
1313
- # Handle edge cases
1314
- if n_pred == 0 and n_gt == 0:
1315
- return np.array([]), np.array([]), np.array([]), np.array([])
1316
- if n_pred == 0:
1317
- return np.array([]), np.array([]), np.array([]), np.arange(n_gt)
1318
- if n_gt == 0:
1319
- return np.array([]), np.array([]), np.arange(n_pred), np.array([])
1320
-
1321
- # Compute pairwise distances
1322
- cost_matrix = cdist(pred_centroids, gt_centroids)
1323
-
1324
- # Run Hungarian algorithm for optimal matching
1325
- pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
1326
-
1327
- # Filter matches that exceed max_distance
1328
- matched_pred = []
1329
- matched_gt = []
1330
- for p_idx, g_idx in zip(pred_indices, gt_indices):
1331
- if cost_matrix[p_idx, g_idx] <= max_distance:
1332
- matched_pred.append(p_idx)
1333
- matched_gt.append(g_idx)
1334
-
1335
- matched_pred = np.array(matched_pred)
1336
- matched_gt = np.array(matched_gt)
1337
-
1338
- # Find unmatched indices
1339
- all_pred = set(range(n_pred))
1340
- all_gt = set(range(n_gt))
1341
- unmatched_pred = np.array(list(all_pred - set(matched_pred)))
1342
- unmatched_gt = np.array(list(all_gt - set(matched_gt)))
1343
-
1344
- return matched_pred, matched_gt, unmatched_pred, unmatched_gt
1345
-
1346
-
1347
- class CentroidEvaluationCallback(Callback):
1348
- """Callback to run centroid-specific evaluation metrics at end of validation epochs.
1349
-
1350
- This callback is designed specifically for centroid models, which predict a single
1351
- point (centroid) per instance rather than full pose skeletons. It computes
1352
- distance-based metrics and detection metrics that are more appropriate for
1353
- point detection tasks than OKS/PCK metrics.
1354
-
1355
- Metrics computed:
1356
- - Distance metrics: mean, median, p90, p95, max Euclidean distance
1357
- - Detection metrics: precision, recall, F1 score
1358
- - Counts: true positives, false positives, false negatives
1359
-
1360
- Attributes:
1361
- videos: List of sio.Video objects.
1362
- eval_frequency: Run evaluation every N epochs (default: 1).
1363
- match_threshold: Maximum distance (pixels) for matching pred to GT (default: 50.0).
1364
- """
1365
-
1366
- def __init__(
1367
- self,
1368
- videos: list,
1369
- eval_frequency: int = 1,
1370
- match_threshold: float = 50.0,
1371
- ):
1372
- """Initialize the callback.
1373
-
1374
- Args:
1375
- videos: List of sio.Video objects.
1376
- eval_frequency: Run evaluation every N epochs (default: 1).
1377
- match_threshold: Maximum distance in pixels for a prediction to be
1378
- considered a match to a ground truth centroid (default: 50.0).
1379
- """
1380
- super().__init__()
1381
- self.videos = videos
1382
- self.eval_frequency = eval_frequency
1383
- self.match_threshold = match_threshold
1384
-
1385
- def on_validation_epoch_start(self, trainer, pl_module):
1386
- """Enable prediction collection at the start of validation.
1387
-
1388
- Skip during sanity check to avoid inference issues.
1389
- """
1390
- if trainer.sanity_checking:
1391
- return
1392
- pl_module._collect_val_predictions = True
1393
-
1394
- def on_validation_epoch_end(self, trainer, pl_module):
1395
- """Run centroid evaluation and log metrics at end of validation epoch."""
1396
- import numpy as np
1397
- from lightning.pytorch.loggers import WandbLogger
1398
-
1399
- # Determine if we should run evaluation this epoch (only on rank 0)
1400
- should_evaluate = (
1401
- trainer.current_epoch + 1
1402
- ) % self.eval_frequency == 0 and trainer.is_global_zero
1403
-
1404
- if should_evaluate:
1405
- # Check if we have predictions
1406
- if not pl_module.val_predictions or not pl_module.val_ground_truth:
1407
- logger.warning(
1408
- "No predictions collected for centroid epoch-end evaluation"
1409
- )
1410
- else:
1411
- try:
1412
- metrics = self._compute_metrics(
1413
- pl_module.val_predictions, pl_module.val_ground_truth, np
1414
- )
1415
-
1416
- # Log to WandB
1417
- self._log_metrics(trainer, metrics, trainer.current_epoch)
1418
-
1419
- logger.info(
1420
- f"Epoch {trainer.current_epoch} centroid evaluation: "
1421
- f"precision={metrics['precision']:.4f}, "
1422
- f"recall={metrics['recall']:.4f}, "
1423
- f"dist_avg={metrics['dist_avg']:.2f}px"
1424
- )
1425
-
1426
- except Exception as e:
1427
- logger.warning(f"Centroid epoch-end evaluation failed: {e}")
1428
-
1429
- # Cleanup - all ranks reset the flag, rank 0 clears the lists
1430
- pl_module._collect_val_predictions = False
1431
- if trainer.is_global_zero:
1432
- pl_module.val_predictions = []
1433
- pl_module.val_ground_truth = []
1434
-
1435
- # Sync all processes - barrier must be reached by ALL ranks
1436
- trainer.strategy.barrier()
1437
-
1438
- def _compute_metrics(self, predictions: list, ground_truth: list, np) -> dict:
1439
- """Compute centroid-specific metrics.
1440
-
1441
- Args:
1442
- predictions: List of prediction dicts with "pred_peaks" key.
1443
- ground_truth: List of ground truth dicts with "gt_instances" key.
1444
- np: NumPy module.
1445
-
1446
- Returns:
1447
- Dictionary of computed metrics.
1448
- """
1449
- all_distances = []
1450
- total_tp = 0
1451
- total_fp = 0
1452
- total_fn = 0
1453
-
1454
- # Group predictions and GT by frame
1455
- pred_by_frame = {}
1456
- for pred in predictions:
1457
- key = (pred["video_idx"], pred["frame_idx"])
1458
- if key not in pred_by_frame:
1459
- pred_by_frame[key] = []
1460
- # pred_peaks shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
1461
- centroids = pred["pred_peaks"].reshape(-1, 2)
1462
- # Filter out NaN centroids
1463
- valid_mask = ~np.isnan(centroids).any(axis=1)
1464
- pred_by_frame[key].append(centroids[valid_mask])
1465
-
1466
- gt_by_frame = {}
1467
- for gt in ground_truth:
1468
- key = (gt["video_idx"], gt["frame_idx"])
1469
- if key not in gt_by_frame:
1470
- gt_by_frame[key] = []
1471
- # gt_instances shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
1472
- centroids = gt["gt_instances"].reshape(-1, 2)
1473
- # Filter out NaN centroids
1474
- valid_mask = ~np.isnan(centroids).any(axis=1)
1475
- gt_by_frame[key].append(centroids[valid_mask])
1476
-
1477
- # Process each frame
1478
- all_frames = set(pred_by_frame.keys()) | set(gt_by_frame.keys())
1479
- for frame_key in all_frames:
1480
- # Concatenate all predictions for this frame
1481
- if frame_key in pred_by_frame:
1482
- frame_preds = np.concatenate(pred_by_frame[frame_key], axis=0)
1483
- else:
1484
- frame_preds = np.zeros((0, 2))
1485
-
1486
- # Concatenate all GT for this frame
1487
- if frame_key in gt_by_frame:
1488
- frame_gt = np.concatenate(gt_by_frame[frame_key], axis=0)
1489
- else:
1490
- frame_gt = np.zeros((0, 2))
1491
-
1492
- # Match predictions to ground truth
1493
- matched_pred, matched_gt, unmatched_pred, unmatched_gt = match_centroids(
1494
- frame_preds, frame_gt, max_distance=self.match_threshold
1495
- )
1496
-
1497
- # Compute distances for matched pairs
1498
- if len(matched_pred) > 0:
1499
- matched_pred_points = frame_preds[matched_pred]
1500
- matched_gt_points = frame_gt[matched_gt]
1501
- distances = np.linalg.norm(
1502
- matched_pred_points - matched_gt_points, axis=1
1503
- )
1504
- all_distances.extend(distances.tolist())
1505
-
1506
- # Update counts
1507
- total_tp += len(matched_pred)
1508
- total_fp += len(unmatched_pred)
1509
- total_fn += len(unmatched_gt)
1510
-
1511
- # Compute aggregate metrics
1512
- all_distances = np.array(all_distances)
1513
-
1514
- # Distance metrics (only if we have matches)
1515
- if len(all_distances) > 0:
1516
- dist_avg = float(np.mean(all_distances))
1517
- dist_median = float(np.median(all_distances))
1518
- dist_p90 = float(np.percentile(all_distances, 90))
1519
- dist_p95 = float(np.percentile(all_distances, 95))
1520
- dist_max = float(np.max(all_distances))
1521
- else:
1522
- dist_avg = dist_median = dist_p90 = dist_p95 = dist_max = float("nan")
1523
-
1524
- # Detection metrics
1525
- precision = (
1526
- total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
1527
- )
1528
- recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
1529
- f1 = (
1530
- 2 * precision * recall / (precision + recall)
1531
- if (precision + recall) > 0
1532
- else 0.0
1533
- )
1534
-
1535
- return {
1536
- "dist_avg": dist_avg,
1537
- "dist_median": dist_median,
1538
- "dist_p90": dist_p90,
1539
- "dist_p95": dist_p95,
1540
- "dist_max": dist_max,
1541
- "precision": precision,
1542
- "recall": recall,
1543
- "f1": f1,
1544
- "n_true_positives": total_tp,
1545
- "n_false_positives": total_fp,
1546
- "n_false_negatives": total_fn,
1547
- "n_total_predictions": total_tp + total_fp,
1548
- "n_total_ground_truth": total_tp + total_fn,
1549
- }
1550
-
1551
- def _log_metrics(self, trainer, metrics: dict, epoch: int):
1552
- """Log centroid evaluation metrics to WandB."""
1553
- import numpy as np
1554
- from lightning.pytorch.loggers import WandbLogger
1555
-
1556
- # Get WandB logger
1557
- wandb_logger = None
1558
- for log in trainer.loggers:
1559
- if isinstance(log, WandbLogger):
1560
- wandb_logger = log
1561
- break
1562
-
1563
- if wandb_logger is None:
1564
- return
1565
-
1566
- log_dict = {"epoch": epoch}
1567
-
1568
- # Distance metrics (with NaN handling)
1569
- if not np.isnan(metrics["dist_avg"]):
1570
- log_dict["eval/val/centroid_dist_avg"] = metrics["dist_avg"]
1571
- if not np.isnan(metrics["dist_median"]):
1572
- log_dict["eval/val/centroid_dist_median"] = metrics["dist_median"]
1573
- if not np.isnan(metrics["dist_p90"]):
1574
- log_dict["eval/val/centroid_dist_p90"] = metrics["dist_p90"]
1575
- if not np.isnan(metrics["dist_p95"]):
1576
- log_dict["eval/val/centroid_dist_p95"] = metrics["dist_p95"]
1577
- if not np.isnan(metrics["dist_max"]):
1578
- log_dict["eval/val/centroid_dist_max"] = metrics["dist_max"]
1579
-
1580
- # Detection metrics
1581
- log_dict["eval/val/centroid_precision"] = metrics["precision"]
1582
- log_dict["eval/val/centroid_recall"] = metrics["recall"]
1583
- log_dict["eval/val/centroid_f1"] = metrics["f1"]
1584
-
1585
- # Counts
1586
- log_dict["eval/val/centroid_n_tp"] = metrics["n_true_positives"]
1587
- log_dict["eval/val/centroid_n_fp"] = metrics["n_false_positives"]
1588
- log_dict["eval/val/centroid_n_fn"] = metrics["n_false_negatives"]
1589
-
1590
- wandb_logger.experiment.log(log_dict, commit=False)
1591
-
1592
- # Update best metrics in summary
1593
- for key, value in log_dict.items():
1594
- if key == "epoch":
1595
- continue
1596
- summary_key = f"best/{key}"
1597
- current_best = wandb_logger.experiment.summary.get(summary_key)
1598
- # For distance metrics, lower is better; for others, higher is better
1599
- is_distance = "dist" in key
1600
- if current_best is None:
1601
- wandb_logger.experiment.summary[summary_key] = value
1602
- elif is_distance and value < current_best:
1603
- wandb_logger.experiment.summary[summary_key] = value
1604
- elif not is_distance and value > current_best:
1605
- wandb_logger.experiment.summary[summary_key] = value