nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1885 @@
|
|
|
1
|
+
"""Module for composing multiple transforms into augmentation pipelines.
|
|
2
|
+
|
|
3
|
+
This module provides classes for combining multiple transformations into cohesive
|
|
4
|
+
augmentation pipelines. It includes various composition strategies such as sequential
|
|
5
|
+
application, random selection, and conditional application of transforms. These
|
|
6
|
+
composition classes handle the coordination between different transforms, ensuring
|
|
7
|
+
proper data flow and maintaining consistent behavior across the augmentation pipeline.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import random
|
|
13
|
+
import warnings
|
|
14
|
+
from collections import defaultdict
|
|
15
|
+
from collections.abc import Iterator, Sequence
|
|
16
|
+
from typing import Any, Union, cast
|
|
17
|
+
|
|
18
|
+
import cv2
|
|
19
|
+
import numpy as np
|
|
20
|
+
|
|
21
|
+
from .bbox_utils import BboxParams, BboxProcessor
|
|
22
|
+
from .hub_mixin import HubMixin
|
|
23
|
+
from .keypoints_utils import KeypointParams, KeypointsProcessor
|
|
24
|
+
from .serialization import (
|
|
25
|
+
SERIALIZABLE_REGISTRY,
|
|
26
|
+
Serializable,
|
|
27
|
+
get_shortest_class_fullname,
|
|
28
|
+
instantiate_nonserializable,
|
|
29
|
+
)
|
|
30
|
+
from .transforms_interface import BasicTransform
|
|
31
|
+
from .utils import DataProcessor, format_args, get_shape
|
|
32
|
+
|
|
33
|
+
__all__ = [
|
|
34
|
+
"BaseCompose",
|
|
35
|
+
"BboxParams",
|
|
36
|
+
"Compose",
|
|
37
|
+
"KeypointParams",
|
|
38
|
+
"OneOf",
|
|
39
|
+
"OneOrOther",
|
|
40
|
+
"RandomOrder",
|
|
41
|
+
"ReplayCompose",
|
|
42
|
+
"SelectiveChannelTransform",
|
|
43
|
+
"Sequential",
|
|
44
|
+
"SomeOf",
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
NUM_ONEOF_TRANSFORMS = 2
|
|
48
|
+
REPR_INDENT_STEP = 2
|
|
49
|
+
|
|
50
|
+
TransformType = Union[BasicTransform, "BaseCompose"]
|
|
51
|
+
TransformsSeqType = list[TransformType]
|
|
52
|
+
|
|
53
|
+
AVAILABLE_KEYS = ("image", "mask", "masks", "bboxes", "keypoints", "volume", "volumes", "mask3d", "masks3d")
|
|
54
|
+
|
|
55
|
+
MASK_KEYS = (
|
|
56
|
+
"mask", # 2D mask
|
|
57
|
+
"masks", # Multiple 2D masks
|
|
58
|
+
"mask3d", # 3D mask
|
|
59
|
+
"masks3d", # Multiple 3D masks
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Keys related to image data
|
|
63
|
+
IMAGE_KEYS = {"image", "images"}
|
|
64
|
+
CHECKED_SINGLE = {"image", "mask"}
|
|
65
|
+
CHECKED_MULTI = {"masks", "images", "volumes", "masks3d"}
|
|
66
|
+
CHECK_BBOX_PARAM = {"bboxes"}
|
|
67
|
+
CHECK_KEYPOINTS_PARAM = {"keypoints"}
|
|
68
|
+
VOLUME_KEYS = {"volume", "volumes"}
|
|
69
|
+
CHECKED_VOLUME = {"volume"}
|
|
70
|
+
CHECKED_VOLUMES = {"volumes"}
|
|
71
|
+
CHECKED_MASK3D = {"mask3d"}
|
|
72
|
+
CHECKED_MASKS3D = {"masks3d"}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class BaseCompose(Serializable):
|
|
76
|
+
"""Base class for composing multiple transforms together.
|
|
77
|
+
|
|
78
|
+
This class serves as a foundation for creating compositions of transforms
|
|
79
|
+
in the Albumentations library. It provides basic functionality for
|
|
80
|
+
managing a sequence of transforms and applying them to data.
|
|
81
|
+
|
|
82
|
+
The class supports dynamic pipeline modification after initialization using
|
|
83
|
+
mathematical operators:
|
|
84
|
+
- Addition (`+`): Add transforms to the end of the pipeline
|
|
85
|
+
- Right addition (`__radd__`): Add transforms to the beginning of the pipeline
|
|
86
|
+
- Subtraction (`-`): Remove transforms by class from the pipeline
|
|
87
|
+
|
|
88
|
+
Attributes:
|
|
89
|
+
transforms (List[TransformType]): A list of transforms to be applied.
|
|
90
|
+
p (float): Probability of applying the compose. Should be in the range [0, 1].
|
|
91
|
+
replay_mode (bool): If True, the compose is in replay mode.
|
|
92
|
+
_additional_targets (Dict[str, str]): Additional targets for transforms.
|
|
93
|
+
_available_keys (Set[str]): Set of available keys for data.
|
|
94
|
+
processors (Dict[str, Union[BboxProcessor, KeypointsProcessor]]): Processors for specific data types.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
transforms (TransformsSeqType): A sequence of transforms to compose.
|
|
98
|
+
p (float): Probability of applying the compose.
|
|
99
|
+
|
|
100
|
+
Raises:
|
|
101
|
+
ValueError: If an invalid additional target is specified.
|
|
102
|
+
|
|
103
|
+
Note:
|
|
104
|
+
- Subclasses should implement the __call__ method to define how
|
|
105
|
+
the composition is applied to data.
|
|
106
|
+
- The class supports serialization and deserialization of transforms.
|
|
107
|
+
- It provides methods for adding targets, setting deterministic behavior,
|
|
108
|
+
and checking data validity post-transform.
|
|
109
|
+
- All compose classes support pipeline modification operators:
|
|
110
|
+
- `compose + transform` adds individual transform(s) to the end
|
|
111
|
+
- `transform + compose` adds individual transform(s) to the beginning
|
|
112
|
+
- `compose - TransformClass` removes transforms by class type
|
|
113
|
+
- Only BasicTransform instances (not BaseCompose) can be added
|
|
114
|
+
- All operator operations return new instances without modifying the original.
|
|
115
|
+
|
|
116
|
+
Examples:
|
|
117
|
+
>>> import albumentations as A
|
|
118
|
+
>>> # Create base pipeline
|
|
119
|
+
>>> compose = A.Compose([A.HorizontalFlip(p=1.0)])
|
|
120
|
+
>>>
|
|
121
|
+
>>> # Add transforms using operators
|
|
122
|
+
>>> extended = compose + A.VerticalFlip(p=1.0) # Append
|
|
123
|
+
>>> extended = compose + [A.Blur(), A.Rotate()] # Append multiple
|
|
124
|
+
>>> extended = A.RandomCrop(256, 256) + compose # Prepend
|
|
125
|
+
>>>
|
|
126
|
+
>>> # Remove transforms by class
|
|
127
|
+
>>> compose = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=1.0)])
|
|
128
|
+
>>> reduced = compose - A.HorizontalFlip # Remove by class
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
_transforms_dict: dict[int, BasicTransform] | None = None
|
|
133
|
+
check_each_transform: tuple[DataProcessor, ...] | None = None
|
|
134
|
+
main_compose: bool = True
|
|
135
|
+
|
|
136
|
+
def __init__(
|
|
137
|
+
self,
|
|
138
|
+
transforms: TransformsSeqType,
|
|
139
|
+
p: float,
|
|
140
|
+
mask_interpolation: int | None = None,
|
|
141
|
+
seed: int | None = None,
|
|
142
|
+
save_applied_params: bool = False,
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
):
|
|
145
|
+
if isinstance(transforms, (BaseCompose, BasicTransform)):
|
|
146
|
+
warnings.warn(
|
|
147
|
+
"transforms is single transform, but a sequence is expected! Transform will be wrapped into list.",
|
|
148
|
+
stacklevel=2,
|
|
149
|
+
)
|
|
150
|
+
transforms = [transforms]
|
|
151
|
+
|
|
152
|
+
self.transforms = transforms
|
|
153
|
+
self.p = p
|
|
154
|
+
|
|
155
|
+
self.replay_mode = False
|
|
156
|
+
self._additional_targets: dict[str, str] = {}
|
|
157
|
+
self._available_keys: set[str] = set()
|
|
158
|
+
self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
|
|
159
|
+
self._set_keys()
|
|
160
|
+
self.set_mask_interpolation(mask_interpolation)
|
|
161
|
+
self.set_random_seed(seed)
|
|
162
|
+
self.save_applied_params = save_applied_params
|
|
163
|
+
|
|
164
|
+
def _track_transform_params(self, transform: TransformType, data: dict[str, Any]) -> None:
|
|
165
|
+
"""Track transform parameters if tracking is enabled."""
|
|
166
|
+
if "applied_transforms" in data and hasattr(transform, "params") and transform.params:
|
|
167
|
+
data["applied_transforms"].append((transform.__class__.__name__, transform.params.copy()))
|
|
168
|
+
|
|
169
|
+
def set_random_state(
|
|
170
|
+
self,
|
|
171
|
+
random_generator: np.random.Generator,
|
|
172
|
+
py_random: random.Random,
|
|
173
|
+
) -> None:
|
|
174
|
+
"""Set random state directly from generators.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
random_generator (np.random.Generator): numpy random generator to use
|
|
178
|
+
py_random (random.Random): python random generator to use
|
|
179
|
+
|
|
180
|
+
"""
|
|
181
|
+
self.random_generator = random_generator
|
|
182
|
+
self.py_random = py_random
|
|
183
|
+
|
|
184
|
+
# Propagate both random states to all transforms
|
|
185
|
+
for transform in self.transforms:
|
|
186
|
+
if isinstance(transform, (BasicTransform, BaseCompose)):
|
|
187
|
+
transform.set_random_state(random_generator, py_random)
|
|
188
|
+
|
|
189
|
+
def set_random_seed(self, seed: int | None) -> None:
|
|
190
|
+
"""Set random state from seed.
|
|
191
|
+
|
|
192
|
+
Args:
|
|
193
|
+
seed (int | None): Random seed to use
|
|
194
|
+
|
|
195
|
+
"""
|
|
196
|
+
# Store the original seed
|
|
197
|
+
self.seed = seed
|
|
198
|
+
|
|
199
|
+
# Use base seed directly (subclasses like Compose can override this)
|
|
200
|
+
self.random_generator = np.random.default_rng(seed)
|
|
201
|
+
self.py_random = random.Random(seed)
|
|
202
|
+
|
|
203
|
+
# Propagate seed to all transforms
|
|
204
|
+
for transform in self.transforms:
|
|
205
|
+
if isinstance(transform, (BasicTransform, BaseCompose)):
|
|
206
|
+
transform.set_random_seed(seed)
|
|
207
|
+
|
|
208
|
+
def set_mask_interpolation(self, mask_interpolation: int | None) -> None:
|
|
209
|
+
"""Set interpolation mode for mask resizing operations.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
mask_interpolation (int | None): OpenCV interpolation flag to use for mask transforms.
|
|
213
|
+
If None, default interpolation for masks will be used.
|
|
214
|
+
|
|
215
|
+
"""
|
|
216
|
+
self.mask_interpolation = mask_interpolation
|
|
217
|
+
self._set_mask_interpolation_recursive(self.transforms)
|
|
218
|
+
|
|
219
|
+
def _set_mask_interpolation_recursive(self, transforms: TransformsSeqType) -> None:
|
|
220
|
+
for transform in transforms:
|
|
221
|
+
if isinstance(transform, BasicTransform):
|
|
222
|
+
if hasattr(transform, "mask_interpolation") and self.mask_interpolation is not None:
|
|
223
|
+
transform.mask_interpolation = self.mask_interpolation
|
|
224
|
+
elif isinstance(transform, BaseCompose):
|
|
225
|
+
transform.set_mask_interpolation(self.mask_interpolation)
|
|
226
|
+
|
|
227
|
+
def __iter__(self) -> Iterator[TransformType]:
|
|
228
|
+
return iter(self.transforms)
|
|
229
|
+
|
|
230
|
+
def __len__(self) -> int:
|
|
231
|
+
return len(self.transforms)
|
|
232
|
+
|
|
233
|
+
def __call__(self, *args: Any, **data: Any) -> dict[str, Any]:
|
|
234
|
+
"""Apply transforms.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
*args (Any): Positional arguments are not supported.
|
|
238
|
+
**data (Any): Named parameters with data to transform.
|
|
239
|
+
|
|
240
|
+
Returns:
|
|
241
|
+
dict[str, Any]: Transformed data.
|
|
242
|
+
|
|
243
|
+
Raises:
|
|
244
|
+
NotImplementedError: This method must be implemented by subclasses.
|
|
245
|
+
|
|
246
|
+
"""
|
|
247
|
+
raise NotImplementedError
|
|
248
|
+
|
|
249
|
+
def __getitem__(self, item: int) -> TransformType:
|
|
250
|
+
return self.transforms[item]
|
|
251
|
+
|
|
252
|
+
def __repr__(self) -> str:
|
|
253
|
+
return self.indented_repr()
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def additional_targets(self) -> dict[str, str]:
|
|
257
|
+
"""Get additional targets dictionary.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
dict[str, str]: Dictionary containing additional targets mapping.
|
|
261
|
+
|
|
262
|
+
"""
|
|
263
|
+
return self._additional_targets
|
|
264
|
+
|
|
265
|
+
@property
|
|
266
|
+
def available_keys(self) -> set[str]:
|
|
267
|
+
"""Get set of available keys.
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
set[str]: Set of string keys available for transforms.
|
|
271
|
+
|
|
272
|
+
"""
|
|
273
|
+
return self._available_keys
|
|
274
|
+
|
|
275
|
+
def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
|
|
276
|
+
"""Get an indented string representation of the composition.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
indent (int): Indentation level. Default: REPR_INDENT_STEP.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
str: Formatted string representation with proper indentation.
|
|
283
|
+
|
|
284
|
+
"""
|
|
285
|
+
args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")}
|
|
286
|
+
repr_string = self.__class__.__name__ + "(["
|
|
287
|
+
for t in self.transforms:
|
|
288
|
+
repr_string += "\n"
|
|
289
|
+
t_repr = t.indented_repr(indent + REPR_INDENT_STEP) if hasattr(t, "indented_repr") else repr(t)
|
|
290
|
+
repr_string += " " * indent + t_repr + ","
|
|
291
|
+
repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + f"], {format_args(args)})"
|
|
292
|
+
return repr_string
|
|
293
|
+
|
|
294
|
+
@classmethod
|
|
295
|
+
def get_class_fullname(cls) -> str:
|
|
296
|
+
"""Get the full qualified name of the class.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
str: The shortest class fullname.
|
|
300
|
+
|
|
301
|
+
"""
|
|
302
|
+
return get_shortest_class_fullname(cls)
|
|
303
|
+
|
|
304
|
+
@classmethod
|
|
305
|
+
def is_serializable(cls) -> bool:
|
|
306
|
+
"""Check if the class is serializable.
|
|
307
|
+
|
|
308
|
+
Returns:
|
|
309
|
+
bool: True if the class is serializable, False otherwise.
|
|
310
|
+
|
|
311
|
+
"""
|
|
312
|
+
return True
|
|
313
|
+
|
|
314
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
315
|
+
"""Convert the composition to a dictionary for serialization.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
dict[str, Any]: Dictionary representation of the composition.
|
|
319
|
+
|
|
320
|
+
"""
|
|
321
|
+
return {
|
|
322
|
+
"__class_fullname__": self.get_class_fullname(),
|
|
323
|
+
"p": self.p,
|
|
324
|
+
"transforms": [t.to_dict_private() for t in self.transforms],
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
def get_dict_with_id(self) -> dict[str, Any]:
|
|
328
|
+
"""Get a dictionary representation with object IDs for replay mode.
|
|
329
|
+
|
|
330
|
+
Returns:
|
|
331
|
+
dict[str, Any]: Dictionary with composition data and object IDs.
|
|
332
|
+
|
|
333
|
+
"""
|
|
334
|
+
return {
|
|
335
|
+
"__class_fullname__": self.get_class_fullname(),
|
|
336
|
+
"id": id(self),
|
|
337
|
+
"params": None,
|
|
338
|
+
"transforms": [t.get_dict_with_id() for t in self.transforms],
|
|
339
|
+
}
|
|
340
|
+
|
|
341
|
+
def add_targets(self, additional_targets: dict[str, str] | None) -> None:
|
|
342
|
+
"""Add additional targets to all transforms.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
additional_targets (dict[str, str] | None): Dict of name -> type mapping for additional targets.
|
|
346
|
+
If None, no additional targets will be added.
|
|
347
|
+
|
|
348
|
+
"""
|
|
349
|
+
if additional_targets:
|
|
350
|
+
for k, v in additional_targets.items():
|
|
351
|
+
if k in self._additional_targets and v != self._additional_targets[k]:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"Trying to overwrite existed additional targets. "
|
|
354
|
+
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
|
|
355
|
+
)
|
|
356
|
+
self._additional_targets.update(additional_targets)
|
|
357
|
+
for t in self.transforms:
|
|
358
|
+
t.add_targets(additional_targets)
|
|
359
|
+
for proc in self.processors.values():
|
|
360
|
+
proc.add_targets(additional_targets)
|
|
361
|
+
self._set_keys()
|
|
362
|
+
|
|
363
|
+
def _set_keys(self) -> None:
|
|
364
|
+
"""Set _available_keys"""
|
|
365
|
+
self._available_keys.update(self._additional_targets.keys())
|
|
366
|
+
for t in self.transforms:
|
|
367
|
+
self._available_keys.update(t.available_keys)
|
|
368
|
+
if hasattr(t, "targets_as_params"):
|
|
369
|
+
self._available_keys.update(t.targets_as_params)
|
|
370
|
+
if self.processors:
|
|
371
|
+
self._available_keys.update(["labels"])
|
|
372
|
+
for proc in self.processors.values():
|
|
373
|
+
if proc.default_data_name not in self._available_keys: # if no transform to process this data
|
|
374
|
+
warnings.warn(
|
|
375
|
+
f"Got processor for {proc.default_data_name}, but no transform to process it.",
|
|
376
|
+
stacklevel=2,
|
|
377
|
+
)
|
|
378
|
+
self._available_keys.update(proc.data_fields)
|
|
379
|
+
if proc.params.label_fields:
|
|
380
|
+
self._available_keys.update(proc.params.label_fields)
|
|
381
|
+
|
|
382
|
+
def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
|
|
383
|
+
"""Set deterministic mode for all transforms.
|
|
384
|
+
|
|
385
|
+
Args:
|
|
386
|
+
flag (bool): Whether to enable deterministic mode.
|
|
387
|
+
save_key (str): Key to save replay parameters. Default: "replay".
|
|
388
|
+
|
|
389
|
+
"""
|
|
390
|
+
for t in self.transforms:
|
|
391
|
+
t.set_deterministic(flag, save_key)
|
|
392
|
+
|
|
393
|
+
def check_data_post_transform(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
394
|
+
"""Check and filter data after transformation.
|
|
395
|
+
|
|
396
|
+
Args:
|
|
397
|
+
data (dict[str, Any]): Dictionary containing transformed data
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
dict[str, Any]: Filtered data dictionary
|
|
401
|
+
|
|
402
|
+
"""
|
|
403
|
+
if self.check_each_transform:
|
|
404
|
+
shape = get_shape(data)
|
|
405
|
+
|
|
406
|
+
for proc in self.check_each_transform:
|
|
407
|
+
for data_name, data_value in data.items():
|
|
408
|
+
if data_name in proc.data_fields or (
|
|
409
|
+
data_name in self._additional_targets
|
|
410
|
+
and self._additional_targets[data_name] in proc.data_fields
|
|
411
|
+
):
|
|
412
|
+
data[data_name] = proc.filter(data_value, shape)
|
|
413
|
+
return data
|
|
414
|
+
|
|
415
|
+
def _validate_transforms(self, transforms: list[Any]) -> None:
|
|
416
|
+
"""Validate that all elements are BasicTransform instances.
|
|
417
|
+
|
|
418
|
+
Args:
|
|
419
|
+
transforms: List of objects to validate
|
|
420
|
+
|
|
421
|
+
Raises:
|
|
422
|
+
TypeError: If any element is not a BasicTransform instance
|
|
423
|
+
|
|
424
|
+
"""
|
|
425
|
+
for t in transforms:
|
|
426
|
+
if not isinstance(t, BasicTransform):
|
|
427
|
+
raise TypeError(
|
|
428
|
+
f"All elements must be instances of BasicTransform, got {type(t).__name__}",
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
def _combine_transforms(self, other: TransformType | TransformsSeqType, *, prepend: bool = False) -> BaseCompose:
|
|
432
|
+
"""Combine transforms with the current compose.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
other: Transform or sequence of transforms to combine
|
|
436
|
+
prepend: If True, prepend other to the beginning; if False, append to the end
|
|
437
|
+
|
|
438
|
+
Returns:
|
|
439
|
+
BaseCompose: New compose instance with combined transforms
|
|
440
|
+
|
|
441
|
+
Raises:
|
|
442
|
+
TypeError: If other is not a valid transform or sequence of transforms
|
|
443
|
+
|
|
444
|
+
"""
|
|
445
|
+
if isinstance(other, (list, tuple)):
|
|
446
|
+
self._validate_transforms(other)
|
|
447
|
+
other_list = list(other)
|
|
448
|
+
else:
|
|
449
|
+
self._validate_transforms([other])
|
|
450
|
+
other_list = [other]
|
|
451
|
+
|
|
452
|
+
new_transforms = [*other_list, *list(self.transforms)] if prepend else [*list(self.transforms), *other_list]
|
|
453
|
+
|
|
454
|
+
return self._create_new_instance(new_transforms)
|
|
455
|
+
|
|
456
|
+
def __add__(self, other: TransformType | TransformsSeqType) -> BaseCompose:
|
|
457
|
+
"""Add transform(s) to the end of this compose.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
other: Transform or sequence of transforms to append
|
|
461
|
+
|
|
462
|
+
Returns:
|
|
463
|
+
BaseCompose: New compose instance with transforms appended
|
|
464
|
+
|
|
465
|
+
Raises:
|
|
466
|
+
TypeError: If other is not a valid transform or sequence of transforms
|
|
467
|
+
|
|
468
|
+
Examples:
|
|
469
|
+
>>> new_compose = compose + A.HorizontalFlip()
|
|
470
|
+
>>> new_compose = compose + [A.HorizontalFlip(), A.VerticalFlip()]
|
|
471
|
+
|
|
472
|
+
"""
|
|
473
|
+
return self._combine_transforms(other, prepend=False)
|
|
474
|
+
|
|
475
|
+
def __radd__(self, other: TransformType | TransformsSeqType) -> BaseCompose:
|
|
476
|
+
"""Add transform(s) to the beginning of this compose.
|
|
477
|
+
|
|
478
|
+
Args:
|
|
479
|
+
other: Transform or sequence of transforms to prepend
|
|
480
|
+
|
|
481
|
+
Returns:
|
|
482
|
+
BaseCompose: New compose instance with transforms prepended
|
|
483
|
+
|
|
484
|
+
Raises:
|
|
485
|
+
TypeError: If other is not a valid transform or sequence of transforms
|
|
486
|
+
|
|
487
|
+
Examples:
|
|
488
|
+
>>> new_compose = A.HorizontalFlip() + compose
|
|
489
|
+
>>> new_compose = [A.HorizontalFlip(), A.VerticalFlip()] + compose
|
|
490
|
+
|
|
491
|
+
"""
|
|
492
|
+
return self._combine_transforms(other, prepend=True)
|
|
493
|
+
|
|
494
|
+
def __sub__(self, other: type[BasicTransform]) -> BaseCompose:
|
|
495
|
+
"""Remove transform from this compose by class type.
|
|
496
|
+
|
|
497
|
+
Removes the first transform in the compose that matches the provided transform class.
|
|
498
|
+
|
|
499
|
+
Args:
|
|
500
|
+
other: Transform class to remove (e.g., A.HorizontalFlip)
|
|
501
|
+
|
|
502
|
+
Returns:
|
|
503
|
+
BaseCompose: New compose instance with transform removed
|
|
504
|
+
|
|
505
|
+
Raises:
|
|
506
|
+
TypeError: If other is not a BasicTransform class
|
|
507
|
+
ValueError: If no transform of that type is found in the compose
|
|
508
|
+
|
|
509
|
+
Note:
|
|
510
|
+
If multiple transforms of the same type exist in the compose,
|
|
511
|
+
only the first occurrence will be removed.
|
|
512
|
+
|
|
513
|
+
Examples:
|
|
514
|
+
>>> # Remove by transform class
|
|
515
|
+
>>> new_compose = compose - A.HorizontalFlip
|
|
516
|
+
>>>
|
|
517
|
+
>>> # With duplicates - only first occurrence removed
|
|
518
|
+
>>> compose = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(), A.HorizontalFlip(p=1.0)])
|
|
519
|
+
>>> result = compose - A.HorizontalFlip # Removes first HorizontalFlip (p=0.5)
|
|
520
|
+
>>> len(result.transforms) # 2 (VerticalFlip and second HorizontalFlip remain)
|
|
521
|
+
|
|
522
|
+
"""
|
|
523
|
+
# Validate that other is a BasicTransform class
|
|
524
|
+
if not (isinstance(other, type) and issubclass(other, BasicTransform)):
|
|
525
|
+
raise TypeError(
|
|
526
|
+
f"Can only remove BasicTransform classes, got {type(other).__name__}",
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
# Find first transform of matching class
|
|
530
|
+
new_transforms = list(self.transforms)
|
|
531
|
+
for i, transform in enumerate(new_transforms):
|
|
532
|
+
if type(transform) is other:
|
|
533
|
+
new_transforms.pop(i)
|
|
534
|
+
return self._create_new_instance(new_transforms)
|
|
535
|
+
|
|
536
|
+
# No matching transform found
|
|
537
|
+
class_name = other.__name__
|
|
538
|
+
raise ValueError(f"No transform of type {class_name} found in the compose pipeline")
|
|
539
|
+
|
|
540
|
+
def _create_new_instance(self, new_transforms: TransformsSeqType) -> BaseCompose:
|
|
541
|
+
"""Create a new instance of the same class with new transforms.
|
|
542
|
+
|
|
543
|
+
Args:
|
|
544
|
+
new_transforms: List of transforms for the new instance
|
|
545
|
+
|
|
546
|
+
Returns:
|
|
547
|
+
BaseCompose: New instance of the same class
|
|
548
|
+
|
|
549
|
+
"""
|
|
550
|
+
# Get current instance parameters
|
|
551
|
+
init_params = self._get_init_params()
|
|
552
|
+
init_params["transforms"] = new_transforms
|
|
553
|
+
|
|
554
|
+
# Create new instance
|
|
555
|
+
new_instance = self.__class__(**init_params)
|
|
556
|
+
|
|
557
|
+
# Copy random state from original instance to new instance
|
|
558
|
+
if hasattr(self, "random_generator") and hasattr(self, "py_random"):
|
|
559
|
+
new_instance.set_random_state(self.random_generator, self.py_random)
|
|
560
|
+
|
|
561
|
+
return new_instance
|
|
562
|
+
|
|
563
|
+
def _get_init_params(self) -> dict[str, Any]:
|
|
564
|
+
"""Get parameters needed to recreate this instance.
|
|
565
|
+
|
|
566
|
+
Note:
|
|
567
|
+
Subclasses that add new initialization parameters (other than 'transforms',
|
|
568
|
+
which is set separately in _create_new_instance) should override this method
|
|
569
|
+
to include those parameters in the returned dictionary.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
dict[str, Any]: Dictionary of initialization parameters
|
|
573
|
+
|
|
574
|
+
"""
|
|
575
|
+
return {
|
|
576
|
+
"p": self.p,
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
def _get_effective_seed(self, base_seed: int | None) -> int | None:
|
|
580
|
+
"""Get effective seed considering worker context.
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
base_seed (int | None): Base seed value
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
int | None: Effective seed after considering worker context
|
|
587
|
+
|
|
588
|
+
"""
|
|
589
|
+
if base_seed is None:
|
|
590
|
+
return base_seed
|
|
591
|
+
|
|
592
|
+
try:
|
|
593
|
+
import torch
|
|
594
|
+
import torch.utils.data
|
|
595
|
+
|
|
596
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
597
|
+
if worker_info is not None:
|
|
598
|
+
# We're in a DataLoader worker process
|
|
599
|
+
# Use torch.initial_seed() which is unique per worker and changes on respawn
|
|
600
|
+
torch_seed = torch.initial_seed() % (2**32)
|
|
601
|
+
return (base_seed + torch_seed) % (2**32)
|
|
602
|
+
except (ImportError, AttributeError):
|
|
603
|
+
# PyTorch not available or not in worker context
|
|
604
|
+
pass
|
|
605
|
+
|
|
606
|
+
return base_seed
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
class Compose(BaseCompose, HubMixin):
|
|
610
|
+
"""Compose multiple transforms together and apply them sequentially to input data.
|
|
611
|
+
|
|
612
|
+
This class allows you to chain multiple image augmentation transforms and apply them
|
|
613
|
+
in a specified order. It also handles bounding box and keypoint transformations if
|
|
614
|
+
the appropriate parameters are provided.
|
|
615
|
+
|
|
616
|
+
The Compose class supports dynamic pipeline modification after initialization using
|
|
617
|
+
mathematical operators. All parameters (bbox_params, keypoint_params, additional_targets,
|
|
618
|
+
etc.) are preserved when using operators to modify the pipeline.
|
|
619
|
+
|
|
620
|
+
Args:
|
|
621
|
+
transforms (list[BasicTransform | BaseCompose]): A list of transforms to apply.
|
|
622
|
+
bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
|
|
623
|
+
Can be a dict of params or a BboxParams object. Default is None.
|
|
624
|
+
keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
|
|
625
|
+
Can be a dict of params or a KeypointParams object. Default is None.
|
|
626
|
+
additional_targets (dict[str, str] | None): A dictionary mapping additional target names
|
|
627
|
+
to their types. For example, {'image2': 'image'}. Default is None.
|
|
628
|
+
p (float): Probability of applying all transforms. Should be in range [0, 1]. Default is 1.0.
|
|
629
|
+
is_check_shapes (bool): If True, checks consistency of shapes for image/mask/masks on each call.
|
|
630
|
+
Disable only if you are sure about your data consistency. Default is True.
|
|
631
|
+
strict (bool): If True, enables strict mode which:
|
|
632
|
+
1. Validates that all input keys are known/expected
|
|
633
|
+
2. Validates that no transforms have invalid arguments
|
|
634
|
+
3. Raises ValueError if any validation fails
|
|
635
|
+
If False, these validations are skipped. Default is False.
|
|
636
|
+
mask_interpolation (int | None): Interpolation method for mask transforms. When defined,
|
|
637
|
+
it overrides the interpolation method specified in individual transforms. Default is None.
|
|
638
|
+
seed (int | None): Controls reproducibility of random augmentations. Compose uses
|
|
639
|
+
its own internal random state, completely independent from global random seeds.
|
|
640
|
+
|
|
641
|
+
When seed is set (int):
|
|
642
|
+
- Creates a fixed internal random state
|
|
643
|
+
- Two Compose instances with the same seed and transforms will produce identical
|
|
644
|
+
sequences of augmentations
|
|
645
|
+
- Each call to the same Compose instance still produces random augmentations,
|
|
646
|
+
but these sequences are reproducible between different Compose instances
|
|
647
|
+
- Example: transform1 = A.Compose([...], seed=137) and
|
|
648
|
+
transform2 = A.Compose([...], seed=137) will produce identical sequences
|
|
649
|
+
|
|
650
|
+
When seed is None (default):
|
|
651
|
+
- Generates a new internal random state on each Compose creation
|
|
652
|
+
- Different Compose instances will produce different sequences of augmentations
|
|
653
|
+
- Example: transform = A.Compose([...]) # random results
|
|
654
|
+
|
|
655
|
+
Important: Setting random seeds outside of Compose (like np.random.seed() or
|
|
656
|
+
random.seed()) has no effect on augmentations as Compose uses its own internal
|
|
657
|
+
random state.
|
|
658
|
+
save_applied_params (bool): If True, saves the applied parameters of each transform. Default is False.
|
|
659
|
+
You will need to use the `applied_transforms` key in the output dictionary to access the parameters.
|
|
660
|
+
|
|
661
|
+
Examples:
|
|
662
|
+
>>> # Basic usage:
|
|
663
|
+
>>> import albumentations as A
|
|
664
|
+
>>> transform = A.Compose([
|
|
665
|
+
... A.RandomCrop(width=256, height=256),
|
|
666
|
+
... A.HorizontalFlip(p=0.5),
|
|
667
|
+
... A.RandomBrightnessContrast(p=0.2),
|
|
668
|
+
... ], seed=137)
|
|
669
|
+
>>> transformed = transform(image=image)
|
|
670
|
+
|
|
671
|
+
>>> # Pipeline modification after initialization:
|
|
672
|
+
>>> # Create initial pipeline with bbox support
|
|
673
|
+
>>> base_transform = A.Compose([
|
|
674
|
+
... A.HorizontalFlip(p=0.5),
|
|
675
|
+
... A.RandomCrop(width=512, height=512)
|
|
676
|
+
... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
|
|
677
|
+
>>>
|
|
678
|
+
>>> # Add transforms using operators (bbox_params preserved)
|
|
679
|
+
>>> extended = base_transform + A.RandomBrightnessContrast(p=0.3)
|
|
680
|
+
>>> extended = base_transform + [A.Blur(), A.GaussNoise()]
|
|
681
|
+
>>> extended = A.Resize(height=1024, width=1024) + base_transform
|
|
682
|
+
>>>
|
|
683
|
+
>>> # Remove transforms by class
|
|
684
|
+
>>> pipeline = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(), A.Rotate()])
|
|
685
|
+
>>> without_flip = pipeline - A.HorizontalFlip # Remove by class
|
|
686
|
+
|
|
687
|
+
Note:
|
|
688
|
+
- The class checks the validity of input data and shapes if is_check_args and is_check_shapes are True.
|
|
689
|
+
- When bbox_params or keypoint_params are provided, it sets up the corresponding processors.
|
|
690
|
+
- The transform can handle additional targets specified in the additional_targets dictionary.
|
|
691
|
+
- When strict mode is enabled, it performs additional validation to ensure data and transform
|
|
692
|
+
configuration correctness.
|
|
693
|
+
- Pipeline modification operators (+, -, __radd__) preserve all Compose parameters including
|
|
694
|
+
bbox_params, keypoint_params, additional_targets, and other configuration settings.
|
|
695
|
+
- All operators return new Compose instances without modifying the original pipeline.
|
|
696
|
+
|
|
697
|
+
"""
|
|
698
|
+
|
|
699
|
+
def __init__(
|
|
700
|
+
self,
|
|
701
|
+
transforms: TransformsSeqType,
|
|
702
|
+
bbox_params: dict[str, Any] | BboxParams | None = None,
|
|
703
|
+
keypoint_params: dict[str, Any] | KeypointParams | None = None,
|
|
704
|
+
additional_targets: dict[str, str] | None = None,
|
|
705
|
+
p: float = 1.0,
|
|
706
|
+
is_check_shapes: bool = True,
|
|
707
|
+
strict: bool = False,
|
|
708
|
+
mask_interpolation: int | None = None,
|
|
709
|
+
seed: int | None = None,
|
|
710
|
+
save_applied_params: bool = False,
|
|
711
|
+
):
|
|
712
|
+
# Store the original base seed for worker context recalculation
|
|
713
|
+
self._base_seed = seed
|
|
714
|
+
|
|
715
|
+
# Get effective seed considering worker context
|
|
716
|
+
effective_seed = self._get_effective_seed(seed)
|
|
717
|
+
|
|
718
|
+
super().__init__(
|
|
719
|
+
transforms=transforms,
|
|
720
|
+
p=p,
|
|
721
|
+
mask_interpolation=mask_interpolation,
|
|
722
|
+
seed=effective_seed,
|
|
723
|
+
save_applied_params=save_applied_params,
|
|
724
|
+
)
|
|
725
|
+
|
|
726
|
+
if bbox_params:
|
|
727
|
+
if isinstance(bbox_params, dict):
|
|
728
|
+
b_params = BboxParams(**bbox_params)
|
|
729
|
+
elif isinstance(bbox_params, BboxParams):
|
|
730
|
+
b_params = bbox_params
|
|
731
|
+
else:
|
|
732
|
+
msg = "unknown format of bbox_params, please use `dict` or `BboxParams`"
|
|
733
|
+
raise ValueError(msg)
|
|
734
|
+
self.processors["bboxes"] = BboxProcessor(b_params)
|
|
735
|
+
|
|
736
|
+
if keypoint_params:
|
|
737
|
+
if isinstance(keypoint_params, dict):
|
|
738
|
+
k_params = KeypointParams(**keypoint_params)
|
|
739
|
+
elif isinstance(keypoint_params, KeypointParams):
|
|
740
|
+
k_params = keypoint_params
|
|
741
|
+
else:
|
|
742
|
+
msg = "unknown format of keypoint_params, please use `dict` or `KeypointParams`"
|
|
743
|
+
raise ValueError(msg)
|
|
744
|
+
self.processors["keypoints"] = KeypointsProcessor(k_params)
|
|
745
|
+
|
|
746
|
+
for proc in self.processors.values():
|
|
747
|
+
proc.ensure_transforms_valid(self.transforms)
|
|
748
|
+
|
|
749
|
+
self.add_targets(additional_targets)
|
|
750
|
+
if not self.transforms: # if no transforms -> do nothing, all keys will be available
|
|
751
|
+
self._available_keys.update(AVAILABLE_KEYS)
|
|
752
|
+
|
|
753
|
+
self.is_check_args = True
|
|
754
|
+
self.strict = strict
|
|
755
|
+
|
|
756
|
+
self.is_check_shapes = is_check_shapes
|
|
757
|
+
self.check_each_transform = tuple( # processors that checks after each transform
|
|
758
|
+
proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False)
|
|
759
|
+
)
|
|
760
|
+
self._set_check_args_for_transforms(self.transforms)
|
|
761
|
+
|
|
762
|
+
self._set_processors_for_transforms(self.transforms)
|
|
763
|
+
|
|
764
|
+
self.save_applied_params = save_applied_params
|
|
765
|
+
self._images_was_list = False
|
|
766
|
+
self._masks_was_list = False
|
|
767
|
+
self._last_torch_seed: int | None = None
|
|
768
|
+
|
|
769
|
+
@property
|
|
770
|
+
def strict(self) -> bool:
|
|
771
|
+
"""Get the current strict mode setting.
|
|
772
|
+
|
|
773
|
+
Returns:
|
|
774
|
+
bool: True if strict mode is enabled, False otherwise.
|
|
775
|
+
|
|
776
|
+
"""
|
|
777
|
+
return self._strict
|
|
778
|
+
|
|
779
|
+
@strict.setter
|
|
780
|
+
def strict(self, value: bool) -> None:
|
|
781
|
+
# if value and not self._strict:
|
|
782
|
+
if value:
|
|
783
|
+
# Only validate when enabling strict mode
|
|
784
|
+
self._validate_strict()
|
|
785
|
+
self._strict = value
|
|
786
|
+
|
|
787
|
+
def _validate_strict(self) -> None:
|
|
788
|
+
"""Validate that no transforms have invalid arguments when strict mode is enabled."""
|
|
789
|
+
|
|
790
|
+
def check_transform(transform: TransformType) -> None:
|
|
791
|
+
if hasattr(transform, "invalid_args") and transform.invalid_args:
|
|
792
|
+
message = (
|
|
793
|
+
f"Argument(s) '{', '.join(transform.invalid_args)}' "
|
|
794
|
+
f"are not valid for transform {transform.__class__.__name__}"
|
|
795
|
+
)
|
|
796
|
+
raise ValueError(message)
|
|
797
|
+
if isinstance(transform, BaseCompose):
|
|
798
|
+
for t in transform.transforms:
|
|
799
|
+
check_transform(t)
|
|
800
|
+
|
|
801
|
+
for transform in self.transforms:
|
|
802
|
+
check_transform(transform)
|
|
803
|
+
|
|
804
|
+
def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None:
|
|
805
|
+
for transform in transforms:
|
|
806
|
+
if isinstance(transform, BasicTransform):
|
|
807
|
+
if hasattr(transform, "set_processors"):
|
|
808
|
+
transform.set_processors(self.processors)
|
|
809
|
+
elif isinstance(transform, BaseCompose):
|
|
810
|
+
self._set_processors_for_transforms(transform.transforms)
|
|
811
|
+
|
|
812
|
+
def _set_check_args_for_transforms(self, transforms: TransformsSeqType) -> None:
|
|
813
|
+
for transform in transforms:
|
|
814
|
+
if isinstance(transform, BaseCompose):
|
|
815
|
+
self._set_check_args_for_transforms(transform.transforms)
|
|
816
|
+
transform.check_each_transform = self.check_each_transform
|
|
817
|
+
transform.processors = self.processors
|
|
818
|
+
if isinstance(transform, Compose):
|
|
819
|
+
transform.disable_check_args_private()
|
|
820
|
+
|
|
821
|
+
def disable_check_args_private(self) -> None:
|
|
822
|
+
"""Disable argument checking for transforms.
|
|
823
|
+
|
|
824
|
+
This method disables strict mode and argument checking for all transforms in the composition.
|
|
825
|
+
"""
|
|
826
|
+
self.is_check_args = False
|
|
827
|
+
self.strict = False
|
|
828
|
+
self.main_compose = False
|
|
829
|
+
|
|
830
|
+
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
831
|
+
"""Apply transformations to data with automatic worker seed synchronization.
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
*args (Any): Positional arguments are not supported.
|
|
835
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
836
|
+
**data (Any): Dict with data to transform.
|
|
837
|
+
|
|
838
|
+
Returns:
|
|
839
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
840
|
+
|
|
841
|
+
Raises:
|
|
842
|
+
KeyError: If positional arguments are provided.
|
|
843
|
+
|
|
844
|
+
"""
|
|
845
|
+
# Check and sync worker seed if needed
|
|
846
|
+
self._check_worker_seed()
|
|
847
|
+
|
|
848
|
+
if args:
|
|
849
|
+
msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
|
|
850
|
+
raise KeyError(msg)
|
|
851
|
+
|
|
852
|
+
# Initialize applied_transforms only in top-level Compose if requested
|
|
853
|
+
if self.save_applied_params and self.main_compose:
|
|
854
|
+
data["applied_transforms"] = []
|
|
855
|
+
|
|
856
|
+
need_to_run = force_apply or self.py_random.random() < self.p
|
|
857
|
+
if not need_to_run:
|
|
858
|
+
return data
|
|
859
|
+
|
|
860
|
+
self.preprocess(data)
|
|
861
|
+
|
|
862
|
+
for t in self.transforms:
|
|
863
|
+
data = t(**data)
|
|
864
|
+
self._track_transform_params(t, data)
|
|
865
|
+
data = self.check_data_post_transform(data)
|
|
866
|
+
|
|
867
|
+
return self.postprocess(data)
|
|
868
|
+
|
|
869
|
+
def _check_worker_seed(self) -> None:
|
|
870
|
+
"""Check and update random seed if in worker context."""
|
|
871
|
+
if not hasattr(self, "_base_seed") or self._base_seed is None:
|
|
872
|
+
return
|
|
873
|
+
|
|
874
|
+
# Check if we're in a worker and need to update the seed
|
|
875
|
+
try:
|
|
876
|
+
import torch
|
|
877
|
+
import torch.utils.data
|
|
878
|
+
|
|
879
|
+
worker_info = torch.utils.data.get_worker_info()
|
|
880
|
+
if worker_info is not None:
|
|
881
|
+
# Get the current torch initial seed
|
|
882
|
+
current_torch_seed = torch.initial_seed()
|
|
883
|
+
|
|
884
|
+
# Check if we've already synchronized for this seed
|
|
885
|
+
if hasattr(self, "_last_torch_seed") and self._last_torch_seed == current_torch_seed:
|
|
886
|
+
return
|
|
887
|
+
|
|
888
|
+
# Update the seed and mark as synchronized
|
|
889
|
+
self._last_torch_seed = current_torch_seed
|
|
890
|
+
effective_seed = self._get_effective_seed(self._base_seed)
|
|
891
|
+
|
|
892
|
+
# Update our own random state
|
|
893
|
+
self.random_generator = np.random.default_rng(effective_seed)
|
|
894
|
+
self.py_random = random.Random(effective_seed)
|
|
895
|
+
|
|
896
|
+
# Propagate to all transforms
|
|
897
|
+
for transform in self.transforms:
|
|
898
|
+
if hasattr(transform, "set_random_state"):
|
|
899
|
+
transform.set_random_state(self.random_generator, self.py_random)
|
|
900
|
+
elif hasattr(transform, "set_random_seed"):
|
|
901
|
+
# For transforms that don't have set_random_state, use set_random_seed
|
|
902
|
+
transform.set_random_seed(effective_seed)
|
|
903
|
+
except (ImportError, AttributeError):
|
|
904
|
+
pass
|
|
905
|
+
|
|
906
|
+
def __setstate__(self, state: dict[str, Any]) -> None:
|
|
907
|
+
"""Set state from unpickling and handle worker seed."""
|
|
908
|
+
self.__dict__.update(state)
|
|
909
|
+
# If we have a base seed, recalculate effective seed in worker context
|
|
910
|
+
if hasattr(self, "_base_seed") and self._base_seed is not None:
|
|
911
|
+
# Reset _last_torch_seed to ensure worker-seed sync runs after unpickling
|
|
912
|
+
self._last_torch_seed = None
|
|
913
|
+
# Recalculate effective seed in worker context
|
|
914
|
+
self.set_random_seed(self._base_seed)
|
|
915
|
+
elif hasattr(self, "seed") and self.seed is not None:
|
|
916
|
+
# For backward compatibility, if no base seed but seed exists
|
|
917
|
+
self._base_seed = self.seed
|
|
918
|
+
self._last_torch_seed = None
|
|
919
|
+
self.set_random_seed(self.seed)
|
|
920
|
+
|
|
921
|
+
def set_random_seed(self, seed: int | None) -> None:
|
|
922
|
+
"""Override to use worker-aware seed functionality.
|
|
923
|
+
|
|
924
|
+
Args:
|
|
925
|
+
seed (int | None): Random seed to use
|
|
926
|
+
|
|
927
|
+
"""
|
|
928
|
+
# Store the original base seed
|
|
929
|
+
self._base_seed = seed
|
|
930
|
+
self.seed = seed
|
|
931
|
+
|
|
932
|
+
# Get effective seed considering worker context
|
|
933
|
+
effective_seed = self._get_effective_seed(seed)
|
|
934
|
+
|
|
935
|
+
# Initialize random generators with effective seed
|
|
936
|
+
self.random_generator = np.random.default_rng(effective_seed)
|
|
937
|
+
self.py_random = random.Random(effective_seed)
|
|
938
|
+
|
|
939
|
+
# Propagate to all transforms
|
|
940
|
+
for transform in self.transforms:
|
|
941
|
+
if hasattr(transform, "set_random_state"):
|
|
942
|
+
transform.set_random_state(self.random_generator, self.py_random)
|
|
943
|
+
elif hasattr(transform, "set_random_seed"):
|
|
944
|
+
# For transforms that don't have set_random_state, use set_random_seed
|
|
945
|
+
transform.set_random_seed(effective_seed)
|
|
946
|
+
|
|
947
|
+
def preprocess(self, data: Any) -> None:
|
|
948
|
+
"""Preprocess input data before applying transforms."""
|
|
949
|
+
# Always validate shapes if is_check_shapes is True, regardless of strict mode
|
|
950
|
+
if self.is_check_shapes:
|
|
951
|
+
shapes = [] # For H,W checks
|
|
952
|
+
volume_shapes = [] # For D,H,W checks
|
|
953
|
+
|
|
954
|
+
for data_name, data_value in data.items():
|
|
955
|
+
internal_name = self._additional_targets.get(data_name, data_name)
|
|
956
|
+
|
|
957
|
+
# Skip empty data
|
|
958
|
+
if data_value is None:
|
|
959
|
+
continue
|
|
960
|
+
|
|
961
|
+
shape = self._get_data_shape(data_name, internal_name, data_value)
|
|
962
|
+
if shape is not None:
|
|
963
|
+
if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
|
|
964
|
+
shapes.append(shape[1:3]) # H,W from (D,H,W)
|
|
965
|
+
volume_shapes.append(shape[:3]) # D,H,W
|
|
966
|
+
elif internal_name in {"volumes", "masks3d"}:
|
|
967
|
+
shapes.append(shape[2:4]) # H,W from (N,D,H,W)
|
|
968
|
+
volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
|
|
969
|
+
else:
|
|
970
|
+
shapes.append(shape[:2]) # H,W
|
|
971
|
+
|
|
972
|
+
self._check_shape_consistency(shapes, volume_shapes)
|
|
973
|
+
|
|
974
|
+
# Do strict validation only if enabled
|
|
975
|
+
if self.strict:
|
|
976
|
+
self._validate_data(data)
|
|
977
|
+
|
|
978
|
+
self._preprocess_processors(data)
|
|
979
|
+
self._preprocess_arrays(data)
|
|
980
|
+
|
|
981
|
+
def _validate_data(self, data: dict[str, Any]) -> None:
|
|
982
|
+
"""Validate input data keys and arguments."""
|
|
983
|
+
if not self.strict:
|
|
984
|
+
return
|
|
985
|
+
|
|
986
|
+
for data_name in data:
|
|
987
|
+
if not self._is_valid_key(data_name):
|
|
988
|
+
raise ValueError(f"Key {data_name} is not in available keys.")
|
|
989
|
+
|
|
990
|
+
if self.is_check_args:
|
|
991
|
+
self._check_args(**data)
|
|
992
|
+
|
|
993
|
+
def _is_valid_key(self, key: str) -> bool:
|
|
994
|
+
"""Check if the key is valid for processing."""
|
|
995
|
+
return key in self._available_keys or key in MASK_KEYS or key in IMAGE_KEYS or key == "applied_transforms"
|
|
996
|
+
|
|
997
|
+
def _preprocess_processors(self, data: dict[str, Any]) -> None:
|
|
998
|
+
"""Run preprocessors if this is the main compose."""
|
|
999
|
+
if not self.main_compose:
|
|
1000
|
+
return
|
|
1001
|
+
|
|
1002
|
+
for processor in self.processors.values():
|
|
1003
|
+
processor.ensure_data_valid(data)
|
|
1004
|
+
for processor in self.processors.values():
|
|
1005
|
+
processor.preprocess(data)
|
|
1006
|
+
|
|
1007
|
+
def _preprocess_arrays(self, data: dict[str, Any]) -> None:
|
|
1008
|
+
"""Convert lists to numpy arrays for images and masks, and ensure contiguity."""
|
|
1009
|
+
self._preprocess_images(data)
|
|
1010
|
+
self._preprocess_masks(data)
|
|
1011
|
+
|
|
1012
|
+
def _preprocess_images(self, data: dict[str, Any]) -> None:
|
|
1013
|
+
"""Convert image lists to numpy arrays."""
|
|
1014
|
+
if "images" not in data:
|
|
1015
|
+
return
|
|
1016
|
+
|
|
1017
|
+
if isinstance(data["images"], (list, tuple)):
|
|
1018
|
+
self._images_was_list = True
|
|
1019
|
+
# Skip stacking for empty lists
|
|
1020
|
+
if not data["images"]:
|
|
1021
|
+
return
|
|
1022
|
+
data["images"] = np.stack(data["images"])
|
|
1023
|
+
else:
|
|
1024
|
+
self._images_was_list = False
|
|
1025
|
+
|
|
1026
|
+
def _preprocess_masks(self, data: dict[str, Any]) -> None:
|
|
1027
|
+
"""Convert mask lists to numpy arrays."""
|
|
1028
|
+
if "masks" not in data:
|
|
1029
|
+
return
|
|
1030
|
+
|
|
1031
|
+
if isinstance(data["masks"], (list, tuple)):
|
|
1032
|
+
self._masks_was_list = True
|
|
1033
|
+
# Skip stacking for empty lists
|
|
1034
|
+
if not data["masks"]:
|
|
1035
|
+
return
|
|
1036
|
+
data["masks"] = np.stack(data["masks"])
|
|
1037
|
+
else:
|
|
1038
|
+
self._masks_was_list = False
|
|
1039
|
+
|
|
1040
|
+
def postprocess(self, data: dict[str, Any]) -> dict[str, Any]:
|
|
1041
|
+
"""Apply post-processing to data after all transforms have been applied.
|
|
1042
|
+
|
|
1043
|
+
Args:
|
|
1044
|
+
data (dict[str, Any]): Data after transformation.
|
|
1045
|
+
|
|
1046
|
+
Returns:
|
|
1047
|
+
dict[str, Any]: Post-processed data.
|
|
1048
|
+
|
|
1049
|
+
"""
|
|
1050
|
+
if self.main_compose:
|
|
1051
|
+
for p in self.processors.values():
|
|
1052
|
+
p.postprocess(data)
|
|
1053
|
+
|
|
1054
|
+
# Convert back to list if original input was a list
|
|
1055
|
+
if "images" in data and self._images_was_list:
|
|
1056
|
+
data["images"] = list(data["images"])
|
|
1057
|
+
|
|
1058
|
+
if "masks" in data and self._masks_was_list:
|
|
1059
|
+
data["masks"] = list(data["masks"])
|
|
1060
|
+
|
|
1061
|
+
return data
|
|
1062
|
+
|
|
1063
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
1064
|
+
"""Convert the composition to a dictionary for serialization.
|
|
1065
|
+
|
|
1066
|
+
Returns:
|
|
1067
|
+
dict[str, Any]: Dictionary representation of the composition.
|
|
1068
|
+
|
|
1069
|
+
"""
|
|
1070
|
+
dictionary = super().to_dict_private()
|
|
1071
|
+
bbox_processor = self.processors.get("bboxes")
|
|
1072
|
+
keypoints_processor = self.processors.get("keypoints")
|
|
1073
|
+
dictionary.update(
|
|
1074
|
+
{
|
|
1075
|
+
"bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
|
|
1076
|
+
"keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
|
|
1077
|
+
"additional_targets": self.additional_targets,
|
|
1078
|
+
"is_check_shapes": self.is_check_shapes,
|
|
1079
|
+
"seed": getattr(self, "_base_seed", None),
|
|
1080
|
+
},
|
|
1081
|
+
)
|
|
1082
|
+
return dictionary
|
|
1083
|
+
|
|
1084
|
+
def get_dict_with_id(self) -> dict[str, Any]:
|
|
1085
|
+
"""Get a dictionary representation with object IDs for replay mode.
|
|
1086
|
+
|
|
1087
|
+
Returns:
|
|
1088
|
+
dict[str, Any]: Dictionary with composition data and object IDs.
|
|
1089
|
+
|
|
1090
|
+
"""
|
|
1091
|
+
dictionary = super().get_dict_with_id()
|
|
1092
|
+
bbox_processor = self.processors.get("bboxes")
|
|
1093
|
+
keypoints_processor = self.processors.get("keypoints")
|
|
1094
|
+
dictionary.update(
|
|
1095
|
+
{
|
|
1096
|
+
"bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
|
|
1097
|
+
"keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
|
|
1098
|
+
"additional_targets": self.additional_targets,
|
|
1099
|
+
"params": None,
|
|
1100
|
+
"is_check_shapes": self.is_check_shapes,
|
|
1101
|
+
},
|
|
1102
|
+
)
|
|
1103
|
+
return dictionary
|
|
1104
|
+
|
|
1105
|
+
@staticmethod
|
|
1106
|
+
def _check_single_data(data_name: str, data: Any) -> tuple[int, int]:
|
|
1107
|
+
if not isinstance(data, np.ndarray):
|
|
1108
|
+
raise TypeError(f"{data_name} must be numpy array type")
|
|
1109
|
+
return data.shape[:2]
|
|
1110
|
+
|
|
1111
|
+
@staticmethod
|
|
1112
|
+
def _check_masks_data(data_name: str, data: Any) -> tuple[int, int] | None:
|
|
1113
|
+
"""Check masks data format and return shape.
|
|
1114
|
+
|
|
1115
|
+
Args:
|
|
1116
|
+
data_name (str): Name of the data field being checked
|
|
1117
|
+
data (Any): Input data in one of these formats:
|
|
1118
|
+
- List of numpy arrays, each of shape (H, W) or (H, W, C)
|
|
1119
|
+
- Numpy array of shape (N, H, W) or (N, H, W, C)
|
|
1120
|
+
- Empty list for cases where no masks are present
|
|
1121
|
+
|
|
1122
|
+
Returns:
|
|
1123
|
+
tuple[int, int] | None: (height, width) of the first mask, or None if masks list is empty
|
|
1124
|
+
Raises:
|
|
1125
|
+
TypeError: If data format is invalid
|
|
1126
|
+
|
|
1127
|
+
"""
|
|
1128
|
+
if isinstance(data, np.ndarray):
|
|
1129
|
+
if data.ndim not in [3, 4]: # (N,H,W) or (N,H,W,C)
|
|
1130
|
+
raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
|
|
1131
|
+
return data.shape[1:3] # Return (H,W)
|
|
1132
|
+
|
|
1133
|
+
if isinstance(data, (list, tuple)):
|
|
1134
|
+
if not data:
|
|
1135
|
+
# Allow empty list/tuple of masks
|
|
1136
|
+
return None
|
|
1137
|
+
if not all(isinstance(m, np.ndarray) for m in data):
|
|
1138
|
+
raise TypeError(f"All elements in {data_name} must be numpy arrays")
|
|
1139
|
+
if any(m.ndim not in {2, 3} for m in data):
|
|
1140
|
+
raise TypeError(f"All masks in {data_name} must be 2D or 3D numpy arrays")
|
|
1141
|
+
return data[0].shape[:2]
|
|
1142
|
+
|
|
1143
|
+
raise TypeError(f"{data_name} must be either a numpy array or a sequence of numpy arrays")
|
|
1144
|
+
|
|
1145
|
+
@staticmethod
|
|
1146
|
+
def _check_multi_data(data_name: str, data: Any) -> tuple[int, int]:
|
|
1147
|
+
"""Check multi-image data format and return shape.
|
|
1148
|
+
|
|
1149
|
+
Args:
|
|
1150
|
+
data_name (str): Name of the data field being checked
|
|
1151
|
+
data (Any): Input data in one of these formats:
|
|
1152
|
+
- List-like of numpy arrays
|
|
1153
|
+
- Numpy array of shape (N, H, W, C) or (N, H, W)
|
|
1154
|
+
|
|
1155
|
+
Returns:
|
|
1156
|
+
tuple[int, int]: (height, width) of the first image
|
|
1157
|
+
Raises:
|
|
1158
|
+
TypeError: If data format is invalid
|
|
1159
|
+
|
|
1160
|
+
"""
|
|
1161
|
+
if isinstance(data, np.ndarray):
|
|
1162
|
+
if data.ndim not in {3, 4}: # (N,H,W) or (N,H,W,C)
|
|
1163
|
+
raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
|
|
1164
|
+
return data.shape[1:3] # Return (H,W)
|
|
1165
|
+
|
|
1166
|
+
if not isinstance(data, Sequence) or not isinstance(data[0], np.ndarray):
|
|
1167
|
+
raise TypeError(f"{data_name} must be either a numpy array or a list of numpy arrays")
|
|
1168
|
+
return data[0].shape[:2]
|
|
1169
|
+
|
|
1170
|
+
@staticmethod
|
|
1171
|
+
def _check_bbox_keypoint_params(internal_data_name: str, processors: dict[str, Any]) -> None:
|
|
1172
|
+
if internal_data_name in CHECK_BBOX_PARAM and processors.get("bboxes") is None:
|
|
1173
|
+
raise ValueError("bbox_params must be specified for bbox transformations")
|
|
1174
|
+
if internal_data_name in CHECK_KEYPOINTS_PARAM and processors.get("keypoints") is None:
|
|
1175
|
+
raise ValueError("keypoints_params must be specified for keypoint transformations")
|
|
1176
|
+
|
|
1177
|
+
@staticmethod
|
|
1178
|
+
def _check_shapes(shapes: list[tuple[int, ...]], is_check_shapes: bool) -> None:
|
|
1179
|
+
if is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
|
|
1180
|
+
raise ValueError(
|
|
1181
|
+
"Height and Width of image, mask or masks should be equal. You can disable shapes check "
|
|
1182
|
+
"by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure "
|
|
1183
|
+
"about your data consistency).",
|
|
1184
|
+
)
|
|
1185
|
+
|
|
1186
|
+
def _check_args(self, **kwargs: Any) -> None:
|
|
1187
|
+
shapes = [] # For H,W checks
|
|
1188
|
+
volume_shapes = [] # For D,H,W checks
|
|
1189
|
+
|
|
1190
|
+
for data_name, data in kwargs.items():
|
|
1191
|
+
internal_name = self._additional_targets.get(data_name, data_name)
|
|
1192
|
+
|
|
1193
|
+
# For CHECKED_SINGLE, we must validate even if None
|
|
1194
|
+
if internal_name in CHECKED_SINGLE:
|
|
1195
|
+
if not isinstance(data, np.ndarray):
|
|
1196
|
+
raise TypeError(f"{data_name} must be numpy array type")
|
|
1197
|
+
shapes.append(data.shape[:2])
|
|
1198
|
+
continue
|
|
1199
|
+
|
|
1200
|
+
# Skip empty data or non-array/list inputs for other types
|
|
1201
|
+
if data is None:
|
|
1202
|
+
continue
|
|
1203
|
+
if not isinstance(data, (np.ndarray, list)):
|
|
1204
|
+
continue
|
|
1205
|
+
|
|
1206
|
+
self._check_bbox_keypoint_params(internal_name, self.processors)
|
|
1207
|
+
|
|
1208
|
+
shape = self._get_data_shape(data_name, internal_name, data)
|
|
1209
|
+
if shape is None:
|
|
1210
|
+
continue
|
|
1211
|
+
|
|
1212
|
+
# Handle different shape types
|
|
1213
|
+
if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
|
|
1214
|
+
shapes.append(shape[1:3]) # H,W from (D,H,W)
|
|
1215
|
+
volume_shapes.append(shape[:3]) # D,H,W
|
|
1216
|
+
elif internal_name in {"volumes", "masks3d"}:
|
|
1217
|
+
shapes.append(shape[2:4]) # H,W from (N,D,H,W)
|
|
1218
|
+
volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
|
|
1219
|
+
else:
|
|
1220
|
+
shapes.append(shape[:2]) # H,W
|
|
1221
|
+
|
|
1222
|
+
self._check_shape_consistency(shapes, volume_shapes)
|
|
1223
|
+
|
|
1224
|
+
def _get_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
|
|
1225
|
+
"""Get shape of data based on its type."""
|
|
1226
|
+
# Handle single images and masks
|
|
1227
|
+
if internal_name in CHECKED_SINGLE:
|
|
1228
|
+
return self._get_single_data_shape(data_name, data)
|
|
1229
|
+
|
|
1230
|
+
# Handle volumes
|
|
1231
|
+
if internal_name in CHECKED_VOLUME:
|
|
1232
|
+
return self._check_volume_data(data_name, data)
|
|
1233
|
+
|
|
1234
|
+
# Handle 3D masks
|
|
1235
|
+
if internal_name in CHECKED_MASK3D:
|
|
1236
|
+
return self._check_mask3d_data(data_name, data)
|
|
1237
|
+
|
|
1238
|
+
# Handle multi-item data (masks, images, volumes)
|
|
1239
|
+
if internal_name in CHECKED_MULTI:
|
|
1240
|
+
return self._get_multi_data_shape(data_name, internal_name, data)
|
|
1241
|
+
|
|
1242
|
+
return None
|
|
1243
|
+
|
|
1244
|
+
def _get_single_data_shape(self, data_name: str, data: np.ndarray) -> tuple[int, ...]:
|
|
1245
|
+
"""Get shape of single image or mask."""
|
|
1246
|
+
if not isinstance(data, np.ndarray):
|
|
1247
|
+
raise TypeError(f"{data_name} must be numpy array type")
|
|
1248
|
+
return data.shape
|
|
1249
|
+
|
|
1250
|
+
def _get_multi_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
|
|
1251
|
+
"""Get shape of multi-item data (masks, images, volumes)."""
|
|
1252
|
+
if internal_name == "masks":
|
|
1253
|
+
shape = self._check_masks_data(data_name, data)
|
|
1254
|
+
# Skip empty masks lists when returning shape
|
|
1255
|
+
return None if shape is None else shape
|
|
1256
|
+
|
|
1257
|
+
if internal_name in {"volumes", "masks3d"}: # Group these together
|
|
1258
|
+
if not isinstance(data, np.ndarray):
|
|
1259
|
+
raise TypeError(f"{data_name} must be numpy array type")
|
|
1260
|
+
if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
|
|
1261
|
+
raise TypeError(f"{data_name} must be 4D or 5D array")
|
|
1262
|
+
return data.shape # Return full shape
|
|
1263
|
+
|
|
1264
|
+
return self._check_multi_data(data_name, data)
|
|
1265
|
+
|
|
1266
|
+
def _check_shape_consistency(self, shapes: list[tuple[int, ...]], volume_shapes: list[tuple[int, ...]]) -> None:
|
|
1267
|
+
"""Check consistency of shapes."""
|
|
1268
|
+
# Check H,W consistency
|
|
1269
|
+
self._check_shapes(shapes, self.is_check_shapes)
|
|
1270
|
+
|
|
1271
|
+
# Check D,H,W consistency for volumes and 3D masks
|
|
1272
|
+
if self.is_check_shapes and volume_shapes and volume_shapes.count(volume_shapes[0]) != len(volume_shapes):
|
|
1273
|
+
raise ValueError(
|
|
1274
|
+
"Depth, Height and Width of volume, mask3d, volumes and masks3d should be equal. "
|
|
1275
|
+
"You can disable shapes check by setting is_check_shapes=False.",
|
|
1276
|
+
)
|
|
1277
|
+
|
|
1278
|
+
@staticmethod
|
|
1279
|
+
def _check_volume_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
|
|
1280
|
+
if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
|
|
1281
|
+
raise TypeError(f"{data_name} must be 3D or 4D array")
|
|
1282
|
+
return data.shape[:3] # Return (D,H,W)
|
|
1283
|
+
|
|
1284
|
+
@staticmethod
|
|
1285
|
+
def _check_volumes_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
|
|
1286
|
+
if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
|
|
1287
|
+
raise TypeError(f"{data_name} must be 4D or 5D array")
|
|
1288
|
+
return data.shape[1:4] # Return (D,H,W)
|
|
1289
|
+
|
|
1290
|
+
@staticmethod
|
|
1291
|
+
def _check_mask3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
|
|
1292
|
+
"""Check single volumetric mask data format and return shape."""
|
|
1293
|
+
if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
|
|
1294
|
+
raise TypeError(f"{data_name} must be 3D or 4D array")
|
|
1295
|
+
return data.shape[:3] # Return (D,H,W)
|
|
1296
|
+
|
|
1297
|
+
@staticmethod
|
|
1298
|
+
def _check_masks3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
|
|
1299
|
+
"""Check multiple volumetric masks data format and return shape."""
|
|
1300
|
+
if data.ndim not in [4, 5]: # (N,D,H,W) or (N,D,H,W,C)
|
|
1301
|
+
raise TypeError(f"{data_name} must be 4D or 5D array")
|
|
1302
|
+
return data.shape[1:4] # Return (D,H,W)
|
|
1303
|
+
|
|
1304
|
+
def _get_init_params(self) -> dict[str, Any]:
|
|
1305
|
+
"""Get parameters needed to recreate this Compose instance.
|
|
1306
|
+
|
|
1307
|
+
Returns:
|
|
1308
|
+
dict[str, Any]: Dictionary of initialization parameters
|
|
1309
|
+
|
|
1310
|
+
"""
|
|
1311
|
+
bbox_processor = self.processors.get("bboxes")
|
|
1312
|
+
keypoints_processor = self.processors.get("keypoints")
|
|
1313
|
+
|
|
1314
|
+
return {
|
|
1315
|
+
"bbox_params": bbox_processor.params if bbox_processor else None,
|
|
1316
|
+
"keypoint_params": keypoints_processor.params if keypoints_processor else None,
|
|
1317
|
+
"additional_targets": self.additional_targets,
|
|
1318
|
+
"p": self.p,
|
|
1319
|
+
"is_check_shapes": self.is_check_shapes,
|
|
1320
|
+
"strict": self.strict,
|
|
1321
|
+
"mask_interpolation": getattr(self, "mask_interpolation", None),
|
|
1322
|
+
"seed": getattr(self, "_base_seed", None),
|
|
1323
|
+
"save_applied_params": getattr(self, "save_applied_params", False),
|
|
1324
|
+
}
|
|
1325
|
+
|
|
1326
|
+
|
|
1327
|
+
class OneOf(BaseCompose):
|
|
1328
|
+
"""Select one of transforms to apply. Selected transform will be called with `force_apply=True`.
|
|
1329
|
+
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
|
|
1330
|
+
|
|
1331
|
+
Args:
|
|
1332
|
+
transforms (list): list of transformations to compose.
|
|
1333
|
+
p (float): probability of applying selected transform. Default: 0.5.
|
|
1334
|
+
|
|
1335
|
+
"""
|
|
1336
|
+
|
|
1337
|
+
def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
|
|
1338
|
+
super().__init__(transforms=transforms, p=p)
|
|
1339
|
+
transforms_ps = [t.p for t in self.transforms]
|
|
1340
|
+
s = sum(transforms_ps)
|
|
1341
|
+
self.transforms_ps = [t / s for t in transforms_ps]
|
|
1342
|
+
|
|
1343
|
+
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
1344
|
+
"""Apply the OneOf composition to the input data.
|
|
1345
|
+
|
|
1346
|
+
Args:
|
|
1347
|
+
*args (Any): Positional arguments are not supported.
|
|
1348
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1349
|
+
**data (Any): Dict with data to transform.
|
|
1350
|
+
|
|
1351
|
+
Returns:
|
|
1352
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
1353
|
+
|
|
1354
|
+
Raises:
|
|
1355
|
+
KeyError: If positional arguments are provided.
|
|
1356
|
+
|
|
1357
|
+
"""
|
|
1358
|
+
if self.replay_mode:
|
|
1359
|
+
for t in self.transforms:
|
|
1360
|
+
data = t(**data)
|
|
1361
|
+
return data
|
|
1362
|
+
|
|
1363
|
+
if self.transforms_ps and (force_apply or self.py_random.random() < self.p):
|
|
1364
|
+
idx: int = self.random_generator.choice(len(self.transforms), p=self.transforms_ps)
|
|
1365
|
+
t = self.transforms[idx]
|
|
1366
|
+
data = t(force_apply=True, **data)
|
|
1367
|
+
self._track_transform_params(t, data)
|
|
1368
|
+
return data
|
|
1369
|
+
|
|
1370
|
+
|
|
1371
|
+
class SomeOf(BaseCompose):
|
|
1372
|
+
"""Selects exactly `n` transforms from the given list and applies them.
|
|
1373
|
+
|
|
1374
|
+
The selection of which `n` transforms to apply is done **uniformly at random**
|
|
1375
|
+
from the provided list. Each transform in the list has an equal chance of being selected.
|
|
1376
|
+
|
|
1377
|
+
Once the `n` transforms are selected, each one is applied **based on its
|
|
1378
|
+
individual probability** `p`.
|
|
1379
|
+
|
|
1380
|
+
Args:
|
|
1381
|
+
transforms (list[BasicTransform | BaseCompose]): A list of transforms to choose from.
|
|
1382
|
+
n (int): The exact number of transforms to select and potentially apply.
|
|
1383
|
+
If `replace=False` and `n` is greater than the number of available transforms,
|
|
1384
|
+
`n` will be capped at the number of transforms.
|
|
1385
|
+
replace (bool): Whether to sample transforms with replacement. If True, the same
|
|
1386
|
+
transform can be selected multiple times (up to `n` times).
|
|
1387
|
+
Default is False.
|
|
1388
|
+
p (float): The probability that this `SomeOf` composition will be applied.
|
|
1389
|
+
If applied, it will select `n` transforms and attempt to apply them.
|
|
1390
|
+
Default is 1.0.
|
|
1391
|
+
|
|
1392
|
+
Note:
|
|
1393
|
+
- The overall probability `p` of the `SomeOf` block determines if *any* selection
|
|
1394
|
+
and application occurs.
|
|
1395
|
+
- The individual probability `p` of each transform inside the list determines if
|
|
1396
|
+
that specific transform runs *if it is selected*.
|
|
1397
|
+
- If `replace` is True, the same transform might be selected multiple times, and
|
|
1398
|
+
its individual probability `p` will be checked each time it's encountered.
|
|
1399
|
+
- When using pipeline modification operators (+, -, __radd__), the `n` parameter
|
|
1400
|
+
is preserved while the pool of available transforms changes:
|
|
1401
|
+
- `SomeOf([A, B], n=2) + C` → `SomeOf([A, B, C], n=2)` (selects 2 from 3 transforms)
|
|
1402
|
+
- This allows for dynamic adjustment of the transform pool without changing selection count.
|
|
1403
|
+
|
|
1404
|
+
Examples:
|
|
1405
|
+
>>> import albumentations as A
|
|
1406
|
+
>>> transform = A.SomeOf([
|
|
1407
|
+
... A.HorizontalFlip(p=0.5), # 50% chance to apply if selected
|
|
1408
|
+
... A.VerticalFlip(p=0.8), # 80% chance to apply if selected
|
|
1409
|
+
... A.RandomRotate90(p=1.0), # 100% chance to apply if selected
|
|
1410
|
+
... ], n=2, replace=False, p=1.0) # Always select 2 transforms uniformly
|
|
1411
|
+
|
|
1412
|
+
# In each call, 2 transforms out of 3 are chosen uniformly.
|
|
1413
|
+
# For example, if HFlip and VFlip are chosen:
|
|
1414
|
+
# - HFlip runs if random() < 0.5
|
|
1415
|
+
# - VFlip runs if random() < 0.8
|
|
1416
|
+
# If VFlip and Rotate90 are chosen:
|
|
1417
|
+
# - VFlip runs if random() < 0.8
|
|
1418
|
+
# - Rotate90 runs if random() < 1.0 (always)
|
|
1419
|
+
|
|
1420
|
+
>>> # Pipeline modification example:
|
|
1421
|
+
>>> # Add more transforms to the pool while keeping n=2
|
|
1422
|
+
>>> extended = transform + [A.Blur(p=1.0), A.RandomBrightnessContrast(p=0.7)]
|
|
1423
|
+
>>> # Now selects 2 transforms from 5 available transforms uniformly
|
|
1424
|
+
|
|
1425
|
+
"""
|
|
1426
|
+
|
|
1427
|
+
def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
|
|
1428
|
+
super().__init__(transforms, p)
|
|
1429
|
+
self.n = n
|
|
1430
|
+
if not replace and n > len(self.transforms):
|
|
1431
|
+
self.n = len(self.transforms)
|
|
1432
|
+
warnings.warn(
|
|
1433
|
+
f"`n` is greater than number of transforms. `n` will be set to {self.n}.",
|
|
1434
|
+
UserWarning,
|
|
1435
|
+
stacklevel=2,
|
|
1436
|
+
)
|
|
1437
|
+
self.replace = replace
|
|
1438
|
+
|
|
1439
|
+
def __call__(self, *arg: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
1440
|
+
"""Apply n randomly selected transforms from the list of transforms.
|
|
1441
|
+
|
|
1442
|
+
Args:
|
|
1443
|
+
*arg (Any): Positional arguments are not supported.
|
|
1444
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1445
|
+
**data (Any): Dict with data to transform.
|
|
1446
|
+
|
|
1447
|
+
Returns:
|
|
1448
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
1449
|
+
|
|
1450
|
+
"""
|
|
1451
|
+
if self.replay_mode:
|
|
1452
|
+
for t in self.transforms:
|
|
1453
|
+
data = t(**data)
|
|
1454
|
+
data = self.check_data_post_transform(data)
|
|
1455
|
+
return data
|
|
1456
|
+
|
|
1457
|
+
if self.py_random.random() < self.p: # Check overall SomeOf probability
|
|
1458
|
+
# Get indices uniformly
|
|
1459
|
+
indices_to_consider = self._get_idx()
|
|
1460
|
+
for i in indices_to_consider:
|
|
1461
|
+
t = self.transforms[i]
|
|
1462
|
+
# Apply the transform respecting its own probability `t.p`
|
|
1463
|
+
data = t(**data)
|
|
1464
|
+
self._track_transform_params(t, data)
|
|
1465
|
+
data = self.check_data_post_transform(data)
|
|
1466
|
+
return data
|
|
1467
|
+
|
|
1468
|
+
def _get_idx(self) -> np.ndarray[np.int_]:
|
|
1469
|
+
# Use uniform probability for selection, ignore individual p values here
|
|
1470
|
+
idx = self.random_generator.choice(
|
|
1471
|
+
len(self.transforms),
|
|
1472
|
+
size=self.n,
|
|
1473
|
+
replace=self.replace,
|
|
1474
|
+
)
|
|
1475
|
+
idx.sort()
|
|
1476
|
+
return idx
|
|
1477
|
+
|
|
1478
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
1479
|
+
"""Convert the SomeOf composition to a dictionary for serialization.
|
|
1480
|
+
|
|
1481
|
+
Returns:
|
|
1482
|
+
dict[str, Any]: Dictionary representation of the composition.
|
|
1483
|
+
|
|
1484
|
+
"""
|
|
1485
|
+
dictionary = super().to_dict_private()
|
|
1486
|
+
dictionary.update({"n": self.n, "replace": self.replace})
|
|
1487
|
+
return dictionary
|
|
1488
|
+
|
|
1489
|
+
def _get_init_params(self) -> dict[str, Any]:
|
|
1490
|
+
"""Get parameters needed to recreate this SomeOf instance.
|
|
1491
|
+
|
|
1492
|
+
Returns:
|
|
1493
|
+
dict[str, Any]: Dictionary of initialization parameters
|
|
1494
|
+
|
|
1495
|
+
"""
|
|
1496
|
+
base_params = super()._get_init_params()
|
|
1497
|
+
base_params.update(
|
|
1498
|
+
{
|
|
1499
|
+
"n": self.n,
|
|
1500
|
+
"replace": self.replace,
|
|
1501
|
+
},
|
|
1502
|
+
)
|
|
1503
|
+
return base_params
|
|
1504
|
+
|
|
1505
|
+
|
|
1506
|
+
class RandomOrder(SomeOf):
|
|
1507
|
+
"""Apply a random subset of transforms from the given list in a random order.
|
|
1508
|
+
|
|
1509
|
+
Selects exactly `n` transforms uniformly at random from the list, and then applies
|
|
1510
|
+
the selected transforms in a random order. Each selected transform is applied
|
|
1511
|
+
based on its individual probability `p`.
|
|
1512
|
+
|
|
1513
|
+
Attributes:
|
|
1514
|
+
transforms (TransformsSeqType): A list of transformations to choose from.
|
|
1515
|
+
n (int): The number of transforms to apply. If `n` is greater than the number of available transforms
|
|
1516
|
+
and `replace` is False, `n` will be set to the number of available transforms.
|
|
1517
|
+
replace (bool): Whether to sample transforms with replacement. If True, the same transform can be
|
|
1518
|
+
selected multiple times. Default is False.
|
|
1519
|
+
p (float): Probability of applying the selected transforms. Should be in the range [0, 1]. Default is 1.0.
|
|
1520
|
+
|
|
1521
|
+
Examples:
|
|
1522
|
+
>>> import albumentations as A
|
|
1523
|
+
>>> transform = A.RandomOrder([
|
|
1524
|
+
... A.HorizontalFlip(p=0.5),
|
|
1525
|
+
... A.VerticalFlip(p=1.0),
|
|
1526
|
+
... A.RandomBrightnessContrast(p=0.8),
|
|
1527
|
+
... ], n=2, replace=False, p=1.0)
|
|
1528
|
+
>>> # This will uniformly select 2 transforms and apply them in a random order,
|
|
1529
|
+
>>> # respecting their individual probabilities (0.5, 1.0, 0.8).
|
|
1530
|
+
|
|
1531
|
+
Note:
|
|
1532
|
+
- Inherits from SomeOf, but overrides `_get_idx` to ensure random order without sorting.
|
|
1533
|
+
- Selection is uniform; application depends on individual transform probabilities.
|
|
1534
|
+
|
|
1535
|
+
"""
|
|
1536
|
+
|
|
1537
|
+
def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
|
|
1538
|
+
# Initialize using SomeOf's logic (which now does uniform selection setup)
|
|
1539
|
+
super().__init__(transforms=transforms, n=n, replace=replace, p=p)
|
|
1540
|
+
|
|
1541
|
+
def _get_idx(self) -> np.ndarray[np.int_]:
|
|
1542
|
+
# Perform uniform random selection without replacement, like SomeOf
|
|
1543
|
+
# Crucially, DO NOT sort the indices here to maintain random order.
|
|
1544
|
+
return self.random_generator.choice(
|
|
1545
|
+
len(self.transforms),
|
|
1546
|
+
size=self.n,
|
|
1547
|
+
replace=self.replace,
|
|
1548
|
+
)
|
|
1549
|
+
|
|
1550
|
+
|
|
1551
|
+
class OneOrOther(BaseCompose):
|
|
1552
|
+
"""Select one or another transform to apply. Selected transform will be called with `force_apply=True`."""
|
|
1553
|
+
|
|
1554
|
+
def __init__(
|
|
1555
|
+
self,
|
|
1556
|
+
first: TransformType | None = None,
|
|
1557
|
+
second: TransformType | None = None,
|
|
1558
|
+
transforms: TransformsSeqType | None = None,
|
|
1559
|
+
p: float = 0.5,
|
|
1560
|
+
):
|
|
1561
|
+
if transforms is None:
|
|
1562
|
+
if first is None or second is None:
|
|
1563
|
+
msg = "You must set both first and second or set transforms argument."
|
|
1564
|
+
raise ValueError(msg)
|
|
1565
|
+
transforms = [first, second]
|
|
1566
|
+
super().__init__(transforms=transforms, p=p)
|
|
1567
|
+
if len(self.transforms) != NUM_ONEOF_TRANSFORMS:
|
|
1568
|
+
warnings.warn("Length of transforms is not equal to 2.", stacklevel=2)
|
|
1569
|
+
|
|
1570
|
+
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
1571
|
+
"""Apply one or another transform to the input data.
|
|
1572
|
+
|
|
1573
|
+
Args:
|
|
1574
|
+
*args (Any): Positional arguments are not supported.
|
|
1575
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1576
|
+
**data (Any): Dict with data to transform.
|
|
1577
|
+
|
|
1578
|
+
Returns:
|
|
1579
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
1580
|
+
|
|
1581
|
+
"""
|
|
1582
|
+
if self.replay_mode:
|
|
1583
|
+
for t in self.transforms:
|
|
1584
|
+
data = t(**data)
|
|
1585
|
+
self._track_transform_params(t, data)
|
|
1586
|
+
return data
|
|
1587
|
+
|
|
1588
|
+
if self.py_random.random() < self.p:
|
|
1589
|
+
return self.transforms[0](force_apply=True, **data)
|
|
1590
|
+
|
|
1591
|
+
return self.transforms[-1](force_apply=True, **data)
|
|
1592
|
+
|
|
1593
|
+
|
|
1594
|
+
class SelectiveChannelTransform(BaseCompose):
|
|
1595
|
+
"""A transformation class to apply specified transforms to selected channels of an image.
|
|
1596
|
+
|
|
1597
|
+
This class extends BaseCompose to allow selective application of transformations to
|
|
1598
|
+
specified image channels. It extracts the selected channels, applies the transformations,
|
|
1599
|
+
and then reinserts the transformed channels back into their original positions in the image.
|
|
1600
|
+
|
|
1601
|
+
Args:
|
|
1602
|
+
transforms (TransformsSeqType):
|
|
1603
|
+
A sequence of transformations (from Albumentations) to be applied to the specified channels.
|
|
1604
|
+
channels (Sequence[int]):
|
|
1605
|
+
A sequence of integers specifying the indices of the channels to which the transforms should be applied.
|
|
1606
|
+
p (float): Probability that the transform will be applied; the default is 1.0 (always apply).
|
|
1607
|
+
|
|
1608
|
+
Returns:
|
|
1609
|
+
dict[str, Any]: The transformed data dictionary, which includes the transformed 'image' key.
|
|
1610
|
+
|
|
1611
|
+
Note:
|
|
1612
|
+
- When using pipeline modification operators (+, -, __radd__), the `channels` parameter
|
|
1613
|
+
is preserved in the resulting SelectiveChannelTransform instance.
|
|
1614
|
+
- Only the transform list is modified while maintaining the same channel selection behavior.
|
|
1615
|
+
|
|
1616
|
+
"""
|
|
1617
|
+
|
|
1618
|
+
def __init__(
|
|
1619
|
+
self,
|
|
1620
|
+
transforms: TransformsSeqType,
|
|
1621
|
+
channels: Sequence[int] = (0, 1, 2),
|
|
1622
|
+
p: float = 1.0,
|
|
1623
|
+
) -> None:
|
|
1624
|
+
super().__init__(transforms=transforms, p=p)
|
|
1625
|
+
self.channels = channels
|
|
1626
|
+
|
|
1627
|
+
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
1628
|
+
"""Apply transforms to specific channels of the image.
|
|
1629
|
+
|
|
1630
|
+
Args:
|
|
1631
|
+
*args (Any): Positional arguments are not supported.
|
|
1632
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1633
|
+
**data (Any): Dict with data to transform.
|
|
1634
|
+
|
|
1635
|
+
Returns:
|
|
1636
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
1637
|
+
|
|
1638
|
+
"""
|
|
1639
|
+
if force_apply or self.py_random.random() < self.p:
|
|
1640
|
+
image = data["image"]
|
|
1641
|
+
|
|
1642
|
+
selected_channels = image[:, :, self.channels]
|
|
1643
|
+
sub_image = np.ascontiguousarray(selected_channels)
|
|
1644
|
+
|
|
1645
|
+
for t in self.transforms:
|
|
1646
|
+
sub_data = {"image": sub_image}
|
|
1647
|
+
sub_image = t(**sub_data)["image"]
|
|
1648
|
+
self._track_transform_params(t, sub_data)
|
|
1649
|
+
|
|
1650
|
+
transformed_channels = cv2.split(sub_image)
|
|
1651
|
+
output_img = image.copy()
|
|
1652
|
+
|
|
1653
|
+
for idx, channel in zip(self.channels, transformed_channels):
|
|
1654
|
+
output_img[:, :, idx] = channel
|
|
1655
|
+
|
|
1656
|
+
data["image"] = np.ascontiguousarray(output_img)
|
|
1657
|
+
|
|
1658
|
+
return data
|
|
1659
|
+
|
|
1660
|
+
def _get_init_params(self) -> dict[str, Any]:
|
|
1661
|
+
"""Get parameters needed to recreate this SelectiveChannelTransform instance.
|
|
1662
|
+
|
|
1663
|
+
Returns:
|
|
1664
|
+
dict[str, Any]: Dictionary of initialization parameters
|
|
1665
|
+
|
|
1666
|
+
"""
|
|
1667
|
+
base_params = super()._get_init_params()
|
|
1668
|
+
base_params.update(
|
|
1669
|
+
{
|
|
1670
|
+
"channels": self.channels,
|
|
1671
|
+
},
|
|
1672
|
+
)
|
|
1673
|
+
return base_params
|
|
1674
|
+
|
|
1675
|
+
|
|
1676
|
+
class ReplayCompose(Compose):
|
|
1677
|
+
"""Composition class that enables transform replay functionality.
|
|
1678
|
+
|
|
1679
|
+
This class extends the Compose class with the ability to record and replay
|
|
1680
|
+
transformations. This is useful for applying the same sequence of random
|
|
1681
|
+
transformations to different data.
|
|
1682
|
+
|
|
1683
|
+
Args:
|
|
1684
|
+
transforms (TransformsSeqType): List of transformations to compose.
|
|
1685
|
+
bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
|
|
1686
|
+
keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
|
|
1687
|
+
additional_targets (dict[str, str] | None): Dictionary of additional targets.
|
|
1688
|
+
p (float): Probability of applying the compose.
|
|
1689
|
+
is_check_shapes (bool): Whether to check shapes of different targets.
|
|
1690
|
+
save_key (str): Key for storing the applied transformations.
|
|
1691
|
+
|
|
1692
|
+
"""
|
|
1693
|
+
|
|
1694
|
+
def __init__(
|
|
1695
|
+
self,
|
|
1696
|
+
transforms: TransformsSeqType,
|
|
1697
|
+
bbox_params: dict[str, Any] | BboxParams | None = None,
|
|
1698
|
+
keypoint_params: dict[str, Any] | KeypointParams | None = None,
|
|
1699
|
+
additional_targets: dict[str, str] | None = None,
|
|
1700
|
+
p: float = 1.0,
|
|
1701
|
+
is_check_shapes: bool = True,
|
|
1702
|
+
save_key: str = "replay",
|
|
1703
|
+
):
|
|
1704
|
+
super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes)
|
|
1705
|
+
self.set_deterministic(True, save_key=save_key)
|
|
1706
|
+
self.save_key = save_key
|
|
1707
|
+
self._available_keys.add(save_key)
|
|
1708
|
+
|
|
1709
|
+
def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> dict[str, Any]:
|
|
1710
|
+
"""Apply transforms and record parameters for future replay.
|
|
1711
|
+
|
|
1712
|
+
Args:
|
|
1713
|
+
*args (Any): Positional arguments are not supported.
|
|
1714
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1715
|
+
**kwargs (Any): Dict with data to transform.
|
|
1716
|
+
|
|
1717
|
+
Returns:
|
|
1718
|
+
dict[str, Any]: Dictionary with transformed data and replay information.
|
|
1719
|
+
|
|
1720
|
+
"""
|
|
1721
|
+
kwargs[self.save_key] = defaultdict(dict)
|
|
1722
|
+
result = super().__call__(force_apply=force_apply, **kwargs)
|
|
1723
|
+
serialized = self.get_dict_with_id()
|
|
1724
|
+
self.fill_with_params(serialized, result[self.save_key])
|
|
1725
|
+
self.fill_applied(serialized)
|
|
1726
|
+
result[self.save_key] = serialized
|
|
1727
|
+
return result
|
|
1728
|
+
|
|
1729
|
+
@staticmethod
|
|
1730
|
+
def replay(saved_augmentations: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
|
|
1731
|
+
"""Replay previously saved augmentations.
|
|
1732
|
+
|
|
1733
|
+
Args:
|
|
1734
|
+
saved_augmentations (dict[str, Any]): Previously saved augmentation parameters.
|
|
1735
|
+
**kwargs (Any): Dict with data to transform.
|
|
1736
|
+
|
|
1737
|
+
Returns:
|
|
1738
|
+
dict[str, Any]: Dictionary with transformed data using saved parameters.
|
|
1739
|
+
|
|
1740
|
+
"""
|
|
1741
|
+
augs = ReplayCompose._restore_for_replay(saved_augmentations)
|
|
1742
|
+
return augs(force_apply=True, **kwargs)
|
|
1743
|
+
|
|
1744
|
+
@staticmethod
|
|
1745
|
+
def _restore_for_replay(
|
|
1746
|
+
transform_dict: dict[str, Any],
|
|
1747
|
+
lambda_transforms: dict[str, Any] | None = None,
|
|
1748
|
+
) -> TransformType:
|
|
1749
|
+
"""Args:
|
|
1750
|
+
transform_dict (dict[str, Any]): A dictionary that contains transform data.
|
|
1751
|
+
lambda_transforms (dict): A dictionary that contains lambda transforms, that
|
|
1752
|
+
is instances of the Lambda class.
|
|
1753
|
+
This dictionary is required when you are restoring a pipeline that contains lambda transforms.
|
|
1754
|
+
Keys in that dictionary should be named same as `name` arguments in respective lambda transforms
|
|
1755
|
+
from a serialized pipeline.
|
|
1756
|
+
|
|
1757
|
+
"""
|
|
1758
|
+
applied = transform_dict["applied"]
|
|
1759
|
+
params = transform_dict["params"]
|
|
1760
|
+
lmbd = instantiate_nonserializable(transform_dict, lambda_transforms)
|
|
1761
|
+
if lmbd:
|
|
1762
|
+
transform = lmbd
|
|
1763
|
+
else:
|
|
1764
|
+
name = transform_dict["__class_fullname__"]
|
|
1765
|
+
args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]}
|
|
1766
|
+
cls = SERIALIZABLE_REGISTRY[name]
|
|
1767
|
+
if "transforms" in args:
|
|
1768
|
+
args["transforms"] = [
|
|
1769
|
+
ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
|
|
1770
|
+
for t in args["transforms"]
|
|
1771
|
+
]
|
|
1772
|
+
transform = cls(**args)
|
|
1773
|
+
|
|
1774
|
+
transform = cast("BasicTransform", transform)
|
|
1775
|
+
if isinstance(transform, BasicTransform):
|
|
1776
|
+
transform.params = params
|
|
1777
|
+
transform.replay_mode = True
|
|
1778
|
+
transform.applied_in_replay = applied
|
|
1779
|
+
return transform
|
|
1780
|
+
|
|
1781
|
+
def fill_with_params(self, serialized: dict[str, Any], all_params: Any) -> None:
|
|
1782
|
+
"""Fill serialized transform data with parameters for replay.
|
|
1783
|
+
|
|
1784
|
+
Args:
|
|
1785
|
+
serialized (dict[str, Any]): Serialized transform data.
|
|
1786
|
+
all_params (Any): Parameters to fill in.
|
|
1787
|
+
|
|
1788
|
+
"""
|
|
1789
|
+
params = all_params.get(serialized.get("id"))
|
|
1790
|
+
serialized["params"] = params
|
|
1791
|
+
del serialized["id"]
|
|
1792
|
+
for transform in serialized.get("transforms", []):
|
|
1793
|
+
self.fill_with_params(transform, all_params)
|
|
1794
|
+
|
|
1795
|
+
def fill_applied(self, serialized: dict[str, Any]) -> bool:
|
|
1796
|
+
"""Set 'applied' flag for transforms based on parameters.
|
|
1797
|
+
|
|
1798
|
+
Args:
|
|
1799
|
+
serialized (dict[str, Any]): Serialized transform data.
|
|
1800
|
+
|
|
1801
|
+
Returns:
|
|
1802
|
+
bool: True if any transform was applied, False otherwise.
|
|
1803
|
+
|
|
1804
|
+
"""
|
|
1805
|
+
if "transforms" in serialized:
|
|
1806
|
+
applied = [self.fill_applied(t) for t in serialized["transforms"]]
|
|
1807
|
+
serialized["applied"] = any(applied)
|
|
1808
|
+
else:
|
|
1809
|
+
serialized["applied"] = serialized.get("params") is not None
|
|
1810
|
+
return serialized["applied"]
|
|
1811
|
+
|
|
1812
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
1813
|
+
"""Convert the ReplayCompose to a dictionary for serialization.
|
|
1814
|
+
|
|
1815
|
+
Returns:
|
|
1816
|
+
dict[str, Any]: Dictionary representation of the composition.
|
|
1817
|
+
|
|
1818
|
+
"""
|
|
1819
|
+
dictionary = super().to_dict_private()
|
|
1820
|
+
dictionary.update({"save_key": self.save_key})
|
|
1821
|
+
return dictionary
|
|
1822
|
+
|
|
1823
|
+
def _get_init_params(self) -> dict[str, Any]:
|
|
1824
|
+
"""Get parameters needed to recreate this ReplayCompose instance.
|
|
1825
|
+
|
|
1826
|
+
Returns:
|
|
1827
|
+
dict[str, Any]: Dictionary of initialization parameters
|
|
1828
|
+
|
|
1829
|
+
"""
|
|
1830
|
+
base_params = super()._get_init_params()
|
|
1831
|
+
base_params.update(
|
|
1832
|
+
{
|
|
1833
|
+
"save_key": self.save_key,
|
|
1834
|
+
},
|
|
1835
|
+
)
|
|
1836
|
+
return base_params
|
|
1837
|
+
|
|
1838
|
+
|
|
1839
|
+
class Sequential(BaseCompose):
|
|
1840
|
+
"""Sequentially applies all transforms to targets.
|
|
1841
|
+
|
|
1842
|
+
Note:
|
|
1843
|
+
This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose`
|
|
1844
|
+
the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to
|
|
1845
|
+
create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly
|
|
1846
|
+
chose sequence to input data (see the `Example` section for an example definition of such pipeline).
|
|
1847
|
+
|
|
1848
|
+
Examples:
|
|
1849
|
+
>>> import albumentations as A
|
|
1850
|
+
>>> transform = A.Compose([
|
|
1851
|
+
>>> A.OneOf([
|
|
1852
|
+
>>> A.Sequential([
|
|
1853
|
+
>>> A.HorizontalFlip(p=0.5),
|
|
1854
|
+
>>> A.ShiftScaleRotate(p=0.5),
|
|
1855
|
+
>>> ]),
|
|
1856
|
+
>>> A.Sequential([
|
|
1857
|
+
>>> A.VerticalFlip(p=0.5),
|
|
1858
|
+
>>> A.RandomBrightnessContrast(p=0.5),
|
|
1859
|
+
>>> ]),
|
|
1860
|
+
>>> ], p=1)
|
|
1861
|
+
>>> ])
|
|
1862
|
+
|
|
1863
|
+
"""
|
|
1864
|
+
|
|
1865
|
+
def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
|
|
1866
|
+
super().__init__(transforms=transforms, p=p)
|
|
1867
|
+
|
|
1868
|
+
def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
|
|
1869
|
+
"""Apply all transforms in sequential order.
|
|
1870
|
+
|
|
1871
|
+
Args:
|
|
1872
|
+
*args (Any): Positional arguments are not supported.
|
|
1873
|
+
force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
|
|
1874
|
+
**data (Any): Dict with data to transform.
|
|
1875
|
+
|
|
1876
|
+
Returns:
|
|
1877
|
+
dict[str, Any]: Dictionary with transformed data.
|
|
1878
|
+
|
|
1879
|
+
"""
|
|
1880
|
+
if self.replay_mode or force_apply or self.py_random.random() < self.p:
|
|
1881
|
+
for t in self.transforms:
|
|
1882
|
+
data = t(**data)
|
|
1883
|
+
self._track_transform_params(t, data)
|
|
1884
|
+
data = self.check_data_post_transform(data)
|
|
1885
|
+
return data
|