sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,15 @@
1
- """This module implements data pipeline blocks for augmentation operations."""
1
+ """This module implements data pipeline blocks for augmentation operations.
2
2
 
3
- from typing import Any, Dict, Optional, Tuple, Union
4
- import kornia as K
3
+ Uses Skia (skia-python) for ~1.5x faster augmentation compared to Kornia.
4
+ """
5
+
6
+ from typing import Optional, Tuple
5
7
  import torch
6
- from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
7
- from kornia.augmentation.container import AugmentationSequential
8
- from kornia.augmentation.utils.param_validation import _range_bound
9
- from kornia.core import Tensor
8
+
9
+ from sleap_nn.data.skia_augmentation import (
10
+ apply_intensity_augmentation_skia,
11
+ apply_geometric_augmentation_skia,
12
+ )
10
13
 
11
14
 
12
15
  def apply_intensity_augmentation(
@@ -24,8 +27,8 @@ def apply_intensity_augmentation(
24
27
  brightness_min: Optional[float] = 1.0,
25
28
  brightness_max: Optional[float] = 1.0,
26
29
  brightness_p: float = 0.0,
27
- ) -> Tuple[torch.Tensor]:
28
- """Apply kornia intensity augmentation on image and instances.
30
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
31
+ """Apply intensity augmentation on image and instances.
29
32
 
30
33
  Args:
31
34
  image: Input image. Shape: (n_samples, C, H, W)
@@ -46,76 +49,36 @@ def apply_intensity_augmentation(
46
49
  Returns:
47
50
  Returns tuple: (image, instances) with augmentation applied.
48
51
  """
49
- aug_stack = []
50
- if uniform_noise_p > 0:
51
- aug_stack.append(
52
- RandomUniformNoise(
53
- noise=(uniform_noise_min, uniform_noise_max),
54
- p=uniform_noise_p,
55
- keepdim=True,
56
- same_on_batch=True,
57
- )
58
- )
59
- if gaussian_noise_p > 0:
60
- aug_stack.append(
61
- K.augmentation.RandomGaussianNoise(
62
- mean=gaussian_noise_mean,
63
- std=gaussian_noise_std,
64
- p=gaussian_noise_p,
65
- keepdim=True,
66
- same_on_batch=True,
67
- )
68
- )
69
- if contrast_p > 0:
70
- aug_stack.append(
71
- K.augmentation.RandomContrast(
72
- contrast=(contrast_min, contrast_max),
73
- p=contrast_p,
74
- keepdim=True,
75
- same_on_batch=True,
76
- )
77
- )
78
- if brightness_p > 0:
79
- aug_stack.append(
80
- K.augmentation.RandomBrightness(
81
- brightness=(brightness_min, brightness_max),
82
- p=brightness_p,
83
- keepdim=True,
84
- same_on_batch=True,
85
- )
86
- )
87
-
88
- augmenter = AugmentationSequential(
89
- *aug_stack,
90
- data_keys=["input", "keypoints"],
91
- keepdim=True,
92
- same_on_batch=True,
52
+ return apply_intensity_augmentation_skia(
53
+ image=image,
54
+ instances=instances,
55
+ uniform_noise_min=uniform_noise_min,
56
+ uniform_noise_max=uniform_noise_max,
57
+ uniform_noise_p=uniform_noise_p,
58
+ gaussian_noise_mean=gaussian_noise_mean,
59
+ gaussian_noise_std=gaussian_noise_std,
60
+ gaussian_noise_p=gaussian_noise_p,
61
+ contrast_min=contrast_min,
62
+ contrast_max=contrast_max,
63
+ contrast_p=contrast_p,
64
+ brightness_min=brightness_min,
65
+ brightness_max=brightness_max,
66
+ brightness_p=brightness_p,
93
67
  )
94
68
 
95
- inst_shape = instances.shape
96
- # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
97
- # or
98
- # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
99
- instances = instances.reshape(inst_shape[0], -1, 2)
100
- # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
101
-
102
- aug_image, aug_instances = augmenter(image, instances)
103
-
104
- # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
105
- # or
106
- # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
107
- return aug_image, aug_instances.reshape(*inst_shape)
108
-
109
69
 
110
70
  def apply_geometric_augmentation(
111
71
  image: torch.Tensor,
112
72
  instances: torch.Tensor,
113
73
  rotation_min: Optional[float] = -15.0,
114
74
  rotation_max: Optional[float] = 15.0,
75
+ rotation_p: Optional[float] = None,
115
76
  scale_min: Optional[float] = 0.9,
116
77
  scale_max: Optional[float] = 1.1,
78
+ scale_p: Optional[float] = None,
117
79
  translate_width: Optional[float] = 0.02,
118
80
  translate_height: Optional[float] = 0.02,
81
+ translate_p: Optional[float] = None,
119
82
  affine_p: float = 0.0,
120
83
  erase_scale_min: Optional[float] = 0.0001,
121
84
  erase_scale_max: Optional[float] = 0.01,
@@ -125,19 +88,26 @@ def apply_geometric_augmentation(
125
88
  mixup_lambda_min: Optional[float] = 0.01,
126
89
  mixup_lambda_max: Optional[float] = 0.05,
127
90
  mixup_p: float = 0.0,
128
- ) -> Tuple[torch.Tensor]:
129
- """Apply kornia geometric augmentation on image and instances.
91
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """Apply geometric augmentation on image and instances.
130
93
 
131
94
  Args:
132
95
  image: Input image. Shape: (n_samples, C, H, W)
133
96
  instances: Input keypoints. (n_samples, n_instances, n_nodes, 2) or (n_samples, n_nodes, 2)
134
97
  rotation_min: Minimum rotation angle in degrees. Default: -15.0.
135
98
  rotation_max: Maximum rotation angle in degrees. Default: 15.0.
99
+ rotation_p: Probability of applying random rotation independently. If None,
100
+ falls back to affine_p for bundled behavior. Default: None.
136
101
  scale_min: Minimum scaling factor for isotropic scaling. Default: 0.9.
137
102
  scale_max: Maximum scaling factor for isotropic scaling. Default: 1.1.
103
+ scale_p: Probability of applying random scaling independently. If None,
104
+ falls back to affine_p for bundled behavior. Default: None.
138
105
  translate_width: Maximum absolute fraction for horizontal translation. Default: 0.02.
139
106
  translate_height: Maximum absolute fraction for vertical translation. Default: 0.02.
140
- affine_p: Probability of applying random affine transformations. Default: 0.0.
107
+ translate_p: Probability of applying random translation independently. If None,
108
+ falls back to affine_p for bundled behavior. Default: None.
109
+ affine_p: Probability of applying random affine transformations (rotation, scale,
110
+ translate bundled). Used when individual *_p params are None. Default: 0.0.
141
111
  erase_scale_min: Minimum value of range of proportion of erased area against input image. Default: 0.0001.
142
112
  erase_scale_max: Maximum value of range of proportion of erased area against input image. Default: 0.01.
143
113
  erase_ratio_min: Minimum value of range of aspect ratio of erased area. Default: 1.
@@ -150,134 +120,25 @@ def apply_geometric_augmentation(
150
120
  Returns:
151
121
  Returns tuple: (image, instances) with augmentation applied.
152
122
  """
153
- aug_stack = []
154
- if affine_p > 0:
155
- aug_stack.append(
156
- K.augmentation.RandomAffine(
157
- degrees=(rotation_min, rotation_max),
158
- translate=(translate_width, translate_height),
159
- scale=(scale_min, scale_max),
160
- p=affine_p,
161
- keepdim=True,
162
- same_on_batch=True,
163
- )
164
- )
165
-
166
- if erase_p > 0:
167
- aug_stack.append(
168
- K.augmentation.RandomErasing(
169
- scale=(erase_scale_min, erase_scale_max),
170
- ratio=(erase_ratio_min, erase_ratio_max),
171
- p=erase_p,
172
- keepdim=True,
173
- same_on_batch=True,
174
- )
175
- )
176
- if mixup_p > 0:
177
- aug_stack.append(
178
- K.augmentation.RandomMixUpV2(
179
- lambda_val=(mixup_lambda_min, mixup_lambda_max),
180
- p=mixup_p,
181
- keepdim=True,
182
- same_on_batch=True,
183
- )
184
- )
185
-
186
- augmenter = AugmentationSequential(
187
- *aug_stack,
188
- data_keys=["input", "keypoints"],
189
- keepdim=True,
190
- same_on_batch=True,
123
+ return apply_geometric_augmentation_skia(
124
+ image=image,
125
+ instances=instances,
126
+ rotation_min=rotation_min,
127
+ rotation_max=rotation_max,
128
+ rotation_p=rotation_p,
129
+ scale_min=scale_min,
130
+ scale_max=scale_max,
131
+ scale_p=scale_p,
132
+ translate_width=translate_width,
133
+ translate_height=translate_height,
134
+ translate_p=translate_p,
135
+ affine_p=affine_p,
136
+ erase_scale_min=erase_scale_min,
137
+ erase_scale_max=erase_scale_max,
138
+ erase_ratio_min=erase_ratio_min,
139
+ erase_ratio_max=erase_ratio_max,
140
+ erase_p=erase_p,
141
+ mixup_lambda_min=mixup_lambda_min,
142
+ mixup_lambda_max=mixup_lambda_max,
143
+ mixup_p=mixup_p,
191
144
  )
192
-
193
- inst_shape = instances.shape
194
- # Before (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
195
- # or
196
- # Before (cropped image): (B=1, C, crop_H, crop_W), (n_samples, n_nodes, 2)
197
- instances = instances.reshape(inst_shape[0], -1, 2)
198
- # (n_samples, C, H, W), (n_samples, n_instances * n_nodes, 2) OR (n_samples, n_nodes, 2)
199
-
200
- aug_image, aug_instances = augmenter(image, instances)
201
-
202
- # After (full image): (n_samples, C, H, W), (n_samples, n_instances, n_nodes, 2)
203
- # or
204
- # After (cropped image): (n_samples, C, crop_H, crop_W), (n_samples, n_nodes, 2)
205
- return aug_image, aug_instances.reshape(*inst_shape)
206
-
207
-
208
- class RandomUniformNoise(IntensityAugmentationBase2D):
209
- """Data transformer for applying random uniform noise to input images.
210
-
211
- This is a custom Kornia augmentation inheriting from `IntensityAugmentationBase2D`.
212
- Uniform noise within (min_val, max_val) is applied to the entire input image.
213
-
214
- Note: Inverse transform is not implemented and re-applying the same transformation
215
- in the example below does not work when included in an AugmentationSequential class.
216
-
217
- Args:
218
- noise: 2-tuple (min_val, max_val); 0.0 <= min_val <= max_val <= 1.0.
219
- p: probability for applying an augmentation. This param controls the augmentation probabilities
220
- element-wise for a batch.
221
- p_batch: probability for applying an augmentation to a batch. This param controls the augmentation
222
- probabilities batch-wise.
223
- same_on_batch: apply the same transformation across the batch.
224
- keepdim: whether to keep the output shape the same as input `True` or broadcast it
225
- to the batch form `False`.
226
-
227
- Examples:
228
- >>> rng = torch.manual_seed(0)
229
- >>> img = torch.rand(1, 1, 2, 2)
230
- >>> RandomUniformNoise(min_val=0., max_val=0.1, p=1.)(img)
231
- tensor([[[[0.9607, 0.5865],
232
- [0.2705, 0.5920]]]])
233
-
234
- To apply the exact augmentation again, you may take the advantage of the previous parameter state:
235
- >>> input = torch.rand(1, 3, 32, 32)
236
- >>> aug = RandomUniformNoise(min_val=0., max_val=0.1, p=1.)
237
- >>> (aug(input) == aug(input, params=aug._params)).all()
238
- tensor(True)
239
-
240
- Ref: `kornia.augmentation._2d.intensity.gaussian_noise
241
- <https://kornia.readthedocs.io/en/latest/_modules/kornia/augmentation/_2d/intensity/gaussian_noise.html#RandomGaussianNoise>`_.
242
- """
243
-
244
- def __init__(
245
- self,
246
- noise: Tuple[float, float],
247
- p: float = 0.5,
248
- p_batch: float = 1.0,
249
- clip_output: bool = True,
250
- same_on_batch: bool = False,
251
- keepdim: bool = False,
252
- ) -> None:
253
- """Initialize the class."""
254
- super().__init__(
255
- p=p, p_batch=p_batch, same_on_batch=same_on_batch, keepdim=keepdim
256
- )
257
- self.flags = {
258
- "uniform_noise": _range_bound(noise, "uniform_noise", bounds=(0.0, 1.0))
259
- }
260
- self.clip_output = clip_output
261
-
262
- def apply_transform(
263
- self,
264
- input: Tensor,
265
- params: Dict[str, Tensor],
266
- flags: Dict[str, Any],
267
- transform: Optional[Tensor] = None,
268
- ) -> Tensor:
269
- """Compute the uniform noise, add, and clamp output."""
270
- if "uniform_noise" in params:
271
- uniform_noise = params["uniform_noise"]
272
- else:
273
- uniform_noise = (
274
- torch.FloatTensor(input.shape)
275
- .uniform_(flags["uniform_noise"][0], flags["uniform_noise"][1])
276
- .to(input.device)
277
- )
278
- self._params["uniform_noise"] = uniform_noise
279
- if self.clip_output:
280
- return torch.clamp(
281
- input + uniform_noise, 0.0, 1.0
282
- ) # RandomGaussianNoise doesn't clamp.
283
- return input + uniform_noise