sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +6 -1
- sleap_nn/cli.py +142 -3
- sleap_nn/config/data_config.py +44 -7
- sleap_nn/config/get_config.py +22 -20
- sleap_nn/config/trainer_config.py +12 -0
- sleap_nn/data/augmentation.py +54 -2
- sleap_nn/data/custom_datasets.py +22 -22
- sleap_nn/data/instance_cropping.py +70 -5
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/evaluation.py +99 -23
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/peak_finding.py +10 -2
- sleap_nn/inference/predictors.py +115 -20
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/predict.py +187 -10
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +64 -40
- sleap_nn/training/callbacks.py +317 -5
- sleap_nn/training/lightning_modules.py +325 -180
- sleap_nn/training/model_trainer.py +308 -22
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +22 -32
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/RECORD +30 -28
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
sleap_nn/training/callbacks.py
CHANGED
|
@@ -2,10 +2,15 @@
|
|
|
2
2
|
|
|
3
3
|
import zmq
|
|
4
4
|
import jsonpickle
|
|
5
|
-
from typing import Callable, Optional
|
|
5
|
+
from typing import Callable, Optional, Union
|
|
6
6
|
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.callbacks.progress import TQDMProgressBar
|
|
7
8
|
from loguru import logger
|
|
8
9
|
import matplotlib
|
|
10
|
+
|
|
11
|
+
matplotlib.use(
|
|
12
|
+
"Agg"
|
|
13
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
9
14
|
import matplotlib.pyplot as plt
|
|
10
15
|
from PIL import Image
|
|
11
16
|
from pathlib import Path
|
|
@@ -14,6 +19,32 @@ import csv
|
|
|
14
19
|
from sleap_nn import RANK
|
|
15
20
|
|
|
16
21
|
|
|
22
|
+
class SleapProgressBar(TQDMProgressBar):
|
|
23
|
+
"""Custom progress bar with better formatting for small metric values.
|
|
24
|
+
|
|
25
|
+
The default TQDMProgressBar truncates small floats like 1e-5 to "0.000".
|
|
26
|
+
This subclass formats metrics using scientific notation when appropriate.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def get_metrics(
|
|
30
|
+
self, trainer, pl_module
|
|
31
|
+
) -> dict[str, Union[int, str, float, dict[str, float]]]:
|
|
32
|
+
"""Override to format metrics with scientific notation for small values."""
|
|
33
|
+
items = super().get_metrics(trainer, pl_module)
|
|
34
|
+
formatted = {}
|
|
35
|
+
for k, v in items.items():
|
|
36
|
+
if isinstance(v, float):
|
|
37
|
+
# Use scientific notation for very small values
|
|
38
|
+
if v != 0 and abs(v) < 0.001:
|
|
39
|
+
formatted[k] = f"{v:.2e}"
|
|
40
|
+
else:
|
|
41
|
+
# Use 4 decimal places for normal values
|
|
42
|
+
formatted[k] = f"{v:.4f}"
|
|
43
|
+
else:
|
|
44
|
+
formatted[k] = v
|
|
45
|
+
return formatted
|
|
46
|
+
|
|
47
|
+
|
|
17
48
|
class CSVLoggerCallback(Callback):
|
|
18
49
|
"""Callback for logging metrics to csv.
|
|
19
50
|
|
|
@@ -53,6 +84,16 @@ class CSVLoggerCallback(Callback):
|
|
|
53
84
|
for key in self.keys:
|
|
54
85
|
if key == "epoch":
|
|
55
86
|
log_data["epoch"] = trainer.current_epoch
|
|
87
|
+
elif key == "learning_rate":
|
|
88
|
+
# Handle both direct logging and LearningRateMonitor format (lr-*)
|
|
89
|
+
value = metrics.get(key, None)
|
|
90
|
+
if value is None:
|
|
91
|
+
# Look for lr-* keys from LearningRateMonitor
|
|
92
|
+
for metric_key in metrics.keys():
|
|
93
|
+
if metric_key.startswith("lr-"):
|
|
94
|
+
value = metrics[metric_key]
|
|
95
|
+
break
|
|
96
|
+
log_data[key] = value.item() if value is not None else None
|
|
56
97
|
else:
|
|
57
98
|
value = metrics.get(key, None)
|
|
58
99
|
log_data[key] = value.item() if value is not None else None
|
|
@@ -66,7 +107,11 @@ class CSVLoggerCallback(Callback):
|
|
|
66
107
|
|
|
67
108
|
|
|
68
109
|
class WandBPredImageLogger(Callback):
|
|
69
|
-
"""Callback for writing image predictions to wandb.
|
|
110
|
+
"""Callback for writing image predictions to wandb as a Table.
|
|
111
|
+
|
|
112
|
+
.. deprecated::
|
|
113
|
+
This callback logs images to a wandb.Table which doesn't support
|
|
114
|
+
step sliders. Use WandBVizCallback instead for better UX.
|
|
70
115
|
|
|
71
116
|
Attributes:
|
|
72
117
|
viz_folder: Path to viz directory.
|
|
@@ -141,12 +186,275 @@ class WandBPredImageLogger(Callback):
|
|
|
141
186
|
]
|
|
142
187
|
]
|
|
143
188
|
table = wandb.Table(columns=column_names, data=data)
|
|
144
|
-
|
|
189
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
190
|
+
wandb.log({f"{self.wandb_run_name}": table}, commit=False)
|
|
145
191
|
|
|
146
192
|
# Sync all processes after wandb logging
|
|
147
193
|
trainer.strategy.barrier()
|
|
148
194
|
|
|
149
195
|
|
|
196
|
+
class WandBVizCallback(Callback):
|
|
197
|
+
"""Callback for logging visualization images directly to wandb with slider support.
|
|
198
|
+
|
|
199
|
+
This callback logs images using wandb.log() which enables step slider navigation
|
|
200
|
+
in the wandb UI. Multiple visualization modes can be enabled simultaneously:
|
|
201
|
+
- viz_enabled: Pre-render with matplotlib (same as disk viz)
|
|
202
|
+
- viz_boxes: Interactive keypoint boxes with filtering
|
|
203
|
+
- viz_masks: Confidence map overlay with per-node toggling
|
|
204
|
+
|
|
205
|
+
Attributes:
|
|
206
|
+
train_viz_fn: Function that returns VisualizationData for training sample.
|
|
207
|
+
val_viz_fn: Function that returns VisualizationData for validation sample.
|
|
208
|
+
viz_enabled: Whether to log pre-rendered matplotlib images.
|
|
209
|
+
viz_boxes: Whether to log interactive keypoint boxes.
|
|
210
|
+
viz_masks: Whether to log confidence map overlay masks.
|
|
211
|
+
box_size: Size of keypoint boxes in pixels (for viz_boxes).
|
|
212
|
+
confmap_threshold: Threshold for confmap masks (for viz_masks).
|
|
213
|
+
log_table: Whether to also log to a wandb.Table (backwards compat).
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
def __init__(
|
|
217
|
+
self,
|
|
218
|
+
train_viz_fn: Callable,
|
|
219
|
+
val_viz_fn: Callable,
|
|
220
|
+
viz_enabled: bool = True,
|
|
221
|
+
viz_boxes: bool = False,
|
|
222
|
+
viz_masks: bool = False,
|
|
223
|
+
box_size: float = 5.0,
|
|
224
|
+
confmap_threshold: float = 0.1,
|
|
225
|
+
log_table: bool = False,
|
|
226
|
+
):
|
|
227
|
+
"""Initialize the callback.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
train_viz_fn: Callable that returns VisualizationData for a training sample.
|
|
231
|
+
val_viz_fn: Callable that returns VisualizationData for a validation sample.
|
|
232
|
+
viz_enabled: If True, log pre-rendered matplotlib images.
|
|
233
|
+
viz_boxes: If True, log interactive keypoint boxes.
|
|
234
|
+
viz_masks: If True, log confidence map overlay masks.
|
|
235
|
+
box_size: Size of keypoint boxes in pixels (for viz_boxes).
|
|
236
|
+
confmap_threshold: Threshold for confmap mask generation (for viz_masks).
|
|
237
|
+
log_table: If True, also log images to a wandb.Table (for backwards compat).
|
|
238
|
+
"""
|
|
239
|
+
super().__init__()
|
|
240
|
+
self.train_viz_fn = train_viz_fn
|
|
241
|
+
self.val_viz_fn = val_viz_fn
|
|
242
|
+
self.viz_enabled = viz_enabled
|
|
243
|
+
self.viz_boxes = viz_boxes
|
|
244
|
+
self.viz_masks = viz_masks
|
|
245
|
+
self.log_table = log_table
|
|
246
|
+
|
|
247
|
+
# Import here to avoid circular imports
|
|
248
|
+
from sleap_nn.training.utils import WandBRenderer
|
|
249
|
+
|
|
250
|
+
self.box_size = box_size
|
|
251
|
+
self.confmap_threshold = confmap_threshold
|
|
252
|
+
|
|
253
|
+
# Create renderers for each enabled mode
|
|
254
|
+
self.renderers = {}
|
|
255
|
+
if viz_enabled:
|
|
256
|
+
self.renderers["direct"] = WandBRenderer(
|
|
257
|
+
mode="direct", box_size=box_size, confmap_threshold=confmap_threshold
|
|
258
|
+
)
|
|
259
|
+
if viz_boxes:
|
|
260
|
+
self.renderers["boxes"] = WandBRenderer(
|
|
261
|
+
mode="boxes", box_size=box_size, confmap_threshold=confmap_threshold
|
|
262
|
+
)
|
|
263
|
+
if viz_masks:
|
|
264
|
+
self.renderers["masks"] = WandBRenderer(
|
|
265
|
+
mode="masks", box_size=box_size, confmap_threshold=confmap_threshold
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _get_wandb_logger(self, trainer):
|
|
269
|
+
"""Get the WandbLogger from trainer's loggers."""
|
|
270
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
271
|
+
|
|
272
|
+
for logger in trainer.loggers:
|
|
273
|
+
if isinstance(logger, WandbLogger):
|
|
274
|
+
return logger
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
|
278
|
+
"""Log visualization images at end of each epoch."""
|
|
279
|
+
if trainer.is_global_zero:
|
|
280
|
+
epoch = trainer.current_epoch
|
|
281
|
+
|
|
282
|
+
# Get the wandb logger to use its experiment for logging
|
|
283
|
+
wandb_logger = self._get_wandb_logger(trainer)
|
|
284
|
+
if wandb_logger is None:
|
|
285
|
+
return # No wandb logger, skip visualization logging
|
|
286
|
+
|
|
287
|
+
# Get visualization data
|
|
288
|
+
train_data = self.train_viz_fn()
|
|
289
|
+
val_data = self.val_viz_fn()
|
|
290
|
+
|
|
291
|
+
# Render and log for each enabled mode
|
|
292
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
293
|
+
log_dict = {}
|
|
294
|
+
for mode_name, renderer in self.renderers.items():
|
|
295
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
296
|
+
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
297
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
298
|
+
log_dict[f"train_predictions{suffix}"] = train_img
|
|
299
|
+
log_dict[f"val_predictions{suffix}"] = val_img
|
|
300
|
+
|
|
301
|
+
if log_dict:
|
|
302
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
303
|
+
log_dict["epoch"] = epoch
|
|
304
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
305
|
+
# Lightning will commit when it logs its own metrics
|
|
306
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
307
|
+
|
|
308
|
+
# Optionally also log to table for backwards compat
|
|
309
|
+
if self.log_table and "direct" in self.renderers:
|
|
310
|
+
train_img = self.renderers["direct"].render(
|
|
311
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
312
|
+
)
|
|
313
|
+
val_img = self.renderers["direct"].render(
|
|
314
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
315
|
+
)
|
|
316
|
+
table = wandb.Table(
|
|
317
|
+
columns=["Epoch", "Train", "Validation"],
|
|
318
|
+
data=[[epoch, train_img, val_img]],
|
|
319
|
+
)
|
|
320
|
+
wandb_logger.experiment.log({"predictions_table": table}, commit=False)
|
|
321
|
+
|
|
322
|
+
# Sync all processes
|
|
323
|
+
trainer.strategy.barrier()
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
class WandBVizCallbackWithPAFs(WandBVizCallback):
|
|
327
|
+
"""Extended WandBVizCallback that also logs PAF visualizations for bottom-up models."""
|
|
328
|
+
|
|
329
|
+
def __init__(
|
|
330
|
+
self,
|
|
331
|
+
train_viz_fn: Callable,
|
|
332
|
+
val_viz_fn: Callable,
|
|
333
|
+
train_pafs_viz_fn: Callable,
|
|
334
|
+
val_pafs_viz_fn: Callable,
|
|
335
|
+
viz_enabled: bool = True,
|
|
336
|
+
viz_boxes: bool = False,
|
|
337
|
+
viz_masks: bool = False,
|
|
338
|
+
box_size: float = 5.0,
|
|
339
|
+
confmap_threshold: float = 0.1,
|
|
340
|
+
log_table: bool = False,
|
|
341
|
+
):
|
|
342
|
+
"""Initialize the callback.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
train_viz_fn: Callable returning VisualizationData for training sample.
|
|
346
|
+
val_viz_fn: Callable returning VisualizationData for validation sample.
|
|
347
|
+
train_pafs_viz_fn: Callable returning VisualizationData with PAFs for training.
|
|
348
|
+
val_pafs_viz_fn: Callable returning VisualizationData with PAFs for validation.
|
|
349
|
+
viz_enabled: If True, log pre-rendered matplotlib images.
|
|
350
|
+
viz_boxes: If True, log interactive keypoint boxes.
|
|
351
|
+
viz_masks: If True, log confidence map overlay masks.
|
|
352
|
+
box_size: Size of keypoint boxes in pixels.
|
|
353
|
+
confmap_threshold: Threshold for confmap mask generation.
|
|
354
|
+
log_table: If True, also log images to a wandb.Table.
|
|
355
|
+
"""
|
|
356
|
+
super().__init__(
|
|
357
|
+
train_viz_fn=train_viz_fn,
|
|
358
|
+
val_viz_fn=val_viz_fn,
|
|
359
|
+
viz_enabled=viz_enabled,
|
|
360
|
+
viz_boxes=viz_boxes,
|
|
361
|
+
viz_masks=viz_masks,
|
|
362
|
+
box_size=box_size,
|
|
363
|
+
confmap_threshold=confmap_threshold,
|
|
364
|
+
log_table=log_table,
|
|
365
|
+
)
|
|
366
|
+
self.train_pafs_viz_fn = train_pafs_viz_fn
|
|
367
|
+
self.val_pafs_viz_fn = val_pafs_viz_fn
|
|
368
|
+
|
|
369
|
+
# Import here to avoid circular imports
|
|
370
|
+
from sleap_nn.training.utils import MatplotlibRenderer
|
|
371
|
+
|
|
372
|
+
self._mpl_renderer = MatplotlibRenderer()
|
|
373
|
+
|
|
374
|
+
def on_train_epoch_end(self, trainer, pl_module):
|
|
375
|
+
"""Log visualization images including PAFs at end of each epoch."""
|
|
376
|
+
if trainer.is_global_zero:
|
|
377
|
+
epoch = trainer.current_epoch
|
|
378
|
+
|
|
379
|
+
# Get the wandb logger to use its experiment for logging
|
|
380
|
+
wandb_logger = self._get_wandb_logger(trainer)
|
|
381
|
+
if wandb_logger is None:
|
|
382
|
+
return # No wandb logger, skip visualization logging
|
|
383
|
+
|
|
384
|
+
# Get visualization data
|
|
385
|
+
train_data = self.train_viz_fn()
|
|
386
|
+
val_data = self.val_viz_fn()
|
|
387
|
+
train_pafs_data = self.train_pafs_viz_fn()
|
|
388
|
+
val_pafs_data = self.val_pafs_viz_fn()
|
|
389
|
+
|
|
390
|
+
# Render and log for each enabled mode
|
|
391
|
+
# Use the logger's experiment to let Lightning manage step tracking
|
|
392
|
+
log_dict = {}
|
|
393
|
+
for mode_name, renderer in self.renderers.items():
|
|
394
|
+
suffix = "" if mode_name == "direct" else f"_{mode_name}"
|
|
395
|
+
train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
|
|
396
|
+
val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
|
|
397
|
+
log_dict[f"train_predictions{suffix}"] = train_img
|
|
398
|
+
log_dict[f"val_predictions{suffix}"] = val_img
|
|
399
|
+
|
|
400
|
+
# Render PAFs (always use matplotlib/direct for PAFs)
|
|
401
|
+
from io import BytesIO
|
|
402
|
+
import matplotlib.pyplot as plt
|
|
403
|
+
from PIL import Image
|
|
404
|
+
|
|
405
|
+
train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
|
|
406
|
+
buf = BytesIO()
|
|
407
|
+
train_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
408
|
+
buf.seek(0)
|
|
409
|
+
plt.close(train_pafs_fig)
|
|
410
|
+
train_pafs_pil = Image.open(buf)
|
|
411
|
+
log_dict["train_pafs"] = wandb.Image(
|
|
412
|
+
train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
|
|
416
|
+
buf = BytesIO()
|
|
417
|
+
val_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
418
|
+
buf.seek(0)
|
|
419
|
+
plt.close(val_pafs_fig)
|
|
420
|
+
val_pafs_pil = Image.open(buf)
|
|
421
|
+
log_dict["val_pafs"] = wandb.Image(
|
|
422
|
+
val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
if log_dict:
|
|
426
|
+
# Include epoch so wandb can use it as x-axis (via define_metric)
|
|
427
|
+
log_dict["epoch"] = epoch
|
|
428
|
+
# Use commit=False to accumulate with other metrics in this step
|
|
429
|
+
# Lightning will commit when it logs its own metrics
|
|
430
|
+
wandb_logger.experiment.log(log_dict, commit=False)
|
|
431
|
+
|
|
432
|
+
# Optionally also log to table
|
|
433
|
+
if self.log_table and "direct" in self.renderers:
|
|
434
|
+
train_img = self.renderers["direct"].render(
|
|
435
|
+
train_data, caption=f"Train Epoch {epoch}"
|
|
436
|
+
)
|
|
437
|
+
val_img = self.renderers["direct"].render(
|
|
438
|
+
val_data, caption=f"Val Epoch {epoch}"
|
|
439
|
+
)
|
|
440
|
+
table = wandb.Table(
|
|
441
|
+
columns=["Epoch", "Train", "Validation", "Train PAFs", "Val PAFs"],
|
|
442
|
+
data=[
|
|
443
|
+
[
|
|
444
|
+
epoch,
|
|
445
|
+
train_img,
|
|
446
|
+
val_img,
|
|
447
|
+
log_dict["train_pafs"],
|
|
448
|
+
log_dict["val_pafs"],
|
|
449
|
+
]
|
|
450
|
+
],
|
|
451
|
+
)
|
|
452
|
+
wandb_logger.experiment.log({"predictions_table": table}, commit=False)
|
|
453
|
+
|
|
454
|
+
# Sync all processes
|
|
455
|
+
trainer.strategy.barrier()
|
|
456
|
+
|
|
457
|
+
|
|
150
458
|
class MatplotlibSaver(Callback):
|
|
151
459
|
"""Callback for saving images rendered with matplotlib during training.
|
|
152
460
|
|
|
@@ -194,7 +502,7 @@ class MatplotlibSaver(Callback):
|
|
|
194
502
|
).as_posix()
|
|
195
503
|
|
|
196
504
|
# Save rendered figure.
|
|
197
|
-
figure.savefig(figure_path, format="png"
|
|
505
|
+
figure.savefig(figure_path, format="png")
|
|
198
506
|
plt.close(figure)
|
|
199
507
|
|
|
200
508
|
# Sync all processes after file I/O
|
|
@@ -303,7 +611,11 @@ class ProgressReporterZMQ(Callback):
|
|
|
303
611
|
def on_train_start(self, trainer, pl_module):
|
|
304
612
|
"""Called at the beginning of training process."""
|
|
305
613
|
if trainer.is_global_zero:
|
|
306
|
-
|
|
614
|
+
# Include WandB URL if available
|
|
615
|
+
wandb_url = None
|
|
616
|
+
if wandb.run is not None:
|
|
617
|
+
wandb_url = wandb.run.url
|
|
618
|
+
self.send("train_begin", wandb_url=wandb_url)
|
|
307
619
|
trainer.strategy.barrier()
|
|
308
620
|
|
|
309
621
|
def on_train_end(self, trainer, pl_module):
|