sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/cli.py +36 -0
- sleap_nn/evaluation.py +8 -0
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/predict.py +29 -0
- sleap_nn/train.py +64 -0
- sleap_nn/training/callbacks.py +62 -20
- sleap_nn/training/lightning_modules.py +332 -30
- sleap_nn/training/model_trainer.py +35 -67
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
sleap_nn/export/utils.py
ADDED
|
@@ -0,0 +1,307 @@
|
|
|
1
|
+
"""Utilities for export workflows."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Optional, Tuple
|
|
7
|
+
|
|
8
|
+
from omegaconf import DictConfig, OmegaConf
|
|
9
|
+
|
|
10
|
+
from sleap_nn.config.training_job_config import TrainingJobConfig
|
|
11
|
+
from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def load_training_config(model_dir: str | Path) -> DictConfig:
|
|
15
|
+
"""Load training configuration from a model directory."""
|
|
16
|
+
model_dir = Path(model_dir)
|
|
17
|
+
yaml_path = model_dir / "training_config.yaml"
|
|
18
|
+
json_path = model_dir / "training_config.json"
|
|
19
|
+
|
|
20
|
+
if yaml_path.exists():
|
|
21
|
+
return OmegaConf.load(yaml_path.as_posix())
|
|
22
|
+
if json_path.exists():
|
|
23
|
+
return TrainingJobConfig.load_sleap_config(json_path.as_posix())
|
|
24
|
+
|
|
25
|
+
raise FileNotFoundError(
|
|
26
|
+
f"No training_config.yaml or training_config.json found in {model_dir}"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def resolve_input_scale(cfg: DictConfig) -> float:
|
|
31
|
+
"""Resolve preprocessing scale from config."""
|
|
32
|
+
scale = cfg.data_config.preprocessing.scale
|
|
33
|
+
# Check for list/tuple or OmegaConf ListConfig
|
|
34
|
+
if isinstance(scale, (list, tuple)) or (
|
|
35
|
+
hasattr(scale, "__iter__")
|
|
36
|
+
and hasattr(scale, "__len__")
|
|
37
|
+
and not isinstance(scale, str)
|
|
38
|
+
):
|
|
39
|
+
return float(scale[0]) if len(scale) > 0 else 1.0
|
|
40
|
+
return float(scale)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def resolve_input_channels(cfg: DictConfig) -> int:
|
|
44
|
+
"""Resolve input channels from backbone config."""
|
|
45
|
+
backbone_type = get_backbone_type_from_cfg(cfg)
|
|
46
|
+
return int(cfg.model_config.backbone_config[backbone_type].in_channels)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def resolve_output_stride(cfg: DictConfig, model_type: str) -> int:
|
|
50
|
+
"""Resolve output stride from head config."""
|
|
51
|
+
head_cfg = cfg.model_config.head_configs[model_type]
|
|
52
|
+
if head_cfg is None:
|
|
53
|
+
return 1
|
|
54
|
+
if hasattr(head_cfg, "confmaps") and head_cfg.confmaps is not None:
|
|
55
|
+
return int(head_cfg.confmaps.output_stride)
|
|
56
|
+
if hasattr(head_cfg, "pafs") and head_cfg.pafs is not None:
|
|
57
|
+
return int(head_cfg.pafs.output_stride)
|
|
58
|
+
return 1
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def resolve_pafs_output_stride(cfg: DictConfig) -> int:
|
|
62
|
+
"""Resolve PAFs output stride for bottom-up models."""
|
|
63
|
+
bottomup_cfg = getattr(cfg.model_config.head_configs, "bottomup", None)
|
|
64
|
+
if bottomup_cfg is not None and bottomup_cfg.pafs is not None:
|
|
65
|
+
return int(bottomup_cfg.pafs.output_stride)
|
|
66
|
+
return 1
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def resolve_class_maps_output_stride(cfg: DictConfig) -> int:
|
|
70
|
+
"""Resolve class maps output stride for multiclass bottom-up models."""
|
|
71
|
+
mc_bottomup_cfg = getattr(
|
|
72
|
+
cfg.model_config.head_configs, "multi_class_bottomup", None
|
|
73
|
+
)
|
|
74
|
+
if mc_bottomup_cfg is not None and mc_bottomup_cfg.class_maps is not None:
|
|
75
|
+
return int(mc_bottomup_cfg.class_maps.output_stride)
|
|
76
|
+
return 8
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def resolve_class_names(cfg: DictConfig, model_type: str) -> List[str]:
|
|
80
|
+
"""Resolve class names for multiclass models."""
|
|
81
|
+
head_cfg = cfg.model_config.head_configs.get(model_type)
|
|
82
|
+
if head_cfg is None:
|
|
83
|
+
return []
|
|
84
|
+
|
|
85
|
+
# Top-down multiclass: class_vectors.classes
|
|
86
|
+
if hasattr(head_cfg, "class_vectors") and head_cfg.class_vectors is not None:
|
|
87
|
+
classes = getattr(head_cfg.class_vectors, "classes", None)
|
|
88
|
+
if classes:
|
|
89
|
+
return list(classes)
|
|
90
|
+
|
|
91
|
+
# Bottom-up multiclass: class_maps.classes
|
|
92
|
+
if hasattr(head_cfg, "class_maps") and head_cfg.class_maps is not None:
|
|
93
|
+
classes = getattr(head_cfg.class_maps, "classes", None)
|
|
94
|
+
if classes:
|
|
95
|
+
return list(classes)
|
|
96
|
+
|
|
97
|
+
return []
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def resolve_n_classes(cfg: DictConfig, model_type: str) -> int:
|
|
101
|
+
"""Resolve number of classes for multiclass models."""
|
|
102
|
+
class_names = resolve_class_names(cfg, model_type)
|
|
103
|
+
return len(class_names) if class_names else 0
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def resolve_crop_size(cfg: DictConfig) -> Optional[Tuple[int, int]]:
|
|
107
|
+
"""Resolve crop size from preprocessing config."""
|
|
108
|
+
crop_size = cfg.data_config.preprocessing.crop_size
|
|
109
|
+
if crop_size is None:
|
|
110
|
+
return None
|
|
111
|
+
# Check for list/tuple or OmegaConf ListConfig
|
|
112
|
+
if isinstance(crop_size, (list, tuple)) or (
|
|
113
|
+
hasattr(crop_size, "__iter__")
|
|
114
|
+
and hasattr(crop_size, "__len__")
|
|
115
|
+
and not isinstance(crop_size, (str, int))
|
|
116
|
+
):
|
|
117
|
+
if len(crop_size) == 2:
|
|
118
|
+
return int(crop_size[0]), int(crop_size[1])
|
|
119
|
+
if len(crop_size) == 1:
|
|
120
|
+
return int(crop_size[0]), int(crop_size[0])
|
|
121
|
+
return int(crop_size), int(crop_size)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def resolve_node_names(cfg: DictConfig, model_type: str) -> List[str]:
|
|
125
|
+
"""Resolve node names for metadata."""
|
|
126
|
+
skeleton_nodes = _node_names_from_skeletons(cfg.data_config.skeletons)
|
|
127
|
+
if skeleton_nodes:
|
|
128
|
+
return skeleton_nodes
|
|
129
|
+
|
|
130
|
+
head_cfg = cfg.model_config.head_configs.get(model_type)
|
|
131
|
+
if head_cfg is None:
|
|
132
|
+
return []
|
|
133
|
+
|
|
134
|
+
if hasattr(head_cfg, "confmaps") and head_cfg.confmaps is not None:
|
|
135
|
+
part_names = getattr(head_cfg.confmaps, "part_names", None)
|
|
136
|
+
if part_names:
|
|
137
|
+
return list(part_names)
|
|
138
|
+
|
|
139
|
+
if model_type == "centroid":
|
|
140
|
+
anchor = getattr(head_cfg.confmaps, "anchor_part", None) if head_cfg else None
|
|
141
|
+
return [anchor] if anchor else ["centroid"]
|
|
142
|
+
|
|
143
|
+
return []
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def resolve_edge_inds(cfg: DictConfig, node_names: List[str]) -> List[Tuple[int, int]]:
|
|
147
|
+
"""Resolve edge indices for metadata."""
|
|
148
|
+
edges = _edge_inds_from_skeletons(cfg.data_config.skeletons)
|
|
149
|
+
if edges:
|
|
150
|
+
return _normalize_edges(edges, node_names)
|
|
151
|
+
|
|
152
|
+
bottomup_cfg = getattr(cfg.model_config.head_configs, "bottomup", None)
|
|
153
|
+
if bottomup_cfg is not None and bottomup_cfg.pafs is not None:
|
|
154
|
+
edges = bottomup_cfg.pafs.edges
|
|
155
|
+
if edges:
|
|
156
|
+
return _normalize_edges(edges, node_names)
|
|
157
|
+
|
|
158
|
+
return []
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def resolve_model_type(cfg: DictConfig) -> str:
|
|
162
|
+
"""Return model type from config."""
|
|
163
|
+
return get_model_type_from_cfg(cfg)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def resolve_backbone_type(cfg: DictConfig) -> str:
|
|
167
|
+
"""Return backbone type from config."""
|
|
168
|
+
return get_backbone_type_from_cfg(cfg)
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
def resolve_input_shape(
|
|
172
|
+
cfg: DictConfig,
|
|
173
|
+
input_height: Optional[int] = None,
|
|
174
|
+
input_width: Optional[int] = None,
|
|
175
|
+
) -> Tuple[int, int, int, int]:
|
|
176
|
+
"""Resolve a dummy input shape for export."""
|
|
177
|
+
channels = resolve_input_channels(cfg)
|
|
178
|
+
height = input_height or cfg.data_config.preprocessing.max_height or 512
|
|
179
|
+
width = input_width or cfg.data_config.preprocessing.max_width or 512
|
|
180
|
+
return 1, channels, int(height), int(width)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def _node_names_from_skeletons(skeletons) -> List[str]:
|
|
184
|
+
if not skeletons:
|
|
185
|
+
return []
|
|
186
|
+
skeleton = skeletons[0]
|
|
187
|
+
if hasattr(skeleton, "nodes"):
|
|
188
|
+
try:
|
|
189
|
+
return [node.name for node in skeleton.nodes]
|
|
190
|
+
except Exception:
|
|
191
|
+
pass
|
|
192
|
+
if isinstance(skeleton, dict):
|
|
193
|
+
nodes = skeleton.get("nodes")
|
|
194
|
+
if nodes:
|
|
195
|
+
if isinstance(nodes[0], dict):
|
|
196
|
+
return [node.get("name", "") for node in nodes if node.get("name")]
|
|
197
|
+
return [str(node) for node in nodes]
|
|
198
|
+
node_names = skeleton.get("node_names")
|
|
199
|
+
if node_names:
|
|
200
|
+
return [str(name) for name in node_names]
|
|
201
|
+
return []
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _edge_inds_from_skeletons(skeletons) -> List:
|
|
205
|
+
if not skeletons:
|
|
206
|
+
return []
|
|
207
|
+
skeleton = skeletons[0]
|
|
208
|
+
if hasattr(skeleton, "edge_inds"):
|
|
209
|
+
try:
|
|
210
|
+
return list(skeleton.edge_inds)
|
|
211
|
+
except Exception:
|
|
212
|
+
pass
|
|
213
|
+
if isinstance(skeleton, dict):
|
|
214
|
+
edges = skeleton.get("edges") or skeleton.get("edge_inds")
|
|
215
|
+
if edges:
|
|
216
|
+
return list(edges)
|
|
217
|
+
return []
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _normalize_edges(edges: List, node_names: List[str]) -> List[Tuple[int, int]]:
|
|
221
|
+
if not edges:
|
|
222
|
+
return []
|
|
223
|
+
if not node_names:
|
|
224
|
+
return [(int(src), int(dst)) for src, dst in edges]
|
|
225
|
+
|
|
226
|
+
if isinstance(edges[0][0], str):
|
|
227
|
+
name_to_idx = {name: idx for idx, name in enumerate(node_names)}
|
|
228
|
+
normalized = []
|
|
229
|
+
for src, dst in edges:
|
|
230
|
+
if src in name_to_idx and dst in name_to_idx:
|
|
231
|
+
normalized.append((name_to_idx[src], name_to_idx[dst]))
|
|
232
|
+
return normalized
|
|
233
|
+
|
|
234
|
+
return [(int(src), int(dst)) for src, dst in edges]
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def build_bottomup_candidate_template(
|
|
238
|
+
n_nodes: int, max_peaks_per_node: int, edge_inds: List[Tuple[int, int]]
|
|
239
|
+
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
|
240
|
+
"""Build candidate template matching ONNX wrapper's line_scores ordering.
|
|
241
|
+
|
|
242
|
+
The ONNX BottomUpONNXWrapper produces line_scores with shape (n_edges, k*k) where
|
|
243
|
+
for each edge connecting (src_node, dst_node), position i*k + j corresponds to:
|
|
244
|
+
- src peak flat index: src_node * k + i
|
|
245
|
+
- dst peak flat index: dst_node * k + j
|
|
246
|
+
|
|
247
|
+
This function builds edge_inds and edge_peak_inds tensors that match this exact
|
|
248
|
+
ordering, so that line_scores_flat[idx] corresponds to edge_peak_inds[idx].
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
n_nodes: Number of nodes in the skeleton.
|
|
252
|
+
max_peaks_per_node: Maximum peaks per node (k) used during export.
|
|
253
|
+
edge_inds: List of (src_node, dst_node) tuples defining skeleton edges.
|
|
254
|
+
|
|
255
|
+
Returns:
|
|
256
|
+
Tuple of (peak_channel_inds, edge_inds_tensor, edge_peak_inds_tensor):
|
|
257
|
+
- peak_channel_inds: (n_nodes * k,) tensor mapping flat peak index to node
|
|
258
|
+
- edge_inds_tensor: (n_edges * k * k,) tensor of edge indices for each candidate
|
|
259
|
+
- edge_peak_inds_tensor: (n_edges * k * k, 2) tensor of (src, dst) peak indices
|
|
260
|
+
|
|
261
|
+
Example:
|
|
262
|
+
>>> from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
263
|
+
>>> peak_ch, edge_inds, edge_peaks = build_bottomup_candidate_template(
|
|
264
|
+
... n_nodes=15, max_peaks_per_node=20, edge_inds=[(1, 2), (1, 5)]
|
|
265
|
+
... )
|
|
266
|
+
>>> # Use with ONNX output:
|
|
267
|
+
>>> line_scores_flat = line_scores.reshape(-1)
|
|
268
|
+
>>> valid_scores = line_scores_flat[valid_mask]
|
|
269
|
+
>>> valid_edge_peaks = edge_peaks[valid_mask]
|
|
270
|
+
|
|
271
|
+
Note:
|
|
272
|
+
This function is necessary because `get_connection_candidates()` in
|
|
273
|
+
`sleap_nn.inference.paf_grouping` uses unstable argsort, which shuffles
|
|
274
|
+
peak indices within each node and breaks alignment with ONNX output ordering.
|
|
275
|
+
"""
|
|
276
|
+
import torch
|
|
277
|
+
|
|
278
|
+
k = max_peaks_per_node
|
|
279
|
+
n_edges = len(edge_inds)
|
|
280
|
+
|
|
281
|
+
# peak_channel_inds: [0,0,...0, 1,1,...1, ...] (k times each)
|
|
282
|
+
peak_channel_inds = torch.arange(n_nodes, dtype=torch.int32).repeat_interleave(k)
|
|
283
|
+
|
|
284
|
+
edge_inds_list = []
|
|
285
|
+
edge_peak_inds_list = []
|
|
286
|
+
|
|
287
|
+
for edge_idx, (src_node, dst_node) in enumerate(edge_inds):
|
|
288
|
+
# Build k*k candidate pairs in row-major order (i*k + j)
|
|
289
|
+
# src indices: [src_node*k + 0, src_node*k + 0, ..., src_node*k + 1, ...]
|
|
290
|
+
# dst indices: [dst_node*k + 0, dst_node*k + 1, ..., dst_node*k + 0, ...]
|
|
291
|
+
src_base = src_node * k
|
|
292
|
+
dst_base = dst_node * k
|
|
293
|
+
|
|
294
|
+
src_indices = torch.arange(k, dtype=torch.int32).repeat_interleave(k) + src_base
|
|
295
|
+
dst_indices = torch.arange(k, dtype=torch.int32).repeat(k) + dst_base
|
|
296
|
+
|
|
297
|
+
edge_inds_list.append(torch.full((k * k,), edge_idx, dtype=torch.int32))
|
|
298
|
+
edge_peak_inds_list.append(torch.stack([src_indices, dst_indices], dim=1))
|
|
299
|
+
|
|
300
|
+
if edge_inds_list:
|
|
301
|
+
edge_inds_tensor = torch.cat(edge_inds_list)
|
|
302
|
+
edge_peak_inds_tensor = torch.cat(edge_peak_inds_list)
|
|
303
|
+
else:
|
|
304
|
+
edge_inds_tensor = torch.empty((0,), dtype=torch.int32)
|
|
305
|
+
edge_peak_inds_tensor = torch.empty((0, 2), dtype=torch.int32)
|
|
306
|
+
|
|
307
|
+
return peak_channel_inds, edge_inds_tensor, edge_peak_inds_tensor
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""ONNX/TensorRT export wrappers."""
|
|
2
|
+
|
|
3
|
+
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
4
|
+
from sleap_nn.export.wrappers.centroid import CentroidONNXWrapper
|
|
5
|
+
from sleap_nn.export.wrappers.centered_instance import CenteredInstanceONNXWrapper
|
|
6
|
+
from sleap_nn.export.wrappers.topdown import TopDownONNXWrapper
|
|
7
|
+
from sleap_nn.export.wrappers.bottomup import BottomUpONNXWrapper
|
|
8
|
+
from sleap_nn.export.wrappers.single_instance import SingleInstanceONNXWrapper
|
|
9
|
+
from sleap_nn.export.wrappers.topdown_multiclass import (
|
|
10
|
+
TopDownMultiClassONNXWrapper,
|
|
11
|
+
TopDownMultiClassCombinedONNXWrapper,
|
|
12
|
+
)
|
|
13
|
+
from sleap_nn.export.wrappers.bottomup_multiclass import BottomUpMultiClassONNXWrapper
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"BaseExportWrapper",
|
|
17
|
+
"CentroidONNXWrapper",
|
|
18
|
+
"CenteredInstanceONNXWrapper",
|
|
19
|
+
"TopDownONNXWrapper",
|
|
20
|
+
"BottomUpONNXWrapper",
|
|
21
|
+
"SingleInstanceONNXWrapper",
|
|
22
|
+
"TopDownMultiClassONNXWrapper",
|
|
23
|
+
"TopDownMultiClassCombinedONNXWrapper",
|
|
24
|
+
"BottomUpMultiClassONNXWrapper",
|
|
25
|
+
]
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
"""Base classes and shared helpers for export wrappers."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Iterable, List, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseExportWrapper(nn.Module):
|
|
13
|
+
"""Base class for ONNX-exportable wrappers."""
|
|
14
|
+
|
|
15
|
+
def __init__(self, model: nn.Module):
|
|
16
|
+
"""Initialize wrapper with the underlying model.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
model: The PyTorch model to wrap for export.
|
|
20
|
+
"""
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.model = model
|
|
23
|
+
|
|
24
|
+
@staticmethod
|
|
25
|
+
def _normalize_uint8(image: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
"""Normalize unnormalized uint8 (or [0, 255] float) images to [0, 1]."""
|
|
27
|
+
if image.dtype != torch.float32:
|
|
28
|
+
image = image.float()
|
|
29
|
+
return image / 255.0
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def _extract_tensor(output, key_hints: Iterable[str]) -> torch.Tensor:
|
|
33
|
+
if isinstance(output, dict):
|
|
34
|
+
for key in output:
|
|
35
|
+
for hint in key_hints:
|
|
36
|
+
if hint.lower() in key.lower():
|
|
37
|
+
return output[key]
|
|
38
|
+
return next(iter(output.values()))
|
|
39
|
+
return output
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def _find_topk_peaks(
|
|
43
|
+
confmaps: torch.Tensor, k: int
|
|
44
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
45
|
+
"""Top-K peak finding with NMS via max pooling."""
|
|
46
|
+
batch_size, _, height, width = confmaps.shape
|
|
47
|
+
pooled = F.max_pool2d(confmaps, kernel_size=3, stride=1, padding=1)
|
|
48
|
+
is_peak = (confmaps == pooled) & (confmaps > 0)
|
|
49
|
+
|
|
50
|
+
confmaps_flat = confmaps.reshape(batch_size, height * width)
|
|
51
|
+
is_peak_flat = is_peak.reshape(batch_size, height * width)
|
|
52
|
+
masked = torch.where(
|
|
53
|
+
is_peak_flat, confmaps_flat, torch.full_like(confmaps_flat, -1e9)
|
|
54
|
+
)
|
|
55
|
+
values, indices = torch.topk(masked, k=k, dim=1)
|
|
56
|
+
|
|
57
|
+
y = indices // width
|
|
58
|
+
x = indices % width
|
|
59
|
+
peaks = torch.stack([x.float(), y.float()], dim=-1)
|
|
60
|
+
valid = values > 0
|
|
61
|
+
return peaks, values, valid
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def _find_topk_peaks_per_node(
|
|
65
|
+
confmaps: torch.Tensor, k: int
|
|
66
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
67
|
+
"""Top-K peak finding per channel with NMS via max pooling."""
|
|
68
|
+
batch_size, n_nodes, height, width = confmaps.shape
|
|
69
|
+
pooled = F.max_pool2d(confmaps, kernel_size=3, stride=1, padding=1)
|
|
70
|
+
is_peak = (confmaps == pooled) & (confmaps > 0)
|
|
71
|
+
|
|
72
|
+
confmaps_flat = confmaps.reshape(batch_size, n_nodes, height * width)
|
|
73
|
+
is_peak_flat = is_peak.reshape(batch_size, n_nodes, height * width)
|
|
74
|
+
masked = torch.where(
|
|
75
|
+
is_peak_flat, confmaps_flat, torch.full_like(confmaps_flat, -1e9)
|
|
76
|
+
)
|
|
77
|
+
values, indices = torch.topk(masked, k=k, dim=2)
|
|
78
|
+
|
|
79
|
+
y = indices // width
|
|
80
|
+
x = indices % width
|
|
81
|
+
peaks = torch.stack([x.float(), y.float()], dim=-1)
|
|
82
|
+
valid = values > 0
|
|
83
|
+
return peaks, values, valid
|
|
84
|
+
|
|
85
|
+
@staticmethod
|
|
86
|
+
def _find_global_peaks(
|
|
87
|
+
confmaps: torch.Tensor,
|
|
88
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
89
|
+
"""Find global maxima per channel."""
|
|
90
|
+
batch_size, channels, height, width = confmaps.shape
|
|
91
|
+
flat = confmaps.reshape(batch_size, channels, height * width)
|
|
92
|
+
values, indices = flat.max(dim=-1)
|
|
93
|
+
y = indices // width
|
|
94
|
+
x = indices % width
|
|
95
|
+
peaks = torch.stack([x.float(), y.float()], dim=-1)
|
|
96
|
+
return peaks, values
|
|
@@ -0,0 +1,243 @@
|
|
|
1
|
+
"""Bottom-up ONNX wrapper."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Dict, Tuple
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from torch import nn
|
|
9
|
+
from torch.nn import functional as F
|
|
10
|
+
|
|
11
|
+
from sleap_nn.export.wrappers.base import BaseExportWrapper
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BottomUpONNXWrapper(BaseExportWrapper):
|
|
15
|
+
"""ONNX-exportable wrapper for bottom-up inference up to PAF scoring.
|
|
16
|
+
|
|
17
|
+
Expects input images as uint8 tensors in [0, 255].
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: nn.Module,
|
|
23
|
+
skeleton_edges: list,
|
|
24
|
+
n_nodes: int,
|
|
25
|
+
max_peaks_per_node: int = 20,
|
|
26
|
+
n_line_points: int = 10,
|
|
27
|
+
cms_output_stride: int = 4,
|
|
28
|
+
pafs_output_stride: int = 8,
|
|
29
|
+
max_edge_length_ratio: float = 0.25,
|
|
30
|
+
dist_penalty_weight: float = 1.0,
|
|
31
|
+
input_scale: float = 1.0,
|
|
32
|
+
) -> None:
|
|
33
|
+
"""Initialize bottom-up ONNX wrapper.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model: Bottom-up model producing confidence maps and PAFs.
|
|
37
|
+
skeleton_edges: List of (src, dst) edge tuples defining skeleton.
|
|
38
|
+
n_nodes: Number of nodes in the skeleton.
|
|
39
|
+
max_peaks_per_node: Maximum peaks to detect per node type.
|
|
40
|
+
n_line_points: Points to sample along PAF edges.
|
|
41
|
+
cms_output_stride: Confidence map output stride.
|
|
42
|
+
pafs_output_stride: PAF output stride.
|
|
43
|
+
max_edge_length_ratio: Maximum edge length as ratio of image size.
|
|
44
|
+
dist_penalty_weight: Weight for distance penalty in scoring.
|
|
45
|
+
input_scale: Input scaling factor.
|
|
46
|
+
"""
|
|
47
|
+
super().__init__(model)
|
|
48
|
+
self.n_nodes = n_nodes
|
|
49
|
+
self.n_edges = len(skeleton_edges)
|
|
50
|
+
self.max_peaks_per_node = max_peaks_per_node
|
|
51
|
+
self.n_line_points = n_line_points
|
|
52
|
+
self.cms_output_stride = cms_output_stride
|
|
53
|
+
self.pafs_output_stride = pafs_output_stride
|
|
54
|
+
self.max_edge_length_ratio = max_edge_length_ratio
|
|
55
|
+
self.dist_penalty_weight = dist_penalty_weight
|
|
56
|
+
self.input_scale = input_scale
|
|
57
|
+
|
|
58
|
+
edge_src = torch.tensor([e[0] for e in skeleton_edges], dtype=torch.long)
|
|
59
|
+
edge_dst = torch.tensor([e[1] for e in skeleton_edges], dtype=torch.long)
|
|
60
|
+
self.register_buffer("edge_src", edge_src)
|
|
61
|
+
self.register_buffer("edge_dst", edge_dst)
|
|
62
|
+
|
|
63
|
+
line_samples = torch.linspace(0, 1, n_line_points, dtype=torch.float32)
|
|
64
|
+
self.register_buffer("line_samples", line_samples)
|
|
65
|
+
|
|
66
|
+
def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
67
|
+
"""Run bottom-up inference and return fixed-size outputs.
|
|
68
|
+
|
|
69
|
+
Note: confmaps and pafs are NOT returned to avoid D2H transfer bottleneck.
|
|
70
|
+
Peak detection and PAF scoring are performed on GPU within this wrapper.
|
|
71
|
+
"""
|
|
72
|
+
image = self._normalize_uint8(image)
|
|
73
|
+
if self.input_scale != 1.0:
|
|
74
|
+
height = int(image.shape[-2] * self.input_scale)
|
|
75
|
+
width = int(image.shape[-1] * self.input_scale)
|
|
76
|
+
image = F.interpolate(
|
|
77
|
+
image, size=(height, width), mode="bilinear", align_corners=False
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
batch_size, _, height, width = image.shape
|
|
81
|
+
|
|
82
|
+
out = self.model(image)
|
|
83
|
+
if isinstance(out, dict):
|
|
84
|
+
confmaps = self._extract_tensor(out, ["confmap", "multiinstance"])
|
|
85
|
+
pafs = self._extract_tensor(out, ["paf", "affinity"])
|
|
86
|
+
else:
|
|
87
|
+
confmaps, pafs = out[:2]
|
|
88
|
+
|
|
89
|
+
peaks, peak_vals, peak_mask = self._find_topk_peaks_per_node(
|
|
90
|
+
confmaps, self.max_peaks_per_node
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
peaks = peaks * self.cms_output_stride
|
|
94
|
+
|
|
95
|
+
# Compute max_edge_length to match PyTorch implementation:
|
|
96
|
+
# max_edge_length = ratio * max(paf_dims) * pafs_stride
|
|
97
|
+
# PAFs shape is (batch, 2*edges, H, W)
|
|
98
|
+
_, n_paf_channels, paf_height, paf_width = pafs.shape
|
|
99
|
+
max_paf_dim = max(n_paf_channels, paf_height, paf_width)
|
|
100
|
+
max_edge_length = torch.tensor(
|
|
101
|
+
self.max_edge_length_ratio * max_paf_dim * self.pafs_output_stride,
|
|
102
|
+
dtype=peaks.dtype,
|
|
103
|
+
device=peaks.device,
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
line_scores, candidate_mask = self._score_all_candidates(
|
|
107
|
+
pafs, peaks, peak_mask, max_edge_length
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# Only return final outputs needed for CPU-side grouping.
|
|
111
|
+
# Do NOT return confmaps/pafs - they are large (~29 MB/batch) and
|
|
112
|
+
# cause D2H transfer bottleneck. Peak detection and PAF scoring
|
|
113
|
+
# are already done on GPU above.
|
|
114
|
+
return {
|
|
115
|
+
"peaks": peaks,
|
|
116
|
+
"peak_vals": peak_vals,
|
|
117
|
+
"peak_mask": peak_mask,
|
|
118
|
+
"line_scores": line_scores,
|
|
119
|
+
"candidate_mask": candidate_mask,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
def _score_all_candidates(
|
|
123
|
+
self,
|
|
124
|
+
pafs: torch.Tensor,
|
|
125
|
+
peaks: torch.Tensor,
|
|
126
|
+
peak_mask: torch.Tensor,
|
|
127
|
+
max_edge_length: torch.Tensor,
|
|
128
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
129
|
+
"""Score all K*K candidate connections for each edge."""
|
|
130
|
+
batch_size = peaks.shape[0]
|
|
131
|
+
k = self.max_peaks_per_node
|
|
132
|
+
n_edges = self.n_edges
|
|
133
|
+
|
|
134
|
+
_, _, paf_height, paf_width = pafs.shape
|
|
135
|
+
|
|
136
|
+
src_peaks = peaks[:, self.edge_src, :, :]
|
|
137
|
+
dst_peaks = peaks[:, self.edge_dst, :, :]
|
|
138
|
+
|
|
139
|
+
src_mask = peak_mask[:, self.edge_src, :]
|
|
140
|
+
dst_mask = peak_mask[:, self.edge_dst, :]
|
|
141
|
+
|
|
142
|
+
src_peaks_exp = src_peaks.unsqueeze(3).expand(-1, -1, -1, k, -1)
|
|
143
|
+
dst_peaks_exp = dst_peaks.unsqueeze(2).expand(-1, -1, k, -1, -1)
|
|
144
|
+
|
|
145
|
+
src_mask_exp = src_mask.unsqueeze(3).expand(-1, -1, -1, k)
|
|
146
|
+
dst_mask_exp = dst_mask.unsqueeze(2).expand(-1, -1, k, -1)
|
|
147
|
+
candidate_mask = src_mask_exp & dst_mask_exp
|
|
148
|
+
|
|
149
|
+
src_peaks_flat = src_peaks_exp.reshape(batch_size, n_edges, k * k, 2)
|
|
150
|
+
dst_peaks_flat = dst_peaks_exp.reshape(batch_size, n_edges, k * k, 2)
|
|
151
|
+
candidate_mask_flat = candidate_mask.reshape(batch_size, n_edges, k * k)
|
|
152
|
+
|
|
153
|
+
spatial_vecs = dst_peaks_flat - src_peaks_flat
|
|
154
|
+
spatial_lengths = torch.norm(spatial_vecs, dim=-1, keepdim=True).clamp(min=1e-6)
|
|
155
|
+
spatial_vecs_norm = spatial_vecs / spatial_lengths
|
|
156
|
+
|
|
157
|
+
line_samples = self.line_samples.view(1, 1, 1, -1, 1)
|
|
158
|
+
src_exp = src_peaks_flat.unsqueeze(3)
|
|
159
|
+
dst_exp = dst_peaks_flat.unsqueeze(3)
|
|
160
|
+
line_points = src_exp + line_samples * (dst_exp - src_exp)
|
|
161
|
+
|
|
162
|
+
line_points_paf = line_points / self.pafs_output_stride
|
|
163
|
+
line_x = line_points_paf[..., 0].clamp(0, paf_width - 1)
|
|
164
|
+
line_y = line_points_paf[..., 1].clamp(0, paf_height - 1)
|
|
165
|
+
|
|
166
|
+
line_scores = self._sample_and_score_lines(
|
|
167
|
+
pafs,
|
|
168
|
+
line_x,
|
|
169
|
+
line_y,
|
|
170
|
+
spatial_vecs_norm,
|
|
171
|
+
spatial_lengths.squeeze(-1),
|
|
172
|
+
max_edge_length,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
line_scores = line_scores.masked_fill(~candidate_mask_flat, -2.0)
|
|
176
|
+
return line_scores, candidate_mask_flat
|
|
177
|
+
|
|
178
|
+
def _sample_and_score_lines(
|
|
179
|
+
self,
|
|
180
|
+
pafs: torch.Tensor,
|
|
181
|
+
line_x: torch.Tensor,
|
|
182
|
+
line_y: torch.Tensor,
|
|
183
|
+
spatial_vecs_norm: torch.Tensor,
|
|
184
|
+
spatial_lengths: torch.Tensor,
|
|
185
|
+
max_edge_length: torch.Tensor,
|
|
186
|
+
) -> torch.Tensor:
|
|
187
|
+
"""Sample PAF values along lines and compute scores."""
|
|
188
|
+
batch_size, n_edges, k2, n_points = line_x.shape
|
|
189
|
+
_, _, paf_height, paf_width = pafs.shape
|
|
190
|
+
|
|
191
|
+
all_scores = []
|
|
192
|
+
for edge_idx in range(n_edges):
|
|
193
|
+
paf_x = pafs[:, 2 * edge_idx, :, :]
|
|
194
|
+
paf_y = pafs[:, 2 * edge_idx + 1, :, :]
|
|
195
|
+
|
|
196
|
+
lx = line_x[:, edge_idx, :, :]
|
|
197
|
+
ly = line_y[:, edge_idx, :, :]
|
|
198
|
+
|
|
199
|
+
lx_norm = (lx / (paf_width - 1)) * 2 - 1
|
|
200
|
+
ly_norm = (ly / (paf_height - 1)) * 2 - 1
|
|
201
|
+
|
|
202
|
+
grid = torch.stack([lx_norm, ly_norm], dim=-1)
|
|
203
|
+
|
|
204
|
+
paf_x_samples = F.grid_sample(
|
|
205
|
+
paf_x.unsqueeze(1),
|
|
206
|
+
grid,
|
|
207
|
+
mode="bilinear",
|
|
208
|
+
padding_mode="zeros",
|
|
209
|
+
align_corners=True,
|
|
210
|
+
).squeeze(1)
|
|
211
|
+
|
|
212
|
+
paf_y_samples = F.grid_sample(
|
|
213
|
+
paf_y.unsqueeze(1),
|
|
214
|
+
grid,
|
|
215
|
+
mode="bilinear",
|
|
216
|
+
padding_mode="zeros",
|
|
217
|
+
align_corners=True,
|
|
218
|
+
).squeeze(1)
|
|
219
|
+
|
|
220
|
+
paf_samples = torch.stack([paf_x_samples, paf_y_samples], dim=-1)
|
|
221
|
+
disp_vec = spatial_vecs_norm[:, edge_idx, :, :]
|
|
222
|
+
|
|
223
|
+
dot_products = (paf_samples * disp_vec.unsqueeze(2)).sum(dim=-1)
|
|
224
|
+
mean_scores = dot_products.mean(dim=-1)
|
|
225
|
+
|
|
226
|
+
edge_lengths = spatial_lengths[:, edge_idx, :]
|
|
227
|
+
dist_penalty = self._compute_distance_penalty(edge_lengths, max_edge_length)
|
|
228
|
+
|
|
229
|
+
all_scores.append(mean_scores + dist_penalty)
|
|
230
|
+
|
|
231
|
+
return torch.stack(all_scores, dim=1)
|
|
232
|
+
|
|
233
|
+
def _compute_distance_penalty(
|
|
234
|
+
self, distances: torch.Tensor, max_edge_length: torch.Tensor
|
|
235
|
+
) -> torch.Tensor:
|
|
236
|
+
"""Compute distance penalty for edge candidates.
|
|
237
|
+
|
|
238
|
+
Matches the PyTorch implementation in sleap_nn.inference.paf_grouping.
|
|
239
|
+
Penalty is 0 when distance <= max_edge_length, and negative when longer.
|
|
240
|
+
"""
|
|
241
|
+
# Match PyTorch: penalty = clamp((max_edge_length / distance) - 1, max=0) * weight
|
|
242
|
+
penalty = torch.clamp((max_edge_length / distances) - 1, max=0)
|
|
243
|
+
return penalty * self.dist_penalty_weight
|