sleap-nn 0.1.0a1__py3-none-any.whl → 0.1.0a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sleap_nn/__init__.py +1 -1
  2. sleap_nn/cli.py +36 -0
  3. sleap_nn/config/trainer_config.py +18 -0
  4. sleap_nn/evaluation.py +81 -22
  5. sleap_nn/export/__init__.py +21 -0
  6. sleap_nn/export/cli.py +1778 -0
  7. sleap_nn/export/exporters/__init__.py +51 -0
  8. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  9. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  10. sleap_nn/export/metadata.py +225 -0
  11. sleap_nn/export/predictors/__init__.py +63 -0
  12. sleap_nn/export/predictors/base.py +22 -0
  13. sleap_nn/export/predictors/onnx.py +154 -0
  14. sleap_nn/export/predictors/tensorrt.py +312 -0
  15. sleap_nn/export/utils.py +307 -0
  16. sleap_nn/export/wrappers/__init__.py +25 -0
  17. sleap_nn/export/wrappers/base.py +96 -0
  18. sleap_nn/export/wrappers/bottomup.py +243 -0
  19. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  20. sleap_nn/export/wrappers/centered_instance.py +56 -0
  21. sleap_nn/export/wrappers/centroid.py +58 -0
  22. sleap_nn/export/wrappers/single_instance.py +83 -0
  23. sleap_nn/export/wrappers/topdown.py +180 -0
  24. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  25. sleap_nn/inference/bottomup.py +86 -20
  26. sleap_nn/inference/postprocessing.py +284 -0
  27. sleap_nn/predict.py +29 -0
  28. sleap_nn/train.py +64 -0
  29. sleap_nn/training/callbacks.py +324 -8
  30. sleap_nn/training/lightning_modules.py +542 -32
  31. sleap_nn/training/model_trainer.py +48 -57
  32. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/METADATA +13 -2
  33. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/RECORD +37 -16
  34. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/WHEEL +0 -0
  35. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/entry_points.txt +0 -0
  36. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/licenses/LICENSE +0 -0
  37. {sleap_nn-0.1.0a1.dist-info → sleap_nn-0.1.0a3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,195 @@
1
+ """ONNX wrapper for bottom-up multiclass (supervised ID) models."""
2
+
3
+ from typing import Dict
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
10
+
11
+
12
+ class BottomUpMultiClassONNXWrapper(BaseExportWrapper):
13
+ """ONNX-exportable wrapper for bottom-up multiclass (supervised ID) models.
14
+
15
+ This wrapper handles models that output both confidence maps for keypoint
16
+ detection and class maps for identity classification. Unlike PAF-based
17
+ bottom-up models, multiclass models use class maps to assign identity to
18
+ each detected peak, then group peaks by identity.
19
+
20
+ The wrapper performs:
21
+ 1. Peak detection in confidence maps (GPU)
22
+ 2. Class probability sampling at peak locations (GPU)
23
+ 3. Returns fixed-size tensors for CPU-side grouping
24
+
25
+ Expects input images as uint8 tensors in [0, 255].
26
+
27
+ Attributes:
28
+ model: The underlying PyTorch model.
29
+ n_nodes: Number of keypoint nodes in the skeleton.
30
+ n_classes: Number of identity classes.
31
+ max_peaks_per_node: Maximum number of peaks to detect per node.
32
+ cms_output_stride: Output stride of the confidence map head.
33
+ class_maps_output_stride: Output stride of the class maps head.
34
+ input_scale: Scale factor applied to input images before inference.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ model: nn.Module,
40
+ n_nodes: int,
41
+ n_classes: int = 2,
42
+ max_peaks_per_node: int = 20,
43
+ cms_output_stride: int = 4,
44
+ class_maps_output_stride: int = 8,
45
+ input_scale: float = 1.0,
46
+ ):
47
+ """Initialize the wrapper.
48
+
49
+ Args:
50
+ model: The underlying PyTorch model.
51
+ n_nodes: Number of keypoint nodes.
52
+ n_classes: Number of identity classes (e.g., 2 for male/female).
53
+ max_peaks_per_node: Maximum peaks per node to detect.
54
+ cms_output_stride: Output stride of confidence maps.
55
+ class_maps_output_stride: Output stride of class maps.
56
+ input_scale: Scale factor for input images.
57
+ """
58
+ super().__init__(model)
59
+ self.n_nodes = n_nodes
60
+ self.n_classes = n_classes
61
+ self.max_peaks_per_node = max_peaks_per_node
62
+ self.cms_output_stride = cms_output_stride
63
+ self.class_maps_output_stride = class_maps_output_stride
64
+ self.input_scale = input_scale
65
+
66
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
67
+ """Run bottom-up multiclass inference.
68
+
69
+ Args:
70
+ image: Input image tensor of shape (batch, channels, height, width).
71
+ Expected to be uint8 in [0, 255].
72
+
73
+ Returns:
74
+ Dictionary with keys:
75
+ - "peaks": Detected peak coordinates (batch, n_nodes, max_peaks, 2).
76
+ Coordinates are in input image space (x, y).
77
+ - "peak_vals": Peak confidence values (batch, n_nodes, max_peaks).
78
+ - "peak_mask": Boolean mask for valid peaks (batch, n_nodes, max_peaks).
79
+ - "class_probs": Class probabilities at each peak location
80
+ (batch, n_nodes, max_peaks, n_classes).
81
+
82
+ Postprocessing on CPU uses `classify_peaks_from_maps()` to group
83
+ peaks by identity using Hungarian matching.
84
+ """
85
+ # Normalize uint8 [0, 255] to float32 [0, 1]
86
+ image = self._normalize_uint8(image)
87
+
88
+ # Apply input scaling if needed
89
+ if self.input_scale != 1.0:
90
+ height = int(image.shape[-2] * self.input_scale)
91
+ width = int(image.shape[-1] * self.input_scale)
92
+ image = F.interpolate(
93
+ image, size=(height, width), mode="bilinear", align_corners=False
94
+ )
95
+
96
+ batch_size = image.shape[0]
97
+
98
+ # Forward pass
99
+ out = self.model(image)
100
+
101
+ # Extract outputs
102
+ # Note: Use "classmaps" as a single hint to avoid "map" matching "confmaps"
103
+ confmaps = self._extract_tensor(out, ["confmap", "multiinstance"])
104
+ class_maps = self._extract_tensor(out, ["classmaps", "classmapshead"])
105
+
106
+ # Find top-k peaks per node
107
+ peaks, peak_vals, peak_mask = self._find_topk_peaks_per_node(
108
+ confmaps, self.max_peaks_per_node
109
+ )
110
+
111
+ # Scale peaks to input image space
112
+ peaks = peaks * self.cms_output_stride
113
+
114
+ # Sample class maps at peak locations
115
+ class_probs = self._sample_class_maps_at_peaks(class_maps, peaks, peak_mask)
116
+
117
+ # Scale peaks for output (accounting for input scale)
118
+ if self.input_scale != 1.0:
119
+ peaks = peaks / self.input_scale
120
+
121
+ return {
122
+ "peaks": peaks,
123
+ "peak_vals": peak_vals,
124
+ "peak_mask": peak_mask,
125
+ "class_probs": class_probs,
126
+ }
127
+
128
+ def _sample_class_maps_at_peaks(
129
+ self,
130
+ class_maps: torch.Tensor,
131
+ peaks: torch.Tensor,
132
+ peak_mask: torch.Tensor,
133
+ ) -> torch.Tensor:
134
+ """Sample class map values at peak locations.
135
+
136
+ Args:
137
+ class_maps: Class maps of shape (batch, n_classes, height, width).
138
+ peaks: Peak coordinates in cms_output_stride space,
139
+ shape (batch, n_nodes, max_peaks, 2) in (x, y) order.
140
+ peak_mask: Boolean mask for valid peaks (batch, n_nodes, max_peaks).
141
+
142
+ Returns:
143
+ Class probabilities at each peak location,
144
+ shape (batch, n_nodes, max_peaks, n_classes).
145
+ """
146
+ batch_size, n_classes, cm_height, cm_width = class_maps.shape
147
+ _, n_nodes, max_peaks, _ = peaks.shape
148
+ device = peaks.device
149
+
150
+ # Initialize output tensor
151
+ class_probs = torch.zeros(
152
+ (batch_size, n_nodes, max_peaks, n_classes),
153
+ device=device,
154
+ dtype=class_maps.dtype,
155
+ )
156
+
157
+ # Convert peak coordinates to class map space
158
+ # peaks are in full image space (after cms_output_stride scaling)
159
+ peaks_cm = peaks / self.class_maps_output_stride
160
+
161
+ # Clamp coordinates to valid range
162
+ peaks_cm_x = peaks_cm[..., 0].clamp(0, cm_width - 1)
163
+ peaks_cm_y = peaks_cm[..., 1].clamp(0, cm_height - 1)
164
+
165
+ # Use grid_sample for bilinear interpolation
166
+ # Normalize coordinates to [-1, 1] for grid_sample
167
+ grid_x = (peaks_cm_x / (cm_width - 1)) * 2 - 1
168
+ grid_y = (peaks_cm_y / (cm_height - 1)) * 2 - 1
169
+
170
+ # Reshape for grid_sample: (batch, n_nodes * max_peaks, 1, 2)
171
+ grid = torch.stack([grid_x, grid_y], dim=-1)
172
+ grid_flat = grid.reshape(batch_size, n_nodes * max_peaks, 1, 2)
173
+
174
+ # Sample class maps: (batch, n_classes, n_nodes * max_peaks, 1)
175
+ sampled = F.grid_sample(
176
+ class_maps,
177
+ grid_flat,
178
+ mode="bilinear",
179
+ padding_mode="zeros",
180
+ align_corners=True,
181
+ )
182
+
183
+ # Reshape to (batch, n_nodes, max_peaks, n_classes)
184
+ sampled = sampled.squeeze(-1) # (batch, n_classes, n_nodes * max_peaks)
185
+ sampled = sampled.permute(0, 2, 1) # (batch, n_nodes * max_peaks, n_classes)
186
+ sampled = sampled.reshape(batch_size, n_nodes, max_peaks, n_classes)
187
+
188
+ # Apply softmax to get probabilities (optional - depends on training)
189
+ # For now, return raw values as the grouping function expects logits
190
+ class_probs = sampled
191
+
192
+ # Mask invalid peaks
193
+ class_probs = class_probs * peak_mask.unsqueeze(-1).float()
194
+
195
+ return class_probs
@@ -0,0 +1,56 @@
1
+ """Centered-instance ONNX wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
12
+
13
+
14
+ class CenteredInstanceONNXWrapper(BaseExportWrapper):
15
+ """ONNX-exportable wrapper for centered-instance models.
16
+
17
+ Expects input images as uint8 tensors in [0, 255].
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: nn.Module,
23
+ output_stride: int = 4,
24
+ input_scale: float = 1.0,
25
+ ):
26
+ """Initialize centered instance ONNX wrapper.
27
+
28
+ Args:
29
+ model: Centered instance model for pose estimation.
30
+ output_stride: Output stride for confidence maps.
31
+ input_scale: Input scaling factor.
32
+ """
33
+ super().__init__(model)
34
+ self.output_stride = output_stride
35
+ self.input_scale = input_scale
36
+
37
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
38
+ """Run centered-instance inference on crops."""
39
+ image = self._normalize_uint8(image)
40
+ if self.input_scale != 1.0:
41
+ height = int(image.shape[-2] * self.input_scale)
42
+ width = int(image.shape[-1] * self.input_scale)
43
+ image = F.interpolate(
44
+ image, size=(height, width), mode="bilinear", align_corners=False
45
+ )
46
+
47
+ confmaps = self._extract_tensor(
48
+ self.model(image), ["centered", "instance", "confmap"]
49
+ )
50
+ peaks, values = self._find_global_peaks(confmaps)
51
+ peaks = peaks * (self.output_stride / self.input_scale)
52
+
53
+ return {
54
+ "peaks": peaks,
55
+ "peak_vals": values,
56
+ }
@@ -0,0 +1,58 @@
1
+ """Centroid ONNX wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
12
+
13
+
14
+ class CentroidONNXWrapper(BaseExportWrapper):
15
+ """ONNX-exportable wrapper for centroid models.
16
+
17
+ Expects input images as uint8 tensors in [0, 255].
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: nn.Module,
23
+ max_instances: int = 20,
24
+ output_stride: int = 2,
25
+ input_scale: float = 1.0,
26
+ ):
27
+ """Initialize centroid ONNX wrapper.
28
+
29
+ Args:
30
+ model: Centroid detection model.
31
+ max_instances: Maximum number of instances to detect.
32
+ output_stride: Output stride for confidence maps.
33
+ input_scale: Input scaling factor.
34
+ """
35
+ super().__init__(model)
36
+ self.max_instances = max_instances
37
+ self.output_stride = output_stride
38
+ self.input_scale = input_scale
39
+
40
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
41
+ """Run centroid inference and return fixed-size outputs."""
42
+ image = self._normalize_uint8(image)
43
+ if self.input_scale != 1.0:
44
+ height = int(image.shape[-2] * self.input_scale)
45
+ width = int(image.shape[-1] * self.input_scale)
46
+ image = F.interpolate(
47
+ image, size=(height, width), mode="bilinear", align_corners=False
48
+ )
49
+
50
+ confmaps = self._extract_tensor(self.model(image), ["centroid", "confmap"])
51
+ peaks, values, valid = self._find_topk_peaks(confmaps, self.max_instances)
52
+ peaks = peaks * (self.output_stride / self.input_scale)
53
+
54
+ return {
55
+ "centroids": peaks,
56
+ "centroid_vals": values,
57
+ "instance_valid": valid,
58
+ }
@@ -0,0 +1,83 @@
1
+ """Single-instance ONNX wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
12
+
13
+
14
+ class SingleInstanceONNXWrapper(BaseExportWrapper):
15
+ """ONNX-exportable wrapper for single-instance models.
16
+
17
+ This wrapper handles full-frame inference assuming a single instance per frame.
18
+ For each body part (channel), it finds the global maximum in the confidence map.
19
+
20
+ Expects input images as uint8 tensors in [0, 255].
21
+
22
+ Attributes:
23
+ model: The trained backbone model that outputs confidence maps.
24
+ output_stride: Output stride of the model (e.g., 4 means confmaps are 1/4 the
25
+ input resolution).
26
+ input_scale: Factor to scale input images before inference.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ model: nn.Module,
32
+ output_stride: int = 4,
33
+ input_scale: float = 1.0,
34
+ ):
35
+ """Initialize the single-instance wrapper.
36
+
37
+ Args:
38
+ model: The trained backbone model.
39
+ output_stride: Output stride of the model. Default: 4.
40
+ input_scale: Factor to scale input images. Default: 1.0.
41
+ """
42
+ super().__init__(model)
43
+ self.output_stride = output_stride
44
+ self.input_scale = input_scale
45
+
46
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
47
+ """Run single-instance inference.
48
+
49
+ Args:
50
+ image: Input image tensor of shape (batch, channels, height, width).
51
+ Expected as uint8 [0, 255] values.
52
+
53
+ Returns:
54
+ Dictionary with:
55
+ peaks: Peak coordinates of shape (batch, n_nodes, 2) in (x, y) format.
56
+ peak_vals: Peak confidence values of shape (batch, n_nodes).
57
+ """
58
+ # Normalize uint8 [0, 255] to float32 [0, 1]
59
+ image = self._normalize_uint8(image)
60
+
61
+ # Apply input scaling if needed
62
+ if self.input_scale != 1.0:
63
+ height = int(image.shape[-2] * self.input_scale)
64
+ width = int(image.shape[-1] * self.input_scale)
65
+ image = F.interpolate(
66
+ image, size=(height, width), mode="bilinear", align_corners=False
67
+ )
68
+
69
+ # Run model to get confidence maps: (batch, n_nodes, height, width)
70
+ confmaps = self._extract_tensor(
71
+ self.model(image), ["single", "instance", "confmap"]
72
+ )
73
+
74
+ # Find global peak for each channel: (batch, n_nodes, 2), (batch, n_nodes)
75
+ peaks, values = self._find_global_peaks(confmaps)
76
+
77
+ # Scale peaks from confmap coordinates to image coordinates
78
+ peaks = peaks * (self.output_stride / self.input_scale)
79
+
80
+ return {
81
+ "peaks": peaks,
82
+ "peak_vals": values,
83
+ }
@@ -0,0 +1,180 @@
1
+ """Top-down ONNX wrapper."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, Tuple
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+ from sleap_nn.export.wrappers.base import BaseExportWrapper
12
+
13
+
14
+ class TopDownONNXWrapper(BaseExportWrapper):
15
+ """ONNX-exportable wrapper for top-down (centroid + centered-instance) inference.
16
+
17
+ Expects input images as uint8 tensors in [0, 255].
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ centroid_model: nn.Module,
23
+ instance_model: nn.Module,
24
+ max_instances: int = 20,
25
+ crop_size: Tuple[int, int] = (192, 192),
26
+ centroid_output_stride: int = 2,
27
+ instance_output_stride: int = 4,
28
+ centroid_input_scale: float = 1.0,
29
+ instance_input_scale: float = 1.0,
30
+ n_nodes: int = 1,
31
+ ) -> None:
32
+ """Initialize top-down ONNX wrapper.
33
+
34
+ Args:
35
+ centroid_model: Centroid detection model.
36
+ instance_model: Instance pose estimation model.
37
+ max_instances: Maximum number of instances to detect.
38
+ crop_size: Size of instance crops (height, width).
39
+ centroid_output_stride: Centroid model output stride.
40
+ instance_output_stride: Instance model output stride.
41
+ centroid_input_scale: Centroid input scaling factor.
42
+ instance_input_scale: Instance input scaling factor.
43
+ n_nodes: Number of skeleton nodes.
44
+ """
45
+ super().__init__(centroid_model)
46
+ self.centroid_model = centroid_model
47
+ self.instance_model = instance_model
48
+ self.max_instances = max_instances
49
+ self.crop_size = crop_size
50
+ self.centroid_output_stride = centroid_output_stride
51
+ self.instance_output_stride = instance_output_stride
52
+ self.centroid_input_scale = centroid_input_scale
53
+ self.instance_input_scale = instance_input_scale
54
+ self.n_nodes = n_nodes
55
+
56
+ crop_h, crop_w = crop_size
57
+ y_crop = torch.linspace(-1, 1, crop_h, dtype=torch.float32)
58
+ x_crop = torch.linspace(-1, 1, crop_w, dtype=torch.float32)
59
+ grid_y, grid_x = torch.meshgrid(y_crop, x_crop, indexing="ij")
60
+ base_grid = torch.stack([grid_x, grid_y], dim=-1)
61
+ self.register_buffer("base_grid", base_grid, persistent=False)
62
+
63
+ def forward(self, image: torch.Tensor) -> Dict[str, torch.Tensor]:
64
+ """Run top-down inference and return fixed-size outputs."""
65
+ image = self._normalize_uint8(image)
66
+ batch_size, channels, height, width = image.shape
67
+
68
+ scaled_image = image
69
+ if self.centroid_input_scale != 1.0:
70
+ scaled_h = int(height * self.centroid_input_scale)
71
+ scaled_w = int(width * self.centroid_input_scale)
72
+ scaled_image = F.interpolate(
73
+ scaled_image,
74
+ size=(scaled_h, scaled_w),
75
+ mode="bilinear",
76
+ align_corners=False,
77
+ )
78
+
79
+ centroid_out = self.centroid_model(scaled_image)
80
+ centroid_cms = self._extract_tensor(centroid_out, ["centroid", "confmap"])
81
+
82
+ centroids, centroid_vals, instance_valid = self._find_topk_peaks(
83
+ centroid_cms, self.max_instances
84
+ )
85
+ centroids = centroids * (
86
+ self.centroid_output_stride / self.centroid_input_scale
87
+ )
88
+
89
+ crops = self._extract_crops(image, centroids)
90
+ crops_flat = crops.reshape(
91
+ batch_size * self.max_instances,
92
+ channels,
93
+ self.crop_size[0],
94
+ self.crop_size[1],
95
+ )
96
+
97
+ if self.instance_input_scale != 1.0:
98
+ scaled_h = int(self.crop_size[0] * self.instance_input_scale)
99
+ scaled_w = int(self.crop_size[1] * self.instance_input_scale)
100
+ crops_flat = F.interpolate(
101
+ crops_flat,
102
+ size=(scaled_h, scaled_w),
103
+ mode="bilinear",
104
+ align_corners=False,
105
+ )
106
+
107
+ instance_out = self.instance_model(crops_flat)
108
+ instance_cms = self._extract_tensor(
109
+ instance_out, ["centered", "instance", "confmap"]
110
+ )
111
+
112
+ crop_peaks, crop_peak_vals = self._find_global_peaks(instance_cms)
113
+ crop_peaks = crop_peaks * (
114
+ self.instance_output_stride / self.instance_input_scale
115
+ )
116
+
117
+ crop_peaks = crop_peaks.reshape(batch_size, self.max_instances, self.n_nodes, 2)
118
+ peak_vals = crop_peak_vals.reshape(batch_size, self.max_instances, self.n_nodes)
119
+
120
+ crop_offset = centroids.unsqueeze(2) - image.new_tensor(
121
+ [self.crop_size[1] / 2.0, self.crop_size[0] / 2.0]
122
+ )
123
+ peaks = crop_peaks + crop_offset
124
+
125
+ invalid_mask = ~instance_valid
126
+ centroids = centroids.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
127
+ centroid_vals = centroid_vals.masked_fill(invalid_mask, 0.0)
128
+ peaks = peaks.masked_fill(invalid_mask.unsqueeze(-1).unsqueeze(-1), 0.0)
129
+ peak_vals = peak_vals.masked_fill(invalid_mask.unsqueeze(-1), 0.0)
130
+
131
+ return {
132
+ "centroids": centroids,
133
+ "centroid_vals": centroid_vals,
134
+ "peaks": peaks,
135
+ "peak_vals": peak_vals,
136
+ "instance_valid": instance_valid,
137
+ }
138
+
139
+ def _extract_crops(
140
+ self,
141
+ image: torch.Tensor,
142
+ centroids: torch.Tensor,
143
+ ) -> torch.Tensor:
144
+ """Extract crops around centroids using grid_sample."""
145
+ batch_size, channels, height, width = image.shape
146
+ crop_h, crop_w = self.crop_size
147
+ n_instances = centroids.shape[1]
148
+
149
+ scale_x = crop_w / width
150
+ scale_y = crop_h / height
151
+ scale = image.new_tensor([scale_x, scale_y])
152
+ base_grid = self.base_grid.to(device=image.device, dtype=image.dtype)
153
+ scaled_grid = base_grid * scale
154
+
155
+ scaled_grid = scaled_grid.unsqueeze(0).unsqueeze(0)
156
+ scaled_grid = scaled_grid.expand(batch_size, n_instances, -1, -1, -1)
157
+
158
+ norm_centroids = torch.zeros_like(centroids)
159
+ norm_centroids[..., 0] = (centroids[..., 0] / (width - 1)) * 2 - 1
160
+ norm_centroids[..., 1] = (centroids[..., 1] / (height - 1)) * 2 - 1
161
+ offset = norm_centroids.unsqueeze(2).unsqueeze(2)
162
+
163
+ sample_grid = scaled_grid + offset
164
+
165
+ image_expanded = image.unsqueeze(1).expand(-1, n_instances, -1, -1, -1)
166
+ image_flat = image_expanded.reshape(
167
+ batch_size * n_instances, channels, height, width
168
+ )
169
+ grid_flat = sample_grid.reshape(batch_size * n_instances, crop_h, crop_w, 2)
170
+
171
+ crops_flat = F.grid_sample(
172
+ image_flat,
173
+ grid_flat,
174
+ mode="bilinear",
175
+ padding_mode="zeros",
176
+ align_corners=True,
177
+ )
178
+
179
+ crops = crops_flat.reshape(batch_size, n_instances, channels, crop_h, crop_w)
180
+ return crops