nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1238 @@
|
|
|
1
|
+
"""Geometric distortion transforms for image augmentation.
|
|
2
|
+
|
|
3
|
+
This module provides various geometric distortion transformations that modify the spatial arrangement
|
|
4
|
+
of pixels in images while preserving their intensity values. These transforms can create
|
|
5
|
+
non-rigid deformations that are useful for data augmentation, especially when training models
|
|
6
|
+
that need to be robust to geometric variations.
|
|
7
|
+
|
|
8
|
+
Available transforms:
|
|
9
|
+
- ElasticTransform: Creates random elastic deformations by displacing pixels along random vectors
|
|
10
|
+
- GridDistortion: Distorts the image by moving the nodes of a grid placed on the image
|
|
11
|
+
- OpticalDistortion: Simulates lens distortion effects (barrel/pincushion) using camera or fisheye models
|
|
12
|
+
- PiecewiseAffine: Divides the image into a grid and applies random affine transformations to each cell
|
|
13
|
+
- ThinPlateSpline: Applies smooth deformations based on the thin plate spline interpolation technique
|
|
14
|
+
|
|
15
|
+
All transforms inherit from BaseDistortion, which provides a common interface and functionality
|
|
16
|
+
for applying distortion maps to various target types (images, masks, bounding boxes, keypoints).
|
|
17
|
+
These transforms are particularly useful for:
|
|
18
|
+
|
|
19
|
+
- Data augmentation to increase training set diversity
|
|
20
|
+
- Simulating real-world distortion effects like camera lens aberrations
|
|
21
|
+
- Creating more challenging test cases for computer vision models
|
|
22
|
+
- Medical image analysis where anatomy might appear in different shapes
|
|
23
|
+
|
|
24
|
+
Each transform supports customization through various parameters controlling the strength,
|
|
25
|
+
type, and characteristics of the distortion, as well as interpolation methods for different
|
|
26
|
+
target types.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
from typing import Annotated, Any, Literal, cast
|
|
32
|
+
from warnings import warn
|
|
33
|
+
|
|
34
|
+
import cv2
|
|
35
|
+
import numpy as np
|
|
36
|
+
from albucore import batch_transform
|
|
37
|
+
from pydantic import (
|
|
38
|
+
AfterValidator,
|
|
39
|
+
Field,
|
|
40
|
+
ValidationInfo,
|
|
41
|
+
field_validator,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
from albumentations.augmentations.utils import check_range
|
|
45
|
+
from albumentations.core.bbox_utils import (
|
|
46
|
+
denormalize_bboxes,
|
|
47
|
+
normalize_bboxes,
|
|
48
|
+
)
|
|
49
|
+
from albumentations.core.pydantic import (
|
|
50
|
+
NonNegativeFloatRangeType,
|
|
51
|
+
SymmetricRangeType,
|
|
52
|
+
check_range_bounds,
|
|
53
|
+
)
|
|
54
|
+
from albumentations.core.transforms_interface import (
|
|
55
|
+
BaseTransformInitSchema,
|
|
56
|
+
DualTransform,
|
|
57
|
+
)
|
|
58
|
+
from albumentations.core.type_definitions import (
|
|
59
|
+
ALL_TARGETS,
|
|
60
|
+
BIG_INTEGER,
|
|
61
|
+
)
|
|
62
|
+
from albumentations.core.utils import to_tuple
|
|
63
|
+
|
|
64
|
+
from . import functional as fgeometric
|
|
65
|
+
|
|
66
|
+
__all__ = [
|
|
67
|
+
"ElasticTransform",
|
|
68
|
+
"GridDistortion",
|
|
69
|
+
"OpticalDistortion",
|
|
70
|
+
"PiecewiseAffine",
|
|
71
|
+
"ThinPlateSpline",
|
|
72
|
+
]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class BaseDistortion(DualTransform):
|
|
76
|
+
"""Base class for distortion-based transformations.
|
|
77
|
+
|
|
78
|
+
This class provides a foundation for implementing various types of image distortions,
|
|
79
|
+
such as optical distortions, grid distortions, and elastic transformations. It handles
|
|
80
|
+
the common operations of applying distortions to images, masks, bounding boxes, and keypoints.
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
interpolation (int): Interpolation method to be used for image transformation.
|
|
84
|
+
Should be one of the OpenCV interpolation types (e.g., cv2.INTER_LINEAR,
|
|
85
|
+
cv2.INTER_CUBIC).
|
|
86
|
+
mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
|
|
87
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
88
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
89
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
90
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
91
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
92
|
+
Default: "mask"
|
|
93
|
+
p (float): Probability of applying the transform.
|
|
94
|
+
|
|
95
|
+
Targets:
|
|
96
|
+
image, mask, bboxes, keypoints, volume, mask3d
|
|
97
|
+
|
|
98
|
+
Image types:
|
|
99
|
+
uint8, float32
|
|
100
|
+
|
|
101
|
+
Note:
|
|
102
|
+
- This is an abstract base class and should not be used directly.
|
|
103
|
+
- Subclasses should implement the `get_params_dependent_on_data` method to generate
|
|
104
|
+
the distortion maps (map_x and map_y).
|
|
105
|
+
- The distortion is applied consistently across all targets (image, mask, bboxes, keypoints)
|
|
106
|
+
to maintain coherence in the augmented data.
|
|
107
|
+
|
|
108
|
+
Examples:
|
|
109
|
+
>>> import numpy as np
|
|
110
|
+
>>> import albumentations as A
|
|
111
|
+
>>> import cv2
|
|
112
|
+
>>>
|
|
113
|
+
>>> class CustomDistortion(A.BaseDistortion):
|
|
114
|
+
... def __init__(self, distort_limit=0.3, *args, **kwargs):
|
|
115
|
+
... super().__init__(*args, **kwargs)
|
|
116
|
+
... self.distort_limit = distort_limit
|
|
117
|
+
...
|
|
118
|
+
... def get_params_dependent_on_data(self, params, data):
|
|
119
|
+
... height, width = params["shape"][:2]
|
|
120
|
+
... # Create distortion maps - a simple radial distortion in this example
|
|
121
|
+
... map_x = np.zeros((height, width), dtype=np.float32)
|
|
122
|
+
... map_y = np.zeros((height, width), dtype=np.float32)
|
|
123
|
+
...
|
|
124
|
+
... # Calculate distortion center
|
|
125
|
+
... center_x = width / 2
|
|
126
|
+
... center_y = height / 2
|
|
127
|
+
...
|
|
128
|
+
... # Generate distortion maps
|
|
129
|
+
... for y in range(height):
|
|
130
|
+
... for x in range(width):
|
|
131
|
+
... # Distance from center
|
|
132
|
+
... dx = (x - center_x) / width
|
|
133
|
+
... dy = (y - center_y) / height
|
|
134
|
+
... r = np.sqrt(dx * dx + dy * dy)
|
|
135
|
+
...
|
|
136
|
+
... # Apply radial distortion
|
|
137
|
+
... factor = 1 + self.distort_limit * r
|
|
138
|
+
... map_x[y, x] = x + dx * factor
|
|
139
|
+
... map_y[y, x] = y + dy * factor
|
|
140
|
+
...
|
|
141
|
+
... return {"map_x": map_x, "map_y": map_y}
|
|
142
|
+
>>>
|
|
143
|
+
>>> # Prepare sample data
|
|
144
|
+
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
145
|
+
>>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
146
|
+
>>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
|
|
147
|
+
>>> bbox_labels = [1, 2]
|
|
148
|
+
>>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
|
|
149
|
+
>>> keypoint_labels = [0, 1]
|
|
150
|
+
>>>
|
|
151
|
+
>>> # Define transform with the custom distortion
|
|
152
|
+
>>> transform = A.Compose([
|
|
153
|
+
... CustomDistortion(
|
|
154
|
+
... distort_limit=0.2,
|
|
155
|
+
... interpolation=cv2.INTER_LINEAR,
|
|
156
|
+
... mask_interpolation=cv2.INTER_NEAREST,
|
|
157
|
+
... keypoint_remapping_method="mask",
|
|
158
|
+
... p=1.0
|
|
159
|
+
... )
|
|
160
|
+
... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
|
|
161
|
+
... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
|
|
162
|
+
>>>
|
|
163
|
+
>>> # Apply the transform
|
|
164
|
+
>>> transformed = transform(
|
|
165
|
+
... image=image,
|
|
166
|
+
... mask=mask,
|
|
167
|
+
... bboxes=bboxes,
|
|
168
|
+
... bbox_labels=bbox_labels,
|
|
169
|
+
... keypoints=keypoints,
|
|
170
|
+
... keypoint_labels=keypoint_labels
|
|
171
|
+
... )
|
|
172
|
+
>>>
|
|
173
|
+
>>> # Get the transformed data
|
|
174
|
+
>>> transformed_image = transformed['image']
|
|
175
|
+
>>> transformed_mask = transformed['mask']
|
|
176
|
+
>>> transformed_bboxes = transformed['bboxes']
|
|
177
|
+
>>> transformed_keypoints = transformed['keypoints']
|
|
178
|
+
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
_targets = ALL_TARGETS
|
|
182
|
+
|
|
183
|
+
class InitSchema(BaseTransformInitSchema):
|
|
184
|
+
interpolation: Literal[cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4]
|
|
185
|
+
mask_interpolation: Literal[
|
|
186
|
+
cv2.INTER_NEAREST,
|
|
187
|
+
cv2.INTER_LINEAR,
|
|
188
|
+
cv2.INTER_CUBIC,
|
|
189
|
+
cv2.INTER_AREA,
|
|
190
|
+
cv2.INTER_LANCZOS4,
|
|
191
|
+
]
|
|
192
|
+
keypoint_remapping_method: Literal["direct", "mask"]
|
|
193
|
+
border_mode: Literal[
|
|
194
|
+
cv2.BORDER_CONSTANT,
|
|
195
|
+
cv2.BORDER_REPLICATE,
|
|
196
|
+
cv2.BORDER_REFLECT,
|
|
197
|
+
cv2.BORDER_WRAP,
|
|
198
|
+
cv2.BORDER_REFLECT_101,
|
|
199
|
+
]
|
|
200
|
+
fill: tuple[float, ...] | float
|
|
201
|
+
fill_mask: tuple[float, ...] | float
|
|
202
|
+
|
|
203
|
+
def __init__(
|
|
204
|
+
self,
|
|
205
|
+
interpolation: Literal[
|
|
206
|
+
cv2.INTER_NEAREST,
|
|
207
|
+
cv2.INTER_LINEAR,
|
|
208
|
+
cv2.INTER_CUBIC,
|
|
209
|
+
cv2.INTER_AREA,
|
|
210
|
+
cv2.INTER_LANCZOS4,
|
|
211
|
+
],
|
|
212
|
+
mask_interpolation: Literal[
|
|
213
|
+
cv2.INTER_NEAREST,
|
|
214
|
+
cv2.INTER_LINEAR,
|
|
215
|
+
cv2.INTER_CUBIC,
|
|
216
|
+
cv2.INTER_AREA,
|
|
217
|
+
cv2.INTER_LANCZOS4,
|
|
218
|
+
],
|
|
219
|
+
keypoint_remapping_method: Literal["direct", "mask"],
|
|
220
|
+
p: float,
|
|
221
|
+
border_mode: Literal[
|
|
222
|
+
cv2.BORDER_CONSTANT,
|
|
223
|
+
cv2.BORDER_REPLICATE,
|
|
224
|
+
cv2.BORDER_REFLECT,
|
|
225
|
+
cv2.BORDER_WRAP,
|
|
226
|
+
cv2.BORDER_REFLECT_101,
|
|
227
|
+
] = cv2.BORDER_CONSTANT,
|
|
228
|
+
fill: tuple[float, ...] | float = 0,
|
|
229
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
230
|
+
):
|
|
231
|
+
super().__init__(p=p)
|
|
232
|
+
self.interpolation = interpolation
|
|
233
|
+
self.mask_interpolation = mask_interpolation
|
|
234
|
+
self.keypoint_remapping_method = keypoint_remapping_method
|
|
235
|
+
self.border_mode = border_mode
|
|
236
|
+
self.fill = fill
|
|
237
|
+
self.fill_mask = fill_mask
|
|
238
|
+
|
|
239
|
+
def apply(
|
|
240
|
+
self,
|
|
241
|
+
img: np.ndarray,
|
|
242
|
+
map_x: np.ndarray,
|
|
243
|
+
map_y: np.ndarray,
|
|
244
|
+
**params: Any,
|
|
245
|
+
) -> np.ndarray:
|
|
246
|
+
"""Apply the distortion to the input image.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
img (np.ndarray): Input image to be distorted.
|
|
250
|
+
map_x (np.ndarray): X-coordinate map of the distortion.
|
|
251
|
+
map_y (np.ndarray): Y-coordinate map of the distortion.
|
|
252
|
+
**params (Any): Additional parameters.
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
np.ndarray: Distorted image.
|
|
256
|
+
|
|
257
|
+
"""
|
|
258
|
+
return fgeometric.remap(
|
|
259
|
+
img,
|
|
260
|
+
map_x,
|
|
261
|
+
map_y,
|
|
262
|
+
self.interpolation,
|
|
263
|
+
self.border_mode,
|
|
264
|
+
self.fill,
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
@batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
|
|
268
|
+
def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
|
|
269
|
+
"""Apply the distortion to a batch of images.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
images (np.ndarray): Batch of images to be distorted.
|
|
273
|
+
**params (Any): Additional parameters.
|
|
274
|
+
|
|
275
|
+
Returns:
|
|
276
|
+
np.ndarray: Batch of distorted images.
|
|
277
|
+
|
|
278
|
+
"""
|
|
279
|
+
return self.apply(images, **params)
|
|
280
|
+
|
|
281
|
+
@batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
|
|
282
|
+
def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
|
|
283
|
+
"""Apply the distortion to a volume.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
volume (np.ndarray): Volume to be distorted.
|
|
287
|
+
**params (Any): Additional parameters.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
np.ndarray: Distorted volume.
|
|
291
|
+
|
|
292
|
+
"""
|
|
293
|
+
return self.apply(volume, **params)
|
|
294
|
+
|
|
295
|
+
@batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
|
|
296
|
+
def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
|
|
297
|
+
"""Apply the distortion to a batch of volumes.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
volumes (np.ndarray): Batch of volumes to be distorted.
|
|
301
|
+
**params (Any): Additional parameters.
|
|
302
|
+
|
|
303
|
+
Returns:
|
|
304
|
+
np.ndarray: Batch of distorted volumes.
|
|
305
|
+
|
|
306
|
+
"""
|
|
307
|
+
return self.apply(volumes, **params)
|
|
308
|
+
|
|
309
|
+
@batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
|
|
310
|
+
def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
|
|
311
|
+
"""Apply the distortion to a 3D mask.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
mask3d (np.ndarray): 3D mask to be distorted.
|
|
315
|
+
**params (Any): Additional parameters.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
np.ndarray: Distorted 3D mask.
|
|
319
|
+
|
|
320
|
+
"""
|
|
321
|
+
return self.apply_to_mask(mask3d, **params)
|
|
322
|
+
|
|
323
|
+
def apply_to_mask(
|
|
324
|
+
self,
|
|
325
|
+
mask: np.ndarray,
|
|
326
|
+
map_x: np.ndarray,
|
|
327
|
+
map_y: np.ndarray,
|
|
328
|
+
**params: Any,
|
|
329
|
+
) -> np.ndarray:
|
|
330
|
+
"""Apply the distortion to a mask.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
mask (np.ndarray): Mask to be distorted.
|
|
334
|
+
map_x (np.ndarray): X-coordinate map of the distortion.
|
|
335
|
+
map_y (np.ndarray): Y-coordinate map of the distortion.
|
|
336
|
+
**params (Any): Additional parameters.
|
|
337
|
+
|
|
338
|
+
Returns:
|
|
339
|
+
np.ndarray: Distorted mask.
|
|
340
|
+
|
|
341
|
+
"""
|
|
342
|
+
return fgeometric.remap(
|
|
343
|
+
mask,
|
|
344
|
+
map_x,
|
|
345
|
+
map_y,
|
|
346
|
+
self.mask_interpolation,
|
|
347
|
+
self.border_mode,
|
|
348
|
+
self.fill_mask,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def apply_to_bboxes(
|
|
352
|
+
self,
|
|
353
|
+
bboxes: np.ndarray,
|
|
354
|
+
map_x: np.ndarray,
|
|
355
|
+
map_y: np.ndarray,
|
|
356
|
+
**params: Any,
|
|
357
|
+
) -> np.ndarray:
|
|
358
|
+
"""Apply the distortion to bounding boxes.
|
|
359
|
+
|
|
360
|
+
Args:
|
|
361
|
+
bboxes (np.ndarray): Bounding boxes to be distorted.
|
|
362
|
+
map_x (np.ndarray): X-coordinate map of the distortion.
|
|
363
|
+
map_y (np.ndarray): Y-coordinate map of the distortion.
|
|
364
|
+
**params (Any): Additional parameters.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
np.ndarray: Distorted bounding boxes.
|
|
368
|
+
|
|
369
|
+
"""
|
|
370
|
+
image_shape = params["shape"][:2]
|
|
371
|
+
bboxes_denorm = denormalize_bboxes(bboxes, image_shape)
|
|
372
|
+
bboxes_returned = fgeometric.remap_bboxes(
|
|
373
|
+
bboxes_denorm,
|
|
374
|
+
map_x,
|
|
375
|
+
map_y,
|
|
376
|
+
image_shape,
|
|
377
|
+
)
|
|
378
|
+
return normalize_bboxes(bboxes_returned, image_shape)
|
|
379
|
+
|
|
380
|
+
def apply_to_keypoints(
|
|
381
|
+
self,
|
|
382
|
+
keypoints: np.ndarray,
|
|
383
|
+
map_x: np.ndarray,
|
|
384
|
+
map_y: np.ndarray,
|
|
385
|
+
**params: Any,
|
|
386
|
+
) -> np.ndarray:
|
|
387
|
+
"""Apply the distortion to keypoints.
|
|
388
|
+
|
|
389
|
+
Args:
|
|
390
|
+
keypoints (np.ndarray): Keypoints to be distorted.
|
|
391
|
+
map_x (np.ndarray): X-coordinate map of the distortion.
|
|
392
|
+
map_y (np.ndarray): Y-coordinate map of the distortion.
|
|
393
|
+
**params (Any): Additional parameters.
|
|
394
|
+
|
|
395
|
+
Returns:
|
|
396
|
+
np.ndarray: Distorted keypoints.
|
|
397
|
+
|
|
398
|
+
"""
|
|
399
|
+
if self.keypoint_remapping_method == "direct":
|
|
400
|
+
return fgeometric.remap_keypoints(keypoints, map_x, map_y, params["shape"])
|
|
401
|
+
return fgeometric.remap_keypoints_via_mask(keypoints, map_x, map_y, params["shape"])
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
class ElasticTransform(BaseDistortion):
|
|
405
|
+
"""Apply elastic deformation to images, masks, bounding boxes, and keypoints.
|
|
406
|
+
|
|
407
|
+
This transformation introduces random elastic distortions to the input data. It's particularly
|
|
408
|
+
useful for data augmentation in training deep learning models, especially for tasks like
|
|
409
|
+
image segmentation or object detection where you want to maintain the relative positions of
|
|
410
|
+
features while introducing realistic deformations.
|
|
411
|
+
|
|
412
|
+
The transform works by generating random displacement fields and applying them to the input.
|
|
413
|
+
These fields are smoothed using a Gaussian filter to create more natural-looking distortions.
|
|
414
|
+
|
|
415
|
+
Args:
|
|
416
|
+
alpha (float): Scaling factor for the random displacement fields. Higher values result in
|
|
417
|
+
more pronounced distortions. Default: 1.0
|
|
418
|
+
sigma (float): Standard deviation of the Gaussian filter used to smooth the displacement
|
|
419
|
+
fields. Higher values result in smoother, more global distortions. Default: 50.0
|
|
420
|
+
interpolation (int): Interpolation method to be used for image transformation. Should be one
|
|
421
|
+
of the OpenCV interpolation types. Default: cv2.INTER_LINEAR
|
|
422
|
+
approximate (bool): Whether to use an approximate version of the elastic transform. If True,
|
|
423
|
+
uses a fixed kernel size for Gaussian smoothing, which can be faster but potentially
|
|
424
|
+
less accurate for large sigma values. Default: False
|
|
425
|
+
same_dxdy (bool): Whether to use the same random displacement field for both x and y
|
|
426
|
+
directions. Can speed up the transform at the cost of less diverse distortions. Default: False
|
|
427
|
+
mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
|
|
428
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
429
|
+
Default: cv2.INTER_NEAREST.
|
|
430
|
+
noise_distribution (Literal["gaussian", "uniform"]): Distribution used to generate the displacement fields.
|
|
431
|
+
"gaussian" generates fields using normal distribution (more natural deformations).
|
|
432
|
+
"uniform" generates fields using uniform distribution (more mechanical deformations).
|
|
433
|
+
Default: "gaussian".
|
|
434
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
435
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
436
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
437
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
438
|
+
Default: "mask"
|
|
439
|
+
|
|
440
|
+
p (float): Probability of applying the transform. Default: 0.5
|
|
441
|
+
|
|
442
|
+
Targets:
|
|
443
|
+
image, mask, bboxes, keypoints, volume, mask3d
|
|
444
|
+
|
|
445
|
+
Image types:
|
|
446
|
+
uint8, float32
|
|
447
|
+
|
|
448
|
+
Note:
|
|
449
|
+
- The transform will maintain consistency across all targets (image, mask, bboxes, keypoints)
|
|
450
|
+
by using the same displacement fields for all.
|
|
451
|
+
- The 'approximate' parameter determines whether to use a precise or approximate method for
|
|
452
|
+
generating displacement fields. The approximate method can be faster but may be less
|
|
453
|
+
accurate for large sigma values.
|
|
454
|
+
- Bounding boxes that end up outside the image after transformation will be removed.
|
|
455
|
+
- Keypoints that end up outside the image after transformation will be removed.
|
|
456
|
+
|
|
457
|
+
Examples:
|
|
458
|
+
>>> import albumentations as A
|
|
459
|
+
>>> transform = A.Compose([
|
|
460
|
+
... A.ElasticTransform(alpha=1, sigma=50, p=0.5),
|
|
461
|
+
... ])
|
|
462
|
+
>>> transformed = transform(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
|
|
463
|
+
>>> transformed_image = transformed['image']
|
|
464
|
+
>>> transformed_mask = transformed['mask']
|
|
465
|
+
>>> transformed_bboxes = transformed['bboxes']
|
|
466
|
+
>>> transformed_keypoints = transformed['keypoints']
|
|
467
|
+
|
|
468
|
+
"""
|
|
469
|
+
|
|
470
|
+
class InitSchema(BaseDistortion.InitSchema):
|
|
471
|
+
alpha: Annotated[float, Field(ge=0)]
|
|
472
|
+
sigma: Annotated[float, Field(ge=1)]
|
|
473
|
+
approximate: bool
|
|
474
|
+
same_dxdy: bool
|
|
475
|
+
noise_distribution: Literal["gaussian", "uniform"]
|
|
476
|
+
keypoint_remapping_method: Literal["direct", "mask"]
|
|
477
|
+
|
|
478
|
+
def __init__(
|
|
479
|
+
self,
|
|
480
|
+
alpha: float = 1,
|
|
481
|
+
sigma: float = 50,
|
|
482
|
+
interpolation: Literal[
|
|
483
|
+
cv2.INTER_NEAREST,
|
|
484
|
+
cv2.INTER_LINEAR,
|
|
485
|
+
cv2.INTER_CUBIC,
|
|
486
|
+
cv2.INTER_AREA,
|
|
487
|
+
cv2.INTER_LANCZOS4,
|
|
488
|
+
] = cv2.INTER_LINEAR,
|
|
489
|
+
approximate: bool = False,
|
|
490
|
+
same_dxdy: bool = False,
|
|
491
|
+
mask_interpolation: Literal[
|
|
492
|
+
cv2.INTER_NEAREST,
|
|
493
|
+
cv2.INTER_LINEAR,
|
|
494
|
+
cv2.INTER_CUBIC,
|
|
495
|
+
cv2.INTER_AREA,
|
|
496
|
+
cv2.INTER_LANCZOS4,
|
|
497
|
+
] = cv2.INTER_NEAREST,
|
|
498
|
+
noise_distribution: Literal["gaussian", "uniform"] = "gaussian",
|
|
499
|
+
keypoint_remapping_method: Literal["direct", "mask"] = "mask",
|
|
500
|
+
border_mode: Literal[
|
|
501
|
+
cv2.BORDER_CONSTANT,
|
|
502
|
+
cv2.BORDER_REPLICATE,
|
|
503
|
+
cv2.BORDER_REFLECT,
|
|
504
|
+
cv2.BORDER_WRAP,
|
|
505
|
+
cv2.BORDER_REFLECT_101,
|
|
506
|
+
] = cv2.BORDER_CONSTANT,
|
|
507
|
+
fill: tuple[float, ...] | float = 0,
|
|
508
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
509
|
+
p: float = 0.5,
|
|
510
|
+
):
|
|
511
|
+
super().__init__(
|
|
512
|
+
interpolation=interpolation,
|
|
513
|
+
mask_interpolation=mask_interpolation,
|
|
514
|
+
keypoint_remapping_method=keypoint_remapping_method,
|
|
515
|
+
p=p,
|
|
516
|
+
border_mode=border_mode,
|
|
517
|
+
fill=fill,
|
|
518
|
+
fill_mask=fill_mask,
|
|
519
|
+
)
|
|
520
|
+
self.alpha = alpha
|
|
521
|
+
self.sigma = sigma
|
|
522
|
+
self.approximate = approximate
|
|
523
|
+
self.same_dxdy = same_dxdy
|
|
524
|
+
self.noise_distribution = noise_distribution
|
|
525
|
+
|
|
526
|
+
def get_params_dependent_on_data(
|
|
527
|
+
self,
|
|
528
|
+
params: dict[str, Any],
|
|
529
|
+
data: dict[str, Any],
|
|
530
|
+
) -> dict[str, Any]:
|
|
531
|
+
"""Generate displacement fields for the elastic transform.
|
|
532
|
+
|
|
533
|
+
Args:
|
|
534
|
+
params (dict[str, Any]): Dictionary containing parameters for the transform.
|
|
535
|
+
data (dict[str, Any]): Dictionary containing data for the transform.
|
|
536
|
+
|
|
537
|
+
Returns:
|
|
538
|
+
dict[str, Any]: Dictionary containing displacement fields for the elastic transform.
|
|
539
|
+
|
|
540
|
+
"""
|
|
541
|
+
height, width = params["shape"][:2]
|
|
542
|
+
kernel_size = (17, 17) if self.approximate else (0, 0)
|
|
543
|
+
|
|
544
|
+
# Generate displacement fields
|
|
545
|
+
dx, dy = fgeometric.generate_displacement_fields(
|
|
546
|
+
(height, width),
|
|
547
|
+
self.alpha,
|
|
548
|
+
self.sigma,
|
|
549
|
+
same_dxdy=self.same_dxdy,
|
|
550
|
+
kernel_size=kernel_size,
|
|
551
|
+
random_generator=self.random_generator,
|
|
552
|
+
noise_distribution=self.noise_distribution,
|
|
553
|
+
)
|
|
554
|
+
|
|
555
|
+
# Vectorized map generation
|
|
556
|
+
coords = np.stack(np.meshgrid(np.arange(width), np.arange(height)))
|
|
557
|
+
maps = coords + np.stack([dx, dy])
|
|
558
|
+
return {
|
|
559
|
+
"map_x": maps[0].astype(np.float32),
|
|
560
|
+
"map_y": maps[1].astype(np.float32),
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
class PiecewiseAffine(BaseDistortion):
|
|
565
|
+
"""Apply piecewise affine transformations to the input image.
|
|
566
|
+
|
|
567
|
+
This augmentation places a regular grid of points on an image and randomly moves the neighborhood of these points
|
|
568
|
+
around via affine transformations. This leads to local distortions in the image.
|
|
569
|
+
|
|
570
|
+
Args:
|
|
571
|
+
scale (tuple[float, float] | float): Standard deviation of the normal distributions. These are used to sample
|
|
572
|
+
the random distances of the subimage's corners from the full image's corners.
|
|
573
|
+
If scale is a single float value, the range will be (0, scale).
|
|
574
|
+
Recommended values are in the range (0.01, 0.05) for small distortions,
|
|
575
|
+
and (0.05, 0.1) for larger distortions. Default: (0.03, 0.05).
|
|
576
|
+
nb_rows (tuple[int, int] | int): Number of rows of points that the regular grid should have.
|
|
577
|
+
Must be at least 2. For large images, you might want to pick a higher value than 4.
|
|
578
|
+
If a single int, then that value will always be used as the number of rows.
|
|
579
|
+
If a tuple (a, b), then a value from the discrete interval [a..b] will be uniformly sampled per image.
|
|
580
|
+
Default: 4.
|
|
581
|
+
nb_cols (tuple[int, int] | int): Number of columns of points that the regular grid should have.
|
|
582
|
+
Must be at least 2. For large images, you might want to pick a higher value than 4.
|
|
583
|
+
If a single int, then that value will always be used as the number of columns.
|
|
584
|
+
If a tuple (a, b), then a value from the discrete interval [a..b] will be uniformly sampled per image.
|
|
585
|
+
Default: 4.
|
|
586
|
+
interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm.
|
|
587
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
588
|
+
Default: cv2.INTER_LINEAR.
|
|
589
|
+
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
|
|
590
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
591
|
+
Default: cv2.INTER_NEAREST.
|
|
592
|
+
absolute_scale (bool): If set to True, the value of the scale parameter will be treated as an absolute
|
|
593
|
+
pixel value. If set to False, it will be treated as a fraction of the image height and width.
|
|
594
|
+
Default: False.
|
|
595
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
596
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
597
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
598
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
599
|
+
Default: "mask"
|
|
600
|
+
p (float): Probability of applying the transform. Default: 0.5.
|
|
601
|
+
|
|
602
|
+
Targets:
|
|
603
|
+
image, mask, keypoints, bboxes, volume, mask3d
|
|
604
|
+
|
|
605
|
+
Image types:
|
|
606
|
+
uint8, float32
|
|
607
|
+
|
|
608
|
+
Note:
|
|
609
|
+
- This augmentation is very slow. Consider using `ElasticTransform` instead, which is at least 10x faster.
|
|
610
|
+
- The augmentation may not always produce visible effects, especially with small scale values.
|
|
611
|
+
- For keypoints and bounding boxes, the transformation might move them outside the image boundaries.
|
|
612
|
+
In such cases, the keypoints will be set to (-1, -1) and the bounding boxes will be removed.
|
|
613
|
+
|
|
614
|
+
Examples:
|
|
615
|
+
>>> import numpy as np
|
|
616
|
+
>>> import albumentations as A
|
|
617
|
+
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
618
|
+
>>> transform = A.Compose([
|
|
619
|
+
... A.PiecewiseAffine(scale=(0.03, 0.05), nb_rows=4, nb_cols=4, p=0.5),
|
|
620
|
+
... ])
|
|
621
|
+
>>> transformed = transform(image=image)
|
|
622
|
+
>>> transformed_image = transformed["image"]
|
|
623
|
+
|
|
624
|
+
"""
|
|
625
|
+
|
|
626
|
+
class InitSchema(BaseDistortion.InitSchema):
|
|
627
|
+
scale: NonNegativeFloatRangeType
|
|
628
|
+
nb_rows: tuple[int, int] | int
|
|
629
|
+
nb_cols: tuple[int, int] | int
|
|
630
|
+
absolute_scale: bool
|
|
631
|
+
|
|
632
|
+
@field_validator("nb_rows", "nb_cols")
|
|
633
|
+
@classmethod
|
|
634
|
+
def _process_range(
|
|
635
|
+
cls,
|
|
636
|
+
value: tuple[int, int] | int,
|
|
637
|
+
info: ValidationInfo,
|
|
638
|
+
) -> tuple[int, int]:
|
|
639
|
+
bounds = 2, BIG_INTEGER
|
|
640
|
+
result = to_tuple(value, value)
|
|
641
|
+
check_range(result, *bounds, info.field_name)
|
|
642
|
+
return result
|
|
643
|
+
|
|
644
|
+
def __init__(
|
|
645
|
+
self,
|
|
646
|
+
scale: tuple[float, float] | float = (0.03, 0.05),
|
|
647
|
+
nb_rows: tuple[int, int] | int = (4, 4),
|
|
648
|
+
nb_cols: tuple[int, int] | int = (4, 4),
|
|
649
|
+
interpolation: Literal[
|
|
650
|
+
cv2.INTER_NEAREST,
|
|
651
|
+
cv2.INTER_LINEAR,
|
|
652
|
+
cv2.INTER_CUBIC,
|
|
653
|
+
cv2.INTER_AREA,
|
|
654
|
+
cv2.INTER_LANCZOS4,
|
|
655
|
+
] = cv2.INTER_LINEAR,
|
|
656
|
+
mask_interpolation: Literal[
|
|
657
|
+
cv2.INTER_NEAREST,
|
|
658
|
+
cv2.INTER_LINEAR,
|
|
659
|
+
cv2.INTER_CUBIC,
|
|
660
|
+
cv2.INTER_AREA,
|
|
661
|
+
cv2.INTER_LANCZOS4,
|
|
662
|
+
] = cv2.INTER_NEAREST,
|
|
663
|
+
absolute_scale: bool = False,
|
|
664
|
+
keypoint_remapping_method: Literal["direct", "mask"] = "mask",
|
|
665
|
+
p: float = 0.5,
|
|
666
|
+
border_mode: Literal[
|
|
667
|
+
cv2.BORDER_CONSTANT,
|
|
668
|
+
cv2.BORDER_REPLICATE,
|
|
669
|
+
cv2.BORDER_REFLECT,
|
|
670
|
+
cv2.BORDER_WRAP,
|
|
671
|
+
cv2.BORDER_REFLECT_101,
|
|
672
|
+
] = cv2.BORDER_CONSTANT,
|
|
673
|
+
fill: tuple[float, ...] | float = 0,
|
|
674
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
675
|
+
):
|
|
676
|
+
super().__init__(
|
|
677
|
+
p=p,
|
|
678
|
+
interpolation=interpolation,
|
|
679
|
+
mask_interpolation=mask_interpolation,
|
|
680
|
+
keypoint_remapping_method=keypoint_remapping_method,
|
|
681
|
+
border_mode=border_mode,
|
|
682
|
+
fill=fill,
|
|
683
|
+
fill_mask=fill_mask,
|
|
684
|
+
)
|
|
685
|
+
|
|
686
|
+
warn(
|
|
687
|
+
"This augmenter is very slow. Try to use ``ElasticTransform`` instead, which is at least 10x faster.",
|
|
688
|
+
stacklevel=2,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
self.scale = cast("tuple[float, float]", scale)
|
|
692
|
+
self.nb_rows = cast("tuple[int, int]", nb_rows)
|
|
693
|
+
self.nb_cols = cast("tuple[int, int]", nb_cols)
|
|
694
|
+
self.absolute_scale = absolute_scale
|
|
695
|
+
|
|
696
|
+
def get_params_dependent_on_data(
|
|
697
|
+
self,
|
|
698
|
+
params: dict[str, Any],
|
|
699
|
+
data: dict[str, Any],
|
|
700
|
+
) -> dict[str, Any]:
|
|
701
|
+
"""Get the parameters dependent on the data.
|
|
702
|
+
|
|
703
|
+
Args:
|
|
704
|
+
params (dict[str, Any]): Parameters.
|
|
705
|
+
data (dict[str, Any]): Data.
|
|
706
|
+
|
|
707
|
+
Returns:
|
|
708
|
+
dict[str, Any]: Parameters.
|
|
709
|
+
|
|
710
|
+
"""
|
|
711
|
+
image_shape = params["shape"][:2]
|
|
712
|
+
|
|
713
|
+
nb_rows = np.clip(self.py_random.randint(*self.nb_rows), 2, None)
|
|
714
|
+
nb_cols = np.clip(self.py_random.randint(*self.nb_cols), 2, None)
|
|
715
|
+
scale = self.py_random.uniform(*self.scale)
|
|
716
|
+
|
|
717
|
+
map_x, map_y = fgeometric.create_piecewise_affine_maps(
|
|
718
|
+
image_shape=image_shape,
|
|
719
|
+
grid=(nb_rows, nb_cols),
|
|
720
|
+
scale=scale,
|
|
721
|
+
absolute_scale=self.absolute_scale,
|
|
722
|
+
random_generator=self.random_generator,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
return {"map_x": map_x, "map_y": map_y}
|
|
726
|
+
|
|
727
|
+
|
|
728
|
+
class OpticalDistortion(BaseDistortion):
|
|
729
|
+
"""Apply optical distortion to images, masks, bounding boxes, and keypoints.
|
|
730
|
+
|
|
731
|
+
Supports two distortion models:
|
|
732
|
+
1. Camera matrix model (original):
|
|
733
|
+
Uses OpenCV's camera calibration model with k1=k2=k distortion coefficients
|
|
734
|
+
|
|
735
|
+
2. Fisheye model:
|
|
736
|
+
Direct radial distortion: r_dist = r * (1 + gamma * r²)
|
|
737
|
+
|
|
738
|
+
Args:
|
|
739
|
+
distort_limit (float | tuple[float, float]): Range of distortion coefficient.
|
|
740
|
+
For camera model: recommended range (-0.05, 0.05)
|
|
741
|
+
For fisheye model: recommended range (-0.3, 0.3)
|
|
742
|
+
Default: (-0.05, 0.05)
|
|
743
|
+
|
|
744
|
+
mode (Literal['camera', 'fisheye']): Distortion model to use:
|
|
745
|
+
- 'camera': Original camera matrix model
|
|
746
|
+
- 'fisheye': Fisheye lens model
|
|
747
|
+
Default: 'camera'
|
|
748
|
+
|
|
749
|
+
interpolation (OpenCV flag): Interpolation method used for image transformation.
|
|
750
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC,
|
|
751
|
+
cv2.INTER_AREA, cv2.INTER_LANCZOS4. Default: cv2.INTER_LINEAR.
|
|
752
|
+
|
|
753
|
+
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
|
|
754
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
755
|
+
Default: cv2.INTER_NEAREST.
|
|
756
|
+
|
|
757
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
758
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
759
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
760
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
761
|
+
Default: "mask"
|
|
762
|
+
|
|
763
|
+
p (float): Probability of applying the transform. Default: 0.5.
|
|
764
|
+
|
|
765
|
+
Targets:
|
|
766
|
+
image, mask, bboxes, keypoints, volume, mask3d
|
|
767
|
+
|
|
768
|
+
Image types:
|
|
769
|
+
uint8, float32
|
|
770
|
+
|
|
771
|
+
Note:
|
|
772
|
+
- The distortion is applied using OpenCV's initUndistortRectifyMap and remap functions.
|
|
773
|
+
- The distortion coefficient (k) is randomly sampled from the distort_limit range.
|
|
774
|
+
- Bounding boxes and keypoints are transformed along with the image to maintain consistency.
|
|
775
|
+
- Fisheye model directly applies radial distortion
|
|
776
|
+
|
|
777
|
+
Examples:
|
|
778
|
+
>>> import albumentations as A
|
|
779
|
+
>>> transform = A.Compose([
|
|
780
|
+
... A.OpticalDistortion(distort_limit=0.1, p=1.0),
|
|
781
|
+
... ])
|
|
782
|
+
>>> transformed = transform(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
|
|
783
|
+
>>> transformed_image = transformed['image']
|
|
784
|
+
>>> transformed_mask = transformed['mask']
|
|
785
|
+
>>> transformed_bboxes = transformed['bboxes']
|
|
786
|
+
>>> transformed_keypoints = transformed['keypoints']
|
|
787
|
+
|
|
788
|
+
"""
|
|
789
|
+
|
|
790
|
+
class InitSchema(BaseDistortion.InitSchema):
|
|
791
|
+
distort_limit: SymmetricRangeType
|
|
792
|
+
mode: Literal["camera", "fisheye"]
|
|
793
|
+
keypoint_remapping_method: Literal["direct", "mask"]
|
|
794
|
+
|
|
795
|
+
def __init__(
|
|
796
|
+
self,
|
|
797
|
+
distort_limit: tuple[float, float] | float = (-0.05, 0.05),
|
|
798
|
+
interpolation: Literal[
|
|
799
|
+
cv2.INTER_NEAREST,
|
|
800
|
+
cv2.INTER_LINEAR,
|
|
801
|
+
cv2.INTER_CUBIC,
|
|
802
|
+
cv2.INTER_AREA,
|
|
803
|
+
cv2.INTER_LANCZOS4,
|
|
804
|
+
] = cv2.INTER_LINEAR,
|
|
805
|
+
mask_interpolation: Literal[
|
|
806
|
+
cv2.INTER_NEAREST,
|
|
807
|
+
cv2.INTER_LINEAR,
|
|
808
|
+
cv2.INTER_CUBIC,
|
|
809
|
+
cv2.INTER_AREA,
|
|
810
|
+
cv2.INTER_LANCZOS4,
|
|
811
|
+
] = cv2.INTER_NEAREST,
|
|
812
|
+
mode: Literal["camera", "fisheye"] = "camera",
|
|
813
|
+
keypoint_remapping_method: Literal["direct", "mask"] = "mask",
|
|
814
|
+
p: float = 0.5,
|
|
815
|
+
border_mode: Literal[
|
|
816
|
+
cv2.BORDER_CONSTANT,
|
|
817
|
+
cv2.BORDER_REPLICATE,
|
|
818
|
+
cv2.BORDER_REFLECT,
|
|
819
|
+
cv2.BORDER_WRAP,
|
|
820
|
+
cv2.BORDER_REFLECT_101,
|
|
821
|
+
] = cv2.BORDER_CONSTANT,
|
|
822
|
+
fill: tuple[float, ...] | float = 0,
|
|
823
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
824
|
+
):
|
|
825
|
+
super().__init__(
|
|
826
|
+
interpolation=interpolation,
|
|
827
|
+
mask_interpolation=mask_interpolation,
|
|
828
|
+
keypoint_remapping_method=keypoint_remapping_method,
|
|
829
|
+
p=p,
|
|
830
|
+
border_mode=border_mode,
|
|
831
|
+
fill=fill,
|
|
832
|
+
fill_mask=fill_mask,
|
|
833
|
+
)
|
|
834
|
+
self.distort_limit = cast("tuple[float, float]", distort_limit)
|
|
835
|
+
self.mode = mode
|
|
836
|
+
|
|
837
|
+
def get_params_dependent_on_data(
|
|
838
|
+
self,
|
|
839
|
+
params: dict[str, Any],
|
|
840
|
+
data: dict[str, Any],
|
|
841
|
+
) -> dict[str, Any]:
|
|
842
|
+
"""Get the parameters dependent on the data.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
params (dict[str, Any]): Parameters.
|
|
846
|
+
data (dict[str, Any]): Data.
|
|
847
|
+
|
|
848
|
+
Returns:
|
|
849
|
+
dict[str, Any]: Parameters.
|
|
850
|
+
|
|
851
|
+
"""
|
|
852
|
+
image_shape = params["shape"][:2]
|
|
853
|
+
|
|
854
|
+
# Get distortion coefficient
|
|
855
|
+
k = self.py_random.uniform(*self.distort_limit)
|
|
856
|
+
|
|
857
|
+
# Get distortion maps based on mode
|
|
858
|
+
if self.mode == "camera":
|
|
859
|
+
map_x, map_y = fgeometric.get_camera_matrix_distortion_maps(
|
|
860
|
+
image_shape,
|
|
861
|
+
k,
|
|
862
|
+
)
|
|
863
|
+
else: # fisheye
|
|
864
|
+
map_x, map_y = fgeometric.get_fisheye_distortion_maps(
|
|
865
|
+
image_shape,
|
|
866
|
+
k,
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
return {"map_x": map_x, "map_y": map_y}
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
class GridDistortion(BaseDistortion):
|
|
873
|
+
"""Apply grid distortion to images, masks, bounding boxes, and keypoints.
|
|
874
|
+
|
|
875
|
+
This transformation divides the image into a grid and randomly distorts each cell,
|
|
876
|
+
creating localized warping effects. It's particularly useful for data augmentation
|
|
877
|
+
in tasks like medical image analysis, OCR, and other domains where local geometric
|
|
878
|
+
variations are meaningful.
|
|
879
|
+
|
|
880
|
+
Args:
|
|
881
|
+
num_steps (int): Number of grid cells on each side of the image. Higher values
|
|
882
|
+
create more granular distortions. Must be at least 1. Default: 5.
|
|
883
|
+
distort_limit (float or tuple[float, float]): Range of distortion. If a single float
|
|
884
|
+
is provided, the range will be (-distort_limit, distort_limit). Higher values
|
|
885
|
+
create stronger distortions. Should be in the range of -1 to 1.
|
|
886
|
+
Default: (-0.3, 0.3).
|
|
887
|
+
interpolation (int): OpenCV interpolation method used for image transformation.
|
|
888
|
+
Options include cv2.INTER_LINEAR, cv2.INTER_CUBIC, etc. Default: cv2.INTER_LINEAR.
|
|
889
|
+
normalized (bool): If True, ensures that the distortion does not move pixels
|
|
890
|
+
outside the image boundaries. This can result in less extreme distortions
|
|
891
|
+
but guarantees that no information is lost. Default: True.
|
|
892
|
+
mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
|
|
893
|
+
Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
|
894
|
+
Default: cv2.INTER_NEAREST.
|
|
895
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
896
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
897
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
898
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
899
|
+
Default: "mask"
|
|
900
|
+
p (float): Probability of applying the transform. Default: 0.5.
|
|
901
|
+
|
|
902
|
+
Targets:
|
|
903
|
+
image, mask, bboxes, keypoints, volume, mask3d
|
|
904
|
+
|
|
905
|
+
Image types:
|
|
906
|
+
uint8, float32
|
|
907
|
+
|
|
908
|
+
Note:
|
|
909
|
+
- The same distortion is applied to all targets (image, mask, bboxes, keypoints)
|
|
910
|
+
to maintain consistency.
|
|
911
|
+
- When normalized=True, the distortion is adjusted to ensure all pixels remain
|
|
912
|
+
within the image boundaries.
|
|
913
|
+
|
|
914
|
+
Examples:
|
|
915
|
+
>>> import albumentations as A
|
|
916
|
+
>>> transform = A.Compose([
|
|
917
|
+
... A.GridDistortion(num_steps=5, distort_limit=0.3, p=1.0),
|
|
918
|
+
... ])
|
|
919
|
+
>>> transformed = transform(image=image, mask=mask, bboxes=bboxes, keypoints=keypoints)
|
|
920
|
+
>>> transformed_image = transformed['image']
|
|
921
|
+
>>> transformed_mask = transformed['mask']
|
|
922
|
+
>>> transformed_bboxes = transformed['bboxes']
|
|
923
|
+
>>> transformed_keypoints = transformed['keypoints']
|
|
924
|
+
|
|
925
|
+
"""
|
|
926
|
+
|
|
927
|
+
class InitSchema(BaseDistortion.InitSchema):
|
|
928
|
+
num_steps: Annotated[int, Field(ge=1)]
|
|
929
|
+
distort_limit: SymmetricRangeType
|
|
930
|
+
normalized: bool
|
|
931
|
+
keypoint_remapping_method: Literal["direct", "mask"]
|
|
932
|
+
|
|
933
|
+
@field_validator("distort_limit")
|
|
934
|
+
@classmethod
|
|
935
|
+
def _check_limits(
|
|
936
|
+
cls,
|
|
937
|
+
v: tuple[float, float],
|
|
938
|
+
info: ValidationInfo,
|
|
939
|
+
) -> tuple[float, float]:
|
|
940
|
+
bounds = -1, 1
|
|
941
|
+
result = to_tuple(v)
|
|
942
|
+
check_range(result, *bounds, info.field_name)
|
|
943
|
+
return result
|
|
944
|
+
|
|
945
|
+
def __init__(
|
|
946
|
+
self,
|
|
947
|
+
num_steps: int = 5,
|
|
948
|
+
distort_limit: tuple[float, float] | float = (-0.3, 0.3),
|
|
949
|
+
interpolation: Literal[
|
|
950
|
+
cv2.INTER_NEAREST,
|
|
951
|
+
cv2.INTER_LINEAR,
|
|
952
|
+
cv2.INTER_CUBIC,
|
|
953
|
+
cv2.INTER_AREA,
|
|
954
|
+
cv2.INTER_LANCZOS4,
|
|
955
|
+
] = cv2.INTER_LINEAR,
|
|
956
|
+
normalized: bool = True,
|
|
957
|
+
mask_interpolation: Literal[
|
|
958
|
+
cv2.INTER_NEAREST,
|
|
959
|
+
cv2.INTER_LINEAR,
|
|
960
|
+
cv2.INTER_CUBIC,
|
|
961
|
+
cv2.INTER_AREA,
|
|
962
|
+
cv2.INTER_LANCZOS4,
|
|
963
|
+
] = cv2.INTER_NEAREST,
|
|
964
|
+
keypoint_remapping_method: Literal["direct", "mask"] = "mask",
|
|
965
|
+
p: float = 0.5,
|
|
966
|
+
border_mode: Literal[
|
|
967
|
+
cv2.BORDER_CONSTANT,
|
|
968
|
+
cv2.BORDER_REPLICATE,
|
|
969
|
+
cv2.BORDER_REFLECT,
|
|
970
|
+
cv2.BORDER_WRAP,
|
|
971
|
+
cv2.BORDER_REFLECT_101,
|
|
972
|
+
] = cv2.BORDER_CONSTANT,
|
|
973
|
+
fill: tuple[float, ...] | float = 0,
|
|
974
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
975
|
+
):
|
|
976
|
+
super().__init__(
|
|
977
|
+
interpolation=interpolation,
|
|
978
|
+
mask_interpolation=mask_interpolation,
|
|
979
|
+
keypoint_remapping_method=keypoint_remapping_method,
|
|
980
|
+
p=p,
|
|
981
|
+
border_mode=border_mode,
|
|
982
|
+
fill=fill,
|
|
983
|
+
fill_mask=fill_mask,
|
|
984
|
+
)
|
|
985
|
+
self.num_steps = num_steps
|
|
986
|
+
self.distort_limit = cast("tuple[float, float]", distort_limit)
|
|
987
|
+
self.normalized = normalized
|
|
988
|
+
|
|
989
|
+
def get_params_dependent_on_data(
|
|
990
|
+
self,
|
|
991
|
+
params: dict[str, Any],
|
|
992
|
+
data: dict[str, Any],
|
|
993
|
+
) -> dict[str, Any]:
|
|
994
|
+
"""Get the parameters dependent on the data.
|
|
995
|
+
|
|
996
|
+
Args:
|
|
997
|
+
params (dict[str, Any]): Parameters.
|
|
998
|
+
data (dict[str, Any]): Data.
|
|
999
|
+
|
|
1000
|
+
Returns:
|
|
1001
|
+
dict[str, Any]: Parameters.
|
|
1002
|
+
|
|
1003
|
+
"""
|
|
1004
|
+
image_shape = params["shape"][:2]
|
|
1005
|
+
steps_x = [1 + self.py_random.uniform(*self.distort_limit) for _ in range(self.num_steps + 1)]
|
|
1006
|
+
steps_y = [1 + self.py_random.uniform(*self.distort_limit) for _ in range(self.num_steps + 1)]
|
|
1007
|
+
|
|
1008
|
+
if self.normalized:
|
|
1009
|
+
normalized_params = fgeometric.normalize_grid_distortion_steps(
|
|
1010
|
+
image_shape,
|
|
1011
|
+
self.num_steps,
|
|
1012
|
+
steps_x,
|
|
1013
|
+
steps_y,
|
|
1014
|
+
)
|
|
1015
|
+
steps_x, steps_y = (
|
|
1016
|
+
normalized_params["steps_x"],
|
|
1017
|
+
normalized_params["steps_y"],
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
map_x, map_y = fgeometric.generate_grid(
|
|
1021
|
+
image_shape,
|
|
1022
|
+
steps_x,
|
|
1023
|
+
steps_y,
|
|
1024
|
+
self.num_steps,
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
return {"map_x": map_x, "map_y": map_y}
|
|
1028
|
+
|
|
1029
|
+
|
|
1030
|
+
class ThinPlateSpline(BaseDistortion):
|
|
1031
|
+
r"""Apply Thin Plate Spline (TPS) transformation to create smooth, non-rigid deformations.
|
|
1032
|
+
|
|
1033
|
+
Imagine the image printed on a thin metal plate that can be bent and warped smoothly:
|
|
1034
|
+
- Control points act like pins pushing or pulling the plate
|
|
1035
|
+
- The plate resists sharp bending, creating smooth deformations
|
|
1036
|
+
- The transformation maintains continuity (no tears or folds)
|
|
1037
|
+
- Areas between control points are interpolated naturally
|
|
1038
|
+
|
|
1039
|
+
The transform works by:
|
|
1040
|
+
1. Creating a regular grid of control points (like pins in the plate)
|
|
1041
|
+
2. Randomly displacing these points (like pushing/pulling the pins)
|
|
1042
|
+
3. Computing a smooth interpolation (like the plate bending)
|
|
1043
|
+
4. Applying the resulting deformation to the image
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
Args:
|
|
1047
|
+
scale_range (tuple[float, float]): Range for random displacement of control points.
|
|
1048
|
+
Values should be in [0.0, 1.0]:
|
|
1049
|
+
- 0.0: No displacement (identity transform)
|
|
1050
|
+
- 0.1: Subtle warping
|
|
1051
|
+
- 0.2-0.4: Moderate deformation (recommended range)
|
|
1052
|
+
- 0.5+: Strong warping
|
|
1053
|
+
Default: (0.2, 0.4)
|
|
1054
|
+
|
|
1055
|
+
num_control_points (int): Number of control points per side.
|
|
1056
|
+
Creates a grid of num_control_points x num_control_points points.
|
|
1057
|
+
- 2: Minimal deformation (affine-like)
|
|
1058
|
+
- 3-4: Moderate flexibility (recommended)
|
|
1059
|
+
- 5+: More local deformation control
|
|
1060
|
+
Must be >= 2. Default: 4
|
|
1061
|
+
|
|
1062
|
+
interpolation (int): OpenCV interpolation flag. Used for image sampling.
|
|
1063
|
+
See also: cv2.INTER_*
|
|
1064
|
+
Default: cv2.INTER_LINEAR
|
|
1065
|
+
|
|
1066
|
+
mask_interpolation (int): OpenCV interpolation flag. Used for mask sampling.
|
|
1067
|
+
See also: cv2.INTER_*
|
|
1068
|
+
Default: cv2.INTER_NEAREST
|
|
1069
|
+
|
|
1070
|
+
keypoint_remapping_method (Literal["direct", "mask"]): Method to use for keypoint remapping.
|
|
1071
|
+
- "mask": Uses mask-based remapping. Faster, especially for many keypoints, but may be
|
|
1072
|
+
less accurate for large distortions. Recommended for large images or many keypoints.
|
|
1073
|
+
- "direct": Uses inverse mapping. More accurate for large distortions but slower.
|
|
1074
|
+
Default: "mask"
|
|
1075
|
+
|
|
1076
|
+
p (float): Probability of applying the transform. Default: 0.5
|
|
1077
|
+
|
|
1078
|
+
Targets:
|
|
1079
|
+
image, mask, keypoints, bboxes, volume, mask3d
|
|
1080
|
+
|
|
1081
|
+
Image types:
|
|
1082
|
+
uint8, float32
|
|
1083
|
+
|
|
1084
|
+
Note:
|
|
1085
|
+
- The transformation preserves smoothness and continuity
|
|
1086
|
+
- Stronger scale values may create more extreme deformations
|
|
1087
|
+
- Higher number of control points allows more local deformations
|
|
1088
|
+
- The same deformation is applied consistently to all targets
|
|
1089
|
+
|
|
1090
|
+
Examples:
|
|
1091
|
+
>>> import numpy as np
|
|
1092
|
+
>>> import albumentations as A
|
|
1093
|
+
>>> import cv2
|
|
1094
|
+
>>>
|
|
1095
|
+
>>> # Create sample data
|
|
1096
|
+
>>> image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
1097
|
+
>>> mask = np.zeros((100, 100), dtype=np.uint8)
|
|
1098
|
+
>>> mask[25:75, 25:75] = 1 # Square mask
|
|
1099
|
+
>>> bboxes = np.array([[10, 10, 40, 40]]) # Single box
|
|
1100
|
+
>>> bbox_labels = [1]
|
|
1101
|
+
>>> keypoints = np.array([[50, 50]]) # Single keypoint at center
|
|
1102
|
+
>>> keypoint_labels = [0]
|
|
1103
|
+
>>>
|
|
1104
|
+
>>> # Set up transform with Compose to handle all targets
|
|
1105
|
+
>>> transform = A.Compose([
|
|
1106
|
+
... A.ThinPlateSpline(scale_range=(0.2, 0.4), p=1.0)
|
|
1107
|
+
... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
|
|
1108
|
+
... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
|
|
1109
|
+
>>>
|
|
1110
|
+
>>> # Apply to all targets
|
|
1111
|
+
>>> result = transform(
|
|
1112
|
+
... image=image,
|
|
1113
|
+
... mask=mask,
|
|
1114
|
+
... bboxes=bboxes,
|
|
1115
|
+
... bbox_labels=bbox_labels,
|
|
1116
|
+
... keypoints=keypoints,
|
|
1117
|
+
... keypoint_labels=keypoint_labels
|
|
1118
|
+
... )
|
|
1119
|
+
>>>
|
|
1120
|
+
>>> # Access transformed results
|
|
1121
|
+
>>> transformed_image = result['image']
|
|
1122
|
+
>>> transformed_mask = result['mask']
|
|
1123
|
+
>>> transformed_bboxes = result['bboxes']
|
|
1124
|
+
>>> transformed_bbox_labels = result['bbox_labels']
|
|
1125
|
+
>>> transformed_keypoints = result['keypoints']
|
|
1126
|
+
>>> transformed_keypoint_labels = result['keypoint_labels']
|
|
1127
|
+
|
|
1128
|
+
References:
|
|
1129
|
+
- "Principal Warps: Thin-Plate Splines and the Decomposition of Deformations"
|
|
1130
|
+
by F.L. Bookstein
|
|
1131
|
+
https://doi.org/10.1109/34.24792
|
|
1132
|
+
|
|
1133
|
+
- Thin Plate Splines in Computer Vision:
|
|
1134
|
+
https://en.wikipedia.org/wiki/Thin_plate_spline
|
|
1135
|
+
|
|
1136
|
+
- Similar implementation in Kornia:
|
|
1137
|
+
https://kornia.readthedocs.io/en/latest/augmentation.html#kornia.augmentation.RandomThinPlateSpline
|
|
1138
|
+
|
|
1139
|
+
See Also:
|
|
1140
|
+
- ElasticTransform: For different type of non-rigid deformation
|
|
1141
|
+
- GridDistortion: For grid-based warping
|
|
1142
|
+
- OpticalDistortion: For lens-like distortions
|
|
1143
|
+
|
|
1144
|
+
"""
|
|
1145
|
+
|
|
1146
|
+
class InitSchema(BaseDistortion.InitSchema):
|
|
1147
|
+
scale_range: Annotated[tuple[float, float], AfterValidator(check_range_bounds(0, 1))]
|
|
1148
|
+
num_control_points: int = Field(ge=2)
|
|
1149
|
+
keypoint_remapping_method: Literal["direct", "mask"]
|
|
1150
|
+
|
|
1151
|
+
def __init__(
|
|
1152
|
+
self,
|
|
1153
|
+
scale_range: tuple[float, float] = (0.2, 0.4),
|
|
1154
|
+
num_control_points: int = 4,
|
|
1155
|
+
interpolation: Literal[
|
|
1156
|
+
cv2.INTER_NEAREST,
|
|
1157
|
+
cv2.INTER_LINEAR,
|
|
1158
|
+
cv2.INTER_CUBIC,
|
|
1159
|
+
cv2.INTER_AREA,
|
|
1160
|
+
cv2.INTER_LANCZOS4,
|
|
1161
|
+
] = cv2.INTER_LINEAR,
|
|
1162
|
+
mask_interpolation: Literal[
|
|
1163
|
+
cv2.INTER_NEAREST,
|
|
1164
|
+
cv2.INTER_LINEAR,
|
|
1165
|
+
cv2.INTER_CUBIC,
|
|
1166
|
+
cv2.INTER_AREA,
|
|
1167
|
+
cv2.INTER_LANCZOS4,
|
|
1168
|
+
] = cv2.INTER_NEAREST,
|
|
1169
|
+
keypoint_remapping_method: Literal["direct", "mask"] = "mask",
|
|
1170
|
+
p: float = 0.5,
|
|
1171
|
+
border_mode: Literal[
|
|
1172
|
+
cv2.BORDER_CONSTANT,
|
|
1173
|
+
cv2.BORDER_REPLICATE,
|
|
1174
|
+
cv2.BORDER_REFLECT,
|
|
1175
|
+
cv2.BORDER_WRAP,
|
|
1176
|
+
cv2.BORDER_REFLECT_101,
|
|
1177
|
+
] = cv2.BORDER_CONSTANT,
|
|
1178
|
+
fill: tuple[float, ...] | float = 0,
|
|
1179
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
1180
|
+
):
|
|
1181
|
+
super().__init__(
|
|
1182
|
+
interpolation=interpolation,
|
|
1183
|
+
mask_interpolation=mask_interpolation,
|
|
1184
|
+
keypoint_remapping_method=keypoint_remapping_method,
|
|
1185
|
+
p=p,
|
|
1186
|
+
border_mode=border_mode,
|
|
1187
|
+
fill=fill,
|
|
1188
|
+
fill_mask=fill_mask,
|
|
1189
|
+
)
|
|
1190
|
+
self.scale_range = scale_range
|
|
1191
|
+
self.num_control_points = num_control_points
|
|
1192
|
+
|
|
1193
|
+
def get_params_dependent_on_data(
|
|
1194
|
+
self,
|
|
1195
|
+
params: dict[str, Any],
|
|
1196
|
+
data: dict[str, Any],
|
|
1197
|
+
) -> dict[str, Any]:
|
|
1198
|
+
"""Get the parameters dependent on the data.
|
|
1199
|
+
|
|
1200
|
+
Args:
|
|
1201
|
+
params (dict[str, Any]): Parameters.
|
|
1202
|
+
data (dict[str, Any]): Data.
|
|
1203
|
+
|
|
1204
|
+
Returns:
|
|
1205
|
+
dict[str, Any]: Parameters.
|
|
1206
|
+
|
|
1207
|
+
"""
|
|
1208
|
+
height, width = params["shape"][:2]
|
|
1209
|
+
src_points = fgeometric.generate_control_points(self.num_control_points)
|
|
1210
|
+
|
|
1211
|
+
# Add random displacement to destination points
|
|
1212
|
+
scale = self.py_random.uniform(*self.scale_range) / 10
|
|
1213
|
+
dst_points = src_points + self.random_generator.normal(
|
|
1214
|
+
0,
|
|
1215
|
+
scale,
|
|
1216
|
+
src_points.shape,
|
|
1217
|
+
)
|
|
1218
|
+
|
|
1219
|
+
# Compute TPS weights
|
|
1220
|
+
weights, affine = fgeometric.compute_tps_weights(src_points, dst_points)
|
|
1221
|
+
|
|
1222
|
+
# Create grid of points
|
|
1223
|
+
x, y = np.meshgrid(np.arange(width), np.arange(height))
|
|
1224
|
+
points = np.stack([x.flatten(), y.flatten()], axis=1).astype(np.float32)
|
|
1225
|
+
|
|
1226
|
+
# Transform points
|
|
1227
|
+
transformed = fgeometric.tps_transform(
|
|
1228
|
+
points / [width, height],
|
|
1229
|
+
src_points,
|
|
1230
|
+
weights,
|
|
1231
|
+
affine,
|
|
1232
|
+
)
|
|
1233
|
+
transformed *= [width, height]
|
|
1234
|
+
|
|
1235
|
+
return {
|
|
1236
|
+
"map_x": transformed[:, 0].reshape(height, width).astype(np.float32),
|
|
1237
|
+
"map_y": transformed[:, 1].reshape(height, width).astype(np.float32),
|
|
1238
|
+
}
|