sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
"""ONNX wrapper for top-down multiclass (supervised ID) models."""
|
|
2
|
+
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
|
|
9
|
+
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class TopDownMultiClassONNXWrapper(BaseExportWrapper):
|
|
13
|
+
"""ONNX-exportable wrapper for top-down multiclass (supervised ID) models.
|
|
14
|
+
|
|
15
|
+
This wrapper handles models that output both confidence maps for keypoint
|
|
16
|
+
detection and class logits for identity classification. It runs on instance
|
|
17
|
+
crops (centered around detected centroids).
|
|
18
|
+
|
|
19
|
+
Expects input images as uint8 tensors in [0, 255].
|
|
20
|
+
|
|
21
|
+
Attributes:
|
|
22
|
+
model: The underlying PyTorch model (centered instance + class vectors heads).
|
|
23
|
+
output_stride: Output stride of the confmap head.
|
|
24
|
+
input_scale: Scale factor applied to input images before inference.
|
|
25
|
+
n_classes: Number of identity classes.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
model: nn.Module,
|
|
31
|
+
output_stride: int = 2,
|
|
32
|
+
input_scale: float = 1.0,
|
|
33
|
+
n_classes: int = 2,
|
|
34
|
+
):
|
|
35
|
+
"""Initialize the wrapper.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model: The underlying PyTorch model.
|
|
39
|
+
output_stride: Output stride of the confidence maps.
|
|
40
|
+
input_scale: Scale factor for input images.
|
|
41
|
+
n_classes: Number of identity classes (e.g., 2 for male/female).
|
|
42
|
+
"""
|
|
43
|
+
super().__init__(model)
|
|
44
|
+
self.output_stride = output_stride
|
|
45
|
+
self.input_scale = input_scale
|
|
46
|
+
self.n_classes = n_classes
|
|
47
|
+
|
|
48
|
+
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
49
|
+
"""Run top-down multiclass inference on crops.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
image: Input image tensor of shape (batch, channels, height, width).
|
|
53
|
+
Expected to be uint8 in [0, 255].
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
Dictionary with keys:
|
|
57
|
+
- "peaks": Predicted peak coordinates (batch, n_nodes, 2) in (x, y).
|
|
58
|
+
- "peak_vals": Peak confidence values (batch, n_nodes).
|
|
59
|
+
- "class_logits": Raw class logits (batch, n_classes).
|
|
60
|
+
|
|
61
|
+
The class assignment is done on CPU using Hungarian matching
|
|
62
|
+
via `get_class_inds_from_vectors()`.
|
|
63
|
+
"""
|
|
64
|
+
# Normalize uint8 [0, 255] to float32 [0, 1]
|
|
65
|
+
image = self._normalize_uint8(image)
|
|
66
|
+
|
|
67
|
+
# Apply input scaling if needed
|
|
68
|
+
if self.input_scale != 1.0:
|
|
69
|
+
height = int(image.shape[-2] * self.input_scale)
|
|
70
|
+
width = int(image.shape[-1] * self.input_scale)
|
|
71
|
+
image = F.interpolate(
|
|
72
|
+
image, size=(height, width), mode="bilinear", align_corners=False
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Forward pass
|
|
76
|
+
out = self.model(image)
|
|
77
|
+
|
|
78
|
+
# Extract outputs
|
|
79
|
+
confmaps = self._extract_tensor(out, ["centered", "instance", "confmap"])
|
|
80
|
+
class_logits = self._extract_tensor(out, ["class", "vector"])
|
|
81
|
+
|
|
82
|
+
# Find global peaks (one per node)
|
|
83
|
+
peaks, peak_vals = self._find_global_peaks(confmaps)
|
|
84
|
+
|
|
85
|
+
# Scale peaks back to input coordinates
|
|
86
|
+
peaks = peaks * (self.output_stride / self.input_scale)
|
|
87
|
+
|
|
88
|
+
return {
|
|
89
|
+
"peaks": peaks,
|
|
90
|
+
"peak_vals": peak_vals,
|
|
91
|
+
"class_logits": class_logits,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class TopDownMultiClassCombinedONNXWrapper(BaseExportWrapper):
|
|
96
|
+
"""ONNX-exportable wrapper for combined centroid + multiclass instance models.
|
|
97
|
+
|
|
98
|
+
This wrapper combines a centroid detection model with a centered instance
|
|
99
|
+
multiclass model. It performs:
|
|
100
|
+
1. Centroid detection on full images
|
|
101
|
+
2. Cropping around each centroid using vectorized grid_sample
|
|
102
|
+
3. Instance keypoint detection + identity classification on each crop
|
|
103
|
+
|
|
104
|
+
Expects input images as uint8 tensors in [0, 255].
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
def __init__(
|
|
108
|
+
self,
|
|
109
|
+
centroid_model: nn.Module,
|
|
110
|
+
instance_model: nn.Module,
|
|
111
|
+
max_instances: int = 20,
|
|
112
|
+
crop_size: tuple = (192, 192),
|
|
113
|
+
centroid_output_stride: int = 4,
|
|
114
|
+
instance_output_stride: int = 2,
|
|
115
|
+
centroid_input_scale: float = 1.0,
|
|
116
|
+
instance_input_scale: float = 1.0,
|
|
117
|
+
n_nodes: int = 13,
|
|
118
|
+
n_classes: int = 2,
|
|
119
|
+
):
|
|
120
|
+
"""Initialize the combined wrapper.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
centroid_model: Model for centroid detection.
|
|
124
|
+
instance_model: Model for instance keypoints + class prediction.
|
|
125
|
+
max_instances: Maximum number of instances to detect.
|
|
126
|
+
crop_size: Size of crops around centroids (height, width).
|
|
127
|
+
centroid_output_stride: Output stride of centroid model.
|
|
128
|
+
instance_output_stride: Output stride of instance model.
|
|
129
|
+
centroid_input_scale: Input scale for centroid model.
|
|
130
|
+
instance_input_scale: Input scale for instance model.
|
|
131
|
+
n_nodes: Number of keypoint nodes per instance.
|
|
132
|
+
n_classes: Number of identity classes.
|
|
133
|
+
"""
|
|
134
|
+
super().__init__(centroid_model) # Primary model is centroid
|
|
135
|
+
self.instance_model = instance_model
|
|
136
|
+
self.max_instances = max_instances
|
|
137
|
+
self.crop_size = crop_size
|
|
138
|
+
self.centroid_output_stride = centroid_output_stride
|
|
139
|
+
self.instance_output_stride = instance_output_stride
|
|
140
|
+
self.centroid_input_scale = centroid_input_scale
|
|
141
|
+
self.instance_input_scale = instance_input_scale
|
|
142
|
+
self.n_nodes = n_nodes
|
|
143
|
+
self.n_classes = n_classes
|
|
144
|
+
|
|
145
|
+
# Pre-compute base grid for crop extraction (same as TopDownONNXWrapper)
|
|
146
|
+
crop_h, crop_w = crop_size
|
|
147
|
+
y_crop = torch.linspace(-1, 1, crop_h, dtype=torch.float32)
|
|
148
|
+
x_crop = torch.linspace(-1, 1, crop_w, dtype=torch.float32)
|
|
149
|
+
grid_y, grid_x = torch.meshgrid(y_crop, x_crop, indexing="ij")
|
|
150
|
+
base_grid = torch.stack([grid_x, grid_y], dim=-1)
|
|
151
|
+
self.register_buffer("base_grid", base_grid, persistent=False)
|
|
152
|
+
|
|
153
|
+
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
154
|
+
"""Run combined top-down multiclass inference.
|
|
155
|
+
|
|
156
|
+
Args:
|
|
157
|
+
image: Input image tensor of shape (batch, channels, height, width).
|
|
158
|
+
Expected to be uint8 in [0, 255].
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Dictionary with keys:
|
|
162
|
+
- "centroids": Detected centroids (batch, max_instances, 2).
|
|
163
|
+
- "centroid_vals": Centroid confidence values (batch, max_instances).
|
|
164
|
+
- "peaks": Instance peaks (batch, max_instances, n_nodes, 2).
|
|
165
|
+
- "peak_vals": Peak values (batch, max_instances, n_nodes).
|
|
166
|
+
- "class_logits": Class logits per instance (batch, max_instances, n_classes).
|
|
167
|
+
- "instance_valid": Validity mask (batch, max_instances).
|
|
168
|
+
"""
|
|
169
|
+
# Normalize input
|
|
170
|
+
image = self._normalize_uint8(image)
|
|
171
|
+
batch_size, channels, height, width = image.shape
|
|
172
|
+
|
|
173
|
+
# Apply centroid input scaling
|
|
174
|
+
scaled_image = image
|
|
175
|
+
if self.centroid_input_scale != 1.0:
|
|
176
|
+
scaled_h = int(height * self.centroid_input_scale)
|
|
177
|
+
scaled_w = int(width * self.centroid_input_scale)
|
|
178
|
+
scaled_image = F.interpolate(
|
|
179
|
+
scaled_image,
|
|
180
|
+
size=(scaled_h, scaled_w),
|
|
181
|
+
mode="bilinear",
|
|
182
|
+
align_corners=False,
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Centroid detection
|
|
186
|
+
centroid_out = self.model(scaled_image)
|
|
187
|
+
centroid_cms = self._extract_tensor(centroid_out, ["centroid", "confmap"])
|
|
188
|
+
centroids, centroid_vals, instance_valid = self._find_topk_peaks(
|
|
189
|
+
centroid_cms, self.max_instances
|
|
190
|
+
)
|
|
191
|
+
centroids = centroids * (
|
|
192
|
+
self.centroid_output_stride / self.centroid_input_scale
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# Extract crops using vectorized grid_sample (same as TopDownONNXWrapper)
|
|
196
|
+
crops = self._extract_crops(image, centroids)
|
|
197
|
+
crops_flat = crops.reshape(
|
|
198
|
+
batch_size * self.max_instances,
|
|
199
|
+
channels,
|
|
200
|
+
self.crop_size[0],
|
|
201
|
+
self.crop_size[1],
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
# Apply instance input scaling if needed
|
|
205
|
+
if self.instance_input_scale != 1.0:
|
|
206
|
+
scaled_h = int(self.crop_size[0] * self.instance_input_scale)
|
|
207
|
+
scaled_w = int(self.crop_size[1] * self.instance_input_scale)
|
|
208
|
+
crops_flat = F.interpolate(
|
|
209
|
+
crops_flat,
|
|
210
|
+
size=(scaled_h, scaled_w),
|
|
211
|
+
mode="bilinear",
|
|
212
|
+
align_corners=False,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Instance model forward (batch all crops)
|
|
216
|
+
instance_out = self.instance_model(crops_flat)
|
|
217
|
+
instance_cms = self._extract_tensor(
|
|
218
|
+
instance_out, ["centered", "instance", "confmap"]
|
|
219
|
+
)
|
|
220
|
+
instance_class = self._extract_tensor(instance_out, ["class", "vector"])
|
|
221
|
+
|
|
222
|
+
# Find peaks in all crops
|
|
223
|
+
crop_peaks, crop_peak_vals = self._find_global_peaks(instance_cms)
|
|
224
|
+
crop_peaks = crop_peaks * (
|
|
225
|
+
self.instance_output_stride / self.instance_input_scale
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Reshape to batch x instances x nodes x 2
|
|
229
|
+
crop_peaks = crop_peaks.reshape(batch_size, self.max_instances, self.n_nodes, 2)
|
|
230
|
+
peak_vals = crop_peak_vals.reshape(batch_size, self.max_instances, self.n_nodes)
|
|
231
|
+
|
|
232
|
+
# Reshape class logits
|
|
233
|
+
class_logits = instance_class.reshape(
|
|
234
|
+
batch_size, self.max_instances, self.n_classes
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Transform peaks from crop coordinates to full image coordinates
|
|
238
|
+
crop_offset = centroids.unsqueeze(2) - image.new_tensor(
|
|
239
|
+
[self.crop_size[1] / 2.0, self.crop_size[0] / 2.0]
|
|
240
|
+
)
|
|
241
|
+
peaks = crop_peaks + crop_offset
|
|
242
|
+
|
|
243
|
+
# Zero out invalid instances
|
|
244
|
+
invalid_mask = ~instance_valid
|
|
245
|
+
centroids = centroids.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
|
|
246
|
+
centroid_vals = centroid_vals.masked_fill(invalid_mask, 0.0)
|
|
247
|
+
peaks = peaks.masked_fill(invalid_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
|
|
248
|
+
peak_vals = peak_vals.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
|
|
249
|
+
class_logits = class_logits.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
|
|
250
|
+
|
|
251
|
+
return {
|
|
252
|
+
"centroids": centroids,
|
|
253
|
+
"centroid_vals": centroid_vals,
|
|
254
|
+
"peaks": peaks,
|
|
255
|
+
"peak_vals": peak_vals,
|
|
256
|
+
"class_logits": class_logits,
|
|
257
|
+
"instance_valid": instance_valid,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
def _extract_crops(
|
|
261
|
+
self,
|
|
262
|
+
image: torch.Tensor,
|
|
263
|
+
centroids: torch.Tensor,
|
|
264
|
+
) -> torch.Tensor:
|
|
265
|
+
"""Extract crops around centroids using grid_sample.
|
|
266
|
+
|
|
267
|
+
This is the same vectorized implementation as TopDownONNXWrapper.
|
|
268
|
+
"""
|
|
269
|
+
batch_size, channels, height, width = image.shape
|
|
270
|
+
crop_h, crop_w = self.crop_size
|
|
271
|
+
n_instances = centroids.shape[1]
|
|
272
|
+
|
|
273
|
+
scale_x = crop_w / width
|
|
274
|
+
scale_y = crop_h / height
|
|
275
|
+
scale = image.new_tensor([scale_x, scale_y])
|
|
276
|
+
base_grid = self.base_grid.to(device=image.device, dtype=image.dtype)
|
|
277
|
+
scaled_grid = base_grid * scale
|
|
278
|
+
|
|
279
|
+
scaled_grid = scaled_grid.unsqueeze(0).unsqueeze(0)
|
|
280
|
+
scaled_grid = scaled_grid.expand(batch_size, n_instances, -1, -1, -1)
|
|
281
|
+
|
|
282
|
+
norm_centroids = torch.zeros_like(centroids)
|
|
283
|
+
norm_centroids[..., 0] = (centroids[..., 0] / (width - 1)) * 2 - 1
|
|
284
|
+
norm_centroids[..., 1] = (centroids[..., 1] / (height - 1)) * 2 - 1
|
|
285
|
+
offset = norm_centroids.unsqueeze(2).unsqueeze(2)
|
|
286
|
+
|
|
287
|
+
sample_grid = scaled_grid + offset
|
|
288
|
+
|
|
289
|
+
image_expanded = image.unsqueeze(1).expand(-1, n_instances, -1, -1, -1)
|
|
290
|
+
image_flat = image_expanded.reshape(
|
|
291
|
+
batch_size * n_instances, channels, height, width
|
|
292
|
+
)
|
|
293
|
+
grid_flat = sample_grid.reshape(batch_size * n_instances, crop_h, crop_w, 2)
|
|
294
|
+
|
|
295
|
+
crops_flat = F.grid_sample(
|
|
296
|
+
image_flat,
|
|
297
|
+
grid_flat,
|
|
298
|
+
mode="bilinear",
|
|
299
|
+
padding_mode="zeros",
|
|
300
|
+
align_corners=True,
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
crops = crops_flat.reshape(batch_size, n_instances, channels, crop_h, crop_w)
|
|
304
|
+
return crops
|
sleap_nn/inference/__init__.py
CHANGED
sleap_nn/inference/bottomup.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Inference modules for BottomUp models."""
|
|
2
2
|
|
|
3
|
+
import logging
|
|
3
4
|
from typing import Dict, Optional
|
|
4
5
|
import torch
|
|
5
6
|
import lightning as L
|
|
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
|
|
|
7
8
|
from sleap_nn.inference.paf_grouping import PAFScorer
|
|
8
9
|
from sleap_nn.inference.identity import classify_peaks_from_maps
|
|
9
10
|
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
10
13
|
|
|
11
14
|
class BottomUpInferenceModel(L.LightningModule):
|
|
12
15
|
"""BottomUp Inference model.
|
|
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
63
66
|
return_pafs: Optional[bool] = False,
|
|
64
67
|
return_paf_graph: Optional[bool] = False,
|
|
65
68
|
input_scale: float = 1.0,
|
|
69
|
+
max_peaks_per_node: Optional[int] = None,
|
|
66
70
|
):
|
|
67
|
-
"""Initialise the model attributes.
|
|
71
|
+
"""Initialise the model attributes.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
torch_model: A `nn.Module` that accepts images and predicts confidence maps.
|
|
75
|
+
paf_scorer: A `PAFScorer` instance for grouping instances.
|
|
76
|
+
cms_output_stride: Output stride of confidence maps relative to images.
|
|
77
|
+
pafs_output_stride: Output stride of PAFs relative to images.
|
|
78
|
+
peak_threshold: Minimum confidence map value for valid peaks.
|
|
79
|
+
refinement: Peak refinement method: None, "integral", or "local".
|
|
80
|
+
integral_patch_size: Size of patches for integral refinement.
|
|
81
|
+
return_confmaps: If True, return confidence maps in output.
|
|
82
|
+
return_pafs: If True, return PAFs in output.
|
|
83
|
+
return_paf_graph: If True, return intermediate PAF graph in output.
|
|
84
|
+
input_scale: Scale factor applied to input images.
|
|
85
|
+
max_peaks_per_node: Maximum number of peaks allowed per node before
|
|
86
|
+
skipping PAF scoring. If any node has more peaks than this limit,
|
|
87
|
+
empty predictions are returned. This prevents combinatorial explosion
|
|
88
|
+
during early training when confidence maps are noisy. Set to None to
|
|
89
|
+
disable this check (default). Recommended value: 100.
|
|
90
|
+
"""
|
|
68
91
|
super().__init__()
|
|
69
92
|
self.torch_model = torch_model
|
|
70
93
|
self.paf_scorer = paf_scorer
|
|
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
77
100
|
self.return_pafs = return_pafs
|
|
78
101
|
self.return_paf_graph = return_paf_graph
|
|
79
102
|
self.input_scale = input_scale
|
|
103
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
80
104
|
|
|
81
105
|
def _generate_cms_peaks(self, cms):
|
|
82
106
|
# TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
|
|
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
|
|
|
124
148
|
) # (batch, h, w, 2*edges)
|
|
125
149
|
cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
|
|
126
150
|
|
|
127
|
-
(
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
151
|
+
# Check if too many peaks per node (prevents combinatorial explosion)
|
|
152
|
+
skip_paf_scoring = False
|
|
153
|
+
if self.max_peaks_per_node is not None:
|
|
154
|
+
n_nodes = cms.shape[1]
|
|
155
|
+
for b in range(self.batch_size):
|
|
156
|
+
for node_idx in range(n_nodes):
|
|
157
|
+
n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
|
|
158
|
+
if n_peaks > self.max_peaks_per_node:
|
|
159
|
+
logger.warning(
|
|
160
|
+
f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
|
|
161
|
+
f"(max_peaks_per_node={self.max_peaks_per_node}). "
|
|
162
|
+
f"Model may need more training."
|
|
163
|
+
)
|
|
164
|
+
skip_paf_scoring = True
|
|
165
|
+
break
|
|
166
|
+
if skip_paf_scoring:
|
|
167
|
+
break
|
|
168
|
+
|
|
169
|
+
if skip_paf_scoring:
|
|
170
|
+
# Return empty predictions for each sample
|
|
171
|
+
device = cms.device
|
|
172
|
+
n_nodes = cms.shape[1]
|
|
173
|
+
predicted_instances_adjusted = []
|
|
174
|
+
predicted_peak_scores = []
|
|
175
|
+
predicted_instance_scores = []
|
|
176
|
+
for _ in range(self.batch_size):
|
|
177
|
+
predicted_instances_adjusted.append(
|
|
178
|
+
torch.full((0, n_nodes, 2), float("nan"), device=device)
|
|
179
|
+
)
|
|
180
|
+
predicted_peak_scores.append(
|
|
181
|
+
torch.full((0, n_nodes), float("nan"), device=device)
|
|
182
|
+
)
|
|
183
|
+
predicted_instance_scores.append(torch.tensor([], device=device))
|
|
184
|
+
edge_inds = [
|
|
185
|
+
torch.tensor([], dtype=torch.int32, device=device)
|
|
186
|
+
] * self.batch_size
|
|
187
|
+
edge_peak_inds = [
|
|
188
|
+
torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
|
|
189
|
+
] * self.batch_size
|
|
190
|
+
line_scores = [torch.tensor([], device=device)] * self.batch_size
|
|
191
|
+
else:
|
|
192
|
+
(
|
|
193
|
+
predicted_instances,
|
|
194
|
+
predicted_peak_scores,
|
|
195
|
+
predicted_instance_scores,
|
|
196
|
+
edge_inds,
|
|
197
|
+
edge_peak_inds,
|
|
198
|
+
line_scores,
|
|
199
|
+
) = self.paf_scorer.predict(
|
|
200
|
+
pafs=pafs,
|
|
201
|
+
peaks=cms_peaks,
|
|
202
|
+
peak_vals=cms_peak_vals,
|
|
203
|
+
peak_channel_inds=cms_peak_channel_inds,
|
|
146
204
|
)
|
|
205
|
+
|
|
206
|
+
predicted_instances = [p / self.input_scale for p in predicted_instances]
|
|
207
|
+
predicted_instances_adjusted = []
|
|
208
|
+
for idx, p in enumerate(predicted_instances):
|
|
209
|
+
predicted_instances_adjusted.append(
|
|
210
|
+
p / inputs["eff_scale"][idx].to(p.device)
|
|
211
|
+
)
|
|
212
|
+
|
|
147
213
|
out = {
|
|
148
214
|
"pred_instance_peaks": predicted_instances_adjusted,
|
|
149
215
|
"pred_peak_values": predicted_peak_scores,
|
|
@@ -2,18 +2,60 @@
|
|
|
2
2
|
|
|
3
3
|
from typing import Optional, Tuple
|
|
4
4
|
|
|
5
|
-
import kornia as K
|
|
6
|
-
import numpy as np
|
|
7
5
|
import torch
|
|
8
|
-
|
|
6
|
+
import torch.nn.functional as F
|
|
9
7
|
|
|
10
8
|
from sleap_nn.data.instance_cropping import make_centered_bboxes
|
|
11
9
|
|
|
12
10
|
|
|
11
|
+
def morphological_dilation(image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
"""Apply morphological dilation using max pooling.
|
|
13
|
+
|
|
14
|
+
This is a pure PyTorch replacement for kornia.morphology.dilation.
|
|
15
|
+
For non-maximum suppression, it computes the maximum of 8 neighbors
|
|
16
|
+
(excluding the center pixel).
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
image: Input tensor of shape (B, 1, H, W).
|
|
20
|
+
kernel: Dilation kernel (3x3 expected for NMS).
|
|
21
|
+
|
|
22
|
+
Returns:
|
|
23
|
+
Dilated tensor of same shape as input.
|
|
24
|
+
"""
|
|
25
|
+
# Pad the image to handle border pixels
|
|
26
|
+
padded = F.pad(image, (1, 1, 1, 1), mode="constant", value=float("-inf"))
|
|
27
|
+
|
|
28
|
+
# Extract 3x3 patches using unfold
|
|
29
|
+
# Shape: (B, 1, H, W, 3, 3)
|
|
30
|
+
patches = padded.unfold(2, 3, 1).unfold(3, 3, 1)
|
|
31
|
+
|
|
32
|
+
# Reshape to (B, 1, H, W, 9)
|
|
33
|
+
b, c, h, w, kh, kw = patches.shape
|
|
34
|
+
patches = patches.reshape(b, c, h, w, -1)
|
|
35
|
+
|
|
36
|
+
# Apply kernel mask (kernel has 0 at center, 1 elsewhere for NMS)
|
|
37
|
+
# Reshape kernel to (1, 1, 1, 1, 9)
|
|
38
|
+
kernel_flat = kernel.reshape(-1).to(patches.device)
|
|
39
|
+
kernel_mask = kernel_flat > 0
|
|
40
|
+
|
|
41
|
+
# Set non-kernel positions to -inf so they don't affect max
|
|
42
|
+
patches_masked = patches.clone()
|
|
43
|
+
patches_masked[..., ~kernel_mask] = float("-inf")
|
|
44
|
+
|
|
45
|
+
# Take max over the kernel neighborhood
|
|
46
|
+
max_vals = patches_masked.max(dim=-1)[0]
|
|
47
|
+
|
|
48
|
+
return max_vals
|
|
49
|
+
|
|
50
|
+
|
|
13
51
|
def crop_bboxes(
|
|
14
52
|
images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
|
|
15
53
|
) -> torch.Tensor:
|
|
16
|
-
"""Crop bounding boxes from a batch of images.
|
|
54
|
+
"""Crop bounding boxes from a batch of images using fast tensor indexing.
|
|
55
|
+
|
|
56
|
+
This uses tensor unfold operations to extract patches, which is significantly
|
|
57
|
+
faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
|
|
58
|
+
transform computations.
|
|
17
59
|
|
|
18
60
|
Args:
|
|
19
61
|
images: Tensor of shape (samples, channels, height, width) of a batch of images.
|
|
@@ -27,7 +69,7 @@ def crop_bboxes(
|
|
|
27
69
|
box should be cropped from.
|
|
28
70
|
|
|
29
71
|
Returns:
|
|
30
|
-
A tensor of shape (n_bboxes, crop_height, crop_width
|
|
72
|
+
A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
|
|
31
73
|
dtype as the input image. The crop size is inferred from the bounding box
|
|
32
74
|
coordinates.
|
|
33
75
|
|
|
@@ -42,20 +84,55 @@ def crop_bboxes(
|
|
|
42
84
|
|
|
43
85
|
See also: `make_centered_bboxes`
|
|
44
86
|
"""
|
|
87
|
+
n_crops = bboxes.shape[0]
|
|
88
|
+
if n_crops == 0:
|
|
89
|
+
# Return empty tensor; use default crop size since we can't infer from bboxes
|
|
90
|
+
return torch.empty(
|
|
91
|
+
0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
|
|
92
|
+
)
|
|
93
|
+
|
|
45
94
|
# Compute bounding box size to use for crops.
|
|
46
|
-
height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
|
|
47
|
-
width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
95
|
+
height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
|
|
96
|
+
width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
|
|
97
|
+
|
|
98
|
+
# Store original dtype for conversion back after cropping.
|
|
99
|
+
original_dtype = images.dtype
|
|
100
|
+
device = images.device
|
|
101
|
+
n_samples, channels, img_h, img_w = images.shape
|
|
102
|
+
half_h, half_w = height // 2, width // 2
|
|
103
|
+
|
|
104
|
+
# Pad images for edge handling.
|
|
105
|
+
images_padded = F.pad(
|
|
106
|
+
images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
|
|
55
107
|
)
|
|
56
108
|
|
|
109
|
+
# Extract all possible patches using unfold (creates a view, no copy).
|
|
110
|
+
# Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
|
|
111
|
+
patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
|
|
112
|
+
|
|
113
|
+
# Get crop centers from bboxes.
|
|
114
|
+
# The bbox top-left is at index 0, with (x, y) coordinates.
|
|
115
|
+
# We need the center of the crop (peak location), which is top-left + half_size.
|
|
116
|
+
# Ensure bboxes are on the same device as images for index computation.
|
|
117
|
+
bboxes_on_device = bboxes.to(device)
|
|
118
|
+
crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
|
|
119
|
+
crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
|
|
120
|
+
|
|
121
|
+
# Clamp indices to valid bounds to handle edge cases where centroids
|
|
122
|
+
# might be at or beyond image boundaries.
|
|
123
|
+
crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
|
|
124
|
+
crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
|
|
125
|
+
|
|
126
|
+
# Select crops using advanced indexing.
|
|
127
|
+
# Convert sample_inds to tensor if it's a list.
|
|
128
|
+
if not isinstance(sample_inds, torch.Tensor):
|
|
129
|
+
sample_inds = torch.tensor(sample_inds, device=device)
|
|
130
|
+
sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
|
|
131
|
+
crops = patches[sample_inds_long, :, crop_y, crop_x]
|
|
132
|
+
# Shape: (n_crops, channels, height, width)
|
|
133
|
+
|
|
57
134
|
# Cast back to original dtype and return.
|
|
58
|
-
crops = crops.to(
|
|
135
|
+
crops = crops.to(original_dtype)
|
|
59
136
|
return crops
|
|
60
137
|
|
|
61
138
|
|
|
@@ -236,7 +313,7 @@ def find_local_peaks_rough(
|
|
|
236
313
|
flat_img = cms.reshape(-1, 1, height, width)
|
|
237
314
|
|
|
238
315
|
# Perform dilation filtering to find local maxima per channel and reshape back.
|
|
239
|
-
max_img =
|
|
316
|
+
max_img = morphological_dilation(flat_img, kernel.to(flat_img.device))
|
|
240
317
|
max_img = max_img.reshape(-1, channels, height, width)
|
|
241
318
|
|
|
242
319
|
# Filter for maxima and threshold.
|