sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Custom learning rate schedulers for sleap-nn training.
|
|
2
|
+
|
|
3
|
+
This module provides learning rate schedulers with warmup phases that are commonly
|
|
4
|
+
used in deep learning for pose estimation and computer vision tasks.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from torch.optim.lr_scheduler import LRScheduler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LinearWarmupCosineAnnealingLR(LRScheduler):
|
|
12
|
+
"""Cosine annealing learning rate scheduler with linear warmup.
|
|
13
|
+
|
|
14
|
+
The learning rate increases linearly from `warmup_start_lr` to the optimizer's
|
|
15
|
+
base learning rate over `warmup_epochs`, then decreases following a cosine
|
|
16
|
+
curve to `eta_min` over the remaining epochs.
|
|
17
|
+
|
|
18
|
+
This schedule is widely used in vision transformers and modern CNN architectures
|
|
19
|
+
as it provides stable early training (warmup) and smooth convergence (cosine decay).
|
|
20
|
+
|
|
21
|
+
Args:
|
|
22
|
+
optimizer: Wrapped optimizer.
|
|
23
|
+
warmup_epochs: Number of epochs for the linear warmup phase.
|
|
24
|
+
max_epochs: Total number of training epochs.
|
|
25
|
+
warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
|
|
26
|
+
eta_min: Minimum learning rate at the end of the schedule. Default: 0.0.
|
|
27
|
+
last_epoch: The index of the last epoch. Default: -1.
|
|
28
|
+
|
|
29
|
+
Example:
|
|
30
|
+
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
31
|
+
>>> scheduler = LinearWarmupCosineAnnealingLR(
|
|
32
|
+
... optimizer, warmup_epochs=5, max_epochs=100, eta_min=1e-6
|
|
33
|
+
... )
|
|
34
|
+
>>> for epoch in range(100):
|
|
35
|
+
... train(...)
|
|
36
|
+
... scheduler.step()
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
def __init__(
|
|
40
|
+
self,
|
|
41
|
+
optimizer,
|
|
42
|
+
warmup_epochs: int,
|
|
43
|
+
max_epochs: int,
|
|
44
|
+
warmup_start_lr: float = 0.0,
|
|
45
|
+
eta_min: float = 0.0,
|
|
46
|
+
last_epoch: int = -1,
|
|
47
|
+
):
|
|
48
|
+
"""Initialize the scheduler.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
optimizer: Wrapped optimizer.
|
|
52
|
+
warmup_epochs: Number of epochs for the linear warmup phase.
|
|
53
|
+
max_epochs: Total number of training epochs.
|
|
54
|
+
warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
|
|
55
|
+
eta_min: Minimum learning rate at the end of the schedule. Default: 0.0.
|
|
56
|
+
last_epoch: The index of the last epoch. Default: -1.
|
|
57
|
+
"""
|
|
58
|
+
if warmup_epochs < 0:
|
|
59
|
+
raise ValueError(f"warmup_epochs must be >= 0, got {warmup_epochs}")
|
|
60
|
+
if max_epochs <= 0:
|
|
61
|
+
raise ValueError(f"max_epochs must be > 0, got {max_epochs}")
|
|
62
|
+
if warmup_epochs >= max_epochs:
|
|
63
|
+
raise ValueError(
|
|
64
|
+
f"warmup_epochs ({warmup_epochs}) must be < max_epochs ({max_epochs})"
|
|
65
|
+
)
|
|
66
|
+
if warmup_start_lr < 0:
|
|
67
|
+
raise ValueError(f"warmup_start_lr must be >= 0, got {warmup_start_lr}")
|
|
68
|
+
if eta_min < 0:
|
|
69
|
+
raise ValueError(f"eta_min must be >= 0, got {eta_min}")
|
|
70
|
+
|
|
71
|
+
self.warmup_epochs = warmup_epochs
|
|
72
|
+
self.max_epochs = max_epochs
|
|
73
|
+
self.warmup_start_lr = warmup_start_lr
|
|
74
|
+
self.eta_min = eta_min
|
|
75
|
+
super().__init__(optimizer, last_epoch)
|
|
76
|
+
|
|
77
|
+
def get_lr(self):
|
|
78
|
+
"""Compute the learning rate at the current epoch."""
|
|
79
|
+
if self.last_epoch < self.warmup_epochs:
|
|
80
|
+
# Linear warmup phase
|
|
81
|
+
if self.warmup_epochs == 0:
|
|
82
|
+
return list(self.base_lrs)
|
|
83
|
+
alpha = self.last_epoch / self.warmup_epochs
|
|
84
|
+
return [
|
|
85
|
+
self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr)
|
|
86
|
+
for base_lr in self.base_lrs
|
|
87
|
+
]
|
|
88
|
+
else:
|
|
89
|
+
# Cosine annealing phase
|
|
90
|
+
decay_epochs = self.max_epochs - self.warmup_epochs
|
|
91
|
+
if decay_epochs == 0:
|
|
92
|
+
return [self.eta_min for _ in self.base_lrs]
|
|
93
|
+
progress = (self.last_epoch - self.warmup_epochs) / decay_epochs
|
|
94
|
+
# Clamp progress to [0, 1] to handle epochs beyond max_epochs
|
|
95
|
+
progress = min(1.0, progress)
|
|
96
|
+
return [
|
|
97
|
+
self.eta_min
|
|
98
|
+
+ (base_lr - self.eta_min) * (1 + math.cos(math.pi * progress)) / 2
|
|
99
|
+
for base_lr in self.base_lrs
|
|
100
|
+
]
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class LinearWarmupLinearDecayLR(LRScheduler):
|
|
104
|
+
"""Linear warmup followed by linear decay learning rate scheduler.
|
|
105
|
+
|
|
106
|
+
The learning rate increases linearly from `warmup_start_lr` to the optimizer's
|
|
107
|
+
base learning rate over `warmup_epochs`, then decreases linearly to `end_lr`
|
|
108
|
+
over the remaining epochs.
|
|
109
|
+
|
|
110
|
+
This schedule provides a simple, interpretable learning rate trajectory and is
|
|
111
|
+
commonly used in transformer-based models and NLP tasks.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
optimizer: Wrapped optimizer.
|
|
115
|
+
warmup_epochs: Number of epochs for the linear warmup phase.
|
|
116
|
+
max_epochs: Total number of training epochs.
|
|
117
|
+
warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
|
|
118
|
+
end_lr: Learning rate at the end of training. Default: 0.0.
|
|
119
|
+
last_epoch: The index of the last epoch. Default: -1.
|
|
120
|
+
|
|
121
|
+
Example:
|
|
122
|
+
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
|
123
|
+
>>> scheduler = LinearWarmupLinearDecayLR(
|
|
124
|
+
... optimizer, warmup_epochs=5, max_epochs=100, end_lr=1e-6
|
|
125
|
+
... )
|
|
126
|
+
>>> for epoch in range(100):
|
|
127
|
+
... train(...)
|
|
128
|
+
... scheduler.step()
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(
|
|
132
|
+
self,
|
|
133
|
+
optimizer,
|
|
134
|
+
warmup_epochs: int,
|
|
135
|
+
max_epochs: int,
|
|
136
|
+
warmup_start_lr: float = 0.0,
|
|
137
|
+
end_lr: float = 0.0,
|
|
138
|
+
last_epoch: int = -1,
|
|
139
|
+
):
|
|
140
|
+
"""Initialize the scheduler.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
optimizer: Wrapped optimizer.
|
|
144
|
+
warmup_epochs: Number of epochs for the linear warmup phase.
|
|
145
|
+
max_epochs: Total number of training epochs.
|
|
146
|
+
warmup_start_lr: Learning rate at the start of warmup. Default: 0.0.
|
|
147
|
+
end_lr: Learning rate at the end of training. Default: 0.0.
|
|
148
|
+
last_epoch: The index of the last epoch. Default: -1.
|
|
149
|
+
"""
|
|
150
|
+
if warmup_epochs < 0:
|
|
151
|
+
raise ValueError(f"warmup_epochs must be >= 0, got {warmup_epochs}")
|
|
152
|
+
if max_epochs <= 0:
|
|
153
|
+
raise ValueError(f"max_epochs must be > 0, got {max_epochs}")
|
|
154
|
+
if warmup_epochs >= max_epochs:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"warmup_epochs ({warmup_epochs}) must be < max_epochs ({max_epochs})"
|
|
157
|
+
)
|
|
158
|
+
if warmup_start_lr < 0:
|
|
159
|
+
raise ValueError(f"warmup_start_lr must be >= 0, got {warmup_start_lr}")
|
|
160
|
+
if end_lr < 0:
|
|
161
|
+
raise ValueError(f"end_lr must be >= 0, got {end_lr}")
|
|
162
|
+
|
|
163
|
+
self.warmup_epochs = warmup_epochs
|
|
164
|
+
self.max_epochs = max_epochs
|
|
165
|
+
self.warmup_start_lr = warmup_start_lr
|
|
166
|
+
self.end_lr = end_lr
|
|
167
|
+
super().__init__(optimizer, last_epoch)
|
|
168
|
+
|
|
169
|
+
def get_lr(self):
|
|
170
|
+
"""Compute the learning rate at the current epoch."""
|
|
171
|
+
if self.last_epoch < self.warmup_epochs:
|
|
172
|
+
# Linear warmup phase
|
|
173
|
+
if self.warmup_epochs == 0:
|
|
174
|
+
return list(self.base_lrs)
|
|
175
|
+
alpha = self.last_epoch / self.warmup_epochs
|
|
176
|
+
return [
|
|
177
|
+
self.warmup_start_lr + alpha * (base_lr - self.warmup_start_lr)
|
|
178
|
+
for base_lr in self.base_lrs
|
|
179
|
+
]
|
|
180
|
+
else:
|
|
181
|
+
# Linear decay phase
|
|
182
|
+
decay_epochs = self.max_epochs - self.warmup_epochs
|
|
183
|
+
if decay_epochs == 0:
|
|
184
|
+
return [self.end_lr for _ in self.base_lrs]
|
|
185
|
+
progress = (self.last_epoch - self.warmup_epochs) / decay_epochs
|
|
186
|
+
# Clamp progress to [0, 1] to handle epochs beyond max_epochs
|
|
187
|
+
progress = min(1.0, progress)
|
|
188
|
+
return [
|
|
189
|
+
base_lr + progress * (self.end_lr - base_lr)
|
|
190
|
+
for base_lr in self.base_lrs
|
|
191
|
+
]
|
sleap_nn/training/utils.py
CHANGED
|
@@ -1,13 +1,19 @@
|
|
|
1
1
|
"""Miscellaneous utility functions for training."""
|
|
2
2
|
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from io import BytesIO
|
|
3
5
|
import numpy as np
|
|
6
|
+
import matplotlib
|
|
7
|
+
|
|
8
|
+
matplotlib.use(
|
|
9
|
+
"Agg"
|
|
10
|
+
) # Use non-interactive backend to avoid tkinter issues on Windows CI
|
|
4
11
|
import matplotlib.pyplot as plt
|
|
5
12
|
from loguru import logger
|
|
6
13
|
from torch import nn
|
|
7
14
|
import torch.distributed as dist
|
|
8
|
-
import matplotlib
|
|
9
15
|
import seaborn as sns
|
|
10
|
-
from typing import List
|
|
16
|
+
from typing import List, Optional
|
|
11
17
|
import shutil
|
|
12
18
|
import os
|
|
13
19
|
import subprocess
|
|
@@ -236,3 +242,362 @@ def plot_peaks(
|
|
|
236
242
|
)
|
|
237
243
|
)
|
|
238
244
|
return handles
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
@dataclass
|
|
248
|
+
class VisualizationData:
|
|
249
|
+
"""Container for visualization data from a single sample.
|
|
250
|
+
|
|
251
|
+
This dataclass decouples data extraction from rendering, allowing the same
|
|
252
|
+
data to be rendered to different output targets (matplotlib, wandb, etc.).
|
|
253
|
+
|
|
254
|
+
Attributes:
|
|
255
|
+
image: Input image as (H, W, C) numpy array, normalized to [0, 1].
|
|
256
|
+
pred_confmaps: Predicted confidence maps as (H, W, nodes) array, values in [0, 1].
|
|
257
|
+
pred_peaks: Predicted keypoints as (instances, nodes, 2) or (nodes, 2) array.
|
|
258
|
+
pred_peak_values: Confidence values as (instances, nodes) or (nodes,) array.
|
|
259
|
+
gt_instances: Ground truth keypoints, same shape as pred_peaks.
|
|
260
|
+
node_names: List of node/keypoint names, e.g., ["head", "thorax", ...].
|
|
261
|
+
output_scale: Ratio of confmap size to image size (confmap_h / image_h).
|
|
262
|
+
is_paired: Whether GT and predictions can be paired for error visualization.
|
|
263
|
+
pred_pafs: Part affinity fields for bottom-up models, optional.
|
|
264
|
+
pred_class_maps: Class maps for multi-class models, optional.
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
image: np.ndarray
|
|
268
|
+
pred_confmaps: np.ndarray
|
|
269
|
+
pred_peaks: np.ndarray
|
|
270
|
+
pred_peak_values: np.ndarray
|
|
271
|
+
gt_instances: np.ndarray
|
|
272
|
+
node_names: List[str] = field(default_factory=list)
|
|
273
|
+
output_scale: float = 1.0
|
|
274
|
+
is_paired: bool = True
|
|
275
|
+
pred_pafs: Optional[np.ndarray] = None
|
|
276
|
+
pred_class_maps: Optional[np.ndarray] = None
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
class MatplotlibRenderer:
|
|
280
|
+
"""Renders VisualizationData to matplotlib figures."""
|
|
281
|
+
|
|
282
|
+
def render(self, data: VisualizationData) -> matplotlib.figure.Figure:
|
|
283
|
+
"""Render visualization data to a matplotlib figure.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
data: VisualizationData containing image, confmaps, peaks, etc.
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
A matplotlib Figure object.
|
|
290
|
+
"""
|
|
291
|
+
img = data.image
|
|
292
|
+
scale = 1.0
|
|
293
|
+
if img.shape[0] < 512:
|
|
294
|
+
scale = 2.0
|
|
295
|
+
if img.shape[0] < 256:
|
|
296
|
+
scale = 4.0
|
|
297
|
+
|
|
298
|
+
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
299
|
+
plot_confmaps(data.pred_confmaps, output_scale=data.output_scale)
|
|
300
|
+
plot_peaks(data.gt_instances, data.pred_peaks, paired=data.is_paired)
|
|
301
|
+
return fig
|
|
302
|
+
|
|
303
|
+
def render_pafs(self, data: VisualizationData) -> matplotlib.figure.Figure:
|
|
304
|
+
"""Render PAF magnitude visualization.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
data: VisualizationData with pred_pafs populated.
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
A matplotlib Figure object showing PAF magnitudes.
|
|
311
|
+
"""
|
|
312
|
+
if data.pred_pafs is None:
|
|
313
|
+
raise ValueError("pred_pafs is None, cannot render PAFs")
|
|
314
|
+
|
|
315
|
+
img = data.image
|
|
316
|
+
scale = 1.0
|
|
317
|
+
if img.shape[0] < 512:
|
|
318
|
+
scale = 2.0
|
|
319
|
+
if img.shape[0] < 256:
|
|
320
|
+
scale = 4.0
|
|
321
|
+
|
|
322
|
+
# Compute PAF magnitude
|
|
323
|
+
pafs = data.pred_pafs # (H, W, 2*edges) or (H, W, edges, 2)
|
|
324
|
+
if pafs.ndim == 3:
|
|
325
|
+
n_edges = pafs.shape[-1] // 2
|
|
326
|
+
pafs = pafs.reshape(pafs.shape[0], pafs.shape[1], n_edges, 2)
|
|
327
|
+
magnitude = np.sqrt(pafs[..., 0] ** 2 + pafs[..., 1] ** 2)
|
|
328
|
+
magnitude = magnitude.max(axis=-1) # Max over edges
|
|
329
|
+
|
|
330
|
+
fig = plot_img(img, dpi=72 * scale, scale=scale)
|
|
331
|
+
ax = plt.gca()
|
|
332
|
+
|
|
333
|
+
# Calculate PAF output scale from actual PAF dimensions, not confmap output_scale
|
|
334
|
+
# PAFs may have a different output_stride than confmaps
|
|
335
|
+
paf_output_scale = magnitude.shape[0] / img.shape[0]
|
|
336
|
+
|
|
337
|
+
ax.imshow(
|
|
338
|
+
magnitude,
|
|
339
|
+
alpha=0.5,
|
|
340
|
+
origin="upper",
|
|
341
|
+
cmap="viridis",
|
|
342
|
+
extent=[
|
|
343
|
+
-0.5,
|
|
344
|
+
magnitude.shape[1] / paf_output_scale - 0.5,
|
|
345
|
+
magnitude.shape[0] / paf_output_scale - 0.5,
|
|
346
|
+
-0.5,
|
|
347
|
+
],
|
|
348
|
+
)
|
|
349
|
+
return fig
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
class WandBRenderer:
|
|
353
|
+
"""Renders VisualizationData to wandb.Image objects.
|
|
354
|
+
|
|
355
|
+
Supports multiple rendering modes:
|
|
356
|
+
- "direct": Pre-render with matplotlib, convert to wandb.Image
|
|
357
|
+
- "boxes": Use wandb boxes for interactive keypoint visualization
|
|
358
|
+
- "masks": Use wandb masks for confidence map overlay
|
|
359
|
+
"""
|
|
360
|
+
|
|
361
|
+
def __init__(
|
|
362
|
+
self,
|
|
363
|
+
mode: str = "direct",
|
|
364
|
+
box_size: float = 5.0,
|
|
365
|
+
confmap_threshold: float = 0.1,
|
|
366
|
+
min_size: int = 512,
|
|
367
|
+
):
|
|
368
|
+
"""Initialize the renderer.
|
|
369
|
+
|
|
370
|
+
Args:
|
|
371
|
+
mode: Rendering mode - "direct", "boxes", or "masks".
|
|
372
|
+
box_size: Size of keypoint boxes in pixels (for "boxes" mode).
|
|
373
|
+
confmap_threshold: Threshold for confmap mask (for "masks" mode).
|
|
374
|
+
min_size: Minimum image dimension. Smaller images will be upscaled.
|
|
375
|
+
"""
|
|
376
|
+
self.mode = mode
|
|
377
|
+
self.box_size = box_size
|
|
378
|
+
self.confmap_threshold = confmap_threshold
|
|
379
|
+
self.min_size = min_size
|
|
380
|
+
self._mpl_renderer = MatplotlibRenderer()
|
|
381
|
+
|
|
382
|
+
def render(
|
|
383
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
384
|
+
) -> "wandb.Image":
|
|
385
|
+
"""Render visualization data to a wandb.Image.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
data: VisualizationData containing image, confmaps, peaks, etc.
|
|
389
|
+
caption: Optional caption for the image.
|
|
390
|
+
|
|
391
|
+
Returns:
|
|
392
|
+
A wandb.Image object.
|
|
393
|
+
"""
|
|
394
|
+
import wandb
|
|
395
|
+
|
|
396
|
+
if self.mode == "boxes":
|
|
397
|
+
return self._render_with_boxes(data, caption)
|
|
398
|
+
elif self.mode == "masks":
|
|
399
|
+
return self._render_with_masks(data, caption)
|
|
400
|
+
else: # "direct"
|
|
401
|
+
return self._render_direct(data, caption)
|
|
402
|
+
|
|
403
|
+
def _get_scale_factor(self, img_h: int, img_w: int) -> int:
|
|
404
|
+
"""Calculate scale factor to ensure minimum image size."""
|
|
405
|
+
min_dim = min(img_h, img_w)
|
|
406
|
+
if min_dim >= self.min_size:
|
|
407
|
+
return 1
|
|
408
|
+
return int(np.ceil(self.min_size / min_dim))
|
|
409
|
+
|
|
410
|
+
def _render_direct(
|
|
411
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
412
|
+
) -> "wandb.Image":
|
|
413
|
+
"""Pre-render with matplotlib, return as wandb.Image."""
|
|
414
|
+
import wandb
|
|
415
|
+
from PIL import Image
|
|
416
|
+
|
|
417
|
+
fig = self._mpl_renderer.render(data)
|
|
418
|
+
|
|
419
|
+
# Convert figure to PIL Image
|
|
420
|
+
buf = BytesIO()
|
|
421
|
+
fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0)
|
|
422
|
+
buf.seek(0)
|
|
423
|
+
plt.close(fig)
|
|
424
|
+
|
|
425
|
+
pil_image = Image.open(buf)
|
|
426
|
+
return wandb.Image(pil_image, caption=caption)
|
|
427
|
+
|
|
428
|
+
def _render_with_boxes(
|
|
429
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
430
|
+
) -> "wandb.Image":
|
|
431
|
+
"""Use wandb boxes for interactive keypoint visualization."""
|
|
432
|
+
import wandb
|
|
433
|
+
from PIL import Image
|
|
434
|
+
|
|
435
|
+
# Prepare class labels from node names
|
|
436
|
+
class_labels = {i: name for i, name in enumerate(data.node_names)}
|
|
437
|
+
if not class_labels:
|
|
438
|
+
class_labels = {i: f"node_{i}" for i in range(data.pred_peaks.shape[-2])}
|
|
439
|
+
|
|
440
|
+
# Convert image to uint8
|
|
441
|
+
img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
|
|
442
|
+
# Handle single-channel images: squeeze (H, W, 1) -> (H, W)
|
|
443
|
+
if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
|
|
444
|
+
img_uint8 = img_uint8.squeeze(axis=2)
|
|
445
|
+
img_h, img_w = img_uint8.shape[:2]
|
|
446
|
+
|
|
447
|
+
# Scale up small images for better visibility in wandb
|
|
448
|
+
scale = self._get_scale_factor(img_h, img_w)
|
|
449
|
+
if scale > 1:
|
|
450
|
+
pil_img = Image.fromarray(img_uint8)
|
|
451
|
+
pil_img = pil_img.resize(
|
|
452
|
+
(img_w * scale, img_h * scale), resample=Image.BILINEAR
|
|
453
|
+
)
|
|
454
|
+
img_uint8 = np.array(pil_img)
|
|
455
|
+
|
|
456
|
+
# Build ground truth boxes (use percent domain for proper thumbnail scaling)
|
|
457
|
+
gt_box_data = self._peaks_to_boxes(
|
|
458
|
+
data.gt_instances, data.node_names, img_w, img_h, is_gt=True
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
# Build prediction boxes
|
|
462
|
+
pred_box_data = self._peaks_to_boxes(
|
|
463
|
+
data.pred_peaks,
|
|
464
|
+
data.node_names,
|
|
465
|
+
img_w,
|
|
466
|
+
img_h,
|
|
467
|
+
peak_values=data.pred_peak_values,
|
|
468
|
+
is_gt=False,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
return wandb.Image(
|
|
472
|
+
img_uint8,
|
|
473
|
+
boxes={
|
|
474
|
+
"ground_truth": {"box_data": gt_box_data, "class_labels": class_labels},
|
|
475
|
+
"predictions": {
|
|
476
|
+
"box_data": pred_box_data,
|
|
477
|
+
"class_labels": class_labels,
|
|
478
|
+
},
|
|
479
|
+
},
|
|
480
|
+
caption=caption,
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def _peaks_to_boxes(
|
|
484
|
+
self,
|
|
485
|
+
peaks: np.ndarray,
|
|
486
|
+
node_names: List[str],
|
|
487
|
+
img_w: int,
|
|
488
|
+
img_h: int,
|
|
489
|
+
peak_values: Optional[np.ndarray] = None,
|
|
490
|
+
is_gt: bool = False,
|
|
491
|
+
) -> List[dict]:
|
|
492
|
+
"""Convert peaks array to wandb box_data format.
|
|
493
|
+
|
|
494
|
+
Args:
|
|
495
|
+
peaks: Keypoints as (instances, nodes, 2) or (nodes, 2).
|
|
496
|
+
node_names: List of node names.
|
|
497
|
+
img_w: Image width in pixels.
|
|
498
|
+
img_h: Image height in pixels.
|
|
499
|
+
peak_values: Optional confidence values.
|
|
500
|
+
is_gt: Whether these are ground truth points.
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
List of box dictionaries for wandb.
|
|
504
|
+
"""
|
|
505
|
+
box_data = []
|
|
506
|
+
|
|
507
|
+
# Normalize shape to (instances, nodes, 2)
|
|
508
|
+
if peaks.ndim == 2:
|
|
509
|
+
peaks = peaks[np.newaxis, ...]
|
|
510
|
+
if peak_values is not None and peak_values.ndim == 1:
|
|
511
|
+
peak_values = peak_values[np.newaxis, ...]
|
|
512
|
+
|
|
513
|
+
# Convert box_size from pixels to percent
|
|
514
|
+
box_w_pct = self.box_size / img_w
|
|
515
|
+
box_h_pct = self.box_size / img_h
|
|
516
|
+
|
|
517
|
+
for inst_idx, instance in enumerate(peaks):
|
|
518
|
+
for node_idx, (x, y) in enumerate(instance):
|
|
519
|
+
if np.isnan(x) or np.isnan(y):
|
|
520
|
+
continue
|
|
521
|
+
|
|
522
|
+
node_name = (
|
|
523
|
+
node_names[node_idx]
|
|
524
|
+
if node_idx < len(node_names)
|
|
525
|
+
else f"node_{node_idx}"
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
# Convert pixel coordinates to percent (0-1 range)
|
|
529
|
+
x_pct = float(x) / img_w
|
|
530
|
+
y_pct = float(y) / img_h
|
|
531
|
+
|
|
532
|
+
box = {
|
|
533
|
+
"position": {
|
|
534
|
+
"middle": [x_pct, y_pct],
|
|
535
|
+
"width": box_w_pct,
|
|
536
|
+
"height": box_h_pct,
|
|
537
|
+
},
|
|
538
|
+
"domain": "percent",
|
|
539
|
+
"class_id": node_idx,
|
|
540
|
+
}
|
|
541
|
+
|
|
542
|
+
if is_gt:
|
|
543
|
+
box["box_caption"] = f"GT: {node_name}"
|
|
544
|
+
else:
|
|
545
|
+
if peak_values is not None:
|
|
546
|
+
conf = float(peak_values[inst_idx, node_idx])
|
|
547
|
+
box["box_caption"] = f"{node_name} ({conf:.2f})"
|
|
548
|
+
box["scores"] = {"confidence": conf}
|
|
549
|
+
else:
|
|
550
|
+
box["box_caption"] = node_name
|
|
551
|
+
|
|
552
|
+
box_data.append(box)
|
|
553
|
+
|
|
554
|
+
return box_data
|
|
555
|
+
|
|
556
|
+
def _render_with_masks(
|
|
557
|
+
self, data: VisualizationData, caption: Optional[str] = None
|
|
558
|
+
) -> "wandb.Image":
|
|
559
|
+
"""Use wandb masks for confidence map overlay.
|
|
560
|
+
|
|
561
|
+
Uses argmax approach: each pixel shows the dominant node.
|
|
562
|
+
"""
|
|
563
|
+
import wandb
|
|
564
|
+
|
|
565
|
+
# Prepare class labels (0 = background, 1+ = nodes)
|
|
566
|
+
class_labels = {0: "background"}
|
|
567
|
+
for i, name in enumerate(data.node_names):
|
|
568
|
+
class_labels[i + 1] = name
|
|
569
|
+
if not data.node_names:
|
|
570
|
+
n_nodes = data.pred_confmaps.shape[-1]
|
|
571
|
+
for i in range(n_nodes):
|
|
572
|
+
class_labels[i + 1] = f"node_{i}"
|
|
573
|
+
|
|
574
|
+
# Create argmax mask from confmaps
|
|
575
|
+
confmaps = data.pred_confmaps # (H/stride, W/stride, nodes)
|
|
576
|
+
max_vals = confmaps.max(axis=-1)
|
|
577
|
+
argmax_map = confmaps.argmax(axis=-1) + 1 # +1 for background offset
|
|
578
|
+
argmax_map[max_vals < self.confmap_threshold] = 0 # Background
|
|
579
|
+
|
|
580
|
+
# Convert image to uint8
|
|
581
|
+
img_uint8 = (np.clip(data.image, 0, 1) * 255).astype(np.uint8)
|
|
582
|
+
# Handle single-channel images: (H, W, 1) -> (H, W)
|
|
583
|
+
if img_uint8.ndim == 3 and img_uint8.shape[2] == 1:
|
|
584
|
+
img_uint8 = img_uint8.squeeze(axis=2)
|
|
585
|
+
img_h, img_w = img_uint8.shape[:2]
|
|
586
|
+
|
|
587
|
+
# Resize mask to match image dimensions (confmaps are H/stride, W/stride)
|
|
588
|
+
from PIL import Image
|
|
589
|
+
|
|
590
|
+
mask_pil = Image.fromarray(argmax_map.astype(np.uint8))
|
|
591
|
+
mask_pil = mask_pil.resize((img_w, img_h), resample=Image.NEAREST)
|
|
592
|
+
argmax_map = np.array(mask_pil)
|
|
593
|
+
|
|
594
|
+
return wandb.Image(
|
|
595
|
+
img_uint8,
|
|
596
|
+
masks={
|
|
597
|
+
"confidence_maps": {
|
|
598
|
+
"mask_data": argmax_map.astype(np.uint8),
|
|
599
|
+
"class_labels": class_labels,
|
|
600
|
+
}
|
|
601
|
+
},
|
|
602
|
+
caption=caption,
|
|
603
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: sleap-nn
|
|
3
|
-
Version: 0.0
|
|
3
|
+
Version: 0.1.0
|
|
4
4
|
Summary: Neural network backend for training and inference for animal pose estimation.
|
|
5
5
|
Author-email: Divya Seshadri Murali <dimurali@salk.edu>, Elizabeth Berrigan <eberrigan@salk.edu>, Vincent Tu <vitu@ucsd.edu>, Liezl Maree <lmaree@salk.edu>, David Samy <davidasamy@gmail.com>, Talmo Pereira <talmo@salk.edu>
|
|
6
6
|
License: BSD-3-Clause
|
|
@@ -13,10 +13,10 @@ Classifier: Programming Language :: Python :: 3.13
|
|
|
13
13
|
Requires-Python: <3.14,>=3.11
|
|
14
14
|
Description-Content-Type: text/markdown
|
|
15
15
|
License-File: LICENSE
|
|
16
|
-
Requires-Dist: sleap-io
|
|
16
|
+
Requires-Dist: sleap-io<0.7.0,>=0.6.2
|
|
17
17
|
Requires-Dist: numpy
|
|
18
18
|
Requires-Dist: lightning
|
|
19
|
-
Requires-Dist:
|
|
19
|
+
Requires-Dist: skia-python>=87.0
|
|
20
20
|
Requires-Dist: jsonpickle
|
|
21
21
|
Requires-Dist: scipy
|
|
22
22
|
Requires-Dist: attrs
|
|
@@ -32,37 +32,33 @@ Requires-Dist: hydra-core
|
|
|
32
32
|
Requires-Dist: jupyter
|
|
33
33
|
Requires-Dist: jupyterlab
|
|
34
34
|
Requires-Dist: pyzmq
|
|
35
|
+
Requires-Dist: rich-click>=1.9.5
|
|
35
36
|
Provides-Extra: torch
|
|
36
37
|
Requires-Dist: torch; extra == "torch"
|
|
37
|
-
Requires-Dist: torchvision
|
|
38
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch"
|
|
38
39
|
Provides-Extra: torch-cpu
|
|
39
40
|
Requires-Dist: torch; extra == "torch-cpu"
|
|
40
|
-
Requires-Dist: torchvision
|
|
41
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cpu"
|
|
41
42
|
Provides-Extra: torch-cuda118
|
|
42
43
|
Requires-Dist: torch; extra == "torch-cuda118"
|
|
43
|
-
Requires-Dist: torchvision
|
|
44
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda118"
|
|
44
45
|
Provides-Extra: torch-cuda128
|
|
45
46
|
Requires-Dist: torch; extra == "torch-cuda128"
|
|
46
|
-
Requires-Dist: torchvision
|
|
47
|
-
Provides-Extra:
|
|
48
|
-
Requires-Dist:
|
|
49
|
-
Requires-Dist:
|
|
50
|
-
|
|
51
|
-
Requires-Dist:
|
|
52
|
-
Requires-Dist:
|
|
53
|
-
Requires-Dist:
|
|
54
|
-
|
|
55
|
-
Requires-Dist:
|
|
56
|
-
Requires-Dist:
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
Requires-Dist:
|
|
60
|
-
Requires-Dist:
|
|
61
|
-
Requires-Dist: mike; extra == "docs"
|
|
62
|
-
Requires-Dist: mkdocstrings[python]; extra == "docs"
|
|
63
|
-
Requires-Dist: mkdocs-gen-files; extra == "docs"
|
|
64
|
-
Requires-Dist: mkdocs-literate-nav; extra == "docs"
|
|
65
|
-
Requires-Dist: mkdocs-section-index; extra == "docs"
|
|
47
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda128"
|
|
48
|
+
Provides-Extra: torch-cuda130
|
|
49
|
+
Requires-Dist: torch; extra == "torch-cuda130"
|
|
50
|
+
Requires-Dist: torchvision>=0.20.0; extra == "torch-cuda130"
|
|
51
|
+
Provides-Extra: export
|
|
52
|
+
Requires-Dist: onnx>=1.15.0; extra == "export"
|
|
53
|
+
Requires-Dist: onnxruntime>=1.16.0; extra == "export"
|
|
54
|
+
Requires-Dist: onnxscript>=0.1.0; extra == "export"
|
|
55
|
+
Provides-Extra: export-gpu
|
|
56
|
+
Requires-Dist: onnx>=1.15.0; extra == "export-gpu"
|
|
57
|
+
Requires-Dist: onnxruntime-gpu>=1.16.0; extra == "export-gpu"
|
|
58
|
+
Requires-Dist: onnxscript>=0.1.0; extra == "export-gpu"
|
|
59
|
+
Provides-Extra: tensorrt
|
|
60
|
+
Requires-Dist: tensorrt>=10.13.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
|
|
61
|
+
Requires-Dist: torch-tensorrt>=2.5.0; (sys_platform == "linux" or sys_platform == "win32") and extra == "tensorrt"
|
|
66
62
|
Dynamic: license-file
|
|
67
63
|
|
|
68
64
|
# sleap-nn
|
|
@@ -120,22 +116,28 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
|
|
120
116
|
> Replace `...` with the rest of your install command as needed.
|
|
121
117
|
|
|
122
118
|
- Sync all dependencies based on your correct wheel using `uv sync`. `uv sync` creates a `.venv` (virtual environment) inside your current working directory. This environment is only active within that directory and can't be directly accessed from outside. To use all installed packages, you must run commands with `uv run` (e.g., `uv run sleap-nn train ...` or `uv run pytest ...`).
|
|
123
|
-
- **Windows/Linux with NVIDIA GPU (CUDA
|
|
119
|
+
- **Windows/Linux with NVIDIA GPU (CUDA 13.0):**
|
|
124
120
|
|
|
125
121
|
```bash
|
|
126
|
-
uv sync --extra
|
|
122
|
+
uv sync --extra torch-cuda130
|
|
127
123
|
```
|
|
128
124
|
|
|
129
125
|
- **Windows/Linux with NVIDIA GPU (CUDA 12.8):**
|
|
130
126
|
|
|
131
127
|
```bash
|
|
132
|
-
uv sync --extra
|
|
128
|
+
uv sync --extra torch-cuda128
|
|
133
129
|
```
|
|
134
|
-
|
|
135
|
-
- **
|
|
130
|
+
|
|
131
|
+
- **Windows/Linux with NVIDIA GPU (CUDA 11.8):**
|
|
132
|
+
|
|
133
|
+
```bash
|
|
134
|
+
uv sync --extra torch-cuda118
|
|
135
|
+
```
|
|
136
|
+
|
|
137
|
+
- **macOS with Apple Silicon (M1, M2, M3, M4) or CPU-only (no GPU or unsupported GPU):**
|
|
136
138
|
Note: Even if torch-cpu is used on macOS, the MPS backend will be available.
|
|
137
139
|
```bash
|
|
138
|
-
uv sync --extra
|
|
140
|
+
uv sync --extra torch-cpu
|
|
139
141
|
```
|
|
140
142
|
|
|
141
143
|
4. **Run tests**
|
|
@@ -152,6 +154,6 @@ powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
|
|
152
154
|
> **Upgrading All Dependencies**
|
|
153
155
|
> To ensure you have the latest versions of all dependencies, use the `--upgrade` flag with `uv sync`:
|
|
154
156
|
> ```bash
|
|
155
|
-
> uv sync --
|
|
157
|
+
> uv sync --upgrade
|
|
156
158
|
> ```
|
|
157
159
|
> This will upgrade all installed packages in your environment to the latest available versions compatible with your `pyproject.toml`.
|