sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
sleap_nn/training/callbacks.py
CHANGED
|
@@ -2,10 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import zmq
|
|
4
4
|
import jsonpickle
|
|
5
|
-
from typing import Callable, Optional
|
|
5
|
+
from typing import Callable, Optional, Union
|
|
6
6
|
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.callbacks.progress import TQDMProgressBar
|
|
7
8
|
from loguru import logger
|
|
8
9
|
import matplotlib
|
|
10
|
+
|
|
11
|
+
matplotlib.use(
|
|
12
|
+
"Agg"
|
|
13
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
9
14
|
import matplotlib.pyplot as plt
|
|
10
15
|
from PIL import Image
|
|
11
16
|
from pathlib import Path
|
|
@@ -14,6 +19,32 @@ import csv
|
|
|
14
19
|
from sleap_nn import RANK
|
|
15
20
|
|
|
16
21
|
|
|
22
|
+
class SleapProgressBar(TQDMProgressBar):
|
|
23
|
+
"""Custom progress bar with better formatting for small metric values.
|
|
24
|
+
|
|
25
|
+
The default TQDMProgressBar truncates small floats like 1e-5 to "0.000".
|
|
26
|
+
This subclass formats metrics using scientific notation when appropriate.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def get_metrics(
|
|
30
|
+
self, trainer, pl_module
|
|
31
|
+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
|
|
32
|
+
"""Override to format metrics with scientific notation for small values."""
|
|
33
|
+
items = super().get_metrics(trainer, pl_module)
|
|
34
|
+
formatted = {}
|
|
35
|
+
for k, v in items.items():
|
|
36
|
+
if isinstance(v, float):
|
|
37
|
+
# Use scientific notation for very small values
|
|
38
|
+
if v != 0 and abs(v) < 0.001:
|
|
39
|
+
formatted[k] = f"{v:.2e}"
|
|
40
|
+
else:
|
|
41
|
+
# Use 4 decimal places for normal values
|
|
42
|
+
formatted[k] = f"{v:.4f}"
|
|
43
|
+
else:
|
|
44
|
+
formatted[k] = v
|
|
45
|
+
return formatted
|
|
46
|
+
|
|
47
|
+
|
|
17
48
|
class CSVLoggerCallback(Callback):
|
|
18
49
|
"""Callback for logging metrics to csv.
|
|
19
50
|
|
|
@@ -53,6 +84,21 @@ class CSVLoggerCallback(Callback):
|
|
|
53
84
|
for key in self.keys:
|
|
54
85
|
if key == "epoch":
|
|
55
86
|
log_data["epoch"] = trainer.current_epoch
|
|
87
|
+
elif key == "learning_rate":
|
|
88
|
+
# Handle multiple formats:
|
|
89
|
+
# 1. Direct "learning_rate" key
|
|
90
|
+
# 2. "train/lr" key (current format from lightning modules)
|
|
91
|
+
# 3. "lr-*" keys from LearningRateMonitor (legacy)
|
|
92
|
+
value = metrics.get(key, None)
|
|
93
|
+
if value is None:
|
|
94
|
+
value = metrics.get("train/lr", None)
|
|
95
|
+
if value is None:
|
|
96
|
+
# Look for lr-* keys from LearningRateMonitor (legacy)
|
|
97
|
+
for metric_key in metrics.keys():
|
|
98
|
+
if metric_key.startswith("lr-"):
|
|
99
|
+
value = metrics[metric_key]
|
|
100
|
+
break
|
|
101
|
+
log_data[key] = value.item() if value is not None else None
|
|
56
102
|
else:
|
|
57
103
|
value = metrics.get(key, None)
|
|
58
104
|
log_data[key] = value.item() if value is not None else None
|
|
@@ -66,7 +112,11 @@ class CSVLoggerCallback(Callback):
|
|
|
66
112
|
|
|
67
113
|
|
|
68
114
|
class WandBPredImageLogger(Callback):
|
|
69
|
-
"""Callback for writing image predictions to wandb.
|
|
115
|
+
"""Callback for writing image predictions to wandb as a Table.
|
|
116
|
+
|
|
117
|
+
.. deprecated::
|
|
118
|
+
This callback logs images to a wandb.Table which doesn't support
|
|
119
|
+
step sliders. Use WandBVizCallback instead for better UX.
|
|
70
120
|
|
|
71
121
|
Attributes:
|
|
72
122
|
viz_folder: Path to viz directory.
|
|
@@ -141,12 +191,576 @@ class WandBPredImageLogger(Callback):
|
|
|
141
191
|
]
|
|
142
192
|
]
|
|
143
193
|
table = wandb.Table(columns=column_names, data=data)
|
|
144
|
-
|
|
194
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
195
|
+
wandb.log({f"{self.wandb_run_name}": table}, commit=False)
|
|
145
196
|
|
|
146
197
|
# Sync all processes after wandb logging
|
|
147
198
|
trainer.strategy.barrier()
|
|
148
199
|
|
|
149
200
|
|
|
201
|
+
class WandBVizCallback(Callback):
|
|
202
|
+
"""Callback for logging visualization images directly to wandb with slider support.
|
|
203
|
+
|
|
204
|
+
This callback logs images using wandb.log() which enables step slider navigation
|
|
205
|
+
in the wandb UI. Multiple visualization modes can be enabled simultaneously:
|
|
206
|
+
- viz_enabled: Pre-render with matplotlib (same as disk viz)
|
|
207
|
+
- viz_boxes: Interactive keypoint boxes with filtering
|
|
208
|
+
- viz_masks: Confidence map overlay with per-node toggling
|
|
209
|
+
|
|
210
|
+
Attributes:
|
|
211
|
+
train_viz_fn: Function that returns VisualizationData for training sample.
|
|
212
|
+
val_viz_fn: Function that returns VisualizationData for validation sample.
|
|
213
|
+
viz_enabled: Whether to log pre-rendered matplotlib images.
|
|
214
|
+
viz_boxes: Whether to log interactive keypoint boxes.
|
|
215
|
+
viz_masks: Whether to log confidence map overlay masks.
|
|
216
|
+
box_size: Size of keypoint boxes in pixels (for viz_boxes).
|
|
217
|
+
confmap_threshold: Threshold for confmap masks (for viz_masks).
|
|
218
|
+
log_table: Whether to also log to a wandb.Table (backwards compat).
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def __init__(
|
|
222
|
+
self,
|
|
223
|
+
train_viz_fn: Callable,
|
|
224
|
+
val_viz_fn: Callable,
|
|
225
|
+
viz_enabled: bool = True,
|
|
226
|
+
viz_boxes: bool = False,
|
|
227
|
+
viz_masks: bool = False,
|
|
228
|
+
box_size: float = 5.0,
|
|
229
|
+
confmap_threshold: float = 0.1,
|
|
230
|
+
log_table: bool = False,
|
|
231
|
+
):
|
|
232
|
+
"""Initialize the callback.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
train_viz_fn: Callable that returns VisualizationData for a training sample.
|
|
236
|
+
val_viz_fn: Callable that returns VisualizationData for a validation sample.
|
|
237
|
+
viz_enabled: If True, log pre-rendered matplotlib images.
|
|
238
|
+
viz_boxes: If True, log interactive keypoint boxes.
|
|
239
|
+
viz_masks: If True, log confidence map overlay masks.
|
|
240
|
+
box_size: Size of keypoint boxes in pixels (for viz_boxes).
|
|
241
|
+
confmap_threshold: Threshold for confmap mask generation (for viz_masks).
|
|
242
|
+
log_table: If True, also log images to a wandb.Table (for backwards compat).
|
|
243
|
+
"""
|
|
244
|
+
super().__init__()
|
|
245
|
+
self.train_viz_fn = train_viz_fn
|
|
246
|
+
self.val_viz_fn = val_viz_fn
|
|
247
|
+
self.viz_enabled = viz_enabled
|
|
248
|
+
self.viz_boxes = viz_boxes
|
|
249
|
+
self.viz_masks = viz_masks
|
|
250
|
+
self.log_table = log_table
|
|
251
|
+
|
|
252
|
+
# Import here to avoid circular imports
|
|
253
|
+
from sleap_nn.training.utils import WandBRenderer
|
|
254
|
+
|
|
255
|
+
self.box_size = box_size
|
|
256
|
+
self.confmap_threshold = confmap_threshold
|
|
257
|
+
|
|
258
|
+
# Create renderers for each enabled mode
|
|
259
|
+
self.renderers = {}
|
|
260
|
+
if viz_enabled:
|
|
261
|
+
self.renderers["direct"] = WandBRenderer(
|
|
262
|
+
mode="direct", box_size=box_size, confmap_threshold=confmap_threshold
|
|
263
|
+
)
|
|
264
|
+
if viz_boxes:
|
|
265
|
+
self.renderers["boxes"] = WandBRenderer(
|
|
266
|
+
mode="boxes", box_size=box_size, confmap_threshold=confmap_threshold
|
|
267
|
+
)
|
|
268
|
+
if viz_masks:
|
|
269
|
+
self.renderers["masks"] = WandBRenderer(
|
|
270
|
+
mode="masks", box_size=box_size, confmap_threshold=confmap_threshold
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def _get_wandb_logger(self, trainer):
|
|
274
|
+
"""Get the WandbLogger from trainer's loggers."""
|
|
275
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
276
|
+
|
|
277
|
+
for logger in trainer.loggers:
|
|
278
|
+
if isinstance(logger, WandbLogger):
|
|
279
|
+
return logger
|
|
280
|
+
return None
|
|
281
|
+
|
|
282
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
|
283
|
+
"""Log visualization images at end of each epoch."""
|
|
284
|
+
if trainer.is_global_zero:
|
|
285
|
+
epoch = trainer.current_epoch
|
|
286
|
+
|
|
287
|
+
# Get the wandb logger to use its experiment for logging
|
|
288
|
+
wandb_logger = self._get_wandb_logger(trainer)
|
|
289
|
+
|
|
290
|
+
# Only do visualization work if wandb logger is available
|
|
291
|
+
if wandb_logger is not None:
|
|
292
|
+
# Get visualization data
|
|
293
|
+
train_data = self.train_viz_fn()
|
|
294
|
+
val_data = self.val_viz_fn()
|
|
295
|
+
|
|
296
|
+
# Render and log for each enabled mode
|
|
297
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
298
|
+
log_dict = {}
|
|
299
|
+
for mode_name, renderer in self.renderers.items():
|
|
300
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
301
|
+
train_img = renderer.render(
|
|
302
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
303
|
+
)
|
|
304
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
305
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
306
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
307
|
+
|
|
308
|
+
if log_dict:
|
|
309
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
310
|
+
log_dict["epoch"] = epoch
|
|
311
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
312
|
+
# Lightning will commit when it logs its own metrics
|
|
313
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
314
|
+
|
|
315
|
+
# Optionally also log to table for backwards compat
|
|
316
|
+
if self.log_table and "direct" in self.renderers:
|
|
317
|
+
train_img = self.renderers["direct"].render(
|
|
318
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
319
|
+
)
|
|
320
|
+
val_img = self.renderers["direct"].render(
|
|
321
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
322
|
+
)
|
|
323
|
+
table = wandb.Table(
|
|
324
|
+
columns=["Epoch", "Train", "Validation"],
|
|
325
|
+
data=[[epoch, train_img, val_img]],
|
|
326
|
+
)
|
|
327
|
+
wandb_logger.experiment.log(
|
|
328
|
+
{"predictions_table": table}, commit=False
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Sync all processes - barrier must be reached by ALL ranks
|
|
332
|
+
trainer.strategy.barrier()
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
336
|
+
"""Extended WandBVizCallback that also logs PAF visualizations for bottom-up models."""
|
|
337
|
+
|
|
338
|
+
def __init__(
|
|
339
|
+
self,
|
|
340
|
+
train_viz_fn: Callable,
|
|
341
|
+
val_viz_fn: Callable,
|
|
342
|
+
train_pafs_viz_fn: Callable,
|
|
343
|
+
val_pafs_viz_fn: Callable,
|
|
344
|
+
viz_enabled: bool = True,
|
|
345
|
+
viz_boxes: bool = False,
|
|
346
|
+
viz_masks: bool = False,
|
|
347
|
+
box_size: float = 5.0,
|
|
348
|
+
confmap_threshold: float = 0.1,
|
|
349
|
+
log_table: bool = False,
|
|
350
|
+
):
|
|
351
|
+
"""Initialize the callback.
|
|
352
|
+
|
|
353
|
+
Args:
|
|
354
|
+
train_viz_fn: Callable returning VisualizationData for training sample.
|
|
355
|
+
val_viz_fn: Callable returning VisualizationData for validation sample.
|
|
356
|
+
train_pafs_viz_fn: Callable returning VisualizationData with PAFs for training.
|
|
357
|
+
val_pafs_viz_fn: Callable returning VisualizationData with PAFs for validation.
|
|
358
|
+
viz_enabled: If True, log pre-rendered matplotlib images.
|
|
359
|
+
viz_boxes: If True, log interactive keypoint boxes.
|
|
360
|
+
viz_masks: If True, log confidence map overlay masks.
|
|
361
|
+
box_size: Size of keypoint boxes in pixels.
|
|
362
|
+
confmap_threshold: Threshold for confmap mask generation.
|
|
363
|
+
log_table: If True, also log images to a wandb.Table.
|
|
364
|
+
"""
|
|
365
|
+
super().__init__(
|
|
366
|
+
train_viz_fn=train_viz_fn,
|
|
367
|
+
val_viz_fn=val_viz_fn,
|
|
368
|
+
viz_enabled=viz_enabled,
|
|
369
|
+
viz_boxes=viz_boxes,
|
|
370
|
+
viz_masks=viz_masks,
|
|
371
|
+
box_size=box_size,
|
|
372
|
+
confmap_threshold=confmap_threshold,
|
|
373
|
+
log_table=log_table,
|
|
374
|
+
)
|
|
375
|
+
self.train_pafs_viz_fn = train_pafs_viz_fn
|
|
376
|
+
self.val_pafs_viz_fn = val_pafs_viz_fn
|
|
377
|
+
|
|
378
|
+
# Import here to avoid circular imports
|
|
379
|
+
from sleap_nn.training.utils import MatplotlibRenderer
|
|
380
|
+
|
|
381
|
+
self._mpl_renderer = MatplotlibRenderer()
|
|
382
|
+
|
|
383
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
|
384
|
+
"""Log visualization images including PAFs at end of each epoch."""
|
|
385
|
+
if trainer.is_global_zero:
|
|
386
|
+
epoch = trainer.current_epoch
|
|
387
|
+
|
|
388
|
+
# Get the wandb logger to use its experiment for logging
|
|
389
|
+
wandb_logger = self._get_wandb_logger(trainer)
|
|
390
|
+
|
|
391
|
+
# Only do visualization work if wandb logger is available
|
|
392
|
+
if wandb_logger is not None:
|
|
393
|
+
# Get visualization data
|
|
394
|
+
train_data = self.train_viz_fn()
|
|
395
|
+
val_data = self.val_viz_fn()
|
|
396
|
+
train_pafs_data = self.train_pafs_viz_fn()
|
|
397
|
+
val_pafs_data = self.val_pafs_viz_fn()
|
|
398
|
+
|
|
399
|
+
# Render and log for each enabled mode
|
|
400
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
401
|
+
log_dict = {}
|
|
402
|
+
for mode_name, renderer in self.renderers.items():
|
|
403
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
404
|
+
train_img = renderer.render(
|
|
405
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
406
|
+
)
|
|
407
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
408
|
+
log_dict[f"viz/train/predictions{suffix}"] = train_img
|
|
409
|
+
log_dict[f"viz/val/predictions{suffix}"] = val_img
|
|
410
|
+
|
|
411
|
+
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
412
|
+
from io import BytesIO
|
|
413
|
+
import matplotlib.pyplot as plt
|
|
414
|
+
from PIL import Image
|
|
415
|
+
|
|
416
|
+
train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
|
|
417
|
+
buf = BytesIO()
|
|
418
|
+
train_pafs_fig.savefig(
|
|
419
|
+
buf, format="png", bbox_inches="tight", pad_inches=0
|
|
420
|
+
)
|
|
421
|
+
buf.seek(0)
|
|
422
|
+
plt.close(train_pafs_fig)
|
|
423
|
+
train_pafs_pil = Image.open(buf)
|
|
424
|
+
log_dict["viz/train/pafs"] = wandb.Image(
|
|
425
|
+
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
|
|
429
|
+
buf = BytesIO()
|
|
430
|
+
val_pafs_fig.savefig(
|
|
431
|
+
buf, format="png", bbox_inches="tight", pad_inches=0
|
|
432
|
+
)
|
|
433
|
+
buf.seek(0)
|
|
434
|
+
plt.close(val_pafs_fig)
|
|
435
|
+
val_pafs_pil = Image.open(buf)
|
|
436
|
+
log_dict["viz/val/pafs"] = wandb.Image(
|
|
437
|
+
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if log_dict:
|
|
441
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
442
|
+
log_dict["epoch"] = epoch
|
|
443
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
444
|
+
# Lightning will commit when it logs its own metrics
|
|
445
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
446
|
+
|
|
447
|
+
# Optionally also log to table
|
|
448
|
+
if self.log_table and "direct" in self.renderers:
|
|
449
|
+
train_img = self.renderers["direct"].render(
|
|
450
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
451
|
+
)
|
|
452
|
+
val_img = self.renderers["direct"].render(
|
|
453
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
454
|
+
)
|
|
455
|
+
table = wandb.Table(
|
|
456
|
+
columns=[
|
|
457
|
+
"Epoch",
|
|
458
|
+
"Train",
|
|
459
|
+
"Validation",
|
|
460
|
+
"Train PAFs",
|
|
461
|
+
"Val PAFs",
|
|
462
|
+
],
|
|
463
|
+
data=[
|
|
464
|
+
[
|
|
465
|
+
epoch,
|
|
466
|
+
train_img,
|
|
467
|
+
val_img,
|
|
468
|
+
log_dict["viz/train/pafs"],
|
|
469
|
+
log_dict["viz/val/pafs"],
|
|
470
|
+
]
|
|
471
|
+
],
|
|
472
|
+
)
|
|
473
|
+
wandb_logger.experiment.log(
|
|
474
|
+
{"predictions_table": table}, commit=False
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
# Sync all processes - barrier must be reached by ALL ranks
|
|
478
|
+
trainer.strategy.barrier()
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
class UnifiedVizCallback(Callback):
|
|
482
|
+
"""Unified callback for all visualization outputs during training.
|
|
483
|
+
|
|
484
|
+
This callback consolidates all visualization functionality into a single callback,
|
|
485
|
+
eliminating redundant dataset copies and inference runs. It handles:
|
|
486
|
+
- Local disk saving (matplotlib figures)
|
|
487
|
+
- WandB logging (multiple modes: direct, boxes, masks)
|
|
488
|
+
- Model-specific visualizations (PAFs for bottomup, class maps for multi_class_bottomup)
|
|
489
|
+
|
|
490
|
+
Benefits over separate callbacks:
|
|
491
|
+
- Uses ONE sample per epoch for all visualizations (no dataset deepcopy)
|
|
492
|
+
- Runs inference ONCE per sample (vs 4-8x in previous implementation)
|
|
493
|
+
- Outputs to multiple destinations from the same data
|
|
494
|
+
- Simpler code with less duplication
|
|
495
|
+
|
|
496
|
+
Attributes:
|
|
497
|
+
model_trainer: Reference to the ModelTrainer (for lazy access to lightning_model).
|
|
498
|
+
train_pipeline: Iterator over training visualization dataset.
|
|
499
|
+
val_pipeline: Iterator over validation visualization dataset.
|
|
500
|
+
model_type: Type of model (affects which visualizations are enabled).
|
|
501
|
+
save_local: Whether to save matplotlib figures to disk.
|
|
502
|
+
local_save_dir: Directory for local visualization saves.
|
|
503
|
+
log_wandb: Whether to log visualizations to wandb.
|
|
504
|
+
wandb_modes: List of wandb rendering modes ("direct", "boxes", "masks").
|
|
505
|
+
wandb_box_size: Size of keypoint boxes in pixels (for "boxes" mode).
|
|
506
|
+
wandb_confmap_threshold: Threshold for confmap masks (for "masks" mode).
|
|
507
|
+
log_wandb_table: Whether to also log to a wandb.Table.
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
def __init__(
|
|
511
|
+
self,
|
|
512
|
+
model_trainer,
|
|
513
|
+
train_dataset,
|
|
514
|
+
val_dataset,
|
|
515
|
+
model_type: str,
|
|
516
|
+
save_local: bool = True,
|
|
517
|
+
local_save_dir: Optional[Path] = None,
|
|
518
|
+
log_wandb: bool = False,
|
|
519
|
+
wandb_modes: Optional[list] = None,
|
|
520
|
+
wandb_box_size: float = 5.0,
|
|
521
|
+
wandb_confmap_threshold: float = 0.1,
|
|
522
|
+
log_wandb_table: bool = False,
|
|
523
|
+
):
|
|
524
|
+
"""Initialize the unified visualization callback.
|
|
525
|
+
|
|
526
|
+
Args:
|
|
527
|
+
model_trainer: ModelTrainer instance (lightning_model accessed lazily).
|
|
528
|
+
train_dataset: Training visualization dataset (will be cycled).
|
|
529
|
+
val_dataset: Validation visualization dataset (will be cycled).
|
|
530
|
+
model_type: Model type string (e.g., "bottomup", "multi_class_bottomup").
|
|
531
|
+
save_local: If True, save matplotlib figures to local_save_dir.
|
|
532
|
+
local_save_dir: Path to directory for saving visualization images.
|
|
533
|
+
log_wandb: If True, log visualizations to wandb.
|
|
534
|
+
wandb_modes: List of wandb rendering modes. Defaults to ["direct"].
|
|
535
|
+
wandb_box_size: Size of keypoint boxes in pixels.
|
|
536
|
+
wandb_confmap_threshold: Threshold for confidence map masks.
|
|
537
|
+
log_wandb_table: If True, also log to a wandb.Table.
|
|
538
|
+
"""
|
|
539
|
+
super().__init__()
|
|
540
|
+
from itertools import cycle
|
|
541
|
+
|
|
542
|
+
self.model_trainer = model_trainer
|
|
543
|
+
self.train_pipeline = cycle(train_dataset)
|
|
544
|
+
self.val_pipeline = cycle(val_dataset)
|
|
545
|
+
self.model_type = model_type
|
|
546
|
+
|
|
547
|
+
# Local disk config
|
|
548
|
+
self.save_local = save_local
|
|
549
|
+
self.local_save_dir = local_save_dir
|
|
550
|
+
|
|
551
|
+
# WandB config
|
|
552
|
+
self.log_wandb = log_wandb
|
|
553
|
+
self.wandb_modes = wandb_modes or ["direct"]
|
|
554
|
+
self.wandb_box_size = wandb_box_size
|
|
555
|
+
self.wandb_confmap_threshold = wandb_confmap_threshold
|
|
556
|
+
self.log_wandb_table = log_wandb_table
|
|
557
|
+
|
|
558
|
+
# Auto-enable model-specific visualizations
|
|
559
|
+
self.viz_pafs = model_type == "bottomup"
|
|
560
|
+
self.viz_class_maps = model_type == "multi_class_bottomup"
|
|
561
|
+
|
|
562
|
+
# Initialize renderers
|
|
563
|
+
from sleap_nn.training.utils import MatplotlibRenderer, WandBRenderer
|
|
564
|
+
|
|
565
|
+
self._mpl_renderer = MatplotlibRenderer()
|
|
566
|
+
|
|
567
|
+
# Create wandb renderers for each enabled mode
|
|
568
|
+
self._wandb_renderers = {}
|
|
569
|
+
if log_wandb:
|
|
570
|
+
for mode in self.wandb_modes:
|
|
571
|
+
self._wandb_renderers[mode] = WandBRenderer(
|
|
572
|
+
mode=mode,
|
|
573
|
+
box_size=wandb_box_size,
|
|
574
|
+
confmap_threshold=wandb_confmap_threshold,
|
|
575
|
+
)
|
|
576
|
+
|
|
577
|
+
def _get_wandb_logger(self, trainer):
|
|
578
|
+
"""Get the WandbLogger from trainer's loggers."""
|
|
579
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
580
|
+
|
|
581
|
+
for log in trainer.loggers:
|
|
582
|
+
if isinstance(log, WandbLogger):
|
|
583
|
+
return log
|
|
584
|
+
return None
|
|
585
|
+
|
|
586
|
+
def _get_viz_data(self, sample):
|
|
587
|
+
"""Get visualization data with all needed fields based on model type.
|
|
588
|
+
|
|
589
|
+
Args:
|
|
590
|
+
sample: A sample from the visualization dataset.
|
|
591
|
+
|
|
592
|
+
Returns:
|
|
593
|
+
VisualizationData with appropriate fields populated.
|
|
594
|
+
"""
|
|
595
|
+
# Build kwargs based on model type
|
|
596
|
+
kwargs = {}
|
|
597
|
+
if self.viz_pafs:
|
|
598
|
+
kwargs["include_pafs"] = True
|
|
599
|
+
if self.viz_class_maps:
|
|
600
|
+
kwargs["include_class_maps"] = True
|
|
601
|
+
|
|
602
|
+
# Access lightning_model lazily from model_trainer
|
|
603
|
+
return self.model_trainer.lightning_model.get_visualization_data(
|
|
604
|
+
sample, **kwargs
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
def _save_local_viz(self, data, prefix: str, epoch: int):
|
|
608
|
+
"""Save visualization to local disk.
|
|
609
|
+
|
|
610
|
+
Args:
|
|
611
|
+
data: VisualizationData object.
|
|
612
|
+
prefix: Filename prefix (e.g., "train", "validation").
|
|
613
|
+
epoch: Current epoch number.
|
|
614
|
+
"""
|
|
615
|
+
if not self.save_local or self.local_save_dir is None:
|
|
616
|
+
return
|
|
617
|
+
|
|
618
|
+
# Confmaps visualization
|
|
619
|
+
fig = self._mpl_renderer.render(data)
|
|
620
|
+
fig_path = self.local_save_dir / f"{prefix}.{epoch:04d}.png"
|
|
621
|
+
fig.savefig(fig_path, format="png")
|
|
622
|
+
plt.close(fig)
|
|
623
|
+
|
|
624
|
+
# PAFs visualization (for bottomup models)
|
|
625
|
+
if self.viz_pafs and data.pred_pafs is not None:
|
|
626
|
+
fig = self._mpl_renderer.render_pafs(data)
|
|
627
|
+
fig_path = self.local_save_dir / f"{prefix}.pafs_magnitude.{epoch:04d}.png"
|
|
628
|
+
fig.savefig(fig_path, format="png")
|
|
629
|
+
plt.close(fig)
|
|
630
|
+
|
|
631
|
+
# Class maps visualization (for multi_class_bottomup models)
|
|
632
|
+
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
633
|
+
fig = self._render_class_maps(data)
|
|
634
|
+
fig_path = self.local_save_dir / f"{prefix}.class_maps.{epoch:04d}.png"
|
|
635
|
+
fig.savefig(fig_path, format="png")
|
|
636
|
+
plt.close(fig)
|
|
637
|
+
|
|
638
|
+
def _render_class_maps(self, data):
|
|
639
|
+
"""Render class maps visualization.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
data: VisualizationData with pred_class_maps populated.
|
|
643
|
+
|
|
644
|
+
Returns:
|
|
645
|
+
A matplotlib Figure object.
|
|
646
|
+
"""
|
|
647
|
+
from sleap_nn.training.utils import plot_img, plot_confmaps
|
|
648
|
+
|
|
649
|
+
img = data.image
|
|
650
|
+
scale = 1.0
|
|
651
|
+
if img.shape[0] < 512:
|
|
652
|
+
scale = 2.0
|
|
653
|
+
if img.shape[0] < 256:
|
|
654
|
+
scale = 4.0
|
|
655
|
+
|
|
656
|
+
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
657
|
+
plot_confmaps(
|
|
658
|
+
data.pred_class_maps,
|
|
659
|
+
output_scale=data.pred_class_maps.shape[0] / img.shape[0],
|
|
660
|
+
)
|
|
661
|
+
return fig
|
|
662
|
+
|
|
663
|
+
def _log_wandb_viz(self, data, prefix: str, epoch: int, wandb_logger):
|
|
664
|
+
"""Log visualization to wandb.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
data: VisualizationData object.
|
|
668
|
+
prefix: Log prefix (e.g., "train", "val").
|
|
669
|
+
epoch: Current epoch number.
|
|
670
|
+
wandb_logger: WandbLogger instance.
|
|
671
|
+
"""
|
|
672
|
+
if not self.log_wandb or wandb_logger is None:
|
|
673
|
+
return
|
|
674
|
+
|
|
675
|
+
from io import BytesIO
|
|
676
|
+
from PIL import Image as PILImage
|
|
677
|
+
|
|
678
|
+
log_dict = {}
|
|
679
|
+
|
|
680
|
+
# Render confmaps for each enabled mode
|
|
681
|
+
for mode_name, renderer in self._wandb_renderers.items():
|
|
682
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
683
|
+
img = renderer.render(data, caption=f"{prefix.title()} Epoch {epoch}")
|
|
684
|
+
log_dict[f"viz/{prefix}/predictions{suffix}"] = img
|
|
685
|
+
|
|
686
|
+
# PAFs visualization (for bottomup models)
|
|
687
|
+
if self.viz_pafs and data.pred_pafs is not None:
|
|
688
|
+
pafs_fig = self._mpl_renderer.render_pafs(data)
|
|
689
|
+
buf = BytesIO()
|
|
690
|
+
pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
691
|
+
buf.seek(0)
|
|
692
|
+
plt.close(pafs_fig)
|
|
693
|
+
pafs_pil = PILImage.open(buf)
|
|
694
|
+
log_dict[f"viz/{prefix}/pafs"] = wandb.Image(
|
|
695
|
+
pafs_pil, caption=f"{prefix.title()} PAFs Epoch {epoch}"
|
|
696
|
+
)
|
|
697
|
+
|
|
698
|
+
# Class maps visualization (for multi_class_bottomup models)
|
|
699
|
+
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
700
|
+
class_fig = self._render_class_maps(data)
|
|
701
|
+
buf = BytesIO()
|
|
702
|
+
class_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
703
|
+
buf.seek(0)
|
|
704
|
+
plt.close(class_fig)
|
|
705
|
+
class_pil = PILImage.open(buf)
|
|
706
|
+
log_dict[f"viz/{prefix}/class_maps"] = wandb.Image(
|
|
707
|
+
class_pil, caption=f"{prefix.title()} Class Maps Epoch {epoch}"
|
|
708
|
+
)
|
|
709
|
+
|
|
710
|
+
if log_dict:
|
|
711
|
+
log_dict["epoch"] = epoch
|
|
712
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
713
|
+
|
|
714
|
+
# Optionally log to table for backwards compatibility
|
|
715
|
+
if self.log_wandb_table and "direct" in self._wandb_renderers:
|
|
716
|
+
train_img = self._wandb_renderers["direct"].render(
|
|
717
|
+
data, caption=f"{prefix.title()} Epoch {epoch}"
|
|
718
|
+
)
|
|
719
|
+
table_data = [[epoch, train_img]]
|
|
720
|
+
columns = ["Epoch", prefix.title()]
|
|
721
|
+
|
|
722
|
+
if self.viz_pafs and data.pred_pafs is not None:
|
|
723
|
+
columns.append(f"{prefix.title()} PAFs")
|
|
724
|
+
table_data[0].append(log_dict.get(f"viz/{prefix}/pafs"))
|
|
725
|
+
|
|
726
|
+
if self.viz_class_maps and data.pred_class_maps is not None:
|
|
727
|
+
columns.append(f"{prefix.title()} Class Maps")
|
|
728
|
+
table_data[0].append(log_dict.get(f"viz/{prefix}/class_maps"))
|
|
729
|
+
|
|
730
|
+
table = wandb.Table(columns=columns, data=table_data)
|
|
731
|
+
wandb_logger.experiment.log(
|
|
732
|
+
{f"predictions_table_{prefix}": table}, commit=False
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
|
736
|
+
"""Generate and output all visualizations at epoch end.
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
trainer: PyTorch Lightning trainer.
|
|
740
|
+
pl_module: Lightning module (not used, we use self.lightning_module).
|
|
741
|
+
"""
|
|
742
|
+
if trainer.is_global_zero:
|
|
743
|
+
epoch = trainer.current_epoch
|
|
744
|
+
wandb_logger = self._get_wandb_logger(trainer) if self.log_wandb else None
|
|
745
|
+
|
|
746
|
+
# Get ONE sample for train visualization
|
|
747
|
+
train_sample = next(self.train_pipeline)
|
|
748
|
+
# Run inference ONCE with all needed data
|
|
749
|
+
train_data = self._get_viz_data(train_sample)
|
|
750
|
+
# Output to all destinations
|
|
751
|
+
self._save_local_viz(train_data, "train", epoch)
|
|
752
|
+
self._log_wandb_viz(train_data, "train", epoch, wandb_logger)
|
|
753
|
+
|
|
754
|
+
# Same for validation
|
|
755
|
+
val_sample = next(self.val_pipeline)
|
|
756
|
+
val_data = self._get_viz_data(val_sample)
|
|
757
|
+
self._save_local_viz(val_data, "validation", epoch)
|
|
758
|
+
self._log_wandb_viz(val_data, "val", epoch, wandb_logger)
|
|
759
|
+
|
|
760
|
+
# Sync all processes - barrier must be reached by ALL ranks
|
|
761
|
+
trainer.strategy.barrier()
|
|
762
|
+
|
|
763
|
+
|
|
150
764
|
class MatplotlibSaver(Callback):
|
|
151
765
|
"""Callback for saving images rendered with matplotlib during training.
|
|
152
766
|
|
|
@@ -194,7 +808,7 @@ class MatplotlibSaver(Callback):
|
|
|
194
808
|
).as_posix()
|
|
195
809
|
|
|
196
810
|
# Save rendered figure.
|
|
197
|
-
figure.savefig(figure_path, format="png"
|
|
811
|
+
figure.savefig(figure_path, format="png")
|
|
198
812
|
plt.close(figure)
|
|
199
813
|
|
|
200
814
|
# Sync all processes after file I/O
|
|
@@ -303,7 +917,11 @@ class ProgressReporterZMQ(Callback):
|
|
|
303
917
|
def on_train_start(self, trainer, pl_module):
|
|
304
918
|
"""Called at the beginning of training process."""
|
|
305
919
|
if trainer.is_global_zero:
|
|
306
|
-
|
|
920
|
+
# Include WandB URL if available
|
|
921
|
+
wandb_url = None
|
|
922
|
+
if wandb.run is not None:
|
|
923
|
+
wandb_url = wandb.run.url
|
|
924
|
+
self.send("train_begin", wandb_url=wandb_url)
|
|
307
925
|
trainer.strategy.barrier()
|
|
308
926
|
|
|
309
927
|
def on_train_end(self, trainer, pl_module):
|
|
@@ -350,3 +968,638 @@ class ProgressReporterZMQ(Callback):
|
|
|
350
968
|
return {
|
|
351
969
|
k: float(v.item()) if hasattr(v, "item") else v for k, v in logs.items()
|
|
352
970
|
}
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
class EpochEndEvaluationCallback(Callback):
|
|
974
|
+
"""Callback to run full evaluation metrics at end of validation epochs.
|
|
975
|
+
|
|
976
|
+
This callback collects predictions and ground truth during validation,
|
|
977
|
+
then runs the full evaluation pipeline (OKS, mAP, PCK, etc.) and logs
|
|
978
|
+
metrics to WandB.
|
|
979
|
+
|
|
980
|
+
Attributes:
|
|
981
|
+
skeleton: sio.Skeleton for creating instances.
|
|
982
|
+
videos: List of sio.Video objects.
|
|
983
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
984
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
985
|
+
oks_scale: Optional OKS scale override.
|
|
986
|
+
metrics_to_log: List of metric keys to log.
|
|
987
|
+
"""
|
|
988
|
+
|
|
989
|
+
def __init__(
|
|
990
|
+
self,
|
|
991
|
+
skeleton: "sio.Skeleton",
|
|
992
|
+
videos: list,
|
|
993
|
+
eval_frequency: int = 1,
|
|
994
|
+
oks_stddev: float = 0.025,
|
|
995
|
+
oks_scale: Optional[float] = None,
|
|
996
|
+
metrics_to_log: Optional[list] = None,
|
|
997
|
+
):
|
|
998
|
+
"""Initialize the callback.
|
|
999
|
+
|
|
1000
|
+
Args:
|
|
1001
|
+
skeleton: sio.Skeleton for creating instances.
|
|
1002
|
+
videos: List of sio.Video objects.
|
|
1003
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1004
|
+
oks_stddev: OKS standard deviation (default: 0.025).
|
|
1005
|
+
oks_scale: Optional OKS scale override.
|
|
1006
|
+
metrics_to_log: List of metric keys to log. If None, logs all available.
|
|
1007
|
+
"""
|
|
1008
|
+
super().__init__()
|
|
1009
|
+
self.skeleton = skeleton
|
|
1010
|
+
self.videos = videos
|
|
1011
|
+
self.eval_frequency = eval_frequency
|
|
1012
|
+
self.oks_stddev = oks_stddev
|
|
1013
|
+
self.oks_scale = oks_scale
|
|
1014
|
+
self.metrics_to_log = metrics_to_log or [
|
|
1015
|
+
"mOKS",
|
|
1016
|
+
"oks_voc.mAP",
|
|
1017
|
+
"oks_voc.mAR",
|
|
1018
|
+
"distance/avg",
|
|
1019
|
+
"distance/p50",
|
|
1020
|
+
"distance/p95",
|
|
1021
|
+
"distance/p99",
|
|
1022
|
+
"mPCK",
|
|
1023
|
+
"PCK@5",
|
|
1024
|
+
"PCK@10",
|
|
1025
|
+
"visibility_precision",
|
|
1026
|
+
"visibility_recall",
|
|
1027
|
+
]
|
|
1028
|
+
|
|
1029
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
1030
|
+
"""Enable prediction collection at the start of validation.
|
|
1031
|
+
|
|
1032
|
+
Skip during sanity check to avoid inference issues.
|
|
1033
|
+
"""
|
|
1034
|
+
if trainer.sanity_checking:
|
|
1035
|
+
return
|
|
1036
|
+
pl_module._collect_val_predictions = True
|
|
1037
|
+
|
|
1038
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
1039
|
+
"""Run evaluation and log metrics at end of validation epoch."""
|
|
1040
|
+
import sleap_io as sio
|
|
1041
|
+
import numpy as np
|
|
1042
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
1043
|
+
from sleap_nn.evaluation import Evaluator
|
|
1044
|
+
|
|
1045
|
+
# Determine if we should run evaluation this epoch (only on rank 0)
|
|
1046
|
+
should_evaluate = (
|
|
1047
|
+
trainer.current_epoch + 1
|
|
1048
|
+
) % self.eval_frequency == 0 and trainer.is_global_zero
|
|
1049
|
+
|
|
1050
|
+
if should_evaluate:
|
|
1051
|
+
# Check if we have predictions
|
|
1052
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
1053
|
+
logger.warning("No predictions collected for epoch-end evaluation")
|
|
1054
|
+
else:
|
|
1055
|
+
try:
|
|
1056
|
+
# Build sio.Labels from accumulated predictions and ground truth
|
|
1057
|
+
pred_labels = self._build_pred_labels(
|
|
1058
|
+
pl_module.val_predictions, sio, np
|
|
1059
|
+
)
|
|
1060
|
+
gt_labels = self._build_gt_labels(
|
|
1061
|
+
pl_module.val_ground_truth, sio, np
|
|
1062
|
+
)
|
|
1063
|
+
|
|
1064
|
+
# Check if we have valid frames to evaluate
|
|
1065
|
+
if len(pred_labels) == 0:
|
|
1066
|
+
logger.warning(
|
|
1067
|
+
"No valid predictions for epoch-end evaluation "
|
|
1068
|
+
"(all predictions may be empty or NaN)"
|
|
1069
|
+
)
|
|
1070
|
+
else:
|
|
1071
|
+
# Run evaluation
|
|
1072
|
+
evaluator = Evaluator(
|
|
1073
|
+
ground_truth_instances=gt_labels,
|
|
1074
|
+
predicted_instances=pred_labels,
|
|
1075
|
+
oks_stddev=self.oks_stddev,
|
|
1076
|
+
oks_scale=self.oks_scale,
|
|
1077
|
+
user_labels_only=False, # All validation frames are "user" frames
|
|
1078
|
+
)
|
|
1079
|
+
metrics = evaluator.evaluate()
|
|
1080
|
+
|
|
1081
|
+
# Log to WandB
|
|
1082
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
1083
|
+
|
|
1084
|
+
logger.info(
|
|
1085
|
+
f"Epoch {trainer.current_epoch} evaluation: "
|
|
1086
|
+
f"PCK@5={metrics['pck_metrics']['PCK@5']:.4f}, "
|
|
1087
|
+
f"mOKS={metrics['mOKS']['mOKS']:.4f}, "
|
|
1088
|
+
f"mAP={metrics['voc_metrics']['oks_voc.mAP']:.4f}"
|
|
1089
|
+
)
|
|
1090
|
+
|
|
1091
|
+
except Exception as e:
|
|
1092
|
+
logger.warning(f"Epoch-end evaluation failed: {e}")
|
|
1093
|
+
|
|
1094
|
+
# Cleanup - all ranks reset the flag, rank 0 clears the lists
|
|
1095
|
+
pl_module._collect_val_predictions = False
|
|
1096
|
+
if trainer.is_global_zero:
|
|
1097
|
+
pl_module.val_predictions = []
|
|
1098
|
+
pl_module.val_ground_truth = []
|
|
1099
|
+
|
|
1100
|
+
# Sync all processes - barrier must be reached by ALL ranks
|
|
1101
|
+
trainer.strategy.barrier()
|
|
1102
|
+
|
|
1103
|
+
def _build_pred_labels(self, predictions: list, sio, np) -> "sio.Labels":
|
|
1104
|
+
"""Convert prediction dicts to sio.Labels."""
|
|
1105
|
+
labeled_frames = []
|
|
1106
|
+
for pred in predictions:
|
|
1107
|
+
pred_peaks = pred["pred_peaks"]
|
|
1108
|
+
pred_scores = pred["pred_scores"]
|
|
1109
|
+
|
|
1110
|
+
# Handle NaN/missing predictions
|
|
1111
|
+
if pred_peaks is None or (
|
|
1112
|
+
isinstance(pred_peaks, np.ndarray) and np.isnan(pred_peaks).all()
|
|
1113
|
+
):
|
|
1114
|
+
continue
|
|
1115
|
+
|
|
1116
|
+
# Handle multi-instance predictions (bottomup)
|
|
1117
|
+
if len(pred_peaks.shape) == 2:
|
|
1118
|
+
# Single instance: (n_nodes, 2) -> (1, n_nodes, 2)
|
|
1119
|
+
pred_peaks = pred_peaks.reshape(1, -1, 2)
|
|
1120
|
+
pred_scores = pred_scores.reshape(1, -1)
|
|
1121
|
+
|
|
1122
|
+
instances = []
|
|
1123
|
+
for inst_idx in range(len(pred_peaks)):
|
|
1124
|
+
inst_points = pred_peaks[inst_idx]
|
|
1125
|
+
inst_scores = pred_scores[inst_idx] if pred_scores is not None else None
|
|
1126
|
+
|
|
1127
|
+
# Skip if all NaN
|
|
1128
|
+
if np.isnan(inst_points).all():
|
|
1129
|
+
continue
|
|
1130
|
+
|
|
1131
|
+
inst = sio.PredictedInstance.from_numpy(
|
|
1132
|
+
points_data=inst_points,
|
|
1133
|
+
skeleton=self.skeleton,
|
|
1134
|
+
point_scores=(
|
|
1135
|
+
inst_scores
|
|
1136
|
+
if inst_scores is not None
|
|
1137
|
+
else np.ones(len(inst_points))
|
|
1138
|
+
),
|
|
1139
|
+
score=(
|
|
1140
|
+
float(np.nanmean(inst_scores))
|
|
1141
|
+
if inst_scores is not None
|
|
1142
|
+
else 1.0
|
|
1143
|
+
),
|
|
1144
|
+
)
|
|
1145
|
+
instances.append(inst)
|
|
1146
|
+
|
|
1147
|
+
if instances:
|
|
1148
|
+
lf = sio.LabeledFrame(
|
|
1149
|
+
video=self.videos[pred["video_idx"]],
|
|
1150
|
+
frame_idx=pred["frame_idx"],
|
|
1151
|
+
instances=instances,
|
|
1152
|
+
)
|
|
1153
|
+
labeled_frames.append(lf)
|
|
1154
|
+
|
|
1155
|
+
return sio.Labels(
|
|
1156
|
+
videos=self.videos,
|
|
1157
|
+
skeletons=[self.skeleton],
|
|
1158
|
+
labeled_frames=labeled_frames,
|
|
1159
|
+
)
|
|
1160
|
+
|
|
1161
|
+
def _build_gt_labels(self, ground_truth: list, sio, np) -> "sio.Labels":
|
|
1162
|
+
"""Convert ground truth dicts to sio.Labels."""
|
|
1163
|
+
labeled_frames = []
|
|
1164
|
+
for gt in ground_truth:
|
|
1165
|
+
instances = []
|
|
1166
|
+
gt_instances = gt["gt_instances"]
|
|
1167
|
+
|
|
1168
|
+
# Handle shape variations
|
|
1169
|
+
if len(gt_instances.shape) == 2:
|
|
1170
|
+
# (n_nodes, 2) -> (1, n_nodes, 2)
|
|
1171
|
+
gt_instances = gt_instances.reshape(1, -1, 2)
|
|
1172
|
+
|
|
1173
|
+
for i in range(min(gt["num_instances"], len(gt_instances))):
|
|
1174
|
+
inst_data = gt_instances[i]
|
|
1175
|
+
if np.isnan(inst_data).all():
|
|
1176
|
+
continue
|
|
1177
|
+
inst = sio.Instance.from_numpy(
|
|
1178
|
+
points_data=inst_data,
|
|
1179
|
+
skeleton=self.skeleton,
|
|
1180
|
+
)
|
|
1181
|
+
instances.append(inst)
|
|
1182
|
+
|
|
1183
|
+
if instances:
|
|
1184
|
+
lf = sio.LabeledFrame(
|
|
1185
|
+
video=self.videos[gt["video_idx"]],
|
|
1186
|
+
frame_idx=gt["frame_idx"],
|
|
1187
|
+
instances=instances,
|
|
1188
|
+
)
|
|
1189
|
+
labeled_frames.append(lf)
|
|
1190
|
+
|
|
1191
|
+
return sio.Labels(
|
|
1192
|
+
videos=self.videos,
|
|
1193
|
+
skeletons=[self.skeleton],
|
|
1194
|
+
labeled_frames=labeled_frames,
|
|
1195
|
+
)
|
|
1196
|
+
|
|
1197
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
1198
|
+
"""Log evaluation metrics to WandB."""
|
|
1199
|
+
import numpy as np
|
|
1200
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
1201
|
+
|
|
1202
|
+
# Get WandB logger
|
|
1203
|
+
wandb_logger = None
|
|
1204
|
+
for log in trainer.loggers:
|
|
1205
|
+
if isinstance(log, WandbLogger):
|
|
1206
|
+
wandb_logger = log
|
|
1207
|
+
break
|
|
1208
|
+
|
|
1209
|
+
if wandb_logger is None:
|
|
1210
|
+
return
|
|
1211
|
+
|
|
1212
|
+
log_dict = {"epoch": epoch}
|
|
1213
|
+
|
|
1214
|
+
# Extract key metrics with consistent naming
|
|
1215
|
+
# All eval metrics use eval/val/ prefix since they're computed on validation data
|
|
1216
|
+
if "mOKS" in self.metrics_to_log:
|
|
1217
|
+
log_dict["eval/val/mOKS"] = metrics["mOKS"]["mOKS"]
|
|
1218
|
+
|
|
1219
|
+
if "oks_voc.mAP" in self.metrics_to_log:
|
|
1220
|
+
log_dict["eval/val/oks_voc_mAP"] = metrics["voc_metrics"]["oks_voc.mAP"]
|
|
1221
|
+
|
|
1222
|
+
if "oks_voc.mAR" in self.metrics_to_log:
|
|
1223
|
+
log_dict["eval/val/oks_voc_mAR"] = metrics["voc_metrics"]["oks_voc.mAR"]
|
|
1224
|
+
|
|
1225
|
+
# Distance metrics grouped under eval/val/distance/
|
|
1226
|
+
if "distance/avg" in self.metrics_to_log:
|
|
1227
|
+
val = metrics["distance_metrics"]["avg"]
|
|
1228
|
+
if not np.isnan(val):
|
|
1229
|
+
log_dict["eval/val/distance/avg"] = val
|
|
1230
|
+
|
|
1231
|
+
if "distance/p50" in self.metrics_to_log:
|
|
1232
|
+
val = metrics["distance_metrics"]["p50"]
|
|
1233
|
+
if not np.isnan(val):
|
|
1234
|
+
log_dict["eval/val/distance/p50"] = val
|
|
1235
|
+
|
|
1236
|
+
if "distance/p95" in self.metrics_to_log:
|
|
1237
|
+
val = metrics["distance_metrics"]["p95"]
|
|
1238
|
+
if not np.isnan(val):
|
|
1239
|
+
log_dict["eval/val/distance/p95"] = val
|
|
1240
|
+
|
|
1241
|
+
if "distance/p99" in self.metrics_to_log:
|
|
1242
|
+
val = metrics["distance_metrics"]["p99"]
|
|
1243
|
+
if not np.isnan(val):
|
|
1244
|
+
log_dict["eval/val/distance/p99"] = val
|
|
1245
|
+
|
|
1246
|
+
# PCK metrics
|
|
1247
|
+
if "mPCK" in self.metrics_to_log:
|
|
1248
|
+
log_dict["eval/val/mPCK"] = metrics["pck_metrics"]["mPCK"]
|
|
1249
|
+
|
|
1250
|
+
# PCK at specific thresholds (precomputed in evaluation.py)
|
|
1251
|
+
if "PCK@5" in self.metrics_to_log:
|
|
1252
|
+
log_dict["eval/val/PCK_5"] = metrics["pck_metrics"]["PCK@5"]
|
|
1253
|
+
|
|
1254
|
+
if "PCK@10" in self.metrics_to_log:
|
|
1255
|
+
log_dict["eval/val/PCK_10"] = metrics["pck_metrics"]["PCK@10"]
|
|
1256
|
+
|
|
1257
|
+
# Visibility metrics
|
|
1258
|
+
if "visibility_precision" in self.metrics_to_log:
|
|
1259
|
+
val = metrics["visibility_metrics"]["precision"]
|
|
1260
|
+
if not np.isnan(val):
|
|
1261
|
+
log_dict["eval/val/visibility_precision"] = val
|
|
1262
|
+
|
|
1263
|
+
if "visibility_recall" in self.metrics_to_log:
|
|
1264
|
+
val = metrics["visibility_metrics"]["recall"]
|
|
1265
|
+
if not np.isnan(val):
|
|
1266
|
+
log_dict["eval/val/visibility_recall"] = val
|
|
1267
|
+
|
|
1268
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
1269
|
+
|
|
1270
|
+
# Update best metrics in summary (excluding epoch)
|
|
1271
|
+
for key, value in log_dict.items():
|
|
1272
|
+
if key == "epoch":
|
|
1273
|
+
continue
|
|
1274
|
+
# Create summary key like "best/eval/val/mOKS"
|
|
1275
|
+
summary_key = f"best/{key}"
|
|
1276
|
+
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
1277
|
+
# For distance metrics, lower is better; for others, higher is better
|
|
1278
|
+
is_distance = "distance" in key
|
|
1279
|
+
if current_best is None:
|
|
1280
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
1281
|
+
elif is_distance and value < current_best:
|
|
1282
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
1283
|
+
elif not is_distance and value > current_best:
|
|
1284
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
1285
|
+
|
|
1286
|
+
|
|
1287
|
+
def match_centroids(
|
|
1288
|
+
pred_centroids: "np.ndarray",
|
|
1289
|
+
gt_centroids: "np.ndarray",
|
|
1290
|
+
max_distance: float = 50.0,
|
|
1291
|
+
) -> tuple:
|
|
1292
|
+
"""Match predicted centroids to ground truth using Hungarian algorithm.
|
|
1293
|
+
|
|
1294
|
+
Args:
|
|
1295
|
+
pred_centroids: Predicted centroid locations, shape (n_pred, 2).
|
|
1296
|
+
gt_centroids: Ground truth centroid locations, shape (n_gt, 2).
|
|
1297
|
+
max_distance: Maximum distance threshold for valid matches (in pixels).
|
|
1298
|
+
|
|
1299
|
+
Returns:
|
|
1300
|
+
Tuple of:
|
|
1301
|
+
- matched_pred_indices: Indices of matched predictions
|
|
1302
|
+
- matched_gt_indices: Indices of matched ground truth
|
|
1303
|
+
- unmatched_pred_indices: Indices of unmatched predictions (false positives)
|
|
1304
|
+
- unmatched_gt_indices: Indices of unmatched ground truth (false negatives)
|
|
1305
|
+
"""
|
|
1306
|
+
import numpy as np
|
|
1307
|
+
from scipy.optimize import linear_sum_assignment
|
|
1308
|
+
from scipy.spatial.distance import cdist
|
|
1309
|
+
|
|
1310
|
+
n_pred = len(pred_centroids)
|
|
1311
|
+
n_gt = len(gt_centroids)
|
|
1312
|
+
|
|
1313
|
+
# Handle edge cases
|
|
1314
|
+
if n_pred == 0 and n_gt == 0:
|
|
1315
|
+
return np.array([]), np.array([]), np.array([]), np.array([])
|
|
1316
|
+
if n_pred == 0:
|
|
1317
|
+
return np.array([]), np.array([]), np.array([]), np.arange(n_gt)
|
|
1318
|
+
if n_gt == 0:
|
|
1319
|
+
return np.array([]), np.array([]), np.arange(n_pred), np.array([])
|
|
1320
|
+
|
|
1321
|
+
# Compute pairwise distances
|
|
1322
|
+
cost_matrix = cdist(pred_centroids, gt_centroids)
|
|
1323
|
+
|
|
1324
|
+
# Run Hungarian algorithm for optimal matching
|
|
1325
|
+
pred_indices, gt_indices = linear_sum_assignment(cost_matrix)
|
|
1326
|
+
|
|
1327
|
+
# Filter matches that exceed max_distance
|
|
1328
|
+
matched_pred = []
|
|
1329
|
+
matched_gt = []
|
|
1330
|
+
for p_idx, g_idx in zip(pred_indices, gt_indices):
|
|
1331
|
+
if cost_matrix[p_idx, g_idx] <= max_distance:
|
|
1332
|
+
matched_pred.append(p_idx)
|
|
1333
|
+
matched_gt.append(g_idx)
|
|
1334
|
+
|
|
1335
|
+
matched_pred = np.array(matched_pred)
|
|
1336
|
+
matched_gt = np.array(matched_gt)
|
|
1337
|
+
|
|
1338
|
+
# Find unmatched indices
|
|
1339
|
+
all_pred = set(range(n_pred))
|
|
1340
|
+
all_gt = set(range(n_gt))
|
|
1341
|
+
unmatched_pred = np.array(list(all_pred - set(matched_pred)))
|
|
1342
|
+
unmatched_gt = np.array(list(all_gt - set(matched_gt)))
|
|
1343
|
+
|
|
1344
|
+
return matched_pred, matched_gt, unmatched_pred, unmatched_gt
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
class CentroidEvaluationCallback(Callback):
|
|
1348
|
+
"""Callback to run centroid-specific evaluation metrics at end of validation epochs.
|
|
1349
|
+
|
|
1350
|
+
This callback is designed specifically for centroid models, which predict a single
|
|
1351
|
+
point (centroid) per instance rather than full pose skeletons. It computes
|
|
1352
|
+
distance-based metrics and detection metrics that are more appropriate for
|
|
1353
|
+
point detection tasks than OKS/PCK metrics.
|
|
1354
|
+
|
|
1355
|
+
Metrics computed:
|
|
1356
|
+
- Distance metrics: mean, median, p90, p95, max Euclidean distance
|
|
1357
|
+
- Detection metrics: precision, recall, F1 score
|
|
1358
|
+
- Counts: true positives, false positives, false negatives
|
|
1359
|
+
|
|
1360
|
+
Attributes:
|
|
1361
|
+
videos: List of sio.Video objects.
|
|
1362
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1363
|
+
match_threshold: Maximum distance (pixels) for matching pred to GT (default: 50.0).
|
|
1364
|
+
"""
|
|
1365
|
+
|
|
1366
|
+
def __init__(
|
|
1367
|
+
self,
|
|
1368
|
+
videos: list,
|
|
1369
|
+
eval_frequency: int = 1,
|
|
1370
|
+
match_threshold: float = 50.0,
|
|
1371
|
+
):
|
|
1372
|
+
"""Initialize the callback.
|
|
1373
|
+
|
|
1374
|
+
Args:
|
|
1375
|
+
videos: List of sio.Video objects.
|
|
1376
|
+
eval_frequency: Run evaluation every N epochs (default: 1).
|
|
1377
|
+
match_threshold: Maximum distance in pixels for a prediction to be
|
|
1378
|
+
considered a match to a ground truth centroid (default: 50.0).
|
|
1379
|
+
"""
|
|
1380
|
+
super().__init__()
|
|
1381
|
+
self.videos = videos
|
|
1382
|
+
self.eval_frequency = eval_frequency
|
|
1383
|
+
self.match_threshold = match_threshold
|
|
1384
|
+
|
|
1385
|
+
def on_validation_epoch_start(self, trainer, pl_module):
|
|
1386
|
+
"""Enable prediction collection at the start of validation.
|
|
1387
|
+
|
|
1388
|
+
Skip during sanity check to avoid inference issues.
|
|
1389
|
+
"""
|
|
1390
|
+
if trainer.sanity_checking:
|
|
1391
|
+
return
|
|
1392
|
+
pl_module._collect_val_predictions = True
|
|
1393
|
+
|
|
1394
|
+
def on_validation_epoch_end(self, trainer, pl_module):
|
|
1395
|
+
"""Run centroid evaluation and log metrics at end of validation epoch."""
|
|
1396
|
+
import numpy as np
|
|
1397
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
1398
|
+
|
|
1399
|
+
# Determine if we should run evaluation this epoch (only on rank 0)
|
|
1400
|
+
should_evaluate = (
|
|
1401
|
+
trainer.current_epoch + 1
|
|
1402
|
+
) % self.eval_frequency == 0 and trainer.is_global_zero
|
|
1403
|
+
|
|
1404
|
+
if should_evaluate:
|
|
1405
|
+
# Check if we have predictions
|
|
1406
|
+
if not pl_module.val_predictions or not pl_module.val_ground_truth:
|
|
1407
|
+
logger.warning(
|
|
1408
|
+
"No predictions collected for centroid epoch-end evaluation"
|
|
1409
|
+
)
|
|
1410
|
+
else:
|
|
1411
|
+
try:
|
|
1412
|
+
metrics = self._compute_metrics(
|
|
1413
|
+
pl_module.val_predictions, pl_module.val_ground_truth, np
|
|
1414
|
+
)
|
|
1415
|
+
|
|
1416
|
+
# Log to WandB
|
|
1417
|
+
self._log_metrics(trainer, metrics, trainer.current_epoch)
|
|
1418
|
+
|
|
1419
|
+
logger.info(
|
|
1420
|
+
f"Epoch {trainer.current_epoch} centroid evaluation: "
|
|
1421
|
+
f"precision={metrics['precision']:.4f}, "
|
|
1422
|
+
f"recall={metrics['recall']:.4f}, "
|
|
1423
|
+
f"dist_avg={metrics['dist_avg']:.2f}px"
|
|
1424
|
+
)
|
|
1425
|
+
|
|
1426
|
+
except Exception as e:
|
|
1427
|
+
logger.warning(f"Centroid epoch-end evaluation failed: {e}")
|
|
1428
|
+
|
|
1429
|
+
# Cleanup - all ranks reset the flag, rank 0 clears the lists
|
|
1430
|
+
pl_module._collect_val_predictions = False
|
|
1431
|
+
if trainer.is_global_zero:
|
|
1432
|
+
pl_module.val_predictions = []
|
|
1433
|
+
pl_module.val_ground_truth = []
|
|
1434
|
+
|
|
1435
|
+
# Sync all processes - barrier must be reached by ALL ranks
|
|
1436
|
+
trainer.strategy.barrier()
|
|
1437
|
+
|
|
1438
|
+
def _compute_metrics(self, predictions: list, ground_truth: list, np) -> dict:
|
|
1439
|
+
"""Compute centroid-specific metrics.
|
|
1440
|
+
|
|
1441
|
+
Args:
|
|
1442
|
+
predictions: List of prediction dicts with "pred_peaks" key.
|
|
1443
|
+
ground_truth: List of ground truth dicts with "gt_instances" key.
|
|
1444
|
+
np: NumPy module.
|
|
1445
|
+
|
|
1446
|
+
Returns:
|
|
1447
|
+
Dictionary of computed metrics.
|
|
1448
|
+
"""
|
|
1449
|
+
all_distances = []
|
|
1450
|
+
total_tp = 0
|
|
1451
|
+
total_fp = 0
|
|
1452
|
+
total_fn = 0
|
|
1453
|
+
|
|
1454
|
+
# Group predictions and GT by frame
|
|
1455
|
+
pred_by_frame = {}
|
|
1456
|
+
for pred in predictions:
|
|
1457
|
+
key = (pred["video_idx"], pred["frame_idx"])
|
|
1458
|
+
if key not in pred_by_frame:
|
|
1459
|
+
pred_by_frame[key] = []
|
|
1460
|
+
# pred_peaks shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
|
|
1461
|
+
centroids = pred["pred_peaks"].reshape(-1, 2)
|
|
1462
|
+
# Filter out NaN centroids
|
|
1463
|
+
valid_mask = ~np.isnan(centroids).any(axis=1)
|
|
1464
|
+
pred_by_frame[key].append(centroids[valid_mask])
|
|
1465
|
+
|
|
1466
|
+
gt_by_frame = {}
|
|
1467
|
+
for gt in ground_truth:
|
|
1468
|
+
key = (gt["video_idx"], gt["frame_idx"])
|
|
1469
|
+
if key not in gt_by_frame:
|
|
1470
|
+
gt_by_frame[key] = []
|
|
1471
|
+
# gt_instances shape: (n_inst, 1, 2) -> extract centroids as (n_inst, 2)
|
|
1472
|
+
centroids = gt["gt_instances"].reshape(-1, 2)
|
|
1473
|
+
# Filter out NaN centroids
|
|
1474
|
+
valid_mask = ~np.isnan(centroids).any(axis=1)
|
|
1475
|
+
gt_by_frame[key].append(centroids[valid_mask])
|
|
1476
|
+
|
|
1477
|
+
# Process each frame
|
|
1478
|
+
all_frames = set(pred_by_frame.keys()) | set(gt_by_frame.keys())
|
|
1479
|
+
for frame_key in all_frames:
|
|
1480
|
+
# Concatenate all predictions for this frame
|
|
1481
|
+
if frame_key in pred_by_frame:
|
|
1482
|
+
frame_preds = np.concatenate(pred_by_frame[frame_key], axis=0)
|
|
1483
|
+
else:
|
|
1484
|
+
frame_preds = np.zeros((0, 2))
|
|
1485
|
+
|
|
1486
|
+
# Concatenate all GT for this frame
|
|
1487
|
+
if frame_key in gt_by_frame:
|
|
1488
|
+
frame_gt = np.concatenate(gt_by_frame[frame_key], axis=0)
|
|
1489
|
+
else:
|
|
1490
|
+
frame_gt = np.zeros((0, 2))
|
|
1491
|
+
|
|
1492
|
+
# Match predictions to ground truth
|
|
1493
|
+
matched_pred, matched_gt, unmatched_pred, unmatched_gt = match_centroids(
|
|
1494
|
+
frame_preds, frame_gt, max_distance=self.match_threshold
|
|
1495
|
+
)
|
|
1496
|
+
|
|
1497
|
+
# Compute distances for matched pairs
|
|
1498
|
+
if len(matched_pred) > 0:
|
|
1499
|
+
matched_pred_points = frame_preds[matched_pred]
|
|
1500
|
+
matched_gt_points = frame_gt[matched_gt]
|
|
1501
|
+
distances = np.linalg.norm(
|
|
1502
|
+
matched_pred_points - matched_gt_points, axis=1
|
|
1503
|
+
)
|
|
1504
|
+
all_distances.extend(distances.tolist())
|
|
1505
|
+
|
|
1506
|
+
# Update counts
|
|
1507
|
+
total_tp += len(matched_pred)
|
|
1508
|
+
total_fp += len(unmatched_pred)
|
|
1509
|
+
total_fn += len(unmatched_gt)
|
|
1510
|
+
|
|
1511
|
+
# Compute aggregate metrics
|
|
1512
|
+
all_distances = np.array(all_distances)
|
|
1513
|
+
|
|
1514
|
+
# Distance metrics (only if we have matches)
|
|
1515
|
+
if len(all_distances) > 0:
|
|
1516
|
+
dist_avg = float(np.mean(all_distances))
|
|
1517
|
+
dist_median = float(np.median(all_distances))
|
|
1518
|
+
dist_p90 = float(np.percentile(all_distances, 90))
|
|
1519
|
+
dist_p95 = float(np.percentile(all_distances, 95))
|
|
1520
|
+
dist_max = float(np.max(all_distances))
|
|
1521
|
+
else:
|
|
1522
|
+
dist_avg = dist_median = dist_p90 = dist_p95 = dist_max = float("nan")
|
|
1523
|
+
|
|
1524
|
+
# Detection metrics
|
|
1525
|
+
precision = (
|
|
1526
|
+
total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
|
|
1527
|
+
)
|
|
1528
|
+
recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
|
|
1529
|
+
f1 = (
|
|
1530
|
+
2 * precision * recall / (precision + recall)
|
|
1531
|
+
if (precision + recall) > 0
|
|
1532
|
+
else 0.0
|
|
1533
|
+
)
|
|
1534
|
+
|
|
1535
|
+
return {
|
|
1536
|
+
"dist_avg": dist_avg,
|
|
1537
|
+
"dist_median": dist_median,
|
|
1538
|
+
"dist_p90": dist_p90,
|
|
1539
|
+
"dist_p95": dist_p95,
|
|
1540
|
+
"dist_max": dist_max,
|
|
1541
|
+
"precision": precision,
|
|
1542
|
+
"recall": recall,
|
|
1543
|
+
"f1": f1,
|
|
1544
|
+
"n_true_positives": total_tp,
|
|
1545
|
+
"n_false_positives": total_fp,
|
|
1546
|
+
"n_false_negatives": total_fn,
|
|
1547
|
+
"n_total_predictions": total_tp + total_fp,
|
|
1548
|
+
"n_total_ground_truth": total_tp + total_fn,
|
|
1549
|
+
}
|
|
1550
|
+
|
|
1551
|
+
def _log_metrics(self, trainer, metrics: dict, epoch: int):
|
|
1552
|
+
"""Log centroid evaluation metrics to WandB."""
|
|
1553
|
+
import numpy as np
|
|
1554
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
1555
|
+
|
|
1556
|
+
# Get WandB logger
|
|
1557
|
+
wandb_logger = None
|
|
1558
|
+
for log in trainer.loggers:
|
|
1559
|
+
if isinstance(log, WandbLogger):
|
|
1560
|
+
wandb_logger = log
|
|
1561
|
+
break
|
|
1562
|
+
|
|
1563
|
+
if wandb_logger is None:
|
|
1564
|
+
return
|
|
1565
|
+
|
|
1566
|
+
log_dict = {"epoch": epoch}
|
|
1567
|
+
|
|
1568
|
+
# Distance metrics (with NaN handling)
|
|
1569
|
+
if not np.isnan(metrics["dist_avg"]):
|
|
1570
|
+
log_dict["eval/val/centroid_dist_avg"] = metrics["dist_avg"]
|
|
1571
|
+
if not np.isnan(metrics["dist_median"]):
|
|
1572
|
+
log_dict["eval/val/centroid_dist_median"] = metrics["dist_median"]
|
|
1573
|
+
if not np.isnan(metrics["dist_p90"]):
|
|
1574
|
+
log_dict["eval/val/centroid_dist_p90"] = metrics["dist_p90"]
|
|
1575
|
+
if not np.isnan(metrics["dist_p95"]):
|
|
1576
|
+
log_dict["eval/val/centroid_dist_p95"] = metrics["dist_p95"]
|
|
1577
|
+
if not np.isnan(metrics["dist_max"]):
|
|
1578
|
+
log_dict["eval/val/centroid_dist_max"] = metrics["dist_max"]
|
|
1579
|
+
|
|
1580
|
+
# Detection metrics
|
|
1581
|
+
log_dict["eval/val/centroid_precision"] = metrics["precision"]
|
|
1582
|
+
log_dict["eval/val/centroid_recall"] = metrics["recall"]
|
|
1583
|
+
log_dict["eval/val/centroid_f1"] = metrics["f1"]
|
|
1584
|
+
|
|
1585
|
+
# Counts
|
|
1586
|
+
log_dict["eval/val/centroid_n_tp"] = metrics["n_true_positives"]
|
|
1587
|
+
log_dict["eval/val/centroid_n_fp"] = metrics["n_false_positives"]
|
|
1588
|
+
log_dict["eval/val/centroid_n_fn"] = metrics["n_false_negatives"]
|
|
1589
|
+
|
|
1590
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
1591
|
+
|
|
1592
|
+
# Update best metrics in summary
|
|
1593
|
+
for key, value in log_dict.items():
|
|
1594
|
+
if key == "epoch":
|
|
1595
|
+
continue
|
|
1596
|
+
summary_key = f"best/{key}"
|
|
1597
|
+
current_best = wandb_logger.experiment.summary.get(summary_key)
|
|
1598
|
+
# For distance metrics, lower is better; for others, higher is better
|
|
1599
|
+
is_distance = "dist" in key
|
|
1600
|
+
if current_best is None:
|
|
1601
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
1602
|
+
elif is_distance and value < current_best:
|
|
1603
|
+
wandb_logger.experiment.summary[summary_key] = value
|
|
1604
|
+
elif not is_distance and value > current_best:
|
|
1605
|
+
wandb_logger.experiment.summary[summary_key] = value
|