sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -2,10 +2,15 @@
2
2
 
3
3
  import zmq
4
4
  import jsonpickle
5
- from typing import Callable, Optional
5
+ from typing import Callable, Optional, Union
6
6
  from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.callbacks.progress import TQDMProgressBar
7
8
  from loguru import logger
8
9
  import matplotlib
10
+
11
+ matplotlib.use(
12
+ "Agg"
13
+ ) # Use non-interactive backend to avoid tkinter issues on Windows CI
9
14
  import matplotlib.pyplot as plt
10
15
  from PIL import Image
11
16
  from pathlib import Path
@@ -14,6 +19,32 @@ import csv
14
19
  from sleap_nn import RANK
15
20
 
16
21
 
22
+ class SleapProgressBar(TQDMProgressBar):
23
+ """Custom progress bar with better formatting for small metric values.
24
+
25
+ The default TQDMProgressBar truncates small floats like 1e-5 to "0.000".
26
+ This subclass formats metrics using scientific notation when appropriate.
27
+ """
28
+
29
+ def get_metrics(
30
+ self, trainer, pl_module
31
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
32
+ """Override to format metrics with scientific notation for small values."""
33
+ items = super().get_metrics(trainer, pl_module)
34
+ formatted = {}
35
+ for k, v in items.items():
36
+ if isinstance(v, float):
37
+ # Use scientific notation for very small values
38
+ if v != 0 and abs(v) < 0.001:
39
+ formatted[k] = f"{v:.2e}"
40
+ else:
41
+ # Use 4 decimal places for normal values
42
+ formatted[k] = f"{v:.4f}"
43
+ else:
44
+ formatted[k] = v
45
+ return formatted
46
+
47
+
17
48
  class CSVLoggerCallback(Callback):
18
49
  """Callback for logging metrics to csv.
19
50
 
@@ -53,6 +84,21 @@ class CSVLoggerCallback(Callback):
53
84
  for key in self.keys:
54
85
  if key == "epoch":
55
86
  log_data["epoch"] = trainer.current_epoch
87
+ elif key == "learning_rate":
88
+ # Handle multiple formats:
89
+ # 1. Direct "learning_rate" key
90
+ # 2. "train/lr" key (current format from lightning modules)
91
+ # 3. "lr-*" keys from LearningRateMonitor (legacy)
92
+ value = metrics.get(key, None)
93
+ if value is None:
94
+ value = metrics.get("train/lr", None)
95
+ if value is None:
96
+ # Look for lr-* keys from LearningRateMonitor (legacy)
97
+ for metric_key in metrics.keys():
98
+ if metric_key.startswith("lr-"):
99
+ value = metrics[metric_key]
100
+ break
101
+ log_data[key] = value.item() if value is not None else None
56
102
  else:
57
103
  value = metrics.get(key, None)
58
104
  log_data[key] = value.item() if value is not None else None
@@ -66,7 +112,11 @@ class CSVLoggerCallback(Callback):
66
112
 
67
113
 
68
114
  class WandBPredImageLogger(Callback):
69
- """Callback for writing image predictions to wandb.
115
+ """Callback for writing image predictions to wandb as a Table.
116
+
117
+ .. deprecated::
118
+ This callback logs images to a wandb.Table which doesn't support
119
+ step sliders. Use WandBVizCallback instead for better UX.
70
120
 
71
121
  Attributes:
72
122
  viz_folder: Path to viz directory.
@@ -141,12 +191,576 @@ class WandBPredImageLogger(Callback):
141
191
  ]
142
192
  ]
143
193
  table = wandb.Table(columns=column_names, data=data)
144
- wandb.log({f"{self.wandb_run_name}": table})
194
+ # Use commit=False to accumulate with other metrics in this step
195
+ wandb.log({f"{self.wandb_run_name}": table}, commit=False)
145
196
 
146
197
  # Sync all processes after wandb logging
147
198
  trainer.strategy.barrier()
148
199
 
149
200
 
201
+ class WandBVizCallback(Callback):
202
+ """Callback for logging visualization images directly to wandb with slider support.
203
+
204
+ This callback logs images using wandb.log() which enables step slider navigation
205
+ in the wandb UI. Multiple visualization modes can be enabled simultaneously:
206
+ - viz_enabled: Pre-render with matplotlib (same as disk viz)
207
+ - viz_boxes: Interactive keypoint boxes with filtering
208
+ - viz_masks: Confidence map overlay with per-node toggling
209
+
210
+ Attributes:
211
+ train_viz_fn: Function that returns VisualizationData for training sample.
212
+ val_viz_fn: Function that returns VisualizationData for validation sample.
213
+ viz_enabled: Whether to log pre-rendered matplotlib images.
214
+ viz_boxes: Whether to log interactive keypoint boxes.
215
+ viz_masks: Whether to log confidence map overlay masks.
216
+ box_size: Size of keypoint boxes in pixels (for viz_boxes).
217
+ confmap_threshold: Threshold for confmap masks (for viz_masks).
218
+ log_table: Whether to also log to a wandb.Table (backwards compat).
219
+ """
220
+
221
+ def __init__(
222
+ self,
223
+ train_viz_fn: Callable,
224
+ val_viz_fn: Callable,
225
+ viz_enabled: bool = True,
226
+ viz_boxes: bool = False,
227
+ viz_masks: bool = False,
228
+ box_size: float = 5.0,
229
+ confmap_threshold: float = 0.1,
230
+ log_table: bool = False,
231
+ ):
232
+ """Initialize the callback.
233
+
234
+ Args:
235
+ train_viz_fn: Callable that returns VisualizationData for a training sample.
236
+ val_viz_fn: Callable that returns VisualizationData for a validation sample.
237
+ viz_enabled: If True, log pre-rendered matplotlib images.
238
+ viz_boxes: If True, log interactive keypoint boxes.
239
+ viz_masks: If True, log confidence map overlay masks.
240
+ box_size: Size of keypoint boxes in pixels (for viz_boxes).
241
+ confmap_threshold: Threshold for confmap mask generation (for viz_masks).
242
+ log_table: If True, also log images to a wandb.Table (for backwards compat).
243
+ """
244
+ super().__init__()
245
+ self.train_viz_fn = train_viz_fn
246
+ self.val_viz_fn = val_viz_fn
247
+ self.viz_enabled = viz_enabled
248
+ self.viz_boxes = viz_boxes
249
+ self.viz_masks = viz_masks
250
+ self.log_table = log_table
251
+
252
+ # Import here to avoid circular imports
253
+ from sleap_nn.training.utils import WandBRenderer
254
+
255
+ self.box_size = box_size
256
+ self.confmap_threshold = confmap_threshold
257
+
258
+ # Create renderers for each enabled mode
259
+ self.renderers = {}
260
+ if viz_enabled:
261
+ self.renderers["direct"] = WandBRenderer(
262
+ mode="direct", box_size=box_size, confmap_threshold=confmap_threshold
263
+ )
264
+ if viz_boxes:
265
+ self.renderers["boxes"] = WandBRenderer(
266
+ mode="boxes", box_size=box_size, confmap_threshold=confmap_threshold
267
+ )
268
+ if viz_masks:
269
+ self.renderers["masks"] = WandBRenderer(
270
+ mode="masks", box_size=box_size, confmap_threshold=confmap_threshold
271
+ )
272
+
273
+ def _get_wandb_logger(self, trainer):
274
+ """Get the WandbLogger from trainer's loggers."""
275
+ from lightning.pytorch.loggers import WandbLogger
276
+
277
+ for logger in trainer.loggers:
278
+ if isinstance(logger, WandbLogger):
279
+ return logger
280
+ return None
281
+
282
+ def on_train_epoch_end(self, trainer, pl_module):
283
+ """Log visualization images at end of each epoch."""
284
+ if trainer.is_global_zero:
285
+ epoch = trainer.current_epoch
286
+
287
+ # Get the wandb logger to use its experiment for logging
288
+ wandb_logger = self._get_wandb_logger(trainer)
289
+
290
+ # Only do visualization work if wandb logger is available
291
+ if wandb_logger is not None:
292
+ # Get visualization data
293
+ train_data = self.train_viz_fn()
294
+ val_data = self.val_viz_fn()
295
+
296
+ # Render and log for each enabled mode
297
+ # Use the logger's experiment to let Lightning manage step tracking
298
+ log_dict = {}
299
+ for mode_name, renderer in self.renderers.items():
300
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
301
+ train_img = renderer.render(
302
+ train_data, caption=f"Train Epoch {epoch}"
303
+ )
304
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
305
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
306
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
307
+
308
+ if log_dict:
309
+ # Include epoch so wandb can use it as x-axis (via define_metric)
310
+ log_dict["epoch"] = epoch
311
+ # Use commit=False to accumulate with other metrics in this step
312
+ # Lightning will commit when it logs its own metrics
313
+ wandb_logger.experiment.log(log_dict, commit=False)
314
+
315
+ # Optionally also log to table for backwards compat
316
+ if self.log_table and "direct" in self.renderers:
317
+ train_img = self.renderers["direct"].render(
318
+ train_data, caption=f"Train Epoch {epoch}"
319
+ )
320
+ val_img = self.renderers["direct"].render(
321
+ val_data, caption=f"Val Epoch {epoch}"
322
+ )
323
+ table = wandb.Table(
324
+ columns=["Epoch", "Train", "Validation"],
325
+ data=[[epoch, train_img, val_img]],
326
+ )
327
+ wandb_logger.experiment.log(
328
+ {"predictions_table": table}, commit=False
329
+ )
330
+
331
+ # Sync all processes - barrier must be reached by ALL ranks
332
+ trainer.strategy.barrier()
333
+
334
+
335
+ class WandBVizCallbackWithPAFs(WandBVizCallback):
336
+ """Extended WandBVizCallback that also logs PAF visualizations for bottom-up models."""
337
+
338
+ def __init__(
339
+ self,
340
+ train_viz_fn: Callable,
341
+ val_viz_fn: Callable,
342
+ train_pafs_viz_fn: Callable,
343
+ val_pafs_viz_fn: Callable,
344
+ viz_enabled: bool = True,
345
+ viz_boxes: bool = False,
346
+ viz_masks: bool = False,
347
+ box_size: float = 5.0,
348
+ confmap_threshold: float = 0.1,
349
+ log_table: bool = False,
350
+ ):
351
+ """Initialize the callback.
352
+
353
+ Args:
354
+ train_viz_fn: Callable returning VisualizationData for training sample.
355
+ val_viz_fn: Callable returning VisualizationData for validation sample.
356
+ train_pafs_viz_fn: Callable returning VisualizationData with PAFs for training.
357
+ val_pafs_viz_fn: Callable returning VisualizationData with PAFs for validation.
358
+ viz_enabled: If True, log pre-rendered matplotlib images.
359
+ viz_boxes: If True, log interactive keypoint boxes.
360
+ viz_masks: If True, log confidence map overlay masks.
361
+ box_size: Size of keypoint boxes in pixels.
362
+ confmap_threshold: Threshold for confmap mask generation.
363
+ log_table: If True, also log images to a wandb.Table.
364
+ """
365
+ super().__init__(
366
+ train_viz_fn=train_viz_fn,
367
+ val_viz_fn=val_viz_fn,
368
+ viz_enabled=viz_enabled,
369
+ viz_boxes=viz_boxes,
370
+ viz_masks=viz_masks,
371
+ box_size=box_size,
372
+ confmap_threshold=confmap_threshold,
373
+ log_table=log_table,
374
+ )
375
+ self.train_pafs_viz_fn = train_pafs_viz_fn
376
+ self.val_pafs_viz_fn = val_pafs_viz_fn
377
+
378
+ # Import here to avoid circular imports
379
+ from sleap_nn.training.utils import MatplotlibRenderer
380
+
381
+ self._mpl_renderer = MatplotlibRenderer()
382
+
383
+ def on_train_epoch_end(self, trainer, pl_module):
384
+ """Log visualization images including PAFs at end of each epoch."""
385
+ if trainer.is_global_zero:
386
+ epoch = trainer.current_epoch
387
+
388
+ # Get the wandb logger to use its experiment for logging
389
+ wandb_logger = self._get_wandb_logger(trainer)
390
+
391
+ # Only do visualization work if wandb logger is available
392
+ if wandb_logger is not None:
393
+ # Get visualization data
394
+ train_data = self.train_viz_fn()
395
+ val_data = self.val_viz_fn()
396
+ train_pafs_data = self.train_pafs_viz_fn()
397
+ val_pafs_data = self.val_pafs_viz_fn()
398
+
399
+ # Render and log for each enabled mode
400
+ # Use the logger's experiment to let Lightning manage step tracking
401
+ log_dict = {}
402
+ for mode_name, renderer in self.renderers.items():
403
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
404
+ train_img = renderer.render(
405
+ train_data, caption=f"Train Epoch {epoch}"
406
+ )
407
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
408
+ log_dict[f"viz/train/predictions{suffix}"] = train_img
409
+ log_dict[f"viz/val/predictions{suffix}"] = val_img
410
+
411
+ # Render PAFs (always use matplotlib/direct for PAFs)
412
+ from io import BytesIO
413
+ import matplotlib.pyplot as plt
414
+ from PIL import Image
415
+
416
+ train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
417
+ buf = BytesIO()
418
+ train_pafs_fig.savefig(
419
+ buf, format="png", bbox_inches="tight", pad_inches=0
420
+ )
421
+ buf.seek(0)
422
+ plt.close(train_pafs_fig)
423
+ train_pafs_pil = Image.open(buf)
424
+ log_dict["viz/train/pafs"] = wandb.Image(
425
+ train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
426
+ )
427
+
428
+ val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
429
+ buf = BytesIO()
430
+ val_pafs_fig.savefig(
431
+ buf, format="png", bbox_inches="tight", pad_inches=0
432
+ )
433
+ buf.seek(0)
434
+ plt.close(val_pafs_fig)
435
+ val_pafs_pil = Image.open(buf)
436
+ log_dict["viz/val/pafs"] = wandb.Image(
437
+ val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
438
+ )
439
+
440
+ if log_dict:
441
+ # Include epoch so wandb can use it as x-axis (via define_metric)
442
+ log_dict["epoch"] = epoch
443
+ # Use commit=False to accumulate with other metrics in this step
444
+ # Lightning will commit when it logs its own metrics
445
+ wandb_logger.experiment.log(log_dict, commit=False)
446
+
447
+ # Optionally also log to table
448
+ if self.log_table and "direct" in self.renderers:
449
+ train_img = self.renderers["direct"].render(
450
+ train_data, caption=f"Train Epoch {epoch}"
451
+ )
452
+ val_img = self.renderers["direct"].render(
453
+ val_data, caption=f"Val Epoch {epoch}"
454
+ )
455
+ table = wandb.Table(
456
+ columns=[
457
+ "Epoch",
458
+ "Train",
459
+ "Validation",
460
+ "Train PAFs",
461
+ "Val PAFs",
462
+ ],
463
+ data=[
464
+ [
465
+ epoch,
466
+ train_img,
467
+ val_img,
468
+ log_dict["viz/train/pafs"],
469
+ log_dict["viz/val/pafs"],
470
+ ]
471
+ ],
472
+ )
473
+ wandb_logger.experiment.log(
474
+ {"predictions_table": table}, commit=False
475
+ )
476
+
477
+ # Sync all processes - barrier must be reached by ALL ranks
478
+ trainer.strategy.barrier()
479
+
480
+
481
+ class UnifiedVizCallback(Callback):
482
+ """Unified callback for all visualization outputs during training.
483
+
484
+ This callback consolidates all visualization functionality into a single callback,
485
+ eliminating redundant dataset copies and inference runs. It handles:
486
+ - Local disk saving (matplotlib figures)
487
+ - WandB logging (multiple modes: direct, boxes, masks)
488
+ - Model-specific visualizations (PAFs for bottomup, class maps for multi_class_bottomup)
489
+
490
+ Benefits over separate callbacks:
491
+ - Uses ONE sample per epoch for all visualizations (no dataset deepcopy)
492
+ - Runs inference ONCE per sample (vs 4-8x in previous implementation)
493
+ - Outputs to multiple destinations from the same data
494
+ - Simpler code with less duplication
495
+
496
+ Attributes:
497
+ model_trainer: Reference to the ModelTrainer (for lazy access to lightning_model).
498
+ train_pipeline: Iterator over training visualization dataset.
499
+ val_pipeline: Iterator over validation visualization dataset.
500
+ model_type: Type of model (affects which visualizations are enabled).
501
+ save_local: Whether to save matplotlib figures to disk.
502
+ local_save_dir: Directory for local visualization saves.
503
+ log_wandb: Whether to log visualizations to wandb.
504
+ wandb_modes: List of wandb rendering modes ("direct", "boxes", "masks").
505
+ wandb_box_size: Size of keypoint boxes in pixels (for "boxes" mode).
506
+ wandb_confmap_threshold: Threshold for confmap masks (for "masks" mode).
507
+ log_wandb_table: Whether to also log to a wandb.Table.
508
+ """
509
+
510
+ def __init__(
511
+ self,
512
+ model_trainer,
513
+ train_dataset,
514
+ val_dataset,
515
+ model_type: str,
516
+ save_local: bool = True,
517
+ local_save_dir: Optional[Path] = None,
518
+ log_wandb: bool = False,
519
+ wandb_modes: Optional[list] = None,
520
+ wandb_box_size: float = 5.0,
521
+ wandb_confmap_threshold: float = 0.1,
522
+ log_wandb_table: bool = False,
523
+ ):
524
+ """Initialize the unified visualization callback.
525
+
526
+ Args:
527
+ model_trainer: ModelTrainer instance (lightning_model accessed lazily).
528
+ train_dataset: Training visualization dataset (will be cycled).
529
+ val_dataset: Validation visualization dataset (will be cycled).
530
+ model_type: Model type string (e.g., "bottomup", "multi_class_bottomup").
531
+ save_local: If True, save matplotlib figures to local_save_dir.
532
+ local_save_dir: Path to directory for saving visualization images.
533
+ log_wandb: If True, log visualizations to wandb.
534
+ wandb_modes: List of wandb rendering modes. Defaults to ["direct"].
535
+ wandb_box_size: Size of keypoint boxes in pixels.
536
+ wandb_confmap_threshold: Threshold for confidence map masks.
537
+ log_wandb_table: If True, also log to a wandb.Table.
538
+ """
539
+ super().__init__()
540
+ from itertools import cycle
541
+
542
+ self.model_trainer = model_trainer
543
+ self.train_pipeline = cycle(train_dataset)
544
+ self.val_pipeline = cycle(val_dataset)
545
+ self.model_type = model_type
546
+
547
+ # Local disk config
548
+ self.save_local = save_local
549
+ self.local_save_dir = local_save_dir
550
+
551
+ # WandB config
552
+ self.log_wandb = log_wandb
553
+ self.wandb_modes = wandb_modes or ["direct"]
554
+ self.wandb_box_size = wandb_box_size
555
+ self.wandb_confmap_threshold = wandb_confmap_threshold
556
+ self.log_wandb_table = log_wandb_table
557
+
558
+ # Auto-enable model-specific visualizations
559
+ self.viz_pafs = model_type == "bottomup"
560
+ self.viz_class_maps = model_type == "multi_class_bottomup"
561
+
562
+ # Initialize renderers
563
+ from sleap_nn.training.utils import MatplotlibRenderer, WandBRenderer
564
+
565
+ self._mpl_renderer = MatplotlibRenderer()
566
+
567
+ # Create wandb renderers for each enabled mode
568
+ self._wandb_renderers = {}
569
+ if log_wandb:
570
+ for mode in self.wandb_modes:
571
+ self._wandb_renderers[mode] = WandBRenderer(
572
+ mode=mode,
573
+ box_size=wandb_box_size,
574
+ confmap_threshold=wandb_confmap_threshold,
575
+ )
576
+
577
+ def _get_wandb_logger(self, trainer):
578
+ """Get the WandbLogger from trainer's loggers."""
579
+ from lightning.pytorch.loggers import WandbLogger
580
+
581
+ for log in trainer.loggers:
582
+ if isinstance(log, WandbLogger):
583
+ return log
584
+ return None
585
+
586
+ def _get_viz_data(self, sample):
587
+ """Get visualization data with all needed fields based on model type.
588
+
589
+ Args:
590
+ sample: A sample from the visualization dataset.
591
+
592
+ Returns:
593
+ VisualizationData with appropriate fields populated.
594
+ """
595
+ # Build kwargs based on model type
596
+ kwargs = {}
597
+ if self.viz_pafs:
598
+ kwargs["include_pafs"] = True
599
+ if self.viz_class_maps:
600
+ kwargs["include_class_maps"] = True
601
+
602
+ # Access lightning_model lazily from model_trainer
603
+ return self.model_trainer.lightning_model.get_visualization_data(
604
+ sample, **kwargs
605
+ )
606
+
607
+ def _save_local_viz(self, data, prefix: str, epoch: int):
608
+ """Save visualization to local disk.
609
+
610
+ Args:
611
+ data: VisualizationData object.
612
+ prefix: Filename prefix (e.g., "train", "validation").
613
+ epoch: Current epoch number.
614
+ """
615
+ if not self.save_local or self.local_save_dir is None:
616
+ return
617
+
618
+ # Confmaps visualization
619
+ fig = self._mpl_renderer.render(data)
620
+ fig_path = self.local_save_dir / f"{prefix}.{epoch:04d}.png"
621
+ fig.savefig(fig_path, format="png")
622
+ plt.close(fig)
623
+
624
+ # PAFs visualization (for bottomup models)
625
+ if self.viz_pafs and data.pred_pafs is not None:
626
+ fig = self._mpl_renderer.render_pafs(data)
627
+ fig_path = self.local_save_dir / f"{prefix}.pafs_magnitude.{epoch:04d}.png"
628
+ fig.savefig(fig_path, format="png")
629
+ plt.close(fig)
630
+
631
+ # Class maps visualization (for multi_class_bottomup models)
632
+ if self.viz_class_maps and data.pred_class_maps is not None:
633
+ fig = self._render_class_maps(data)
634
+ fig_path = self.local_save_dir / f"{prefix}.class_maps.{epoch:04d}.png"
635
+ fig.savefig(fig_path, format="png")
636
+ plt.close(fig)
637
+
638
+ def _render_class_maps(self, data):
639
+ """Render class maps visualization.
640
+
641
+ Args:
642
+ data: VisualizationData with pred_class_maps populated.
643
+
644
+ Returns:
645
+ A matplotlib Figure object.
646
+ """
647
+ from sleap_nn.training.utils import plot_img, plot_confmaps
648
+
649
+ img = data.image
650
+ scale = 1.0
651
+ if img.shape[0] < 512:
652
+ scale = 2.0
653
+ if img.shape[0] < 256:
654
+ scale = 4.0
655
+
656
+ fig = plot_img(img, dpi=72 * scale, scale=scale)
657
+ plot_confmaps(
658
+ data.pred_class_maps,
659
+ output_scale=data.pred_class_maps.shape[0] / img.shape[0],
660
+ )
661
+ return fig
662
+
663
+ def _log_wandb_viz(self, data, prefix: str, epoch: int, wandb_logger):
664
+ """Log visualization to wandb.
665
+
666
+ Args:
667
+ data: VisualizationData object.
668
+ prefix: Log prefix (e.g., "train", "val").
669
+ epoch: Current epoch number.
670
+ wandb_logger: WandbLogger instance.
671
+ """
672
+ if not self.log_wandb or wandb_logger is None:
673
+ return
674
+
675
+ from io import BytesIO
676
+ from PIL import Image as PILImage
677
+
678
+ log_dict = {}
679
+
680
+ # Render confmaps for each enabled mode
681
+ for mode_name, renderer in self._wandb_renderers.items():
682
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
683
+ img = renderer.render(data, caption=f"{prefix.title()} Epoch {epoch}")
684
+ log_dict[f"viz/{prefix}/predictions{suffix}"] = img
685
+
686
+ # PAFs visualization (for bottomup models)
687
+ if self.viz_pafs and data.pred_pafs is not None:
688
+ pafs_fig = self._mpl_renderer.render_pafs(data)
689
+ buf = BytesIO()
690
+ pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
691
+ buf.seek(0)
692
+ plt.close(pafs_fig)
693
+ pafs_pil = PILImage.open(buf)
694
+ log_dict[f"viz/{prefix}/pafs"] = wandb.Image(
695
+ pafs_pil, caption=f"{prefix.title()} PAFs Epoch {epoch}"
696
+ )
697
+
698
+ # Class maps visualization (for multi_class_bottomup models)
699
+ if self.viz_class_maps and data.pred_class_maps is not None:
700
+ class_fig = self._render_class_maps(data)
701
+ buf = BytesIO()
702
+ class_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
703
+ buf.seek(0)
704
+ plt.close(class_fig)
705
+ class_pil = PILImage.open(buf)
706
+ log_dict[f"viz/{prefix}/class_maps"] = wandb.Image(
707
+ class_pil, caption=f"{prefix.title()} Class Maps Epoch {epoch}"
708
+ )
709
+
710
+ if log_dict:
711
+ log_dict["epoch"] = epoch
712
+ wandb_logger.experiment.log(log_dict, commit=False)
713
+
714
+ # Optionally log to table for backwards compatibility
715
+ if self.log_wandb_table and "direct" in self._wandb_renderers:
716
+ train_img = self._wandb_renderers["direct"].render(
717
+ data, caption=f"{prefix.title()} Epoch {epoch}"
718
+ )
719
+ table_data = [[epoch, train_img]]
720
+ columns = ["Epoch", prefix.title()]
721
+
722
+ if self.viz_pafs and data.pred_pafs is not None:
723
+ columns.append(f"{prefix.title()} PAFs")
724
+ table_data[0].append(log_dict.get(f"viz/{prefix}/pafs"))
725
+
726
+ if self.viz_class_maps and data.pred_class_maps is not None:
727
+ columns.append(f"{prefix.title()} Class Maps")
728
+ table_data[0].append(log_dict.get(f"viz/{prefix}/class_maps"))
729
+
730
+ table = wandb.Table(columns=columns, data=table_data)
731
+ wandb_logger.experiment.log(
732
+ {f"predictions_table_{prefix}": table}, commit=False
733
+ )
734
+
735
+ def on_train_epoch_end(self, trainer, pl_module):
736
+ """Generate and output all visualizations at epoch end.
737
+
738
+ Args:
739
+ trainer: PyTorch Lightning trainer.
740
+ pl_module: Lightning module (not used, we use self.lightning_module).
741
+ """
742
+ if trainer.is_global_zero:
743
+ epoch = trainer.current_epoch
744
+ wandb_logger = self._get_wandb_logger(trainer) if self.log_wandb else None
745
+
746
+ # Get ONE sample for train visualization
747
+ train_sample = next(self.train_pipeline)
748
+ # Run inference ONCE with all needed data
749
+ train_data = self._get_viz_data(train_sample)
750
+ # Output to all destinations
751
+ self._save_local_viz(train_data, "train", epoch)
752
+ self._log_wandb_viz(train_data, "train", epoch, wandb_logger)
753
+
754
+ # Same for validation
755
+ val_sample = next(self.val_pipeline)
756
+ val_data = self._get_viz_data(val_sample)
757
+ self._save_local_viz(val_data, "validation", epoch)
758
+ self._log_wandb_viz(val_data, "val", epoch, wandb_logger)
759
+
760
+ # Sync all processes - barrier must be reached by ALL ranks
761
+ trainer.strategy.barrier()
762
+
763
+
150
764
  class MatplotlibSaver(Callback):
151
765
  """Callback for saving images rendered with matplotlib during training.
152
766
 
@@ -194,7 +808,7 @@ class MatplotlibSaver(Callback):
194
808
  ).as_posix()
195
809
 
196
810
  # Save rendered figure.
197
- figure.savefig(figure_path, format="png", pad_inches=0)
811
+ figure.savefig(figure_path, format="png")
198
812
  plt.close(figure)
199
813
 
200
814
  # Sync all processes after file I/O
@@ -303,7 +917,11 @@ class ProgressReporterZMQ(Callback):
303
917
  def on_train_start(self, trainer, pl_module):
304
918
  """Called at the beginning of training process."""
305
919
  if trainer.is_global_zero:
306
- self.send("train_begin")
920
+ # Include WandB URL if available
921
+ wandb_url = None
922
+ if wandb.run is not None:
923
+ wandb_url = wandb.run.url
924
+ self.send("train_begin", wandb_url=wandb_url)
307
925
  trainer.strategy.barrier()
308
926
 
309
927
  def on_train_end(self, trainer, pl_module):
@@ -350,3 +968,638 @@ class ProgressReporterZMQ(Callback):
350
968
  return {
351
969
  k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
352
970
  }
971
+
972
+
973
+ class EpochEndEvaluationCallback(Callback):
974
+ """Callback to run full evaluation metrics at end of validation epochs.
975
+
976
+ This callback collects predictions and ground truth during validation,
977
+ then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
978
+ metrics to WandB.
979
+
980
+ Attributes:
981
+ skeleton: sio.Skeleton for creating instances.
982
+ videos: List of sio.Video objects.
983
+ eval_frequency: Run evaluation every N epochs (default: 1).
984
+ oks_stddev: OKS standard deviation (default: 0.025).
985
+ oks_scale: Optional OKS scale override.
986
+ metrics_to_log: List of metric keys to log.
987
+ """
988
+
989
+ def __init__(
990
+ self,
991
+ skeleton: "sio.Skeleton",
992
+ videos: list,
993
+ eval_frequency: int = 1,
994
+ oks_stddev: float = 0.025,
995
+ oks_scale: Optional[float] = None,
996
+ metrics_to_log: Optional[list] = None,
997
+ ):
998
+ """Initialize the callback.
999
+
1000
+ Args:
1001
+ skeleton: sio.Skeleton for creating instances.
1002
+ videos: List of sio.Video objects.
1003
+ eval_frequency: Run evaluation every N epochs (default: 1).
1004
+ oks_stddev: OKS standard deviation (default: 0.025).
1005
+ oks_scale: Optional OKS scale override.
1006
+ metrics_to_log: List of metric keys to log. If None, logs all available.
1007
+ """
1008
+ super().__init__()
1009
+ self.skeleton = skeleton
1010
+ self.videos = videos
1011
+ self.eval_frequency = eval_frequency
1012
+ self.oks_stddev = oks_stddev
1013
+ self.oks_scale = oks_scale
1014
+ self.metrics_to_log = metrics_to_log or [
1015
+ "mOKS",
1016
+ "oks_voc.mAP",
1017
+ "oks_voc.mAR",
1018
+ "distance/avg",
1019
+ "distance/p50",
1020
+ "distance/p95",
1021
+ "distance/p99",
1022
+ "mPCK",
1023
+ "PCK@5",
1024
+ "PCK@10",
1025
+ "visibility_precision",
1026
+ "visibility_recall",
1027
+ ]
1028
+
1029
+ def on_validation_epoch_start(self, trainer, pl_module):
1030
+ """Enable prediction collection at the start of validation.
1031
+
1032
+ Skip during sanity check to avoid inference issues.
1033
+ """
1034
+ if trainer.sanity_checking:
1035
+ return
1036
+ pl_module._collect_val_predictions = True
1037
+
1038
+ def on_validation_epoch_end(self, trainer, pl_module):
1039
+ """Run evaluation and log metrics at end of validation epoch."""
1040
+ import sleap_io as sio
1041
+ import numpy as np
1042
+ from lightning.pytorch.loggers import WandbLogger
1043
+ from sleap_nn.evaluation import Evaluator
1044
+
1045
+ # Determine if we should run evaluation this epoch (only on rank 0)
1046
+ should_evaluate = (
1047
+ trainer.current_epoch + 1
1048
+ ) % self.eval_frequency == 0 and trainer.is_global_zero
1049
+
1050
+ if should_evaluate:
1051
+ # Check if we have predictions
1052
+ if not pl_module.val_predictions or not pl_module.val_ground_truth:
1053
+ logger.warning("No predictions collected for epoch-end evaluation")
1054
+ else:
1055
+ try:
1056
+ # Build sio.Labels from accumulated predictions and ground truth
1057
+ pred_labels = self._build_pred_labels(
1058
+ pl_module.val_predictions, sio, np
1059
+ )
1060
+ gt_labels = self._build_gt_labels(
1061
+ pl_module.val_ground_truth, sio, np
1062
+ )
1063
+
1064
+ # Check if we have valid frames to evaluate
1065
+ if len(pred_labels) == 0:
1066
+ logger.warning(
1067
+ "No valid predictions for epoch-end evaluation "
1068
+ "(all predictions may be empty or NaN)"
1069
+ )
1070
+ else:
1071
+ # Run evaluation
1072
+ evaluator = Evaluator(
1073
+ ground_truth_instances=gt_labels,
1074
+ predicted_instances=pred_labels,
1075
+ oks_stddev=self.oks_stddev,
1076
+ oks_scale=self.oks_scale,
1077
+ user_labels_only=False, # All validation frames are "user" frames
1078
+ )
1079
+ metrics = evaluator.evaluate()
1080
+
1081
+ # Log to WandB
1082
+ self._log_metrics(trainer, metrics, trainer.current_epoch)
1083
+
1084
+ logger.info(
1085
+ f"Epoch {trainer.current_epoch} evaluation: "
1086
+ f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
1087
+ f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
1088
+ f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
1089
+ )
1090
+
1091
+ except Exception as e:
1092
+ logger.warning(f"Epoch-end evaluation failed: {e}")
1093
+
1094
+ # Cleanup - all ranks reset the flag, rank 0 clears the lists
1095
+ pl_module._collect_val_predictions = False
1096
+ if trainer.is_global_zero:
1097
+ pl_module.val_predictions = []
1098
+ pl_module.val_ground_truth = []
1099
+
1100
+ # Sync all processes - barrier must be reached by ALL ranks
1101
+ trainer.strategy.barrier()
1102
+
1103
+ def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
1104
+ """Convert prediction dicts to sio.Labels."""
1105
+ labeled_frames = []
1106
+ for pred in predictions:
1107
+ pred_peaks = pred["pred_peaks"]
1108
+ pred_scores = pred["pred_scores"]
1109
+
1110
+ # Handle NaN/missing predictions
1111
+ if pred_peaks is None or (
1112
+ isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
1113
+ ):
1114
+ continue
1115
+
1116
+ # Handle multi-instance predictions (bottomup)
1117
+ if len(pred_peaks.shape) == 2:
1118
+ # Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
1119
+ pred_peaks = pred_peaks.reshape(1, -1, 2)
1120
+ pred_scores = pred_scores.reshape(1, -1)
1121
+
1122
+ instances = []
1123
+ for inst_idx in range(len(pred_peaks)):
1124
+ inst_points = pred_peaks[inst_idx]
1125
+ inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
1126
+
1127
+ # Skip if all NaN
1128
+ if np.isnan(inst_points).all():
1129
+ continue
1130
+
1131
+ inst = sio.PredictedInstance.from_numpy(
1132
+ points_data=inst_points,
1133
+ skeleton=self.skeleton,
1134
+ point_scores=(
1135
+ inst_scores
1136
+ if inst_scores is not None
1137
+ else np.ones(len(inst_points))
1138
+ ),
1139
+ score=(
1140
+ float(np.nanmean(inst_scores))
1141
+ if inst_scores is not None
1142
+ else 1.0
1143
+ ),
1144
+ )
1145
+ instances.append(inst)
1146
+
1147
+ if instances:
1148
+ lf = sio.LabeledFrame(
1149
+ video=self.videos[pred["video_idx"]],
1150
+ frame_idx=pred["frame_idx"],
1151
+ instances=instances,
1152
+ )
1153
+ labeled_frames.append(lf)
1154
+
1155
+ return sio.Labels(
1156
+ videos=self.videos,
1157
+ skeletons=[self.skeleton],
1158
+ labeled_frames=labeled_frames,
1159
+ )
1160
+
1161
+ def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
1162
+ """Convert ground truth dicts to sio.Labels."""
1163
+ labeled_frames = []
1164
+ for gt in ground_truth:
1165
+ instances = []
1166
+ gt_instances = gt["gt_instances"]
1167
+
1168
+ # Handle shape variations
1169
+ if len(gt_instances.shape) == 2:
1170
+ # (n_nodes, 2) -> (1, n_nodes, 2)
1171
+ gt_instances = gt_instances.reshape(1, -1, 2)
1172
+
1173
+ for i in range(min(gt["num_instances"], len(gt_instances))):
1174
+ inst_data = gt_instances[i]
1175
+ if np.isnan(inst_data).all():
1176
+ continue
1177
+ inst = sio.Instance.from_numpy(
1178
+ points_data=inst_data,
1179
+ skeleton=self.skeleton,
1180
+ )
1181
+ instances.append(inst)
1182
+
1183
+ if instances:
1184
+ lf = sio.LabeledFrame(
1185
+ video=self.videos[gt["video_idx"]],
1186
+ frame_idx=gt["frame_idx"],
1187
+ instances=instances,
1188
+ )
1189
+ labeled_frames.append(lf)
1190
+
1191
+ return sio.Labels(
1192
+ videos=self.videos,
1193
+ skeletons=[self.skeleton],
1194
+ labeled_frames=labeled_frames,
1195
+ )
1196
+
1197
+ def _log_metrics(self, trainer, metrics: dict, epoch: int):
1198
+ """Log evaluation metrics to WandB."""
1199
+ import numpy as np
1200
+ from lightning.pytorch.loggers import WandbLogger
1201
+
1202
+ # Get WandB logger
1203
+ wandb_logger = None
1204
+ for log in trainer.loggers:
1205
+ if isinstance(log, WandbLogger):
1206
+ wandb_logger = log
1207
+ break
1208
+
1209
+ if wandb_logger is None:
1210
+ return
1211
+
1212
+ log_dict = {"epoch": epoch}
1213
+
1214
+ # Extract key metrics with consistent naming
1215
+ # All eval metrics use eval/val/ prefix since they're computed on validation data
1216
+ if "mOKS" in self.metrics_to_log:
1217
+ log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
1218
+
1219
+ if "oks_voc.mAP" in self.metrics_to_log:
1220
+ log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
1221
+
1222
+ if "oks_voc.mAR" in self.metrics_to_log:
1223
+ log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
1224
+
1225
+ # Distance metrics grouped under eval/val/distance/
1226
+ if "distance/avg" in self.metrics_to_log:
1227
+ val = metrics["distance_metrics"]["avg"]
1228
+ if not np.isnan(val):
1229
+ log_dict["eval/val/distance/avg"] = val
1230
+
1231
+ if "distance/p50" in self.metrics_to_log:
1232
+ val = metrics["distance_metrics"]["p50"]
1233
+ if not np.isnan(val):
1234
+ log_dict["eval/val/distance/p50"] = val
1235
+
1236
+ if "distance/p95" in self.metrics_to_log:
1237
+ val = metrics["distance_metrics"]["p95"]
1238
+ if not np.isnan(val):
1239
+ log_dict["eval/val/distance/p95"] = val
1240
+
1241
+ if "distance/p99" in self.metrics_to_log:
1242
+ val = metrics["distance_metrics"]["p99"]
1243
+ if not np.isnan(val):
1244
+ log_dict["eval/val/distance/p99"] = val
1245
+
1246
+ # PCK metrics
1247
+ if "mPCK" in self.metrics_to_log:
1248
+ log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
1249
+
1250
+ # PCK at specific thresholds (precomputed in evaluation.py)
1251
+ if "PCK@5" in self.metrics_to_log:
1252
+ log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
1253
+
1254
+ if "PCK@10" in self.metrics_to_log:
1255
+ log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
1256
+
1257
+ # Visibility metrics
1258
+ if "visibility_precision" in self.metrics_to_log:
1259
+ val = metrics["visibility_metrics"]["precision"]
1260
+ if not np.isnan(val):
1261
+ log_dict["eval/val/visibility_precision"] = val
1262
+
1263
+ if "visibility_recall" in self.metrics_to_log:
1264
+ val = metrics["visibility_metrics"]["recall"]
1265
+ if not np.isnan(val):
1266
+ log_dict["eval/val/visibility_recall"] = val
1267
+
1268
+ wandb_logger.experiment.log(log_dict, commit=False)
1269
+
1270
+ # Update best metrics in summary (excluding epoch)
1271
+ for key, value in log_dict.items():
1272
+ if key == "epoch":
1273
+ continue
1274
+ # Create summary key like "best/eval/val/mOKS"
1275
+ summary_key = f"best/{key}"
1276
+ current_best = wandb_logger.experiment.summary.get(summary_key)
1277
+ # For distance metrics, lower is better; for others, higher is better
1278
+ is_distance = "distance" in key
1279
+ if current_best is None:
1280
+ wandb_logger.experiment.summary[summary_key] = value
1281
+ elif is_distance and value < current_best:
1282
+ wandb_logger.experiment.summary[summary_key] = value
1283
+ elif not is_distance and value > current_best:
1284
+ wandb_logger.experiment.summary[summary_key] = value
1285
+
1286
+
1287
+ def match_centroids(
1288
+ pred_centroids: "np.ndarray",
1289
+ gt_centroids: "np.ndarray",
1290
+ max_distance: float = 50.0,
1291
+ ) -> tuple:
1292
+ """Match predicted centroids to ground truth using Hungarian algorithm.
1293
+
1294
+ Args:
1295
+ pred_centroids: Predicted centroid locations, shape (n_pred, 2).
1296
+ gt_centroids: Ground truth centroid locations, shape (n_gt, 2).
1297
+ max_distance: Maximum distance threshold for valid matches (in pixels).
1298
+
1299
+ Returns:
1300
+ Tuple of:
1301
+ - matched_pred_indices: Indices of matched predictions
1302
+ - matched_gt_indices: Indices of matched ground truth
1303
+ - unmatched_pred_indices: Indices of unmatched predictions (false positives)
1304
+ - unmatched_gt_indices: Indices of unmatched ground truth (false negatives)
1305
+ """
1306
+ import numpy as np
1307
+ from scipy.optimize import linear_sum_assignment
1308
+ from scipy.spatial.distance import cdist
1309
+
1310
+ n_pred = len(pred_centroids)
1311
+ n_gt = len(gt_centroids)
1312
+
1313
+ # Handle edge cases
1314
+ if n_pred == 0 and n_gt == 0:
1315
+ return np.array([]), np.array([]), np.array([]), np.array([])
1316
+ if n_pred == 0:
1317
+ return np.array([]), np.array([]), np.array([]), np.arange(n_gt)
1318
+ if n_gt == 0:
1319
+ return np.array([]), np.array([]), np.arange(n_pred), np.array([])
1320
+
1321
+ # Compute pairwise distances
1322
+ cost_matrix = cdist(pred_centroids, gt_centroids)
1323
+
1324
+ # Run Hungarian algorithm for optimal matching
1325
+ pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
1326
+
1327
+ # Filter matches that exceed max_distance
1328
+ matched_pred = []
1329
+ matched_gt = []
1330
+ for p_idx, g_idx in zip(pred_indices, gt_indices):
1331
+ if cost_matrix[p_idx, g_idx] <= max_distance:
1332
+ matched_pred.append(p_idx)
1333
+ matched_gt.append(g_idx)
1334
+
1335
+ matched_pred = np.array(matched_pred)
1336
+ matched_gt = np.array(matched_gt)
1337
+
1338
+ # Find unmatched indices
1339
+ all_pred = set(range(n_pred))
1340
+ all_gt = set(range(n_gt))
1341
+ unmatched_pred = np.array(list(all_pred - set(matched_pred)))
1342
+ unmatched_gt = np.array(list(all_gt - set(matched_gt)))
1343
+
1344
+ return matched_pred, matched_gt, unmatched_pred, unmatched_gt
1345
+
1346
+
1347
+ class CentroidEvaluationCallback(Callback):
1348
+ """Callback to run centroid-specific evaluation metrics at end of validation epochs.
1349
+
1350
+ This callback is designed specifically for centroid models, which predict a single
1351
+ point (centroid) per instance rather than full pose skeletons. It computes
1352
+ distance-based metrics and detection metrics that are more appropriate for
1353
+ point detection tasks than OKS/PCK metrics.
1354
+
1355
+ Metrics computed:
1356
+ - Distance metrics: mean, median, p90, p95, max Euclidean distance
1357
+ - Detection metrics: precision, recall, F1 score
1358
+ - Counts: true positives, false positives, false negatives
1359
+
1360
+ Attributes:
1361
+ videos: List of sio.Video objects.
1362
+ eval_frequency: Run evaluation every N epochs (default: 1).
1363
+ match_threshold: Maximum distance (pixels) for matching pred to GT (default: 50.0).
1364
+ """
1365
+
1366
+ def __init__(
1367
+ self,
1368
+ videos: list,
1369
+ eval_frequency: int = 1,
1370
+ match_threshold: float = 50.0,
1371
+ ):
1372
+ """Initialize the callback.
1373
+
1374
+ Args:
1375
+ videos: List of sio.Video objects.
1376
+ eval_frequency: Run evaluation every N epochs (default: 1).
1377
+ match_threshold: Maximum distance in pixels for a prediction to be
1378
+ considered a match to a ground truth centroid (default: 50.0).
1379
+ """
1380
+ super().__init__()
1381
+ self.videos = videos
1382
+ self.eval_frequency = eval_frequency
1383
+ self.match_threshold = match_threshold
1384
+
1385
+ def on_validation_epoch_start(self, trainer, pl_module):
1386
+ """Enable prediction collection at the start of validation.
1387
+
1388
+ Skip during sanity check to avoid inference issues.
1389
+ """
1390
+ if trainer.sanity_checking:
1391
+ return
1392
+ pl_module._collect_val_predictions = True
1393
+
1394
+ def on_validation_epoch_end(self, trainer, pl_module):
1395
+ """Run centroid evaluation and log metrics at end of validation epoch."""
1396
+ import numpy as np
1397
+ from lightning.pytorch.loggers import WandbLogger
1398
+
1399
+ # Determine if we should run evaluation this epoch (only on rank 0)
1400
+ should_evaluate = (
1401
+ trainer.current_epoch + 1
1402
+ ) % self.eval_frequency == 0 and trainer.is_global_zero
1403
+
1404
+ if should_evaluate:
1405
+ # Check if we have predictions
1406
+ if not pl_module.val_predictions or not pl_module.val_ground_truth:
1407
+ logger.warning(
1408
+ "No predictions collected for centroid epoch-end evaluation"
1409
+ )
1410
+ else:
1411
+ try:
1412
+ metrics = self._compute_metrics(
1413
+ pl_module.val_predictions, pl_module.val_ground_truth, np
1414
+ )
1415
+
1416
+ # Log to WandB
1417
+ self._log_metrics(trainer, metrics, trainer.current_epoch)
1418
+
1419
+ logger.info(
1420
+ f"Epoch {trainer.current_epoch} centroid evaluation: "
1421
+ f"precision={metrics['precision']:.4f}, "
1422
+ f"recall={metrics['recall']:.4f}, "
1423
+ f"dist_avg={metrics['dist_avg']:.2f}px"
1424
+ )
1425
+
1426
+ except Exception as e:
1427
+ logger.warning(f"Centroid epoch-end evaluation failed: {e}")
1428
+
1429
+ # Cleanup - all ranks reset the flag, rank 0 clears the lists
1430
+ pl_module._collect_val_predictions = False
1431
+ if trainer.is_global_zero:
1432
+ pl_module.val_predictions = []
1433
+ pl_module.val_ground_truth = []
1434
+
1435
+ # Sync all processes - barrier must be reached by ALL ranks
1436
+ trainer.strategy.barrier()
1437
+
1438
+ def _compute_metrics(self, predictions: list, ground_truth: list, np) -> dict:
1439
+ """Compute centroid-specific metrics.
1440
+
1441
+ Args:
1442
+ predictions: List of prediction dicts with "pred_peaks" key.
1443
+ ground_truth: List of ground truth dicts with "gt_instances" key.
1444
+ np: NumPy module.
1445
+
1446
+ Returns:
1447
+ Dictionary of computed metrics.
1448
+ """
1449
+ all_distances = []
1450
+ total_tp = 0
1451
+ total_fp = 0
1452
+ total_fn = 0
1453
+
1454
+ # Group predictions and GT by frame
1455
+ pred_by_frame = {}
1456
+ for pred in predictions:
1457
+ key = (pred["video_idx"], pred["frame_idx"])
1458
+ if key not in pred_by_frame:
1459
+ pred_by_frame[key] = []
1460
+ # pred_peaks shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
1461
+ centroids = pred["pred_peaks"].reshape(-1, 2)
1462
+ # Filter out NaN centroids
1463
+ valid_mask = ~np.isnan(centroids).any(axis=1)
1464
+ pred_by_frame[key].append(centroids[valid_mask])
1465
+
1466
+ gt_by_frame = {}
1467
+ for gt in ground_truth:
1468
+ key = (gt["video_idx"], gt["frame_idx"])
1469
+ if key not in gt_by_frame:
1470
+ gt_by_frame[key] = []
1471
+ # gt_instances shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
1472
+ centroids = gt["gt_instances"].reshape(-1, 2)
1473
+ # Filter out NaN centroids
1474
+ valid_mask = ~np.isnan(centroids).any(axis=1)
1475
+ gt_by_frame[key].append(centroids[valid_mask])
1476
+
1477
+ # Process each frame
1478
+ all_frames = set(pred_by_frame.keys()) | set(gt_by_frame.keys())
1479
+ for frame_key in all_frames:
1480
+ # Concatenate all predictions for this frame
1481
+ if frame_key in pred_by_frame:
1482
+ frame_preds = np.concatenate(pred_by_frame[frame_key], axis=0)
1483
+ else:
1484
+ frame_preds = np.zeros((0, 2))
1485
+
1486
+ # Concatenate all GT for this frame
1487
+ if frame_key in gt_by_frame:
1488
+ frame_gt = np.concatenate(gt_by_frame[frame_key], axis=0)
1489
+ else:
1490
+ frame_gt = np.zeros((0, 2))
1491
+
1492
+ # Match predictions to ground truth
1493
+ matched_pred, matched_gt, unmatched_pred, unmatched_gt = match_centroids(
1494
+ frame_preds, frame_gt, max_distance=self.match_threshold
1495
+ )
1496
+
1497
+ # Compute distances for matched pairs
1498
+ if len(matched_pred) > 0:
1499
+ matched_pred_points = frame_preds[matched_pred]
1500
+ matched_gt_points = frame_gt[matched_gt]
1501
+ distances = np.linalg.norm(
1502
+ matched_pred_points - matched_gt_points, axis=1
1503
+ )
1504
+ all_distances.extend(distances.tolist())
1505
+
1506
+ # Update counts
1507
+ total_tp += len(matched_pred)
1508
+ total_fp += len(unmatched_pred)
1509
+ total_fn += len(unmatched_gt)
1510
+
1511
+ # Compute aggregate metrics
1512
+ all_distances = np.array(all_distances)
1513
+
1514
+ # Distance metrics (only if we have matches)
1515
+ if len(all_distances) > 0:
1516
+ dist_avg = float(np.mean(all_distances))
1517
+ dist_median = float(np.median(all_distances))
1518
+ dist_p90 = float(np.percentile(all_distances, 90))
1519
+ dist_p95 = float(np.percentile(all_distances, 95))
1520
+ dist_max = float(np.max(all_distances))
1521
+ else:
1522
+ dist_avg = dist_median = dist_p90 = dist_p95 = dist_max = float("nan")
1523
+
1524
+ # Detection metrics
1525
+ precision = (
1526
+ total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
1527
+ )
1528
+ recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
1529
+ f1 = (
1530
+ 2 * precision * recall / (precision + recall)
1531
+ if (precision + recall) > 0
1532
+ else 0.0
1533
+ )
1534
+
1535
+ return {
1536
+ "dist_avg": dist_avg,
1537
+ "dist_median": dist_median,
1538
+ "dist_p90": dist_p90,
1539
+ "dist_p95": dist_p95,
1540
+ "dist_max": dist_max,
1541
+ "precision": precision,
1542
+ "recall": recall,
1543
+ "f1": f1,
1544
+ "n_true_positives": total_tp,
1545
+ "n_false_positives": total_fp,
1546
+ "n_false_negatives": total_fn,
1547
+ "n_total_predictions": total_tp + total_fp,
1548
+ "n_total_ground_truth": total_tp + total_fn,
1549
+ }
1550
+
1551
+ def _log_metrics(self, trainer, metrics: dict, epoch: int):
1552
+ """Log centroid evaluation metrics to WandB."""
1553
+ import numpy as np
1554
+ from lightning.pytorch.loggers import WandbLogger
1555
+
1556
+ # Get WandB logger
1557
+ wandb_logger = None
1558
+ for log in trainer.loggers:
1559
+ if isinstance(log, WandbLogger):
1560
+ wandb_logger = log
1561
+ break
1562
+
1563
+ if wandb_logger is None:
1564
+ return
1565
+
1566
+ log_dict = {"epoch": epoch}
1567
+
1568
+ # Distance metrics (with NaN handling)
1569
+ if not np.isnan(metrics["dist_avg"]):
1570
+ log_dict["eval/val/centroid_dist_avg"] = metrics["dist_avg"]
1571
+ if not np.isnan(metrics["dist_median"]):
1572
+ log_dict["eval/val/centroid_dist_median"] = metrics["dist_median"]
1573
+ if not np.isnan(metrics["dist_p90"]):
1574
+ log_dict["eval/val/centroid_dist_p90"] = metrics["dist_p90"]
1575
+ if not np.isnan(metrics["dist_p95"]):
1576
+ log_dict["eval/val/centroid_dist_p95"] = metrics["dist_p95"]
1577
+ if not np.isnan(metrics["dist_max"]):
1578
+ log_dict["eval/val/centroid_dist_max"] = metrics["dist_max"]
1579
+
1580
+ # Detection metrics
1581
+ log_dict["eval/val/centroid_precision"] = metrics["precision"]
1582
+ log_dict["eval/val/centroid_recall"] = metrics["recall"]
1583
+ log_dict["eval/val/centroid_f1"] = metrics["f1"]
1584
+
1585
+ # Counts
1586
+ log_dict["eval/val/centroid_n_tp"] = metrics["n_true_positives"]
1587
+ log_dict["eval/val/centroid_n_fp"] = metrics["n_false_positives"]
1588
+ log_dict["eval/val/centroid_n_fn"] = metrics["n_false_negatives"]
1589
+
1590
+ wandb_logger.experiment.log(log_dict, commit=False)
1591
+
1592
+ # Update best metrics in summary
1593
+ for key, value in log_dict.items():
1594
+ if key == "epoch":
1595
+ continue
1596
+ summary_key = f"best/{key}"
1597
+ current_best = wandb_logger.experiment.summary.get(summary_key)
1598
+ # For distance metrics, lower is better; for others, higher is better
1599
+ is_distance = "dist" in key
1600
+ if current_best is None:
1601
+ wandb_logger.experiment.summary[summary_key] = value
1602
+ elif is_distance and value < current_best:
1603
+ wandb_logger.experiment.summary[summary_key] = value
1604
+ elif not is_distance and value > current_best:
1605
+ wandb_logger.experiment.summary[summary_key] = value