sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ """ONNX wrapper for top-down multiclass (supervised ID) models."""
2
+
3
+ from typing import Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
10
+
11
+
12
+ class TopDownMultiClassONNXWrapper(BaseExportWrapper):
13
+ """ONNX-exportable wrapper for top-down multiclass (supervised ID) models.
14
+
15
+ This wrapper handles models that output both confidence maps for keypoint
16
+ detection and class logits for identity classification. It runs on instance
17
+ crops (centered around detected centroids).
18
+
19
+ Expects input images as uint8 tensors in [0, 255].
20
+
21
+ Attributes:
22
+ model: The underlying PyTorch model (centered instance + class vectors heads).
23
+ output_stride: Output stride of the confmap head.
24
+ input_scale: Scale factor applied to input images before inference.
25
+ n_classes: Number of identity classes.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ model: nn.Module,
31
+ output_stride: int = 2,
32
+ input_scale: float = 1.0,
33
+ n_classes: int = 2,
34
+ ):
35
+ """Initialize the wrapper.
36
+
37
+ Args:
38
+ model: The underlying PyTorch model.
39
+ output_stride: Output stride of the confidence maps.
40
+ input_scale: Scale factor for input images.
41
+ n_classes: Number of identity classes (e.g., 2 for male/female).
42
+ """
43
+ super().__init__(model)
44
+ self.output_stride = output_stride
45
+ self.input_scale = input_scale
46
+ self.n_classes = n_classes
47
+
48
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
49
+ """Run top-down multiclass inference on crops.
50
+
51
+ Args:
52
+ image: Input image tensor of shape (batch, channels, height, width).
53
+ Expected to be uint8 in [0, 255].
54
+
55
+ Returns:
56
+ Dictionary with keys:
57
+ - "peaks": Predicted peak coordinates (batch, n_nodes, 2) in (x, y).
58
+ - "peak_vals": Peak confidence values (batch, n_nodes).
59
+ - "class_logits": Raw class logits (batch, n_classes).
60
+
61
+ The class assignment is done on CPU using Hungarian matching
62
+ via `get_class_inds_from_vectors()`.
63
+ """
64
+ # Normalize uint8 [0, 255] to float32 [0, 1]
65
+ image = self._normalize_uint8(image)
66
+
67
+ # Apply input scaling if needed
68
+ if self.input_scale != 1.0:
69
+ height = int(image.shape[-2] * self.input_scale)
70
+ width = int(image.shape[-1] * self.input_scale)
71
+ image = F.interpolate(
72
+ image, size=(height, width), mode="bilinear", align_corners=False
73
+ )
74
+
75
+ # Forward pass
76
+ out = self.model(image)
77
+
78
+ # Extract outputs
79
+ confmaps = self._extract_tensor(out, ["centered", "instance", "confmap"])
80
+ class_logits = self._extract_tensor(out, ["class", "vector"])
81
+
82
+ # Find global peaks (one per node)
83
+ peaks, peak_vals = self._find_global_peaks(confmaps)
84
+
85
+ # Scale peaks back to input coordinates
86
+ peaks = peaks * (self.output_stride / self.input_scale)
87
+
88
+ return {
89
+ "peaks": peaks,
90
+ "peak_vals": peak_vals,
91
+ "class_logits": class_logits,
92
+ }
93
+
94
+
95
+ class TopDownMultiClassCombinedONNXWrapper(BaseExportWrapper):
96
+ """ONNX-exportable wrapper for combined centroid + multiclass instance models.
97
+
98
+ This wrapper combines a centroid detection model with a centered instance
99
+ multiclass model. It performs:
100
+ 1. Centroid detection on full images
101
+ 2. Cropping around each centroid using vectorized grid_sample
102
+ 3. Instance keypoint detection + identity classification on each crop
103
+
104
+ Expects input images as uint8 tensors in [0, 255].
105
+ """
106
+
107
+ def __init__(
108
+ self,
109
+ centroid_model: nn.Module,
110
+ instance_model: nn.Module,
111
+ max_instances: int = 20,
112
+ crop_size: tuple = (192, 192),
113
+ centroid_output_stride: int = 4,
114
+ instance_output_stride: int = 2,
115
+ centroid_input_scale: float = 1.0,
116
+ instance_input_scale: float = 1.0,
117
+ n_nodes: int = 13,
118
+ n_classes: int = 2,
119
+ ):
120
+ """Initialize the combined wrapper.
121
+
122
+ Args:
123
+ centroid_model: Model for centroid detection.
124
+ instance_model: Model for instance keypoints + class prediction.
125
+ max_instances: Maximum number of instances to detect.
126
+ crop_size: Size of crops around centroids (height, width).
127
+ centroid_output_stride: Output stride of centroid model.
128
+ instance_output_stride: Output stride of instance model.
129
+ centroid_input_scale: Input scale for centroid model.
130
+ instance_input_scale: Input scale for instance model.
131
+ n_nodes: Number of keypoint nodes per instance.
132
+ n_classes: Number of identity classes.
133
+ """
134
+ super().__init__(centroid_model) # Primary model is centroid
135
+ self.instance_model = instance_model
136
+ self.max_instances = max_instances
137
+ self.crop_size = crop_size
138
+ self.centroid_output_stride = centroid_output_stride
139
+ self.instance_output_stride = instance_output_stride
140
+ self.centroid_input_scale = centroid_input_scale
141
+ self.instance_input_scale = instance_input_scale
142
+ self.n_nodes = n_nodes
143
+ self.n_classes = n_classes
144
+
145
+ # Pre-compute base grid for crop extraction (same as TopDownONNXWrapper)
146
+ crop_h, crop_w = crop_size
147
+ y_crop = torch.linspace(-1, 1, crop_h, dtype=torch.float32)
148
+ x_crop = torch.linspace(-1, 1, crop_w, dtype=torch.float32)
149
+ grid_y, grid_x = torch.meshgrid(y_crop, x_crop, indexing="ij")
150
+ base_grid = torch.stack([grid_x, grid_y], dim=-1)
151
+ self.register_buffer("base_grid", base_grid, persistent=False)
152
+
153
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
154
+ """Run combined top-down multiclass inference.
155
+
156
+ Args:
157
+ image: Input image tensor of shape (batch, channels, height, width).
158
+ Expected to be uint8 in [0, 255].
159
+
160
+ Returns:
161
+ Dictionary with keys:
162
+ - "centroids": Detected centroids (batch, max_instances, 2).
163
+ - "centroid_vals": Centroid confidence values (batch, max_instances).
164
+ - "peaks": Instance peaks (batch, max_instances, n_nodes, 2).
165
+ - "peak_vals": Peak values (batch, max_instances, n_nodes).
166
+ - "class_logits": Class logits per instance (batch, max_instances, n_classes).
167
+ - "instance_valid": Validity mask (batch, max_instances).
168
+ """
169
+ # Normalize input
170
+ image = self._normalize_uint8(image)
171
+ batch_size, channels, height, width = image.shape
172
+
173
+ # Apply centroid input scaling
174
+ scaled_image = image
175
+ if self.centroid_input_scale != 1.0:
176
+ scaled_h = int(height * self.centroid_input_scale)
177
+ scaled_w = int(width * self.centroid_input_scale)
178
+ scaled_image = F.interpolate(
179
+ scaled_image,
180
+ size=(scaled_h, scaled_w),
181
+ mode="bilinear",
182
+ align_corners=False,
183
+ )
184
+
185
+ # Centroid detection
186
+ centroid_out = self.model(scaled_image)
187
+ centroid_cms = self._extract_tensor(centroid_out, ["centroid", "confmap"])
188
+ centroids, centroid_vals, instance_valid = self._find_topk_peaks(
189
+ centroid_cms, self.max_instances
190
+ )
191
+ centroids = centroids * (
192
+ self.centroid_output_stride / self.centroid_input_scale
193
+ )
194
+
195
+ # Extract crops using vectorized grid_sample (same as TopDownONNXWrapper)
196
+ crops = self._extract_crops(image, centroids)
197
+ crops_flat = crops.reshape(
198
+ batch_size * self.max_instances,
199
+ channels,
200
+ self.crop_size[0],
201
+ self.crop_size[1],
202
+ )
203
+
204
+ # Apply instance input scaling if needed
205
+ if self.instance_input_scale != 1.0:
206
+ scaled_h = int(self.crop_size[0] * self.instance_input_scale)
207
+ scaled_w = int(self.crop_size[1] * self.instance_input_scale)
208
+ crops_flat = F.interpolate(
209
+ crops_flat,
210
+ size=(scaled_h, scaled_w),
211
+ mode="bilinear",
212
+ align_corners=False,
213
+ )
214
+
215
+ # Instance model forward (batch all crops)
216
+ instance_out = self.instance_model(crops_flat)
217
+ instance_cms = self._extract_tensor(
218
+ instance_out, ["centered", "instance", "confmap"]
219
+ )
220
+ instance_class = self._extract_tensor(instance_out, ["class", "vector"])
221
+
222
+ # Find peaks in all crops
223
+ crop_peaks, crop_peak_vals = self._find_global_peaks(instance_cms)
224
+ crop_peaks = crop_peaks * (
225
+ self.instance_output_stride / self.instance_input_scale
226
+ )
227
+
228
+ # Reshape to batch x instances x nodes x 2
229
+ crop_peaks = crop_peaks.reshape(batch_size, self.max_instances, self.n_nodes, 2)
230
+ peak_vals = crop_peak_vals.reshape(batch_size, self.max_instances, self.n_nodes)
231
+
232
+ # Reshape class logits
233
+ class_logits = instance_class.reshape(
234
+ batch_size, self.max_instances, self.n_classes
235
+ )
236
+
237
+ # Transform peaks from crop coordinates to full image coordinates
238
+ crop_offset = centroids.unsqueeze(2) - image.new_tensor(
239
+ [self.crop_size[1] / 2.0, self.crop_size[0] / 2.0]
240
+ )
241
+ peaks = crop_peaks + crop_offset
242
+
243
+ # Zero out invalid instances
244
+ invalid_mask = ~instance_valid
245
+ centroids = centroids.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
246
+ centroid_vals = centroid_vals.masked_fill(invalid_mask, 0.0)
247
+ peaks = peaks.masked_fill(invalid_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
248
+ peak_vals = peak_vals.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
249
+ class_logits = class_logits.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
250
+
251
+ return {
252
+ "centroids": centroids,
253
+ "centroid_vals": centroid_vals,
254
+ "peaks": peaks,
255
+ "peak_vals": peak_vals,
256
+ "class_logits": class_logits,
257
+ "instance_valid": instance_valid,
258
+ }
259
+
260
+ def _extract_crops(
261
+ self,
262
+ image: torch.Tensor,
263
+ centroids: torch.Tensor,
264
+ ) -> torch.Tensor:
265
+ """Extract crops around centroids using grid_sample.
266
+
267
+ This is the same vectorized implementation as TopDownONNXWrapper.
268
+ """
269
+ batch_size, channels, height, width = image.shape
270
+ crop_h, crop_w = self.crop_size
271
+ n_instances = centroids.shape[1]
272
+
273
+ scale_x = crop_w / width
274
+ scale_y = crop_h / height
275
+ scale = image.new_tensor([scale_x, scale_y])
276
+ base_grid = self.base_grid.to(device=image.device, dtype=image.dtype)
277
+ scaled_grid = base_grid * scale
278
+
279
+ scaled_grid = scaled_grid.unsqueeze(0).unsqueeze(0)
280
+ scaled_grid = scaled_grid.expand(batch_size, n_instances, -1, -1, -1)
281
+
282
+ norm_centroids = torch.zeros_like(centroids)
283
+ norm_centroids[..., 0] = (centroids[..., 0] / (width - 1)) * 2 - 1
284
+ norm_centroids[..., 1] = (centroids[..., 1] / (height - 1)) * 2 - 1
285
+ offset = norm_centroids.unsqueeze(2).unsqueeze(2)
286
+
287
+ sample_grid = scaled_grid + offset
288
+
289
+ image_expanded = image.unsqueeze(1).expand(-1, n_instances, -1, -1, -1)
290
+ image_flat = image_expanded.reshape(
291
+ batch_size * n_instances, channels, height, width
292
+ )
293
+ grid_flat = sample_grid.reshape(batch_size * n_instances, crop_h, crop_w, 2)
294
+
295
+ crops_flat = F.grid_sample(
296
+ image_flat,
297
+ grid_flat,
298
+ mode="bilinear",
299
+ padding_mode="zeros",
300
+ align_corners=True,
301
+ )
302
+
303
+ crops = crops_flat.reshape(batch_size, n_instances, channels, crop_h, crop_w)
304
+ return crops
@@ -1 +1,7 @@
1
1
  """Inference-related modules."""
2
+
3
+ from sleap_nn.inference.provenance import (
4
+ build_inference_provenance,
5
+ build_tracking_only_provenance,
6
+ merge_provenance,
7
+ )
@@ -1,5 +1,6 @@
1
1
  """Inference modules for BottomUp models."""
2
2
 
3
+ import logging
3
4
  from typing import Dict, Optional
4
5
  import torch
5
6
  import lightning as L
@@ -7,6 +8,8 @@ from sleap_nn.inference.peak_finding import find_local_peaks
7
8
  from sleap_nn.inference.paf_grouping import PAFScorer
8
9
  from sleap_nn.inference.identity import classify_peaks_from_maps
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class BottomUpInferenceModel(L.LightningModule):
12
15
  """BottomUp Inference model.
@@ -63,8 +66,28 @@ class BottomUpInferenceModel(L.LightningModule):
63
66
  return_pafs: Optional[bool] = False,
64
67
  return_paf_graph: Optional[bool] = False,
65
68
  input_scale: float = 1.0,
69
+ max_peaks_per_node: Optional[int] = None,
66
70
  ):
67
- """Initialise the model attributes."""
71
+ """Initialise the model attributes.
72
+
73
+ Args:
74
+ torch_model: A `nn.Module` that accepts images and predicts confidence maps.
75
+ paf_scorer: A `PAFScorer` instance for grouping instances.
76
+ cms_output_stride: Output stride of confidence maps relative to images.
77
+ pafs_output_stride: Output stride of PAFs relative to images.
78
+ peak_threshold: Minimum confidence map value for valid peaks.
79
+ refinement: Peak refinement method: None, "integral", or "local".
80
+ integral_patch_size: Size of patches for integral refinement.
81
+ return_confmaps: If True, return confidence maps in output.
82
+ return_pafs: If True, return PAFs in output.
83
+ return_paf_graph: If True, return intermediate PAF graph in output.
84
+ input_scale: Scale factor applied to input images.
85
+ max_peaks_per_node: Maximum number of peaks allowed per node before
86
+ skipping PAF scoring. If any node has more peaks than this limit,
87
+ empty predictions are returned. This prevents combinatorial explosion
88
+ during early training when confidence maps are noisy. Set to None to
89
+ disable this check (default). Recommended value: 100.
90
+ """
68
91
  super().__init__()
69
92
  self.torch_model = torch_model
70
93
  self.paf_scorer = paf_scorer
@@ -77,6 +100,7 @@ class BottomUpInferenceModel(L.LightningModule):
77
100
  self.return_pafs = return_pafs
78
101
  self.return_paf_graph = return_paf_graph
79
102
  self.input_scale = input_scale
103
+ self.max_peaks_per_node = max_peaks_per_node
80
104
 
81
105
  def _generate_cms_peaks(self, cms):
82
106
  # TODO: append nans to batch them -> tensor (vectorize the initial paf grouping steps)
@@ -124,26 +148,68 @@ class BottomUpInferenceModel(L.LightningModule):
124
148
  ) # (batch, h, w, 2*edges)
125
149
  cms_peaks, cms_peak_vals, cms_peak_channel_inds = self._generate_cms_peaks(cms)
126
150
 
127
- (
128
- predicted_instances,
129
- predicted_peak_scores,
130
- predicted_instance_scores,
131
- edge_inds,
132
- edge_peak_inds,
133
- line_scores,
134
- ) = self.paf_scorer.predict(
135
- pafs=pafs,
136
- peaks=cms_peaks,
137
- peak_vals=cms_peak_vals,
138
- peak_channel_inds=cms_peak_channel_inds,
139
- )
140
-
141
- predicted_instances = [p / self.input_scale for p in predicted_instances]
142
- predicted_instances_adjusted = []
143
- for idx, p in enumerate(predicted_instances):
144
- predicted_instances_adjusted.append(
145
- p / inputs["eff_scale"][idx].to(p.device)
151
+ # Check if too many peaks per node (prevents combinatorial explosion)
152
+ skip_paf_scoring = False
153
+ if self.max_peaks_per_node is not None:
154
+ n_nodes = cms.shape[1]
155
+ for b in range(self.batch_size):
156
+ for node_idx in range(n_nodes):
157
+ n_peaks = int((cms_peak_channel_inds[b] == node_idx).sum().item())
158
+ if n_peaks > self.max_peaks_per_node:
159
+ logger.warning(
160
+ f"Skipping PAF scoring: node {node_idx} has {n_peaks} peaks "
161
+ f"(max_peaks_per_node={self.max_peaks_per_node}). "
162
+ f"Model may need more training."
163
+ )
164
+ skip_paf_scoring = True
165
+ break
166
+ if skip_paf_scoring:
167
+ break
168
+
169
+ if skip_paf_scoring:
170
+ # Return empty predictions for each sample
171
+ device = cms.device
172
+ n_nodes = cms.shape[1]
173
+ predicted_instances_adjusted = []
174
+ predicted_peak_scores = []
175
+ predicted_instance_scores = []
176
+ for _ in range(self.batch_size):
177
+ predicted_instances_adjusted.append(
178
+ torch.full((0, n_nodes, 2), float("nan"), device=device)
179
+ )
180
+ predicted_peak_scores.append(
181
+ torch.full((0, n_nodes), float("nan"), device=device)
182
+ )
183
+ predicted_instance_scores.append(torch.tensor([], device=device))
184
+ edge_inds = [
185
+ torch.tensor([], dtype=torch.int32, device=device)
186
+ ] * self.batch_size
187
+ edge_peak_inds = [
188
+ torch.tensor([], dtype=torch.int32, device=device).reshape(0, 2)
189
+ ] * self.batch_size
190
+ line_scores = [torch.tensor([], device=device)] * self.batch_size
191
+ else:
192
+ (
193
+ predicted_instances,
194
+ predicted_peak_scores,
195
+ predicted_instance_scores,
196
+ edge_inds,
197
+ edge_peak_inds,
198
+ line_scores,
199
+ ) = self.paf_scorer.predict(
200
+ pafs=pafs,
201
+ peaks=cms_peaks,
202
+ peak_vals=cms_peak_vals,
203
+ peak_channel_inds=cms_peak_channel_inds,
146
204
  )
205
+
206
+ predicted_instances = [p / self.input_scale for p in predicted_instances]
207
+ predicted_instances_adjusted = []
208
+ for idx, p in enumerate(predicted_instances):
209
+ predicted_instances_adjusted.append(
210
+ p / inputs["eff_scale"][idx].to(p.device)
211
+ )
212
+
147
213
  out = {
148
214
  "pred_instance_peaks": predicted_instances_adjusted,
149
215
  "pred_peak_values": predicted_peak_scores,
@@ -2,18 +2,60 @@
2
2
 
3
3
  from typing import Optional, Tuple
4
4
 
5
- import kornia as K
6
- import numpy as np
7
5
  import torch
8
- from kornia.geometry.transform import crop_and_resize
6
+ import torch.nn.functional as F
9
7
 
10
8
  from sleap_nn.data.instance_cropping import make_centered_bboxes
11
9
 
12
10
 
11
+ def morphological_dilation(image: torch.Tensor, kernel: torch.Tensor) -> torch.Tensor:
12
+ """Apply morphological dilation using max pooling.
13
+
14
+ This is a pure PyTorch replacement for kornia.morphology.dilation.
15
+ For non-maximum suppression, it computes the maximum of 8 neighbors
16
+ (excluding the center pixel).
17
+
18
+ Args:
19
+ image: Input tensor of shape (B, 1, H, W).
20
+ kernel: Dilation kernel (3x3 expected for NMS).
21
+
22
+ Returns:
23
+ Dilated tensor of same shape as input.
24
+ """
25
+ # Pad the image to handle border pixels
26
+ padded = F.pad(image, (1, 1, 1, 1), mode="constant", value=float("-inf"))
27
+
28
+ # Extract 3x3 patches using unfold
29
+ # Shape: (B, 1, H, W, 3, 3)
30
+ patches = padded.unfold(2, 3, 1).unfold(3, 3, 1)
31
+
32
+ # Reshape to (B, 1, H, W, 9)
33
+ b, c, h, w, kh, kw = patches.shape
34
+ patches = patches.reshape(b, c, h, w, -1)
35
+
36
+ # Apply kernel mask (kernel has 0 at center, 1 elsewhere for NMS)
37
+ # Reshape kernel to (1, 1, 1, 1, 9)
38
+ kernel_flat = kernel.reshape(-1).to(patches.device)
39
+ kernel_mask = kernel_flat > 0
40
+
41
+ # Set non-kernel positions to -inf so they don't affect max
42
+ patches_masked = patches.clone()
43
+ patches_masked[..., ~kernel_mask] = float("-inf")
44
+
45
+ # Take max over the kernel neighborhood
46
+ max_vals = patches_masked.max(dim=-1)[0]
47
+
48
+ return max_vals
49
+
50
+
13
51
  def crop_bboxes(
14
52
  images: torch.Tensor, bboxes: torch.Tensor, sample_inds: torch.Tensor
15
53
  ) -> torch.Tensor:
16
- """Crop bounding boxes from a batch of images.
54
+ """Crop bounding boxes from a batch of images using fast tensor indexing.
55
+
56
+ This uses tensor unfold operations to extract patches, which is significantly
57
+ faster than kornia's crop_and_resize (17-51x speedup) as it avoids perspective
58
+ transform computations.
17
59
 
18
60
  Args:
19
61
  images: Tensor of shape (samples, channels, height, width) of a batch of images.
@@ -27,7 +69,7 @@ def crop_bboxes(
27
69
  box should be cropped from.
28
70
 
29
71
  Returns:
30
- A tensor of shape (n_bboxes, crop_height, crop_width, channels) of the same
72
+ A tensor of shape (n_bboxes, channels, crop_height, crop_width) of the same
31
73
  dtype as the input image. The crop size is inferred from the bounding box
32
74
  coordinates.
33
75
 
@@ -42,20 +84,55 @@ def crop_bboxes(
42
84
 
43
85
  See also: `make_centered_bboxes`
44
86
  """
87
+ n_crops = bboxes.shape[0]
88
+ if n_crops == 0:
89
+ # Return empty tensor; use default crop size since we can't infer from bboxes
90
+ return torch.empty(
91
+ 0, images.shape[1], 0, 0, device=images.device, dtype=images.dtype
92
+ )
93
+
45
94
  # Compute bounding box size to use for crops.
46
- height = abs(bboxes[0, 3, 1] - bboxes[0, 0, 1])
47
- width = abs(bboxes[0, 1, 0] - bboxes[0, 0, 0])
48
- box_size = tuple(torch.round(torch.Tensor((height + 1, width + 1))).to(torch.int32))
49
-
50
- # Crop.
51
- crops = crop_and_resize(
52
- images[sample_inds], # (n_boxes, channels, height, width)
53
- boxes=bboxes,
54
- size=box_size,
95
+ height = int(abs(bboxes[0, 3, 1] - bboxes[0, 0, 1]).item()) + 1
96
+ width = int(abs(bboxes[0, 1, 0] - bboxes[0, 0, 0]).item()) + 1
97
+
98
+ # Store original dtype for conversion back after cropping.
99
+ original_dtype = images.dtype
100
+ device = images.device
101
+ n_samples, channels, img_h, img_w = images.shape
102
+ half_h, half_w = height // 2, width // 2
103
+
104
+ # Pad images for edge handling.
105
+ images_padded = F.pad(
106
+ images.float(), (half_w, half_w, half_h, half_h), mode="constant", value=0
55
107
  )
56
108
 
109
+ # Extract all possible patches using unfold (creates a view, no copy).
110
+ # Shape after unfold: (n_samples, channels, img_h, img_w, height, width)
111
+ patches = images_padded.unfold(2, height, 1).unfold(3, width, 1)
112
+
113
+ # Get crop centers from bboxes.
114
+ # The bbox top-left is at index 0, with (x, y) coordinates.
115
+ # We need the center of the crop (peak location), which is top-left + half_size.
116
+ # Ensure bboxes are on the same device as images for index computation.
117
+ bboxes_on_device = bboxes.to(device)
118
+ crop_x = (bboxes_on_device[:, 0, 0] + half_w).to(torch.long)
119
+ crop_y = (bboxes_on_device[:, 0, 1] + half_h).to(torch.long)
120
+
121
+ # Clamp indices to valid bounds to handle edge cases where centroids
122
+ # might be at or beyond image boundaries.
123
+ crop_x = torch.clamp(crop_x, 0, patches.shape[3] - 1)
124
+ crop_y = torch.clamp(crop_y, 0, patches.shape[2] - 1)
125
+
126
+ # Select crops using advanced indexing.
127
+ # Convert sample_inds to tensor if it's a list.
128
+ if not isinstance(sample_inds, torch.Tensor):
129
+ sample_inds = torch.tensor(sample_inds, device=device)
130
+ sample_inds_long = sample_inds.to(device=device, dtype=torch.long)
131
+ crops = patches[sample_inds_long, :, crop_y, crop_x]
132
+ # Shape: (n_crops, channels, height, width)
133
+
57
134
  # Cast back to original dtype and return.
58
- crops = crops.to(images.dtype)
135
+ crops = crops.to(original_dtype)
59
136
  return crops
60
137
 
61
138
 
@@ -236,7 +313,7 @@ def find_local_peaks_rough(
236
313
  flat_img = cms.reshape(-1, 1, height, width)
237
314
 
238
315
  # Perform dilation filtering to find local maxima per channel and reshape back.
239
- max_img = K.morphology.dilation(flat_img, kernel.to(flat_img.device))
316
+ max_img = morphological_dilation(flat_img, kernel.to(flat_img.device))
240
317
  max_img = max_img.reshape(-1, channels, height, width)
241
318
 
242
319
  # Filter for maxima and threshold.