sleap-nn 0.0.5__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,10 +2,15 @@
2
2
 
3
3
  import zmq
4
4
  import jsonpickle
5
- from typing import Callable, Optional
5
+ from typing import Callable, Optional, Union
6
6
  from lightning.pytorch.callbacks import Callback
7
+ from lightning.pytorch.callbacks.progress import TQDMProgressBar
7
8
  from loguru import logger
8
9
  import matplotlib
10
+
11
+ matplotlib.use(
12
+ "Agg"
13
+ ) # Use non-interactive backend to avoid tkinter issues on Windows CI
9
14
  import matplotlib.pyplot as plt
10
15
  from PIL import Image
11
16
  from pathlib import Path
@@ -14,6 +19,32 @@ import csv
14
19
  from sleap_nn import RANK
15
20
 
16
21
 
22
+ class SleapProgressBar(TQDMProgressBar):
23
+ """Custom progress bar with better formatting for small metric values.
24
+
25
+ The default TQDMProgressBar truncates small floats like 1e-5 to "0.000".
26
+ This subclass formats metrics using scientific notation when appropriate.
27
+ """
28
+
29
+ def get_metrics(
30
+ self, trainer, pl_module
31
+ ) -> dict[str, Union[int, str, float, dict[str, float]]]:
32
+ """Override to format metrics with scientific notation for small values."""
33
+ items = super().get_metrics(trainer, pl_module)
34
+ formatted = {}
35
+ for k, v in items.items():
36
+ if isinstance(v, float):
37
+ # Use scientific notation for very small values
38
+ if v != 0 and abs(v) < 0.001:
39
+ formatted[k] = f"{v:.2e}"
40
+ else:
41
+ # Use 4 decimal places for normal values
42
+ formatted[k] = f"{v:.4f}"
43
+ else:
44
+ formatted[k] = v
45
+ return formatted
46
+
47
+
17
48
  class CSVLoggerCallback(Callback):
18
49
  """Callback for logging metrics to csv.
19
50
 
@@ -53,6 +84,16 @@ class CSVLoggerCallback(Callback):
53
84
  for key in self.keys:
54
85
  if key == "epoch":
55
86
  log_data["epoch"] = trainer.current_epoch
87
+ elif key == "learning_rate":
88
+ # Handle both direct logging and LearningRateMonitor format (lr-*)
89
+ value = metrics.get(key, None)
90
+ if value is None:
91
+ # Look for lr-* keys from LearningRateMonitor
92
+ for metric_key in metrics.keys():
93
+ if metric_key.startswith("lr-"):
94
+ value = metrics[metric_key]
95
+ break
96
+ log_data[key] = value.item() if value is not None else None
56
97
  else:
57
98
  value = metrics.get(key, None)
58
99
  log_data[key] = value.item() if value is not None else None
@@ -66,7 +107,11 @@ class CSVLoggerCallback(Callback):
66
107
 
67
108
 
68
109
  class WandBPredImageLogger(Callback):
69
- """Callback for writing image predictions to wandb.
110
+ """Callback for writing image predictions to wandb as a Table.
111
+
112
+ .. deprecated::
113
+ This callback logs images to a wandb.Table which doesn't support
114
+ step sliders. Use WandBVizCallback instead for better UX.
70
115
 
71
116
  Attributes:
72
117
  viz_folder: Path to viz directory.
@@ -141,12 +186,275 @@ class WandBPredImageLogger(Callback):
141
186
  ]
142
187
  ]
143
188
  table = wandb.Table(columns=column_names, data=data)
144
- wandb.log({f"{self.wandb_run_name}": table})
189
+ # Use commit=False to accumulate with other metrics in this step
190
+ wandb.log({f"{self.wandb_run_name}": table}, commit=False)
145
191
 
146
192
  # Sync all processes after wandb logging
147
193
  trainer.strategy.barrier()
148
194
 
149
195
 
196
+ class WandBVizCallback(Callback):
197
+ """Callback for logging visualization images directly to wandb with slider support.
198
+
199
+ This callback logs images using wandb.log() which enables step slider navigation
200
+ in the wandb UI. Multiple visualization modes can be enabled simultaneously:
201
+ - viz_enabled: Pre-render with matplotlib (same as disk viz)
202
+ - viz_boxes: Interactive keypoint boxes with filtering
203
+ - viz_masks: Confidence map overlay with per-node toggling
204
+
205
+ Attributes:
206
+ train_viz_fn: Function that returns VisualizationData for training sample.
207
+ val_viz_fn: Function that returns VisualizationData for validation sample.
208
+ viz_enabled: Whether to log pre-rendered matplotlib images.
209
+ viz_boxes: Whether to log interactive keypoint boxes.
210
+ viz_masks: Whether to log confidence map overlay masks.
211
+ box_size: Size of keypoint boxes in pixels (for viz_boxes).
212
+ confmap_threshold: Threshold for confmap masks (for viz_masks).
213
+ log_table: Whether to also log to a wandb.Table (backwards compat).
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ train_viz_fn: Callable,
219
+ val_viz_fn: Callable,
220
+ viz_enabled: bool = True,
221
+ viz_boxes: bool = False,
222
+ viz_masks: bool = False,
223
+ box_size: float = 5.0,
224
+ confmap_threshold: float = 0.1,
225
+ log_table: bool = False,
226
+ ):
227
+ """Initialize the callback.
228
+
229
+ Args:
230
+ train_viz_fn: Callable that returns VisualizationData for a training sample.
231
+ val_viz_fn: Callable that returns VisualizationData for a validation sample.
232
+ viz_enabled: If True, log pre-rendered matplotlib images.
233
+ viz_boxes: If True, log interactive keypoint boxes.
234
+ viz_masks: If True, log confidence map overlay masks.
235
+ box_size: Size of keypoint boxes in pixels (for viz_boxes).
236
+ confmap_threshold: Threshold for confmap mask generation (for viz_masks).
237
+ log_table: If True, also log images to a wandb.Table (for backwards compat).
238
+ """
239
+ super().__init__()
240
+ self.train_viz_fn = train_viz_fn
241
+ self.val_viz_fn = val_viz_fn
242
+ self.viz_enabled = viz_enabled
243
+ self.viz_boxes = viz_boxes
244
+ self.viz_masks = viz_masks
245
+ self.log_table = log_table
246
+
247
+ # Import here to avoid circular imports
248
+ from sleap_nn.training.utils import WandBRenderer
249
+
250
+ self.box_size = box_size
251
+ self.confmap_threshold = confmap_threshold
252
+
253
+ # Create renderers for each enabled mode
254
+ self.renderers = {}
255
+ if viz_enabled:
256
+ self.renderers["direct"] = WandBRenderer(
257
+ mode="direct", box_size=box_size, confmap_threshold=confmap_threshold
258
+ )
259
+ if viz_boxes:
260
+ self.renderers["boxes"] = WandBRenderer(
261
+ mode="boxes", box_size=box_size, confmap_threshold=confmap_threshold
262
+ )
263
+ if viz_masks:
264
+ self.renderers["masks"] = WandBRenderer(
265
+ mode="masks", box_size=box_size, confmap_threshold=confmap_threshold
266
+ )
267
+
268
+ def _get_wandb_logger(self, trainer):
269
+ """Get the WandbLogger from trainer's loggers."""
270
+ from lightning.pytorch.loggers import WandbLogger
271
+
272
+ for logger in trainer.loggers:
273
+ if isinstance(logger, WandbLogger):
274
+ return logger
275
+ return None
276
+
277
+ def on_train_epoch_end(self, trainer, pl_module):
278
+ """Log visualization images at end of each epoch."""
279
+ if trainer.is_global_zero:
280
+ epoch = trainer.current_epoch
281
+
282
+ # Get the wandb logger to use its experiment for logging
283
+ wandb_logger = self._get_wandb_logger(trainer)
284
+ if wandb_logger is None:
285
+ return # No wandb logger, skip visualization logging
286
+
287
+ # Get visualization data
288
+ train_data = self.train_viz_fn()
289
+ val_data = self.val_viz_fn()
290
+
291
+ # Render and log for each enabled mode
292
+ # Use the logger's experiment to let Lightning manage step tracking
293
+ log_dict = {}
294
+ for mode_name, renderer in self.renderers.items():
295
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
296
+ train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
297
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
298
+ log_dict[f"train_predictions{suffix}"] = train_img
299
+ log_dict[f"val_predictions{suffix}"] = val_img
300
+
301
+ if log_dict:
302
+ # Include epoch so wandb can use it as x-axis (via define_metric)
303
+ log_dict["epoch"] = epoch
304
+ # Use commit=False to accumulate with other metrics in this step
305
+ # Lightning will commit when it logs its own metrics
306
+ wandb_logger.experiment.log(log_dict, commit=False)
307
+
308
+ # Optionally also log to table for backwards compat
309
+ if self.log_table and "direct" in self.renderers:
310
+ train_img = self.renderers["direct"].render(
311
+ train_data, caption=f"Train Epoch {epoch}"
312
+ )
313
+ val_img = self.renderers["direct"].render(
314
+ val_data, caption=f"Val Epoch {epoch}"
315
+ )
316
+ table = wandb.Table(
317
+ columns=["Epoch", "Train", "Validation"],
318
+ data=[[epoch, train_img, val_img]],
319
+ )
320
+ wandb_logger.experiment.log({"predictions_table": table}, commit=False)
321
+
322
+ # Sync all processes
323
+ trainer.strategy.barrier()
324
+
325
+
326
+ class WandBVizCallbackWithPAFs(WandBVizCallback):
327
+ """Extended WandBVizCallback that also logs PAF visualizations for bottom-up models."""
328
+
329
+ def __init__(
330
+ self,
331
+ train_viz_fn: Callable,
332
+ val_viz_fn: Callable,
333
+ train_pafs_viz_fn: Callable,
334
+ val_pafs_viz_fn: Callable,
335
+ viz_enabled: bool = True,
336
+ viz_boxes: bool = False,
337
+ viz_masks: bool = False,
338
+ box_size: float = 5.0,
339
+ confmap_threshold: float = 0.1,
340
+ log_table: bool = False,
341
+ ):
342
+ """Initialize the callback.
343
+
344
+ Args:
345
+ train_viz_fn: Callable returning VisualizationData for training sample.
346
+ val_viz_fn: Callable returning VisualizationData for validation sample.
347
+ train_pafs_viz_fn: Callable returning VisualizationData with PAFs for training.
348
+ val_pafs_viz_fn: Callable returning VisualizationData with PAFs for validation.
349
+ viz_enabled: If True, log pre-rendered matplotlib images.
350
+ viz_boxes: If True, log interactive keypoint boxes.
351
+ viz_masks: If True, log confidence map overlay masks.
352
+ box_size: Size of keypoint boxes in pixels.
353
+ confmap_threshold: Threshold for confmap mask generation.
354
+ log_table: If True, also log images to a wandb.Table.
355
+ """
356
+ super().__init__(
357
+ train_viz_fn=train_viz_fn,
358
+ val_viz_fn=val_viz_fn,
359
+ viz_enabled=viz_enabled,
360
+ viz_boxes=viz_boxes,
361
+ viz_masks=viz_masks,
362
+ box_size=box_size,
363
+ confmap_threshold=confmap_threshold,
364
+ log_table=log_table,
365
+ )
366
+ self.train_pafs_viz_fn = train_pafs_viz_fn
367
+ self.val_pafs_viz_fn = val_pafs_viz_fn
368
+
369
+ # Import here to avoid circular imports
370
+ from sleap_nn.training.utils import MatplotlibRenderer
371
+
372
+ self._mpl_renderer = MatplotlibRenderer()
373
+
374
+ def on_train_epoch_end(self, trainer, pl_module):
375
+ """Log visualization images including PAFs at end of each epoch."""
376
+ if trainer.is_global_zero:
377
+ epoch = trainer.current_epoch
378
+
379
+ # Get the wandb logger to use its experiment for logging
380
+ wandb_logger = self._get_wandb_logger(trainer)
381
+ if wandb_logger is None:
382
+ return # No wandb logger, skip visualization logging
383
+
384
+ # Get visualization data
385
+ train_data = self.train_viz_fn()
386
+ val_data = self.val_viz_fn()
387
+ train_pafs_data = self.train_pafs_viz_fn()
388
+ val_pafs_data = self.val_pafs_viz_fn()
389
+
390
+ # Render and log for each enabled mode
391
+ # Use the logger's experiment to let Lightning manage step tracking
392
+ log_dict = {}
393
+ for mode_name, renderer in self.renderers.items():
394
+ suffix = "" if mode_name == "direct" else f"_{mode_name}"
395
+ train_img = renderer.render(train_data, caption=f"Train Epoch {epoch}")
396
+ val_img = renderer.render(val_data, caption=f"Val Epoch {epoch}")
397
+ log_dict[f"train_predictions{suffix}"] = train_img
398
+ log_dict[f"val_predictions{suffix}"] = val_img
399
+
400
+ # Render PAFs (always use matplotlib/direct for PAFs)
401
+ from io import BytesIO
402
+ import matplotlib.pyplot as plt
403
+ from PIL import Image
404
+
405
+ train_pafs_fig = self._mpl_renderer.render_pafs(train_pafs_data)
406
+ buf = BytesIO()
407
+ train_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
408
+ buf.seek(0)
409
+ plt.close(train_pafs_fig)
410
+ train_pafs_pil = Image.open(buf)
411
+ log_dict["train_pafs"] = wandb.Image(
412
+ train_pafs_pil, caption=f"Train PAFs Epoch {epoch}"
413
+ )
414
+
415
+ val_pafs_fig = self._mpl_renderer.render_pafs(val_pafs_data)
416
+ buf = BytesIO()
417
+ val_pafs_fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
418
+ buf.seek(0)
419
+ plt.close(val_pafs_fig)
420
+ val_pafs_pil = Image.open(buf)
421
+ log_dict["val_pafs"] = wandb.Image(
422
+ val_pafs_pil, caption=f"Val PAFs Epoch {epoch}"
423
+ )
424
+
425
+ if log_dict:
426
+ # Include epoch so wandb can use it as x-axis (via define_metric)
427
+ log_dict["epoch"] = epoch
428
+ # Use commit=False to accumulate with other metrics in this step
429
+ # Lightning will commit when it logs its own metrics
430
+ wandb_logger.experiment.log(log_dict, commit=False)
431
+
432
+ # Optionally also log to table
433
+ if self.log_table and "direct" in self.renderers:
434
+ train_img = self.renderers["direct"].render(
435
+ train_data, caption=f"Train Epoch {epoch}"
436
+ )
437
+ val_img = self.renderers["direct"].render(
438
+ val_data, caption=f"Val Epoch {epoch}"
439
+ )
440
+ table = wandb.Table(
441
+ columns=["Epoch", "Train", "Validation", "Train PAFs", "Val PAFs"],
442
+ data=[
443
+ [
444
+ epoch,
445
+ train_img,
446
+ val_img,
447
+ log_dict["train_pafs"],
448
+ log_dict["val_pafs"],
449
+ ]
450
+ ],
451
+ )
452
+ wandb_logger.experiment.log({"predictions_table": table}, commit=False)
453
+
454
+ # Sync all processes
455
+ trainer.strategy.barrier()
456
+
457
+
150
458
  class MatplotlibSaver(Callback):
151
459
  """Callback for saving images rendered with matplotlib during training.
152
460
 
@@ -194,7 +502,7 @@ class MatplotlibSaver(Callback):
194
502
  ).as_posix()
195
503
 
196
504
  # Save rendered figure.
197
- figure.savefig(figure_path, format="png", pad_inches=0)
505
+ figure.savefig(figure_path, format="png")
198
506
  plt.close(figure)
199
507
 
200
508
  # Sync all processes after file I/O
@@ -303,7 +611,11 @@ class ProgressReporterZMQ(Callback):
303
611
  def on_train_start(self, trainer, pl_module):
304
612
  """Called at the beginning of training process."""
305
613
  if trainer.is_global_zero:
306
- self.send("train_begin")
614
+ # Include WandB URL if available
615
+ wandb_url = None
616
+ if wandb.run is not None:
617
+ wandb_url = wandb.run.url
618
+ self.send("train_begin", wandb_url=wandb_url)
307
619
  trainer.strategy.barrier()
308
620
 
309
621
  def on_train_end(self, trainer, pl_module):