sleap-nn 0.0.5__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +9 -2
- sleap_nn/architectures/convnext.py +5 -0
- sleap_nn/architectures/encoder_decoder.py +25 -6
- sleap_nn/architectures/swint.py +8 -0
- sleap_nn/cli.py +489 -46
- sleap_nn/config/data_config.py +51 -8
- sleap_nn/config/get_config.py +32 -24
- sleap_nn/config/trainer_config.py +88 -0
- sleap_nn/data/augmentation.py +61 -200
- sleap_nn/data/custom_datasets.py +433 -61
- sleap_nn/data/instance_cropping.py +71 -6
- sleap_nn/data/normalization.py +45 -2
- sleap_nn/data/providers.py +26 -0
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/skia_augmentation.py +414 -0
- sleap_nn/data/utils.py +135 -17
- sleap_nn/evaluation.py +177 -42
- sleap_nn/export/__init__.py +21 -0
- sleap_nn/export/cli.py +1778 -0
- sleap_nn/export/exporters/__init__.py +51 -0
- sleap_nn/export/exporters/onnx_exporter.py +80 -0
- sleap_nn/export/exporters/tensorrt_exporter.py +291 -0
- sleap_nn/export/metadata.py +225 -0
- sleap_nn/export/predictors/__init__.py +63 -0
- sleap_nn/export/predictors/base.py +22 -0
- sleap_nn/export/predictors/onnx.py +154 -0
- sleap_nn/export/predictors/tensorrt.py +312 -0
- sleap_nn/export/utils.py +307 -0
- sleap_nn/export/wrappers/__init__.py +25 -0
- sleap_nn/export/wrappers/base.py +96 -0
- sleap_nn/export/wrappers/bottomup.py +243 -0
- sleap_nn/export/wrappers/bottomup_multiclass.py +195 -0
- sleap_nn/export/wrappers/centered_instance.py +56 -0
- sleap_nn/export/wrappers/centroid.py +58 -0
- sleap_nn/export/wrappers/single_instance.py +83 -0
- sleap_nn/export/wrappers/topdown.py +180 -0
- sleap_nn/export/wrappers/topdown_multiclass.py +304 -0
- sleap_nn/inference/__init__.py +6 -0
- sleap_nn/inference/bottomup.py +86 -20
- sleap_nn/inference/peak_finding.py +93 -16
- sleap_nn/inference/postprocessing.py +284 -0
- sleap_nn/inference/predictors.py +339 -137
- sleap_nn/inference/provenance.py +292 -0
- sleap_nn/inference/topdown.py +55 -47
- sleap_nn/legacy_models.py +65 -11
- sleap_nn/predict.py +224 -19
- sleap_nn/system_info.py +443 -0
- sleap_nn/tracking/tracker.py +8 -1
- sleap_nn/train.py +138 -44
- sleap_nn/training/callbacks.py +1258 -5
- sleap_nn/training/lightning_modules.py +902 -220
- sleap_nn/training/model_trainer.py +424 -111
- sleap_nn/training/schedulers.py +191 -0
- sleap_nn/training/utils.py +367 -2
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/METADATA +35 -33
- sleap_nn-0.1.0.dist-info/RECORD +88 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/WHEEL +1 -1
- sleap_nn-0.0.5.dist-info/RECORD +0 -63
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.0.5.dist-info → sleap_nn-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -5,14 +5,82 @@ import math
|
|
|
5
5
|
import numpy as np
|
|
6
6
|
import sleap_io as sio
|
|
7
7
|
import torch
|
|
8
|
-
from
|
|
8
|
+
from sleap_nn.data.skia_augmentation import crop_and_resize_skia as crop_and_resize
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def compute_augmentation_padding(
|
|
12
|
+
bbox_size: float,
|
|
13
|
+
rotation_max: float = 0.0,
|
|
14
|
+
scale_max: float = 1.0,
|
|
15
|
+
) -> int:
|
|
16
|
+
"""Compute padding needed to accommodate augmentation transforms.
|
|
17
|
+
|
|
18
|
+
When rotation and scaling augmentations are applied, the bounding box of an
|
|
19
|
+
instance can expand beyond its original size. This function calculates the
|
|
20
|
+
padding needed to ensure the full instance remains visible after augmentation.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
bbox_size: The size of the instance bounding box (max of width/height).
|
|
24
|
+
rotation_max: Maximum absolute rotation angle in degrees. For symmetric
|
|
25
|
+
rotation ranges like [-180, 180], pass 180.
|
|
26
|
+
scale_max: Maximum scaling factor. For scale range [0.9, 1.1], pass 1.1.
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
Padding in pixels to add around the bounding box (total, not per side).
|
|
30
|
+
"""
|
|
31
|
+
if rotation_max == 0.0 and scale_max <= 1.0:
|
|
32
|
+
return 0
|
|
33
|
+
|
|
34
|
+
# For a square bbox rotated by angle θ, the new bbox has side length:
|
|
35
|
+
# L' = L * (|cos(θ)| + |sin(θ)|)
|
|
36
|
+
# Maximum expansion occurs at 45°: L' = L * sqrt(2)
|
|
37
|
+
# For arbitrary angle: we use the worst case within the rotation range
|
|
38
|
+
rotation_rad = math.radians(min(abs(rotation_max), 90))
|
|
39
|
+
rotation_factor = abs(math.cos(rotation_rad)) + abs(math.sin(rotation_rad))
|
|
40
|
+
|
|
41
|
+
# For angles > 45°, the factor increases, max at 45° = sqrt(2)
|
|
42
|
+
# But for angles approaching 90°, it goes back to 1
|
|
43
|
+
# Worst case in any range including 45° is sqrt(2)
|
|
44
|
+
if abs(rotation_max) >= 45:
|
|
45
|
+
rotation_factor = math.sqrt(2)
|
|
46
|
+
|
|
47
|
+
# Combined expansion factor
|
|
48
|
+
expansion_factor = rotation_factor * max(scale_max, 1.0)
|
|
49
|
+
|
|
50
|
+
# Total padding needed (both sides)
|
|
51
|
+
expanded_size = bbox_size * expansion_factor
|
|
52
|
+
padding = expanded_size - bbox_size
|
|
53
|
+
|
|
54
|
+
return int(math.ceil(padding))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def find_max_instance_bbox_size(labels: sio.Labels) -> float:
|
|
58
|
+
"""Find the maximum bounding box dimension across all instances in labels.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
labels: A `sio.Labels` containing user-labeled instances.
|
|
62
|
+
|
|
63
|
+
Returns:
|
|
64
|
+
The maximum bounding box dimension (max of width or height) across all instances.
|
|
65
|
+
"""
|
|
66
|
+
max_length = 0.0
|
|
67
|
+
for lf in labels:
|
|
68
|
+
for inst in lf.instances:
|
|
69
|
+
if not inst.is_empty:
|
|
70
|
+
pts = inst.numpy()
|
|
71
|
+
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
|
|
72
|
+
diff_x = 0 if np.isnan(diff_x) else diff_x
|
|
73
|
+
max_length = np.maximum(max_length, diff_x)
|
|
74
|
+
diff_y = np.nanmax(pts[:, 1]) - np.nanmin(pts[:, 1])
|
|
75
|
+
diff_y = 0 if np.isnan(diff_y) else diff_y
|
|
76
|
+
max_length = np.maximum(max_length, diff_y)
|
|
77
|
+
return float(max_length)
|
|
9
78
|
|
|
10
79
|
|
|
11
80
|
def find_instance_crop_size(
|
|
12
81
|
labels: sio.Labels,
|
|
13
82
|
padding: int = 0,
|
|
14
83
|
maximum_stride: int = 2,
|
|
15
|
-
input_scaling: float = 1.0,
|
|
16
84
|
min_crop_size: Optional[int] = None,
|
|
17
85
|
) -> int:
|
|
18
86
|
"""Compute the size of the largest instance bounding box from labels.
|
|
@@ -23,8 +91,6 @@ def find_instance_crop_size(
|
|
|
23
91
|
maximum_stride: Ensure that the returned crop size is divisible by this value.
|
|
24
92
|
Useful for ensuring that the crop size will not be truncated in a given
|
|
25
93
|
architecture.
|
|
26
|
-
input_scaling: Float factor indicating the scale of the input images if any
|
|
27
|
-
scaling will be done before cropping.
|
|
28
94
|
min_crop_size: The crop size set by the user.
|
|
29
95
|
|
|
30
96
|
Returns:
|
|
@@ -32,7 +98,7 @@ def find_instance_crop_size(
|
|
|
32
98
|
will contain the instances when cropped. The returned crop size will be larger
|
|
33
99
|
or equal to the input `min_crop_size`.
|
|
34
100
|
|
|
35
|
-
This accounts for stride
|
|
101
|
+
This accounts for stride and padding when ensuring divisibility.
|
|
36
102
|
"""
|
|
37
103
|
# Check if user-specified crop size is divisible by max stride
|
|
38
104
|
min_crop_size = 0 if min_crop_size is None else min_crop_size
|
|
@@ -46,7 +112,6 @@ def find_instance_crop_size(
|
|
|
46
112
|
for inst in lf.instances:
|
|
47
113
|
if not inst.is_empty: # only if at least one point is not nan
|
|
48
114
|
pts = inst.numpy()
|
|
49
|
-
pts *= input_scaling
|
|
50
115
|
diff_x = np.nanmax(pts[:, 0]) - np.nanmin(pts[:, 0])
|
|
51
116
|
diff_x = 0 if np.isnan(diff_x) else diff_x
|
|
52
117
|
max_length = np.maximum(max_length, diff_x)
|
sleap_nn/data/normalization.py
CHANGED
|
@@ -4,6 +4,36 @@ import torch
|
|
|
4
4
|
import torchvision.transforms.v2.functional as F
|
|
5
5
|
|
|
6
6
|
|
|
7
|
+
def normalize_on_gpu(image: torch.Tensor) -> torch.Tensor:
|
|
8
|
+
"""Normalize image tensor on GPU after transfer.
|
|
9
|
+
|
|
10
|
+
This function is called in the model's forward() method after the image has been
|
|
11
|
+
transferred to GPU. It converts uint8 images to float32 and normalizes to [0, 1].
|
|
12
|
+
|
|
13
|
+
By performing normalization on GPU after transfer, we reduce PCIe bandwidth by 4x
|
|
14
|
+
(transferring 1 byte/pixel as uint8 instead of 4 bytes/pixel as float32). This
|
|
15
|
+
provides up to 17x speedup for the transfer+normalization stage.
|
|
16
|
+
|
|
17
|
+
This function handles two cases:
|
|
18
|
+
1. uint8 tensor with values in [0, 255] -> convert to float32 and divide by 255
|
|
19
|
+
2. float32 tensor with values in [0, 255] (e.g., from preprocessing that cast to
|
|
20
|
+
float32 without normalizing) -> divide by 255
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
image: Tensor image that may be uint8 or float32 with values in [0, 255] range.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Float32 tensor normalized to [0, 1] range.
|
|
27
|
+
"""
|
|
28
|
+
if not torch.is_floating_point(image):
|
|
29
|
+
# uint8 -> float32 normalized
|
|
30
|
+
image = image.float() / 255.0
|
|
31
|
+
elif image.max() > 1.0:
|
|
32
|
+
# float32 but not normalized (values > 1 indicate [0, 255] range)
|
|
33
|
+
image = image / 255.0
|
|
34
|
+
return image
|
|
35
|
+
|
|
36
|
+
|
|
7
37
|
def convert_to_grayscale(image: torch.Tensor):
|
|
8
38
|
"""Convert given image to Grayscale image (single-channel).
|
|
9
39
|
|
|
@@ -38,8 +68,21 @@ def convert_to_rgb(image: torch.Tensor):
|
|
|
38
68
|
return image
|
|
39
69
|
|
|
40
70
|
|
|
41
|
-
def apply_normalization(image: torch.Tensor):
|
|
42
|
-
"""Normalize image tensor.
|
|
71
|
+
def apply_normalization(image: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
"""Normalize image tensor from uint8 [0, 255] to float32 [0, 1].
|
|
73
|
+
|
|
74
|
+
This function is used during training data preprocessing where augmentation
|
|
75
|
+
operations (kornia) require float32 input.
|
|
76
|
+
|
|
77
|
+
For inference, normalization is deferred to GPU via `normalize_on_gpu()` in the
|
|
78
|
+
model's forward() method to reduce PCIe bandwidth.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
image: Tensor image (typically uint8 with values in [0, 255]).
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Float32 tensor normalized to [0, 1] range.
|
|
85
|
+
"""
|
|
43
86
|
if not torch.is_floating_point(image):
|
|
44
87
|
image = image.to(torch.float32) / 255.0
|
|
45
88
|
return image
|
sleap_nn/data/providers.py
CHANGED
|
@@ -71,6 +71,8 @@ def process_lf(
|
|
|
71
71
|
for inst in instances_list:
|
|
72
72
|
if not inst.is_empty:
|
|
73
73
|
instances.append(inst.numpy())
|
|
74
|
+
if len(instances) == 0:
|
|
75
|
+
return None
|
|
74
76
|
instances = np.stack(instances, axis=0)
|
|
75
77
|
|
|
76
78
|
# Add singleton time dimension for single frames.
|
|
@@ -233,6 +235,8 @@ class LabelsReader(Thread):
|
|
|
233
235
|
instances_key: bool = False,
|
|
234
236
|
only_labeled_frames: bool = False,
|
|
235
237
|
only_suggested_frames: bool = False,
|
|
238
|
+
exclude_user_labeled: bool = False,
|
|
239
|
+
only_predicted_frames: bool = False,
|
|
236
240
|
):
|
|
237
241
|
"""Initialize attribute of the class."""
|
|
238
242
|
super().__init__()
|
|
@@ -245,6 +249,8 @@ class LabelsReader(Thread):
|
|
|
245
249
|
|
|
246
250
|
self.only_labeled_frames = only_labeled_frames
|
|
247
251
|
self.only_suggested_frames = only_suggested_frames
|
|
252
|
+
self.exclude_user_labeled = exclude_user_labeled
|
|
253
|
+
self.only_predicted_frames = only_predicted_frames
|
|
248
254
|
|
|
249
255
|
# Filter to only user labeled instances
|
|
250
256
|
if self.only_labeled_frames:
|
|
@@ -265,6 +271,20 @@ class LabelsReader(Thread):
|
|
|
265
271
|
)
|
|
266
272
|
self.filtered_lfs.append(new_lf)
|
|
267
273
|
|
|
274
|
+
# Filter out user labeled frames
|
|
275
|
+
elif self.exclude_user_labeled:
|
|
276
|
+
self.filtered_lfs = []
|
|
277
|
+
for lf in self.labels:
|
|
278
|
+
if not lf.has_user_instances:
|
|
279
|
+
self.filtered_lfs.append(lf)
|
|
280
|
+
|
|
281
|
+
# Filter to only predicted frames
|
|
282
|
+
elif self.only_predicted_frames:
|
|
283
|
+
self.filtered_lfs = []
|
|
284
|
+
for lf in self.labels:
|
|
285
|
+
if lf.has_predicted_instances:
|
|
286
|
+
self.filtered_lfs.append(lf)
|
|
287
|
+
|
|
268
288
|
else:
|
|
269
289
|
self.filtered_lfs = [lf for lf in self.labels]
|
|
270
290
|
|
|
@@ -300,6 +320,8 @@ class LabelsReader(Thread):
|
|
|
300
320
|
instances_key: bool = False,
|
|
301
321
|
only_labeled_frames: bool = False,
|
|
302
322
|
only_suggested_frames: bool = False,
|
|
323
|
+
exclude_user_labeled: bool = False,
|
|
324
|
+
only_predicted_frames: bool = False,
|
|
303
325
|
):
|
|
304
326
|
"""Create LabelsReader from a .slp filename."""
|
|
305
327
|
labels = sio.load_slp(filename)
|
|
@@ -310,6 +332,8 @@ class LabelsReader(Thread):
|
|
|
310
332
|
instances_key,
|
|
311
333
|
only_labeled_frames,
|
|
312
334
|
only_suggested_frames,
|
|
335
|
+
exclude_user_labeled,
|
|
336
|
+
only_predicted_frames,
|
|
313
337
|
)
|
|
314
338
|
|
|
315
339
|
def run(self):
|
|
@@ -333,6 +357,8 @@ class LabelsReader(Thread):
|
|
|
333
357
|
for inst in lf:
|
|
334
358
|
if not inst.is_empty:
|
|
335
359
|
instances.append(inst.numpy())
|
|
360
|
+
if len(instances) == 0:
|
|
361
|
+
continue
|
|
336
362
|
instances = np.stack(instances, axis=0)
|
|
337
363
|
|
|
338
364
|
# Add singleton time dimension for single frames.
|
sleap_nn/data/resizing.py
CHANGED
|
@@ -63,7 +63,7 @@ def apply_pad_to_stride(image: torch.Tensor, max_stride: int) -> torch.Tensor:
|
|
|
63
63
|
image,
|
|
64
64
|
(0, pad_width, 0, pad_height),
|
|
65
65
|
mode="constant",
|
|
66
|
-
)
|
|
66
|
+
)
|
|
67
67
|
return image
|
|
68
68
|
|
|
69
69
|
|
|
@@ -136,7 +136,7 @@ def apply_sizematcher(
|
|
|
136
136
|
image,
|
|
137
137
|
(0, pad_width, 0, pad_height),
|
|
138
138
|
mode="constant",
|
|
139
|
-
)
|
|
139
|
+
)
|
|
140
140
|
|
|
141
141
|
return image, eff_scale_ratio
|
|
142
142
|
else:
|
|
@@ -0,0 +1,414 @@
|
|
|
1
|
+
"""Skia-based augmentation functions that operate on uint8 tensors.
|
|
2
|
+
|
|
3
|
+
This module provides augmentation functions using skia-python that:
|
|
4
|
+
1. Match the exact API of sleap_nn.data.augmentation
|
|
5
|
+
2. Operate on uint8 tensors throughout (avoiding float32 conversions)
|
|
6
|
+
3. Provide ~1.5x faster augmentation compared to Kornia
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
from sleap_nn.data.skia_augmentation import (
|
|
10
|
+
apply_intensity_augmentation_skia,
|
|
11
|
+
apply_geometric_augmentation_skia,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Apply augmentations (uint8 in, uint8 out)
|
|
15
|
+
image, instances = apply_intensity_augmentation_skia(image, instances, **config)
|
|
16
|
+
image, instances = apply_geometric_augmentation_skia(image, instances, **config)
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from typing import Optional, Tuple
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch
|
|
22
|
+
import skia
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def apply_intensity_augmentation_skia(
|
|
26
|
+
image: torch.Tensor,
|
|
27
|
+
instances: torch.Tensor,
|
|
28
|
+
uniform_noise_min: float = 0.0,
|
|
29
|
+
uniform_noise_max: float = 0.04,
|
|
30
|
+
uniform_noise_p: float = 0.0,
|
|
31
|
+
gaussian_noise_mean: float = 0.02,
|
|
32
|
+
gaussian_noise_std: float = 0.004,
|
|
33
|
+
gaussian_noise_p: float = 0.0,
|
|
34
|
+
contrast_min: float = 0.5,
|
|
35
|
+
contrast_max: float = 2.0,
|
|
36
|
+
contrast_p: float = 0.0,
|
|
37
|
+
brightness_min: float = 1.0,
|
|
38
|
+
brightness_max: float = 1.0,
|
|
39
|
+
brightness_p: float = 0.0,
|
|
40
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
+
"""Apply intensity augmentations on uint8 image tensor.
|
|
42
|
+
|
|
43
|
+
Matches API of sleap_nn.data.augmentation.apply_intensity_augmentation.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
|
|
47
|
+
instances: Keypoints tensor (not modified, just passed through).
|
|
48
|
+
uniform_noise_min: Minimum uniform noise (0-1 scale, maps to 0-255).
|
|
49
|
+
uniform_noise_max: Maximum uniform noise (0-1 scale).
|
|
50
|
+
uniform_noise_p: Probability of uniform noise.
|
|
51
|
+
gaussian_noise_mean: Gaussian noise mean (0-1 scale).
|
|
52
|
+
gaussian_noise_std: Gaussian noise std (0-1 scale).
|
|
53
|
+
gaussian_noise_p: Probability of Gaussian noise.
|
|
54
|
+
contrast_min: Minimum contrast factor.
|
|
55
|
+
contrast_max: Maximum contrast factor.
|
|
56
|
+
contrast_p: Probability of contrast adjustment.
|
|
57
|
+
brightness_min: Minimum brightness factor.
|
|
58
|
+
brightness_max: Maximum brightness factor.
|
|
59
|
+
brightness_p: Probability of brightness adjustment.
|
|
60
|
+
|
|
61
|
+
Returns:
|
|
62
|
+
Tuple of (augmented_image, instances). Image dtype matches input.
|
|
63
|
+
"""
|
|
64
|
+
# Convert to numpy for Skia processing
|
|
65
|
+
is_float = image.dtype == torch.float32
|
|
66
|
+
if is_float:
|
|
67
|
+
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
68
|
+
else:
|
|
69
|
+
img_np = image[0].permute(1, 2, 0).numpy()
|
|
70
|
+
|
|
71
|
+
result = img_np.copy()
|
|
72
|
+
|
|
73
|
+
# Apply uniform noise (in uint8 space)
|
|
74
|
+
if uniform_noise_p > 0 and np.random.random() < uniform_noise_p:
|
|
75
|
+
noise = np.random.randint(
|
|
76
|
+
int(uniform_noise_min * 255),
|
|
77
|
+
int(uniform_noise_max * 255) + 1,
|
|
78
|
+
img_np.shape,
|
|
79
|
+
dtype=np.int16,
|
|
80
|
+
)
|
|
81
|
+
result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
|
82
|
+
|
|
83
|
+
# Apply Gaussian noise (in uint8 space)
|
|
84
|
+
if gaussian_noise_p > 0 and np.random.random() < gaussian_noise_p:
|
|
85
|
+
noise = np.random.normal(
|
|
86
|
+
gaussian_noise_mean * 255, gaussian_noise_std * 255, img_np.shape
|
|
87
|
+
).astype(np.int16)
|
|
88
|
+
result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
|
89
|
+
|
|
90
|
+
# Apply contrast using lookup table (pure uint8)
|
|
91
|
+
if contrast_p > 0 and np.random.random() < contrast_p:
|
|
92
|
+
factor = np.random.uniform(contrast_min, contrast_max)
|
|
93
|
+
lut = np.arange(256, dtype=np.float32)
|
|
94
|
+
lut = np.clip((lut - 127.5) * factor + 127.5, 0, 255).astype(np.uint8)
|
|
95
|
+
result = lut[result]
|
|
96
|
+
|
|
97
|
+
# Apply brightness using lookup table (pure uint8)
|
|
98
|
+
if brightness_p > 0 and np.random.random() < brightness_p:
|
|
99
|
+
factor = np.random.uniform(brightness_min, brightness_max)
|
|
100
|
+
lut = np.arange(256, dtype=np.float32)
|
|
101
|
+
lut = np.clip(lut * factor, 0, 255).astype(np.uint8)
|
|
102
|
+
result = lut[result]
|
|
103
|
+
|
|
104
|
+
# Convert back to tensor
|
|
105
|
+
result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
|
|
106
|
+
if is_float:
|
|
107
|
+
result_tensor = result_tensor.float() / 255.0
|
|
108
|
+
|
|
109
|
+
return result_tensor, instances
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def apply_geometric_augmentation_skia(
|
|
113
|
+
image: torch.Tensor,
|
|
114
|
+
instances: torch.Tensor,
|
|
115
|
+
rotation_min: float = -15.0,
|
|
116
|
+
rotation_max: float = 15.0,
|
|
117
|
+
rotation_p: Optional[float] = None,
|
|
118
|
+
scale_min: float = 0.9,
|
|
119
|
+
scale_max: float = 1.1,
|
|
120
|
+
scale_p: Optional[float] = None,
|
|
121
|
+
translate_width: float = 0.02,
|
|
122
|
+
translate_height: float = 0.02,
|
|
123
|
+
translate_p: Optional[float] = None,
|
|
124
|
+
affine_p: float = 0.0,
|
|
125
|
+
erase_scale_min: float = 0.0001,
|
|
126
|
+
erase_scale_max: float = 0.01,
|
|
127
|
+
erase_ratio_min: float = 1.0,
|
|
128
|
+
erase_ratio_max: float = 1.0,
|
|
129
|
+
erase_p: float = 0.0,
|
|
130
|
+
mixup_lambda_min: float = 0.01,
|
|
131
|
+
mixup_lambda_max: float = 0.05,
|
|
132
|
+
mixup_p: float = 0.0,
|
|
133
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
134
|
+
"""Apply geometric augmentations using Skia.
|
|
135
|
+
|
|
136
|
+
Matches API of sleap_nn.data.augmentation.apply_geometric_augmentation.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
|
|
140
|
+
instances: Keypoints tensor of shape (1, n_instances, n_nodes, 2) or (1, n_nodes, 2).
|
|
141
|
+
rotation_min: Minimum rotation angle in degrees.
|
|
142
|
+
rotation_max: Maximum rotation angle in degrees.
|
|
143
|
+
rotation_p: Probability of rotation (independent). None = use affine_p.
|
|
144
|
+
scale_min: Minimum scale factor.
|
|
145
|
+
scale_max: Maximum scale factor.
|
|
146
|
+
scale_p: Probability of scaling (independent). None = use affine_p.
|
|
147
|
+
translate_width: Max horizontal translation as fraction of width.
|
|
148
|
+
translate_height: Max vertical translation as fraction of height.
|
|
149
|
+
translate_p: Probability of translation (independent). None = use affine_p.
|
|
150
|
+
affine_p: Probability of bundled affine transform.
|
|
151
|
+
erase_scale_min: Min proportion of image to erase.
|
|
152
|
+
erase_scale_max: Max proportion of image to erase.
|
|
153
|
+
erase_ratio_min: Min aspect ratio of erased area.
|
|
154
|
+
erase_ratio_max: Max aspect ratio of erased area.
|
|
155
|
+
erase_p: Probability of random erasing.
|
|
156
|
+
mixup_lambda_min: Min mixup strength (not implemented).
|
|
157
|
+
mixup_lambda_max: Max mixup strength (not implemented).
|
|
158
|
+
mixup_p: Probability of mixup (not implemented).
|
|
159
|
+
|
|
160
|
+
Returns:
|
|
161
|
+
Tuple of (augmented_image, augmented_instances). Image dtype matches input.
|
|
162
|
+
"""
|
|
163
|
+
# Convert to numpy for Skia processing
|
|
164
|
+
is_float = image.dtype == torch.float32
|
|
165
|
+
if is_float:
|
|
166
|
+
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
167
|
+
else:
|
|
168
|
+
img_np = image[0].permute(1, 2, 0).numpy().copy()
|
|
169
|
+
|
|
170
|
+
h, w = img_np.shape[:2]
|
|
171
|
+
cx, cy = w / 2, h / 2
|
|
172
|
+
|
|
173
|
+
# Build transformation matrix
|
|
174
|
+
matrix = skia.Matrix()
|
|
175
|
+
has_transform = False
|
|
176
|
+
|
|
177
|
+
use_independent = (
|
|
178
|
+
rotation_p is not None or scale_p is not None or translate_p is not None
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
if use_independent:
|
|
182
|
+
if (
|
|
183
|
+
rotation_p is not None
|
|
184
|
+
and rotation_p > 0
|
|
185
|
+
and np.random.random() < rotation_p
|
|
186
|
+
):
|
|
187
|
+
angle = np.random.uniform(rotation_min, rotation_max)
|
|
188
|
+
rot_matrix = skia.Matrix()
|
|
189
|
+
rot_matrix.setRotate(angle, cx, cy)
|
|
190
|
+
matrix = matrix.preConcat(rot_matrix)
|
|
191
|
+
has_transform = True
|
|
192
|
+
|
|
193
|
+
if scale_p is not None and scale_p > 0 and np.random.random() < scale_p:
|
|
194
|
+
scale = np.random.uniform(scale_min, scale_max)
|
|
195
|
+
scale_matrix = skia.Matrix()
|
|
196
|
+
scale_matrix.setScale(scale, scale, cx, cy)
|
|
197
|
+
matrix = matrix.preConcat(scale_matrix)
|
|
198
|
+
has_transform = True
|
|
199
|
+
|
|
200
|
+
if (
|
|
201
|
+
translate_p is not None
|
|
202
|
+
and translate_p > 0
|
|
203
|
+
and np.random.random() < translate_p
|
|
204
|
+
):
|
|
205
|
+
tx = np.random.uniform(-translate_width, translate_width) * w
|
|
206
|
+
ty = np.random.uniform(-translate_height, translate_height) * h
|
|
207
|
+
trans_matrix = skia.Matrix()
|
|
208
|
+
trans_matrix.setTranslate(tx, ty)
|
|
209
|
+
matrix = matrix.preConcat(trans_matrix)
|
|
210
|
+
has_transform = True
|
|
211
|
+
|
|
212
|
+
elif affine_p > 0 and np.random.random() < affine_p:
|
|
213
|
+
angle = np.random.uniform(rotation_min, rotation_max)
|
|
214
|
+
scale = np.random.uniform(scale_min, scale_max)
|
|
215
|
+
tx = np.random.uniform(-translate_width, translate_width) * w
|
|
216
|
+
ty = np.random.uniform(-translate_height, translate_height) * h
|
|
217
|
+
|
|
218
|
+
matrix.setRotate(angle, cx, cy)
|
|
219
|
+
matrix.preScale(scale, scale, cx, cy)
|
|
220
|
+
matrix.preTranslate(tx, ty)
|
|
221
|
+
has_transform = True
|
|
222
|
+
|
|
223
|
+
# Apply geometric transform
|
|
224
|
+
if has_transform:
|
|
225
|
+
img_np = _transform_image_skia(img_np, matrix)
|
|
226
|
+
instances = _transform_keypoints_tensor(instances, matrix)
|
|
227
|
+
|
|
228
|
+
# Apply random erasing
|
|
229
|
+
if erase_p > 0 and np.random.random() < erase_p:
|
|
230
|
+
img_np = _apply_random_erase(
|
|
231
|
+
img_np, erase_scale_min, erase_scale_max, erase_ratio_min, erase_ratio_max
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
# Convert back to tensor
|
|
235
|
+
result_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
|
|
236
|
+
if is_float:
|
|
237
|
+
result_tensor = result_tensor.float() / 255.0
|
|
238
|
+
|
|
239
|
+
return result_tensor, instances
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _transform_image_skia(image: np.ndarray, matrix: skia.Matrix) -> np.ndarray:
|
|
243
|
+
"""Transform image using Skia matrix (uint8 in, uint8 out)."""
|
|
244
|
+
h, w = image.shape[:2]
|
|
245
|
+
channels = image.shape[2] if image.ndim == 3 else 1
|
|
246
|
+
|
|
247
|
+
# Skia needs RGBA
|
|
248
|
+
if channels == 1:
|
|
249
|
+
image_rgba = np.stack(
|
|
250
|
+
[image.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
|
|
251
|
+
)
|
|
252
|
+
elif channels == 3:
|
|
253
|
+
alpha = np.full((h, w, 1), 255, dtype=np.uint8)
|
|
254
|
+
image_rgba = np.concatenate([image, alpha], axis=-1)
|
|
255
|
+
else:
|
|
256
|
+
raise ValueError(f"Unsupported channels: {channels}")
|
|
257
|
+
|
|
258
|
+
image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
|
|
259
|
+
skia_image = skia.Image.fromarray(
|
|
260
|
+
image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
surface = skia.Surface(w, h)
|
|
264
|
+
canvas = surface.getCanvas()
|
|
265
|
+
canvas.clear(skia.Color4f(0, 0, 0, 1))
|
|
266
|
+
canvas.setMatrix(matrix)
|
|
267
|
+
|
|
268
|
+
paint = skia.Paint()
|
|
269
|
+
paint.setAntiAlias(True)
|
|
270
|
+
sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
|
|
271
|
+
canvas.drawImage(skia_image, 0, 0, sampling, paint)
|
|
272
|
+
|
|
273
|
+
result = surface.makeImageSnapshot().toarray()
|
|
274
|
+
|
|
275
|
+
if channels == 1:
|
|
276
|
+
return result[:, :, 0:1]
|
|
277
|
+
return result[:, :, :channels]
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def _transform_keypoints_tensor(
|
|
281
|
+
keypoints: torch.Tensor, matrix: skia.Matrix
|
|
282
|
+
) -> torch.Tensor:
|
|
283
|
+
"""Transform keypoints tensor using Skia matrix."""
|
|
284
|
+
if keypoints.numel() == 0:
|
|
285
|
+
return keypoints
|
|
286
|
+
|
|
287
|
+
original_shape = keypoints.shape
|
|
288
|
+
flat = keypoints.reshape(-1, 2).numpy()
|
|
289
|
+
|
|
290
|
+
# Handle NaN values
|
|
291
|
+
valid_mask = ~np.isnan(flat).any(axis=1)
|
|
292
|
+
transformed = flat.copy()
|
|
293
|
+
|
|
294
|
+
if valid_mask.any():
|
|
295
|
+
valid_pts = flat[valid_mask]
|
|
296
|
+
skia_pts = [skia.Point(float(p[0]), float(p[1])) for p in valid_pts]
|
|
297
|
+
mapped = matrix.mapPoints(skia_pts)
|
|
298
|
+
transformed[valid_mask] = np.array([[p.x(), p.y()] for p in mapped])
|
|
299
|
+
|
|
300
|
+
return torch.from_numpy(transformed.reshape(original_shape).astype(np.float32))
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def _apply_random_erase(
|
|
304
|
+
image: np.ndarray,
|
|
305
|
+
scale_min: float,
|
|
306
|
+
scale_max: float,
|
|
307
|
+
ratio_min: float,
|
|
308
|
+
ratio_max: float,
|
|
309
|
+
) -> np.ndarray:
|
|
310
|
+
"""Apply random erasing (uint8)."""
|
|
311
|
+
h, w = image.shape[:2]
|
|
312
|
+
area = h * w
|
|
313
|
+
|
|
314
|
+
erase_area = np.random.uniform(scale_min, scale_max) * area
|
|
315
|
+
aspect_ratio = np.random.uniform(ratio_min, ratio_max)
|
|
316
|
+
|
|
317
|
+
erase_h = int(np.sqrt(erase_area * aspect_ratio))
|
|
318
|
+
erase_w = int(np.sqrt(erase_area / aspect_ratio))
|
|
319
|
+
|
|
320
|
+
if erase_h >= h or erase_w >= w:
|
|
321
|
+
return image
|
|
322
|
+
|
|
323
|
+
y = np.random.randint(0, h - erase_h)
|
|
324
|
+
x = np.random.randint(0, w - erase_w)
|
|
325
|
+
|
|
326
|
+
result = image.copy()
|
|
327
|
+
channels = image.shape[2] if image.ndim == 3 else 1
|
|
328
|
+
fill = np.random.randint(0, 256, size=(channels,), dtype=np.uint8)
|
|
329
|
+
result[y : y + erase_h, x : x + erase_w] = fill
|
|
330
|
+
|
|
331
|
+
return result
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def crop_and_resize_skia(
|
|
335
|
+
image: torch.Tensor,
|
|
336
|
+
boxes: torch.Tensor,
|
|
337
|
+
size: Tuple[int, int],
|
|
338
|
+
) -> torch.Tensor:
|
|
339
|
+
"""Crop and resize image regions using Skia.
|
|
340
|
+
|
|
341
|
+
Replacement for kornia.geometry.transform.crop_and_resize.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
image: Input tensor of shape (1, C, H, W).
|
|
345
|
+
boxes: Bounding boxes tensor of shape (1, 4, 2) with corners:
|
|
346
|
+
[top-left, top-right, bottom-right, bottom-left].
|
|
347
|
+
size: Output size (height, width).
|
|
348
|
+
|
|
349
|
+
Returns:
|
|
350
|
+
Cropped and resized tensor of shape (1, C, out_h, out_w).
|
|
351
|
+
"""
|
|
352
|
+
is_float = image.dtype == torch.float32
|
|
353
|
+
if is_float:
|
|
354
|
+
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
355
|
+
else:
|
|
356
|
+
img_np = image[0].permute(1, 2, 0).numpy()
|
|
357
|
+
|
|
358
|
+
h, w = img_np.shape[:2]
|
|
359
|
+
out_h, out_w = size
|
|
360
|
+
channels = img_np.shape[2] if img_np.ndim == 3 else 1
|
|
361
|
+
|
|
362
|
+
# Get box coordinates (top-left and bottom-right)
|
|
363
|
+
box = boxes[0].numpy() # (4, 2)
|
|
364
|
+
x1, y1 = box[0] # top-left
|
|
365
|
+
x2, y2 = box[2] # bottom-right
|
|
366
|
+
|
|
367
|
+
crop_w = x2 - x1
|
|
368
|
+
crop_h = y2 - y1
|
|
369
|
+
|
|
370
|
+
# Create transformation matrix
|
|
371
|
+
matrix = skia.Matrix()
|
|
372
|
+
scale_x = out_w / crop_w
|
|
373
|
+
scale_y = out_h / crop_h
|
|
374
|
+
matrix.setScale(scale_x, scale_y)
|
|
375
|
+
matrix.preTranslate(-x1, -y1)
|
|
376
|
+
|
|
377
|
+
# Skia needs RGBA
|
|
378
|
+
if channels == 1:
|
|
379
|
+
image_rgba = np.stack(
|
|
380
|
+
[img_np.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
|
|
381
|
+
)
|
|
382
|
+
elif channels == 3:
|
|
383
|
+
alpha = np.full((h, w, 1), 255, dtype=np.uint8)
|
|
384
|
+
image_rgba = np.concatenate([img_np, alpha], axis=-1)
|
|
385
|
+
else:
|
|
386
|
+
raise ValueError(f"Unsupported channels: {channels}")
|
|
387
|
+
|
|
388
|
+
image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
|
|
389
|
+
skia_image = skia.Image.fromarray(
|
|
390
|
+
image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
surface = skia.Surface(out_w, out_h)
|
|
394
|
+
canvas = surface.getCanvas()
|
|
395
|
+
canvas.clear(skia.Color4f(0, 0, 0, 1))
|
|
396
|
+
canvas.setMatrix(matrix)
|
|
397
|
+
|
|
398
|
+
paint = skia.Paint()
|
|
399
|
+
paint.setAntiAlias(True)
|
|
400
|
+
sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
|
|
401
|
+
canvas.drawImage(skia_image, 0, 0, sampling, paint)
|
|
402
|
+
|
|
403
|
+
result = surface.makeImageSnapshot().toarray()
|
|
404
|
+
|
|
405
|
+
if channels == 1:
|
|
406
|
+
result = result[:, :, 0:1]
|
|
407
|
+
else:
|
|
408
|
+
result = result[:, :, :channels]
|
|
409
|
+
|
|
410
|
+
result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
|
|
411
|
+
if is_float:
|
|
412
|
+
result_tensor = result_tensor.float() / 255.0
|
|
413
|
+
|
|
414
|
+
return result_tensor
|