nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1962 @@
1
+ """Geometric transformation classes for image augmentation.
2
+
3
+ This module provides a collection of transforms that modify the geometric properties
4
+ of images and associated data (masks, bounding boxes, keypoints). Includes implementations
5
+ for flipping, transposing, affine transformations, distortions, padding, and more complex
6
+ transformations like grid shuffling and thin plate splines.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ from typing import Annotated, Any, Literal, cast
13
+ from warnings import warn
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from albucore import batch_transform, is_grayscale_image, is_rgb_image
18
+ from pydantic import (
19
+ AfterValidator,
20
+ Field,
21
+ ValidationInfo,
22
+ field_validator,
23
+ model_validator,
24
+ )
25
+ from typing_extensions import Self
26
+
27
+ from albumentations.augmentations.utils import check_range
28
+ from albumentations.core.bbox_utils import (
29
+ BboxProcessor,
30
+ denormalize_bboxes,
31
+ normalize_bboxes,
32
+ )
33
+ from albumentations.core.pydantic import (
34
+ NonNegativeFloatRangeType,
35
+ OnePlusIntRangeType,
36
+ SymmetricRangeType,
37
+ check_range_bounds,
38
+ )
39
+ from albumentations.core.transforms_interface import (
40
+ BaseTransformInitSchema,
41
+ DualTransform,
42
+ )
43
+ from albumentations.core.type_definitions import ALL_TARGETS
44
+ from albumentations.core.utils import to_tuple
45
+
46
+ from . import functional as fgeometric
47
+
48
+ __all__ = [
49
+ "Affine",
50
+ "GridElasticDeform",
51
+ "Morphological",
52
+ "Perspective",
53
+ "RandomGridShuffle",
54
+ "ShiftScaleRotate",
55
+ ]
56
+
57
+ NUM_PADS_XY = 2
58
+ NUM_PADS_ALL_SIDES = 4
59
+
60
+
61
+ class Perspective(DualTransform):
62
+ """Apply random four point perspective transformation to the input.
63
+
64
+ Args:
65
+ scale (float or tuple of float): Standard deviation of the normal distributions. These are used to sample
66
+ the random distances of the subimage's corners from the full image's corners.
67
+ If scale is a single float value, the range will be (0, scale).
68
+ Default: (0.05, 0.1).
69
+ keep_size (bool): Whether to resize image back to its original size after applying the perspective transform.
70
+ If set to False, the resulting images may end up having different shapes.
71
+ Default: True.
72
+ border_mode (OpenCV flag): OpenCV border mode used for padding.
73
+ Default: cv2.BORDER_CONSTANT.
74
+ fill (tuple[float, ...] | float): Padding value if border_mode is cv2.BORDER_CONSTANT.
75
+ Default: 0.
76
+ fill_mask (tuple[float, ...] | float): Padding value for mask if border_mode is
77
+ cv2.BORDER_CONSTANT. Default: 0.
78
+ fit_output (bool): If True, the image plane size and position will be adjusted to still capture
79
+ the whole image after perspective transformation. This is followed by image resizing if keep_size is set
80
+ to True. If False, parts of the transformed image may be outside of the image plane.
81
+ This setting should not be set to True when using large scale values as it could lead to very large images.
82
+ Default: False.
83
+ interpolation (int): Interpolation method to be used for image transformation. Should be one
84
+ of the OpenCV interpolation types. Default: cv2.INTER_LINEAR
85
+ mask_interpolation (int): Flag that is used to specify the interpolation algorithm for mask.
86
+ Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
87
+ Default: cv2.INTER_NEAREST.
88
+ p (float): Probability of applying the transform. Default: 0.5.
89
+
90
+ Targets:
91
+ image, mask, keypoints, bboxes, volume, mask3d
92
+
93
+ Image types:
94
+ uint8, float32
95
+
96
+ Note:
97
+ This transformation creates a perspective effect by randomly moving the four corners of the image.
98
+ The amount of movement is controlled by the 'scale' parameter.
99
+
100
+ When 'keep_size' is True, the output image will have the same size as the input image,
101
+ which may cause some parts of the transformed image to be cut off or padded.
102
+
103
+ When 'fit_output' is True, the transformation ensures that the entire transformed image is visible,
104
+ which may result in a larger output image if keep_size is False.
105
+
106
+ Examples:
107
+ >>> import numpy as np
108
+ >>> import albumentations as A
109
+ >>> import cv2
110
+ >>>
111
+ >>> # Prepare sample data
112
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
113
+ >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
114
+ >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
115
+ >>> bbox_labels = [1, 2]
116
+ >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
117
+ >>> keypoint_labels = [0, 1]
118
+ >>>
119
+ >>> # Define transform with parameters as tuples when possible
120
+ >>> transform = A.Compose([
121
+ ... A.Perspective(
122
+ ... scale=(0.05, 0.1),
123
+ ... keep_size=True,
124
+ ... fit_output=False,
125
+ ... border_mode=cv2.BORDER_CONSTANT,
126
+ ... p=1.0
127
+ ... ),
128
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
129
+ ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
130
+ >>>
131
+ >>> # Apply the transform
132
+ >>> transformed = transform(
133
+ ... image=image,
134
+ ... mask=mask,
135
+ ... bboxes=bboxes,
136
+ ... bbox_labels=bbox_labels,
137
+ ... keypoints=keypoints,
138
+ ... keypoint_labels=keypoint_labels
139
+ ... )
140
+ >>>
141
+ >>> # Get the transformed data
142
+ >>> transformed_image = transformed['image'] # Perspective-transformed image
143
+ >>> transformed_mask = transformed['mask'] # Perspective-transformed mask
144
+ >>> transformed_bboxes = transformed['bboxes'] # Perspective-transformed bounding boxes
145
+ >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
146
+ >>> transformed_keypoints = transformed['keypoints'] # Perspective-transformed keypoints
147
+ >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
148
+
149
+ """
150
+
151
+ _targets = ALL_TARGETS
152
+
153
+ class InitSchema(BaseTransformInitSchema):
154
+ scale: NonNegativeFloatRangeType
155
+ keep_size: bool
156
+ fit_output: bool
157
+ interpolation: Literal[
158
+ cv2.INTER_NEAREST,
159
+ cv2.INTER_LINEAR,
160
+ cv2.INTER_CUBIC,
161
+ cv2.INTER_AREA,
162
+ cv2.INTER_LANCZOS4,
163
+ ]
164
+ mask_interpolation: Literal[
165
+ cv2.INTER_NEAREST,
166
+ cv2.INTER_LINEAR,
167
+ cv2.INTER_CUBIC,
168
+ cv2.INTER_AREA,
169
+ cv2.INTER_LANCZOS4,
170
+ ]
171
+ fill: tuple[float, ...] | float
172
+ fill_mask: tuple[float, ...] | float
173
+ border_mode: Literal[
174
+ cv2.BORDER_CONSTANT,
175
+ cv2.BORDER_REPLICATE,
176
+ cv2.BORDER_REFLECT,
177
+ cv2.BORDER_WRAP,
178
+ cv2.BORDER_REFLECT_101,
179
+ ]
180
+
181
+ def __init__(
182
+ self,
183
+ scale: tuple[float, float] | float = (0.05, 0.1),
184
+ keep_size: bool = True,
185
+ fit_output: bool = False,
186
+ interpolation: Literal[
187
+ cv2.INTER_NEAREST,
188
+ cv2.INTER_LINEAR,
189
+ cv2.INTER_CUBIC,
190
+ cv2.INTER_AREA,
191
+ cv2.INTER_LANCZOS4,
192
+ ] = cv2.INTER_LINEAR,
193
+ mask_interpolation: Literal[
194
+ cv2.INTER_NEAREST,
195
+ cv2.INTER_LINEAR,
196
+ cv2.INTER_CUBIC,
197
+ cv2.INTER_AREA,
198
+ cv2.INTER_LANCZOS4,
199
+ ] = cv2.INTER_NEAREST,
200
+ border_mode: Literal[
201
+ cv2.BORDER_CONSTANT,
202
+ cv2.BORDER_REPLICATE,
203
+ cv2.BORDER_REFLECT,
204
+ cv2.BORDER_WRAP,
205
+ cv2.BORDER_REFLECT_101,
206
+ ] = cv2.BORDER_CONSTANT,
207
+ fill: tuple[float, ...] | float = 0,
208
+ fill_mask: tuple[float, ...] | float = 0,
209
+ p: float = 0.5,
210
+ ):
211
+ super().__init__(p)
212
+ self.scale = cast("tuple[float, float]", scale)
213
+ self.keep_size = keep_size
214
+ self.border_mode = border_mode
215
+ self.fill = fill
216
+ self.fill_mask = fill_mask
217
+ self.fit_output = fit_output
218
+ self.interpolation = interpolation
219
+ self.mask_interpolation = mask_interpolation
220
+
221
+ def apply(
222
+ self,
223
+ img: np.ndarray,
224
+ matrix: np.ndarray,
225
+ max_height: int,
226
+ max_width: int,
227
+ **params: Any,
228
+ ) -> np.ndarray:
229
+ """Apply the perspective transform to an image.
230
+
231
+ Args:
232
+ img (np.ndarray): Image to be distorted.
233
+ matrix (np.ndarray): Transformation matrix.
234
+ max_height (int): Maximum height of the image.
235
+ max_width (int): Maximum width of the image.
236
+ **params (Any): Additional parameters.
237
+
238
+ Returns:
239
+ np.ndarray: Distorted image.
240
+
241
+ """
242
+ return fgeometric.perspective(
243
+ img,
244
+ matrix,
245
+ max_width,
246
+ max_height,
247
+ self.fill,
248
+ self.border_mode,
249
+ self.keep_size,
250
+ self.interpolation,
251
+ )
252
+
253
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
254
+ def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
255
+ """Apply the perspective transform to a batch of images.
256
+
257
+ Args:
258
+ images (np.ndarray): Batch of images to be distorted.
259
+ **params (Any): Additional parameters.
260
+
261
+ Returns:
262
+ np.ndarray: Batch of distorted images.
263
+
264
+ """
265
+ return self.apply(images, **params)
266
+
267
+ @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
268
+ def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
269
+ """Apply the perspective transform to a volume.
270
+
271
+ Args:
272
+ volume (np.ndarray): Volume to be distorted.
273
+ **params (Any): Additional parameters.
274
+
275
+ Returns:
276
+ np.ndarray: Distorted volume.
277
+
278
+ """
279
+ return self.apply(volume, **params)
280
+
281
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
282
+ def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
283
+ """Apply the perspective transform to a batch of volumes.
284
+
285
+ Args:
286
+ volumes (np.ndarray): Batch of volumes to be distorted.
287
+ **params (Any): Additional parameters.
288
+
289
+ Returns:
290
+ np.ndarray: Batch of distorted volumes.
291
+
292
+ """
293
+ return self.apply(volumes, **params)
294
+
295
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
296
+ def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
297
+ """Apply the perspective transform to a 3D mask.
298
+
299
+ Args:
300
+ mask3d (np.ndarray): 3D mask to be distorted.
301
+ **params (Any): Additional parameters.
302
+
303
+ Returns:
304
+ np.ndarray: Distorted 3D mask.
305
+
306
+ """
307
+ return self.apply_to_mask(mask3d, **params)
308
+
309
+ def apply_to_mask(
310
+ self,
311
+ mask: np.ndarray,
312
+ matrix: np.ndarray,
313
+ max_height: int,
314
+ max_width: int,
315
+ **params: Any,
316
+ ) -> np.ndarray:
317
+ """Apply the perspective transform to a mask.
318
+
319
+ Args:
320
+ mask (np.ndarray): Mask to be distorted.
321
+ matrix (np.ndarray): Transformation matrix.
322
+ max_height (int): Maximum height of the mask.
323
+ max_width (int): Maximum width of the mask.
324
+ **params (Any): Additional parameters.
325
+
326
+ Returns:
327
+ np.ndarray: Distorted mask.
328
+
329
+ """
330
+ return fgeometric.perspective(
331
+ mask,
332
+ matrix,
333
+ max_width,
334
+ max_height,
335
+ self.fill_mask,
336
+ self.border_mode,
337
+ self.keep_size,
338
+ self.mask_interpolation,
339
+ )
340
+
341
+ def apply_to_bboxes(
342
+ self,
343
+ bboxes: np.ndarray,
344
+ matrix_bbox: np.ndarray,
345
+ max_height: int,
346
+ max_width: int,
347
+ **params: Any,
348
+ ) -> np.ndarray:
349
+ """Apply the perspective transform to a batch of bounding boxes.
350
+
351
+ Args:
352
+ bboxes (np.ndarray): Batch of bounding boxes to be distorted.
353
+ matrix_bbox (np.ndarray): Transformation matrix.
354
+ max_height (int): Maximum height of the bounding boxes.
355
+ max_width (int): Maximum width of the bounding boxes.
356
+ **params (Any): Additional parameters.
357
+
358
+ Returns:
359
+ np.ndarray: Batch of distorted bounding boxes.
360
+
361
+ """
362
+ return fgeometric.perspective_bboxes(
363
+ bboxes,
364
+ params["shape"],
365
+ matrix_bbox,
366
+ max_width,
367
+ max_height,
368
+ self.keep_size,
369
+ )
370
+
371
+ def apply_to_keypoints(
372
+ self,
373
+ keypoints: np.ndarray,
374
+ matrix: np.ndarray,
375
+ max_height: int,
376
+ max_width: int,
377
+ **params: Any,
378
+ ) -> np.ndarray:
379
+ """Apply the perspective transform to a batch of keypoints.
380
+
381
+ Args:
382
+ keypoints (np.ndarray): Batch of keypoints to be distorted.
383
+ matrix (np.ndarray): Transformation matrix.
384
+ max_height (int): Maximum height of the keypoints.
385
+ max_width (int): Maximum width of the keypoints.
386
+ **params (Any): Additional parameters.
387
+
388
+ Returns:
389
+ np.ndarray: Batch of distorted keypoints.
390
+
391
+ """
392
+ return fgeometric.perspective_keypoints(
393
+ keypoints,
394
+ params["shape"],
395
+ matrix,
396
+ max_width,
397
+ max_height,
398
+ self.keep_size,
399
+ )
400
+
401
+ def get_params_dependent_on_data(
402
+ self,
403
+ params: dict[str, Any],
404
+ data: dict[str, Any],
405
+ ) -> dict[str, Any]:
406
+ """Get the parameters dependent on the data.
407
+
408
+ Args:
409
+ params (dict[str, Any]): Parameters.
410
+ data (dict[str, Any]): Data.
411
+
412
+ Returns:
413
+ dict[str, Any]: Parameters.
414
+
415
+ """
416
+ image_shape = params["shape"][:2]
417
+ scale = self.py_random.uniform(*self.scale)
418
+
419
+ points = fgeometric.generate_perspective_points(
420
+ image_shape,
421
+ scale,
422
+ self.random_generator,
423
+ )
424
+ points = fgeometric.order_points(points)
425
+
426
+ matrix, max_width, max_height = fgeometric.compute_perspective_params(
427
+ points,
428
+ image_shape,
429
+ )
430
+
431
+ if self.fit_output:
432
+ matrix, max_width, max_height = fgeometric.expand_transform(
433
+ matrix,
434
+ image_shape,
435
+ )
436
+
437
+ return {
438
+ "matrix": matrix,
439
+ "max_height": max_height,
440
+ "max_width": max_width,
441
+ "matrix_bbox": matrix,
442
+ }
443
+
444
+
445
+ class Affine(DualTransform):
446
+ """Augmentation to apply affine transformations to images.
447
+
448
+ Affine transformations involve:
449
+
450
+ - Translation ("move" image on the x-/y-axis)
451
+ - Rotation
452
+ - Scaling ("zoom" in/out)
453
+ - Shear (move one side of the image, turning a square into a trapezoid)
454
+
455
+ All such transformations can create "new" pixels in the image without a defined content, e.g.
456
+ if the image is translated to the left, pixels are created on the right.
457
+ A method has to be defined to deal with these pixel values.
458
+ The parameters `fill` and `fill_mask` of this class deal with this.
459
+
460
+ Some transformations involve interpolations between several pixels
461
+ of the input image to generate output pixel values. The parameters `interpolation` and
462
+ `mask_interpolation` deals with the method of interpolation used for this.
463
+
464
+ Args:
465
+ scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and
466
+ ``0.5`` is zoomed out to ``50`` percent of the original size.
467
+ * If a single number, then that value will be used for all images.
468
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
469
+ That the same range will be used for both x- and y-axis. To keep the aspect ratio, set
470
+ ``keep_ratio=True``, then the same value will be used for both x- and y-axis.
471
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
472
+ Each of these keys can have the same values as described above.
473
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
474
+ *independently* per axis, resulting in samples that differ between the axes. Note that when
475
+ the ``keep_ratio=True``, the x- and y-axis ranges should be the same.
476
+ translate_percent (None, number, tuple of number or dict): Translation as a fraction of the image height/width
477
+ (x-translation, y-translation), where ``0`` denotes "no change"
478
+ and ``0.5`` denotes "half of the axis size".
479
+ * If ``None`` then equivalent to ``0.0`` unless `translate_px` has a value other than ``None``.
480
+ * If a single number, then that value will be used for all images.
481
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
482
+ That sampled fraction value will be used identically for both x- and y-axis.
483
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
484
+ Each of these keys can have the same values as described above.
485
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
486
+ *independently* per axis, resulting in samples that differ between the axes.
487
+ translate_px (None, int, tuple of int or dict): Translation in pixels.
488
+ * If ``None`` then equivalent to ``0`` unless `translate_percent` has a value other than ``None``.
489
+ * If a single int, then that value will be used for all images.
490
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from
491
+ the discrete interval ``[a..b]``. That number will be used identically for both x- and y-axis.
492
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
493
+ Each of these keys can have the same values as described above.
494
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
495
+ *independently* per axis, resulting in samples that differ between the axes.
496
+ rotate (number or tuple of number): Rotation in degrees (**NOT** radians), i.e. expected value range is
497
+ around ``[-360, 360]``. Rotation happens around the *center* of the image,
498
+ not the top left corner as in some other frameworks.
499
+ * If a number, then that value will be used for all images.
500
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``
501
+ and used as the rotation value.
502
+ shear (number, tuple of number or dict): Shear in degrees (**NOT** radians), i.e. expected value range is
503
+ around ``[-360, 360]``, with reasonable values being in the range of ``[-45, 45]``.
504
+ * If a number, then that value will be used for all images as
505
+ the shear on the x-axis (no shear on the y-axis will be done).
506
+ * If a tuple ``(a, b)``, then two value will be uniformly sampled per image
507
+ from the interval ``[a, b]`` and be used as the x- and y-shear value.
508
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
509
+ Each of these keys can have the same values as described above.
510
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
511
+ *independently* per axis, resulting in samples that differ between the axes.
512
+ interpolation (int): OpenCV interpolation flag.
513
+ mask_interpolation (int): OpenCV interpolation flag.
514
+ fill (tuple[float, ...] | float): The constant value to use when filling in newly created pixels.
515
+ (E.g. translating by 1px to the right will create a new 1px-wide column of pixels
516
+ on the left of the image).
517
+ The value is only used when `mode=constant`. The expected value range is ``[0, 255]`` for ``uint8`` images.
518
+ fill_mask (tuple[float, ...] | float): Same as fill but only for masks.
519
+ border_mode (int): OpenCV border flag.
520
+ fit_output (bool): If True, the image plane size and position will be adjusted to tightly capture
521
+ the whole image after affine transformation (`translate_percent` and `translate_px` are ignored).
522
+ Otherwise (``False``), parts of the transformed image may end up outside the image plane.
523
+ Fitting the output shape can be useful to avoid corners of the image being outside the image plane
524
+ after applying rotations. Default: False
525
+ keep_ratio (bool): When True, the original aspect ratio will be kept when the random scale is applied.
526
+ Default: False.
527
+ rotate_method (Literal["largest_box", "ellipse"]): rotation method used for the bounding boxes.
528
+ Should be one of "largest_box" or "ellipse"[1]. Default: "largest_box"
529
+ balanced_scale (bool): When True, scaling factors are chosen to be either entirely below or above 1,
530
+ ensuring balanced scaling. Default: False.
531
+
532
+ This is important because without it, scaling tends to lean towards upscaling. For example, if we want
533
+ the image to zoom in and out by 2x, we may pick an interval [0.5, 2]. Since the interval [0.5, 1] is
534
+ three times smaller than [1, 2], values above 1 are picked three times more often if sampled directly
535
+ from [0.5, 2]. With `balanced_scale`, the function ensures that half the time, the scaling
536
+ factor is picked from below 1 (zooming out), and the other half from above 1 (zooming in).
537
+ This makes the zooming in and out process more balanced.
538
+ p (float): probability of applying the transform. Default: 0.5.
539
+
540
+ Targets:
541
+ image, mask, keypoints, bboxes, volume, mask3d
542
+
543
+ Image types:
544
+ uint8, float32
545
+
546
+ References:
547
+ Towards Rotation Invariance in Object Detection: https://arxiv.org/abs/2109.13488
548
+
549
+ Examples:
550
+ >>> import numpy as np
551
+ >>> import albumentations as A
552
+ >>> import cv2
553
+ >>>
554
+ >>> # Prepare sample data
555
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
556
+ >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
557
+ >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
558
+ >>> bbox_labels = [1, 2]
559
+ >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
560
+ >>> keypoint_labels = [0, 1]
561
+ >>>
562
+ >>> # Define transform with different parameter types
563
+ >>> transform = A.Compose([
564
+ ... A.Affine(
565
+ ... # Tuple for scale (will be used for both x and y)
566
+ ... scale=(0.8, 1.2),
567
+ ... # Dictionary with tuples for different x/y translations
568
+ ... translate_percent={"x": (-0.2, 0.2), "y": (-0.1, 0.1)},
569
+ ... # Tuple for rotation range
570
+ ... rotate=(-30, 30),
571
+ ... # Dictionary with tuples for different x/y shearing
572
+ ... shear={"x": (-10, 10), "y": (-5, 5)},
573
+ ... # Interpolation methods
574
+ ... interpolation=cv2.INTER_LINEAR,
575
+ ... mask_interpolation=cv2.INTER_NEAREST,
576
+ ... # Other parameters
577
+ ... fit_output=False,
578
+ ... keep_ratio=True,
579
+ ... rotate_method="largest_box",
580
+ ... balanced_scale=True,
581
+ ... border_mode=cv2.BORDER_CONSTANT,
582
+ ... fill=0,
583
+ ... fill_mask=0,
584
+ ... p=1.0
585
+ ... ),
586
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
587
+ ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
588
+ >>>
589
+ >>> # Apply the transform
590
+ >>> transformed = transform(
591
+ ... image=image,
592
+ ... mask=mask,
593
+ ... bboxes=bboxes,
594
+ ... bbox_labels=bbox_labels,
595
+ ... keypoints=keypoints,
596
+ ... keypoint_labels=keypoint_labels
597
+ ... )
598
+ >>>
599
+ >>> # Get the transformed data
600
+ >>> transformed_image = transformed['image'] # Image with affine transforms applied
601
+ >>> transformed_mask = transformed['mask'] # Mask with affine transforms applied
602
+ >>> transformed_bboxes = transformed['bboxes'] # Bounding boxes with affine transforms applied
603
+ >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
604
+ >>> transformed_keypoints = transformed['keypoints'] # Keypoints with affine transforms applied
605
+ >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
606
+ >>>
607
+ >>> # Simpler example with only essential parameters
608
+ >>> simple_transform = A.Compose([
609
+ ... A.Affine(
610
+ ... scale=1.1, # Single scalar value for scale
611
+ ... rotate=15, # Single scalar value for rotation (degrees)
612
+ ... translate_px=30, # Single scalar value for translation (pixels)
613
+ ... p=1.0
614
+ ... ),
615
+ ... ])
616
+ >>> simple_result = simple_transform(image=image)
617
+ >>> simple_transformed = simple_result['image']
618
+
619
+ """
620
+
621
+ _targets = ALL_TARGETS
622
+
623
+ class InitSchema(BaseTransformInitSchema):
624
+ scale: tuple[float, float] | float | dict[str, float | tuple[float, float]]
625
+ translate_percent: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None
626
+ translate_px: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None
627
+ rotate: tuple[float, float] | float
628
+ shear: tuple[float, float] | float | dict[str, float | tuple[float, float]]
629
+ interpolation: Literal[
630
+ cv2.INTER_NEAREST,
631
+ cv2.INTER_LINEAR,
632
+ cv2.INTER_CUBIC,
633
+ cv2.INTER_AREA,
634
+ cv2.INTER_LANCZOS4,
635
+ ]
636
+ mask_interpolation: Literal[
637
+ cv2.INTER_NEAREST,
638
+ cv2.INTER_LINEAR,
639
+ cv2.INTER_CUBIC,
640
+ cv2.INTER_AREA,
641
+ cv2.INTER_LANCZOS4,
642
+ ]
643
+
644
+ fill: tuple[float, ...] | float
645
+ fill_mask: tuple[float, ...] | float
646
+ border_mode: Literal[
647
+ cv2.BORDER_CONSTANT,
648
+ cv2.BORDER_REPLICATE,
649
+ cv2.BORDER_REFLECT,
650
+ cv2.BORDER_WRAP,
651
+ cv2.BORDER_REFLECT_101,
652
+ ]
653
+
654
+ fit_output: bool
655
+ keep_ratio: bool
656
+ rotate_method: Literal["largest_box", "ellipse"]
657
+ balanced_scale: bool
658
+
659
+ @field_validator("shear", "scale")
660
+ @classmethod
661
+ def _process_shear(
662
+ cls,
663
+ value: tuple[float, float] | float | dict[str, float | tuple[float, float]],
664
+ info: ValidationInfo,
665
+ ) -> dict[str, tuple[float, float]]:
666
+ return cls._handle_dict_arg(value, info.field_name)
667
+
668
+ @field_validator("rotate")
669
+ @classmethod
670
+ def _process_rotate(
671
+ cls,
672
+ value: tuple[float, float] | float,
673
+ ) -> tuple[float, float]:
674
+ return to_tuple(value, value)
675
+
676
+ @model_validator(mode="after")
677
+ def _handle_translate(self) -> Self:
678
+ if self.translate_percent is None and self.translate_px is None:
679
+ self.translate_px = 0
680
+
681
+ if self.translate_percent is not None and self.translate_px is not None:
682
+ msg = "Expected either translate_percent or translate_px to be provided, but both were provided."
683
+ raise ValueError(msg)
684
+
685
+ if self.translate_percent is not None:
686
+ self.translate_percent = self._handle_dict_arg(
687
+ self.translate_percent,
688
+ "translate_percent",
689
+ default=0.0,
690
+ ) # type: ignore[assignment]
691
+
692
+ if self.translate_px is not None:
693
+ self.translate_px = self._handle_dict_arg(
694
+ self.translate_px,
695
+ "translate_px",
696
+ default=0,
697
+ ) # type: ignore[assignment]
698
+
699
+ return self
700
+
701
+ @staticmethod
702
+ def _handle_dict_arg(
703
+ val: tuple[float, float]
704
+ | dict[str, float | tuple[float, float]]
705
+ | float
706
+ | tuple[int, int]
707
+ | dict[str, int | tuple[int, int]],
708
+ name: str | None,
709
+ default: float = 1.0,
710
+ ) -> dict[str, tuple[float, float]]:
711
+ if isinstance(val, float):
712
+ return {"x": (val, val), "y": (val, val)}
713
+ if isinstance(val, dict):
714
+ if "x" not in val and "y" not in val:
715
+ raise ValueError(
716
+ f'Expected {name} dictionary to contain at least key "x" or key "y". Found neither of them.',
717
+ )
718
+ x = val.get("x", default)
719
+ y = val.get("y", default)
720
+ return {"x": to_tuple(x, x), "y": to_tuple(y, y)}
721
+ return {"x": to_tuple(val, val), "y": to_tuple(val, val)}
722
+
723
+ def __init__(
724
+ self,
725
+ scale: tuple[float, float] | float | dict[str, float | tuple[float, float]] = (1.0, 1.0),
726
+ translate_percent: tuple[float, float] | float | dict[str, float | tuple[float, float]] | None = None,
727
+ translate_px: tuple[int, int] | int | dict[str, int | tuple[int, int]] | None = None,
728
+ rotate: tuple[float, float] | float = 0.0,
729
+ shear: tuple[float, float] | float | dict[str, float | tuple[float, float]] = (0.0, 0.0),
730
+ interpolation: Literal[
731
+ cv2.INTER_NEAREST,
732
+ cv2.INTER_LINEAR,
733
+ cv2.INTER_CUBIC,
734
+ cv2.INTER_AREA,
735
+ cv2.INTER_LANCZOS4,
736
+ ] = cv2.INTER_LINEAR,
737
+ mask_interpolation: Literal[
738
+ cv2.INTER_NEAREST,
739
+ cv2.INTER_LINEAR,
740
+ cv2.INTER_CUBIC,
741
+ cv2.INTER_AREA,
742
+ cv2.INTER_LANCZOS4,
743
+ ] = cv2.INTER_NEAREST,
744
+ fit_output: bool = False,
745
+ keep_ratio: bool = False,
746
+ rotate_method: Literal["largest_box", "ellipse"] = "largest_box",
747
+ balanced_scale: bool = False,
748
+ border_mode: Literal[
749
+ cv2.BORDER_CONSTANT,
750
+ cv2.BORDER_REPLICATE,
751
+ cv2.BORDER_REFLECT,
752
+ cv2.BORDER_WRAP,
753
+ cv2.BORDER_REFLECT_101,
754
+ ] = cv2.BORDER_CONSTANT,
755
+ fill: tuple[float, ...] | float = 0,
756
+ fill_mask: tuple[float, ...] | float = 0,
757
+ p: float = 0.5,
758
+ ):
759
+ super().__init__(p=p)
760
+
761
+ self.interpolation = interpolation
762
+ self.mask_interpolation = mask_interpolation
763
+ self.fill = fill
764
+ self.fill_mask = fill_mask
765
+ self.border_mode = border_mode
766
+ self.scale = cast("dict[str, tuple[float, float]]", scale)
767
+ self.translate_percent = cast("dict[str, tuple[float, float]]", translate_percent)
768
+ self.translate_px = cast("dict[str, tuple[int, int]]", translate_px)
769
+ self.rotate = cast("tuple[float, float]", rotate)
770
+ self.fit_output = fit_output
771
+ self.shear = cast("dict[str, tuple[float, float]]", shear)
772
+ self.keep_ratio = keep_ratio
773
+ self.rotate_method = rotate_method
774
+ self.balanced_scale = balanced_scale
775
+
776
+ if self.keep_ratio and self.scale["x"] != self.scale["y"]:
777
+ raise ValueError(
778
+ f"When keep_ratio is True, the x and y scale range should be identical. got {self.scale}",
779
+ )
780
+
781
+ def apply(
782
+ self,
783
+ img: np.ndarray,
784
+ matrix: np.ndarray,
785
+ output_shape: tuple[int, int],
786
+ **params: Any,
787
+ ) -> np.ndarray:
788
+ """Apply the affine transform to an image.
789
+
790
+ Args:
791
+ img (np.ndarray): Image to be distorted.
792
+ matrix (np.ndarray): Transformation matrix.
793
+ output_shape (tuple[int, int]): Output shape.
794
+ **params (Any): Additional parameters.
795
+
796
+ Returns:
797
+ np.ndarray: Distorted image.
798
+
799
+ """
800
+ return fgeometric.warp_affine(
801
+ img,
802
+ matrix,
803
+ interpolation=self.interpolation,
804
+ fill=self.fill,
805
+ border_mode=self.border_mode,
806
+ output_shape=output_shape,
807
+ )
808
+
809
+ def apply_to_mask(
810
+ self,
811
+ mask: np.ndarray,
812
+ matrix: np.ndarray,
813
+ output_shape: tuple[int, int],
814
+ **params: Any,
815
+ ) -> np.ndarray:
816
+ """Apply the affine transform to a mask.
817
+
818
+ Args:
819
+ mask (np.ndarray): Mask to be distorted.
820
+ matrix (np.ndarray): Transformation matrix.
821
+ output_shape (tuple[int, int]): Output shape.
822
+ **params (Any): Additional parameters.
823
+
824
+ Returns:
825
+ np.ndarray: Distorted mask.
826
+
827
+ """
828
+ return fgeometric.warp_affine(
829
+ mask,
830
+ matrix,
831
+ interpolation=self.mask_interpolation,
832
+ fill=self.fill_mask,
833
+ border_mode=self.border_mode,
834
+ output_shape=output_shape,
835
+ )
836
+
837
+ def apply_to_bboxes(
838
+ self,
839
+ bboxes: np.ndarray,
840
+ bbox_matrix: np.ndarray,
841
+ output_shape: tuple[int, int],
842
+ **params: Any,
843
+ ) -> np.ndarray:
844
+ """Apply the affine transform to bounding boxes.
845
+
846
+ Args:
847
+ bboxes (np.ndarray): Bounding boxes to be distorted.
848
+ bbox_matrix (np.ndarray): Transformation matrix.
849
+ output_shape (tuple[int, int]): Output shape.
850
+ **params (Any): Additional parameters.
851
+
852
+ Returns:
853
+ np.ndarray: Distorted bounding boxes.
854
+
855
+ """
856
+ return fgeometric.bboxes_affine(
857
+ bboxes,
858
+ bbox_matrix,
859
+ self.rotate_method,
860
+ params["shape"][:2],
861
+ self.border_mode,
862
+ output_shape,
863
+ )
864
+
865
+ def apply_to_keypoints(
866
+ self,
867
+ keypoints: np.ndarray,
868
+ matrix: np.ndarray,
869
+ scale: dict[str, float],
870
+ **params: Any,
871
+ ) -> np.ndarray:
872
+ """Apply the affine transform to keypoints.
873
+
874
+ Args:
875
+ keypoints (np.ndarray): Keypoints to be distorted.
876
+ matrix (np.ndarray): Transformation matrix.
877
+ scale (dict[str, float]): Scale.
878
+ **params (Any): Additional parameters.
879
+
880
+ Returns:
881
+ np.ndarray: Distorted keypoints.
882
+
883
+ """
884
+ return fgeometric.keypoints_affine(
885
+ keypoints,
886
+ matrix,
887
+ params["shape"],
888
+ scale,
889
+ self.border_mode,
890
+ )
891
+
892
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
893
+ def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
894
+ """Apply the affine transform to a batch of images.
895
+
896
+ Args:
897
+ images (np.ndarray): Images to be distorted.
898
+ **params (Any): Additional parameters.
899
+
900
+ Returns:
901
+ np.ndarray: Distorted images.
902
+
903
+ """
904
+ return self.apply(images, **params)
905
+
906
+ @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
907
+ def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
908
+ """Apply the affine transform to a volume.
909
+
910
+ Args:
911
+ volume (np.ndarray): Volume to be distorted.
912
+ **params (Any): Additional parameters.
913
+
914
+ Returns:
915
+ np.ndarray: Distorted volume.
916
+
917
+ """
918
+ return self.apply(volume, **params)
919
+
920
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
921
+ def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
922
+ """Apply the affine transform to a batch of volumes.
923
+
924
+ Args:
925
+ volumes (np.ndarray): Volumes to be distorted.
926
+ **params (Any): Additional parameters.
927
+
928
+ Returns:
929
+ np.ndarray: Distorted volumes.
930
+
931
+ """
932
+ return self.apply(volumes, **params)
933
+
934
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
935
+ def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
936
+ """Apply the affine transform to a 3D mask.
937
+
938
+ Args:
939
+ mask3d (np.ndarray): 3D mask to be distorted.
940
+ **params (Any): Additional parameters.
941
+
942
+ Returns:
943
+ np.ndarray: Distorted 3D mask.
944
+
945
+ """
946
+ return self.apply_to_mask(mask3d, **params)
947
+
948
+ @staticmethod
949
+ def _get_scale(
950
+ scale: dict[str, tuple[float, float]],
951
+ keep_ratio: bool,
952
+ balanced_scale: bool,
953
+ random_state: random.Random,
954
+ ) -> dict[str, float]:
955
+ result_scale = {}
956
+ for key, value in scale.items():
957
+ if isinstance(value, (int, float)):
958
+ result_scale[key] = float(value)
959
+ elif isinstance(value, tuple):
960
+ if balanced_scale:
961
+ lower_interval = (value[0], 1.0) if value[0] < 1 else None
962
+ upper_interval = (1.0, value[1]) if value[1] > 1 else None
963
+
964
+ if lower_interval is not None and upper_interval is not None:
965
+ selected_interval = random_state.choice(
966
+ [lower_interval, upper_interval],
967
+ )
968
+ elif lower_interval is not None:
969
+ selected_interval = lower_interval
970
+ elif upper_interval is not None:
971
+ selected_interval = upper_interval
972
+ else:
973
+ result_scale[key] = 1.0
974
+ continue
975
+
976
+ result_scale[key] = random_state.uniform(*selected_interval)
977
+ else:
978
+ result_scale[key] = random_state.uniform(*value)
979
+ else:
980
+ raise TypeError(
981
+ f"Invalid scale value for key {key}: {value}. Expected a float or a tuple of two floats.",
982
+ )
983
+
984
+ if keep_ratio:
985
+ result_scale["y"] = result_scale["x"]
986
+
987
+ return result_scale
988
+
989
+ def get_params_dependent_on_data(
990
+ self,
991
+ params: dict[str, Any],
992
+ data: dict[str, Any],
993
+ ) -> dict[str, Any]:
994
+ """Get the parameters dependent on the data.
995
+
996
+ Args:
997
+ params (dict[str, Any]): Parameters.
998
+ data (dict[str, Any]): Data.
999
+
1000
+ Returns:
1001
+ dict[str, Any]: Parameters.
1002
+
1003
+ """
1004
+ image_shape = params["shape"][:2]
1005
+
1006
+ translate = self._get_translate_params(image_shape)
1007
+ shear = self._get_shear_params()
1008
+ scale = self._get_scale(
1009
+ self.scale,
1010
+ self.keep_ratio,
1011
+ self.balanced_scale,
1012
+ self.py_random,
1013
+ )
1014
+ rotate = self.py_random.uniform(*self.rotate)
1015
+
1016
+ image_shift = fgeometric.center(image_shape)
1017
+ bbox_shift = fgeometric.center_bbox(image_shape)
1018
+
1019
+ matrix = fgeometric.create_affine_transformation_matrix(
1020
+ translate,
1021
+ shear,
1022
+ scale,
1023
+ rotate,
1024
+ image_shift,
1025
+ )
1026
+ bbox_matrix = fgeometric.create_affine_transformation_matrix(
1027
+ translate,
1028
+ shear,
1029
+ scale,
1030
+ rotate,
1031
+ bbox_shift,
1032
+ )
1033
+
1034
+ if self.fit_output:
1035
+ matrix, output_shape = fgeometric.compute_affine_warp_output_shape(
1036
+ matrix,
1037
+ image_shape,
1038
+ )
1039
+ bbox_matrix, _ = fgeometric.compute_affine_warp_output_shape(
1040
+ bbox_matrix,
1041
+ image_shape,
1042
+ )
1043
+ else:
1044
+ output_shape = image_shape
1045
+
1046
+ return {
1047
+ "rotate": rotate,
1048
+ "scale": scale,
1049
+ "matrix": matrix,
1050
+ "bbox_matrix": bbox_matrix,
1051
+ "output_shape": output_shape,
1052
+ }
1053
+
1054
+ def _get_translate_params(self, image_shape: tuple[int, int]) -> dict[str, int]:
1055
+ height, width = image_shape[:2]
1056
+ if self.translate_px is not None:
1057
+ return {
1058
+ "x": self.py_random.randint(int(self.translate_px["x"][0]), int(self.translate_px["x"][1])),
1059
+ "y": self.py_random.randint(int(self.translate_px["y"][0]), int(self.translate_px["y"][1])),
1060
+ }
1061
+ if self.translate_percent is not None:
1062
+ translate = {key: self.py_random.uniform(*value) for key, value in self.translate_percent.items()}
1063
+ return cast(
1064
+ "dict[str, int]",
1065
+ {"x": int(translate["x"] * width), "y": int(translate["y"] * height)},
1066
+ )
1067
+ return cast("dict[str, int]", {"x": 0, "y": 0})
1068
+
1069
+ def _get_shear_params(self) -> dict[str, float]:
1070
+ return {
1071
+ "x": -self.py_random.uniform(*self.shear["x"]),
1072
+ "y": -self.py_random.uniform(*self.shear["y"]),
1073
+ }
1074
+
1075
+
1076
+ class ShiftScaleRotate(Affine):
1077
+ """Randomly apply affine transforms: translate, scale and rotate the input.
1078
+
1079
+ Args:
1080
+ shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit
1081
+ is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and
1082
+ upper bounds should lie in range [-1, 1]. Default: (-0.0625, 0.0625).
1083
+ scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
1084
+ range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
1085
+ If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
1086
+ Default: (-0.1, 0.1).
1087
+ rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the
1088
+ range will be (-rotate_limit, rotate_limit). Default: (-45, 45).
1089
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
1090
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
1091
+ Default: cv2.INTER_LINEAR.
1092
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
1093
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
1094
+ Default: cv2.BORDER_CONSTANT
1095
+ fill (tuple[float, ...] | float): padding value if border_mode is cv2.BORDER_CONSTANT.
1096
+ fill_mask (tuple[float, ...] | float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
1097
+ shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value
1098
+ instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value,
1099
+ the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in
1100
+ the range [-1, 1]. Default: None.
1101
+ shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value
1102
+ instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value,
1103
+ the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie
1104
+ in the range [-, 1]. Default: None.
1105
+ rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
1106
+ Default: "largest_box"
1107
+ mask_interpolation (OpenCV flag): Flag that is used to specify the interpolation algorithm for mask.
1108
+ Should be one of: cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
1109
+ Default: cv2.INTER_NEAREST.
1110
+ p (float): probability of applying the transform. Default: 0.5.
1111
+
1112
+ Targets:
1113
+ image, mask, keypoints, bboxes, volume, mask3d
1114
+
1115
+ Image types:
1116
+ uint8, float32
1117
+
1118
+ Examples:
1119
+ >>> import numpy as np
1120
+ >>> import albumentations as A
1121
+ >>> import cv2
1122
+ >>>
1123
+ >>> # Prepare sample data
1124
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
1125
+ >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
1126
+ >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
1127
+ >>> bbox_labels = [1, 2]
1128
+ >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
1129
+ >>> keypoint_labels = [0, 1]
1130
+ >>>
1131
+ >>> # Define transform with parameters as tuples when possible
1132
+ >>> transform = A.Compose([
1133
+ ... A.ShiftScaleRotate(
1134
+ ... shift_limit=(-0.0625, 0.0625),
1135
+ ... scale_limit=(-0.1, 0.1),
1136
+ ... rotate_limit=(-45, 45),
1137
+ ... interpolation=cv2.INTER_LINEAR,
1138
+ ... border_mode=cv2.BORDER_CONSTANT,
1139
+ ... rotate_method="largest_box",
1140
+ ... p=1.0
1141
+ ... ),
1142
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
1143
+ ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
1144
+ >>>
1145
+ >>> # Apply the transform
1146
+ >>> transformed = transform(
1147
+ ... image=image,
1148
+ ... mask=mask,
1149
+ ... bboxes=bboxes,
1150
+ ... bbox_labels=bbox_labels,
1151
+ ... keypoints=keypoints,
1152
+ ... keypoint_labels=keypoint_labels
1153
+ ... )
1154
+ >>>
1155
+ >>> # Get the transformed data
1156
+ >>> transformed_image = transformed['image'] # Shifted, scaled and rotated image
1157
+ >>> transformed_mask = transformed['mask'] # Shifted, scaled and rotated mask
1158
+ >>> transformed_bboxes = transformed['bboxes'] # Shifted, scaled and rotated bounding boxes
1159
+ >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
1160
+ >>> transformed_keypoints = transformed['keypoints'] # Shifted, scaled and rotated keypoints
1161
+ >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
1162
+
1163
+ """
1164
+
1165
+ _targets = ALL_TARGETS
1166
+
1167
+ class InitSchema(BaseTransformInitSchema):
1168
+ shift_limit: SymmetricRangeType
1169
+ scale_limit: SymmetricRangeType
1170
+ rotate_limit: SymmetricRangeType
1171
+ interpolation: Literal[
1172
+ cv2.INTER_NEAREST,
1173
+ cv2.INTER_LINEAR,
1174
+ cv2.INTER_CUBIC,
1175
+ cv2.INTER_AREA,
1176
+ cv2.INTER_LANCZOS4,
1177
+ ]
1178
+
1179
+ border_mode: Literal[
1180
+ cv2.BORDER_CONSTANT,
1181
+ cv2.BORDER_REPLICATE,
1182
+ cv2.BORDER_REFLECT,
1183
+ cv2.BORDER_WRAP,
1184
+ cv2.BORDER_REFLECT_101,
1185
+ ]
1186
+
1187
+ fill: tuple[float, ...] | float
1188
+ fill_mask: tuple[float, ...] | float
1189
+
1190
+ shift_limit_x: tuple[float, float] | float | None
1191
+ shift_limit_y: tuple[float, float] | float | None
1192
+ rotate_method: Literal["largest_box", "ellipse"]
1193
+ mask_interpolation: Literal[
1194
+ cv2.INTER_NEAREST,
1195
+ cv2.INTER_LINEAR,
1196
+ cv2.INTER_CUBIC,
1197
+ cv2.INTER_AREA,
1198
+ cv2.INTER_LANCZOS4,
1199
+ ]
1200
+
1201
+ @model_validator(mode="after")
1202
+ def _check_shift_limit(self) -> Self:
1203
+ bounds = -1, 1
1204
+ self.shift_limit_x = to_tuple(
1205
+ self.shift_limit_x if self.shift_limit_x is not None else self.shift_limit,
1206
+ )
1207
+ check_range(self.shift_limit_x, *bounds, "shift_limit_x")
1208
+ self.shift_limit_y = to_tuple(
1209
+ self.shift_limit_y if self.shift_limit_y is not None else self.shift_limit,
1210
+ )
1211
+ check_range(self.shift_limit_y, *bounds, "shift_limit_y")
1212
+
1213
+ return self
1214
+
1215
+ @field_validator("scale_limit")
1216
+ @classmethod
1217
+ def _check_scale_limit(
1218
+ cls,
1219
+ value: tuple[float, float] | float,
1220
+ info: ValidationInfo,
1221
+ ) -> tuple[float, float]:
1222
+ bounds = 0, float("inf")
1223
+ result = to_tuple(value, bias=1.0)
1224
+ check_range(result, *bounds, str(info.field_name))
1225
+ return result
1226
+
1227
+ def __init__(
1228
+ self,
1229
+ shift_limit: tuple[float, float] | float = (-0.0625, 0.0625),
1230
+ scale_limit: tuple[float, float] | float = (-0.1, 0.1),
1231
+ rotate_limit: tuple[float, float] | float = (-45, 45),
1232
+ interpolation: Literal[
1233
+ cv2.INTER_NEAREST,
1234
+ cv2.INTER_LINEAR,
1235
+ cv2.INTER_CUBIC,
1236
+ cv2.INTER_AREA,
1237
+ cv2.INTER_LANCZOS4,
1238
+ ] = cv2.INTER_LINEAR,
1239
+ border_mode: int = cv2.BORDER_CONSTANT,
1240
+ shift_limit_x: tuple[float, float] | float | None = None,
1241
+ shift_limit_y: tuple[float, float] | float | None = None,
1242
+ rotate_method: Literal["largest_box", "ellipse"] = "largest_box",
1243
+ mask_interpolation: Literal[
1244
+ cv2.INTER_NEAREST,
1245
+ cv2.INTER_LINEAR,
1246
+ cv2.INTER_CUBIC,
1247
+ cv2.INTER_AREA,
1248
+ cv2.INTER_LANCZOS4,
1249
+ ] = cv2.INTER_NEAREST,
1250
+ fill: tuple[float, ...] | float = 0,
1251
+ fill_mask: tuple[float, ...] | float = 0,
1252
+ p: float = 0.5,
1253
+ ):
1254
+ shift_limit_x = cast("tuple[float, float]", shift_limit_x)
1255
+ shift_limit_y = cast("tuple[float, float]", shift_limit_y)
1256
+ super().__init__(
1257
+ scale=scale_limit,
1258
+ translate_percent={"x": shift_limit_x, "y": shift_limit_y},
1259
+ rotate=rotate_limit,
1260
+ shear=(0, 0),
1261
+ interpolation=interpolation,
1262
+ mask_interpolation=mask_interpolation,
1263
+ fill=fill,
1264
+ fill_mask=fill_mask,
1265
+ border_mode=border_mode,
1266
+ fit_output=False,
1267
+ keep_ratio=False,
1268
+ rotate_method=rotate_method,
1269
+ p=p,
1270
+ )
1271
+ warn(
1272
+ "ShiftScaleRotate is a special case of Affine transform. Please use Affine transform instead.",
1273
+ UserWarning,
1274
+ stacklevel=2,
1275
+ )
1276
+ self.shift_limit_x = shift_limit_x
1277
+ self.shift_limit_y = shift_limit_y
1278
+
1279
+ self.scale_limit = cast("tuple[float, float]", scale_limit)
1280
+ self.rotate_limit = cast("tuple[int, int]", rotate_limit)
1281
+ self.border_mode = border_mode
1282
+ self.fill = fill
1283
+ self.fill_mask = fill_mask
1284
+
1285
+ def get_transform_init_args(self) -> dict[str, Any]:
1286
+ """Get the transform initialization arguments.
1287
+
1288
+ Returns:
1289
+ dict[str, Any]: Transform initialization arguments.
1290
+
1291
+ """
1292
+ return {
1293
+ "shift_limit_x": self.shift_limit_x,
1294
+ "shift_limit_y": self.shift_limit_y,
1295
+ "scale_limit": to_tuple(self.scale_limit, bias=-1.0),
1296
+ "rotate_limit": self.rotate_limit,
1297
+ "interpolation": self.interpolation,
1298
+ "border_mode": self.border_mode,
1299
+ "fill": self.fill,
1300
+ "fill_mask": self.fill_mask,
1301
+ "rotate_method": self.rotate_method,
1302
+ "mask_interpolation": self.mask_interpolation,
1303
+ }
1304
+
1305
+
1306
+ class GridElasticDeform(DualTransform):
1307
+ """Apply elastic deformations to images, masks, bounding boxes, and keypoints using a grid-based approach.
1308
+
1309
+ This transformation overlays a grid on the input and applies random displacements to the grid points,
1310
+ resulting in local elastic distortions. The granularity and intensity of the distortions can be
1311
+ controlled using the dimensions of the overlaying distortion grid and the magnitude parameter.
1312
+
1313
+
1314
+ Args:
1315
+ num_grid_xy (tuple[int, int]): Number of grid cells along the width and height.
1316
+ Specified as (grid_width, grid_height). Each value must be greater than 1.
1317
+ magnitude (int): Maximum pixel-wise displacement for distortion. Must be greater than 0.
1318
+ interpolation (int): Interpolation method to be used for the image transformation.
1319
+ Default: cv2.INTER_LINEAR
1320
+ mask_interpolation (int): Interpolation method to be used for mask transformation.
1321
+ Default: cv2.INTER_NEAREST
1322
+ p (float): Probability of applying the transform. Default: 1.0.
1323
+
1324
+ Targets:
1325
+ image, mask, bboxes, keypoints, volume, mask3d
1326
+
1327
+ Image types:
1328
+ uint8, float32
1329
+
1330
+ Number of channels:
1331
+ 1, 3
1332
+
1333
+ Examples:
1334
+ >>> import numpy as np
1335
+ >>> import albumentations as A
1336
+ >>> import cv2
1337
+ >>>
1338
+ >>> # Prepare sample data
1339
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
1340
+ >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
1341
+ >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
1342
+ >>> bbox_labels = [1, 2]
1343
+ >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
1344
+ >>> keypoint_labels = [0, 1]
1345
+ >>>
1346
+ >>> # Define transform with parameters as tuples when possible
1347
+ >>> transform = A.Compose([
1348
+ ... A.GridElasticDeform(
1349
+ ... num_grid_xy=(4, 4),
1350
+ ... magnitude=10,
1351
+ ... interpolation=cv2.INTER_LINEAR,
1352
+ ... mask_interpolation=cv2.INTER_NEAREST,
1353
+ ... p=1.0
1354
+ ... ),
1355
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
1356
+ ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
1357
+ >>>
1358
+ >>> # Apply the transform
1359
+ >>> transformed = transform(
1360
+ ... image=image,
1361
+ ... mask=mask,
1362
+ ... bboxes=bboxes,
1363
+ ... bbox_labels=bbox_labels,
1364
+ ... keypoints=keypoints,
1365
+ ... keypoint_labels=keypoint_labels
1366
+ ... )
1367
+ >>>
1368
+ >>> # Get the transformed data
1369
+ >>> transformed_image = transformed['image'] # Elastically deformed image
1370
+ >>> transformed_mask = transformed['mask'] # Elastically deformed mask
1371
+ >>> transformed_bboxes = transformed['bboxes'] # Elastically deformed bounding boxes
1372
+ >>> transformed_bbox_labels = transformed['bbox_labels'] # Labels for transformed bboxes
1373
+ >>> transformed_keypoints = transformed['keypoints'] # Elastically deformed keypoints
1374
+ >>> transformed_keypoint_labels = transformed['keypoint_labels'] # Labels for transformed keypoints
1375
+
1376
+ Note:
1377
+ This transformation is particularly useful for data augmentation in medical imaging
1378
+ and other domains where elastic deformations can simulate realistic variations.
1379
+
1380
+ """
1381
+
1382
+ _targets = ALL_TARGETS
1383
+
1384
+ class InitSchema(BaseTransformInitSchema):
1385
+ num_grid_xy: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))]
1386
+ magnitude: int = Field(gt=0)
1387
+ interpolation: Literal[
1388
+ cv2.INTER_NEAREST,
1389
+ cv2.INTER_LINEAR,
1390
+ cv2.INTER_CUBIC,
1391
+ cv2.INTER_AREA,
1392
+ cv2.INTER_LANCZOS4,
1393
+ ]
1394
+ mask_interpolation: Literal[
1395
+ cv2.INTER_NEAREST,
1396
+ cv2.INTER_LINEAR,
1397
+ cv2.INTER_CUBIC,
1398
+ cv2.INTER_AREA,
1399
+ cv2.INTER_LANCZOS4,
1400
+ ]
1401
+
1402
+ def __init__(
1403
+ self,
1404
+ num_grid_xy: tuple[int, int],
1405
+ magnitude: int,
1406
+ interpolation: Literal[
1407
+ cv2.INTER_NEAREST,
1408
+ cv2.INTER_LINEAR,
1409
+ cv2.INTER_CUBIC,
1410
+ cv2.INTER_AREA,
1411
+ cv2.INTER_LANCZOS4,
1412
+ ] = cv2.INTER_LINEAR,
1413
+ mask_interpolation: Literal[
1414
+ cv2.INTER_NEAREST,
1415
+ cv2.INTER_LINEAR,
1416
+ cv2.INTER_CUBIC,
1417
+ cv2.INTER_AREA,
1418
+ cv2.INTER_LANCZOS4,
1419
+ ] = cv2.INTER_NEAREST,
1420
+ p: float = 1.0,
1421
+ ):
1422
+ super().__init__(p=p)
1423
+ self.num_grid_xy = num_grid_xy
1424
+ self.magnitude = magnitude
1425
+ self.interpolation = interpolation
1426
+ self.mask_interpolation = mask_interpolation
1427
+
1428
+ @staticmethod
1429
+ def _generate_mesh(polygons: np.ndarray, dimensions: np.ndarray) -> np.ndarray:
1430
+ return np.hstack((dimensions.reshape(-1, 4), polygons))
1431
+
1432
+ def get_params_dependent_on_data(
1433
+ self,
1434
+ params: dict[str, Any],
1435
+ data: dict[str, Any],
1436
+ ) -> dict[str, Any]:
1437
+ """Get the parameters dependent on the data.
1438
+
1439
+ Args:
1440
+ params (dict[str, Any]): Parameters.
1441
+ data (dict[str, Any]): Data.
1442
+
1443
+ Returns:
1444
+ dict[str, Any]: Parameters.
1445
+
1446
+ """
1447
+ image_shape = params["shape"][:2]
1448
+
1449
+ # Replace calculate_grid_dimensions with split_uniform_grid
1450
+ tiles = fgeometric.split_uniform_grid(
1451
+ image_shape,
1452
+ self.num_grid_xy,
1453
+ self.random_generator,
1454
+ )
1455
+
1456
+ # Convert tiles to the format expected by generate_distorted_grid_polygons
1457
+ dimensions = np.array(
1458
+ [
1459
+ [
1460
+ tile[1],
1461
+ tile[0],
1462
+ tile[3],
1463
+ tile[2],
1464
+ ] # Reorder to [x_min, y_min, x_max, y_max]
1465
+ for tile in tiles
1466
+ ],
1467
+ ).reshape(
1468
+ (*self.num_grid_xy[::-1], 4),
1469
+ ) # Reshape to (grid_height, grid_width, 4)
1470
+
1471
+ polygons = fgeometric.generate_distorted_grid_polygons(
1472
+ dimensions,
1473
+ self.magnitude,
1474
+ self.random_generator,
1475
+ )
1476
+
1477
+ generated_mesh = self._generate_mesh(polygons, dimensions)
1478
+
1479
+ return {"generated_mesh": generated_mesh}
1480
+
1481
+ def apply(
1482
+ self,
1483
+ img: np.ndarray,
1484
+ generated_mesh: np.ndarray,
1485
+ **params: Any,
1486
+ ) -> np.ndarray:
1487
+ """Apply the GridElasticDeform transform to an image.
1488
+
1489
+ Args:
1490
+ img (np.ndarray): Image to be transformed.
1491
+ generated_mesh (np.ndarray): Generated mesh.
1492
+ **params (Any): Additional parameters.
1493
+
1494
+ """
1495
+ if not is_rgb_image(img) and not is_grayscale_image(img):
1496
+ raise ValueError("GridElasticDeform transform is only supported for RGB and grayscale images.")
1497
+ return fgeometric.distort_image(img, generated_mesh, self.interpolation)
1498
+
1499
+ def apply_to_mask(
1500
+ self,
1501
+ mask: np.ndarray,
1502
+ generated_mesh: np.ndarray,
1503
+ **params: Any,
1504
+ ) -> np.ndarray:
1505
+ """Apply the GridElasticDeform transform to a mask.
1506
+
1507
+ Args:
1508
+ mask (np.ndarray): Mask to be transformed.
1509
+ generated_mesh (np.ndarray): Generated mesh.
1510
+ **params (Any): Additional parameters.
1511
+
1512
+ """
1513
+ return fgeometric.distort_image(mask, generated_mesh, self.mask_interpolation)
1514
+
1515
+ def apply_to_bboxes(
1516
+ self,
1517
+ bboxes: np.ndarray,
1518
+ generated_mesh: np.ndarray,
1519
+ **params: Any,
1520
+ ) -> np.ndarray:
1521
+ """Apply the GridElasticDeform transform to bounding boxes.
1522
+
1523
+ Args:
1524
+ bboxes (np.ndarray): Bounding boxes to be transformed.
1525
+ generated_mesh (np.ndarray): Generated mesh.
1526
+ **params (Any): Additional parameters.
1527
+
1528
+ """
1529
+ bboxes_denorm = denormalize_bboxes(bboxes, params["shape"][:2])
1530
+ return normalize_bboxes(
1531
+ fgeometric.bbox_distort_image(
1532
+ bboxes_denorm,
1533
+ generated_mesh,
1534
+ params["shape"][:2],
1535
+ ),
1536
+ params["shape"][:2],
1537
+ )
1538
+
1539
+ def apply_to_keypoints(
1540
+ self,
1541
+ keypoints: np.ndarray,
1542
+ generated_mesh: np.ndarray,
1543
+ **params: Any,
1544
+ ) -> np.ndarray:
1545
+ """Apply the GridElasticDeform transform to keypoints.
1546
+
1547
+ Args:
1548
+ keypoints (np.ndarray): Keypoints to be transformed.
1549
+ generated_mesh (np.ndarray): Generated mesh.
1550
+ **params (Any): Additional parameters.
1551
+
1552
+ """
1553
+ return fgeometric.distort_image_keypoints(
1554
+ keypoints,
1555
+ generated_mesh,
1556
+ params["shape"][:2],
1557
+ )
1558
+
1559
+
1560
+ class RandomGridShuffle(DualTransform):
1561
+ """Randomly shuffles the grid's cells on an image, mask, or keypoints,
1562
+ effectively rearranging patches within the image.
1563
+ This transformation divides the image into a grid and then permutes these grid cells based on a random mapping.
1564
+
1565
+ Args:
1566
+ grid (tuple[int, int]): Size of the grid for splitting the image into cells. Each cell is shuffled randomly.
1567
+ For example, (3, 3) will divide the image into a 3x3 grid, resulting in 9 cells to be shuffled.
1568
+ Default: (3, 3)
1569
+ p (float): Probability that the transform will be applied. Should be in the range [0, 1].
1570
+ Default: 0.5
1571
+
1572
+ Targets:
1573
+ image, mask, keypoints, bboxes, volume, mask3d
1574
+
1575
+ Image types:
1576
+ uint8, float32
1577
+
1578
+ Note:
1579
+ - This transform maintains consistency across all targets. If applied to an image and its corresponding
1580
+ mask or keypoints, the same shuffling will be applied to all.
1581
+ - The number of cells in the grid should be at least 2 (i.e., grid should be at least (1, 2), (2, 1), or (2, 2))
1582
+ for the transform to have any effect.
1583
+ - Keypoints are moved along with their corresponding grid cell.
1584
+ - This transform could be useful when only micro features are important for the model, and memorizing
1585
+ the global structure could be harmful. For example:
1586
+ - Identifying the type of cell phone used to take a picture based on micro artifacts generated by
1587
+ phone post-processing algorithms, rather than the semantic features of the photo.
1588
+ See more at https://ieeexplore.ieee.org/abstract/document/8622031
1589
+ - Identifying stress, glucose, hydration levels based on skin images.
1590
+
1591
+ Mathematical Formulation:
1592
+ 1. The image is divided into a grid of size (m, n) as specified by the 'grid' parameter.
1593
+ 2. A random permutation P of integers from 0 to (m*n - 1) is generated.
1594
+ 3. Each cell in the grid is assigned a number from 0 to (m*n - 1) in row-major order.
1595
+ 4. The cells are then rearranged according to the permutation P.
1596
+
1597
+ Examples:
1598
+ >>> import numpy as np
1599
+ >>> import albumentations as A
1600
+ >>> # Prepare sample data
1601
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
1602
+ >>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
1603
+ >>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
1604
+ >>> bbox_labels = [1, 2]
1605
+ >>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
1606
+ >>> keypoint_labels = [0, 1]
1607
+ >>>
1608
+ >>> # Define transform with grid as a tuple
1609
+ >>> transform = A.Compose([
1610
+ ... A.RandomGridShuffle(grid=(3, 3), p=1.0),
1611
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
1612
+ ... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
1613
+ >>>
1614
+ >>> # Apply the transform
1615
+ >>> transformed = transform(
1616
+ ... image=image,
1617
+ ... mask=mask,
1618
+ ... bboxes=bboxes,
1619
+ ... bbox_labels=bbox_labels,
1620
+ ... keypoints=keypoints,
1621
+ ... keypoint_labels=keypoint_labels
1622
+ ... )
1623
+ >>>
1624
+ >>> # Get the transformed data
1625
+ >>> transformed_image = transformed['image'] # Grid-shuffled image
1626
+ >>> transformed_mask = transformed['mask'] # Grid-shuffled mask
1627
+ >>> transformed_bboxes = transformed['bboxes'] # Grid-shuffled bounding boxes
1628
+ >>> transformed_keypoints = transformed['keypoints'] # Grid-shuffled keypoints
1629
+ >>>
1630
+ >>> # Visualization example with a simpler grid
1631
+ >>> simple_image = np.array([
1632
+ ... [1, 1, 1, 2, 2, 2],
1633
+ ... [1, 1, 1, 2, 2, 2],
1634
+ ... [1, 1, 1, 2, 2, 2],
1635
+ ... [3, 3, 3, 4, 4, 4],
1636
+ ... [3, 3, 3, 4, 4, 4],
1637
+ ... [3, 3, 3, 4, 4, 4]
1638
+ ... ])
1639
+ >>> simple_transform = A.RandomGridShuffle(grid=(2, 2), p=1.0)
1640
+ >>> simple_result = simple_transform(image=simple_image)
1641
+ >>> simple_transformed = simple_result['image']
1642
+ >>> # The result could look like:
1643
+ >>> # array([[4, 4, 4, 2, 2, 2],
1644
+ >>> # [4, 4, 4, 2, 2, 2],
1645
+ >>> # [4, 4, 4, 2, 2, 2],
1646
+ >>> # [3, 3, 3, 1, 1, 1],
1647
+ >>> # [3, 3, 3, 1, 1, 1],
1648
+ >>> # [3, 3, 3, 1, 1, 1]])
1649
+
1650
+ """
1651
+
1652
+ class InitSchema(BaseTransformInitSchema):
1653
+ grid: Annotated[tuple[int, int], AfterValidator(check_range_bounds(1, None))]
1654
+
1655
+ _targets = ALL_TARGETS
1656
+
1657
+ def __init__(
1658
+ self,
1659
+ grid: tuple[int, int] = (3, 3),
1660
+ p: float = 0.5,
1661
+ ):
1662
+ super().__init__(p=p)
1663
+ self.grid = grid
1664
+
1665
+ def apply(
1666
+ self,
1667
+ img: np.ndarray,
1668
+ tiles: np.ndarray,
1669
+ mapping: list[int],
1670
+ **params: Any,
1671
+ ) -> np.ndarray:
1672
+ """Apply the RandomGridShuffle transform to an image.
1673
+
1674
+ Args:
1675
+ img (np.ndarray): Image to be transformed.
1676
+ tiles (np.ndarray): Tiles to be transformed.
1677
+ mapping (list[int]): Mapping of the tiles.
1678
+ **params (Any): Additional parameters.
1679
+
1680
+ """
1681
+ return fgeometric.swap_tiles_on_image(img, tiles, mapping)
1682
+
1683
+ def apply_to_bboxes(
1684
+ self,
1685
+ bboxes: np.ndarray,
1686
+ tiles: np.ndarray,
1687
+ mapping: np.ndarray,
1688
+ **params: Any,
1689
+ ) -> np.ndarray:
1690
+ """Apply the RandomGridShuffle transform to bounding boxes.
1691
+
1692
+ Args:
1693
+ bboxes (np.ndarray): Bounding boxes to be transformed.
1694
+ tiles (np.ndarray): Tiles to be transformed.
1695
+ mapping (np.ndarray): Mapping of the tiles.
1696
+ **params (Any): Additional parameters.
1697
+
1698
+ """
1699
+ image_shape = params["shape"][:2]
1700
+ bboxes_denorm = denormalize_bboxes(bboxes, image_shape)
1701
+ processor = cast("BboxProcessor", self.get_processor("bboxes"))
1702
+ if processor is None:
1703
+ return bboxes
1704
+ bboxes_returned = fgeometric.bboxes_grid_shuffle(
1705
+ bboxes_denorm,
1706
+ tiles,
1707
+ mapping,
1708
+ image_shape,
1709
+ min_area=processor.params.min_area,
1710
+ min_visibility=processor.params.min_visibility,
1711
+ )
1712
+ return normalize_bboxes(bboxes_returned, image_shape)
1713
+
1714
+ def apply_to_keypoints(
1715
+ self,
1716
+ keypoints: np.ndarray,
1717
+ tiles: np.ndarray,
1718
+ mapping: np.ndarray,
1719
+ **params: Any,
1720
+ ) -> np.ndarray:
1721
+ """Apply the RandomGridShuffle transform to keypoints.
1722
+
1723
+ Args:
1724
+ keypoints (np.ndarray): Keypoints to be transformed.
1725
+ tiles (np.ndarray): Tiles to be transformed.
1726
+ mapping (np.ndarray): Mapping of the tiles.
1727
+ **params (Any): Additional parameters.
1728
+
1729
+ """
1730
+ return fgeometric.swap_tiles_on_keypoints(keypoints, tiles, mapping)
1731
+
1732
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
1733
+ def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
1734
+ """Apply the RandomGridShuffle transform to a batch of images.
1735
+
1736
+ Args:
1737
+ images (np.ndarray): Images to be transformed.
1738
+ **params (Any): Additional parameters.
1739
+
1740
+ """
1741
+ return self.apply(images, **params)
1742
+
1743
+ @batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
1744
+ def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
1745
+ """Apply the RandomGridShuffle transform to a volume.
1746
+
1747
+ Args:
1748
+ volume (np.ndarray): Volume to be transformed.
1749
+ **params (Any): Additional parameters.
1750
+
1751
+ """
1752
+ return self.apply(volume, **params)
1753
+
1754
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
1755
+ def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
1756
+ """Apply the RandomGridShuffle transform to a batch of volumes.
1757
+
1758
+ Args:
1759
+ volumes (np.ndarray): Volumes to be transformed.
1760
+ **params (Any): Additional parameters.
1761
+
1762
+ """
1763
+ return self.apply(volumes, **params)
1764
+
1765
+ @batch_transform("spatial", has_batch_dim=True, has_depth_dim=False)
1766
+ def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
1767
+ """Apply the RandomGridShuffle transform to a 3D mask.
1768
+
1769
+ Args:
1770
+ mask3d (np.ndarray): 3D mask to be transformed.
1771
+ **params (Any): Additional parameters.
1772
+
1773
+ """
1774
+ return self.apply(mask3d, **params)
1775
+
1776
+ def get_params_dependent_on_data(
1777
+ self,
1778
+ params: dict[str, Any],
1779
+ data: dict[str, Any],
1780
+ ) -> dict[str, np.ndarray]:
1781
+ """Get the parameters dependent on the data.
1782
+
1783
+ Args:
1784
+ params (dict[str, Any]): Parameters.
1785
+ data (dict[str, Any]): Data.
1786
+
1787
+ Returns:
1788
+ dict[str, np.ndarray]: Parameters.
1789
+
1790
+ """
1791
+ image_shape = params["shape"][:2]
1792
+
1793
+ original_tiles = fgeometric.split_uniform_grid(
1794
+ image_shape,
1795
+ self.grid,
1796
+ self.random_generator,
1797
+ )
1798
+ shape_groups = fgeometric.create_shape_groups(original_tiles)
1799
+ mapping = fgeometric.shuffle_tiles_within_shape_groups(
1800
+ shape_groups,
1801
+ self.random_generator,
1802
+ )
1803
+
1804
+ return {"tiles": original_tiles, "mapping": mapping}
1805
+
1806
+
1807
+ class Morphological(DualTransform):
1808
+ """Apply a morphological operation (dilation or erosion) to an image,
1809
+ with particular value for enhancing document scans.
1810
+
1811
+ Morphological operations modify the structure of the image.
1812
+ Dilation expands the white (foreground) regions in a binary or grayscale image, while erosion shrinks them.
1813
+ These operations are beneficial in document processing, for example:
1814
+ - Dilation helps in closing up gaps within text or making thin lines thicker,
1815
+ enhancing legibility for OCR (Optical Character Recognition).
1816
+ - Erosion can remove small white noise and detach connected objects,
1817
+ making the structure of larger objects more pronounced.
1818
+
1819
+ Args:
1820
+ scale (int or tuple/list of int): Specifies the size of the structuring element (kernel) used for the operation.
1821
+ - If an integer is provided, a square kernel of that size will be used.
1822
+ - If a tuple or list is provided, it should contain two integers representing the minimum
1823
+ and maximum sizes for the dilation kernel.
1824
+ operation (Literal["erosion", "dilation"]): The morphological operation to apply.
1825
+ Default is 'dilation'.
1826
+ p (float, optional): The probability of applying this transformation. Default is 0.5.
1827
+
1828
+ Targets:
1829
+ image, mask, keypoints, bboxes, volume, mask3d
1830
+
1831
+ Image types:
1832
+ uint8, float32
1833
+
1834
+ References:
1835
+ Nougat: https://github.com/facebookresearch/nougat
1836
+
1837
+ Examples:
1838
+ >>> import numpy as np
1839
+ >>> import albumentations as A
1840
+ >>> import cv2
1841
+ >>>
1842
+ >>> # Create a document-like binary image with text
1843
+ >>> image = np.ones((200, 500), dtype=np.uint8) * 255 # White background
1844
+ >>> # Add some "text" (black pixels)
1845
+ >>> cv2.putText(image, "Document Text", (50, 100), cv2.FONT_HERSHEY_SIMPLEX, 1, 0, 2)
1846
+ >>> # Add some "noise" (small black dots)
1847
+ >>> for _ in range(50):
1848
+ ... x, y = np.random.randint(0, image.shape[1]), np.random.randint(0, image.shape[0])
1849
+ ... cv2.circle(image, (x, y), 1, 0, -1)
1850
+ >>>
1851
+ >>> # Create a mask representing text regions
1852
+ >>> mask = np.zeros_like(image)
1853
+ >>> mask[image < 128] = 1 # Binary mask where text exists
1854
+ >>>
1855
+ >>> # Example 1: Apply dilation to thicken text and fill gaps
1856
+ >>> dilation_transform = A.Morphological(
1857
+ ... scale=3, # Size of the structuring element
1858
+ ... operation="dilation", # Expand white regions (or black if inverted)
1859
+ ... p=1.0 # Always apply
1860
+ ... )
1861
+ >>> result = dilation_transform(image=image, mask=mask)
1862
+ >>> dilated_image = result['image'] # Text is thicker, gaps are filled
1863
+ >>> dilated_mask = result['mask'] # Mask is expanded around text regions
1864
+ >>>
1865
+ >>> # Example 2: Apply erosion to thin text or remove noise
1866
+ >>> erosion_transform = A.Morphological(
1867
+ ... scale=(2, 3), # Random kernel size between 2 and 3
1868
+ ... operation="erosion", # Shrink white regions (or expand black if inverted)
1869
+ ... p=1.0 # Always apply
1870
+ ... )
1871
+ >>> result = erosion_transform(image=image, mask=mask)
1872
+ >>> eroded_image = result['image'] # Text is thinner, small noise may be removed
1873
+ >>> eroded_mask = result['mask'] # Mask is contracted around text regions
1874
+ >>>
1875
+ >>> # Note: For document processing, dilation often helps enhance readability for OCR
1876
+ >>> # while erosion can help remove noise or separate connected components
1877
+
1878
+ """
1879
+
1880
+ _targets = ALL_TARGETS
1881
+
1882
+ class InitSchema(BaseTransformInitSchema):
1883
+ scale: OnePlusIntRangeType
1884
+ operation: Literal["erosion", "dilation"]
1885
+
1886
+ def __init__(
1887
+ self,
1888
+ scale: tuple[int, int] | int = (2, 3),
1889
+ operation: Literal["erosion", "dilation"] = "dilation",
1890
+ p: float = 0.5,
1891
+ ):
1892
+ super().__init__(p=p)
1893
+ self.scale = cast("tuple[int, int]", scale)
1894
+ self.operation = operation
1895
+
1896
+ def apply(
1897
+ self,
1898
+ img: np.ndarray,
1899
+ kernel: tuple[int, int],
1900
+ **params: Any,
1901
+ ) -> np.ndarray:
1902
+ """Apply the Morphological transform to the input image.
1903
+
1904
+ Args:
1905
+ img (np.ndarray): The input image to apply the Morphological transform to.
1906
+ kernel (tuple[int, int]): The structuring element (kernel) used for the operation.
1907
+ **params (Any): Additional parameters for the transform.
1908
+
1909
+ """
1910
+ return fgeometric.morphology(img, kernel, self.operation)
1911
+
1912
+ def apply_to_bboxes(
1913
+ self,
1914
+ bboxes: np.ndarray,
1915
+ kernel: tuple[int, int],
1916
+ **params: Any,
1917
+ ) -> np.ndarray:
1918
+ """Apply the Morphological transform to the input bounding boxes.
1919
+
1920
+ Args:
1921
+ bboxes (np.ndarray): The input bounding boxes to apply the Morphological transform to.
1922
+ kernel (tuple[int, int]): The structuring element (kernel) used for the operation.
1923
+ **params (Any): Additional parameters for the transform.
1924
+
1925
+ """
1926
+ image_shape = params["shape"]
1927
+
1928
+ denormalized_boxes = denormalize_bboxes(bboxes, image_shape)
1929
+
1930
+ result = fgeometric.bboxes_morphology(
1931
+ denormalized_boxes,
1932
+ kernel,
1933
+ self.operation,
1934
+ image_shape,
1935
+ )
1936
+
1937
+ return normalize_bboxes(result, image_shape)
1938
+
1939
+ def apply_to_keypoints(
1940
+ self,
1941
+ keypoints: np.ndarray,
1942
+ **params: Any,
1943
+ ) -> np.ndarray:
1944
+ """Apply the Morphological transform to the input keypoints.
1945
+
1946
+ Args:
1947
+ keypoints (np.ndarray): The input keypoints to apply the Morphological transform to.
1948
+ **params (Any): Additional parameters for the transform.
1949
+
1950
+ """
1951
+ return keypoints
1952
+
1953
+ def get_params(self) -> dict[str, float]:
1954
+ """Generate parameters for the Morphological transform.
1955
+
1956
+ Returns:
1957
+ dict[str, float]: The parameters of the transform.
1958
+
1959
+ """
1960
+ return {
1961
+ "kernel": cv2.getStructuringElement(cv2.MORPH_ELLIPSE, self.scale),
1962
+ }