sleap-nn 0.1.0__py3-none-any.whl → 0.1.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sleap_nn/__init__.py +1 -1
- sleap_nn/architectures/convnext.py +0 -5
- sleap_nn/architectures/encoder_decoder.py +6 -25
- sleap_nn/architectures/swint.py +0 -8
- sleap_nn/cli.py +60 -364
- sleap_nn/config/data_config.py +5 -11
- sleap_nn/config/get_config.py +4 -5
- sleap_nn/config/trainer_config.py +0 -71
- sleap_nn/data/augmentation.py +241 -50
- sleap_nn/data/custom_datasets.py +34 -364
- sleap_nn/data/instance_cropping.py +1 -1
- sleap_nn/data/resizing.py +2 -2
- sleap_nn/data/utils.py +17 -135
- sleap_nn/evaluation.py +22 -81
- sleap_nn/inference/bottomup.py +20 -86
- sleap_nn/inference/peak_finding.py +19 -88
- sleap_nn/inference/predictors.py +117 -224
- sleap_nn/legacy_models.py +11 -65
- sleap_nn/predict.py +9 -37
- sleap_nn/train.py +4 -69
- sleap_nn/training/callbacks.py +105 -1046
- sleap_nn/training/lightning_modules.py +65 -602
- sleap_nn/training/model_trainer.py +204 -201
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/METADATA +3 -15
- sleap_nn-0.1.0a1.dist-info/RECORD +65 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/WHEEL +1 -1
- sleap_nn/data/skia_augmentation.py +0 -414
- sleap_nn/export/__init__.py +0 -21
- sleap_nn/export/cli.py +0 -1778
- sleap_nn/export/exporters/__init__.py +0 -51
- sleap_nn/export/exporters/onnx_exporter.py +0 -80
- sleap_nn/export/exporters/tensorrt_exporter.py +0 -291
- sleap_nn/export/metadata.py +0 -225
- sleap_nn/export/predictors/__init__.py +0 -63
- sleap_nn/export/predictors/base.py +0 -22
- sleap_nn/export/predictors/onnx.py +0 -154
- sleap_nn/export/predictors/tensorrt.py +0 -312
- sleap_nn/export/utils.py +0 -307
- sleap_nn/export/wrappers/__init__.py +0 -25
- sleap_nn/export/wrappers/base.py +0 -96
- sleap_nn/export/wrappers/bottomup.py +0 -243
- sleap_nn/export/wrappers/bottomup_multiclass.py +0 -195
- sleap_nn/export/wrappers/centered_instance.py +0 -56
- sleap_nn/export/wrappers/centroid.py +0 -58
- sleap_nn/export/wrappers/single_instance.py +0 -83
- sleap_nn/export/wrappers/topdown.py +0 -180
- sleap_nn/export/wrappers/topdown_multiclass.py +0 -304
- sleap_nn/inference/postprocessing.py +0 -284
- sleap_nn/training/schedulers.py +0 -191
- sleap_nn-0.1.0.dist-info/RECORD +0 -88
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/entry_points.txt +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/licenses/LICENSE +0 -0
- {sleap_nn-0.1.0.dist-info → sleap_nn-0.1.0a1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
sleap_nn/.DS_Store,sha256=HY8amA79eHkt7o5VUiNsMxkc9YwW6WIPyZbYRj_JdSU,6148
|
|
2
|
+
sleap_nn/__init__.py,sha256=l5Lwiad8GOurqkAhMwWw8-UcpH6af2TnMURf-oKj_U8,1362
|
|
3
|
+
sleap_nn/cli.py,sha256=U4hpEcOxK7a92GeItY95E2DRm5P1ME1GqU__mxaDcW0,21167
|
|
4
|
+
sleap_nn/evaluation.py,sha256=3u7y85wFoBgCwOB2xOGTJIDrd2dUPWOo4m0s0oW3da4,31095
|
|
5
|
+
sleap_nn/legacy_models.py,sha256=8aGK30DZv3pW2IKDBEWH1G2mrytjaxPQD4miPUehj0M,20258
|
|
6
|
+
sleap_nn/predict.py,sha256=8QKjRbS-L-6HF1NFJWioBPv3HSzUpFr2oGEB5hRJzQA,35523
|
|
7
|
+
sleap_nn/system_info.py,sha256=7tWe3y6s872nDbrZoHIdSs-w4w46Z4dEV2qCV-Fe7No,14711
|
|
8
|
+
sleap_nn/train.py,sha256=XvVhzMXL9rNQLx1-6jIcp5BAO1pR7AZjdphMn5ZX-_I,27558
|
|
9
|
+
sleap_nn/architectures/__init__.py,sha256=w0XxQcx-CYyooszzvxRkKWiJkUg-26IlwQoGna8gn40,46
|
|
10
|
+
sleap_nn/architectures/common.py,sha256=MLv-zdHsWL5Q2ct_Wv6SQbRS-5hrFtjK_pvBEfwx-vU,3660
|
|
11
|
+
sleap_nn/architectures/convnext.py,sha256=l9lMJDxIMb-9MI3ShOtVwbOUMuwOLtSQlxiVyYHqjvE,13953
|
|
12
|
+
sleap_nn/architectures/encoder_decoder.py,sha256=f3DUFJo6RrIUposdC3Ytyblr5J0tAeZ_si9dm_m_PhM,28339
|
|
13
|
+
sleap_nn/architectures/heads.py,sha256=5E-7kQ-b2gsL0EviQ8z3KS1DAAMT4F2ZnEzx7eSG5gg,21001
|
|
14
|
+
sleap_nn/architectures/model.py,sha256=1_dsP_4T9fsEVJjDt3er0haMKtbeM6w6JC6tc2jD0Gw,7139
|
|
15
|
+
sleap_nn/architectures/swint.py,sha256=S66Wd0j8Hp-rGlv1C60WSw3AwGyAyGetgfwpL0nIK_M,14687
|
|
16
|
+
sleap_nn/architectures/unet.py,sha256=rAy2Omi6tv1MNW2nBn0Tw-94Nw_-1wFfCT3-IUyPcgo,11723
|
|
17
|
+
sleap_nn/architectures/utils.py,sha256=L0KVs0gbtG8U75Sl40oH_r_w2ySawh3oQPqIGi54HGo,2171
|
|
18
|
+
sleap_nn/config/__init__.py,sha256=l0xV1uJsGJfMPfWAqlUR7Ivu4cSCWsP-3Y9ueyPESuk,42
|
|
19
|
+
sleap_nn/config/data_config.py,sha256=5a5YlXm4V9qGvkqgFNy6o0XJ_Q06UFjpYJXmNHfvXEI,24021
|
|
20
|
+
sleap_nn/config/get_config.py,sha256=rjNUffKU9z-ohLwrOVmJNGCqwUM93eh68h4KJfrSy8Y,42396
|
|
21
|
+
sleap_nn/config/model_config.py,sha256=XFIbqFno7IkX0Se5WF_2_7aUalAlC2SvpDe-uP2TttM,57582
|
|
22
|
+
sleap_nn/config/trainer_config.py,sha256=ZMXxns6VYakgYHRhkM541Eje76DdaTdDi4FFPNjJtP4,28413
|
|
23
|
+
sleap_nn/config/training_job_config.py,sha256=v12_ME_tBUg8JFwOxJNW4sDQn-SedDhiJOGz-TlRwT0,5861
|
|
24
|
+
sleap_nn/config/utils.py,sha256=GgWgVs7_N7ifsJ5OQG3_EyOagNyN3Dx7wS2BAlkaRkg,5553
|
|
25
|
+
sleap_nn/data/__init__.py,sha256=eMNvFJFa3gv5Rq8oK5wzo6zt1pOlwUGYf8EQii6bq7c,54
|
|
26
|
+
sleap_nn/data/augmentation.py,sha256=Kqw_DayPth_DBsmaO1G8Voou_-cYZuSPOjSQWSajgRI,13618
|
|
27
|
+
sleap_nn/data/confidence_maps.py,sha256=PTRqZWSAz1S7viJhxu7QgIC1aHiek97c_dCUsKUwG1o,6217
|
|
28
|
+
sleap_nn/data/custom_datasets.py,sha256=SO-aNB1-bB9DL5Zw-oGYDsliBxwI4iKX_FmwgZjKOgQ,99975
|
|
29
|
+
sleap_nn/data/edge_maps.py,sha256=75qG_7zHRw7fC8JUCVI2tzYakIoxxneWWmcrTwjcHPo,12519
|
|
30
|
+
sleap_nn/data/identity.py,sha256=7vNup6PudST4yDLyDT9wDO-cunRirTEvx4sP77xrlfk,5193
|
|
31
|
+
sleap_nn/data/instance_centroids.py,sha256=SF-3zJt_VMTbZI5ssbrvmZQZDd3684bn55EAtvcbQ6o,2172
|
|
32
|
+
sleap_nn/data/instance_cropping.py,sha256=2dYq5OTwkFN1PdMjoxyuMuHq1OEe03m3Vzqvcs_dkPE,8304
|
|
33
|
+
sleap_nn/data/normalization.py,sha256=5xEvcguG-fvAGObl4nWPZ9TEM5gvv0uYPGDuni34XII,2930
|
|
34
|
+
sleap_nn/data/providers.py,sha256=0x6GFP1s1c08ji4p0M5V6p-dhT4Z9c-SI_Aw1DWX-uM,14272
|
|
35
|
+
sleap_nn/data/resizing.py,sha256=YFpSQduIBkRK39FYmrqDL-v8zMySlEs6TJxh6zb_0ZU,5076
|
|
36
|
+
sleap_nn/data/utils.py,sha256=rT0w7KMOTlzaeKWq1TqjbgC4Lvjz_G96McllvEOqXx8,5641
|
|
37
|
+
sleap_nn/inference/__init__.py,sha256=eVkCmKrxHlDFJIlZTf8B5XEOcSyw-gPQymXMY5uShOM,170
|
|
38
|
+
sleap_nn/inference/bottomup.py,sha256=NqN-G8TzAOsvCoL3bttEjA1iGsuveLOnOCXIUeFCdSA,13684
|
|
39
|
+
sleap_nn/inference/identity.py,sha256=GjNDL9MfGqNyQaK4AE8JQCAE8gpMuE_Y-3r3Gpa53CE,6540
|
|
40
|
+
sleap_nn/inference/paf_grouping.py,sha256=7Fo9lCAj-zcHgv5rI5LIMYGcixCGNt_ZbSNs8Dik7l8,69973
|
|
41
|
+
sleap_nn/inference/peak_finding.py,sha256=L9LdYKt_Bfw7cxo6xEpgF8wXcZAwq5plCfmKJ839N40,13014
|
|
42
|
+
sleap_nn/inference/predictors.py,sha256=U114RlgOXKGm5iz1lnTfE3aN9S0WCh6gWhVP3KVewfc,158046
|
|
43
|
+
sleap_nn/inference/provenance.py,sha256=0BekXyvpLMb0Vv6DjpctlLduG9RN-Q8jt5zDm783eZE,11204
|
|
44
|
+
sleap_nn/inference/single_instance.py,sha256=rOns_5TsJ1rb-lwmHG3ZY-pOhXGN2D-SfW9RmBxxzcI,4089
|
|
45
|
+
sleap_nn/inference/topdown.py,sha256=Ha0Nwx-XCH_rebIuIGhP0qW68QpjLB3XRr9rxt05JLs,35108
|
|
46
|
+
sleap_nn/inference/utils.py,sha256=JnaJK4S_qLtHkWOSkHf4oRZjOmgnU9BGADQnntgGxxs,4689
|
|
47
|
+
sleap_nn/tracking/__init__.py,sha256=rGR35wpSW-n5d3cMiQUzQQ_Dy5II5DPjlXAoPw2QhmM,31
|
|
48
|
+
sleap_nn/tracking/track_instance.py,sha256=9k0uVy9VmpleaLcJh7sVWSeFUPXiw7yj95EYNdXJcks,1373
|
|
49
|
+
sleap_nn/tracking/tracker.py,sha256=_WT-HFruzyOsvcq3AtLm3vnI9MYSwyBmq-HlQvj1vmU,41955
|
|
50
|
+
sleap_nn/tracking/utils.py,sha256=uHVd_mzzZjviVDdLSKXJJ1T96n5ObKvkqIuGsl9Yy8U,11276
|
|
51
|
+
sleap_nn/tracking/candidates/__init__.py,sha256=1O7NObIwshM7j1rLHmImbFphvkM9wY1j4j1TvO5scSE,49
|
|
52
|
+
sleap_nn/tracking/candidates/fixed_window.py,sha256=D80KMlTnenuQveQVVhk9j0G8yx6K324C7nMLHgG76e0,6296
|
|
53
|
+
sleap_nn/tracking/candidates/local_queues.py,sha256=Nx3R5wwEwq0gbfH-fi3oOumfkQo8_sYe5GN47pD9Be8,7305
|
|
54
|
+
sleap_nn/training/__init__.py,sha256=vNTKsIJPZHJwFSKn5PmjiiRJunR_9e7y4_v0S6rdF8U,32
|
|
55
|
+
sleap_nn/training/callbacks.py,sha256=TVnQ6plNC2MnlTiY2rSCRuw2WRk5cQSziek_VPUcOEg,25994
|
|
56
|
+
sleap_nn/training/lightning_modules.py,sha256=G3c4xJkYWW-iSRawzkgTqkGd4lTsbPiMTcB5Nvq7jes,85512
|
|
57
|
+
sleap_nn/training/losses.py,sha256=gbdinUURh4QUzjmNd2UJpt4FXwecqKy9gHr65JZ1bZk,1632
|
|
58
|
+
sleap_nn/training/model_trainer.py,sha256=loCmEX0DfBtdV_pN-W8s31fn2_L-lbpWaq3OQXeSp-0,59337
|
|
59
|
+
sleap_nn/training/utils.py,sha256=ivdkZEI0DkTCm6NPszsaDOh9jSfozkONZdl6TvvQUWI,20398
|
|
60
|
+
sleap_nn-0.1.0a1.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
|
61
|
+
sleap_nn-0.1.0a1.dist-info/METADATA,sha256=h3d4WPIu_JunY32jaRqJ4-fXp4KruTWT57FWb3L6dps,5637
|
|
62
|
+
sleap_nn-0.1.0a1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
63
|
+
sleap_nn-0.1.0a1.dist-info/entry_points.txt,sha256=zfl5Y3hidZxWBvo8qXvu5piJAXJ_l6v7xVFm0gNiUoI,46
|
|
64
|
+
sleap_nn-0.1.0a1.dist-info/top_level.txt,sha256=Kz68iQ55K75LWgSeqz4V4SCMGeFFYH-KGBOyhQh3xZE,9
|
|
65
|
+
sleap_nn-0.1.0a1.dist-info/RECORD,,
|
|
@@ -1,414 +0,0 @@
|
|
|
1
|
-
"""Skia-based augmentation functions that operate on uint8 tensors.
|
|
2
|
-
|
|
3
|
-
This module provides augmentation functions using skia-python that:
|
|
4
|
-
1. Match the exact API of sleap_nn.data.augmentation
|
|
5
|
-
2. Operate on uint8 tensors throughout (avoiding float32 conversions)
|
|
6
|
-
3. Provide ~1.5x faster augmentation compared to Kornia
|
|
7
|
-
|
|
8
|
-
Usage:
|
|
9
|
-
from sleap_nn.data.skia_augmentation import (
|
|
10
|
-
apply_intensity_augmentation_skia,
|
|
11
|
-
apply_geometric_augmentation_skia,
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
# Apply augmentations (uint8 in, uint8 out)
|
|
15
|
-
image, instances = apply_intensity_augmentation_skia(image, instances, **config)
|
|
16
|
-
image, instances = apply_geometric_augmentation_skia(image, instances, **config)
|
|
17
|
-
"""
|
|
18
|
-
|
|
19
|
-
from typing import Optional, Tuple
|
|
20
|
-
import numpy as np
|
|
21
|
-
import torch
|
|
22
|
-
import skia
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
def apply_intensity_augmentation_skia(
|
|
26
|
-
image: torch.Tensor,
|
|
27
|
-
instances: torch.Tensor,
|
|
28
|
-
uniform_noise_min: float = 0.0,
|
|
29
|
-
uniform_noise_max: float = 0.04,
|
|
30
|
-
uniform_noise_p: float = 0.0,
|
|
31
|
-
gaussian_noise_mean: float = 0.02,
|
|
32
|
-
gaussian_noise_std: float = 0.004,
|
|
33
|
-
gaussian_noise_p: float = 0.0,
|
|
34
|
-
contrast_min: float = 0.5,
|
|
35
|
-
contrast_max: float = 2.0,
|
|
36
|
-
contrast_p: float = 0.0,
|
|
37
|
-
brightness_min: float = 1.0,
|
|
38
|
-
brightness_max: float = 1.0,
|
|
39
|
-
brightness_p: float = 0.0,
|
|
40
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
41
|
-
"""Apply intensity augmentations on uint8 image tensor.
|
|
42
|
-
|
|
43
|
-
Matches API of sleap_nn.data.augmentation.apply_intensity_augmentation.
|
|
44
|
-
|
|
45
|
-
Args:
|
|
46
|
-
image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
|
|
47
|
-
instances: Keypoints tensor (not modified, just passed through).
|
|
48
|
-
uniform_noise_min: Minimum uniform noise (0-1 scale, maps to 0-255).
|
|
49
|
-
uniform_noise_max: Maximum uniform noise (0-1 scale).
|
|
50
|
-
uniform_noise_p: Probability of uniform noise.
|
|
51
|
-
gaussian_noise_mean: Gaussian noise mean (0-1 scale).
|
|
52
|
-
gaussian_noise_std: Gaussian noise std (0-1 scale).
|
|
53
|
-
gaussian_noise_p: Probability of Gaussian noise.
|
|
54
|
-
contrast_min: Minimum contrast factor.
|
|
55
|
-
contrast_max: Maximum contrast factor.
|
|
56
|
-
contrast_p: Probability of contrast adjustment.
|
|
57
|
-
brightness_min: Minimum brightness factor.
|
|
58
|
-
brightness_max: Maximum brightness factor.
|
|
59
|
-
brightness_p: Probability of brightness adjustment.
|
|
60
|
-
|
|
61
|
-
Returns:
|
|
62
|
-
Tuple of (augmented_image, instances). Image dtype matches input.
|
|
63
|
-
"""
|
|
64
|
-
# Convert to numpy for Skia processing
|
|
65
|
-
is_float = image.dtype == torch.float32
|
|
66
|
-
if is_float:
|
|
67
|
-
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
68
|
-
else:
|
|
69
|
-
img_np = image[0].permute(1, 2, 0).numpy()
|
|
70
|
-
|
|
71
|
-
result = img_np.copy()
|
|
72
|
-
|
|
73
|
-
# Apply uniform noise (in uint8 space)
|
|
74
|
-
if uniform_noise_p > 0 and np.random.random() < uniform_noise_p:
|
|
75
|
-
noise = np.random.randint(
|
|
76
|
-
int(uniform_noise_min * 255),
|
|
77
|
-
int(uniform_noise_max * 255) + 1,
|
|
78
|
-
img_np.shape,
|
|
79
|
-
dtype=np.int16,
|
|
80
|
-
)
|
|
81
|
-
result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
|
82
|
-
|
|
83
|
-
# Apply Gaussian noise (in uint8 space)
|
|
84
|
-
if gaussian_noise_p > 0 and np.random.random() < gaussian_noise_p:
|
|
85
|
-
noise = np.random.normal(
|
|
86
|
-
gaussian_noise_mean * 255, gaussian_noise_std * 255, img_np.shape
|
|
87
|
-
).astype(np.int16)
|
|
88
|
-
result = np.clip(result.astype(np.int16) + noise, 0, 255).astype(np.uint8)
|
|
89
|
-
|
|
90
|
-
# Apply contrast using lookup table (pure uint8)
|
|
91
|
-
if contrast_p > 0 and np.random.random() < contrast_p:
|
|
92
|
-
factor = np.random.uniform(contrast_min, contrast_max)
|
|
93
|
-
lut = np.arange(256, dtype=np.float32)
|
|
94
|
-
lut = np.clip((lut - 127.5) * factor + 127.5, 0, 255).astype(np.uint8)
|
|
95
|
-
result = lut[result]
|
|
96
|
-
|
|
97
|
-
# Apply brightness using lookup table (pure uint8)
|
|
98
|
-
if brightness_p > 0 and np.random.random() < brightness_p:
|
|
99
|
-
factor = np.random.uniform(brightness_min, brightness_max)
|
|
100
|
-
lut = np.arange(256, dtype=np.float32)
|
|
101
|
-
lut = np.clip(lut * factor, 0, 255).astype(np.uint8)
|
|
102
|
-
result = lut[result]
|
|
103
|
-
|
|
104
|
-
# Convert back to tensor
|
|
105
|
-
result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
|
|
106
|
-
if is_float:
|
|
107
|
-
result_tensor = result_tensor.float() / 255.0
|
|
108
|
-
|
|
109
|
-
return result_tensor, instances
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
def apply_geometric_augmentation_skia(
|
|
113
|
-
image: torch.Tensor,
|
|
114
|
-
instances: torch.Tensor,
|
|
115
|
-
rotation_min: float = -15.0,
|
|
116
|
-
rotation_max: float = 15.0,
|
|
117
|
-
rotation_p: Optional[float] = None,
|
|
118
|
-
scale_min: float = 0.9,
|
|
119
|
-
scale_max: float = 1.1,
|
|
120
|
-
scale_p: Optional[float] = None,
|
|
121
|
-
translate_width: float = 0.02,
|
|
122
|
-
translate_height: float = 0.02,
|
|
123
|
-
translate_p: Optional[float] = None,
|
|
124
|
-
affine_p: float = 0.0,
|
|
125
|
-
erase_scale_min: float = 0.0001,
|
|
126
|
-
erase_scale_max: float = 0.01,
|
|
127
|
-
erase_ratio_min: float = 1.0,
|
|
128
|
-
erase_ratio_max: float = 1.0,
|
|
129
|
-
erase_p: float = 0.0,
|
|
130
|
-
mixup_lambda_min: float = 0.01,
|
|
131
|
-
mixup_lambda_max: float = 0.05,
|
|
132
|
-
mixup_p: float = 0.0,
|
|
133
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
134
|
-
"""Apply geometric augmentations using Skia.
|
|
135
|
-
|
|
136
|
-
Matches API of sleap_nn.data.augmentation.apply_geometric_augmentation.
|
|
137
|
-
|
|
138
|
-
Args:
|
|
139
|
-
image: Input tensor of shape (1, C, H, W) with dtype uint8 or float32.
|
|
140
|
-
instances: Keypoints tensor of shape (1, n_instances, n_nodes, 2) or (1, n_nodes, 2).
|
|
141
|
-
rotation_min: Minimum rotation angle in degrees.
|
|
142
|
-
rotation_max: Maximum rotation angle in degrees.
|
|
143
|
-
rotation_p: Probability of rotation (independent). None = use affine_p.
|
|
144
|
-
scale_min: Minimum scale factor.
|
|
145
|
-
scale_max: Maximum scale factor.
|
|
146
|
-
scale_p: Probability of scaling (independent). None = use affine_p.
|
|
147
|
-
translate_width: Max horizontal translation as fraction of width.
|
|
148
|
-
translate_height: Max vertical translation as fraction of height.
|
|
149
|
-
translate_p: Probability of translation (independent). None = use affine_p.
|
|
150
|
-
affine_p: Probability of bundled affine transform.
|
|
151
|
-
erase_scale_min: Min proportion of image to erase.
|
|
152
|
-
erase_scale_max: Max proportion of image to erase.
|
|
153
|
-
erase_ratio_min: Min aspect ratio of erased area.
|
|
154
|
-
erase_ratio_max: Max aspect ratio of erased area.
|
|
155
|
-
erase_p: Probability of random erasing.
|
|
156
|
-
mixup_lambda_min: Min mixup strength (not implemented).
|
|
157
|
-
mixup_lambda_max: Max mixup strength (not implemented).
|
|
158
|
-
mixup_p: Probability of mixup (not implemented).
|
|
159
|
-
|
|
160
|
-
Returns:
|
|
161
|
-
Tuple of (augmented_image, augmented_instances). Image dtype matches input.
|
|
162
|
-
"""
|
|
163
|
-
# Convert to numpy for Skia processing
|
|
164
|
-
is_float = image.dtype == torch.float32
|
|
165
|
-
if is_float:
|
|
166
|
-
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
167
|
-
else:
|
|
168
|
-
img_np = image[0].permute(1, 2, 0).numpy().copy()
|
|
169
|
-
|
|
170
|
-
h, w = img_np.shape[:2]
|
|
171
|
-
cx, cy = w / 2, h / 2
|
|
172
|
-
|
|
173
|
-
# Build transformation matrix
|
|
174
|
-
matrix = skia.Matrix()
|
|
175
|
-
has_transform = False
|
|
176
|
-
|
|
177
|
-
use_independent = (
|
|
178
|
-
rotation_p is not None or scale_p is not None or translate_p is not None
|
|
179
|
-
)
|
|
180
|
-
|
|
181
|
-
if use_independent:
|
|
182
|
-
if (
|
|
183
|
-
rotation_p is not None
|
|
184
|
-
and rotation_p > 0
|
|
185
|
-
and np.random.random() < rotation_p
|
|
186
|
-
):
|
|
187
|
-
angle = np.random.uniform(rotation_min, rotation_max)
|
|
188
|
-
rot_matrix = skia.Matrix()
|
|
189
|
-
rot_matrix.setRotate(angle, cx, cy)
|
|
190
|
-
matrix = matrix.preConcat(rot_matrix)
|
|
191
|
-
has_transform = True
|
|
192
|
-
|
|
193
|
-
if scale_p is not None and scale_p > 0 and np.random.random() < scale_p:
|
|
194
|
-
scale = np.random.uniform(scale_min, scale_max)
|
|
195
|
-
scale_matrix = skia.Matrix()
|
|
196
|
-
scale_matrix.setScale(scale, scale, cx, cy)
|
|
197
|
-
matrix = matrix.preConcat(scale_matrix)
|
|
198
|
-
has_transform = True
|
|
199
|
-
|
|
200
|
-
if (
|
|
201
|
-
translate_p is not None
|
|
202
|
-
and translate_p > 0
|
|
203
|
-
and np.random.random() < translate_p
|
|
204
|
-
):
|
|
205
|
-
tx = np.random.uniform(-translate_width, translate_width) * w
|
|
206
|
-
ty = np.random.uniform(-translate_height, translate_height) * h
|
|
207
|
-
trans_matrix = skia.Matrix()
|
|
208
|
-
trans_matrix.setTranslate(tx, ty)
|
|
209
|
-
matrix = matrix.preConcat(trans_matrix)
|
|
210
|
-
has_transform = True
|
|
211
|
-
|
|
212
|
-
elif affine_p > 0 and np.random.random() < affine_p:
|
|
213
|
-
angle = np.random.uniform(rotation_min, rotation_max)
|
|
214
|
-
scale = np.random.uniform(scale_min, scale_max)
|
|
215
|
-
tx = np.random.uniform(-translate_width, translate_width) * w
|
|
216
|
-
ty = np.random.uniform(-translate_height, translate_height) * h
|
|
217
|
-
|
|
218
|
-
matrix.setRotate(angle, cx, cy)
|
|
219
|
-
matrix.preScale(scale, scale, cx, cy)
|
|
220
|
-
matrix.preTranslate(tx, ty)
|
|
221
|
-
has_transform = True
|
|
222
|
-
|
|
223
|
-
# Apply geometric transform
|
|
224
|
-
if has_transform:
|
|
225
|
-
img_np = _transform_image_skia(img_np, matrix)
|
|
226
|
-
instances = _transform_keypoints_tensor(instances, matrix)
|
|
227
|
-
|
|
228
|
-
# Apply random erasing
|
|
229
|
-
if erase_p > 0 and np.random.random() < erase_p:
|
|
230
|
-
img_np = _apply_random_erase(
|
|
231
|
-
img_np, erase_scale_min, erase_scale_max, erase_ratio_min, erase_ratio_max
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
# Convert back to tensor
|
|
235
|
-
result_tensor = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0)
|
|
236
|
-
if is_float:
|
|
237
|
-
result_tensor = result_tensor.float() / 255.0
|
|
238
|
-
|
|
239
|
-
return result_tensor, instances
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
def _transform_image_skia(image: np.ndarray, matrix: skia.Matrix) -> np.ndarray:
|
|
243
|
-
"""Transform image using Skia matrix (uint8 in, uint8 out)."""
|
|
244
|
-
h, w = image.shape[:2]
|
|
245
|
-
channels = image.shape[2] if image.ndim == 3 else 1
|
|
246
|
-
|
|
247
|
-
# Skia needs RGBA
|
|
248
|
-
if channels == 1:
|
|
249
|
-
image_rgba = np.stack(
|
|
250
|
-
[image.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
|
|
251
|
-
)
|
|
252
|
-
elif channels == 3:
|
|
253
|
-
alpha = np.full((h, w, 1), 255, dtype=np.uint8)
|
|
254
|
-
image_rgba = np.concatenate([image, alpha], axis=-1)
|
|
255
|
-
else:
|
|
256
|
-
raise ValueError(f"Unsupported channels: {channels}")
|
|
257
|
-
|
|
258
|
-
image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
|
|
259
|
-
skia_image = skia.Image.fromarray(
|
|
260
|
-
image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
surface = skia.Surface(w, h)
|
|
264
|
-
canvas = surface.getCanvas()
|
|
265
|
-
canvas.clear(skia.Color4f(0, 0, 0, 1))
|
|
266
|
-
canvas.setMatrix(matrix)
|
|
267
|
-
|
|
268
|
-
paint = skia.Paint()
|
|
269
|
-
paint.setAntiAlias(True)
|
|
270
|
-
sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
|
|
271
|
-
canvas.drawImage(skia_image, 0, 0, sampling, paint)
|
|
272
|
-
|
|
273
|
-
result = surface.makeImageSnapshot().toarray()
|
|
274
|
-
|
|
275
|
-
if channels == 1:
|
|
276
|
-
return result[:, :, 0:1]
|
|
277
|
-
return result[:, :, :channels]
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
def _transform_keypoints_tensor(
|
|
281
|
-
keypoints: torch.Tensor, matrix: skia.Matrix
|
|
282
|
-
) -> torch.Tensor:
|
|
283
|
-
"""Transform keypoints tensor using Skia matrix."""
|
|
284
|
-
if keypoints.numel() == 0:
|
|
285
|
-
return keypoints
|
|
286
|
-
|
|
287
|
-
original_shape = keypoints.shape
|
|
288
|
-
flat = keypoints.reshape(-1, 2).numpy()
|
|
289
|
-
|
|
290
|
-
# Handle NaN values
|
|
291
|
-
valid_mask = ~np.isnan(flat).any(axis=1)
|
|
292
|
-
transformed = flat.copy()
|
|
293
|
-
|
|
294
|
-
if valid_mask.any():
|
|
295
|
-
valid_pts = flat[valid_mask]
|
|
296
|
-
skia_pts = [skia.Point(float(p[0]), float(p[1])) for p in valid_pts]
|
|
297
|
-
mapped = matrix.mapPoints(skia_pts)
|
|
298
|
-
transformed[valid_mask] = np.array([[p.x(), p.y()] for p in mapped])
|
|
299
|
-
|
|
300
|
-
return torch.from_numpy(transformed.reshape(original_shape).astype(np.float32))
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
def _apply_random_erase(
|
|
304
|
-
image: np.ndarray,
|
|
305
|
-
scale_min: float,
|
|
306
|
-
scale_max: float,
|
|
307
|
-
ratio_min: float,
|
|
308
|
-
ratio_max: float,
|
|
309
|
-
) -> np.ndarray:
|
|
310
|
-
"""Apply random erasing (uint8)."""
|
|
311
|
-
h, w = image.shape[:2]
|
|
312
|
-
area = h * w
|
|
313
|
-
|
|
314
|
-
erase_area = np.random.uniform(scale_min, scale_max) * area
|
|
315
|
-
aspect_ratio = np.random.uniform(ratio_min, ratio_max)
|
|
316
|
-
|
|
317
|
-
erase_h = int(np.sqrt(erase_area * aspect_ratio))
|
|
318
|
-
erase_w = int(np.sqrt(erase_area / aspect_ratio))
|
|
319
|
-
|
|
320
|
-
if erase_h >= h or erase_w >= w:
|
|
321
|
-
return image
|
|
322
|
-
|
|
323
|
-
y = np.random.randint(0, h - erase_h)
|
|
324
|
-
x = np.random.randint(0, w - erase_w)
|
|
325
|
-
|
|
326
|
-
result = image.copy()
|
|
327
|
-
channels = image.shape[2] if image.ndim == 3 else 1
|
|
328
|
-
fill = np.random.randint(0, 256, size=(channels,), dtype=np.uint8)
|
|
329
|
-
result[y : y + erase_h, x : x + erase_w] = fill
|
|
330
|
-
|
|
331
|
-
return result
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
def crop_and_resize_skia(
|
|
335
|
-
image: torch.Tensor,
|
|
336
|
-
boxes: torch.Tensor,
|
|
337
|
-
size: Tuple[int, int],
|
|
338
|
-
) -> torch.Tensor:
|
|
339
|
-
"""Crop and resize image regions using Skia.
|
|
340
|
-
|
|
341
|
-
Replacement for kornia.geometry.transform.crop_and_resize.
|
|
342
|
-
|
|
343
|
-
Args:
|
|
344
|
-
image: Input tensor of shape (1, C, H, W).
|
|
345
|
-
boxes: Bounding boxes tensor of shape (1, 4, 2) with corners:
|
|
346
|
-
[top-left, top-right, bottom-right, bottom-left].
|
|
347
|
-
size: Output size (height, width).
|
|
348
|
-
|
|
349
|
-
Returns:
|
|
350
|
-
Cropped and resized tensor of shape (1, C, out_h, out_w).
|
|
351
|
-
"""
|
|
352
|
-
is_float = image.dtype == torch.float32
|
|
353
|
-
if is_float:
|
|
354
|
-
img_np = (image[0].permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
|
355
|
-
else:
|
|
356
|
-
img_np = image[0].permute(1, 2, 0).numpy()
|
|
357
|
-
|
|
358
|
-
h, w = img_np.shape[:2]
|
|
359
|
-
out_h, out_w = size
|
|
360
|
-
channels = img_np.shape[2] if img_np.ndim == 3 else 1
|
|
361
|
-
|
|
362
|
-
# Get box coordinates (top-left and bottom-right)
|
|
363
|
-
box = boxes[0].numpy() # (4, 2)
|
|
364
|
-
x1, y1 = box[0] # top-left
|
|
365
|
-
x2, y2 = box[2] # bottom-right
|
|
366
|
-
|
|
367
|
-
crop_w = x2 - x1
|
|
368
|
-
crop_h = y2 - y1
|
|
369
|
-
|
|
370
|
-
# Create transformation matrix
|
|
371
|
-
matrix = skia.Matrix()
|
|
372
|
-
scale_x = out_w / crop_w
|
|
373
|
-
scale_y = out_h / crop_h
|
|
374
|
-
matrix.setScale(scale_x, scale_y)
|
|
375
|
-
matrix.preTranslate(-x1, -y1)
|
|
376
|
-
|
|
377
|
-
# Skia needs RGBA
|
|
378
|
-
if channels == 1:
|
|
379
|
-
image_rgba = np.stack(
|
|
380
|
-
[img_np.squeeze()] * 3 + [np.full((h, w), 255, dtype=np.uint8)], axis=-1
|
|
381
|
-
)
|
|
382
|
-
elif channels == 3:
|
|
383
|
-
alpha = np.full((h, w, 1), 255, dtype=np.uint8)
|
|
384
|
-
image_rgba = np.concatenate([img_np, alpha], axis=-1)
|
|
385
|
-
else:
|
|
386
|
-
raise ValueError(f"Unsupported channels: {channels}")
|
|
387
|
-
|
|
388
|
-
image_rgba = np.ascontiguousarray(image_rgba, dtype=np.uint8)
|
|
389
|
-
skia_image = skia.Image.fromarray(
|
|
390
|
-
image_rgba, colorType=skia.ColorType.kRGBA_8888_ColorType
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
surface = skia.Surface(out_w, out_h)
|
|
394
|
-
canvas = surface.getCanvas()
|
|
395
|
-
canvas.clear(skia.Color4f(0, 0, 0, 1))
|
|
396
|
-
canvas.setMatrix(matrix)
|
|
397
|
-
|
|
398
|
-
paint = skia.Paint()
|
|
399
|
-
paint.setAntiAlias(True)
|
|
400
|
-
sampling = skia.SamplingOptions(skia.FilterMode.kLinear)
|
|
401
|
-
canvas.drawImage(skia_image, 0, 0, sampling, paint)
|
|
402
|
-
|
|
403
|
-
result = surface.makeImageSnapshot().toarray()
|
|
404
|
-
|
|
405
|
-
if channels == 1:
|
|
406
|
-
result = result[:, :, 0:1]
|
|
407
|
-
else:
|
|
408
|
-
result = result[:, :, :channels]
|
|
409
|
-
|
|
410
|
-
result_tensor = torch.from_numpy(result).permute(2, 0, 1).unsqueeze(0)
|
|
411
|
-
if is_float:
|
|
412
|
-
result_tensor = result_tensor.float() / 255.0
|
|
413
|
-
|
|
414
|
-
return result_tensor
|
sleap_nn/export/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
"""Export utilities for sleap-nn."""
|
|
2
|
-
|
|
3
|
-
from sleap_nn.export.exporters import export_model, export_to_onnx, export_to_tensorrt
|
|
4
|
-
from sleap_nn.export.metadata import ExportMetadata
|
|
5
|
-
from sleap_nn.export.predictors import (
|
|
6
|
-
load_exported_model,
|
|
7
|
-
ONNXPredictor,
|
|
8
|
-
TensorRTPredictor,
|
|
9
|
-
)
|
|
10
|
-
from sleap_nn.export.utils import build_bottomup_candidate_template
|
|
11
|
-
|
|
12
|
-
__all__ = [
|
|
13
|
-
"export_model",
|
|
14
|
-
"export_to_onnx",
|
|
15
|
-
"export_to_tensorrt",
|
|
16
|
-
"load_exported_model",
|
|
17
|
-
"ONNXPredictor",
|
|
18
|
-
"TensorRTPredictor",
|
|
19
|
-
"ExportMetadata",
|
|
20
|
-
"build_bottomup_candidate_template",
|
|
21
|
-
]
|