sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sleap_nn/__init__.py +9 -2
  2. sleap_nn/architectures/convnext.py +5 -0
  3. sleap_nn/architectures/encoder_decoder.py +25 -6
  4. sleap_nn/architectures/swint.py +8 -0
  5. sleap_nn/cli.py +489 -46
  6. sleap_nn/config/data_config.py +51 -8
  7. sleap_nn/config/get_config.py +32 -24
  8. sleap_nn/config/trainer_config.py +88 -0
  9. sleap_nn/data/augmentation.py +61 -200
  10. sleap_nn/data/custom_datasets.py +433 -61
  11. sleap_nn/data/instance_cropping.py +71 -6
  12. sleap_nn/data/normalization.py +45 -2
  13. sleap_nn/data/providers.py +26 -0
  14. sleap_nn/data/resizing.py +2 -2
  15. sleap_nn/data/skia_augmentation.py +414 -0
  16. sleap_nn/data/utils.py +135 -17
  17. sleap_nn/evaluation.py +177 -42
  18. sleap_nn/export/__init__.py +21 -0
  19. sleap_nn/export/cli.py +1778 -0
  20. sleap_nn/export/exporters/__init__.py +51 -0
  21. sleap_nn/export/exporters/onnx_exporter.py +80 -0
  22. sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
  23. sleap_nn/export/metadata.py +225 -0
  24. sleap_nn/export/predictors/__init__.py +63 -0
  25. sleap_nn/export/predictors/base.py +22 -0
  26. sleap_nn/export/predictors/onnx.py +154 -0
  27. sleap_nn/export/predictors/tensorrt.py +312 -0
  28. sleap_nn/export/utils.py +307 -0
  29. sleap_nn/export/wrappers/__init__.py +25 -0
  30. sleap_nn/export/wrappers/base.py +96 -0
  31. sleap_nn/export/wrappers/bottomup.py +243 -0
  32. sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
  33. sleap_nn/export/wrappers/centered_instance.py +56 -0
  34. sleap_nn/export/wrappers/centroid.py +58 -0
  35. sleap_nn/export/wrappers/single_instance.py +83 -0
  36. sleap_nn/export/wrappers/topdown.py +180 -0
  37. sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
  38. sleap_nn/inference/__init__.py +6 -0
  39. sleap_nn/inference/bottomup.py +86 -20
  40. sleap_nn/inference/peak_finding.py +93 -16
  41. sleap_nn/inference/postprocessing.py +284 -0
  42. sleap_nn/inference/predictors.py +339 -137
  43. sleap_nn/inference/provenance.py +292 -0
  44. sleap_nn/inference/topdown.py +55 -47
  45. sleap_nn/legacy_models.py +65 -11
  46. sleap_nn/predict.py +224 -19
  47. sleap_nn/system_info.py +443 -0
  48. sleap_nn/tracking/tracker.py +8 -1
  49. sleap_nn/train.py +138 -44
  50. sleap_nn/training/callbacks.py +1258 -5
  51. sleap_nn/training/lightning_modules.py +902 -220
  52. sleap_nn/training/model_trainer.py +424 -111
  53. sleap_nn/training/schedulers.py +191 -0
  54. sleap_nn/training/utils.py +367 -2
  55. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
  56. sleap_nn-0.1.0.dist-info/RECORD +88 -0
  57. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
  58. sleap_nn-0.0.5.dist-info/RECORD +0 -63
  59. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
  60. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
  61. {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
@@ -5,14 +5,82 @@ import math
5
5
  import numpy as np
6
6
  import sleap_io as sio
7
7
  import torch
8
- from kornia.geometry.transform import crop_and_resize
8
+ from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
9
+
10
+
11
+ def compute_augmentation_padding(
12
+ bbox_size: float,
13
+ rotation_max: float = 0.0,
14
+ scale_max: float = 1.0,
15
+ ) -> int:
16
+ """Compute padding needed to accommodate augmentation transforms.
17
+
18
+ When rotation and scaling augmentations are applied, the bounding box of an
19
+ instance can expand beyond its original size. This function calculates the
20
+ padding needed to ensure the full instance remains visible after augmentation.
21
+
22
+ Args:
23
+ bbox_size: The size of the instance bounding box (max of width/height).
24
+ rotation_max: Maximum absolute rotation angle in degrees. For symmetric
25
+ rotation ranges like [-180, 180], pass 180.
26
+ scale_max: Maximum scaling factor. For scale range [0.9, 1.1], pass 1.1.
27
+
28
+ Returns:
29
+ Padding in pixels to add around the bounding box (total, not per side).
30
+ """
31
+ if rotation_max == 0.0 and scale_max <= 1.0:
32
+ return 0
33
+
34
+ # For a square bbox rotated by angle θ, the new bbox has side length:
35
+ # L' = L * (|cos(θ)| + |sin(θ)|)
36
+ # Maximum expansion occurs at 45°: L' = L * sqrt(2)
37
+ # For arbitrary angle: we use the worst case within the rotation range
38
+ rotation_rad = math.radians(min(abs(rotation_max), 90))
39
+ rotation_factor = abs(math.cos(rotation_rad)) + abs(math.sin(rotation_rad))
40
+
41
+ # For angles > 45°, the factor increases, max at 45° = sqrt(2)
42
+ # But for angles approaching 90°, it goes back to 1
43
+ # Worst case in any range including 45° is sqrt(2)
44
+ if abs(rotation_max) >= 45:
45
+ rotation_factor = math.sqrt(2)
46
+
47
+ # Combined expansion factor
48
+ expansion_factor = rotation_factor * max(scale_max, 1.0)
49
+
50
+ # Total padding needed (both sides)
51
+ expanded_size = bbox_size * expansion_factor
52
+ padding = expanded_size - bbox_size
53
+
54
+ return int(math.ceil(padding))
55
+
56
+
57
+ def find_max_instance_bbox_size(labels: sio.Labels) -> float:
58
+ """Find the maximum bounding box dimension across all instances in labels.
59
+
60
+ Args:
61
+ labels: A `sio.Labels` containing user-labeled instances.
62
+
63
+ Returns:
64
+ The maximum bounding box dimension (max of width or height) across all instances.
65
+ """
66
+ max_length = 0.0
67
+ for lf in labels:
68
+ for inst in lf.instances:
69
+ if not inst.is_empty:
70
+ pts = inst.numpy()
71
+ diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
72
+ diff_x = 0 if np.isnan(diff_x) else diff_x
73
+ max_length = np.maximum(max_length, diff_x)
74
+ diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
75
+ diff_y = 0 if np.isnan(diff_y) else diff_y
76
+ max_length = np.maximum(max_length, diff_y)
77
+ return float(max_length)
9
78
 
10
79
 
11
80
  def find_instance_crop_size(
12
81
  labels: sio.Labels,
13
82
  padding: int = 0,
14
83
  maximum_stride: int = 2,
15
- input_scaling: float = 1.0,
16
84
  min_crop_size: Optional[int] = None,
17
85
  ) -> int:
18
86
  """Compute the size of the largest instance bounding box from labels.
@@ -23,8 +91,6 @@ def find_instance_crop_size(
23
91
  maximum_stride: Ensure that the returned crop size is divisible by this value.
24
92
  Useful for ensuring that the crop size will not be truncated in a given
25
93
  architecture.
26
- input_scaling: Float factor indicating the scale of the input images if any
27
- scaling will be done before cropping.
28
94
  min_crop_size: The crop size set by the user.
29
95
 
30
96
  Returns:
@@ -32,7 +98,7 @@ def find_instance_crop_size(
32
98
  will contain the instances when cropped. The returned crop size will be larger
33
99
  or equal to the input `min_crop_size`.
34
100
 
35
- This accounts for stride, padding and scaling when ensuring divisibility.
101
+ This accounts for stride and padding when ensuring divisibility.
36
102
  """
37
103
  # Check if user-specified crop size is divisible by max stride
38
104
  min_crop_size = 0 if min_crop_size is None else min_crop_size
@@ -46,7 +112,6 @@ def find_instance_crop_size(
46
112
  for inst in lf.instances:
47
113
  if not inst.is_empty: # only if at least one point is not nan
48
114
  pts = inst.numpy()
49
- pts *= input_scaling
50
115
  diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
51
116
  diff_x = 0 if np.isnan(diff_x) else diff_x
52
117
  max_length = np.maximum(max_length, diff_x)
@@ -4,6 +4,36 @@ import torch
4
4
  import torchvision.transforms.v2.functional as F
5
5
 
6
6
 
7
+ def normalize_on_gpu(image: torch.Tensor) -> torch.Tensor:
8
+ """Normalize image tensor on GPU after transfer.
9
+
10
+ This function is called in the model's forward() method after the image has been
11
+ transferred to GPU. It converts uint8 images to float32 and normalizes to [0, 1].
12
+
13
+ By performing normalization on GPU after transfer, we reduce PCIe bandwidth by 4x
14
+ (transferring 1 byte/pixel as uint8 instead of 4 bytes/pixel as float32). This
15
+ provides up to 17x speedup for the transfer+normalization stage.
16
+
17
+ This function handles two cases:
18
+ 1. uint8 tensor with values in [0, 255] -> convert to float32 and divide by 255
19
+ 2. float32 tensor with values in [0, 255] (e.g., from preprocessing that cast to
20
+ float32 without normalizing) -> divide by 255
21
+
22
+ Args:
23
+ image: Tensor image that may be uint8 or float32 with values in [0, 255] range.
24
+
25
+ Returns:
26
+ Float32 tensor normalized to [0, 1] range.
27
+ """
28
+ if not torch.is_floating_point(image):
29
+ # uint8 -> float32 normalized
30
+ image = image.float() / 255.0
31
+ elif image.max() > 1.0:
32
+ # float32 but not normalized (values > 1 indicate [0, 255] range)
33
+ image = image / 255.0
34
+ return image
35
+
36
+
7
37
  def convert_to_grayscale(image: torch.Tensor):
8
38
  """Convert given image to Grayscale image (single-channel).
9
39
 
@@ -38,8 +68,21 @@ def convert_to_rgb(image: torch.Tensor):
38
68
  return image
39
69
 
40
70
 
41
- def apply_normalization(image: torch.Tensor):
42
- """Normalize image tensor."""
71
+ def apply_normalization(image: torch.Tensor) -> torch.Tensor:
72
+ """Normalize image tensor from uint8 [0, 255] to float32 [0, 1].
73
+
74
+ This function is used during training data preprocessing where augmentation
75
+ operations (kornia) require float32 input.
76
+
77
+ For inference, normalization is deferred to GPU via `normalize_on_gpu()` in the
78
+ model's forward() method to reduce PCIe bandwidth.
79
+
80
+ Args:
81
+ image: Tensor image (typically uint8 with values in [0, 255]).
82
+
83
+ Returns:
84
+ Float32 tensor normalized to [0, 1] range.
85
+ """
43
86
  if not torch.is_floating_point(image):
44
87
  image = image.to(torch.float32) / 255.0
45
88
  return image
@@ -71,6 +71,8 @@ def process_lf(
71
71
  for inst in instances_list:
72
72
  if not inst.is_empty:
73
73
  instances.append(inst.numpy())
74
+ if len(instances) == 0:
75
+ return None
74
76
  instances = np.stack(instances, axis=0)
75
77
 
76
78
  # Add singleton time dimension for single frames.
@@ -233,6 +235,8 @@ class LabelsReader(Thread):
233
235
  instances_key: bool = False,
234
236
  only_labeled_frames: bool = False,
235
237
  only_suggested_frames: bool = False,
238
+ exclude_user_labeled: bool = False,
239
+ only_predicted_frames: bool = False,
236
240
  ):
237
241
  """Initialize attribute of the class."""
238
242
  super().__init__()
@@ -245,6 +249,8 @@ class LabelsReader(Thread):
245
249
 
246
250
  self.only_labeled_frames = only_labeled_frames
247
251
  self.only_suggested_frames = only_suggested_frames
252
+ self.exclude_user_labeled = exclude_user_labeled
253
+ self.only_predicted_frames = only_predicted_frames
248
254
 
249
255
  # Filter to only user labeled instances
250
256
  if self.only_labeled_frames:
@@ -265,6 +271,20 @@ class LabelsReader(Thread):
265
271
  )
266
272
  self.filtered_lfs.append(new_lf)
267
273
 
274
+ # Filter out user labeled frames
275
+ elif self.exclude_user_labeled:
276
+ self.filtered_lfs = []
277
+ for lf in self.labels:
278
+ if not lf.has_user_instances:
279
+ self.filtered_lfs.append(lf)
280
+
281
+ # Filter to only predicted frames
282
+ elif self.only_predicted_frames:
283
+ self.filtered_lfs = []
284
+ for lf in self.labels:
285
+ if lf.has_predicted_instances:
286
+ self.filtered_lfs.append(lf)
287
+
268
288
  else:
269
289
  self.filtered_lfs = [lf for lf in self.labels]
270
290
 
@@ -300,6 +320,8 @@ class LabelsReader(Thread):
300
320
  instances_key: bool = False,
301
321
  only_labeled_frames: bool = False,
302
322
  only_suggested_frames: bool = False,
323
+ exclude_user_labeled: bool = False,
324
+ only_predicted_frames: bool = False,
303
325
  ):
304
326
  """Create LabelsReader from a .slp filename."""
305
327
  labels = sio.load_slp(filename)
@@ -310,6 +332,8 @@ class LabelsReader(Thread):
310
332
  instances_key,
311
333
  only_labeled_frames,
312
334
  only_suggested_frames,
335
+ exclude_user_labeled,
336
+ only_predicted_frames,
313
337
  )
314
338
 
315
339
  def run(self):
@@ -333,6 +357,8 @@ class LabelsReader(Thread):
333
357
  for inst in lf:
334
358
  if not inst.is_empty:
335
359
  instances.append(inst.numpy())
360
+ if len(instances) == 0:
361
+ continue
336
362
  instances = np.stack(instances, axis=0)
337
363
 
338
364
  # Add singleton time dimension for single frames.
sleap_nn/data/resizing.py CHANGED
@@ -63,7 +63,7 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor:
63
63
  image,
64
64
  (0, pad_width, 0, pad_height),
65
65
  mode="constant",
66
- ).to(torch.float32)
66
+ )
67
67
  return image
68
68
 
69
69
 
@@ -136,7 +136,7 @@ def apply_sizematcher(
136
136
  image,
137
137
  (0, pad_width, 0, pad_height),
138
138
  mode="constant",
139
- ).to(torch.float32)
139
+ )
140
140
 
141
141
  return image, eff_scale_ratio
142
142
  else:
@@ -0,0 +1,414 @@
1
+ """Skia-based augmentation functions that operate on uint8 tensors.
2
+
3
+ This module provides augmentation functions using skia-python that:
4
+ 1. Match the exact API of sleap_nn.data.augmentation
5
+ 2. Operate on uint8 tensors throughout (avoiding float32 conversions)
6
+ 3. Provide ~1.5x faster augmentation compared to Kornia
7
+
8
+ Usage:
9
+ from sleap_nn.data.skia_augmentation import (
10
+ apply_intensity_augmentation_skia,
11
+ apply_geometric_augmentation_skia,
12
+ )
13
+
14
+ # Apply augmentations (uint8 in, uint8 out)
15
+ image, instances = apply_intensity_augmentation_skia(image, instances, **config)
16
+ image, instances = apply_geometric_augmentation_skia(image, instances, **config)
17
+ """
18
+
19
+ from typing import Optional, Tuple
20
+ import numpy as np
21
+ import torch
22
+ import skia
23
+
24
+
25
+ def apply_intensity_augmentation_skia(
26
+ image: torch.Tensor,
27
+ instances: torch.Tensor,
28
+ uniform_noise_min: float = 0.0,
29
+ uniform_noise_max: float = 0.04,
30
+ uniform_noise_p: float = 0.0,
31
+ gaussian_noise_mean: float = 0.02,
32
+ gaussian_noise_std: float = 0.004,
33
+ gaussian_noise_p: float = 0.0,
34
+ contrast_min: float = 0.5,
35
+ contrast_max: float = 2.0,
36
+ contrast_p: float = 0.0,
37
+ brightness_min: float = 1.0,
38
+ brightness_max: float = 1.0,
39
+ brightness_p: float = 0.0,
40
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
41
+ """Apply intensity augmentations on uint8 image tensor.
42
+
43
+ Matches API of sleap_nn.data.augmentation.apply_intensity_augmentation.
44
+
45
+ Args:
46
+ image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
47
+ instances: Keypoints tensor (not modified, just passed through).
48
+ uniform_noise_min: Minimum uniform noise (0-1 scale, maps to 0-255).
49
+ uniform_noise_max: Maximum uniform noise (0-1 scale).
50
+ uniform_noise_p: Probability of uniform noise.
51
+ gaussian_noise_mean: Gaussian noise mean (0-1 scale).
52
+ gaussian_noise_std: Gaussian noise std (0-1 scale).
53
+ gaussian_noise_p: Probability of Gaussian noise.
54
+ contrast_min: Minimum contrast factor.
55
+ contrast_max: Maximum contrast factor.
56
+ contrast_p: Probability of contrast adjustment.
57
+ brightness_min: Minimum brightness factor.
58
+ brightness_max: Maximum brightness factor.
59
+ brightness_p: Probability of brightness adjustment.
60
+
61
+ Returns:
62
+ Tuple of (augmented_image, instances). Image dtype matches input.
63
+ """
64
+ # Convert to numpy for Skia processing
65
+ is_float = image.dtype == torch.float32
66
+ if is_float:
67
+ img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
68
+ else:
69
+ img_np = image[0].permute(1, 2, 0).numpy()
70
+
71
+ result = img_np.copy()
72
+
73
+ # Apply uniform noise (in uint8 space)
74
+ if uniform_noise_p > 0 and np.random.random() < uniform_noise_p:
75
+ noise = np.random.randint(
76
+ int(uniform_noise_min * 255),
77
+ int(uniform_noise_max * 255) + 1,
78
+ img_np.shape,
79
+ dtype=np.int16,
80
+ )
81
+ result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
82
+
83
+ # Apply Gaussian noise (in uint8 space)
84
+ if gaussian_noise_p > 0 and np.random.random() < gaussian_noise_p:
85
+ noise = np.random.normal(
86
+ gaussian_noise_mean * 255, gaussian_noise_std * 255, img_np.shape
87
+ ).astype(np.int16)
88
+ result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
89
+
90
+ # Apply contrast using lookup table (pure uint8)
91
+ if contrast_p > 0 and np.random.random() < contrast_p:
92
+ factor = np.random.uniform(contrast_min, contrast_max)
93
+ lut = np.arange(256, dtype=np.float32)
94
+ lut = np.clip((lut - 127.5) * factor + 127.5, 0, 255).astype(np.uint8)
95
+ result = lut[result]
96
+
97
+ # Apply brightness using lookup table (pure uint8)
98
+ if brightness_p > 0 and np.random.random() < brightness_p:
99
+ factor = np.random.uniform(brightness_min, brightness_max)
100
+ lut = np.arange(256, dtype=np.float32)
101
+ lut = np.clip(lut * factor, 0, 255).astype(np.uint8)
102
+ result = lut[result]
103
+
104
+ # Convert back to tensor
105
+ result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
106
+ if is_float:
107
+ result_tensor = result_tensor.float() / 255.0
108
+
109
+ return result_tensor, instances
110
+
111
+
112
+ def apply_geometric_augmentation_skia(
113
+ image: torch.Tensor,
114
+ instances: torch.Tensor,
115
+ rotation_min: float = -15.0,
116
+ rotation_max: float = 15.0,
117
+ rotation_p: Optional[float] = None,
118
+ scale_min: float = 0.9,
119
+ scale_max: float = 1.1,
120
+ scale_p: Optional[float] = None,
121
+ translate_width: float = 0.02,
122
+ translate_height: float = 0.02,
123
+ translate_p: Optional[float] = None,
124
+ affine_p: float = 0.0,
125
+ erase_scale_min: float = 0.0001,
126
+ erase_scale_max: float = 0.01,
127
+ erase_ratio_min: float = 1.0,
128
+ erase_ratio_max: float = 1.0,
129
+ erase_p: float = 0.0,
130
+ mixup_lambda_min: float = 0.01,
131
+ mixup_lambda_max: float = 0.05,
132
+ mixup_p: float = 0.0,
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """Apply geometric augmentations using Skia.
135
+
136
+ Matches API of sleap_nn.data.augmentation.apply_geometric_augmentation.
137
+
138
+ Args:
139
+ image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
140
+ instances: Keypoints tensor of shape (1, n_instances, n_nodes, 2) or (1, n_nodes, 2).
141
+ rotation_min: Minimum rotation angle in degrees.
142
+ rotation_max: Maximum rotation angle in degrees.
143
+ rotation_p: Probability of rotation (independent). None = use affine_p.
144
+ scale_min: Minimum scale factor.
145
+ scale_max: Maximum scale factor.
146
+ scale_p: Probability of scaling (independent). None = use affine_p.
147
+ translate_width: Max horizontal translation as fraction of width.
148
+ translate_height: Max vertical translation as fraction of height.
149
+ translate_p: Probability of translation (independent). None = use affine_p.
150
+ affine_p: Probability of bundled affine transform.
151
+ erase_scale_min: Min proportion of image to erase.
152
+ erase_scale_max: Max proportion of image to erase.
153
+ erase_ratio_min: Min aspect ratio of erased area.
154
+ erase_ratio_max: Max aspect ratio of erased area.
155
+ erase_p: Probability of random erasing.
156
+ mixup_lambda_min: Min mixup strength (not implemented).
157
+ mixup_lambda_max: Max mixup strength (not implemented).
158
+ mixup_p: Probability of mixup (not implemented).
159
+
160
+ Returns:
161
+ Tuple of (augmented_image, augmented_instances). Image dtype matches input.
162
+ """
163
+ # Convert to numpy for Skia processing
164
+ is_float = image.dtype == torch.float32
165
+ if is_float:
166
+ img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
167
+ else:
168
+ img_np = image[0].permute(1, 2, 0).numpy().copy()
169
+
170
+ h, w = img_np.shape[:2]
171
+ cx, cy = w / 2, h / 2
172
+
173
+ # Build transformation matrix
174
+ matrix = skia.Matrix()
175
+ has_transform = False
176
+
177
+ use_independent = (
178
+ rotation_p is not None or scale_p is not None or translate_p is not None
179
+ )
180
+
181
+ if use_independent:
182
+ if (
183
+ rotation_p is not None
184
+ and rotation_p > 0
185
+ and np.random.random() < rotation_p
186
+ ):
187
+ angle = np.random.uniform(rotation_min, rotation_max)
188
+ rot_matrix = skia.Matrix()
189
+ rot_matrix.setRotate(angle, cx, cy)
190
+ matrix = matrix.preConcat(rot_matrix)
191
+ has_transform = True
192
+
193
+ if scale_p is not None and scale_p > 0 and np.random.random() < scale_p:
194
+ scale = np.random.uniform(scale_min, scale_max)
195
+ scale_matrix = skia.Matrix()
196
+ scale_matrix.setScale(scale, scale, cx, cy)
197
+ matrix = matrix.preConcat(scale_matrix)
198
+ has_transform = True
199
+
200
+ if (
201
+ translate_p is not None
202
+ and translate_p > 0
203
+ and np.random.random() < translate_p
204
+ ):
205
+ tx = np.random.uniform(-translate_width, translate_width) * w
206
+ ty = np.random.uniform(-translate_height, translate_height) * h
207
+ trans_matrix = skia.Matrix()
208
+ trans_matrix.setTranslate(tx, ty)
209
+ matrix = matrix.preConcat(trans_matrix)
210
+ has_transform = True
211
+
212
+ elif affine_p > 0 and np.random.random() < affine_p:
213
+ angle = np.random.uniform(rotation_min, rotation_max)
214
+ scale = np.random.uniform(scale_min, scale_max)
215
+ tx = np.random.uniform(-translate_width, translate_width) * w
216
+ ty = np.random.uniform(-translate_height, translate_height) * h
217
+
218
+ matrix.setRotate(angle, cx, cy)
219
+ matrix.preScale(scale, scale, cx, cy)
220
+ matrix.preTranslate(tx, ty)
221
+ has_transform = True
222
+
223
+ # Apply geometric transform
224
+ if has_transform:
225
+ img_np = _transform_image_skia(img_np, matrix)
226
+ instances = _transform_keypoints_tensor(instances, matrix)
227
+
228
+ # Apply random erasing
229
+ if erase_p > 0 and np.random.random() < erase_p:
230
+ img_np = _apply_random_erase(
231
+ img_np, erase_scale_min, erase_scale_max, erase_ratio_min, erase_ratio_max
232
+ )
233
+
234
+ # Convert back to tensor
235
+ result_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
236
+ if is_float:
237
+ result_tensor = result_tensor.float() / 255.0
238
+
239
+ return result_tensor, instances
240
+
241
+
242
+ def _transform_image_skia(image: np.ndarray, matrix: skia.Matrix) -> np.ndarray:
243
+ """Transform image using Skia matrix (uint8 in, uint8 out)."""
244
+ h, w = image.shape[:2]
245
+ channels = image.shape[2] if image.ndim == 3 else 1
246
+
247
+ # Skia needs RGBA
248
+ if channels == 1:
249
+ image_rgba = np.stack(
250
+ [image.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
251
+ )
252
+ elif channels == 3:
253
+ alpha = np.full((h, w, 1), 255, dtype=np.uint8)
254
+ image_rgba = np.concatenate([image, alpha], axis=-1)
255
+ else:
256
+ raise ValueError(f"Unsupported channels: {channels}")
257
+
258
+ image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
259
+ skia_image = skia.Image.fromarray(
260
+ image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
261
+ )
262
+
263
+ surface = skia.Surface(w, h)
264
+ canvas = surface.getCanvas()
265
+ canvas.clear(skia.Color4f(0, 0, 0, 1))
266
+ canvas.setMatrix(matrix)
267
+
268
+ paint = skia.Paint()
269
+ paint.setAntiAlias(True)
270
+ sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
271
+ canvas.drawImage(skia_image, 0, 0, sampling, paint)
272
+
273
+ result = surface.makeImageSnapshot().toarray()
274
+
275
+ if channels == 1:
276
+ return result[:, :, 0:1]
277
+ return result[:, :, :channels]
278
+
279
+
280
+ def _transform_keypoints_tensor(
281
+ keypoints: torch.Tensor, matrix: skia.Matrix
282
+ ) -> torch.Tensor:
283
+ """Transform keypoints tensor using Skia matrix."""
284
+ if keypoints.numel() == 0:
285
+ return keypoints
286
+
287
+ original_shape = keypoints.shape
288
+ flat = keypoints.reshape(-1, 2).numpy()
289
+
290
+ # Handle NaN values
291
+ valid_mask = ~np.isnan(flat).any(axis=1)
292
+ transformed = flat.copy()
293
+
294
+ if valid_mask.any():
295
+ valid_pts = flat[valid_mask]
296
+ skia_pts = [skia.Point(float(p[0]), float(p[1])) for p in valid_pts]
297
+ mapped = matrix.mapPoints(skia_pts)
298
+ transformed[valid_mask] = np.array([[p.x(), p.y()] for p in mapped])
299
+
300
+ return torch.from_numpy(transformed.reshape(original_shape).astype(np.float32))
301
+
302
+
303
+ def _apply_random_erase(
304
+ image: np.ndarray,
305
+ scale_min: float,
306
+ scale_max: float,
307
+ ratio_min: float,
308
+ ratio_max: float,
309
+ ) -> np.ndarray:
310
+ """Apply random erasing (uint8)."""
311
+ h, w = image.shape[:2]
312
+ area = h * w
313
+
314
+ erase_area = np.random.uniform(scale_min, scale_max) * area
315
+ aspect_ratio = np.random.uniform(ratio_min, ratio_max)
316
+
317
+ erase_h = int(np.sqrt(erase_area * aspect_ratio))
318
+ erase_w = int(np.sqrt(erase_area / aspect_ratio))
319
+
320
+ if erase_h >= h or erase_w >= w:
321
+ return image
322
+
323
+ y = np.random.randint(0, h - erase_h)
324
+ x = np.random.randint(0, w - erase_w)
325
+
326
+ result = image.copy()
327
+ channels = image.shape[2] if image.ndim == 3 else 1
328
+ fill = np.random.randint(0, 256, size=(channels,), dtype=np.uint8)
329
+ result[y : y + erase_h, x : x + erase_w] = fill
330
+
331
+ return result
332
+
333
+
334
+ def crop_and_resize_skia(
335
+ image: torch.Tensor,
336
+ boxes: torch.Tensor,
337
+ size: Tuple[int, int],
338
+ ) -> torch.Tensor:
339
+ """Crop and resize image regions using Skia.
340
+
341
+ Replacement for kornia.geometry.transform.crop_and_resize.
342
+
343
+ Args:
344
+ image: Input tensor of shape (1, C, H, W).
345
+ boxes: Bounding boxes tensor of shape (1, 4, 2) with corners:
346
+ [top-left, top-right, bottom-right, bottom-left].
347
+ size: Output size (height, width).
348
+
349
+ Returns:
350
+ Cropped and resized tensor of shape (1, C, out_h, out_w).
351
+ """
352
+ is_float = image.dtype == torch.float32
353
+ if is_float:
354
+ img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
355
+ else:
356
+ img_np = image[0].permute(1, 2, 0).numpy()
357
+
358
+ h, w = img_np.shape[:2]
359
+ out_h, out_w = size
360
+ channels = img_np.shape[2] if img_np.ndim == 3 else 1
361
+
362
+ # Get box coordinates (top-left and bottom-right)
363
+ box = boxes[0].numpy() # (4, 2)
364
+ x1, y1 = box[0] # top-left
365
+ x2, y2 = box[2] # bottom-right
366
+
367
+ crop_w = x2 - x1
368
+ crop_h = y2 - y1
369
+
370
+ # Create transformation matrix
371
+ matrix = skia.Matrix()
372
+ scale_x = out_w / crop_w
373
+ scale_y = out_h / crop_h
374
+ matrix.setScale(scale_x, scale_y)
375
+ matrix.preTranslate(-x1, -y1)
376
+
377
+ # Skia needs RGBA
378
+ if channels == 1:
379
+ image_rgba = np.stack(
380
+ [img_np.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
381
+ )
382
+ elif channels == 3:
383
+ alpha = np.full((h, w, 1), 255, dtype=np.uint8)
384
+ image_rgba = np.concatenate([img_np, alpha], axis=-1)
385
+ else:
386
+ raise ValueError(f"Unsupported channels: {channels}")
387
+
388
+ image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
389
+ skia_image = skia.Image.fromarray(
390
+ image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
391
+ )
392
+
393
+ surface = skia.Surface(out_w, out_h)
394
+ canvas = surface.getCanvas()
395
+ canvas.clear(skia.Color4f(0, 0, 0, 1))
396
+ canvas.setMatrix(matrix)
397
+
398
+ paint = skia.Paint()
399
+ paint.setAntiAlias(True)
400
+ sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
401
+ canvas.drawImage(skia_image, 0, 0, sampling, paint)
402
+
403
+ result = surface.makeImageSnapshot().toarray()
404
+
405
+ if channels == 1:
406
+ result = result[:, :, 0:1]
407
+ else:
408
+ result = result[:, :, :channels]
409
+
410
+ result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
411
+ if is_float:
412
+ result_tensor = result_tensor.float() / 255.0
413
+
414
+ return result_tensor