sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,191 @@
1
+ """Custom learning rate schedulers for sleap-nn training.
2
+
3
+ This module provides learning rate schedulers with warmup phases that are commonly
4
+ used in deep learning for pose estimation and computer vision tasks.
5
+ """
6
+
7
+ import math
8
+ from torch.optim.lr_scheduler import LRScheduler
9
+
10
+
11
+ class LinearWarmupCosineAnnealingLR(LRScheduler):
12
+ """Cosine annealing learning rate scheduler with linear warmup.
13
+
14
+ The learning rate increases linearly from `warmup_start_lr` to the optimizer's
15
+ base learning rate over `warmup_epochs`, then decreases following a cosine
16
+ curve to `eta_min` over the remaining epochs.
17
+
18
+ This schedule is widely used in vision transformers and modern CNN architectures
19
+ as it provides stable early training (warmup) and smooth convergence (cosine decay).
20
+
21
+ Args:
22
+ optimizer: Wrapped optimizer.
23
+ warmup_epochs: Number of epochs for the linear warmup phase.
24
+ max_epochs: Total number of training epochs.
25
+ warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
26
+ eta_min: Minimum learning rate at the end of the schedule. Default: 0.0.
27
+ last_epoch: The index of the last epoch. Default: -1.
28
+
29
+ Example:
30
+ >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
31
+ >>> scheduler = LinearWarmupCosineAnnealingLR(
32
+ ... optimizer, warmup_epochs=5, max_epochs=100, eta_min=1e-6
33
+ ... )
34
+ >>> for epoch in range(100):
35
+ ... train(...)
36
+ ... scheduler.step()
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ optimizer,
42
+ warmup_epochs: int,
43
+ max_epochs: int,
44
+ warmup_start_lr: float = 0.0,
45
+ eta_min: float = 0.0,
46
+ last_epoch: int = -1,
47
+ ):
48
+ """Initialize the scheduler.
49
+
50
+ Args:
51
+ optimizer: Wrapped optimizer.
52
+ warmup_epochs: Number of epochs for the linear warmup phase.
53
+ max_epochs: Total number of training epochs.
54
+ warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
55
+ eta_min: Minimum learning rate at the end of the schedule. Default: 0.0.
56
+ last_epoch: The index of the last epoch. Default: -1.
57
+ """
58
+ if warmup_epochs < 0:
59
+ raise ValueError(f"warmup_epochs must be >= 0, got {warmup_epochs}")
60
+ if max_epochs <= 0:
61
+ raise ValueError(f"max_epochs must be > 0, got {max_epochs}")
62
+ if warmup_epochs >= max_epochs:
63
+ raise ValueError(
64
+ f"warmup_epochs ({warmup_epochs}) must be < max_epochs ({max_epochs})"
65
+ )
66
+ if warmup_start_lr < 0:
67
+ raise ValueError(f"warmup_start_lr must be >= 0, got {warmup_start_lr}")
68
+ if eta_min < 0:
69
+ raise ValueError(f"eta_min must be >= 0, got {eta_min}")
70
+
71
+ self.warmup_epochs = warmup_epochs
72
+ self.max_epochs = max_epochs
73
+ self.warmup_start_lr = warmup_start_lr
74
+ self.eta_min = eta_min
75
+ super().__init__(optimizer, last_epoch)
76
+
77
+ def get_lr(self):
78
+ """Compute the learning rate at the current epoch."""
79
+ if self.last_epoch < self.warmup_epochs:
80
+ # Linear warmup phase
81
+ if self.warmup_epochs == 0:
82
+ return list(self.base_lrs)
83
+ alpha = self.last_epoch / self.warmup_epochs
84
+ return [
85
+ self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr)
86
+ for base_lr in self.base_lrs
87
+ ]
88
+ else:
89
+ # Cosine annealing phase
90
+ decay_epochs = self.max_epochs - self.warmup_epochs
91
+ if decay_epochs == 0:
92
+ return [self.eta_min for _ in self.base_lrs]
93
+ progress = (self.last_epoch - self.warmup_epochs) / decay_epochs
94
+ # Clamp progress to [0, 1] to handle epochs beyond max_epochs
95
+ progress = min(1.0, progress)
96
+ return [
97
+ self.eta_min
98
+ + (base_lr - self.eta_min) * (1 + math.cos(math.pi * progress)) / 2
99
+ for base_lr in self.base_lrs
100
+ ]
101
+
102
+
103
+ class LinearWarmupLinearDecayLR(LRScheduler):
104
+ """Linear warmup followed by linear decay learning rate scheduler.
105
+
106
+ The learning rate increases linearly from `warmup_start_lr` to the optimizer's
107
+ base learning rate over `warmup_epochs`, then decreases linearly to `end_lr`
108
+ over the remaining epochs.
109
+
110
+ This schedule provides a simple, interpretable learning rate trajectory and is
111
+ commonly used in transformer-based models and NLP tasks.
112
+
113
+ Args:
114
+ optimizer: Wrapped optimizer.
115
+ warmup_epochs: Number of epochs for the linear warmup phase.
116
+ max_epochs: Total number of training epochs.
117
+ warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
118
+ end_lr: Learning rate at the end of training. Default: 0.0.
119
+ last_epoch: The index of the last epoch. Default: -1.
120
+
121
+ Example:
122
+ >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
123
+ >>> scheduler = LinearWarmupLinearDecayLR(
124
+ ... optimizer, warmup_epochs=5, max_epochs=100, end_lr=1e-6
125
+ ... )
126
+ >>> for epoch in range(100):
127
+ ... train(...)
128
+ ... scheduler.step()
129
+ """
130
+
131
+ def __init__(
132
+ self,
133
+ optimizer,
134
+ warmup_epochs: int,
135
+ max_epochs: int,
136
+ warmup_start_lr: float = 0.0,
137
+ end_lr: float = 0.0,
138
+ last_epoch: int = -1,
139
+ ):
140
+ """Initialize the scheduler.
141
+
142
+ Args:
143
+ optimizer: Wrapped optimizer.
144
+ warmup_epochs: Number of epochs for the linear warmup phase.
145
+ max_epochs: Total number of training epochs.
146
+ warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
147
+ end_lr: Learning rate at the end of training. Default: 0.0.
148
+ last_epoch: The index of the last epoch. Default: -1.
149
+ """
150
+ if warmup_epochs < 0:
151
+ raise ValueError(f"warmup_epochs must be >= 0, got {warmup_epochs}")
152
+ if max_epochs <= 0:
153
+ raise ValueError(f"max_epochs must be > 0, got {max_epochs}")
154
+ if warmup_epochs >= max_epochs:
155
+ raise ValueError(
156
+ f"warmup_epochs ({warmup_epochs}) must be < max_epochs ({max_epochs})"
157
+ )
158
+ if warmup_start_lr < 0:
159
+ raise ValueError(f"warmup_start_lr must be >= 0, got {warmup_start_lr}")
160
+ if end_lr < 0:
161
+ raise ValueError(f"end_lr must be >= 0, got {end_lr}")
162
+
163
+ self.warmup_epochs = warmup_epochs
164
+ self.max_epochs = max_epochs
165
+ self.warmup_start_lr = warmup_start_lr
166
+ self.end_lr = end_lr
167
+ super().__init__(optimizer, last_epoch)
168
+
169
+ def get_lr(self):
170
+ """Compute the learning rate at the current epoch."""
171
+ if self.last_epoch < self.warmup_epochs:
172
+ # Linear warmup phase
173
+ if self.warmup_epochs == 0:
174
+ return list(self.base_lrs)
175
+ alpha = self.last_epoch / self.warmup_epochs
176
+ return [
177
+ self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr)
178
+ for base_lr in self.base_lrs
179
+ ]
180
+ else:
181
+ # Linear decay phase
182
+ decay_epochs = self.max_epochs - self.warmup_epochs
183
+ if decay_epochs == 0:
184
+ return [self.end_lr for _ in self.base_lrs]
185
+ progress = (self.last_epoch - self.warmup_epochs) / decay_epochs
186
+ # Clamp progress to [0, 1] to handle epochs beyond max_epochs
187
+ progress = min(1.0, progress)
188
+ return [
189
+ base_lr + progress * (self.end_lr - base_lr)
190
+ for base_lr in self.base_lrs
191
+ ]
@@ -1,13 +1,19 @@
1
1
  """Miscellaneous utility functions for training."""
2
2
 
3
+ from dataclasses import dataclass, field
4
+ from io import BytesIO
3
5
  import numpy as np
6
+ import matplotlib
7
+
8
+ matplotlib.use(
9
+ "Agg"
10
+ ) # Use non-interactive backend to avoid tkinter issues on Windows CI
4
11
  import matplotlib.pyplot as plt
5
12
  from loguru import logger
6
13
  from torch import nn
7
14
  import torch.distributed as dist
8
- import matplotlib
9
15
  import seaborn as sns
10
- from typing import List
16
+ from typing import List, Optional
11
17
  import shutil
12
18
  import os
13
19
  import subprocess
@@ -236,3 +242,362 @@ def plot_peaks(
236
242
  )
237
243
  )
238
244
  return handles
245
+
246
+
247
+ @dataclass
248
+ class VisualizationData:
249
+ """Container for visualization data from a single sample.
250
+
251
+ This dataclass decouples data extraction from rendering, allowing the same
252
+ data to be rendered to different output targets (matplotlib, wandb, etc.).
253
+
254
+ Attributes:
255
+ image: Input image as (H, W, C) numpy array, normalized to [0, 1].
256
+ pred_confmaps: Predicted confidence maps as (H, W, nodes) array, values in [0, 1].
257
+ pred_peaks: Predicted keypoints as (instances, nodes, 2) or (nodes, 2) array.
258
+ pred_peak_values: Confidence values as (instances, nodes) or (nodes,) array.
259
+ gt_instances: Ground truth keypoints, same shape as pred_peaks.
260
+ node_names: List of node/keypoint names, e.g., ["head", "thorax", ...].
261
+ output_scale: Ratio of confmap size to image size (confmap_h / image_h).
262
+ is_paired: Whether GT and predictions can be paired for error visualization.
263
+ pred_pafs: Part affinity fields for bottom-up models, optional.
264
+ pred_class_maps: Class maps for multi-class models, optional.
265
+ """
266
+
267
+ image: np.ndarray
268
+ pred_confmaps: np.ndarray
269
+ pred_peaks: np.ndarray
270
+ pred_peak_values: np.ndarray
271
+ gt_instances: np.ndarray
272
+ node_names: List[str] = field(default_factory=list)
273
+ output_scale: float = 1.0
274
+ is_paired: bool = True
275
+ pred_pafs: Optional[np.ndarray] = None
276
+ pred_class_maps: Optional[np.ndarray] = None
277
+
278
+
279
+ class MatplotlibRenderer:
280
+ """Renders VisualizationData to matplotlib figures."""
281
+
282
+ def render(self, data: VisualizationData) -> matplotlib.figure.Figure:
283
+ """Render visualization data to a matplotlib figure.
284
+
285
+ Args:
286
+ data: VisualizationData containing image, confmaps, peaks, etc.
287
+
288
+ Returns:
289
+ A matplotlib Figure object.
290
+ """
291
+ img = data.image
292
+ scale = 1.0
293
+ if img.shape[0] < 512:
294
+ scale = 2.0
295
+ if img.shape[0] < 256:
296
+ scale = 4.0
297
+
298
+ fig = plot_img(img, dpi=72 * scale, scale=scale)
299
+ plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
300
+ plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
301
+ return fig
302
+
303
+ def render_pafs(self, data: VisualizationData) -> matplotlib.figure.Figure:
304
+ """Render PAF magnitude visualization.
305
+
306
+ Args:
307
+ data: VisualizationData with pred_pafs populated.
308
+
309
+ Returns:
310
+ A matplotlib Figure object showing PAF magnitudes.
311
+ """
312
+ if data.pred_pafs is None:
313
+ raise ValueError("pred_pafs is None, cannot render PAFs")
314
+
315
+ img = data.image
316
+ scale = 1.0
317
+ if img.shape[0] < 512:
318
+ scale = 2.0
319
+ if img.shape[0] < 256:
320
+ scale = 4.0
321
+
322
+ # Compute PAF magnitude
323
+ pafs = data.pred_pafs # (H, W, 2*edges) or (H, W, edges, 2)
324
+ if pafs.ndim == 3:
325
+ n_edges = pafs.shape[-1] // 2
326
+ pafs = pafs.reshape(pafs.shape[0], pafs.shape[1], n_edges, 2)
327
+ magnitude = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
328
+ magnitude = magnitude.max(axis=-1) # Max over edges
329
+
330
+ fig = plot_img(img, dpi=72 * scale, scale=scale)
331
+ ax = plt.gca()
332
+
333
+ # Calculate PAF output scale from actual PAF dimensions, not confmap output_scale
334
+ # PAFs may have a different output_stride than confmaps
335
+ paf_output_scale = magnitude.shape[0] / img.shape[0]
336
+
337
+ ax.imshow(
338
+ magnitude,
339
+ alpha=0.5,
340
+ origin="upper",
341
+ cmap="viridis",
342
+ extent=[
343
+ -0.5,
344
+ magnitude.shape[1] / paf_output_scale - 0.5,
345
+ magnitude.shape[0] / paf_output_scale - 0.5,
346
+ -0.5,
347
+ ],
348
+ )
349
+ return fig
350
+
351
+
352
+ class WandBRenderer:
353
+ """Renders VisualizationData to wandb.Image objects.
354
+
355
+ Supports multiple rendering modes:
356
+ - "direct": Pre-render with matplotlib, convert to wandb.Image
357
+ - "boxes": Use wandb boxes for interactive keypoint visualization
358
+ - "masks": Use wandb masks for confidence map overlay
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ mode: str = "direct",
364
+ box_size: float = 5.0,
365
+ confmap_threshold: float = 0.1,
366
+ min_size: int = 512,
367
+ ):
368
+ """Initialize the renderer.
369
+
370
+ Args:
371
+ mode: Rendering mode - "direct", "boxes", or "masks".
372
+ box_size: Size of keypoint boxes in pixels (for "boxes" mode).
373
+ confmap_threshold: Threshold for confmap mask (for "masks" mode).
374
+ min_size: Minimum image dimension. Smaller images will be upscaled.
375
+ """
376
+ self.mode = mode
377
+ self.box_size = box_size
378
+ self.confmap_threshold = confmap_threshold
379
+ self.min_size = min_size
380
+ self._mpl_renderer = MatplotlibRenderer()
381
+
382
+ def render(
383
+ self, data: VisualizationData, caption: Optional[str] = None
384
+ ) -> "wandb.Image":
385
+ """Render visualization data to a wandb.Image.
386
+
387
+ Args:
388
+ data: VisualizationData containing image, confmaps, peaks, etc.
389
+ caption: Optional caption for the image.
390
+
391
+ Returns:
392
+ A wandb.Image object.
393
+ """
394
+ import wandb
395
+
396
+ if self.mode == "boxes":
397
+ return self._render_with_boxes(data, caption)
398
+ elif self.mode == "masks":
399
+ return self._render_with_masks(data, caption)
400
+ else: # "direct"
401
+ return self._render_direct(data, caption)
402
+
403
+ def _get_scale_factor(self, img_h: int, img_w: int) -> int:
404
+ """Calculate scale factor to ensure minimum image size."""
405
+ min_dim = min(img_h, img_w)
406
+ if min_dim >= self.min_size:
407
+ return 1
408
+ return int(np.ceil(self.min_size / min_dim))
409
+
410
+ def _render_direct(
411
+ self, data: VisualizationData, caption: Optional[str] = None
412
+ ) -> "wandb.Image":
413
+ """Pre-render with matplotlib, return as wandb.Image."""
414
+ import wandb
415
+ from PIL import Image
416
+
417
+ fig = self._mpl_renderer.render(data)
418
+
419
+ # Convert figure to PIL Image
420
+ buf = BytesIO()
421
+ fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
422
+ buf.seek(0)
423
+ plt.close(fig)
424
+
425
+ pil_image = Image.open(buf)
426
+ return wandb.Image(pil_image, caption=caption)
427
+
428
+ def _render_with_boxes(
429
+ self, data: VisualizationData, caption: Optional[str] = None
430
+ ) -> "wandb.Image":
431
+ """Use wandb boxes for interactive keypoint visualization."""
432
+ import wandb
433
+ from PIL import Image
434
+
435
+ # Prepare class labels from node names
436
+ class_labels = {i: name for i, name in enumerate(data.node_names)}
437
+ if not class_labels:
438
+ class_labels = {i: f"node_{i}" for i in range(data.pred_peaks.shape[-2])}
439
+
440
+ # Convert image to uint8
441
+ img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
442
+ # Handle single-channel images: squeeze (H, W, 1) -> (H, W)
443
+ if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
444
+ img_uint8 = img_uint8.squeeze(axis=2)
445
+ img_h, img_w = img_uint8.shape[:2]
446
+
447
+ # Scale up small images for better visibility in wandb
448
+ scale = self._get_scale_factor(img_h, img_w)
449
+ if scale > 1:
450
+ pil_img = Image.fromarray(img_uint8)
451
+ pil_img = pil_img.resize(
452
+ (img_w * scale, img_h * scale), resample=Image.BILINEAR
453
+ )
454
+ img_uint8 = np.array(pil_img)
455
+
456
+ # Build ground truth boxes (use percent domain for proper thumbnail scaling)
457
+ gt_box_data = self._peaks_to_boxes(
458
+ data.gt_instances, data.node_names, img_w, img_h, is_gt=True
459
+ )
460
+
461
+ # Build prediction boxes
462
+ pred_box_data = self._peaks_to_boxes(
463
+ data.pred_peaks,
464
+ data.node_names,
465
+ img_w,
466
+ img_h,
467
+ peak_values=data.pred_peak_values,
468
+ is_gt=False,
469
+ )
470
+
471
+ return wandb.Image(
472
+ img_uint8,
473
+ boxes={
474
+ "ground_truth": {"box_data": gt_box_data, "class_labels": class_labels},
475
+ "predictions": {
476
+ "box_data": pred_box_data,
477
+ "class_labels": class_labels,
478
+ },
479
+ },
480
+ caption=caption,
481
+ )
482
+
483
+ def _peaks_to_boxes(
484
+ self,
485
+ peaks: np.ndarray,
486
+ node_names: List[str],
487
+ img_w: int,
488
+ img_h: int,
489
+ peak_values: Optional[np.ndarray] = None,
490
+ is_gt: bool = False,
491
+ ) -> List[dict]:
492
+ """Convert peaks array to wandb box_data format.
493
+
494
+ Args:
495
+ peaks: Keypoints as (instances, nodes, 2) or (nodes, 2).
496
+ node_names: List of node names.
497
+ img_w: Image width in pixels.
498
+ img_h: Image height in pixels.
499
+ peak_values: Optional confidence values.
500
+ is_gt: Whether these are ground truth points.
501
+
502
+ Returns:
503
+ List of box dictionaries for wandb.
504
+ """
505
+ box_data = []
506
+
507
+ # Normalize shape to (instances, nodes, 2)
508
+ if peaks.ndim == 2:
509
+ peaks = peaks[np.newaxis, ...]
510
+ if peak_values is not None and peak_values.ndim == 1:
511
+ peak_values = peak_values[np.newaxis, ...]
512
+
513
+ # Convert box_size from pixels to percent
514
+ box_w_pct = self.box_size / img_w
515
+ box_h_pct = self.box_size / img_h
516
+
517
+ for inst_idx, instance in enumerate(peaks):
518
+ for node_idx, (x, y) in enumerate(instance):
519
+ if np.isnan(x) or np.isnan(y):
520
+ continue
521
+
522
+ node_name = (
523
+ node_names[node_idx]
524
+ if node_idx < len(node_names)
525
+ else f"node_{node_idx}"
526
+ )
527
+
528
+ # Convert pixel coordinates to percent (0-1 range)
529
+ x_pct = float(x) / img_w
530
+ y_pct = float(y) / img_h
531
+
532
+ box = {
533
+ "position": {
534
+ "middle": [x_pct, y_pct],
535
+ "width": box_w_pct,
536
+ "height": box_h_pct,
537
+ },
538
+ "domain": "percent",
539
+ "class_id": node_idx,
540
+ }
541
+
542
+ if is_gt:
543
+ box["box_caption"] = f"GT: {node_name}"
544
+ else:
545
+ if peak_values is not None:
546
+ conf = float(peak_values[inst_idx, node_idx])
547
+ box["box_caption"] = f"{node_name} ({conf:.2f})"
548
+ box["scores"] = {"confidence": conf}
549
+ else:
550
+ box["box_caption"] = node_name
551
+
552
+ box_data.append(box)
553
+
554
+ return box_data
555
+
556
+ def _render_with_masks(
557
+ self, data: VisualizationData, caption: Optional[str] = None
558
+ ) -> "wandb.Image":
559
+ """Use wandb masks for confidence map overlay.
560
+
561
+ Uses argmax approach: each pixel shows the dominant node.
562
+ """
563
+ import wandb
564
+
565
+ # Prepare class labels (0 = background, 1+ = nodes)
566
+ class_labels = {0: "background"}
567
+ for i, name in enumerate(data.node_names):
568
+ class_labels[i + 1] = name
569
+ if not data.node_names:
570
+ n_nodes = data.pred_confmaps.shape[-1]
571
+ for i in range(n_nodes):
572
+ class_labels[i + 1] = f"node_{i}"
573
+
574
+ # Create argmax mask from confmaps
575
+ confmaps = data.pred_confmaps # (H/stride, W/stride, nodes)
576
+ max_vals = confmaps.max(axis=-1)
577
+ argmax_map = confmaps.argmax(axis=-1) + 1 # +1 for background offset
578
+ argmax_map[max_vals < self.confmap_threshold] = 0 # Background
579
+
580
+ # Convert image to uint8
581
+ img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
582
+ # Handle single-channel images: (H, W, 1) -> (H, W)
583
+ if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
584
+ img_uint8 = img_uint8.squeeze(axis=2)
585
+ img_h, img_w = img_uint8.shape[:2]
586
+
587
+ # Resize mask to match image dimensions (confmaps are H/stride, W/stride)
588
+ from PIL import Image
589
+
590
+ mask_pil = Image.fromarray(argmax_map.astype(np.uint8))
591
+ mask_pil = mask_pil.resize((img_w, img_h), resample=Image.NEAREST)
592
+ argmax_map = np.array(mask_pil)
593
+
594
+ return wandb.Image(
595
+ img_uint8,
596
+ masks={
597
+ "confidence_maps": {
598
+ "mask_data": argmax_map.astype(np.uint8),
599
+ "class_labels": class_labels,
600
+ }
601
+ },
602
+ caption=caption,
603
+ )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: sleap-nn
3
- Version: 0.0.5
3
+ Version: 0.1.0
4
4
  Summary: Neural network backend for training and inference for animal pose estimation.
5
5
  Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
6
6
  License: BSD-3-Clause
@@ -13,10 +13,10 @@ Classifier: Programming Language :: Python :: 3.13
13
13
  Requires-Python: <3.14,>=3.11
14
14
  Description-Content-Type: text/markdown
15
15
  License-File: LICENSE
16
- Requires-Dist: sleap-io>=0.5.7
16
+ Requires-Dist: sleap-io<0.7.0,>=0.6.2
17
17
  Requires-Dist: numpy
18
18
  Requires-Dist: lightning
19
- Requires-Dist: kornia
19
+ Requires-Dist: skia-python>=87.0
20
20
  Requires-Dist: jsonpickle
21
21
  Requires-Dist: scipy
22
22
  Requires-Dist: attrs
@@ -32,37 +32,33 @@ Requires-Dist: hydra-core
32
32
  Requires-Dist: jupyter
33
33
  Requires-Dist: jupyterlab
34
34
  Requires-Dist: pyzmq
35
+ Requires-Dist: rich-click>=1.9.5
35
36
  Provides-Extra: torch
36
37
  Requires-Dist: torch; extra == "torch"
37
- Requires-Dist: torchvision<0.24.0,>=0.20.0; extra == "torch"
38
+ Requires-Dist: torchvision>=0.20.0; extra == "torch"
38
39
  Provides-Extra: torch-cpu
39
40
  Requires-Dist: torch; extra == "torch-cpu"
40
- Requires-Dist: torchvision<0.24.0,>=0.20.0; extra == "torch-cpu"
41
+ Requires-Dist: torchvision>=0.20.0; extra == "torch-cpu"
41
42
  Provides-Extra: torch-cuda118
42
43
  Requires-Dist: torch; extra == "torch-cuda118"
43
- Requires-Dist: torchvision<0.24.0,>=0.20.0; extra == "torch-cuda118"
44
+ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda118"
44
45
  Provides-Extra: torch-cuda128
45
46
  Requires-Dist: torch; extra == "torch-cuda128"
46
- Requires-Dist: torchvision<0.24.0,>=0.20.0; extra == "torch-cuda128"
47
- Provides-Extra: dev
48
- Requires-Dist: pytest; extra == "dev"
49
- Requires-Dist: pytest-cov; extra == "dev"
50
- Requires-Dist: black; extra == "dev"
51
- Requires-Dist: pydocstyle; extra == "dev"
52
- Requires-Dist: toml; extra == "dev"
53
- Requires-Dist: twine; extra == "dev"
54
- Requires-Dist: build; extra == "dev"
55
- Requires-Dist: ipython; extra == "dev"
56
- Requires-Dist: ruff; extra == "dev"
57
- Provides-Extra: docs
58
- Requires-Dist: mkdocs; extra == "docs"
59
- Requires-Dist: mkdocs-material; extra == "docs"
60
- Requires-Dist: mkdocs-jupyter; extra == "docs"
61
- Requires-Dist: mike; extra == "docs"
62
- Requires-Dist: mkdocstrings[python]; extra == "docs"
63
- Requires-Dist: mkdocs-gen-files; extra == "docs"
64
- Requires-Dist: mkdocs-literate-nav; extra == "docs"
65
- Requires-Dist: mkdocs-section-index; extra == "docs"
47
+ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
48
+ Provides-Extra: torch-cuda130
49
+ Requires-Dist: torch; extra == "torch-cuda130"
50
+ Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
51
+ Provides-Extra: export
52
+ Requires-Dist: onnx>=1.15.0; extra == "export"
53
+ Requires-Dist: onnxruntime>=1.16.0; extra == "export"
54
+ Requires-Dist: onnxscript>=0.1.0; extra == "export"
55
+ Provides-Extra: export-gpu
56
+ Requires-Dist: onnx>=1.15.0; extra == "export-gpu"
57
+ Requires-Dist: onnxruntime-gpu>=1.16.0; extra == "export-gpu"
58
+ Requires-Dist: onnxscript>=0.1.0; extra == "export-gpu"
59
+ Provides-Extra: tensorrt
60
+ Requires-Dist: tensorrt>=10.13.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
61
+ Requires-Dist: torch-tensorrt>=2.5.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
66
62
  Dynamic: license-file
67
63
 
68
64
  # sleap-nn
@@ -120,22 +116,28 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
120
116
  > Replace `...` with the rest of your install command as needed.
121
117
 
122
118
  - Sync all dependencies based on your correct wheel using `uv sync`. `uv sync` creates a `.venv` (virtual environment) inside your current working directory. This environment is only active within that directory and can't be directly accessed from outside. To use all installed packages, you must run commands with `uv run` (e.g., `uv run sleap-nn train ...` or `uv run pytest ...`).
123
- - **Windows/Linux with NVIDIA GPU (CUDA 11.8):**
119
+ - **Windows/Linux with NVIDIA GPU (CUDA 13.0):**
124
120
 
125
121
  ```bash
126
- uv sync --extra dev --extra torch-cuda118
122
+ uv sync --extra torch-cuda130
127
123
  ```
128
124
 
129
125
  - **Windows/Linux with NVIDIA GPU (CUDA 12.8):**
130
126
 
131
127
  ```bash
132
- uv sync --extra dev --extra torch-cuda128
128
+ uv sync --extra torch-cuda128
133
129
  ```
134
-
135
- - **macOS with Apple Silicon (M1, M2, M3, M4) or CPU-only (no GPU or unsupported GPU):**
130
+
131
+ - **Windows/Linux with NVIDIA GPU (CUDA 11.8):**
132
+
133
+ ```bash
134
+ uv sync --extra torch-cuda118
135
+ ```
136
+
137
+ - **macOS with Apple Silicon (M1, M2, M3, M4) or CPU-only (no GPU or unsupported GPU):**
136
138
  Note: Even if torch-cpu is used on macOS, the MPS backend will be available.
137
139
  ```bash
138
- uv sync --extra dev --extra torch-cpu
140
+ uv sync --extra torch-cpu
139
141
  ```
140
142
 
141
143
  4. **Run tests**
@@ -152,6 +154,6 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
152
154
  > **Upgrading All Dependencies**
153
155
  > To ensure you have the latest versions of all dependencies, use the `--upgrade` flag with `uv sync`:
154
156
  > ```bash
155
- > uv sync --extra dev --upgrade
157
+ > uv sync --upgrade
156
158
  > ```
157
159
  > This will upgrade all installed packages in your environment to the latest available versions compatible with your `pyproject.toml`.