sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sleap_nn/__init__.py +2 -4
  2. sleap_nn/architectures/convnext.py +0 -5
  3. sleap_nn/architectures/encoder_decoder.py +6 -25
  4. sleap_nn/architectures/swint.py +0 -8
  5. sleap_nn/cli.py +60 -364
  6. sleap_nn/config/data_config.py +5 -11
  7. sleap_nn/config/get_config.py +4 -10
  8. sleap_nn/config/trainer_config.py +0 -76
  9. sleap_nn/data/augmentation.py +241 -50
  10. sleap_nn/data/custom_datasets.py +39 -411
  11. sleap_nn/data/instance_cropping.py +1 -1
  12. sleap_nn/data/resizing.py +2 -2
  13. sleap_nn/data/utils.py +17 -135
  14. sleap_nn/evaluation.py +22 -81
  15. sleap_nn/inference/bottomup.py +20 -86
  16. sleap_nn/inference/peak_finding.py +19 -88
  17. sleap_nn/inference/predictors.py +117 -224
  18. sleap_nn/legacy_models.py +11 -65
  19. sleap_nn/predict.py +9 -37
  20. sleap_nn/train.py +4 -74
  21. sleap_nn/training/callbacks.py +105 -1046
  22. sleap_nn/training/lightning_modules.py +65 -602
  23. sleap_nn/training/model_trainer.py +184 -211
  24. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/METADATA +3 -15
  25. sleap_nn-0.1.0a0.dist-info/RECORD +65 -0
  26. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/WHEEL +1 -1
  27. sleap_nn/data/skia_augmentation.py +0 -414
  28. sleap_nn/export/__init__.py +0 -21
  29. sleap_nn/export/cli.py +0 -1778
  30. sleap_nn/export/exporters/__init__.py +0 -51
  31. sleap_nn/export/exporters/onnx_exporter.py +0 -80
  32. sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
  33. sleap_nn/export/metadata.py +0 -225
  34. sleap_nn/export/predictors/__init__.py +0 -63
  35. sleap_nn/export/predictors/base.py +0 -22
  36. sleap_nn/export/predictors/onnx.py +0 -154
  37. sleap_nn/export/predictors/tensorrt.py +0 -312
  38. sleap_nn/export/utils.py +0 -307
  39. sleap_nn/export/wrappers/__init__.py +0 -25
  40. sleap_nn/export/wrappers/base.py +0 -96
  41. sleap_nn/export/wrappers/bottomup.py +0 -243
  42. sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
  43. sleap_nn/export/wrappers/centered_instance.py +0 -56
  44. sleap_nn/export/wrappers/centroid.py +0 -58
  45. sleap_nn/export/wrappers/single_instance.py +0 -83
  46. sleap_nn/export/wrappers/topdown.py +0 -180
  47. sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
  48. sleap_nn/inference/postprocessing.py +0 -284
  49. sleap_nn/training/schedulers.py +0 -191
  50. sleap_nn-0.1.0.dist-info/RECORD +0 -88
  51. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/entry_points.txt +0 -0
  52. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/licenses/LICENSE +0 -0
  53. {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.2)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,414 +0,0 @@
1
- """Skia-based augmentation functions that operate on uint8 tensors.
2
-
3
- This module provides augmentation functions using skia-python that:
4
- 1. Match the exact API of sleap_nn.data.augmentation
5
- 2. Operate on uint8 tensors throughout (avoiding float32 conversions)
6
- 3. Provide ~1.5x faster augmentation compared to Kornia
7
-
8
- Usage:
9
- from sleap_nn.data.skia_augmentation import (
10
- apply_intensity_augmentation_skia,
11
- apply_geometric_augmentation_skia,
12
- )
13
-
14
- # Apply augmentations (uint8 in, uint8 out)
15
- image, instances = apply_intensity_augmentation_skia(image, instances, **config)
16
- image, instances = apply_geometric_augmentation_skia(image, instances, **config)
17
- """
18
-
19
- from typing import Optional, Tuple
20
- import numpy as np
21
- import torch
22
- import skia
23
-
24
-
25
- def apply_intensity_augmentation_skia(
26
- image: torch.Tensor,
27
- instances: torch.Tensor,
28
- uniform_noise_min: float = 0.0,
29
- uniform_noise_max: float = 0.04,
30
- uniform_noise_p: float = 0.0,
31
- gaussian_noise_mean: float = 0.02,
32
- gaussian_noise_std: float = 0.004,
33
- gaussian_noise_p: float = 0.0,
34
- contrast_min: float = 0.5,
35
- contrast_max: float = 2.0,
36
- contrast_p: float = 0.0,
37
- brightness_min: float = 1.0,
38
- brightness_max: float = 1.0,
39
- brightness_p: float = 0.0,
40
- ) -> Tuple[torch.Tensor, torch.Tensor]:
41
- """Apply intensity augmentations on uint8 image tensor.
42
-
43
- Matches API of sleap_nn.data.augmentation.apply_intensity_augmentation.
44
-
45
- Args:
46
- image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
47
- instances: Keypoints tensor (not modified, just passed through).
48
- uniform_noise_min: Minimum uniform noise (0-1 scale, maps to 0-255).
49
- uniform_noise_max: Maximum uniform noise (0-1 scale).
50
- uniform_noise_p: Probability of uniform noise.
51
- gaussian_noise_mean: Gaussian noise mean (0-1 scale).
52
- gaussian_noise_std: Gaussian noise std (0-1 scale).
53
- gaussian_noise_p: Probability of Gaussian noise.
54
- contrast_min: Minimum contrast factor.
55
- contrast_max: Maximum contrast factor.
56
- contrast_p: Probability of contrast adjustment.
57
- brightness_min: Minimum brightness factor.
58
- brightness_max: Maximum brightness factor.
59
- brightness_p: Probability of brightness adjustment.
60
-
61
- Returns:
62
- Tuple of (augmented_image, instances). Image dtype matches input.
63
- """
64
- # Convert to numpy for Skia processing
65
- is_float = image.dtype == torch.float32
66
- if is_float:
67
- img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
68
- else:
69
- img_np = image[0].permute(1, 2, 0).numpy()
70
-
71
- result = img_np.copy()
72
-
73
- # Apply uniform noise (in uint8 space)
74
- if uniform_noise_p > 0 and np.random.random() < uniform_noise_p:
75
- noise = np.random.randint(
76
- int(uniform_noise_min * 255),
77
- int(uniform_noise_max * 255) + 1,
78
- img_np.shape,
79
- dtype=np.int16,
80
- )
81
- result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
82
-
83
- # Apply Gaussian noise (in uint8 space)
84
- if gaussian_noise_p > 0 and np.random.random() < gaussian_noise_p:
85
- noise = np.random.normal(
86
- gaussian_noise_mean * 255, gaussian_noise_std * 255, img_np.shape
87
- ).astype(np.int16)
88
- result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
89
-
90
- # Apply contrast using lookup table (pure uint8)
91
- if contrast_p > 0 and np.random.random() < contrast_p:
92
- factor = np.random.uniform(contrast_min, contrast_max)
93
- lut = np.arange(256, dtype=np.float32)
94
- lut = np.clip((lut - 127.5) * factor + 127.5, 0, 255).astype(np.uint8)
95
- result = lut[result]
96
-
97
- # Apply brightness using lookup table (pure uint8)
98
- if brightness_p > 0 and np.random.random() < brightness_p:
99
- factor = np.random.uniform(brightness_min, brightness_max)
100
- lut = np.arange(256, dtype=np.float32)
101
- lut = np.clip(lut * factor, 0, 255).astype(np.uint8)
102
- result = lut[result]
103
-
104
- # Convert back to tensor
105
- result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
106
- if is_float:
107
- result_tensor = result_tensor.float() / 255.0
108
-
109
- return result_tensor, instances
110
-
111
-
112
- def apply_geometric_augmentation_skia(
113
- image: torch.Tensor,
114
- instances: torch.Tensor,
115
- rotation_min: float = -15.0,
116
- rotation_max: float = 15.0,
117
- rotation_p: Optional[float] = None,
118
- scale_min: float = 0.9,
119
- scale_max: float = 1.1,
120
- scale_p: Optional[float] = None,
121
- translate_width: float = 0.02,
122
- translate_height: float = 0.02,
123
- translate_p: Optional[float] = None,
124
- affine_p: float = 0.0,
125
- erase_scale_min: float = 0.0001,
126
- erase_scale_max: float = 0.01,
127
- erase_ratio_min: float = 1.0,
128
- erase_ratio_max: float = 1.0,
129
- erase_p: float = 0.0,
130
- mixup_lambda_min: float = 0.01,
131
- mixup_lambda_max: float = 0.05,
132
- mixup_p: float = 0.0,
133
- ) -> Tuple[torch.Tensor, torch.Tensor]:
134
- """Apply geometric augmentations using Skia.
135
-
136
- Matches API of sleap_nn.data.augmentation.apply_geometric_augmentation.
137
-
138
- Args:
139
- image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
140
- instances: Keypoints tensor of shape (1, n_instances, n_nodes, 2) or (1, n_nodes, 2).
141
- rotation_min: Minimum rotation angle in degrees.
142
- rotation_max: Maximum rotation angle in degrees.
143
- rotation_p: Probability of rotation (independent). None = use affine_p.
144
- scale_min: Minimum scale factor.
145
- scale_max: Maximum scale factor.
146
- scale_p: Probability of scaling (independent). None = use affine_p.
147
- translate_width: Max horizontal translation as fraction of width.
148
- translate_height: Max vertical translation as fraction of height.
149
- translate_p: Probability of translation (independent). None = use affine_p.
150
- affine_p: Probability of bundled affine transform.
151
- erase_scale_min: Min proportion of image to erase.
152
- erase_scale_max: Max proportion of image to erase.
153
- erase_ratio_min: Min aspect ratio of erased area.
154
- erase_ratio_max: Max aspect ratio of erased area.
155
- erase_p: Probability of random erasing.
156
- mixup_lambda_min: Min mixup strength (not implemented).
157
- mixup_lambda_max: Max mixup strength (not implemented).
158
- mixup_p: Probability of mixup (not implemented).
159
-
160
- Returns:
161
- Tuple of (augmented_image, augmented_instances). Image dtype matches input.
162
- """
163
- # Convert to numpy for Skia processing
164
- is_float = image.dtype == torch.float32
165
- if is_float:
166
- img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
167
- else:
168
- img_np = image[0].permute(1, 2, 0).numpy().copy()
169
-
170
- h, w = img_np.shape[:2]
171
- cx, cy = w / 2, h / 2
172
-
173
- # Build transformation matrix
174
- matrix = skia.Matrix()
175
- has_transform = False
176
-
177
- use_independent = (
178
- rotation_p is not None or scale_p is not None or translate_p is not None
179
- )
180
-
181
- if use_independent:
182
- if (
183
- rotation_p is not None
184
- and rotation_p > 0
185
- and np.random.random() < rotation_p
186
- ):
187
- angle = np.random.uniform(rotation_min, rotation_max)
188
- rot_matrix = skia.Matrix()
189
- rot_matrix.setRotate(angle, cx, cy)
190
- matrix = matrix.preConcat(rot_matrix)
191
- has_transform = True
192
-
193
- if scale_p is not None and scale_p > 0 and np.random.random() < scale_p:
194
- scale = np.random.uniform(scale_min, scale_max)
195
- scale_matrix = skia.Matrix()
196
- scale_matrix.setScale(scale, scale, cx, cy)
197
- matrix = matrix.preConcat(scale_matrix)
198
- has_transform = True
199
-
200
- if (
201
- translate_p is not None
202
- and translate_p > 0
203
- and np.random.random() < translate_p
204
- ):
205
- tx = np.random.uniform(-translate_width, translate_width) * w
206
- ty = np.random.uniform(-translate_height, translate_height) * h
207
- trans_matrix = skia.Matrix()
208
- trans_matrix.setTranslate(tx, ty)
209
- matrix = matrix.preConcat(trans_matrix)
210
- has_transform = True
211
-
212
- elif affine_p > 0 and np.random.random() < affine_p:
213
- angle = np.random.uniform(rotation_min, rotation_max)
214
- scale = np.random.uniform(scale_min, scale_max)
215
- tx = np.random.uniform(-translate_width, translate_width) * w
216
- ty = np.random.uniform(-translate_height, translate_height) * h
217
-
218
- matrix.setRotate(angle, cx, cy)
219
- matrix.preScale(scale, scale, cx, cy)
220
- matrix.preTranslate(tx, ty)
221
- has_transform = True
222
-
223
- # Apply geometric transform
224
- if has_transform:
225
- img_np = _transform_image_skia(img_np, matrix)
226
- instances = _transform_keypoints_tensor(instances, matrix)
227
-
228
- # Apply random erasing
229
- if erase_p > 0 and np.random.random() < erase_p:
230
- img_np = _apply_random_erase(
231
- img_np, erase_scale_min, erase_scale_max, erase_ratio_min, erase_ratio_max
232
- )
233
-
234
- # Convert back to tensor
235
- result_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
236
- if is_float:
237
- result_tensor = result_tensor.float() / 255.0
238
-
239
- return result_tensor, instances
240
-
241
-
242
- def _transform_image_skia(image: np.ndarray, matrix: skia.Matrix) -> np.ndarray:
243
- """Transform image using Skia matrix (uint8 in, uint8 out)."""
244
- h, w = image.shape[:2]
245
- channels = image.shape[2] if image.ndim == 3 else 1
246
-
247
- # Skia needs RGBA
248
- if channels == 1:
249
- image_rgba = np.stack(
250
- [image.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
251
- )
252
- elif channels == 3:
253
- alpha = np.full((h, w, 1), 255, dtype=np.uint8)
254
- image_rgba = np.concatenate([image, alpha], axis=-1)
255
- else:
256
- raise ValueError(f"Unsupported channels: {channels}")
257
-
258
- image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
259
- skia_image = skia.Image.fromarray(
260
- image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
261
- )
262
-
263
- surface = skia.Surface(w, h)
264
- canvas = surface.getCanvas()
265
- canvas.clear(skia.Color4f(0, 0, 0, 1))
266
- canvas.setMatrix(matrix)
267
-
268
- paint = skia.Paint()
269
- paint.setAntiAlias(True)
270
- sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
271
- canvas.drawImage(skia_image, 0, 0, sampling, paint)
272
-
273
- result = surface.makeImageSnapshot().toarray()
274
-
275
- if channels == 1:
276
- return result[:, :, 0:1]
277
- return result[:, :, :channels]
278
-
279
-
280
- def _transform_keypoints_tensor(
281
- keypoints: torch.Tensor, matrix: skia.Matrix
282
- ) -> torch.Tensor:
283
- """Transform keypoints tensor using Skia matrix."""
284
- if keypoints.numel() == 0:
285
- return keypoints
286
-
287
- original_shape = keypoints.shape
288
- flat = keypoints.reshape(-1, 2).numpy()
289
-
290
- # Handle NaN values
291
- valid_mask = ~np.isnan(flat).any(axis=1)
292
- transformed = flat.copy()
293
-
294
- if valid_mask.any():
295
- valid_pts = flat[valid_mask]
296
- skia_pts = [skia.Point(float(p[0]), float(p[1])) for p in valid_pts]
297
- mapped = matrix.mapPoints(skia_pts)
298
- transformed[valid_mask] = np.array([[p.x(), p.y()] for p in mapped])
299
-
300
- return torch.from_numpy(transformed.reshape(original_shape).astype(np.float32))
301
-
302
-
303
- def _apply_random_erase(
304
- image: np.ndarray,
305
- scale_min: float,
306
- scale_max: float,
307
- ratio_min: float,
308
- ratio_max: float,
309
- ) -> np.ndarray:
310
- """Apply random erasing (uint8)."""
311
- h, w = image.shape[:2]
312
- area = h * w
313
-
314
- erase_area = np.random.uniform(scale_min, scale_max) * area
315
- aspect_ratio = np.random.uniform(ratio_min, ratio_max)
316
-
317
- erase_h = int(np.sqrt(erase_area * aspect_ratio))
318
- erase_w = int(np.sqrt(erase_area / aspect_ratio))
319
-
320
- if erase_h >= h or erase_w >= w:
321
- return image
322
-
323
- y = np.random.randint(0, h - erase_h)
324
- x = np.random.randint(0, w - erase_w)
325
-
326
- result = image.copy()
327
- channels = image.shape[2] if image.ndim == 3 else 1
328
- fill = np.random.randint(0, 256, size=(channels,), dtype=np.uint8)
329
- result[y : y + erase_h, x : x + erase_w] = fill
330
-
331
- return result
332
-
333
-
334
- def crop_and_resize_skia(
335
- image: torch.Tensor,
336
- boxes: torch.Tensor,
337
- size: Tuple[int, int],
338
- ) -> torch.Tensor:
339
- """Crop and resize image regions using Skia.
340
-
341
- Replacement for kornia.geometry.transform.crop_and_resize.
342
-
343
- Args:
344
- image: Input tensor of shape (1, C, H, W).
345
- boxes: Bounding boxes tensor of shape (1, 4, 2) with corners:
346
- [top-left, top-right, bottom-right, bottom-left].
347
- size: Output size (height, width).
348
-
349
- Returns:
350
- Cropped and resized tensor of shape (1, C, out_h, out_w).
351
- """
352
- is_float = image.dtype == torch.float32
353
- if is_float:
354
- img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
355
- else:
356
- img_np = image[0].permute(1, 2, 0).numpy()
357
-
358
- h, w = img_np.shape[:2]
359
- out_h, out_w = size
360
- channels = img_np.shape[2] if img_np.ndim == 3 else 1
361
-
362
- # Get box coordinates (top-left and bottom-right)
363
- box = boxes[0].numpy() # (4, 2)
364
- x1, y1 = box[0] # top-left
365
- x2, y2 = box[2] # bottom-right
366
-
367
- crop_w = x2 - x1
368
- crop_h = y2 - y1
369
-
370
- # Create transformation matrix
371
- matrix = skia.Matrix()
372
- scale_x = out_w / crop_w
373
- scale_y = out_h / crop_h
374
- matrix.setScale(scale_x, scale_y)
375
- matrix.preTranslate(-x1, -y1)
376
-
377
- # Skia needs RGBA
378
- if channels == 1:
379
- image_rgba = np.stack(
380
- [img_np.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
381
- )
382
- elif channels == 3:
383
- alpha = np.full((h, w, 1), 255, dtype=np.uint8)
384
- image_rgba = np.concatenate([img_np, alpha], axis=-1)
385
- else:
386
- raise ValueError(f"Unsupported channels: {channels}")
387
-
388
- image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
389
- skia_image = skia.Image.fromarray(
390
- image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
391
- )
392
-
393
- surface = skia.Surface(out_w, out_h)
394
- canvas = surface.getCanvas()
395
- canvas.clear(skia.Color4f(0, 0, 0, 1))
396
- canvas.setMatrix(matrix)
397
-
398
- paint = skia.Paint()
399
- paint.setAntiAlias(True)
400
- sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
401
- canvas.drawImage(skia_image, 0, 0, sampling, paint)
402
-
403
- result = surface.makeImageSnapshot().toarray()
404
-
405
- if channels == 1:
406
- result = result[:, :, 0:1]
407
- else:
408
- result = result[:, :, :channels]
409
-
410
- result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
411
- if is_float:
412
- result_tensor = result_tensor.float() / 255.0
413
-
414
- return result_tensor
@@ -1,21 +0,0 @@
1
- """Export utilities for sleap-nn."""
2
-
3
- from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
4
- from sleap_nn.export.metadata import ExportMetadata
5
- from sleap_nn.export.predictors import (
6
- load_exported_model,
7
- ONNXPredictor,
8
- TensorRTPredictor,
9
- )
10
- from sleap_nn.export.utils import build_bottomup_candidate_template
11
-
12
- __all__ = [
13
- "export_model",
14
- "export_to_onnx",
15
- "export_to_tensorrt",
16
- "load_exported_model",
17
- "ONNXPredictor",
18
- "TensorRTPredictor",
19
- "ExportMetadata",
20
- "build_bottomup_candidate_template",
21
- ]