sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
|
@@ -1,195 +0,0 @@
|
|
|
1
|
-
"""ONNX wrapper for bottom-up multiclass (supervised ID) models."""
|
|
2
|
-
|
|
3
|
-
from typing import Dict
|
|
4
|
-
|
|
5
|
-
import torch
|
|
6
|
-
import torch.nn as nn
|
|
7
|
-
import torch.nn.functional as F
|
|
8
|
-
|
|
9
|
-
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
class BottomUpMultiClassONNXWrapper(BaseExportWrapper):
|
|
13
|
-
"""ONNX-exportable wrapper for bottom-up multiclass (supervised ID) models.
|
|
14
|
-
|
|
15
|
-
This wrapper handles models that output both confidence maps for keypoint
|
|
16
|
-
detection and class maps for identity classification. Unlike PAF-based
|
|
17
|
-
bottom-up models, multiclass models use class maps to assign identity to
|
|
18
|
-
each detected peak, then group peaks by identity.
|
|
19
|
-
|
|
20
|
-
The wrapper performs:
|
|
21
|
-
1. Peak detection in confidence maps (GPU)
|
|
22
|
-
2. Class probability sampling at peak locations (GPU)
|
|
23
|
-
3. Returns fixed-size tensors for CPU-side grouping
|
|
24
|
-
|
|
25
|
-
Expects input images as uint8 tensors in [0, 255].
|
|
26
|
-
|
|
27
|
-
Attributes:
|
|
28
|
-
model: The underlying PyTorch model.
|
|
29
|
-
n_nodes: Number of keypoint nodes in the skeleton.
|
|
30
|
-
n_classes: Number of identity classes.
|
|
31
|
-
max_peaks_per_node: Maximum number of peaks to detect per node.
|
|
32
|
-
cms_output_stride: Output stride of the confidence map head.
|
|
33
|
-
class_maps_output_stride: Output stride of the class maps head.
|
|
34
|
-
input_scale: Scale factor applied to input images before inference.
|
|
35
|
-
"""
|
|
36
|
-
|
|
37
|
-
def __init__(
|
|
38
|
-
self,
|
|
39
|
-
model: nn.Module,
|
|
40
|
-
n_nodes: int,
|
|
41
|
-
n_classes: int = 2,
|
|
42
|
-
max_peaks_per_node: int = 20,
|
|
43
|
-
cms_output_stride: int = 4,
|
|
44
|
-
class_maps_output_stride: int = 8,
|
|
45
|
-
input_scale: float = 1.0,
|
|
46
|
-
):
|
|
47
|
-
"""Initialize the wrapper.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
model: The underlying PyTorch model.
|
|
51
|
-
n_nodes: Number of keypoint nodes.
|
|
52
|
-
n_classes: Number of identity classes (e.g., 2 for male/female).
|
|
53
|
-
max_peaks_per_node: Maximum peaks per node to detect.
|
|
54
|
-
cms_output_stride: Output stride of confidence maps.
|
|
55
|
-
class_maps_output_stride: Output stride of class maps.
|
|
56
|
-
input_scale: Scale factor for input images.
|
|
57
|
-
"""
|
|
58
|
-
super().__init__(model)
|
|
59
|
-
self.n_nodes = n_nodes
|
|
60
|
-
self.n_classes = n_classes
|
|
61
|
-
self.max_peaks_per_node = max_peaks_per_node
|
|
62
|
-
self.cms_output_stride = cms_output_stride
|
|
63
|
-
self.class_maps_output_stride = class_maps_output_stride
|
|
64
|
-
self.input_scale = input_scale
|
|
65
|
-
|
|
66
|
-
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
67
|
-
"""Run bottom-up multiclass inference.
|
|
68
|
-
|
|
69
|
-
Args:
|
|
70
|
-
image: Input image tensor of shape (batch, channels, height, width).
|
|
71
|
-
Expected to be uint8 in [0, 255].
|
|
72
|
-
|
|
73
|
-
Returns:
|
|
74
|
-
Dictionary with keys:
|
|
75
|
-
- "peaks": Detected peak coordinates (batch, n_nodes, max_peaks, 2).
|
|
76
|
-
Coordinates are in input image space (x, y).
|
|
77
|
-
- "peak_vals": Peak confidence values (batch, n_nodes, max_peaks).
|
|
78
|
-
- "peak_mask": Boolean mask for valid peaks (batch, n_nodes, max_peaks).
|
|
79
|
-
- "class_probs": Class probabilities at each peak location
|
|
80
|
-
(batch, n_nodes, max_peaks, n_classes).
|
|
81
|
-
|
|
82
|
-
Postprocessing on CPU uses `classify_peaks_from_maps()` to group
|
|
83
|
-
peaks by identity using Hungarian matching.
|
|
84
|
-
"""
|
|
85
|
-
# Normalize uint8 [0, 255] to float32 [0, 1]
|
|
86
|
-
image = self._normalize_uint8(image)
|
|
87
|
-
|
|
88
|
-
# Apply input scaling if needed
|
|
89
|
-
if self.input_scale != 1.0:
|
|
90
|
-
height = int(image.shape[-2] * self.input_scale)
|
|
91
|
-
width = int(image.shape[-1] * self.input_scale)
|
|
92
|
-
image = F.interpolate(
|
|
93
|
-
image, size=(height, width), mode="bilinear", align_corners=False
|
|
94
|
-
)
|
|
95
|
-
|
|
96
|
-
batch_size = image.shape[0]
|
|
97
|
-
|
|
98
|
-
# Forward pass
|
|
99
|
-
out = self.model(image)
|
|
100
|
-
|
|
101
|
-
# Extract outputs
|
|
102
|
-
# Note: Use "classmaps" as a single hint to avoid "map" matching "confmaps"
|
|
103
|
-
confmaps = self._extract_tensor(out, ["confmap", "multiinstance"])
|
|
104
|
-
class_maps = self._extract_tensor(out, ["classmaps", "classmapshead"])
|
|
105
|
-
|
|
106
|
-
# Find top-k peaks per node
|
|
107
|
-
peaks, peak_vals, peak_mask = self._find_topk_peaks_per_node(
|
|
108
|
-
confmaps, self.max_peaks_per_node
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
# Scale peaks to input image space
|
|
112
|
-
peaks = peaks * self.cms_output_stride
|
|
113
|
-
|
|
114
|
-
# Sample class maps at peak locations
|
|
115
|
-
class_probs = self._sample_class_maps_at_peaks(class_maps, peaks, peak_mask)
|
|
116
|
-
|
|
117
|
-
# Scale peaks for output (accounting for input scale)
|
|
118
|
-
if self.input_scale != 1.0:
|
|
119
|
-
peaks = peaks / self.input_scale
|
|
120
|
-
|
|
121
|
-
return {
|
|
122
|
-
"peaks": peaks,
|
|
123
|
-
"peak_vals": peak_vals,
|
|
124
|
-
"peak_mask": peak_mask,
|
|
125
|
-
"class_probs": class_probs,
|
|
126
|
-
}
|
|
127
|
-
|
|
128
|
-
def _sample_class_maps_at_peaks(
|
|
129
|
-
self,
|
|
130
|
-
class_maps: torch.Tensor,
|
|
131
|
-
peaks: torch.Tensor,
|
|
132
|
-
peak_mask: torch.Tensor,
|
|
133
|
-
) -> torch.Tensor:
|
|
134
|
-
"""Sample class map values at peak locations.
|
|
135
|
-
|
|
136
|
-
Args:
|
|
137
|
-
class_maps: Class maps of shape (batch, n_classes, height, width).
|
|
138
|
-
peaks: Peak coordinates in cms_output_stride space,
|
|
139
|
-
shape (batch, n_nodes, max_peaks, 2) in (x, y) order.
|
|
140
|
-
peak_mask: Boolean mask for valid peaks (batch, n_nodes, max_peaks).
|
|
141
|
-
|
|
142
|
-
Returns:
|
|
143
|
-
Class probabilities at each peak location,
|
|
144
|
-
shape (batch, n_nodes, max_peaks, n_classes).
|
|
145
|
-
"""
|
|
146
|
-
batch_size, n_classes, cm_height, cm_width = class_maps.shape
|
|
147
|
-
_, n_nodes, max_peaks, _ = peaks.shape
|
|
148
|
-
device = peaks.device
|
|
149
|
-
|
|
150
|
-
# Initialize output tensor
|
|
151
|
-
class_probs = torch.zeros(
|
|
152
|
-
(batch_size, n_nodes, max_peaks, n_classes),
|
|
153
|
-
device=device,
|
|
154
|
-
dtype=class_maps.dtype,
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
# Convert peak coordinates to class map space
|
|
158
|
-
# peaks are in full image space (after cms_output_stride scaling)
|
|
159
|
-
peaks_cm = peaks / self.class_maps_output_stride
|
|
160
|
-
|
|
161
|
-
# Clamp coordinates to valid range
|
|
162
|
-
peaks_cm_x = peaks_cm[..., 0].clamp(0, cm_width - 1)
|
|
163
|
-
peaks_cm_y = peaks_cm[..., 1].clamp(0, cm_height - 1)
|
|
164
|
-
|
|
165
|
-
# Use grid_sample for bilinear interpolation
|
|
166
|
-
# Normalize coordinates to [-1, 1] for grid_sample
|
|
167
|
-
grid_x = (peaks_cm_x / (cm_width - 1)) * 2 - 1
|
|
168
|
-
grid_y = (peaks_cm_y / (cm_height - 1)) * 2 - 1
|
|
169
|
-
|
|
170
|
-
# Reshape for grid_sample: (batch, n_nodes * max_peaks, 1, 2)
|
|
171
|
-
grid = torch.stack([grid_x, grid_y], dim=-1)
|
|
172
|
-
grid_flat = grid.reshape(batch_size, n_nodes * max_peaks, 1, 2)
|
|
173
|
-
|
|
174
|
-
# Sample class maps: (batch, n_classes, n_nodes * max_peaks, 1)
|
|
175
|
-
sampled = F.grid_sample(
|
|
176
|
-
class_maps,
|
|
177
|
-
grid_flat,
|
|
178
|
-
mode="bilinear",
|
|
179
|
-
padding_mode="zeros",
|
|
180
|
-
align_corners=True,
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
# Reshape to (batch, n_nodes, max_peaks, n_classes)
|
|
184
|
-
sampled = sampled.squeeze(-1) # (batch, n_classes, n_nodes * max_peaks)
|
|
185
|
-
sampled = sampled.permute(0, 2, 1) # (batch, n_nodes * max_peaks, n_classes)
|
|
186
|
-
sampled = sampled.reshape(batch_size, n_nodes, max_peaks, n_classes)
|
|
187
|
-
|
|
188
|
-
# Apply softmax to get probabilities (optional - depends on training)
|
|
189
|
-
# For now, return raw values as the grouping function expects logits
|
|
190
|
-
class_probs = sampled
|
|
191
|
-
|
|
192
|
-
# Mask invalid peaks
|
|
193
|
-
class_probs = class_probs * peak_mask.unsqueeze(-1).float()
|
|
194
|
-
|
|
195
|
-
return class_probs
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
"""Centered-instance ONNX wrapper."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Dict
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import functional as F
|
|
10
|
-
|
|
11
|
-
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class CenteredInstanceONNXWrapper(BaseExportWrapper):
|
|
15
|
-
"""ONNX-exportable wrapper for centered-instance models.
|
|
16
|
-
|
|
17
|
-
Expects input images as uint8 tensors in [0, 255].
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
model: nn.Module,
|
|
23
|
-
output_stride: int = 4,
|
|
24
|
-
input_scale: float = 1.0,
|
|
25
|
-
):
|
|
26
|
-
"""Initialize centered instance ONNX wrapper.
|
|
27
|
-
|
|
28
|
-
Args:
|
|
29
|
-
model: Centered instance model for pose estimation.
|
|
30
|
-
output_stride: Output stride for confidence maps.
|
|
31
|
-
input_scale: Input scaling factor.
|
|
32
|
-
"""
|
|
33
|
-
super().__init__(model)
|
|
34
|
-
self.output_stride = output_stride
|
|
35
|
-
self.input_scale = input_scale
|
|
36
|
-
|
|
37
|
-
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
38
|
-
"""Run centered-instance inference on crops."""
|
|
39
|
-
image = self._normalize_uint8(image)
|
|
40
|
-
if self.input_scale != 1.0:
|
|
41
|
-
height = int(image.shape[-2] * self.input_scale)
|
|
42
|
-
width = int(image.shape[-1] * self.input_scale)
|
|
43
|
-
image = F.interpolate(
|
|
44
|
-
image, size=(height, width), mode="bilinear", align_corners=False
|
|
45
|
-
)
|
|
46
|
-
|
|
47
|
-
confmaps = self._extract_tensor(
|
|
48
|
-
self.model(image), ["centered", "instance", "confmap"]
|
|
49
|
-
)
|
|
50
|
-
peaks, values = self._find_global_peaks(confmaps)
|
|
51
|
-
peaks = peaks * (self.output_stride / self.input_scale)
|
|
52
|
-
|
|
53
|
-
return {
|
|
54
|
-
"peaks": peaks,
|
|
55
|
-
"peak_vals": values,
|
|
56
|
-
}
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
"""Centroid ONNX wrapper."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Dict
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import functional as F
|
|
10
|
-
|
|
11
|
-
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class CentroidONNXWrapper(BaseExportWrapper):
|
|
15
|
-
"""ONNX-exportable wrapper for centroid models.
|
|
16
|
-
|
|
17
|
-
Expects input images as uint8 tensors in [0, 255].
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
model: nn.Module,
|
|
23
|
-
max_instances: int = 20,
|
|
24
|
-
output_stride: int = 2,
|
|
25
|
-
input_scale: float = 1.0,
|
|
26
|
-
):
|
|
27
|
-
"""Initialize centroid ONNX wrapper.
|
|
28
|
-
|
|
29
|
-
Args:
|
|
30
|
-
model: Centroid detection model.
|
|
31
|
-
max_instances: Maximum number of instances to detect.
|
|
32
|
-
output_stride: Output stride for confidence maps.
|
|
33
|
-
input_scale: Input scaling factor.
|
|
34
|
-
"""
|
|
35
|
-
super().__init__(model)
|
|
36
|
-
self.max_instances = max_instances
|
|
37
|
-
self.output_stride = output_stride
|
|
38
|
-
self.input_scale = input_scale
|
|
39
|
-
|
|
40
|
-
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
41
|
-
"""Run centroid inference and return fixed-size outputs."""
|
|
42
|
-
image = self._normalize_uint8(image)
|
|
43
|
-
if self.input_scale != 1.0:
|
|
44
|
-
height = int(image.shape[-2] * self.input_scale)
|
|
45
|
-
width = int(image.shape[-1] * self.input_scale)
|
|
46
|
-
image = F.interpolate(
|
|
47
|
-
image, size=(height, width), mode="bilinear", align_corners=False
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
confmaps = self._extract_tensor(self.model(image), ["centroid", "confmap"])
|
|
51
|
-
peaks, values, valid = self._find_topk_peaks(confmaps, self.max_instances)
|
|
52
|
-
peaks = peaks * (self.output_stride / self.input_scale)
|
|
53
|
-
|
|
54
|
-
return {
|
|
55
|
-
"centroids": peaks,
|
|
56
|
-
"centroid_vals": values,
|
|
57
|
-
"instance_valid": valid,
|
|
58
|
-
}
|
|
@@ -1,83 +0,0 @@
|
|
|
1
|
-
"""Single-instance ONNX wrapper."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Dict
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import functional as F
|
|
10
|
-
|
|
11
|
-
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class SingleInstanceONNXWrapper(BaseExportWrapper):
|
|
15
|
-
"""ONNX-exportable wrapper for single-instance models.
|
|
16
|
-
|
|
17
|
-
This wrapper handles full-frame inference assuming a single instance per frame.
|
|
18
|
-
For each body part (channel), it finds the global maximum in the confidence map.
|
|
19
|
-
|
|
20
|
-
Expects input images as uint8 tensors in [0, 255].
|
|
21
|
-
|
|
22
|
-
Attributes:
|
|
23
|
-
model: The trained backbone model that outputs confidence maps.
|
|
24
|
-
output_stride: Output stride of the model (e.g., 4 means confmaps are 1/4 the
|
|
25
|
-
input resolution).
|
|
26
|
-
input_scale: Factor to scale input images before inference.
|
|
27
|
-
"""
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
model: nn.Module,
|
|
32
|
-
output_stride: int = 4,
|
|
33
|
-
input_scale: float = 1.0,
|
|
34
|
-
):
|
|
35
|
-
"""Initialize the single-instance wrapper.
|
|
36
|
-
|
|
37
|
-
Args:
|
|
38
|
-
model: The trained backbone model.
|
|
39
|
-
output_stride: Output stride of the model. Default: 4.
|
|
40
|
-
input_scale: Factor to scale input images. Default: 1.0.
|
|
41
|
-
"""
|
|
42
|
-
super().__init__(model)
|
|
43
|
-
self.output_stride = output_stride
|
|
44
|
-
self.input_scale = input_scale
|
|
45
|
-
|
|
46
|
-
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
47
|
-
"""Run single-instance inference.
|
|
48
|
-
|
|
49
|
-
Args:
|
|
50
|
-
image: Input image tensor of shape (batch, channels, height, width).
|
|
51
|
-
Expected as uint8 [0, 255] values.
|
|
52
|
-
|
|
53
|
-
Returns:
|
|
54
|
-
Dictionary with:
|
|
55
|
-
peaks: Peak coordinates of shape (batch, n_nodes, 2) in (x, y) format.
|
|
56
|
-
peak_vals: Peak confidence values of shape (batch, n_nodes).
|
|
57
|
-
"""
|
|
58
|
-
# Normalize uint8 [0, 255] to float32 [0, 1]
|
|
59
|
-
image = self._normalize_uint8(image)
|
|
60
|
-
|
|
61
|
-
# Apply input scaling if needed
|
|
62
|
-
if self.input_scale != 1.0:
|
|
63
|
-
height = int(image.shape[-2] * self.input_scale)
|
|
64
|
-
width = int(image.shape[-1] * self.input_scale)
|
|
65
|
-
image = F.interpolate(
|
|
66
|
-
image, size=(height, width), mode="bilinear", align_corners=False
|
|
67
|
-
)
|
|
68
|
-
|
|
69
|
-
# Run model to get confidence maps: (batch, n_nodes, height, width)
|
|
70
|
-
confmaps = self._extract_tensor(
|
|
71
|
-
self.model(image), ["single", "instance", "confmap"]
|
|
72
|
-
)
|
|
73
|
-
|
|
74
|
-
# Find global peak for each channel: (batch, n_nodes, 2), (batch, n_nodes)
|
|
75
|
-
peaks, values = self._find_global_peaks(confmaps)
|
|
76
|
-
|
|
77
|
-
# Scale peaks from confmap coordinates to image coordinates
|
|
78
|
-
peaks = peaks * (self.output_stride / self.input_scale)
|
|
79
|
-
|
|
80
|
-
return {
|
|
81
|
-
"peaks": peaks,
|
|
82
|
-
"peak_vals": values,
|
|
83
|
-
}
|
|
@@ -1,180 +0,0 @@
|
|
|
1
|
-
"""Top-down ONNX wrapper."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from typing import Dict, Tuple
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
from torch import nn
|
|
9
|
-
from torch.nn import functional as F
|
|
10
|
-
|
|
11
|
-
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
class TopDownONNXWrapper(BaseExportWrapper):
|
|
15
|
-
"""ONNX-exportable wrapper for top-down (centroid + centered-instance) inference.
|
|
16
|
-
|
|
17
|
-
Expects input images as uint8 tensors in [0, 255].
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
centroid_model: nn.Module,
|
|
23
|
-
instance_model: nn.Module,
|
|
24
|
-
max_instances: int = 20,
|
|
25
|
-
crop_size: Tuple[int, int] = (192, 192),
|
|
26
|
-
centroid_output_stride: int = 2,
|
|
27
|
-
instance_output_stride: int = 4,
|
|
28
|
-
centroid_input_scale: float = 1.0,
|
|
29
|
-
instance_input_scale: float = 1.0,
|
|
30
|
-
n_nodes: int = 1,
|
|
31
|
-
) -> None:
|
|
32
|
-
"""Initialize top-down ONNX wrapper.
|
|
33
|
-
|
|
34
|
-
Args:
|
|
35
|
-
centroid_model: Centroid detection model.
|
|
36
|
-
instance_model: Instance pose estimation model.
|
|
37
|
-
max_instances: Maximum number of instances to detect.
|
|
38
|
-
crop_size: Size of instance crops (height, width).
|
|
39
|
-
centroid_output_stride: Centroid model output stride.
|
|
40
|
-
instance_output_stride: Instance model output stride.
|
|
41
|
-
centroid_input_scale: Centroid input scaling factor.
|
|
42
|
-
instance_input_scale: Instance input scaling factor.
|
|
43
|
-
n_nodes: Number of skeleton nodes.
|
|
44
|
-
"""
|
|
45
|
-
super().__init__(centroid_model)
|
|
46
|
-
self.centroid_model = centroid_model
|
|
47
|
-
self.instance_model = instance_model
|
|
48
|
-
self.max_instances = max_instances
|
|
49
|
-
self.crop_size = crop_size
|
|
50
|
-
self.centroid_output_stride = centroid_output_stride
|
|
51
|
-
self.instance_output_stride = instance_output_stride
|
|
52
|
-
self.centroid_input_scale = centroid_input_scale
|
|
53
|
-
self.instance_input_scale = instance_input_scale
|
|
54
|
-
self.n_nodes = n_nodes
|
|
55
|
-
|
|
56
|
-
crop_h, crop_w = crop_size
|
|
57
|
-
y_crop = torch.linspace(-1, 1, crop_h, dtype=torch.float32)
|
|
58
|
-
x_crop = torch.linspace(-1, 1, crop_w, dtype=torch.float32)
|
|
59
|
-
grid_y, grid_x = torch.meshgrid(y_crop, x_crop, indexing="ij")
|
|
60
|
-
base_grid = torch.stack([grid_x, grid_y], dim=-1)
|
|
61
|
-
self.register_buffer("base_grid", base_grid, persistent=False)
|
|
62
|
-
|
|
63
|
-
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
64
|
-
"""Run top-down inference and return fixed-size outputs."""
|
|
65
|
-
image = self._normalize_uint8(image)
|
|
66
|
-
batch_size, channels, height, width = image.shape
|
|
67
|
-
|
|
68
|
-
scaled_image = image
|
|
69
|
-
if self.centroid_input_scale != 1.0:
|
|
70
|
-
scaled_h = int(height * self.centroid_input_scale)
|
|
71
|
-
scaled_w = int(width * self.centroid_input_scale)
|
|
72
|
-
scaled_image = F.interpolate(
|
|
73
|
-
scaled_image,
|
|
74
|
-
size=(scaled_h, scaled_w),
|
|
75
|
-
mode="bilinear",
|
|
76
|
-
align_corners=False,
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
centroid_out = self.centroid_model(scaled_image)
|
|
80
|
-
centroid_cms = self._extract_tensor(centroid_out, ["centroid", "confmap"])
|
|
81
|
-
|
|
82
|
-
centroids, centroid_vals, instance_valid = self._find_topk_peaks(
|
|
83
|
-
centroid_cms, self.max_instances
|
|
84
|
-
)
|
|
85
|
-
centroids = centroids * (
|
|
86
|
-
self.centroid_output_stride / self.centroid_input_scale
|
|
87
|
-
)
|
|
88
|
-
|
|
89
|
-
crops = self._extract_crops(image, centroids)
|
|
90
|
-
crops_flat = crops.reshape(
|
|
91
|
-
batch_size * self.max_instances,
|
|
92
|
-
channels,
|
|
93
|
-
self.crop_size[0],
|
|
94
|
-
self.crop_size[1],
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
if self.instance_input_scale != 1.0:
|
|
98
|
-
scaled_h = int(self.crop_size[0] * self.instance_input_scale)
|
|
99
|
-
scaled_w = int(self.crop_size[1] * self.instance_input_scale)
|
|
100
|
-
crops_flat = F.interpolate(
|
|
101
|
-
crops_flat,
|
|
102
|
-
size=(scaled_h, scaled_w),
|
|
103
|
-
mode="bilinear",
|
|
104
|
-
align_corners=False,
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
instance_out = self.instance_model(crops_flat)
|
|
108
|
-
instance_cms = self._extract_tensor(
|
|
109
|
-
instance_out, ["centered", "instance", "confmap"]
|
|
110
|
-
)
|
|
111
|
-
|
|
112
|
-
crop_peaks, crop_peak_vals = self._find_global_peaks(instance_cms)
|
|
113
|
-
crop_peaks = crop_peaks * (
|
|
114
|
-
self.instance_output_stride / self.instance_input_scale
|
|
115
|
-
)
|
|
116
|
-
|
|
117
|
-
crop_peaks = crop_peaks.reshape(batch_size, self.max_instances, self.n_nodes, 2)
|
|
118
|
-
peak_vals = crop_peak_vals.reshape(batch_size, self.max_instances, self.n_nodes)
|
|
119
|
-
|
|
120
|
-
crop_offset = centroids.unsqueeze(2) - image.new_tensor(
|
|
121
|
-
[self.crop_size[1] / 2.0, self.crop_size[0] / 2.0]
|
|
122
|
-
)
|
|
123
|
-
peaks = crop_peaks + crop_offset
|
|
124
|
-
|
|
125
|
-
invalid_mask = ~instance_valid
|
|
126
|
-
centroids = centroids.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
|
|
127
|
-
centroid_vals = centroid_vals.masked_fill(invalid_mask, 0.0)
|
|
128
|
-
peaks = peaks.masked_fill(invalid_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
|
|
129
|
-
peak_vals = peak_vals.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
|
|
130
|
-
|
|
131
|
-
return {
|
|
132
|
-
"centroids": centroids,
|
|
133
|
-
"centroid_vals": centroid_vals,
|
|
134
|
-
"peaks": peaks,
|
|
135
|
-
"peak_vals": peak_vals,
|
|
136
|
-
"instance_valid": instance_valid,
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
def _extract_crops(
|
|
140
|
-
self,
|
|
141
|
-
image: torch.Tensor,
|
|
142
|
-
centroids: torch.Tensor,
|
|
143
|
-
) -> torch.Tensor:
|
|
144
|
-
"""Extract crops around centroids using grid_sample."""
|
|
145
|
-
batch_size, channels, height, width = image.shape
|
|
146
|
-
crop_h, crop_w = self.crop_size
|
|
147
|
-
n_instances = centroids.shape[1]
|
|
148
|
-
|
|
149
|
-
scale_x = crop_w / width
|
|
150
|
-
scale_y = crop_h / height
|
|
151
|
-
scale = image.new_tensor([scale_x, scale_y])
|
|
152
|
-
base_grid = self.base_grid.to(device=image.device, dtype=image.dtype)
|
|
153
|
-
scaled_grid = base_grid * scale
|
|
154
|
-
|
|
155
|
-
scaled_grid = scaled_grid.unsqueeze(0).unsqueeze(0)
|
|
156
|
-
scaled_grid = scaled_grid.expand(batch_size, n_instances, -1, -1, -1)
|
|
157
|
-
|
|
158
|
-
norm_centroids = torch.zeros_like(centroids)
|
|
159
|
-
norm_centroids[..., 0] = (centroids[..., 0] / (width - 1)) * 2 - 1
|
|
160
|
-
norm_centroids[..., 1] = (centroids[..., 1] / (height - 1)) * 2 - 1
|
|
161
|
-
offset = norm_centroids.unsqueeze(2).unsqueeze(2)
|
|
162
|
-
|
|
163
|
-
sample_grid = scaled_grid + offset
|
|
164
|
-
|
|
165
|
-
image_expanded = image.unsqueeze(1).expand(-1, n_instances, -1, -1, -1)
|
|
166
|
-
image_flat = image_expanded.reshape(
|
|
167
|
-
batch_size * n_instances, channels, height, width
|
|
168
|
-
)
|
|
169
|
-
grid_flat = sample_grid.reshape(batch_size * n_instances, crop_h, crop_w, 2)
|
|
170
|
-
|
|
171
|
-
crops_flat = F.grid_sample(
|
|
172
|
-
image_flat,
|
|
173
|
-
grid_flat,
|
|
174
|
-
mode="bilinear",
|
|
175
|
-
padding_mode="zeros",
|
|
176
|
-
align_corners=True,
|
|
177
|
-
)
|
|
178
|
-
|
|
179
|
-
crops = crops_flat.reshape(batch_size, n_instances, channels, crop_h, crop_w)
|
|
180
|
-
return crops
|