sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/evaluation.py +8 -0
  4. sleap_nn/export/__init__.py +21 -0
  5. sleap_nn/export/cli.py +1778 -0
  6. sleap_nn/export/exporters/__init__.py +51 -0
  7. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  8. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  9. sleap_nn/export/metadata.py +225 -0
  10. sleap_nn/export/predictors/__init__.py +63 -0
  11. sleap_nn/export/predictors/base.py +22 -0
  12. sleap_nn/export/predictors/onnx.py +154 -0
  13. sleap_nn/export/predictors/tensorrt.py +312 -0
  14. sleap_nn/export/utils.py +307 -0
  15. sleap_nn/export/wrappers/__init__.py +25 -0
  16. sleap_nn/export/wrappers/base.py +96 -0
  17. sleap_nn/export/wrappers/bottomup.py +243 -0
  18. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  19. sleap_nn/export/wrappers/centered_instance.py +56 -0
  20. sleap_nn/export/wrappers/centroid.py +58 -0
  21. sleap_nn/export/wrappers/single_instance.py +83 -0
  22. sleap_nn/export/wrappers/topdown.py +180 -0
  23. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  24. sleap_nn/inference/postprocessing.py +284 -0
  25. sleap_nn/predict.py +29 -0
  26. sleap_nn/train.py +64 -0
  27. sleap_nn/training/callbacks.py +62 -20
  28. sleap_nn/training/lightning_modules.py +332 -30
  29. sleap_nn/training/model_trainer.py +35 -67
  30. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +12 -1
  31. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +35 -14
  32. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  33. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  34. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,307 @@
1
+ """Utilities for export workflows."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import List, Optional, Tuple
7
+
8
+ from omegaconf import DictConfig, OmegaConf
9
+
10
+ from sleap_nn.config.training_job_config import TrainingJobConfig
11
+ from sleap_nn.config.utils import get_backbone_type_from_cfg, get_model_type_from_cfg
12
+
13
+
14
+ def load_training_config(model_dir: str | Path) -> DictConfig:
15
+ """Load training configuration from a model directory."""
16
+ model_dir = Path(model_dir)
17
+ yaml_path = model_dir / "training_config.yaml"
18
+ json_path = model_dir / "training_config.json"
19
+
20
+ if yaml_path.exists():
21
+ return OmegaConf.load(yaml_path.as_posix())
22
+ if json_path.exists():
23
+ return TrainingJobConfig.load_sleap_config(json_path.as_posix())
24
+
25
+ raise FileNotFoundError(
26
+ f"No training_config.yaml or training_config.json found in {model_dir}"
27
+ )
28
+
29
+
30
+ def resolve_input_scale(cfg: DictConfig) -> float:
31
+ """Resolve preprocessing scale from config."""
32
+ scale = cfg.data_config.preprocessing.scale
33
+ # Check for list/tuple or OmegaConf ListConfig
34
+ if isinstance(scale, (list, tuple)) or (
35
+ hasattr(scale, "__iter__")
36
+ and hasattr(scale, "__len__")
37
+ and not isinstance(scale, str)
38
+ ):
39
+ return float(scale[0]) if len(scale) > 0 else 1.0
40
+ return float(scale)
41
+
42
+
43
+ def resolve_input_channels(cfg: DictConfig) -> int:
44
+ """Resolve input channels from backbone config."""
45
+ backbone_type = get_backbone_type_from_cfg(cfg)
46
+ return int(cfg.model_config.backbone_config[backbone_type].in_channels)
47
+
48
+
49
+ def resolve_output_stride(cfg: DictConfig, model_type: str) -> int:
50
+ """Resolve output stride from head config."""
51
+ head_cfg = cfg.model_config.head_configs[model_type]
52
+ if head_cfg is None:
53
+ return 1
54
+ if hasattr(head_cfg, "confmaps") and head_cfg.confmaps is not None:
55
+ return int(head_cfg.confmaps.output_stride)
56
+ if hasattr(head_cfg, "pafs") and head_cfg.pafs is not None:
57
+ return int(head_cfg.pafs.output_stride)
58
+ return 1
59
+
60
+
61
+ def resolve_pafs_output_stride(cfg: DictConfig) -> int:
62
+ """Resolve PAFs output stride for bottom-up models."""
63
+ bottomup_cfg = getattr(cfg.model_config.head_configs, "bottomup", None)
64
+ if bottomup_cfg is not None and bottomup_cfg.pafs is not None:
65
+ return int(bottomup_cfg.pafs.output_stride)
66
+ return 1
67
+
68
+
69
+ def resolve_class_maps_output_stride(cfg: DictConfig) -> int:
70
+ """Resolve class maps output stride for multiclass bottom-up models."""
71
+ mc_bottomup_cfg = getattr(
72
+ cfg.model_config.head_configs, "multi_class_bottomup", None
73
+ )
74
+ if mc_bottomup_cfg is not None and mc_bottomup_cfg.class_maps is not None:
75
+ return int(mc_bottomup_cfg.class_maps.output_stride)
76
+ return 8
77
+
78
+
79
+ def resolve_class_names(cfg: DictConfig, model_type: str) -> List[str]:
80
+ """Resolve class names for multiclass models."""
81
+ head_cfg = cfg.model_config.head_configs.get(model_type)
82
+ if head_cfg is None:
83
+ return []
84
+
85
+ # Top-down multiclass: class_vectors.classes
86
+ if hasattr(head_cfg, "class_vectors") and head_cfg.class_vectors is not None:
87
+ classes = getattr(head_cfg.class_vectors, "classes", None)
88
+ if classes:
89
+ return list(classes)
90
+
91
+ # Bottom-up multiclass: class_maps.classes
92
+ if hasattr(head_cfg, "class_maps") and head_cfg.class_maps is not None:
93
+ classes = getattr(head_cfg.class_maps, "classes", None)
94
+ if classes:
95
+ return list(classes)
96
+
97
+ return []
98
+
99
+
100
+ def resolve_n_classes(cfg: DictConfig, model_type: str) -> int:
101
+ """Resolve number of classes for multiclass models."""
102
+ class_names = resolve_class_names(cfg, model_type)
103
+ return len(class_names) if class_names else 0
104
+
105
+
106
+ def resolve_crop_size(cfg: DictConfig) -> Optional[Tuple[int, int]]:
107
+ """Resolve crop size from preprocessing config."""
108
+ crop_size = cfg.data_config.preprocessing.crop_size
109
+ if crop_size is None:
110
+ return None
111
+ # Check for list/tuple or OmegaConf ListConfig
112
+ if isinstance(crop_size, (list, tuple)) or (
113
+ hasattr(crop_size, "__iter__")
114
+ and hasattr(crop_size, "__len__")
115
+ and not isinstance(crop_size, (str, int))
116
+ ):
117
+ if len(crop_size) == 2:
118
+ return int(crop_size[0]), int(crop_size[1])
119
+ if len(crop_size) == 1:
120
+ return int(crop_size[0]), int(crop_size[0])
121
+ return int(crop_size), int(crop_size)
122
+
123
+
124
+ def resolve_node_names(cfg: DictConfig, model_type: str) -> List[str]:
125
+ """Resolve node names for metadata."""
126
+ skeleton_nodes = _node_names_from_skeletons(cfg.data_config.skeletons)
127
+ if skeleton_nodes:
128
+ return skeleton_nodes
129
+
130
+ head_cfg = cfg.model_config.head_configs.get(model_type)
131
+ if head_cfg is None:
132
+ return []
133
+
134
+ if hasattr(head_cfg, "confmaps") and head_cfg.confmaps is not None:
135
+ part_names = getattr(head_cfg.confmaps, "part_names", None)
136
+ if part_names:
137
+ return list(part_names)
138
+
139
+ if model_type == "centroid":
140
+ anchor = getattr(head_cfg.confmaps, "anchor_part", None) if head_cfg else None
141
+ return [anchor] if anchor else ["centroid"]
142
+
143
+ return []
144
+
145
+
146
+ def resolve_edge_inds(cfg: DictConfig, node_names: List[str]) -> List[Tuple[int, int]]:
147
+ """Resolve edge indices for metadata."""
148
+ edges = _edge_inds_from_skeletons(cfg.data_config.skeletons)
149
+ if edges:
150
+ return _normalize_edges(edges, node_names)
151
+
152
+ bottomup_cfg = getattr(cfg.model_config.head_configs, "bottomup", None)
153
+ if bottomup_cfg is not None and bottomup_cfg.pafs is not None:
154
+ edges = bottomup_cfg.pafs.edges
155
+ if edges:
156
+ return _normalize_edges(edges, node_names)
157
+
158
+ return []
159
+
160
+
161
+ def resolve_model_type(cfg: DictConfig) -> str:
162
+ """Return model type from config."""
163
+ return get_model_type_from_cfg(cfg)
164
+
165
+
166
+ def resolve_backbone_type(cfg: DictConfig) -> str:
167
+ """Return backbone type from config."""
168
+ return get_backbone_type_from_cfg(cfg)
169
+
170
+
171
+ def resolve_input_shape(
172
+ cfg: DictConfig,
173
+ input_height: Optional[int] = None,
174
+ input_width: Optional[int] = None,
175
+ ) -> Tuple[int, int, int, int]:
176
+ """Resolve a dummy input shape for export."""
177
+ channels = resolve_input_channels(cfg)
178
+ height = input_height or cfg.data_config.preprocessing.max_height or 512
179
+ width = input_width or cfg.data_config.preprocessing.max_width or 512
180
+ return 1, channels, int(height), int(width)
181
+
182
+
183
+ def _node_names_from_skeletons(skeletons) -> List[str]:
184
+ if not skeletons:
185
+ return []
186
+ skeleton = skeletons[0]
187
+ if hasattr(skeleton, "nodes"):
188
+ try:
189
+ return [node.name for node in skeleton.nodes]
190
+ except Exception:
191
+ pass
192
+ if isinstance(skeleton, dict):
193
+ nodes = skeleton.get("nodes")
194
+ if nodes:
195
+ if isinstance(nodes[0], dict):
196
+ return [node.get("name", "") for node in nodes if node.get("name")]
197
+ return [str(node) for node in nodes]
198
+ node_names = skeleton.get("node_names")
199
+ if node_names:
200
+ return [str(name) for name in node_names]
201
+ return []
202
+
203
+
204
+ def _edge_inds_from_skeletons(skeletons) -> List:
205
+ if not skeletons:
206
+ return []
207
+ skeleton = skeletons[0]
208
+ if hasattr(skeleton, "edge_inds"):
209
+ try:
210
+ return list(skeleton.edge_inds)
211
+ except Exception:
212
+ pass
213
+ if isinstance(skeleton, dict):
214
+ edges = skeleton.get("edges") or skeleton.get("edge_inds")
215
+ if edges:
216
+ return list(edges)
217
+ return []
218
+
219
+
220
+ def _normalize_edges(edges: List, node_names: List[str]) -> List[Tuple[int, int]]:
221
+ if not edges:
222
+ return []
223
+ if not node_names:
224
+ return [(int(src), int(dst)) for src, dst in edges]
225
+
226
+ if isinstance(edges[0][0], str):
227
+ name_to_idx = {name: idx for idx, name in enumerate(node_names)}
228
+ normalized = []
229
+ for src, dst in edges:
230
+ if src in name_to_idx and dst in name_to_idx:
231
+ normalized.append((name_to_idx[src], name_to_idx[dst]))
232
+ return normalized
233
+
234
+ return [(int(src), int(dst)) for src, dst in edges]
235
+
236
+
237
+ def build_bottomup_candidate_template(
238
+ n_nodes: int, max_peaks_per_node: int, edge_inds: List[Tuple[int, int]]
239
+ ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
240
+ """Build candidate template matching ONNX wrapper's line_scores ordering.
241
+
242
+ The ONNX BottomUpONNXWrapper produces line_scores with shape (n_edges, k*k) where
243
+ for each edge connecting (src_node, dst_node), position i*k + j corresponds to:
244
+ - src peak flat index: src_node * k + i
245
+ - dst peak flat index: dst_node * k + j
246
+
247
+ This function builds edge_inds and edge_peak_inds tensors that match this exact
248
+ ordering, so that line_scores_flat[idx] corresponds to edge_peak_inds[idx].
249
+
250
+ Args:
251
+ n_nodes: Number of nodes in the skeleton.
252
+ max_peaks_per_node: Maximum peaks per node (k) used during export.
253
+ edge_inds: List of (src_node, dst_node) tuples defining skeleton edges.
254
+
255
+ Returns:
256
+ Tuple of (peak_channel_inds, edge_inds_tensor, edge_peak_inds_tensor):
257
+ - peak_channel_inds: (n_nodes * k,) tensor mapping flat peak index to node
258
+ - edge_inds_tensor: (n_edges * k * k,) tensor of edge indices for each candidate
259
+ - edge_peak_inds_tensor: (n_edges * k * k, 2) tensor of (src, dst) peak indices
260
+
261
+ Example:
262
+ >>> from sleap_nn.export.utils import build_bottomup_candidate_template
263
+ >>> peak_ch, edge_inds, edge_peaks = build_bottomup_candidate_template(
264
+ ... n_nodes=15, max_peaks_per_node=20, edge_inds=[(1, 2), (1, 5)]
265
+ ... )
266
+ >>> # Use with ONNX output:
267
+ >>> line_scores_flat = line_scores.reshape(-1)
268
+ >>> valid_scores = line_scores_flat[valid_mask]
269
+ >>> valid_edge_peaks = edge_peaks[valid_mask]
270
+
271
+ Note:
272
+ This function is necessary because `get_connection_candidates()` in
273
+ `sleap_nn.inference.paf_grouping` uses unstable argsort, which shuffles
274
+ peak indices within each node and breaks alignment with ONNX output ordering.
275
+ """
276
+ import torch
277
+
278
+ k = max_peaks_per_node
279
+ n_edges = len(edge_inds)
280
+
281
+ # peak_channel_inds: [0,0,...0, 1,1,...1, ...] (k times each)
282
+ peak_channel_inds = torch.arange(n_nodes, dtype=torch.int32).repeat_interleave(k)
283
+
284
+ edge_inds_list = []
285
+ edge_peak_inds_list = []
286
+
287
+ for edge_idx, (src_node, dst_node) in enumerate(edge_inds):
288
+ # Build k*k candidate pairs in row-major order (i*k + j)
289
+ # src indices: [src_node*k + 0, src_node*k + 0, ..., src_node*k + 1, ...]
290
+ # dst indices: [dst_node*k + 0, dst_node*k + 1, ..., dst_node*k + 0, ...]
291
+ src_base = src_node * k
292
+ dst_base = dst_node * k
293
+
294
+ src_indices = torch.arange(k, dtype=torch.int32).repeat_interleave(k) + src_base
295
+ dst_indices = torch.arange(k, dtype=torch.int32).repeat(k) + dst_base
296
+
297
+ edge_inds_list.append(torch.full((k * k,), edge_idx, dtype=torch.int32))
298
+ edge_peak_inds_list.append(torch.stack([src_indices, dst_indices], dim=1))
299
+
300
+ if edge_inds_list:
301
+ edge_inds_tensor = torch.cat(edge_inds_list)
302
+ edge_peak_inds_tensor = torch.cat(edge_peak_inds_list)
303
+ else:
304
+ edge_inds_tensor = torch.empty((0,), dtype=torch.int32)
305
+ edge_peak_inds_tensor = torch.empty((0, 2), dtype=torch.int32)
306
+
307
+ return peak_channel_inds, edge_inds_tensor, edge_peak_inds_tensor
@@ -0,0 +1,25 @@
1
+ """ONNX/TensorRT export wrappers."""
2
+
3
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
4
+ from sleap_nn.export.wrappers.centroid import CentroidONNXWrapper
5
+ from sleap_nn.export.wrappers.centered_instance import CenteredInstanceONNXWrapper
6
+ from sleap_nn.export.wrappers.topdown import TopDownONNXWrapper
7
+ from sleap_nn.export.wrappers.bottomup import BottomUpONNXWrapper
8
+ from sleap_nn.export.wrappers.single_instance import SingleInstanceONNXWrapper
9
+ from sleap_nn.export.wrappers.topdown_multiclass import (
10
+ TopDownMultiClassONNXWrapper,
11
+ TopDownMultiClassCombinedONNXWrapper,
12
+ )
13
+ from sleap_nn.export.wrappers.bottomup_multiclass import BottomUpMultiClassONNXWrapper
14
+
15
+ __all__ = [
16
+ "BaseExportWrapper",
17
+ "CentroidONNXWrapper",
18
+ "CenteredInstanceONNXWrapper",
19
+ "TopDownONNXWrapper",
20
+ "BottomUpONNXWrapper",
21
+ "SingleInstanceONNXWrapper",
22
+ "TopDownMultiClassONNXWrapper",
23
+ "TopDownMultiClassCombinedONNXWrapper",
24
+ "BottomUpMultiClassONNXWrapper",
25
+ ]
@@ -0,0 +1,96 @@
1
+ """Base classes and shared helpers for export wrappers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Iterable, List, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class BaseExportWrapper(nn.Module):
13
+ """Base class for ONNX-exportable wrappers."""
14
+
15
+ def __init__(self, model: nn.Module):
16
+ """Initialize wrapper with the underlying model.
17
+
18
+ Args:
19
+ model: The PyTorch model to wrap for export.
20
+ """
21
+ super().__init__()
22
+ self.model = model
23
+
24
+ @staticmethod
25
+ def _normalize_uint8(image: torch.Tensor) -> torch.Tensor:
26
+ """Normalize unnormalized uint8 (or [0, 255] float) images to [0, 1]."""
27
+ if image.dtype != torch.float32:
28
+ image = image.float()
29
+ return image / 255.0
30
+
31
+ @staticmethod
32
+ def _extract_tensor(output, key_hints: Iterable[str]) -> torch.Tensor:
33
+ if isinstance(output, dict):
34
+ for key in output:
35
+ for hint in key_hints:
36
+ if hint.lower() in key.lower():
37
+ return output[key]
38
+ return next(iter(output.values()))
39
+ return output
40
+
41
+ @staticmethod
42
+ def _find_topk_peaks(
43
+ confmaps: torch.Tensor, k: int
44
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
45
+ """Top-K peak finding with NMS via max pooling."""
46
+ batch_size, _, height, width = confmaps.shape
47
+ pooled = F.max_pool2d(confmaps, kernel_size=3, stride=1, padding=1)
48
+ is_peak = (confmaps == pooled) & (confmaps > 0)
49
+
50
+ confmaps_flat = confmaps.reshape(batch_size, height * width)
51
+ is_peak_flat = is_peak.reshape(batch_size, height * width)
52
+ masked = torch.where(
53
+ is_peak_flat, confmaps_flat, torch.full_like(confmaps_flat, -1e9)
54
+ )
55
+ values, indices = torch.topk(masked, k=k, dim=1)
56
+
57
+ y = indices // width
58
+ x = indices % width
59
+ peaks = torch.stack([x.float(), y.float()], dim=-1)
60
+ valid = values > 0
61
+ return peaks, values, valid
62
+
63
+ @staticmethod
64
+ def _find_topk_peaks_per_node(
65
+ confmaps: torch.Tensor, k: int
66
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
67
+ """Top-K peak finding per channel with NMS via max pooling."""
68
+ batch_size, n_nodes, height, width = confmaps.shape
69
+ pooled = F.max_pool2d(confmaps, kernel_size=3, stride=1, padding=1)
70
+ is_peak = (confmaps == pooled) & (confmaps > 0)
71
+
72
+ confmaps_flat = confmaps.reshape(batch_size, n_nodes, height * width)
73
+ is_peak_flat = is_peak.reshape(batch_size, n_nodes, height * width)
74
+ masked = torch.where(
75
+ is_peak_flat, confmaps_flat, torch.full_like(confmaps_flat, -1e9)
76
+ )
77
+ values, indices = torch.topk(masked, k=k, dim=2)
78
+
79
+ y = indices // width
80
+ x = indices % width
81
+ peaks = torch.stack([x.float(), y.float()], dim=-1)
82
+ valid = values > 0
83
+ return peaks, values, valid
84
+
85
+ @staticmethod
86
+ def _find_global_peaks(
87
+ confmaps: torch.Tensor,
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Find global maxima per channel."""
90
+ batch_size, channels, height, width = confmaps.shape
91
+ flat = confmaps.reshape(batch_size, channels, height * width)
92
+ values, indices = flat.max(dim=-1)
93
+ y = indices // width
94
+ x = indices % width
95
+ peaks = torch.stack([x.float(), y.float()], dim=-1)
96
+ return peaks, values
@@ -0,0 +1,243 @@
1
+ """Bottom-up ONNX wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
12
+
13
+
14
+ class BottomUpONNXWrapper(BaseExportWrapper):
15
+ """ONNX-exportable wrapper for bottom-up inference up to PAF scoring.
16
+
17
+ Expects input images as uint8 tensors in [0, 255].
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: nn.Module,
23
+ skeleton_edges: list,
24
+ n_nodes: int,
25
+ max_peaks_per_node: int = 20,
26
+ n_line_points: int = 10,
27
+ cms_output_stride: int = 4,
28
+ pafs_output_stride: int = 8,
29
+ max_edge_length_ratio: float = 0.25,
30
+ dist_penalty_weight: float = 1.0,
31
+ input_scale: float = 1.0,
32
+ ) -> None:
33
+ """Initialize bottom-up ONNX wrapper.
34
+
35
+ Args:
36
+ model: Bottom-up model producing confidence maps and PAFs.
37
+ skeleton_edges: List of (src, dst) edge tuples defining skeleton.
38
+ n_nodes: Number of nodes in the skeleton.
39
+ max_peaks_per_node: Maximum peaks to detect per node type.
40
+ n_line_points: Points to sample along PAF edges.
41
+ cms_output_stride: Confidence map output stride.
42
+ pafs_output_stride: PAF output stride.
43
+ max_edge_length_ratio: Maximum edge length as ratio of image size.
44
+ dist_penalty_weight: Weight for distance penalty in scoring.
45
+ input_scale: Input scaling factor.
46
+ """
47
+ super().__init__(model)
48
+ self.n_nodes = n_nodes
49
+ self.n_edges = len(skeleton_edges)
50
+ self.max_peaks_per_node = max_peaks_per_node
51
+ self.n_line_points = n_line_points
52
+ self.cms_output_stride = cms_output_stride
53
+ self.pafs_output_stride = pafs_output_stride
54
+ self.max_edge_length_ratio = max_edge_length_ratio
55
+ self.dist_penalty_weight = dist_penalty_weight
56
+ self.input_scale = input_scale
57
+
58
+ edge_src = torch.tensor([e[0] for e in skeleton_edges], dtype=torch.long)
59
+ edge_dst = torch.tensor([e[1] for e in skeleton_edges], dtype=torch.long)
60
+ self.register_buffer("edge_src", edge_src)
61
+ self.register_buffer("edge_dst", edge_dst)
62
+
63
+ line_samples = torch.linspace(0, 1, n_line_points, dtype=torch.float32)
64
+ self.register_buffer("line_samples", line_samples)
65
+
66
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
67
+ """Run bottom-up inference and return fixed-size outputs.
68
+
69
+ Note: confmaps and pafs are NOT returned to avoid D2H transfer bottleneck.
70
+ Peak detection and PAF scoring are performed on GPU within this wrapper.
71
+ """
72
+ image = self._normalize_uint8(image)
73
+ if self.input_scale != 1.0:
74
+ height = int(image.shape[-2] * self.input_scale)
75
+ width = int(image.shape[-1] * self.input_scale)
76
+ image = F.interpolate(
77
+ image, size=(height, width), mode="bilinear", align_corners=False
78
+ )
79
+
80
+ batch_size, _, height, width = image.shape
81
+
82
+ out = self.model(image)
83
+ if isinstance(out, dict):
84
+ confmaps = self._extract_tensor(out, ["confmap", "multiinstance"])
85
+ pafs = self._extract_tensor(out, ["paf", "affinity"])
86
+ else:
87
+ confmaps, pafs = out[:2]
88
+
89
+ peaks, peak_vals, peak_mask = self._find_topk_peaks_per_node(
90
+ confmaps, self.max_peaks_per_node
91
+ )
92
+
93
+ peaks = peaks * self.cms_output_stride
94
+
95
+ # Compute max_edge_length to match PyTorch implementation:
96
+ # max_edge_length = ratio * max(paf_dims) * pafs_stride
97
+ # PAFs shape is (batch, 2*edges, H, W)
98
+ _, n_paf_channels, paf_height, paf_width = pafs.shape
99
+ max_paf_dim = max(n_paf_channels, paf_height, paf_width)
100
+ max_edge_length = torch.tensor(
101
+ self.max_edge_length_ratio * max_paf_dim * self.pafs_output_stride,
102
+ dtype=peaks.dtype,
103
+ device=peaks.device,
104
+ )
105
+
106
+ line_scores, candidate_mask = self._score_all_candidates(
107
+ pafs, peaks, peak_mask, max_edge_length
108
+ )
109
+
110
+ # Only return final outputs needed for CPU-side grouping.
111
+ # Do NOT return confmaps/pafs - they are large (~29 MB/batch) and
112
+ # cause D2H transfer bottleneck. Peak detection and PAF scoring
113
+ # are already done on GPU above.
114
+ return {
115
+ "peaks": peaks,
116
+ "peak_vals": peak_vals,
117
+ "peak_mask": peak_mask,
118
+ "line_scores": line_scores,
119
+ "candidate_mask": candidate_mask,
120
+ }
121
+
122
+ def _score_all_candidates(
123
+ self,
124
+ pafs: torch.Tensor,
125
+ peaks: torch.Tensor,
126
+ peak_mask: torch.Tensor,
127
+ max_edge_length: torch.Tensor,
128
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
129
+ """Score all K*K candidate connections for each edge."""
130
+ batch_size = peaks.shape[0]
131
+ k = self.max_peaks_per_node
132
+ n_edges = self.n_edges
133
+
134
+ _, _, paf_height, paf_width = pafs.shape
135
+
136
+ src_peaks = peaks[:, self.edge_src, :, :]
137
+ dst_peaks = peaks[:, self.edge_dst, :, :]
138
+
139
+ src_mask = peak_mask[:, self.edge_src, :]
140
+ dst_mask = peak_mask[:, self.edge_dst, :]
141
+
142
+ src_peaks_exp = src_peaks.unsqueeze(3).expand(-1, -1, -1, k, -1)
143
+ dst_peaks_exp = dst_peaks.unsqueeze(2).expand(-1, -1, k, -1, -1)
144
+
145
+ src_mask_exp = src_mask.unsqueeze(3).expand(-1, -1, -1, k)
146
+ dst_mask_exp = dst_mask.unsqueeze(2).expand(-1, -1, k, -1)
147
+ candidate_mask = src_mask_exp & dst_mask_exp
148
+
149
+ src_peaks_flat = src_peaks_exp.reshape(batch_size, n_edges, k * k, 2)
150
+ dst_peaks_flat = dst_peaks_exp.reshape(batch_size, n_edges, k * k, 2)
151
+ candidate_mask_flat = candidate_mask.reshape(batch_size, n_edges, k * k)
152
+
153
+ spatial_vecs = dst_peaks_flat - src_peaks_flat
154
+ spatial_lengths = torch.norm(spatial_vecs, dim=-1, keepdim=True).clamp(min=1e-6)
155
+ spatial_vecs_norm = spatial_vecs / spatial_lengths
156
+
157
+ line_samples = self.line_samples.view(1, 1, 1, -1, 1)
158
+ src_exp = src_peaks_flat.unsqueeze(3)
159
+ dst_exp = dst_peaks_flat.unsqueeze(3)
160
+ line_points = src_exp + line_samples * (dst_exp - src_exp)
161
+
162
+ line_points_paf = line_points / self.pafs_output_stride
163
+ line_x = line_points_paf[..., 0].clamp(0, paf_width - 1)
164
+ line_y = line_points_paf[..., 1].clamp(0, paf_height - 1)
165
+
166
+ line_scores = self._sample_and_score_lines(
167
+ pafs,
168
+ line_x,
169
+ line_y,
170
+ spatial_vecs_norm,
171
+ spatial_lengths.squeeze(-1),
172
+ max_edge_length,
173
+ )
174
+
175
+ line_scores = line_scores.masked_fill(~candidate_mask_flat, -2.0)
176
+ return line_scores, candidate_mask_flat
177
+
178
+ def _sample_and_score_lines(
179
+ self,
180
+ pafs: torch.Tensor,
181
+ line_x: torch.Tensor,
182
+ line_y: torch.Tensor,
183
+ spatial_vecs_norm: torch.Tensor,
184
+ spatial_lengths: torch.Tensor,
185
+ max_edge_length: torch.Tensor,
186
+ ) -> torch.Tensor:
187
+ """Sample PAF values along lines and compute scores."""
188
+ batch_size, n_edges, k2, n_points = line_x.shape
189
+ _, _, paf_height, paf_width = pafs.shape
190
+
191
+ all_scores = []
192
+ for edge_idx in range(n_edges):
193
+ paf_x = pafs[:, 2 * edge_idx, :, :]
194
+ paf_y = pafs[:, 2 * edge_idx + 1, :, :]
195
+
196
+ lx = line_x[:, edge_idx, :, :]
197
+ ly = line_y[:, edge_idx, :, :]
198
+
199
+ lx_norm = (lx / (paf_width - 1)) * 2 - 1
200
+ ly_norm = (ly / (paf_height - 1)) * 2 - 1
201
+
202
+ grid = torch.stack([lx_norm, ly_norm], dim=-1)
203
+
204
+ paf_x_samples = F.grid_sample(
205
+ paf_x.unsqueeze(1),
206
+ grid,
207
+ mode="bilinear",
208
+ padding_mode="zeros",
209
+ align_corners=True,
210
+ ).squeeze(1)
211
+
212
+ paf_y_samples = F.grid_sample(
213
+ paf_y.unsqueeze(1),
214
+ grid,
215
+ mode="bilinear",
216
+ padding_mode="zeros",
217
+ align_corners=True,
218
+ ).squeeze(1)
219
+
220
+ paf_samples = torch.stack([paf_x_samples, paf_y_samples], dim=-1)
221
+ disp_vec = spatial_vecs_norm[:, edge_idx, :, :]
222
+
223
+ dot_products = (paf_samples * disp_vec.unsqueeze(2)).sum(dim=-1)
224
+ mean_scores = dot_products.mean(dim=-1)
225
+
226
+ edge_lengths = spatial_lengths[:, edge_idx, :]
227
+ dist_penalty = self._compute_distance_penalty(edge_lengths, max_edge_length)
228
+
229
+ all_scores.append(mean_scores + dist_penalty)
230
+
231
+ return torch.stack(all_scores, dim=1)
232
+
233
+ def _compute_distance_penalty(
234
+ self, distances: torch.Tensor, max_edge_length: torch.Tensor
235
+ ) -> torch.Tensor:
236
+ """Compute distance penalty for edge candidates.
237
+
238
+ Matches the PyTorch implementation in sleap_nn.inference.paf_grouping.
239
+ Penalty is 0 when distance <= max_edge_length, and negative when longer.
240
+ """
241
+ # Match PyTorch: penalty = clamp((max_edge_length / distance) - 1, max=0) * weight
242
+ penalty = torch.clamp((max_edge_length / distances) - 1, max=0)
243
+ return penalty * self.dist_penalty_weight