sleap-nn 0.1.0a2__py3-none-any.whl → 0.1.0a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +168 -39
  6. sleap_nn/evaluation.py +8 -0
  7. sleap_nn/export/__init__.py +21 -0
  8. sleap_nn/export/cli.py +1778 -0
  9. sleap_nn/export/exporters/__init__.py +51 -0
  10. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  11. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  12. sleap_nn/export/metadata.py +225 -0
  13. sleap_nn/export/predictors/__init__.py +63 -0
  14. sleap_nn/export/predictors/base.py +22 -0
  15. sleap_nn/export/predictors/onnx.py +154 -0
  16. sleap_nn/export/predictors/tensorrt.py +312 -0
  17. sleap_nn/export/utils.py +307 -0
  18. sleap_nn/export/wrappers/__init__.py +25 -0
  19. sleap_nn/export/wrappers/base.py +96 -0
  20. sleap_nn/export/wrappers/bottomup.py +243 -0
  21. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  22. sleap_nn/export/wrappers/centered_instance.py +56 -0
  23. sleap_nn/export/wrappers/centroid.py +58 -0
  24. sleap_nn/export/wrappers/single_instance.py +83 -0
  25. sleap_nn/export/wrappers/topdown.py +180 -0
  26. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  27. sleap_nn/inference/peak_finding.py +47 -17
  28. sleap_nn/inference/postprocessing.py +284 -0
  29. sleap_nn/inference/predictors.py +213 -106
  30. sleap_nn/predict.py +35 -7
  31. sleap_nn/train.py +64 -0
  32. sleap_nn/training/callbacks.py +69 -22
  33. sleap_nn/training/lightning_modules.py +332 -30
  34. sleap_nn/training/model_trainer.py +67 -67
  35. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/METADATA +13 -1
  36. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/RECORD +40 -19
  37. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/WHEEL +0 -0
  38. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/entry_points.txt +0 -0
  39. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/licenses/LICENSE +0 -0
  40. {sleap_nn-0.1.0a2.dist-info → sleap_nn-0.1.0a4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ """ONNX wrapper for top-down multiclass (supervised ID) models."""
2
+
3
+ from typing import Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
10
+
11
+
12
+ class TopDownMultiClassONNXWrapper(BaseExportWrapper):
13
+ """ONNX-exportable wrapper for top-down multiclass (supervised ID) models.
14
+
15
+ This wrapper handles models that output both confidence maps for keypoint
16
+ detection and class logits for identity classification. It runs on instance
17
+ crops (centered around detected centroids).
18
+
19
+ Expects input images as uint8 tensors in [0, 255].
20
+
21
+ Attributes:
22
+ model: The underlying PyTorch model (centered instance + class vectors heads).
23
+ output_stride: Output stride of the confmap head.
24
+ input_scale: Scale factor applied to input images before inference.
25
+ n_classes: Number of identity classes.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: nn.Module,
31
+ output_stride: int = 2,
32
+ input_scale: float = 1.0,
33
+ n_classes: int = 2,
34
+ ):
35
+ """Initialize the wrapper.
36
+
37
+ Args:
38
+ model: The underlying PyTorch model.
39
+ output_stride: Output stride of the confidence maps.
40
+ input_scale: Scale factor for input images.
41
+ n_classes: Number of identity classes (e.g., 2 for male/female).
42
+ """
43
+ super().__init__(model)
44
+ self.output_stride = output_stride
45
+ self.input_scale = input_scale
46
+ self.n_classes = n_classes
47
+
48
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
49
+ """Run top-down multiclass inference on crops.
50
+
51
+ Args:
52
+ image: Input image tensor of shape (batch, channels, height, width).
53
+ Expected to be uint8 in [0, 255].
54
+
55
+ Returns:
56
+ Dictionary with keys:
57
+ - "peaks": Predicted peak coordinates (batch, n_nodes, 2) in (x, y).
58
+ - "peak_vals": Peak confidence values (batch, n_nodes).
59
+ - "class_logits": Raw class logits (batch, n_classes).
60
+
61
+ The class assignment is done on CPU using Hungarian matching
62
+ via `get_class_inds_from_vectors()`.
63
+ """
64
+ # Normalize uint8 [0, 255] to float32 [0, 1]
65
+ image = self._normalize_uint8(image)
66
+
67
+ # Apply input scaling if needed
68
+ if self.input_scale != 1.0:
69
+ height = int(image.shape[-2] * self.input_scale)
70
+ width = int(image.shape[-1] * self.input_scale)
71
+ image = F.interpolate(
72
+ image, size=(height, width), mode="bilinear", align_corners=False
73
+ )
74
+
75
+ # Forward pass
76
+ out = self.model(image)
77
+
78
+ # Extract outputs
79
+ confmaps = self._extract_tensor(out, ["centered", "instance", "confmap"])
80
+ class_logits = self._extract_tensor(out, ["class", "vector"])
81
+
82
+ # Find global peaks (one per node)
83
+ peaks, peak_vals = self._find_global_peaks(confmaps)
84
+
85
+ # Scale peaks back to input coordinates
86
+ peaks = peaks * (self.output_stride / self.input_scale)
87
+
88
+ return {
89
+ "peaks": peaks,
90
+ "peak_vals": peak_vals,
91
+ "class_logits": class_logits,
92
+ }
93
+
94
+
95
+ class TopDownMultiClassCombinedONNXWrapper(BaseExportWrapper):
96
+ """ONNX-exportable wrapper for combined centroid + multiclass instance models.
97
+
98
+ This wrapper combines a centroid detection model with a centered instance
99
+ multiclass model. It performs:
100
+ 1. Centroid detection on full images
101
+ 2. Cropping around each centroid using vectorized grid_sample
102
+ 3. Instance keypoint detection + identity classification on each crop
103
+
104
+ Expects input images as uint8 tensors in [0, 255].
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ centroid_model: nn.Module,
110
+ instance_model: nn.Module,
111
+ max_instances: int = 20,
112
+ crop_size: tuple = (192, 192),
113
+ centroid_output_stride: int = 4,
114
+ instance_output_stride: int = 2,
115
+ centroid_input_scale: float = 1.0,
116
+ instance_input_scale: float = 1.0,
117
+ n_nodes: int = 13,
118
+ n_classes: int = 2,
119
+ ):
120
+ """Initialize the combined wrapper.
121
+
122
+ Args:
123
+ centroid_model: Model for centroid detection.
124
+ instance_model: Model for instance keypoints + class prediction.
125
+ max_instances: Maximum number of instances to detect.
126
+ crop_size: Size of crops around centroids (height, width).
127
+ centroid_output_stride: Output stride of centroid model.
128
+ instance_output_stride: Output stride of instance model.
129
+ centroid_input_scale: Input scale for centroid model.
130
+ instance_input_scale: Input scale for instance model.
131
+ n_nodes: Number of keypoint nodes per instance.
132
+ n_classes: Number of identity classes.
133
+ """
134
+ super().__init__(centroid_model) # Primary model is centroid
135
+ self.instance_model = instance_model
136
+ self.max_instances = max_instances
137
+ self.crop_size = crop_size
138
+ self.centroid_output_stride = centroid_output_stride
139
+ self.instance_output_stride = instance_output_stride
140
+ self.centroid_input_scale = centroid_input_scale
141
+ self.instance_input_scale = instance_input_scale
142
+ self.n_nodes = n_nodes
143
+ self.n_classes = n_classes
144
+
145
+ # Pre-compute base grid for crop extraction (same as TopDownONNXWrapper)
146
+ crop_h, crop_w = crop_size
147
+ y_crop = torch.linspace(-1, 1, crop_h, dtype=torch.float32)
148
+ x_crop = torch.linspace(-1, 1, crop_w, dtype=torch.float32)
149
+ grid_y, grid_x = torch.meshgrid(y_crop, x_crop, indexing="ij")
150
+ base_grid = torch.stack([grid_x, grid_y], dim=-1)
151
+ self.register_buffer("base_grid", base_grid, persistent=False)
152
+
153
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
154
+ """Run combined top-down multiclass inference.
155
+
156
+ Args:
157
+ image: Input image tensor of shape (batch, channels, height, width).
158
+ Expected to be uint8 in [0, 255].
159
+
160
+ Returns:
161
+ Dictionary with keys:
162
+ - "centroids": Detected centroids (batch, max_instances, 2).
163
+ - "centroid_vals": Centroid confidence values (batch, max_instances).
164
+ - "peaks": Instance peaks (batch, max_instances, n_nodes, 2).
165
+ - "peak_vals": Peak values (batch, max_instances, n_nodes).
166
+ - "class_logits": Class logits per instance (batch, max_instances, n_classes).
167
+ - "instance_valid": Validity mask (batch, max_instances).
168
+ """
169
+ # Normalize input
170
+ image = self._normalize_uint8(image)
171
+ batch_size, channels, height, width = image.shape
172
+
173
+ # Apply centroid input scaling
174
+ scaled_image = image
175
+ if self.centroid_input_scale != 1.0:
176
+ scaled_h = int(height * self.centroid_input_scale)
177
+ scaled_w = int(width * self.centroid_input_scale)
178
+ scaled_image = F.interpolate(
179
+ scaled_image,
180
+ size=(scaled_h, scaled_w),
181
+ mode="bilinear",
182
+ align_corners=False,
183
+ )
184
+
185
+ # Centroid detection
186
+ centroid_out = self.model(scaled_image)
187
+ centroid_cms = self._extract_tensor(centroid_out, ["centroid", "confmap"])
188
+ centroids, centroid_vals, instance_valid = self._find_topk_peaks(
189
+ centroid_cms, self.max_instances
190
+ )
191
+ centroids = centroids * (
192
+ self.centroid_output_stride / self.centroid_input_scale
193
+ )
194
+
195
+ # Extract crops using vectorized grid_sample (same as TopDownONNXWrapper)
196
+ crops = self._extract_crops(image, centroids)
197
+ crops_flat = crops.reshape(
198
+ batch_size * self.max_instances,
199
+ channels,
200
+ self.crop_size[0],
201
+ self.crop_size[1],
202
+ )
203
+
204
+ # Apply instance input scaling if needed
205
+ if self.instance_input_scale != 1.0:
206
+ scaled_h = int(self.crop_size[0] * self.instance_input_scale)
207
+ scaled_w = int(self.crop_size[1] * self.instance_input_scale)
208
+ crops_flat = F.interpolate(
209
+ crops_flat,
210
+ size=(scaled_h, scaled_w),
211
+ mode="bilinear",
212
+ align_corners=False,
213
+ )
214
+
215
+ # Instance model forward (batch all crops)
216
+ instance_out = self.instance_model(crops_flat)
217
+ instance_cms = self._extract_tensor(
218
+ instance_out, ["centered", "instance", "confmap"]
219
+ )
220
+ instance_class = self._extract_tensor(instance_out, ["class", "vector"])
221
+
222
+ # Find peaks in all crops
223
+ crop_peaks, crop_peak_vals = self._find_global_peaks(instance_cms)
224
+ crop_peaks = crop_peaks * (
225
+ self.instance_output_stride / self.instance_input_scale
226
+ )
227
+
228
+ # Reshape to batch x instances x nodes x 2
229
+ crop_peaks = crop_peaks.reshape(batch_size, self.max_instances, self.n_nodes, 2)
230
+ peak_vals = crop_peak_vals.reshape(batch_size, self.max_instances, self.n_nodes)
231
+
232
+ # Reshape class logits
233
+ class_logits = instance_class.reshape(
234
+ batch_size, self.max_instances, self.n_classes
235
+ )
236
+
237
+ # Transform peaks from crop coordinates to full image coordinates
238
+ crop_offset = centroids.unsqueeze(2) - image.new_tensor(
239
+ [self.crop_size[1] / 2.0, self.crop_size[0] / 2.0]
240
+ )
241
+ peaks = crop_peaks + crop_offset
242
+
243
+ # Zero out invalid instances
244
+ invalid_mask = ~instance_valid
245
+ centroids = centroids.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
246
+ centroid_vals = centroid_vals.masked_fill(invalid_mask, 0.0)
247
+ peaks = peaks.masked_fill(invalid_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
248
+ peak_vals = peak_vals.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
249
+ class_logits = class_logits.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
250
+
251
+ return {
252
+ "centroids": centroids,
253
+ "centroid_vals": centroid_vals,
254
+ "peaks": peaks,
255
+ "peak_vals": peak_vals,
256
+ "class_logits": class_logits,
257
+ "instance_valid": instance_valid,
258
+ }
259
+
260
+ def _extract_crops(
261
+ self,
262
+ image: torch.Tensor,
263
+ centroids: torch.Tensor,
264
+ ) -> torch.Tensor:
265
+ """Extract crops around centroids using grid_sample.
266
+
267
+ This is the same vectorized implementation as TopDownONNXWrapper.
268
+ """
269
+ batch_size, channels, height, width = image.shape
270
+ crop_h, crop_w = self.crop_size
271
+ n_instances = centroids.shape[1]
272
+
273
+ scale_x = crop_w / width
274
+ scale_y = crop_h / height
275
+ scale = image.new_tensor([scale_x, scale_y])
276
+ base_grid = self.base_grid.to(device=image.device, dtype=image.dtype)
277
+ scaled_grid = base_grid * scale
278
+
279
+ scaled_grid = scaled_grid.unsqueeze(0).unsqueeze(0)
280
+ scaled_grid = scaled_grid.expand(batch_size, n_instances, -1, -1, -1)
281
+
282
+ norm_centroids = torch.zeros_like(centroids)
283
+ norm_centroids[..., 0] = (centroids[..., 0] / (width - 1)) * 2 - 1
284
+ norm_centroids[..., 1] = (centroids[..., 1] / (height - 1)) * 2 - 1
285
+ offset = norm_centroids.unsqueeze(2).unsqueeze(2)
286
+
287
+ sample_grid = scaled_grid + offset
288
+
289
+ image_expanded = image.unsqueeze(1).expand(-1, n_instances, -1, -1, -1)
290
+ image_flat = image_expanded.reshape(
291
+ batch_size * n_instances, channels, height, width
292
+ )
293
+ grid_flat = sample_grid.reshape(batch_size * n_instances, crop_h, crop_w, 2)
294
+
295
+ crops_flat = F.grid_sample(
296
+ image_flat,
297
+ grid_flat,
298
+ mode="bilinear",
299
+ padding_mode="zeros",
300
+ align_corners=True,
301
+ )
302
+
303
+ crops = crops_flat.reshape(batch_size, n_instances, channels, crop_h, crop_w)
304
+ return crops
@@ -3,9 +3,8 @@
3
3
  from typing import Optional, Tuple
4
4
 
5
5
  import kornia as K
6
- import numpy as np
7
6
  import torch
8
- from kornia.geometry.transform import crop_and_resize
7
+ import torch.nn.functional as F
9
8
 
10
9
  from sleap_nn.data.instance_cropping import make_centered_bboxes
11
10
 
@@ -13,7 +12,11 @@ from sleap_nn.data.instance_cropping import make_centered_bboxes
13
12
  def crop_bboxes(
14
13
  images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
15
14
  ) -> torch.Tensor:
16
- """Crop bounding boxes from a batch of images.
15
+ """Crop bounding boxes from a batch of images using fast tensor indexing.
16
+
17
+ This uses tensor unfold operations to extract patches, which is significantly
18
+ faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
19
+ transform computations.
17
20
 
18
21
  Args:
19
22
  images: Tensor of shape (samples, channels, height, width) of a batch of images.
@@ -27,7 +30,7 @@ def crop_bboxes(
27
30
  box should be cropped from.
28
31
 
29
32
  Returns:
30
- A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same
33
+ A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
31
34
  dtype as the input image. The crop size is inferred from the bounding box
32
35
  coordinates.
33
36
 
@@ -42,26 +45,53 @@ def crop_bboxes(
42
45
 
43
46
  See also: `make_centered_bboxes`
44
47
  """
48
+ n_crops = bboxes.shape[0]
49
+ if n_crops == 0:
50
+ # Return empty tensor; use default crop size since we can't infer from bboxes
51
+ return torch.empty(
52
+ 0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
53
+ )
54
+
45
55
  # Compute bounding box size to use for crops.
46
- height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
47
- width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
48
- box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
56
+ height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
57
+ width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
49
58
 
50
59
  # Store original dtype for conversion back after cropping.
51
60
  original_dtype = images.dtype
61
+ device = images.device
62
+ n_samples, channels, img_h, img_w = images.shape
63
+ half_h, half_w = height // 2, width // 2
52
64
 
53
- # Kornia's crop_and_resize requires float32 input.
54
- images_to_crop = images[sample_inds]
55
- if not torch.is_floating_point(images_to_crop):
56
- images_to_crop = images_to_crop.float()
57
-
58
- # Crop.
59
- crops = crop_and_resize(
60
- images_to_crop, # (n_boxes, channels, height, width)
61
- boxes=bboxes,
62
- size=box_size,
65
+ # Pad images for edge handling.
66
+ images_padded = F.pad(
67
+ images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
63
68
  )
64
69
 
70
+ # Extract all possible patches using unfold (creates a view, no copy).
71
+ # Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
72
+ patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
73
+
74
+ # Get crop centers from bboxes.
75
+ # The bbox top-left is at index 0, with (x, y) coordinates.
76
+ # We need the center of the crop (peak location), which is top-left + half_size.
77
+ # Ensure bboxes are on the same device as images for index computation.
78
+ bboxes_on_device = bboxes.to(device)
79
+ crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
80
+ crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
81
+
82
+ # Clamp indices to valid bounds to handle edge cases where centroids
83
+ # might be at or beyond image boundaries.
84
+ crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
85
+ crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
86
+
87
+ # Select crops using advanced indexing.
88
+ # Convert sample_inds to tensor if it's a list.
89
+ if not isinstance(sample_inds, torch.Tensor):
90
+ sample_inds = torch.tensor(sample_inds, device=device)
91
+ sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
92
+ crops = patches[sample_inds_long, :, crop_y, crop_x]
93
+ # Shape: (n_crops, channels, height, width)
94
+
65
95
  # Cast back to original dtype and return.
66
96
  crops = crops.to(original_dtype)
67
97
  return crops
@@ -0,0 +1,284 @@
1
+ """Inference-level postprocessing filters for pose predictions.
2
+
3
+ This module provides filters that run after model inference but before tracking.
4
+ These filters are independent of tracking configuration and can be used standalone.
5
+ """
6
+
7
+ from typing import List, Literal
8
+
9
+ import numpy as np
10
+ import sleap_io as sio
11
+
12
+
13
+ def filter_overlapping_instances(
14
+ labels: sio.Labels,
15
+ threshold: float = 0.8,
16
+ method: Literal["iou", "oks"] = "iou",
17
+ ) -> sio.Labels:
18
+ """Filter overlapping instances using greedy non-maximum suppression.
19
+
20
+ Removes duplicate/overlapping instances by applying greedy NMS based on
21
+ either bounding box IOU or Object Keypoint Similarity (OKS). When two
22
+ instances overlap above the threshold, the lower-scoring one is removed.
23
+
24
+ This filter runs independently of tracking and can be used to clean up
25
+ model outputs before saving or further processing.
26
+
27
+ Args:
28
+ labels: Labels object with predicted instances to filter.
29
+ threshold: Similarity threshold for considering instances as overlapping.
30
+ Instances with similarity > threshold are candidates for removal.
31
+ Lower values are more aggressive (remove more).
32
+ Typical values: 0.3 (aggressive) to 0.8 (permissive).
33
+ method: Similarity metric to use for comparing instances.
34
+ "iou": Bounding box intersection-over-union.
35
+ "oks": Object Keypoint Similarity (pose-based).
36
+
37
+ Returns:
38
+ The input Labels object with overlapping instances removed.
39
+ Modification is done in place, but the object is also returned
40
+ for convenience.
41
+
42
+ Example:
43
+ >>> # Filter instances with >80% bounding box overlap
44
+ >>> labels = filter_overlapping_instances(labels, threshold=0.8, method="iou")
45
+ >>> # Filter using OKS similarity
46
+ >>> labels = filter_overlapping_instances(labels, threshold=0.5, method="oks")
47
+
48
+ Note:
49
+ - Only affects frames with 2+ predicted instances
50
+ - Uses instance.score for ranking; higher scores are preferred
51
+ - For IOU: bounding boxes computed from non-NaN keypoints
52
+ - For OKS: uses standard COCO OKS formula with bbox-derived scale
53
+ """
54
+ for lf in labels.labeled_frames:
55
+ if len(lf.instances) <= 1:
56
+ continue
57
+
58
+ # Separate predicted instances (have scores) from other instances
59
+ predicted = []
60
+ other = []
61
+ for inst in lf.instances:
62
+ if isinstance(inst, sio.PredictedInstance):
63
+ predicted.append(inst)
64
+ else:
65
+ other.append(inst)
66
+
67
+ # Only filter predicted instances
68
+ if len(predicted) <= 1:
69
+ continue
70
+
71
+ # Get scores
72
+ scores = np.array([_instance_score(inst) for inst in predicted])
73
+
74
+ # Apply greedy NMS with selected method
75
+ if method == "iou":
76
+ bboxes = np.array([_instance_bbox(inst) for inst in predicted])
77
+ keep_indices = _nms_greedy_iou(bboxes, scores, threshold)
78
+ elif method == "oks":
79
+ points = [inst.numpy() for inst in predicted]
80
+ keep_indices = _nms_greedy_oks(points, scores, threshold)
81
+ else:
82
+ raise ValueError(f"Unknown method: {method}. Use 'iou' or 'oks'.")
83
+
84
+ # Reconstruct instance list: kept predicted + other instances
85
+ kept_predicted = [predicted[i] for i in keep_indices]
86
+ lf.instances = kept_predicted + other
87
+
88
+ return labels
89
+
90
+
91
+ def _instance_bbox(instance: sio.PredictedInstance) -> np.ndarray:
92
+ """Compute axis-aligned bounding box from instance keypoints.
93
+
94
+ Args:
95
+ instance: Instance with keypoints.
96
+
97
+ Returns:
98
+ Bounding box as [xmin, ymin, xmax, ymax].
99
+ Returns [0, 0, 0, 0] if no valid keypoints.
100
+ """
101
+ pts = instance.numpy() # (n_nodes, 2)
102
+ valid = ~np.isnan(pts).any(axis=1)
103
+
104
+ if not valid.any():
105
+ return np.array([0.0, 0.0, 0.0, 0.0])
106
+
107
+ pts = pts[valid]
108
+ return np.array(
109
+ [pts[:, 0].min(), pts[:, 1].min(), pts[:, 0].max(), pts[:, 1].max()]
110
+ )
111
+
112
+
113
+ def _instance_score(instance: sio.PredictedInstance) -> float:
114
+ """Get instance confidence score.
115
+
116
+ Args:
117
+ instance: Predicted instance.
118
+
119
+ Returns:
120
+ Instance score, or 1.0 if not available.
121
+ """
122
+ return getattr(instance, "score", 1.0)
123
+
124
+
125
+ def _nms_greedy_iou(
126
+ bboxes: np.ndarray,
127
+ scores: np.ndarray,
128
+ threshold: float,
129
+ ) -> List[int]:
130
+ """Apply greedy NMS using bounding box IOU.
131
+
132
+ Args:
133
+ bboxes: Bounding boxes of shape (N, 4) as [xmin, ymin, xmax, ymax].
134
+ scores: Confidence scores of shape (N,).
135
+ threshold: IOU threshold for suppression.
136
+
137
+ Returns:
138
+ List of indices to keep, in order of decreasing score.
139
+ """
140
+ if len(bboxes) == 0:
141
+ return []
142
+
143
+ # Sort by score descending
144
+ order = scores.argsort()[::-1].tolist()
145
+
146
+ keep = []
147
+ while order:
148
+ # Take highest scoring remaining instance
149
+ i = order.pop(0)
150
+ keep.append(i)
151
+
152
+ if not order:
153
+ break
154
+
155
+ # Compute IOU with all remaining instances
156
+ remaining_indices = np.array(order)
157
+ similarities = _compute_iou_one_to_many(bboxes[i], bboxes[remaining_indices])
158
+
159
+ # Keep only instances with similarity <= threshold
160
+ mask = similarities <= threshold
161
+ order = [order[j] for j in range(len(order)) if mask[j]]
162
+
163
+ return keep
164
+
165
+
166
+ def _nms_greedy_oks(
167
+ points_list: List[np.ndarray],
168
+ scores: np.ndarray,
169
+ threshold: float,
170
+ ) -> List[int]:
171
+ """Apply greedy NMS using Object Keypoint Similarity (OKS).
172
+
173
+ Args:
174
+ points_list: List of keypoint arrays, each of shape (n_nodes, 2).
175
+ scores: Confidence scores of shape (N,).
176
+ threshold: OKS threshold for suppression.
177
+
178
+ Returns:
179
+ List of indices to keep, in order of decreasing score.
180
+ """
181
+ if len(points_list) == 0:
182
+ return []
183
+
184
+ # Sort by score descending
185
+ order = scores.argsort()[::-1].tolist()
186
+
187
+ keep = []
188
+ while order:
189
+ # Take highest scoring remaining instance
190
+ i = order.pop(0)
191
+ keep.append(i)
192
+
193
+ if not order:
194
+ break
195
+
196
+ # Compute OKS with all remaining instances
197
+ similarities = np.array(
198
+ [_compute_oks(points_list[i], points_list[j]) for j in order]
199
+ )
200
+
201
+ # Keep only instances with similarity <= threshold
202
+ mask = similarities <= threshold
203
+ order = [order[j] for j in range(len(order)) if mask[j]]
204
+
205
+ return keep
206
+
207
+
208
+ def _compute_iou_one_to_many(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
209
+ """Compute IOU between one box and multiple boxes.
210
+
211
+ Args:
212
+ box: Single box of shape (4,) as [xmin, ymin, xmax, ymax].
213
+ boxes: Multiple boxes of shape (N, 4).
214
+
215
+ Returns:
216
+ IOU values of shape (N,).
217
+ """
218
+ # Intersection coordinates
219
+ inter_xmin = np.maximum(box[0], boxes[:, 0])
220
+ inter_ymin = np.maximum(box[1], boxes[:, 1])
221
+ inter_xmax = np.minimum(box[2], boxes[:, 2])
222
+ inter_ymax = np.minimum(box[3], boxes[:, 3])
223
+
224
+ # Intersection area (0 if no overlap)
225
+ inter_w = np.maximum(0.0, inter_xmax - inter_xmin)
226
+ inter_h = np.maximum(0.0, inter_ymax - inter_ymin)
227
+ inter_area = inter_w * inter_h
228
+
229
+ # Individual areas
230
+ area_a = (box[2] - box[0]) * (box[3] - box[1])
231
+ area_b = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
232
+
233
+ # Union area
234
+ union_area = area_a + area_b - inter_area
235
+
236
+ # IOU (avoid division by zero)
237
+ return np.where(union_area > 0, inter_area / union_area, 0.0)
238
+
239
+
240
+ def _compute_oks(
241
+ points_a: np.ndarray,
242
+ points_b: np.ndarray,
243
+ kappa: float = 0.1,
244
+ ) -> float:
245
+ """Compute Object Keypoint Similarity (OKS) between two instances.
246
+
247
+ Uses a simplified OKS formula where all keypoints have equal weight
248
+ and scale is derived from the bounding box of the reference instance.
249
+
250
+ Args:
251
+ points_a: Keypoints of first instance, shape (n_nodes, 2).
252
+ points_b: Keypoints of second instance, shape (n_nodes, 2).
253
+ kappa: Per-keypoint constant controlling falloff. Default 0.1.
254
+
255
+ Returns:
256
+ OKS value in [0, 1]. Higher means more similar.
257
+ """
258
+ # Find valid keypoints (present in both instances)
259
+ valid_a = ~np.isnan(points_a).any(axis=1)
260
+ valid_b = ~np.isnan(points_b).any(axis=1)
261
+ valid = valid_a & valid_b
262
+
263
+ if not valid.any():
264
+ return 0.0
265
+
266
+ # Compute scale from bounding box area of instance A
267
+ pts_a_valid = points_a[valid_a]
268
+ if len(pts_a_valid) < 2:
269
+ return 0.0
270
+
271
+ bbox_w = pts_a_valid[:, 0].max() - pts_a_valid[:, 0].min()
272
+ bbox_h = pts_a_valid[:, 1].max() - pts_a_valid[:, 1].min()
273
+ scale_sq = bbox_w * bbox_h
274
+
275
+ if scale_sq <= 0:
276
+ return 0.0
277
+
278
+ # Compute squared distances for valid keypoints
279
+ d_sq = np.sum((points_a[valid] - points_b[valid]) ** 2, axis=1)
280
+
281
+ # OKS formula: mean of exp(-d^2 / (2 * s^2 * k^2))
282
+ oks_per_kpt = np.exp(-d_sq / (2 * scale_sq * kappa**2))
283
+
284
+ return float(np.mean(oks_per_kpt))