sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
sleap_nn/training/callbacks.py
CHANGED
|
@@ -85,15 +85,10 @@ class CSVLoggerCallback(Callback):
|
|
|
85
85
|
if key == "epoch":
|
|
86
86
|
log_data["epoch"] = trainer.current_epoch
|
|
87
87
|
elif key == "learning_rate":
|
|
88
|
-
# Handle
|
|
89
|
-
# 1. Direct "learning_rate" key
|
|
90
|
-
# 2. "train/lr" key (current format from lightning modules)
|
|
91
|
-
# 3. "lr-*" keys from LearningRateMonitor (legacy)
|
|
88
|
+
# Handle both direct logging and LearningRateMonitor format (lr-*)
|
|
92
89
|
value = metrics.get(key, None)
|
|
93
90
|
if value is None:
|
|
94
|
-
|
|
95
|
-
if value is None:
|
|
96
|
-
# Look for lr-* keys from LearningRateMonitor (legacy)
|
|
91
|
+
# Look for lr-* keys from LearningRateMonitor
|
|
97
92
|
for metric_key in metrics.keys():
|
|
98
93
|
if metric_key.startswith("lr-"):
|
|
99
94
|
value = metrics[metric_key]
|
|
@@ -286,49 +281,45 @@ class WandBVizCallback(Callback):
|
|
|
286
281
|
|
|
287
282
|
# Get the wandb logger to use its experiment for logging
|
|
288
283
|
wandb_logger = self._get_wandb_logger(trainer)
|
|
284
|
+
if wandb_logger is None:
|
|
285
|
+
return # No wandb logger, skip visualization logging
|
|
286
|
+
|
|
287
|
+
# Get visualization data
|
|
288
|
+
train_data = self.train_viz_fn()
|
|
289
|
+
val_data = self.val_viz_fn()
|
|
290
|
+
|
|
291
|
+
# Render and log for each enabled mode
|
|
292
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
293
|
+
log_dict = {}
|
|
294
|
+
for mode_name, renderer in self.renderers.items():
|
|
295
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
296
|
+
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
297
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
298
|
+
log_dict[f"train_predictions{suffix}"] = train_img
|
|
299
|
+
log_dict[f"val_predictions{suffix}"] = val_img
|
|
300
|
+
|
|
301
|
+
if log_dict:
|
|
302
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
303
|
+
log_dict["epoch"] = epoch
|
|
304
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
305
|
+
# Lightning will commit when it logs its own metrics
|
|
306
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
307
|
+
|
|
308
|
+
# Optionally also log to table for backwards compat
|
|
309
|
+
if self.log_table and "direct" in self.renderers:
|
|
310
|
+
train_img = self.renderers["direct"].render(
|
|
311
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
312
|
+
)
|
|
313
|
+
val_img = self.renderers["direct"].render(
|
|
314
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
315
|
+
)
|
|
316
|
+
table = wandb.Table(
|
|
317
|
+
columns=["Epoch", "Train", "Validation"],
|
|
318
|
+
data=[[epoch, train_img, val_img]],
|
|
319
|
+
)
|
|
320
|
+
wandb_logger.experiment.log({"predictions_table": table}, commit=False)
|
|
289
321
|
|
|
290
|
-
|
|
291
|
-
if wandb_logger is not None:
|
|
292
|
-
# Get visualization data
|
|
293
|
-
train_data = self.train_viz_fn()
|
|
294
|
-
val_data = self.val_viz_fn()
|
|
295
|
-
|
|
296
|
-
# Render and log for each enabled mode
|
|
297
|
-
# Use the logger's experiment to let Lightning manage step tracking
|
|
298
|
-
log_dict = {}
|
|
299
|
-
for mode_name, renderer in self.renderers.items():
|
|
300
|
-
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
301
|
-
train_img = renderer.render(
|
|
302
|
-
train_data, caption=f"Train Epoch {epoch}"
|
|
303
|
-
)
|
|
304
|
-
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
305
|
-
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
306
|
-
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
307
|
-
|
|
308
|
-
if log_dict:
|
|
309
|
-
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
310
|
-
log_dict["epoch"] = epoch
|
|
311
|
-
# Use commit=False to accumulate with other metrics in this step
|
|
312
|
-
# Lightning will commit when it logs its own metrics
|
|
313
|
-
wandb_logger.experiment.log(log_dict, commit=False)
|
|
314
|
-
|
|
315
|
-
# Optionally also log to table for backwards compat
|
|
316
|
-
if self.log_table and "direct" in self.renderers:
|
|
317
|
-
train_img = self.renderers["direct"].render(
|
|
318
|
-
train_data, caption=f"Train Epoch {epoch}"
|
|
319
|
-
)
|
|
320
|
-
val_img = self.renderers["direct"].render(
|
|
321
|
-
val_data, caption=f"Val Epoch {epoch}"
|
|
322
|
-
)
|
|
323
|
-
table = wandb.Table(
|
|
324
|
-
columns=["Epoch", "Train", "Validation"],
|
|
325
|
-
data=[[epoch, train_img, val_img]],
|
|
326
|
-
)
|
|
327
|
-
wandb_logger.experiment.log(
|
|
328
|
-
{"predictions_table": table}, commit=False
|
|
329
|
-
)
|
|
330
|
-
|
|
331
|
-
# Sync all processes - barrier must be reached by ALL ranks
|
|
322
|
+
# Sync all processes
|
|
332
323
|
trainer.strategy.barrier()
|
|
333
324
|
|
|
334
325
|
|
|
@@ -387,377 +378,80 @@ class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
|
387
378
|
|
|
388
379
|
# Get the wandb logger to use its experiment for logging
|
|
389
380
|
wandb_logger = self._get_wandb_logger(trainer)
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
|
|
417
|
-
buf = BytesIO()
|
|
418
|
-
train_pafs_fig.savefig(
|
|
419
|
-
buf, format="png", bbox_inches="tight", pad_inches=0
|
|
420
|
-
)
|
|
421
|
-
buf.seek(0)
|
|
422
|
-
plt.close(train_pafs_fig)
|
|
423
|
-
train_pafs_pil = Image.open(buf)
|
|
424
|
-
log_dict["viz/train/pafs"] = wandb.Image(
|
|
425
|
-
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
426
|
-
)
|
|
427
|
-
|
|
428
|
-
val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
|
|
429
|
-
buf = BytesIO()
|
|
430
|
-
val_pafs_fig.savefig(
|
|
431
|
-
buf, format="png", bbox_inches="tight", pad_inches=0
|
|
432
|
-
)
|
|
433
|
-
buf.seek(0)
|
|
434
|
-
plt.close(val_pafs_fig)
|
|
435
|
-
val_pafs_pil = Image.open(buf)
|
|
436
|
-
log_dict["viz/val/pafs"] = wandb.Image(
|
|
437
|
-
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
if log_dict:
|
|
441
|
-
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
442
|
-
log_dict["epoch"] = epoch
|
|
443
|
-
# Use commit=False to accumulate with other metrics in this step
|
|
444
|
-
# Lightning will commit when it logs its own metrics
|
|
445
|
-
wandb_logger.experiment.log(log_dict, commit=False)
|
|
446
|
-
|
|
447
|
-
# Optionally also log to table
|
|
448
|
-
if self.log_table and "direct" in self.renderers:
|
|
449
|
-
train_img = self.renderers["direct"].render(
|
|
450
|
-
train_data, caption=f"Train Epoch {epoch}"
|
|
451
|
-
)
|
|
452
|
-
val_img = self.renderers["direct"].render(
|
|
453
|
-
val_data, caption=f"Val Epoch {epoch}"
|
|
454
|
-
)
|
|
455
|
-
table = wandb.Table(
|
|
456
|
-
columns=[
|
|
457
|
-
"Epoch",
|
|
458
|
-
"Train",
|
|
459
|
-
"Validation",
|
|
460
|
-
"Train PAFs",
|
|
461
|
-
"Val PAFs",
|
|
462
|
-
],
|
|
463
|
-
data=[
|
|
464
|
-
[
|
|
465
|
-
epoch,
|
|
466
|
-
train_img,
|
|
467
|
-
val_img,
|
|
468
|
-
log_dict["viz/train/pafs"],
|
|
469
|
-
log_dict["viz/val/pafs"],
|
|
470
|
-
]
|
|
471
|
-
],
|
|
472
|
-
)
|
|
473
|
-
wandb_logger.experiment.log(
|
|
474
|
-
{"predictions_table": table}, commit=False
|
|
475
|
-
)
|
|
476
|
-
|
|
477
|
-
# Sync all processes - barrier must be reached by ALL ranks
|
|
478
|
-
trainer.strategy.barrier()
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
class UnifiedVizCallback(Callback):
|
|
482
|
-
"""Unified callback for all visualization outputs during training.
|
|
483
|
-
|
|
484
|
-
This callback consolidates all visualization functionality into a single callback,
|
|
485
|
-
eliminating redundant dataset copies and inference runs. It handles:
|
|
486
|
-
- Local disk saving (matplotlib figures)
|
|
487
|
-
- WandB logging (multiple modes: direct, boxes, masks)
|
|
488
|
-
- Model-specific visualizations (PAFs for bottomup, class maps for multi_class_bottomup)
|
|
489
|
-
|
|
490
|
-
Benefits over separate callbacks:
|
|
491
|
-
- Uses ONE sample per epoch for all visualizations (no dataset deepcopy)
|
|
492
|
-
- Runs inference ONCE per sample (vs 4-8x in previous implementation)
|
|
493
|
-
- Outputs to multiple destinations from the same data
|
|
494
|
-
- Simpler code with less duplication
|
|
495
|
-
|
|
496
|
-
Attributes:
|
|
497
|
-
model_trainer: Reference to the ModelTrainer (for lazy access to lightning_model).
|
|
498
|
-
train_pipeline: Iterator over training visualization dataset.
|
|
499
|
-
val_pipeline: Iterator over validation visualization dataset.
|
|
500
|
-
model_type: Type of model (affects which visualizations are enabled).
|
|
501
|
-
save_local: Whether to save matplotlib figures to disk.
|
|
502
|
-
local_save_dir: Directory for local visualization saves.
|
|
503
|
-
log_wandb: Whether to log visualizations to wandb.
|
|
504
|
-
wandb_modes: List of wandb rendering modes ("direct", "boxes", "masks").
|
|
505
|
-
wandb_box_size: Size of keypoint boxes in pixels (for "boxes" mode).
|
|
506
|
-
wandb_confmap_threshold: Threshold for confmap masks (for "masks" mode).
|
|
507
|
-
log_wandb_table: Whether to also log to a wandb.Table.
|
|
508
|
-
"""
|
|
509
|
-
|
|
510
|
-
def __init__(
|
|
511
|
-
self,
|
|
512
|
-
model_trainer,
|
|
513
|
-
train_dataset,
|
|
514
|
-
val_dataset,
|
|
515
|
-
model_type: str,
|
|
516
|
-
save_local: bool = True,
|
|
517
|
-
local_save_dir: Optional[Path] = None,
|
|
518
|
-
log_wandb: bool = False,
|
|
519
|
-
wandb_modes: Optional[list] = None,
|
|
520
|
-
wandb_box_size: float = 5.0,
|
|
521
|
-
wandb_confmap_threshold: float = 0.1,
|
|
522
|
-
log_wandb_table: bool = False,
|
|
523
|
-
):
|
|
524
|
-
"""Initialize the unified visualization callback.
|
|
525
|
-
|
|
526
|
-
Args:
|
|
527
|
-
model_trainer: ModelTrainer instance (lightning_model accessed lazily).
|
|
528
|
-
train_dataset: Training visualization dataset (will be cycled).
|
|
529
|
-
val_dataset: Validation visualization dataset (will be cycled).
|
|
530
|
-
model_type: Model type string (e.g., "bottomup", "multi_class_bottomup").
|
|
531
|
-
save_local: If True, save matplotlib figures to local_save_dir.
|
|
532
|
-
local_save_dir: Path to directory for saving visualization images.
|
|
533
|
-
log_wandb: If True, log visualizations to wandb.
|
|
534
|
-
wandb_modes: List of wandb rendering modes. Defaults to ["direct"].
|
|
535
|
-
wandb_box_size: Size of keypoint boxes in pixels.
|
|
536
|
-
wandb_confmap_threshold: Threshold for confidence map masks.
|
|
537
|
-
log_wandb_table: If True, also log to a wandb.Table.
|
|
538
|
-
"""
|
|
539
|
-
super().__init__()
|
|
540
|
-
from itertools import cycle
|
|
541
|
-
|
|
542
|
-
self.model_trainer = model_trainer
|
|
543
|
-
self.train_pipeline = cycle(train_dataset)
|
|
544
|
-
self.val_pipeline = cycle(val_dataset)
|
|
545
|
-
self.model_type = model_type
|
|
546
|
-
|
|
547
|
-
# Local disk config
|
|
548
|
-
self.save_local = save_local
|
|
549
|
-
self.local_save_dir = local_save_dir
|
|
550
|
-
|
|
551
|
-
# WandB config
|
|
552
|
-
self.log_wandb = log_wandb
|
|
553
|
-
self.wandb_modes = wandb_modes or ["direct"]
|
|
554
|
-
self.wandb_box_size = wandb_box_size
|
|
555
|
-
self.wandb_confmap_threshold = wandb_confmap_threshold
|
|
556
|
-
self.log_wandb_table = log_wandb_table
|
|
557
|
-
|
|
558
|
-
# Auto-enable model-specific visualizations
|
|
559
|
-
self.viz_pafs = model_type == "bottomup"
|
|
560
|
-
self.viz_class_maps = model_type == "multi_class_bottomup"
|
|
561
|
-
|
|
562
|
-
# Initialize renderers
|
|
563
|
-
from sleap_nn.training.utils import MatplotlibRenderer, WandBRenderer
|
|
564
|
-
|
|
565
|
-
self._mpl_renderer = MatplotlibRenderer()
|
|
566
|
-
|
|
567
|
-
# Create wandb renderers for each enabled mode
|
|
568
|
-
self._wandb_renderers = {}
|
|
569
|
-
if log_wandb:
|
|
570
|
-
for mode in self.wandb_modes:
|
|
571
|
-
self._wandb_renderers[mode] = WandBRenderer(
|
|
572
|
-
mode=mode,
|
|
573
|
-
box_size=wandb_box_size,
|
|
574
|
-
confmap_threshold=wandb_confmap_threshold,
|
|
575
|
-
)
|
|
576
|
-
|
|
577
|
-
def _get_wandb_logger(self, trainer):
|
|
578
|
-
"""Get the WandbLogger from trainer's loggers."""
|
|
579
|
-
from lightning.pytorch.loggers import WandbLogger
|
|
580
|
-
|
|
581
|
-
for log in trainer.loggers:
|
|
582
|
-
if isinstance(log, WandbLogger):
|
|
583
|
-
return log
|
|
584
|
-
return None
|
|
585
|
-
|
|
586
|
-
def _get_viz_data(self, sample):
|
|
587
|
-
"""Get visualization data with all needed fields based on model type.
|
|
588
|
-
|
|
589
|
-
Args:
|
|
590
|
-
sample: A sample from the visualization dataset.
|
|
591
|
-
|
|
592
|
-
Returns:
|
|
593
|
-
VisualizationData with appropriate fields populated.
|
|
594
|
-
"""
|
|
595
|
-
# Build kwargs based on model type
|
|
596
|
-
kwargs = {}
|
|
597
|
-
if self.viz_pafs:
|
|
598
|
-
kwargs["include_pafs"] = True
|
|
599
|
-
if self.viz_class_maps:
|
|
600
|
-
kwargs["include_class_maps"] = True
|
|
601
|
-
|
|
602
|
-
# Access lightning_model lazily from model_trainer
|
|
603
|
-
return self.model_trainer.lightning_model.get_visualization_data(
|
|
604
|
-
sample, **kwargs
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
def _save_local_viz(self, data, prefix: str, epoch: int):
|
|
608
|
-
"""Save visualization to local disk.
|
|
609
|
-
|
|
610
|
-
Args:
|
|
611
|
-
data: VisualizationData object.
|
|
612
|
-
prefix: Filename prefix (e.g., "train", "validation").
|
|
613
|
-
epoch: Current epoch number.
|
|
614
|
-
"""
|
|
615
|
-
if not self.save_local or self.local_save_dir is None:
|
|
616
|
-
return
|
|
617
|
-
|
|
618
|
-
# Confmaps visualization
|
|
619
|
-
fig = self._mpl_renderer.render(data)
|
|
620
|
-
fig_path = self.local_save_dir / f"{prefix}.{epoch:04d}.png"
|
|
621
|
-
fig.savefig(fig_path, format="png")
|
|
622
|
-
plt.close(fig)
|
|
623
|
-
|
|
624
|
-
# PAFs visualization (for bottomup models)
|
|
625
|
-
if self.viz_pafs and data.pred_pafs is not None:
|
|
626
|
-
fig = self._mpl_renderer.render_pafs(data)
|
|
627
|
-
fig_path = self.local_save_dir / f"{prefix}.pafs_magnitude.{epoch:04d}.png"
|
|
628
|
-
fig.savefig(fig_path, format="png")
|
|
629
|
-
plt.close(fig)
|
|
630
|
-
|
|
631
|
-
# Class maps visualization (for multi_class_bottomup models)
|
|
632
|
-
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
633
|
-
fig = self._render_class_maps(data)
|
|
634
|
-
fig_path = self.local_save_dir / f"{prefix}.class_maps.{epoch:04d}.png"
|
|
635
|
-
fig.savefig(fig_path, format="png")
|
|
636
|
-
plt.close(fig)
|
|
637
|
-
|
|
638
|
-
def _render_class_maps(self, data):
|
|
639
|
-
"""Render class maps visualization.
|
|
640
|
-
|
|
641
|
-
Args:
|
|
642
|
-
data: VisualizationData with pred_class_maps populated.
|
|
643
|
-
|
|
644
|
-
Returns:
|
|
645
|
-
A matplotlib Figure object.
|
|
646
|
-
"""
|
|
647
|
-
from sleap_nn.training.utils import plot_img, plot_confmaps
|
|
648
|
-
|
|
649
|
-
img = data.image
|
|
650
|
-
scale = 1.0
|
|
651
|
-
if img.shape[0] < 512:
|
|
652
|
-
scale = 2.0
|
|
653
|
-
if img.shape[0] < 256:
|
|
654
|
-
scale = 4.0
|
|
655
|
-
|
|
656
|
-
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
657
|
-
plot_confmaps(
|
|
658
|
-
data.pred_class_maps,
|
|
659
|
-
output_scale=data.pred_class_maps.shape[0] / img.shape[0],
|
|
660
|
-
)
|
|
661
|
-
return fig
|
|
662
|
-
|
|
663
|
-
def _log_wandb_viz(self, data, prefix: str, epoch: int, wandb_logger):
|
|
664
|
-
"""Log visualization to wandb.
|
|
665
|
-
|
|
666
|
-
Args:
|
|
667
|
-
data: VisualizationData object.
|
|
668
|
-
prefix: Log prefix (e.g., "train", "val").
|
|
669
|
-
epoch: Current epoch number.
|
|
670
|
-
wandb_logger: WandbLogger instance.
|
|
671
|
-
"""
|
|
672
|
-
if not self.log_wandb or wandb_logger is None:
|
|
673
|
-
return
|
|
674
|
-
|
|
675
|
-
from io import BytesIO
|
|
676
|
-
from PIL import Image as PILImage
|
|
677
|
-
|
|
678
|
-
log_dict = {}
|
|
679
|
-
|
|
680
|
-
# Render confmaps for each enabled mode
|
|
681
|
-
for mode_name, renderer in self._wandb_renderers.items():
|
|
682
|
-
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
683
|
-
img = renderer.render(data, caption=f"{prefix.title()} Epoch {epoch}")
|
|
684
|
-
log_dict[f"viz/{prefix}/predictions{suffix}"] = img
|
|
685
|
-
|
|
686
|
-
# PAFs visualization (for bottomup models)
|
|
687
|
-
if self.viz_pafs and data.pred_pafs is not None:
|
|
688
|
-
pafs_fig = self._mpl_renderer.render_pafs(data)
|
|
381
|
+
if wandb_logger is None:
|
|
382
|
+
return # No wandb logger, skip visualization logging
|
|
383
|
+
|
|
384
|
+
# Get visualization data
|
|
385
|
+
train_data = self.train_viz_fn()
|
|
386
|
+
val_data = self.val_viz_fn()
|
|
387
|
+
train_pafs_data = self.train_pafs_viz_fn()
|
|
388
|
+
val_pafs_data = self.val_pafs_viz_fn()
|
|
389
|
+
|
|
390
|
+
# Render and log for each enabled mode
|
|
391
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
392
|
+
log_dict = {}
|
|
393
|
+
for mode_name, renderer in self.renderers.items():
|
|
394
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
395
|
+
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
396
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
397
|
+
log_dict[f"train_predictions{suffix}"] = train_img
|
|
398
|
+
log_dict[f"val_predictions{suffix}"] = val_img
|
|
399
|
+
|
|
400
|
+
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
401
|
+
from io import BytesIO
|
|
402
|
+
import matplotlib.pyplot as plt
|
|
403
|
+
from PIL import Image
|
|
404
|
+
|
|
405
|
+
train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
|
|
689
406
|
buf = BytesIO()
|
|
690
|
-
|
|
407
|
+
train_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
691
408
|
buf.seek(0)
|
|
692
|
-
plt.close(
|
|
693
|
-
|
|
694
|
-
log_dict[
|
|
695
|
-
|
|
409
|
+
plt.close(train_pafs_fig)
|
|
410
|
+
train_pafs_pil = Image.open(buf)
|
|
411
|
+
log_dict["train_pafs"] = wandb.Image(
|
|
412
|
+
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
696
413
|
)
|
|
697
414
|
|
|
698
|
-
|
|
699
|
-
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
700
|
-
class_fig = self._render_class_maps(data)
|
|
415
|
+
val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
|
|
701
416
|
buf = BytesIO()
|
|
702
|
-
|
|
417
|
+
val_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
703
418
|
buf.seek(0)
|
|
704
|
-
plt.close(
|
|
705
|
-
|
|
706
|
-
log_dict[
|
|
707
|
-
|
|
708
|
-
)
|
|
709
|
-
|
|
710
|
-
if log_dict:
|
|
711
|
-
log_dict["epoch"] = epoch
|
|
712
|
-
wandb_logger.experiment.log(log_dict, commit=False)
|
|
713
|
-
|
|
714
|
-
# Optionally log to table for backwards compatibility
|
|
715
|
-
if self.log_wandb_table and "direct" in self._wandb_renderers:
|
|
716
|
-
train_img = self._wandb_renderers["direct"].render(
|
|
717
|
-
data, caption=f"{prefix.title()} Epoch {epoch}"
|
|
718
|
-
)
|
|
719
|
-
table_data = [[epoch, train_img]]
|
|
720
|
-
columns = ["Epoch", prefix.title()]
|
|
721
|
-
|
|
722
|
-
if self.viz_pafs and data.pred_pafs is not None:
|
|
723
|
-
columns.append(f"{prefix.title()} PAFs")
|
|
724
|
-
table_data[0].append(log_dict.get(f"viz/{prefix}/pafs"))
|
|
725
|
-
|
|
726
|
-
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
727
|
-
columns.append(f"{prefix.title()} Class Maps")
|
|
728
|
-
table_data[0].append(log_dict.get(f"viz/{prefix}/class_maps"))
|
|
729
|
-
|
|
730
|
-
table = wandb.Table(columns=columns, data=table_data)
|
|
731
|
-
wandb_logger.experiment.log(
|
|
732
|
-
{f"predictions_table_{prefix}": table}, commit=False
|
|
419
|
+
plt.close(val_pafs_fig)
|
|
420
|
+
val_pafs_pil = Image.open(buf)
|
|
421
|
+
log_dict["val_pafs"] = wandb.Image(
|
|
422
|
+
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
733
423
|
)
|
|
734
424
|
|
|
735
|
-
|
|
736
|
-
|
|
425
|
+
if log_dict:
|
|
426
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
427
|
+
log_dict["epoch"] = epoch
|
|
428
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
429
|
+
# Lightning will commit when it logs its own metrics
|
|
430
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
431
|
+
|
|
432
|
+
# Optionally also log to table
|
|
433
|
+
if self.log_table and "direct" in self.renderers:
|
|
434
|
+
train_img = self.renderers["direct"].render(
|
|
435
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
436
|
+
)
|
|
437
|
+
val_img = self.renderers["direct"].render(
|
|
438
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
439
|
+
)
|
|
440
|
+
table = wandb.Table(
|
|
441
|
+
columns=["Epoch", "Train", "Validation", "Train PAFs", "Val PAFs"],
|
|
442
|
+
data=[
|
|
443
|
+
[
|
|
444
|
+
epoch,
|
|
445
|
+
train_img,
|
|
446
|
+
val_img,
|
|
447
|
+
log_dict["train_pafs"],
|
|
448
|
+
log_dict["val_pafs"],
|
|
449
|
+
]
|
|
450
|
+
],
|
|
451
|
+
)
|
|
452
|
+
wandb_logger.experiment.log({"predictions_table": table}, commit=False)
|
|
737
453
|
|
|
738
|
-
|
|
739
|
-
trainer: PyTorch Lightning trainer.
|
|
740
|
-
pl_module: Lightning module (not used, we use self.lightning_module).
|
|
741
|
-
"""
|
|
742
|
-
if trainer.is_global_zero:
|
|
743
|
-
epoch = trainer.current_epoch
|
|
744
|
-
wandb_logger = self._get_wandb_logger(trainer) if self.log_wandb else None
|
|
745
|
-
|
|
746
|
-
# Get ONE sample for train visualization
|
|
747
|
-
train_sample = next(self.train_pipeline)
|
|
748
|
-
# Run inference ONCE with all needed data
|
|
749
|
-
train_data = self._get_viz_data(train_sample)
|
|
750
|
-
# Output to all destinations
|
|
751
|
-
self._save_local_viz(train_data, "train", epoch)
|
|
752
|
-
self._log_wandb_viz(train_data, "train", epoch, wandb_logger)
|
|
753
|
-
|
|
754
|
-
# Same for validation
|
|
755
|
-
val_sample = next(self.val_pipeline)
|
|
756
|
-
val_data = self._get_viz_data(val_sample)
|
|
757
|
-
self._save_local_viz(val_data, "validation", epoch)
|
|
758
|
-
self._log_wandb_viz(val_data, "val", epoch, wandb_logger)
|
|
759
|
-
|
|
760
|
-
# Sync all processes - barrier must be reached by ALL ranks
|
|
454
|
+
# Sync all processes
|
|
761
455
|
trainer.strategy.barrier()
|
|
762
456
|
|
|
763
457
|
|
|
@@ -968,638 +662,3 @@ class ProgressReporterZMQ(Callback):
|
|
|
968
662
|
return {
|
|
969
663
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
970
664
|
}
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
class EpochEndEvaluationCallback(Callback):
|
|
974
|
-
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
975
|
-
|
|
976
|
-
This callback collects predictions and ground truth during validation,
|
|
977
|
-
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
978
|
-
metrics to WandB.
|
|
979
|
-
|
|
980
|
-
Attributes:
|
|
981
|
-
skeleton: sio.Skeleton for creating instances.
|
|
982
|
-
videos: List of sio.Video objects.
|
|
983
|
-
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
984
|
-
oks_stddev: OKS standard deviation (default: 0.025).
|
|
985
|
-
oks_scale: Optional OKS scale override.
|
|
986
|
-
metrics_to_log: List of metric keys to log.
|
|
987
|
-
"""
|
|
988
|
-
|
|
989
|
-
def __init__(
|
|
990
|
-
self,
|
|
991
|
-
skeleton: "sio.Skeleton",
|
|
992
|
-
videos: list,
|
|
993
|
-
eval_frequency: int = 1,
|
|
994
|
-
oks_stddev: float = 0.025,
|
|
995
|
-
oks_scale: Optional[float] = None,
|
|
996
|
-
metrics_to_log: Optional[list] = None,
|
|
997
|
-
):
|
|
998
|
-
"""Initialize the callback.
|
|
999
|
-
|
|
1000
|
-
Args:
|
|
1001
|
-
skeleton: sio.Skeleton for creating instances.
|
|
1002
|
-
videos: List of sio.Video objects.
|
|
1003
|
-
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1004
|
-
oks_stddev: OKS standard deviation (default: 0.025).
|
|
1005
|
-
oks_scale: Optional OKS scale override.
|
|
1006
|
-
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
1007
|
-
"""
|
|
1008
|
-
super().__init__()
|
|
1009
|
-
self.skeleton = skeleton
|
|
1010
|
-
self.videos = videos
|
|
1011
|
-
self.eval_frequency = eval_frequency
|
|
1012
|
-
self.oks_stddev = oks_stddev
|
|
1013
|
-
self.oks_scale = oks_scale
|
|
1014
|
-
self.metrics_to_log = metrics_to_log or [
|
|
1015
|
-
"mOKS",
|
|
1016
|
-
"oks_voc.mAP",
|
|
1017
|
-
"oks_voc.mAR",
|
|
1018
|
-
"distance/avg",
|
|
1019
|
-
"distance/p50",
|
|
1020
|
-
"distance/p95",
|
|
1021
|
-
"distance/p99",
|
|
1022
|
-
"mPCK",
|
|
1023
|
-
"PCK@5",
|
|
1024
|
-
"PCK@10",
|
|
1025
|
-
"visibility_precision",
|
|
1026
|
-
"visibility_recall",
|
|
1027
|
-
]
|
|
1028
|
-
|
|
1029
|
-
def on_validation_epoch_start(self, trainer, pl_module):
|
|
1030
|
-
"""Enable prediction collection at the start of validation.
|
|
1031
|
-
|
|
1032
|
-
Skip during sanity check to avoid inference issues.
|
|
1033
|
-
"""
|
|
1034
|
-
if trainer.sanity_checking:
|
|
1035
|
-
return
|
|
1036
|
-
pl_module._collect_val_predictions = True
|
|
1037
|
-
|
|
1038
|
-
def on_validation_epoch_end(self, trainer, pl_module):
|
|
1039
|
-
"""Run evaluation and log metrics at end of validation epoch."""
|
|
1040
|
-
import sleap_io as sio
|
|
1041
|
-
import numpy as np
|
|
1042
|
-
from lightning.pytorch.loggers import WandbLogger
|
|
1043
|
-
from sleap_nn.evaluation import Evaluator
|
|
1044
|
-
|
|
1045
|
-
# Determine if we should run evaluation this epoch (only on rank 0)
|
|
1046
|
-
should_evaluate = (
|
|
1047
|
-
trainer.current_epoch + 1
|
|
1048
|
-
) % self.eval_frequency == 0 and trainer.is_global_zero
|
|
1049
|
-
|
|
1050
|
-
if should_evaluate:
|
|
1051
|
-
# Check if we have predictions
|
|
1052
|
-
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
1053
|
-
logger.warning("No predictions collected for epoch-end evaluation")
|
|
1054
|
-
else:
|
|
1055
|
-
try:
|
|
1056
|
-
# Build sio.Labels from accumulated predictions and ground truth
|
|
1057
|
-
pred_labels = self._build_pred_labels(
|
|
1058
|
-
pl_module.val_predictions, sio, np
|
|
1059
|
-
)
|
|
1060
|
-
gt_labels = self._build_gt_labels(
|
|
1061
|
-
pl_module.val_ground_truth, sio, np
|
|
1062
|
-
)
|
|
1063
|
-
|
|
1064
|
-
# Check if we have valid frames to evaluate
|
|
1065
|
-
if len(pred_labels) == 0:
|
|
1066
|
-
logger.warning(
|
|
1067
|
-
"No valid predictions for epoch-end evaluation "
|
|
1068
|
-
"(all predictions may be empty or NaN)"
|
|
1069
|
-
)
|
|
1070
|
-
else:
|
|
1071
|
-
# Run evaluation
|
|
1072
|
-
evaluator = Evaluator(
|
|
1073
|
-
ground_truth_instances=gt_labels,
|
|
1074
|
-
predicted_instances=pred_labels,
|
|
1075
|
-
oks_stddev=self.oks_stddev,
|
|
1076
|
-
oks_scale=self.oks_scale,
|
|
1077
|
-
user_labels_only=False, # All validation frames are "user" frames
|
|
1078
|
-
)
|
|
1079
|
-
metrics = evaluator.evaluate()
|
|
1080
|
-
|
|
1081
|
-
# Log to WandB
|
|
1082
|
-
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
1083
|
-
|
|
1084
|
-
logger.info(
|
|
1085
|
-
f"Epoch {trainer.current_epoch} evaluation: "
|
|
1086
|
-
f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
|
|
1087
|
-
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
1088
|
-
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
1089
|
-
)
|
|
1090
|
-
|
|
1091
|
-
except Exception as e:
|
|
1092
|
-
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
1093
|
-
|
|
1094
|
-
# Cleanup - all ranks reset the flag, rank 0 clears the lists
|
|
1095
|
-
pl_module._collect_val_predictions = False
|
|
1096
|
-
if trainer.is_global_zero:
|
|
1097
|
-
pl_module.val_predictions = []
|
|
1098
|
-
pl_module.val_ground_truth = []
|
|
1099
|
-
|
|
1100
|
-
# Sync all processes - barrier must be reached by ALL ranks
|
|
1101
|
-
trainer.strategy.barrier()
|
|
1102
|
-
|
|
1103
|
-
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
1104
|
-
"""Convert prediction dicts to sio.Labels."""
|
|
1105
|
-
labeled_frames = []
|
|
1106
|
-
for pred in predictions:
|
|
1107
|
-
pred_peaks = pred["pred_peaks"]
|
|
1108
|
-
pred_scores = pred["pred_scores"]
|
|
1109
|
-
|
|
1110
|
-
# Handle NaN/missing predictions
|
|
1111
|
-
if pred_peaks is None or (
|
|
1112
|
-
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
1113
|
-
):
|
|
1114
|
-
continue
|
|
1115
|
-
|
|
1116
|
-
# Handle multi-instance predictions (bottomup)
|
|
1117
|
-
if len(pred_peaks.shape) == 2:
|
|
1118
|
-
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
1119
|
-
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
1120
|
-
pred_scores = pred_scores.reshape(1, -1)
|
|
1121
|
-
|
|
1122
|
-
instances = []
|
|
1123
|
-
for inst_idx in range(len(pred_peaks)):
|
|
1124
|
-
inst_points = pred_peaks[inst_idx]
|
|
1125
|
-
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
1126
|
-
|
|
1127
|
-
# Skip if all NaN
|
|
1128
|
-
if np.isnan(inst_points).all():
|
|
1129
|
-
continue
|
|
1130
|
-
|
|
1131
|
-
inst = sio.PredictedInstance.from_numpy(
|
|
1132
|
-
points_data=inst_points,
|
|
1133
|
-
skeleton=self.skeleton,
|
|
1134
|
-
point_scores=(
|
|
1135
|
-
inst_scores
|
|
1136
|
-
if inst_scores is not None
|
|
1137
|
-
else np.ones(len(inst_points))
|
|
1138
|
-
),
|
|
1139
|
-
score=(
|
|
1140
|
-
float(np.nanmean(inst_scores))
|
|
1141
|
-
if inst_scores is not None
|
|
1142
|
-
else 1.0
|
|
1143
|
-
),
|
|
1144
|
-
)
|
|
1145
|
-
instances.append(inst)
|
|
1146
|
-
|
|
1147
|
-
if instances:
|
|
1148
|
-
lf = sio.LabeledFrame(
|
|
1149
|
-
video=self.videos[pred["video_idx"]],
|
|
1150
|
-
frame_idx=pred["frame_idx"],
|
|
1151
|
-
instances=instances,
|
|
1152
|
-
)
|
|
1153
|
-
labeled_frames.append(lf)
|
|
1154
|
-
|
|
1155
|
-
return sio.Labels(
|
|
1156
|
-
videos=self.videos,
|
|
1157
|
-
skeletons=[self.skeleton],
|
|
1158
|
-
labeled_frames=labeled_frames,
|
|
1159
|
-
)
|
|
1160
|
-
|
|
1161
|
-
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
1162
|
-
"""Convert ground truth dicts to sio.Labels."""
|
|
1163
|
-
labeled_frames = []
|
|
1164
|
-
for gt in ground_truth:
|
|
1165
|
-
instances = []
|
|
1166
|
-
gt_instances = gt["gt_instances"]
|
|
1167
|
-
|
|
1168
|
-
# Handle shape variations
|
|
1169
|
-
if len(gt_instances.shape) == 2:
|
|
1170
|
-
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
1171
|
-
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
1172
|
-
|
|
1173
|
-
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
1174
|
-
inst_data = gt_instances[i]
|
|
1175
|
-
if np.isnan(inst_data).all():
|
|
1176
|
-
continue
|
|
1177
|
-
inst = sio.Instance.from_numpy(
|
|
1178
|
-
points_data=inst_data,
|
|
1179
|
-
skeleton=self.skeleton,
|
|
1180
|
-
)
|
|
1181
|
-
instances.append(inst)
|
|
1182
|
-
|
|
1183
|
-
if instances:
|
|
1184
|
-
lf = sio.LabeledFrame(
|
|
1185
|
-
video=self.videos[gt["video_idx"]],
|
|
1186
|
-
frame_idx=gt["frame_idx"],
|
|
1187
|
-
instances=instances,
|
|
1188
|
-
)
|
|
1189
|
-
labeled_frames.append(lf)
|
|
1190
|
-
|
|
1191
|
-
return sio.Labels(
|
|
1192
|
-
videos=self.videos,
|
|
1193
|
-
skeletons=[self.skeleton],
|
|
1194
|
-
labeled_frames=labeled_frames,
|
|
1195
|
-
)
|
|
1196
|
-
|
|
1197
|
-
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
1198
|
-
"""Log evaluation metrics to WandB."""
|
|
1199
|
-
import numpy as np
|
|
1200
|
-
from lightning.pytorch.loggers import WandbLogger
|
|
1201
|
-
|
|
1202
|
-
# Get WandB logger
|
|
1203
|
-
wandb_logger = None
|
|
1204
|
-
for log in trainer.loggers:
|
|
1205
|
-
if isinstance(log, WandbLogger):
|
|
1206
|
-
wandb_logger = log
|
|
1207
|
-
break
|
|
1208
|
-
|
|
1209
|
-
if wandb_logger is None:
|
|
1210
|
-
return
|
|
1211
|
-
|
|
1212
|
-
log_dict = {"epoch": epoch}
|
|
1213
|
-
|
|
1214
|
-
# Extract key metrics with consistent naming
|
|
1215
|
-
# All eval metrics use eval/val/ prefix since they're computed on validation data
|
|
1216
|
-
if "mOKS" in self.metrics_to_log:
|
|
1217
|
-
log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
|
|
1218
|
-
|
|
1219
|
-
if "oks_voc.mAP" in self.metrics_to_log:
|
|
1220
|
-
log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
1221
|
-
|
|
1222
|
-
if "oks_voc.mAR" in self.metrics_to_log:
|
|
1223
|
-
log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
1224
|
-
|
|
1225
|
-
# Distance metrics grouped under eval/val/distance/
|
|
1226
|
-
if "distance/avg" in self.metrics_to_log:
|
|
1227
|
-
val = metrics["distance_metrics"]["avg"]
|
|
1228
|
-
if not np.isnan(val):
|
|
1229
|
-
log_dict["eval/val/distance/avg"] = val
|
|
1230
|
-
|
|
1231
|
-
if "distance/p50" in self.metrics_to_log:
|
|
1232
|
-
val = metrics["distance_metrics"]["p50"]
|
|
1233
|
-
if not np.isnan(val):
|
|
1234
|
-
log_dict["eval/val/distance/p50"] = val
|
|
1235
|
-
|
|
1236
|
-
if "distance/p95" in self.metrics_to_log:
|
|
1237
|
-
val = metrics["distance_metrics"]["p95"]
|
|
1238
|
-
if not np.isnan(val):
|
|
1239
|
-
log_dict["eval/val/distance/p95"] = val
|
|
1240
|
-
|
|
1241
|
-
if "distance/p99" in self.metrics_to_log:
|
|
1242
|
-
val = metrics["distance_metrics"]["p99"]
|
|
1243
|
-
if not np.isnan(val):
|
|
1244
|
-
log_dict["eval/val/distance/p99"] = val
|
|
1245
|
-
|
|
1246
|
-
# PCK metrics
|
|
1247
|
-
if "mPCK" in self.metrics_to_log:
|
|
1248
|
-
log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
1249
|
-
|
|
1250
|
-
# PCK at specific thresholds (precomputed in evaluation.py)
|
|
1251
|
-
if "PCK@5" in self.metrics_to_log:
|
|
1252
|
-
log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
|
|
1253
|
-
|
|
1254
|
-
if "PCK@10" in self.metrics_to_log:
|
|
1255
|
-
log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
|
|
1256
|
-
|
|
1257
|
-
# Visibility metrics
|
|
1258
|
-
if "visibility_precision" in self.metrics_to_log:
|
|
1259
|
-
val = metrics["visibility_metrics"]["precision"]
|
|
1260
|
-
if not np.isnan(val):
|
|
1261
|
-
log_dict["eval/val/visibility_precision"] = val
|
|
1262
|
-
|
|
1263
|
-
if "visibility_recall" in self.metrics_to_log:
|
|
1264
|
-
val = metrics["visibility_metrics"]["recall"]
|
|
1265
|
-
if not np.isnan(val):
|
|
1266
|
-
log_dict["eval/val/visibility_recall"] = val
|
|
1267
|
-
|
|
1268
|
-
wandb_logger.experiment.log(log_dict, commit=False)
|
|
1269
|
-
|
|
1270
|
-
# Update best metrics in summary (excluding epoch)
|
|
1271
|
-
for key, value in log_dict.items():
|
|
1272
|
-
if key == "epoch":
|
|
1273
|
-
continue
|
|
1274
|
-
# Create summary key like "best/eval/val/mOKS"
|
|
1275
|
-
summary_key = f"best/{key}"
|
|
1276
|
-
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
1277
|
-
# For distance metrics, lower is better; for others, higher is better
|
|
1278
|
-
is_distance = "distance" in key
|
|
1279
|
-
if current_best is None:
|
|
1280
|
-
wandb_logger.experiment.summary[summary_key] = value
|
|
1281
|
-
elif is_distance and value < current_best:
|
|
1282
|
-
wandb_logger.experiment.summary[summary_key] = value
|
|
1283
|
-
elif not is_distance and value > current_best:
|
|
1284
|
-
wandb_logger.experiment.summary[summary_key] = value
|
|
1285
|
-
|
|
1286
|
-
|
|
1287
|
-
def match_centroids(
|
|
1288
|
-
pred_centroids: "np.ndarray",
|
|
1289
|
-
gt_centroids: "np.ndarray",
|
|
1290
|
-
max_distance: float = 50.0,
|
|
1291
|
-
) -> tuple:
|
|
1292
|
-
"""Match predicted centroids to ground truth using Hungarian algorithm.
|
|
1293
|
-
|
|
1294
|
-
Args:
|
|
1295
|
-
pred_centroids: Predicted centroid locations, shape (n_pred, 2).
|
|
1296
|
-
gt_centroids: Ground truth centroid locations, shape (n_gt, 2).
|
|
1297
|
-
max_distance: Maximum distance threshold for valid matches (in pixels).
|
|
1298
|
-
|
|
1299
|
-
Returns:
|
|
1300
|
-
Tuple of:
|
|
1301
|
-
- matched_pred_indices: Indices of matched predictions
|
|
1302
|
-
- matched_gt_indices: Indices of matched ground truth
|
|
1303
|
-
- unmatched_pred_indices: Indices of unmatched predictions (false positives)
|
|
1304
|
-
- unmatched_gt_indices: Indices of unmatched ground truth (false negatives)
|
|
1305
|
-
"""
|
|
1306
|
-
import numpy as np
|
|
1307
|
-
from scipy.optimize import linear_sum_assignment
|
|
1308
|
-
from scipy.spatial.distance import cdist
|
|
1309
|
-
|
|
1310
|
-
n_pred = len(pred_centroids)
|
|
1311
|
-
n_gt = len(gt_centroids)
|
|
1312
|
-
|
|
1313
|
-
# Handle edge cases
|
|
1314
|
-
if n_pred == 0 and n_gt == 0:
|
|
1315
|
-
return np.array([]), np.array([]), np.array([]), np.array([])
|
|
1316
|
-
if n_pred == 0:
|
|
1317
|
-
return np.array([]), np.array([]), np.array([]), np.arange(n_gt)
|
|
1318
|
-
if n_gt == 0:
|
|
1319
|
-
return np.array([]), np.array([]), np.arange(n_pred), np.array([])
|
|
1320
|
-
|
|
1321
|
-
# Compute pairwise distances
|
|
1322
|
-
cost_matrix = cdist(pred_centroids, gt_centroids)
|
|
1323
|
-
|
|
1324
|
-
# Run Hungarian algorithm for optimal matching
|
|
1325
|
-
pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
|
|
1326
|
-
|
|
1327
|
-
# Filter matches that exceed max_distance
|
|
1328
|
-
matched_pred = []
|
|
1329
|
-
matched_gt = []
|
|
1330
|
-
for p_idx, g_idx in zip(pred_indices, gt_indices):
|
|
1331
|
-
if cost_matrix[p_idx, g_idx] <= max_distance:
|
|
1332
|
-
matched_pred.append(p_idx)
|
|
1333
|
-
matched_gt.append(g_idx)
|
|
1334
|
-
|
|
1335
|
-
matched_pred = np.array(matched_pred)
|
|
1336
|
-
matched_gt = np.array(matched_gt)
|
|
1337
|
-
|
|
1338
|
-
# Find unmatched indices
|
|
1339
|
-
all_pred = set(range(n_pred))
|
|
1340
|
-
all_gt = set(range(n_gt))
|
|
1341
|
-
unmatched_pred = np.array(list(all_pred - set(matched_pred)))
|
|
1342
|
-
unmatched_gt = np.array(list(all_gt - set(matched_gt)))
|
|
1343
|
-
|
|
1344
|
-
return matched_pred, matched_gt, unmatched_pred, unmatched_gt
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
class CentroidEvaluationCallback(Callback):
|
|
1348
|
-
"""Callback to run centroid-specific evaluation metrics at end of validation epochs.
|
|
1349
|
-
|
|
1350
|
-
This callback is designed specifically for centroid models, which predict a single
|
|
1351
|
-
point (centroid) per instance rather than full pose skeletons. It computes
|
|
1352
|
-
distance-based metrics and detection metrics that are more appropriate for
|
|
1353
|
-
point detection tasks than OKS/PCK metrics.
|
|
1354
|
-
|
|
1355
|
-
Metrics computed:
|
|
1356
|
-
- Distance metrics: mean, median, p90, p95, max Euclidean distance
|
|
1357
|
-
- Detection metrics: precision, recall, F1 score
|
|
1358
|
-
- Counts: true positives, false positives, false negatives
|
|
1359
|
-
|
|
1360
|
-
Attributes:
|
|
1361
|
-
videos: List of sio.Video objects.
|
|
1362
|
-
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1363
|
-
match_threshold: Maximum distance (pixels) for matching pred to GT (default: 50.0).
|
|
1364
|
-
"""
|
|
1365
|
-
|
|
1366
|
-
def __init__(
|
|
1367
|
-
self,
|
|
1368
|
-
videos: list,
|
|
1369
|
-
eval_frequency: int = 1,
|
|
1370
|
-
match_threshold: float = 50.0,
|
|
1371
|
-
):
|
|
1372
|
-
"""Initialize the callback.
|
|
1373
|
-
|
|
1374
|
-
Args:
|
|
1375
|
-
videos: List of sio.Video objects.
|
|
1376
|
-
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1377
|
-
match_threshold: Maximum distance in pixels for a prediction to be
|
|
1378
|
-
considered a match to a ground truth centroid (default: 50.0).
|
|
1379
|
-
"""
|
|
1380
|
-
super().__init__()
|
|
1381
|
-
self.videos = videos
|
|
1382
|
-
self.eval_frequency = eval_frequency
|
|
1383
|
-
self.match_threshold = match_threshold
|
|
1384
|
-
|
|
1385
|
-
def on_validation_epoch_start(self, trainer, pl_module):
|
|
1386
|
-
"""Enable prediction collection at the start of validation.
|
|
1387
|
-
|
|
1388
|
-
Skip during sanity check to avoid inference issues.
|
|
1389
|
-
"""
|
|
1390
|
-
if trainer.sanity_checking:
|
|
1391
|
-
return
|
|
1392
|
-
pl_module._collect_val_predictions = True
|
|
1393
|
-
|
|
1394
|
-
def on_validation_epoch_end(self, trainer, pl_module):
|
|
1395
|
-
"""Run centroid evaluation and log metrics at end of validation epoch."""
|
|
1396
|
-
import numpy as np
|
|
1397
|
-
from lightning.pytorch.loggers import WandbLogger
|
|
1398
|
-
|
|
1399
|
-
# Determine if we should run evaluation this epoch (only on rank 0)
|
|
1400
|
-
should_evaluate = (
|
|
1401
|
-
trainer.current_epoch + 1
|
|
1402
|
-
) % self.eval_frequency == 0 and trainer.is_global_zero
|
|
1403
|
-
|
|
1404
|
-
if should_evaluate:
|
|
1405
|
-
# Check if we have predictions
|
|
1406
|
-
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
1407
|
-
logger.warning(
|
|
1408
|
-
"No predictions collected for centroid epoch-end evaluation"
|
|
1409
|
-
)
|
|
1410
|
-
else:
|
|
1411
|
-
try:
|
|
1412
|
-
metrics = self._compute_metrics(
|
|
1413
|
-
pl_module.val_predictions, pl_module.val_ground_truth, np
|
|
1414
|
-
)
|
|
1415
|
-
|
|
1416
|
-
# Log to WandB
|
|
1417
|
-
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
1418
|
-
|
|
1419
|
-
logger.info(
|
|
1420
|
-
f"Epoch {trainer.current_epoch} centroid evaluation: "
|
|
1421
|
-
f"precision={metrics['precision']:.4f}, "
|
|
1422
|
-
f"recall={metrics['recall']:.4f}, "
|
|
1423
|
-
f"dist_avg={metrics['dist_avg']:.2f}px"
|
|
1424
|
-
)
|
|
1425
|
-
|
|
1426
|
-
except Exception as e:
|
|
1427
|
-
logger.warning(f"Centroid epoch-end evaluation failed: {e}")
|
|
1428
|
-
|
|
1429
|
-
# Cleanup - all ranks reset the flag, rank 0 clears the lists
|
|
1430
|
-
pl_module._collect_val_predictions = False
|
|
1431
|
-
if trainer.is_global_zero:
|
|
1432
|
-
pl_module.val_predictions = []
|
|
1433
|
-
pl_module.val_ground_truth = []
|
|
1434
|
-
|
|
1435
|
-
# Sync all processes - barrier must be reached by ALL ranks
|
|
1436
|
-
trainer.strategy.barrier()
|
|
1437
|
-
|
|
1438
|
-
def _compute_metrics(self, predictions: list, ground_truth: list, np) -> dict:
|
|
1439
|
-
"""Compute centroid-specific metrics.
|
|
1440
|
-
|
|
1441
|
-
Args:
|
|
1442
|
-
predictions: List of prediction dicts with "pred_peaks" key.
|
|
1443
|
-
ground_truth: List of ground truth dicts with "gt_instances" key.
|
|
1444
|
-
np: NumPy module.
|
|
1445
|
-
|
|
1446
|
-
Returns:
|
|
1447
|
-
Dictionary of computed metrics.
|
|
1448
|
-
"""
|
|
1449
|
-
all_distances = []
|
|
1450
|
-
total_tp = 0
|
|
1451
|
-
total_fp = 0
|
|
1452
|
-
total_fn = 0
|
|
1453
|
-
|
|
1454
|
-
# Group predictions and GT by frame
|
|
1455
|
-
pred_by_frame = {}
|
|
1456
|
-
for pred in predictions:
|
|
1457
|
-
key = (pred["video_idx"], pred["frame_idx"])
|
|
1458
|
-
if key not in pred_by_frame:
|
|
1459
|
-
pred_by_frame[key] = []
|
|
1460
|
-
# pred_peaks shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
|
|
1461
|
-
centroids = pred["pred_peaks"].reshape(-1, 2)
|
|
1462
|
-
# Filter out NaN centroids
|
|
1463
|
-
valid_mask = ~np.isnan(centroids).any(axis=1)
|
|
1464
|
-
pred_by_frame[key].append(centroids[valid_mask])
|
|
1465
|
-
|
|
1466
|
-
gt_by_frame = {}
|
|
1467
|
-
for gt in ground_truth:
|
|
1468
|
-
key = (gt["video_idx"], gt["frame_idx"])
|
|
1469
|
-
if key not in gt_by_frame:
|
|
1470
|
-
gt_by_frame[key] = []
|
|
1471
|
-
# gt_instances shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
|
|
1472
|
-
centroids = gt["gt_instances"].reshape(-1, 2)
|
|
1473
|
-
# Filter out NaN centroids
|
|
1474
|
-
valid_mask = ~np.isnan(centroids).any(axis=1)
|
|
1475
|
-
gt_by_frame[key].append(centroids[valid_mask])
|
|
1476
|
-
|
|
1477
|
-
# Process each frame
|
|
1478
|
-
all_frames = set(pred_by_frame.keys()) | set(gt_by_frame.keys())
|
|
1479
|
-
for frame_key in all_frames:
|
|
1480
|
-
# Concatenate all predictions for this frame
|
|
1481
|
-
if frame_key in pred_by_frame:
|
|
1482
|
-
frame_preds = np.concatenate(pred_by_frame[frame_key], axis=0)
|
|
1483
|
-
else:
|
|
1484
|
-
frame_preds = np.zeros((0, 2))
|
|
1485
|
-
|
|
1486
|
-
# Concatenate all GT for this frame
|
|
1487
|
-
if frame_key in gt_by_frame:
|
|
1488
|
-
frame_gt = np.concatenate(gt_by_frame[frame_key], axis=0)
|
|
1489
|
-
else:
|
|
1490
|
-
frame_gt = np.zeros((0, 2))
|
|
1491
|
-
|
|
1492
|
-
# Match predictions to ground truth
|
|
1493
|
-
matched_pred, matched_gt, unmatched_pred, unmatched_gt = match_centroids(
|
|
1494
|
-
frame_preds, frame_gt, max_distance=self.match_threshold
|
|
1495
|
-
)
|
|
1496
|
-
|
|
1497
|
-
# Compute distances for matched pairs
|
|
1498
|
-
if len(matched_pred) > 0:
|
|
1499
|
-
matched_pred_points = frame_preds[matched_pred]
|
|
1500
|
-
matched_gt_points = frame_gt[matched_gt]
|
|
1501
|
-
distances = np.linalg.norm(
|
|
1502
|
-
matched_pred_points - matched_gt_points, axis=1
|
|
1503
|
-
)
|
|
1504
|
-
all_distances.extend(distances.tolist())
|
|
1505
|
-
|
|
1506
|
-
# Update counts
|
|
1507
|
-
total_tp += len(matched_pred)
|
|
1508
|
-
total_fp += len(unmatched_pred)
|
|
1509
|
-
total_fn += len(unmatched_gt)
|
|
1510
|
-
|
|
1511
|
-
# Compute aggregate metrics
|
|
1512
|
-
all_distances = np.array(all_distances)
|
|
1513
|
-
|
|
1514
|
-
# Distance metrics (only if we have matches)
|
|
1515
|
-
if len(all_distances) > 0:
|
|
1516
|
-
dist_avg = float(np.mean(all_distances))
|
|
1517
|
-
dist_median = float(np.median(all_distances))
|
|
1518
|
-
dist_p90 = float(np.percentile(all_distances, 90))
|
|
1519
|
-
dist_p95 = float(np.percentile(all_distances, 95))
|
|
1520
|
-
dist_max = float(np.max(all_distances))
|
|
1521
|
-
else:
|
|
1522
|
-
dist_avg = dist_median = dist_p90 = dist_p95 = dist_max = float("nan")
|
|
1523
|
-
|
|
1524
|
-
# Detection metrics
|
|
1525
|
-
precision = (
|
|
1526
|
-
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
|
|
1527
|
-
)
|
|
1528
|
-
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
|
|
1529
|
-
f1 = (
|
|
1530
|
-
2 * precision * recall / (precision + recall)
|
|
1531
|
-
if (precision + recall) > 0
|
|
1532
|
-
else 0.0
|
|
1533
|
-
)
|
|
1534
|
-
|
|
1535
|
-
return {
|
|
1536
|
-
"dist_avg": dist_avg,
|
|
1537
|
-
"dist_median": dist_median,
|
|
1538
|
-
"dist_p90": dist_p90,
|
|
1539
|
-
"dist_p95": dist_p95,
|
|
1540
|
-
"dist_max": dist_max,
|
|
1541
|
-
"precision": precision,
|
|
1542
|
-
"recall": recall,
|
|
1543
|
-
"f1": f1,
|
|
1544
|
-
"n_true_positives": total_tp,
|
|
1545
|
-
"n_false_positives": total_fp,
|
|
1546
|
-
"n_false_negatives": total_fn,
|
|
1547
|
-
"n_total_predictions": total_tp + total_fp,
|
|
1548
|
-
"n_total_ground_truth": total_tp + total_fn,
|
|
1549
|
-
}
|
|
1550
|
-
|
|
1551
|
-
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
1552
|
-
"""Log centroid evaluation metrics to WandB."""
|
|
1553
|
-
import numpy as np
|
|
1554
|
-
from lightning.pytorch.loggers import WandbLogger
|
|
1555
|
-
|
|
1556
|
-
# Get WandB logger
|
|
1557
|
-
wandb_logger = None
|
|
1558
|
-
for log in trainer.loggers:
|
|
1559
|
-
if isinstance(log, WandbLogger):
|
|
1560
|
-
wandb_logger = log
|
|
1561
|
-
break
|
|
1562
|
-
|
|
1563
|
-
if wandb_logger is None:
|
|
1564
|
-
return
|
|
1565
|
-
|
|
1566
|
-
log_dict = {"epoch": epoch}
|
|
1567
|
-
|
|
1568
|
-
# Distance metrics (with NaN handling)
|
|
1569
|
-
if not np.isnan(metrics["dist_avg"]):
|
|
1570
|
-
log_dict["eval/val/centroid_dist_avg"] = metrics["dist_avg"]
|
|
1571
|
-
if not np.isnan(metrics["dist_median"]):
|
|
1572
|
-
log_dict["eval/val/centroid_dist_median"] = metrics["dist_median"]
|
|
1573
|
-
if not np.isnan(metrics["dist_p90"]):
|
|
1574
|
-
log_dict["eval/val/centroid_dist_p90"] = metrics["dist_p90"]
|
|
1575
|
-
if not np.isnan(metrics["dist_p95"]):
|
|
1576
|
-
log_dict["eval/val/centroid_dist_p95"] = metrics["dist_p95"]
|
|
1577
|
-
if not np.isnan(metrics["dist_max"]):
|
|
1578
|
-
log_dict["eval/val/centroid_dist_max"] = metrics["dist_max"]
|
|
1579
|
-
|
|
1580
|
-
# Detection metrics
|
|
1581
|
-
log_dict["eval/val/centroid_precision"] = metrics["precision"]
|
|
1582
|
-
log_dict["eval/val/centroid_recall"] = metrics["recall"]
|
|
1583
|
-
log_dict["eval/val/centroid_f1"] = metrics["f1"]
|
|
1584
|
-
|
|
1585
|
-
# Counts
|
|
1586
|
-
log_dict["eval/val/centroid_n_tp"] = metrics["n_true_positives"]
|
|
1587
|
-
log_dict["eval/val/centroid_n_fp"] = metrics["n_false_positives"]
|
|
1588
|
-
log_dict["eval/val/centroid_n_fn"] = metrics["n_false_negatives"]
|
|
1589
|
-
|
|
1590
|
-
wandb_logger.experiment.log(log_dict, commit=False)
|
|
1591
|
-
|
|
1592
|
-
# Update best metrics in summary
|
|
1593
|
-
for key, value in log_dict.items():
|
|
1594
|
-
if key == "epoch":
|
|
1595
|
-
continue
|
|
1596
|
-
summary_key = f"best/{key}"
|
|
1597
|
-
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
1598
|
-
# For distance metrics, lower is better; for others, higher is better
|
|
1599
|
-
is_distance = "dist" in key
|
|
1600
|
-
if current_best is None:
|
|
1601
|
-
wandb_logger.experiment.summary[summary_key] = value
|
|
1602
|
-
elif is_distance and value < current_best:
|
|
1603
|
-
wandb_logger.experiment.summary[summary_key] = value
|
|
1604
|
-
elif not is_distance and value > current_best:
|
|
1605
|
-
wandb_logger.experiment.summary[summary_key] = value
|