nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1885 @@
1
+ """Module for composing multiple transforms into augmentation pipelines.
2
+
3
+ This module provides classes for combining multiple transformations into cohesive
4
+ augmentation pipelines. It includes various composition strategies such as sequential
5
+ application, random selection, and conditional application of transforms. These
6
+ composition classes handle the coordination between different transforms, ensuring
7
+ proper data flow and maintaining consistent behavior across the augmentation pipeline.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import random
13
+ import warnings
14
+ from collections import defaultdict
15
+ from collections.abc import Iterator, Sequence
16
+ from typing import Any, Union, cast
17
+
18
+ import cv2
19
+ import numpy as np
20
+
21
+ from .bbox_utils import BboxParams, BboxProcessor
22
+ from .hub_mixin import HubMixin
23
+ from .keypoints_utils import KeypointParams, KeypointsProcessor
24
+ from .serialization import (
25
+ SERIALIZABLE_REGISTRY,
26
+ Serializable,
27
+ get_shortest_class_fullname,
28
+ instantiate_nonserializable,
29
+ )
30
+ from .transforms_interface import BasicTransform
31
+ from .utils import DataProcessor, format_args, get_shape
32
+
33
+ __all__ = [
34
+ "BaseCompose",
35
+ "BboxParams",
36
+ "Compose",
37
+ "KeypointParams",
38
+ "OneOf",
39
+ "OneOrOther",
40
+ "RandomOrder",
41
+ "ReplayCompose",
42
+ "SelectiveChannelTransform",
43
+ "Sequential",
44
+ "SomeOf",
45
+ ]
46
+
47
+ NUM_ONEOF_TRANSFORMS = 2
48
+ REPR_INDENT_STEP = 2
49
+
50
+ TransformType = Union[BasicTransform, "BaseCompose"]
51
+ TransformsSeqType = list[TransformType]
52
+
53
+ AVAILABLE_KEYS = ("image", "mask", "masks", "bboxes", "keypoints", "volume", "volumes", "mask3d", "masks3d")
54
+
55
+ MASK_KEYS = (
56
+ "mask", # 2D mask
57
+ "masks", # Multiple 2D masks
58
+ "mask3d", # 3D mask
59
+ "masks3d", # Multiple 3D masks
60
+ )
61
+
62
+ # Keys related to image data
63
+ IMAGE_KEYS = {"image", "images"}
64
+ CHECKED_SINGLE = {"image", "mask"}
65
+ CHECKED_MULTI = {"masks", "images", "volumes", "masks3d"}
66
+ CHECK_BBOX_PARAM = {"bboxes"}
67
+ CHECK_KEYPOINTS_PARAM = {"keypoints"}
68
+ VOLUME_KEYS = {"volume", "volumes"}
69
+ CHECKED_VOLUME = {"volume"}
70
+ CHECKED_VOLUMES = {"volumes"}
71
+ CHECKED_MASK3D = {"mask3d"}
72
+ CHECKED_MASKS3D = {"masks3d"}
73
+
74
+
75
+ class BaseCompose(Serializable):
76
+ """Base class for composing multiple transforms together.
77
+
78
+ This class serves as a foundation for creating compositions of transforms
79
+ in the Albumentations library. It provides basic functionality for
80
+ managing a sequence of transforms and applying them to data.
81
+
82
+ The class supports dynamic pipeline modification after initialization using
83
+ mathematical operators:
84
+ - Addition (`+`): Add transforms to the end of the pipeline
85
+ - Right addition (`__radd__`): Add transforms to the beginning of the pipeline
86
+ - Subtraction (`-`): Remove transforms by class from the pipeline
87
+
88
+ Attributes:
89
+ transforms (List[TransformType]): A list of transforms to be applied.
90
+ p (float): Probability of applying the compose. Should be in the range [0, 1].
91
+ replay_mode (bool): If True, the compose is in replay mode.
92
+ _additional_targets (Dict[str, str]): Additional targets for transforms.
93
+ _available_keys (Set[str]): Set of available keys for data.
94
+ processors (Dict[str, Union[BboxProcessor, KeypointsProcessor]]): Processors for specific data types.
95
+
96
+ Args:
97
+ transforms (TransformsSeqType): A sequence of transforms to compose.
98
+ p (float): Probability of applying the compose.
99
+
100
+ Raises:
101
+ ValueError: If an invalid additional target is specified.
102
+
103
+ Note:
104
+ - Subclasses should implement the __call__ method to define how
105
+ the composition is applied to data.
106
+ - The class supports serialization and deserialization of transforms.
107
+ - It provides methods for adding targets, setting deterministic behavior,
108
+ and checking data validity post-transform.
109
+ - All compose classes support pipeline modification operators:
110
+ - `compose + transform` adds individual transform(s) to the end
111
+ - `transform + compose` adds individual transform(s) to the beginning
112
+ - `compose - TransformClass` removes transforms by class type
113
+ - Only BasicTransform instances (not BaseCompose) can be added
114
+ - All operator operations return new instances without modifying the original.
115
+
116
+ Examples:
117
+ >>> import albumentations as A
118
+ >>> # Create base pipeline
119
+ >>> compose = A.Compose([A.HorizontalFlip(p=1.0)])
120
+ >>>
121
+ >>> # Add transforms using operators
122
+ >>> extended = compose + A.VerticalFlip(p=1.0) # Append
123
+ >>> extended = compose + [A.Blur(), A.Rotate()] # Append multiple
124
+ >>> extended = A.RandomCrop(256, 256) + compose # Prepend
125
+ >>>
126
+ >>> # Remove transforms by class
127
+ >>> compose = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(p=1.0)])
128
+ >>> reduced = compose - A.HorizontalFlip # Remove by class
129
+
130
+ """
131
+
132
+ _transforms_dict: dict[int, BasicTransform] | None = None
133
+ check_each_transform: tuple[DataProcessor, ...] | None = None
134
+ main_compose: bool = True
135
+
136
+ def __init__(
137
+ self,
138
+ transforms: TransformsSeqType,
139
+ p: float,
140
+ mask_interpolation: int | None = None,
141
+ seed: int | None = None,
142
+ save_applied_params: bool = False,
143
+ **kwargs: Any,
144
+ ):
145
+ if isinstance(transforms, (BaseCompose, BasicTransform)):
146
+ warnings.warn(
147
+ "transforms is single transform, but a sequence is expected! Transform will be wrapped into list.",
148
+ stacklevel=2,
149
+ )
150
+ transforms = [transforms]
151
+
152
+ self.transforms = transforms
153
+ self.p = p
154
+
155
+ self.replay_mode = False
156
+ self._additional_targets: dict[str, str] = {}
157
+ self._available_keys: set[str] = set()
158
+ self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
159
+ self._set_keys()
160
+ self.set_mask_interpolation(mask_interpolation)
161
+ self.set_random_seed(seed)
162
+ self.save_applied_params = save_applied_params
163
+
164
+ def _track_transform_params(self, transform: TransformType, data: dict[str, Any]) -> None:
165
+ """Track transform parameters if tracking is enabled."""
166
+ if "applied_transforms" in data and hasattr(transform, "params") and transform.params:
167
+ data["applied_transforms"].append((transform.__class__.__name__, transform.params.copy()))
168
+
169
+ def set_random_state(
170
+ self,
171
+ random_generator: np.random.Generator,
172
+ py_random: random.Random,
173
+ ) -> None:
174
+ """Set random state directly from generators.
175
+
176
+ Args:
177
+ random_generator (np.random.Generator): numpy random generator to use
178
+ py_random (random.Random): python random generator to use
179
+
180
+ """
181
+ self.random_generator = random_generator
182
+ self.py_random = py_random
183
+
184
+ # Propagate both random states to all transforms
185
+ for transform in self.transforms:
186
+ if isinstance(transform, (BasicTransform, BaseCompose)):
187
+ transform.set_random_state(random_generator, py_random)
188
+
189
+ def set_random_seed(self, seed: int | None) -> None:
190
+ """Set random state from seed.
191
+
192
+ Args:
193
+ seed (int | None): Random seed to use
194
+
195
+ """
196
+ # Store the original seed
197
+ self.seed = seed
198
+
199
+ # Use base seed directly (subclasses like Compose can override this)
200
+ self.random_generator = np.random.default_rng(seed)
201
+ self.py_random = random.Random(seed)
202
+
203
+ # Propagate seed to all transforms
204
+ for transform in self.transforms:
205
+ if isinstance(transform, (BasicTransform, BaseCompose)):
206
+ transform.set_random_seed(seed)
207
+
208
+ def set_mask_interpolation(self, mask_interpolation: int | None) -> None:
209
+ """Set interpolation mode for mask resizing operations.
210
+
211
+ Args:
212
+ mask_interpolation (int | None): OpenCV interpolation flag to use for mask transforms.
213
+ If None, default interpolation for masks will be used.
214
+
215
+ """
216
+ self.mask_interpolation = mask_interpolation
217
+ self._set_mask_interpolation_recursive(self.transforms)
218
+
219
+ def _set_mask_interpolation_recursive(self, transforms: TransformsSeqType) -> None:
220
+ for transform in transforms:
221
+ if isinstance(transform, BasicTransform):
222
+ if hasattr(transform, "mask_interpolation") and self.mask_interpolation is not None:
223
+ transform.mask_interpolation = self.mask_interpolation
224
+ elif isinstance(transform, BaseCompose):
225
+ transform.set_mask_interpolation(self.mask_interpolation)
226
+
227
+ def __iter__(self) -> Iterator[TransformType]:
228
+ return iter(self.transforms)
229
+
230
+ def __len__(self) -> int:
231
+ return len(self.transforms)
232
+
233
+ def __call__(self, *args: Any, **data: Any) -> dict[str, Any]:
234
+ """Apply transforms.
235
+
236
+ Args:
237
+ *args (Any): Positional arguments are not supported.
238
+ **data (Any): Named parameters with data to transform.
239
+
240
+ Returns:
241
+ dict[str, Any]: Transformed data.
242
+
243
+ Raises:
244
+ NotImplementedError: This method must be implemented by subclasses.
245
+
246
+ """
247
+ raise NotImplementedError
248
+
249
+ def __getitem__(self, item: int) -> TransformType:
250
+ return self.transforms[item]
251
+
252
+ def __repr__(self) -> str:
253
+ return self.indented_repr()
254
+
255
+ @property
256
+ def additional_targets(self) -> dict[str, str]:
257
+ """Get additional targets dictionary.
258
+
259
+ Returns:
260
+ dict[str, str]: Dictionary containing additional targets mapping.
261
+
262
+ """
263
+ return self._additional_targets
264
+
265
+ @property
266
+ def available_keys(self) -> set[str]:
267
+ """Get set of available keys.
268
+
269
+ Returns:
270
+ set[str]: Set of string keys available for transforms.
271
+
272
+ """
273
+ return self._available_keys
274
+
275
+ def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
276
+ """Get an indented string representation of the composition.
277
+
278
+ Args:
279
+ indent (int): Indentation level. Default: REPR_INDENT_STEP.
280
+
281
+ Returns:
282
+ str: Formatted string representation with proper indentation.
283
+
284
+ """
285
+ args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")}
286
+ repr_string = self.__class__.__name__ + "(["
287
+ for t in self.transforms:
288
+ repr_string += "\n"
289
+ t_repr = t.indented_repr(indent + REPR_INDENT_STEP) if hasattr(t, "indented_repr") else repr(t)
290
+ repr_string += " " * indent + t_repr + ","
291
+ repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + f"], {format_args(args)})"
292
+ return repr_string
293
+
294
+ @classmethod
295
+ def get_class_fullname(cls) -> str:
296
+ """Get the full qualified name of the class.
297
+
298
+ Returns:
299
+ str: The shortest class fullname.
300
+
301
+ """
302
+ return get_shortest_class_fullname(cls)
303
+
304
+ @classmethod
305
+ def is_serializable(cls) -> bool:
306
+ """Check if the class is serializable.
307
+
308
+ Returns:
309
+ bool: True if the class is serializable, False otherwise.
310
+
311
+ """
312
+ return True
313
+
314
+ def to_dict_private(self) -> dict[str, Any]:
315
+ """Convert the composition to a dictionary for serialization.
316
+
317
+ Returns:
318
+ dict[str, Any]: Dictionary representation of the composition.
319
+
320
+ """
321
+ return {
322
+ "__class_fullname__": self.get_class_fullname(),
323
+ "p": self.p,
324
+ "transforms": [t.to_dict_private() for t in self.transforms],
325
+ }
326
+
327
+ def get_dict_with_id(self) -> dict[str, Any]:
328
+ """Get a dictionary representation with object IDs for replay mode.
329
+
330
+ Returns:
331
+ dict[str, Any]: Dictionary with composition data and object IDs.
332
+
333
+ """
334
+ return {
335
+ "__class_fullname__": self.get_class_fullname(),
336
+ "id": id(self),
337
+ "params": None,
338
+ "transforms": [t.get_dict_with_id() for t in self.transforms],
339
+ }
340
+
341
+ def add_targets(self, additional_targets: dict[str, str] | None) -> None:
342
+ """Add additional targets to all transforms.
343
+
344
+ Args:
345
+ additional_targets (dict[str, str] | None): Dict of name -> type mapping for additional targets.
346
+ If None, no additional targets will be added.
347
+
348
+ """
349
+ if additional_targets:
350
+ for k, v in additional_targets.items():
351
+ if k in self._additional_targets and v != self._additional_targets[k]:
352
+ raise ValueError(
353
+ f"Trying to overwrite existed additional targets. "
354
+ f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
355
+ )
356
+ self._additional_targets.update(additional_targets)
357
+ for t in self.transforms:
358
+ t.add_targets(additional_targets)
359
+ for proc in self.processors.values():
360
+ proc.add_targets(additional_targets)
361
+ self._set_keys()
362
+
363
+ def _set_keys(self) -> None:
364
+ """Set _available_keys"""
365
+ self._available_keys.update(self._additional_targets.keys())
366
+ for t in self.transforms:
367
+ self._available_keys.update(t.available_keys)
368
+ if hasattr(t, "targets_as_params"):
369
+ self._available_keys.update(t.targets_as_params)
370
+ if self.processors:
371
+ self._available_keys.update(["labels"])
372
+ for proc in self.processors.values():
373
+ if proc.default_data_name not in self._available_keys: # if no transform to process this data
374
+ warnings.warn(
375
+ f"Got processor for {proc.default_data_name}, but no transform to process it.",
376
+ stacklevel=2,
377
+ )
378
+ self._available_keys.update(proc.data_fields)
379
+ if proc.params.label_fields:
380
+ self._available_keys.update(proc.params.label_fields)
381
+
382
+ def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
383
+ """Set deterministic mode for all transforms.
384
+
385
+ Args:
386
+ flag (bool): Whether to enable deterministic mode.
387
+ save_key (str): Key to save replay parameters. Default: "replay".
388
+
389
+ """
390
+ for t in self.transforms:
391
+ t.set_deterministic(flag, save_key)
392
+
393
+ def check_data_post_transform(self, data: dict[str, Any]) -> dict[str, Any]:
394
+ """Check and filter data after transformation.
395
+
396
+ Args:
397
+ data (dict[str, Any]): Dictionary containing transformed data
398
+
399
+ Returns:
400
+ dict[str, Any]: Filtered data dictionary
401
+
402
+ """
403
+ if self.check_each_transform:
404
+ shape = get_shape(data)
405
+
406
+ for proc in self.check_each_transform:
407
+ for data_name, data_value in data.items():
408
+ if data_name in proc.data_fields or (
409
+ data_name in self._additional_targets
410
+ and self._additional_targets[data_name] in proc.data_fields
411
+ ):
412
+ data[data_name] = proc.filter(data_value, shape)
413
+ return data
414
+
415
+ def _validate_transforms(self, transforms: list[Any]) -> None:
416
+ """Validate that all elements are BasicTransform instances.
417
+
418
+ Args:
419
+ transforms: List of objects to validate
420
+
421
+ Raises:
422
+ TypeError: If any element is not a BasicTransform instance
423
+
424
+ """
425
+ for t in transforms:
426
+ if not isinstance(t, BasicTransform):
427
+ raise TypeError(
428
+ f"All elements must be instances of BasicTransform, got {type(t).__name__}",
429
+ )
430
+
431
+ def _combine_transforms(self, other: TransformType | TransformsSeqType, *, prepend: bool = False) -> BaseCompose:
432
+ """Combine transforms with the current compose.
433
+
434
+ Args:
435
+ other: Transform or sequence of transforms to combine
436
+ prepend: If True, prepend other to the beginning; if False, append to the end
437
+
438
+ Returns:
439
+ BaseCompose: New compose instance with combined transforms
440
+
441
+ Raises:
442
+ TypeError: If other is not a valid transform or sequence of transforms
443
+
444
+ """
445
+ if isinstance(other, (list, tuple)):
446
+ self._validate_transforms(other)
447
+ other_list = list(other)
448
+ else:
449
+ self._validate_transforms([other])
450
+ other_list = [other]
451
+
452
+ new_transforms = [*other_list, *list(self.transforms)] if prepend else [*list(self.transforms), *other_list]
453
+
454
+ return self._create_new_instance(new_transforms)
455
+
456
+ def __add__(self, other: TransformType | TransformsSeqType) -> BaseCompose:
457
+ """Add transform(s) to the end of this compose.
458
+
459
+ Args:
460
+ other: Transform or sequence of transforms to append
461
+
462
+ Returns:
463
+ BaseCompose: New compose instance with transforms appended
464
+
465
+ Raises:
466
+ TypeError: If other is not a valid transform or sequence of transforms
467
+
468
+ Examples:
469
+ >>> new_compose = compose + A.HorizontalFlip()
470
+ >>> new_compose = compose + [A.HorizontalFlip(), A.VerticalFlip()]
471
+
472
+ """
473
+ return self._combine_transforms(other, prepend=False)
474
+
475
+ def __radd__(self, other: TransformType | TransformsSeqType) -> BaseCompose:
476
+ """Add transform(s) to the beginning of this compose.
477
+
478
+ Args:
479
+ other: Transform or sequence of transforms to prepend
480
+
481
+ Returns:
482
+ BaseCompose: New compose instance with transforms prepended
483
+
484
+ Raises:
485
+ TypeError: If other is not a valid transform or sequence of transforms
486
+
487
+ Examples:
488
+ >>> new_compose = A.HorizontalFlip() + compose
489
+ >>> new_compose = [A.HorizontalFlip(), A.VerticalFlip()] + compose
490
+
491
+ """
492
+ return self._combine_transforms(other, prepend=True)
493
+
494
+ def __sub__(self, other: type[BasicTransform]) -> BaseCompose:
495
+ """Remove transform from this compose by class type.
496
+
497
+ Removes the first transform in the compose that matches the provided transform class.
498
+
499
+ Args:
500
+ other: Transform class to remove (e.g., A.HorizontalFlip)
501
+
502
+ Returns:
503
+ BaseCompose: New compose instance with transform removed
504
+
505
+ Raises:
506
+ TypeError: If other is not a BasicTransform class
507
+ ValueError: If no transform of that type is found in the compose
508
+
509
+ Note:
510
+ If multiple transforms of the same type exist in the compose,
511
+ only the first occurrence will be removed.
512
+
513
+ Examples:
514
+ >>> # Remove by transform class
515
+ >>> new_compose = compose - A.HorizontalFlip
516
+ >>>
517
+ >>> # With duplicates - only first occurrence removed
518
+ >>> compose = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(), A.HorizontalFlip(p=1.0)])
519
+ >>> result = compose - A.HorizontalFlip # Removes first HorizontalFlip (p=0.5)
520
+ >>> len(result.transforms) # 2 (VerticalFlip and second HorizontalFlip remain)
521
+
522
+ """
523
+ # Validate that other is a BasicTransform class
524
+ if not (isinstance(other, type) and issubclass(other, BasicTransform)):
525
+ raise TypeError(
526
+ f"Can only remove BasicTransform classes, got {type(other).__name__}",
527
+ )
528
+
529
+ # Find first transform of matching class
530
+ new_transforms = list(self.transforms)
531
+ for i, transform in enumerate(new_transforms):
532
+ if type(transform) is other:
533
+ new_transforms.pop(i)
534
+ return self._create_new_instance(new_transforms)
535
+
536
+ # No matching transform found
537
+ class_name = other.__name__
538
+ raise ValueError(f"No transform of type {class_name} found in the compose pipeline")
539
+
540
+ def _create_new_instance(self, new_transforms: TransformsSeqType) -> BaseCompose:
541
+ """Create a new instance of the same class with new transforms.
542
+
543
+ Args:
544
+ new_transforms: List of transforms for the new instance
545
+
546
+ Returns:
547
+ BaseCompose: New instance of the same class
548
+
549
+ """
550
+ # Get current instance parameters
551
+ init_params = self._get_init_params()
552
+ init_params["transforms"] = new_transforms
553
+
554
+ # Create new instance
555
+ new_instance = self.__class__(**init_params)
556
+
557
+ # Copy random state from original instance to new instance
558
+ if hasattr(self, "random_generator") and hasattr(self, "py_random"):
559
+ new_instance.set_random_state(self.random_generator, self.py_random)
560
+
561
+ return new_instance
562
+
563
+ def _get_init_params(self) -> dict[str, Any]:
564
+ """Get parameters needed to recreate this instance.
565
+
566
+ Note:
567
+ Subclasses that add new initialization parameters (other than 'transforms',
568
+ which is set separately in _create_new_instance) should override this method
569
+ to include those parameters in the returned dictionary.
570
+
571
+ Returns:
572
+ dict[str, Any]: Dictionary of initialization parameters
573
+
574
+ """
575
+ return {
576
+ "p": self.p,
577
+ }
578
+
579
+ def _get_effective_seed(self, base_seed: int | None) -> int | None:
580
+ """Get effective seed considering worker context.
581
+
582
+ Args:
583
+ base_seed (int | None): Base seed value
584
+
585
+ Returns:
586
+ int | None: Effective seed after considering worker context
587
+
588
+ """
589
+ if base_seed is None:
590
+ return base_seed
591
+
592
+ try:
593
+ import torch
594
+ import torch.utils.data
595
+
596
+ worker_info = torch.utils.data.get_worker_info()
597
+ if worker_info is not None:
598
+ # We're in a DataLoader worker process
599
+ # Use torch.initial_seed() which is unique per worker and changes on respawn
600
+ torch_seed = torch.initial_seed() % (2**32)
601
+ return (base_seed + torch_seed) % (2**32)
602
+ except (ImportError, AttributeError):
603
+ # PyTorch not available or not in worker context
604
+ pass
605
+
606
+ return base_seed
607
+
608
+
609
+ class Compose(BaseCompose, HubMixin):
610
+ """Compose multiple transforms together and apply them sequentially to input data.
611
+
612
+ This class allows you to chain multiple image augmentation transforms and apply them
613
+ in a specified order. It also handles bounding box and keypoint transformations if
614
+ the appropriate parameters are provided.
615
+
616
+ The Compose class supports dynamic pipeline modification after initialization using
617
+ mathematical operators. All parameters (bbox_params, keypoint_params, additional_targets,
618
+ etc.) are preserved when using operators to modify the pipeline.
619
+
620
+ Args:
621
+ transforms (list[BasicTransform | BaseCompose]): A list of transforms to apply.
622
+ bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
623
+ Can be a dict of params or a BboxParams object. Default is None.
624
+ keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
625
+ Can be a dict of params or a KeypointParams object. Default is None.
626
+ additional_targets (dict[str, str] | None): A dictionary mapping additional target names
627
+ to their types. For example, {'image2': 'image'}. Default is None.
628
+ p (float): Probability of applying all transforms. Should be in range [0, 1]. Default is 1.0.
629
+ is_check_shapes (bool): If True, checks consistency of shapes for image/mask/masks on each call.
630
+ Disable only if you are sure about your data consistency. Default is True.
631
+ strict (bool): If True, enables strict mode which:
632
+ 1. Validates that all input keys are known/expected
633
+ 2. Validates that no transforms have invalid arguments
634
+ 3. Raises ValueError if any validation fails
635
+ If False, these validations are skipped. Default is False.
636
+ mask_interpolation (int | None): Interpolation method for mask transforms. When defined,
637
+ it overrides the interpolation method specified in individual transforms. Default is None.
638
+ seed (int | None): Controls reproducibility of random augmentations. Compose uses
639
+ its own internal random state, completely independent from global random seeds.
640
+
641
+ When seed is set (int):
642
+ - Creates a fixed internal random state
643
+ - Two Compose instances with the same seed and transforms will produce identical
644
+ sequences of augmentations
645
+ - Each call to the same Compose instance still produces random augmentations,
646
+ but these sequences are reproducible between different Compose instances
647
+ - Example: transform1 = A.Compose([...], seed=137) and
648
+ transform2 = A.Compose([...], seed=137) will produce identical sequences
649
+
650
+ When seed is None (default):
651
+ - Generates a new internal random state on each Compose creation
652
+ - Different Compose instances will produce different sequences of augmentations
653
+ - Example: transform = A.Compose([...]) # random results
654
+
655
+ Important: Setting random seeds outside of Compose (like np.random.seed() or
656
+ random.seed()) has no effect on augmentations as Compose uses its own internal
657
+ random state.
658
+ save_applied_params (bool): If True, saves the applied parameters of each transform. Default is False.
659
+ You will need to use the `applied_transforms` key in the output dictionary to access the parameters.
660
+
661
+ Examples:
662
+ >>> # Basic usage:
663
+ >>> import albumentations as A
664
+ >>> transform = A.Compose([
665
+ ... A.RandomCrop(width=256, height=256),
666
+ ... A.HorizontalFlip(p=0.5),
667
+ ... A.RandomBrightnessContrast(p=0.2),
668
+ ... ], seed=137)
669
+ >>> transformed = transform(image=image)
670
+
671
+ >>> # Pipeline modification after initialization:
672
+ >>> # Create initial pipeline with bbox support
673
+ >>> base_transform = A.Compose([
674
+ ... A.HorizontalFlip(p=0.5),
675
+ ... A.RandomCrop(width=512, height=512)
676
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
677
+ >>>
678
+ >>> # Add transforms using operators (bbox_params preserved)
679
+ >>> extended = base_transform + A.RandomBrightnessContrast(p=0.3)
680
+ >>> extended = base_transform + [A.Blur(), A.GaussNoise()]
681
+ >>> extended = A.Resize(height=1024, width=1024) + base_transform
682
+ >>>
683
+ >>> # Remove transforms by class
684
+ >>> pipeline = A.Compose([A.HorizontalFlip(p=0.5), A.VerticalFlip(), A.Rotate()])
685
+ >>> without_flip = pipeline - A.HorizontalFlip # Remove by class
686
+
687
+ Note:
688
+ - The class checks the validity of input data and shapes if is_check_args and is_check_shapes are True.
689
+ - When bbox_params or keypoint_params are provided, it sets up the corresponding processors.
690
+ - The transform can handle additional targets specified in the additional_targets dictionary.
691
+ - When strict mode is enabled, it performs additional validation to ensure data and transform
692
+ configuration correctness.
693
+ - Pipeline modification operators (+, -, __radd__) preserve all Compose parameters including
694
+ bbox_params, keypoint_params, additional_targets, and other configuration settings.
695
+ - All operators return new Compose instances without modifying the original pipeline.
696
+
697
+ """
698
+
699
+ def __init__(
700
+ self,
701
+ transforms: TransformsSeqType,
702
+ bbox_params: dict[str, Any] | BboxParams | None = None,
703
+ keypoint_params: dict[str, Any] | KeypointParams | None = None,
704
+ additional_targets: dict[str, str] | None = None,
705
+ p: float = 1.0,
706
+ is_check_shapes: bool = True,
707
+ strict: bool = False,
708
+ mask_interpolation: int | None = None,
709
+ seed: int | None = None,
710
+ save_applied_params: bool = False,
711
+ ):
712
+ # Store the original base seed for worker context recalculation
713
+ self._base_seed = seed
714
+
715
+ # Get effective seed considering worker context
716
+ effective_seed = self._get_effective_seed(seed)
717
+
718
+ super().__init__(
719
+ transforms=transforms,
720
+ p=p,
721
+ mask_interpolation=mask_interpolation,
722
+ seed=effective_seed,
723
+ save_applied_params=save_applied_params,
724
+ )
725
+
726
+ if bbox_params:
727
+ if isinstance(bbox_params, dict):
728
+ b_params = BboxParams(**bbox_params)
729
+ elif isinstance(bbox_params, BboxParams):
730
+ b_params = bbox_params
731
+ else:
732
+ msg = "unknown format of bbox_params, please use `dict` or `BboxParams`"
733
+ raise ValueError(msg)
734
+ self.processors["bboxes"] = BboxProcessor(b_params)
735
+
736
+ if keypoint_params:
737
+ if isinstance(keypoint_params, dict):
738
+ k_params = KeypointParams(**keypoint_params)
739
+ elif isinstance(keypoint_params, KeypointParams):
740
+ k_params = keypoint_params
741
+ else:
742
+ msg = "unknown format of keypoint_params, please use `dict` or `KeypointParams`"
743
+ raise ValueError(msg)
744
+ self.processors["keypoints"] = KeypointsProcessor(k_params)
745
+
746
+ for proc in self.processors.values():
747
+ proc.ensure_transforms_valid(self.transforms)
748
+
749
+ self.add_targets(additional_targets)
750
+ if not self.transforms: # if no transforms -> do nothing, all keys will be available
751
+ self._available_keys.update(AVAILABLE_KEYS)
752
+
753
+ self.is_check_args = True
754
+ self.strict = strict
755
+
756
+ self.is_check_shapes = is_check_shapes
757
+ self.check_each_transform = tuple( # processors that checks after each transform
758
+ proc for proc in self.processors.values() if getattr(proc.params, "check_each_transform", False)
759
+ )
760
+ self._set_check_args_for_transforms(self.transforms)
761
+
762
+ self._set_processors_for_transforms(self.transforms)
763
+
764
+ self.save_applied_params = save_applied_params
765
+ self._images_was_list = False
766
+ self._masks_was_list = False
767
+ self._last_torch_seed: int | None = None
768
+
769
+ @property
770
+ def strict(self) -> bool:
771
+ """Get the current strict mode setting.
772
+
773
+ Returns:
774
+ bool: True if strict mode is enabled, False otherwise.
775
+
776
+ """
777
+ return self._strict
778
+
779
+ @strict.setter
780
+ def strict(self, value: bool) -> None:
781
+ # if value and not self._strict:
782
+ if value:
783
+ # Only validate when enabling strict mode
784
+ self._validate_strict()
785
+ self._strict = value
786
+
787
+ def _validate_strict(self) -> None:
788
+ """Validate that no transforms have invalid arguments when strict mode is enabled."""
789
+
790
+ def check_transform(transform: TransformType) -> None:
791
+ if hasattr(transform, "invalid_args") and transform.invalid_args:
792
+ message = (
793
+ f"Argument(s) '{', '.join(transform.invalid_args)}' "
794
+ f"are not valid for transform {transform.__class__.__name__}"
795
+ )
796
+ raise ValueError(message)
797
+ if isinstance(transform, BaseCompose):
798
+ for t in transform.transforms:
799
+ check_transform(t)
800
+
801
+ for transform in self.transforms:
802
+ check_transform(transform)
803
+
804
+ def _set_processors_for_transforms(self, transforms: TransformsSeqType) -> None:
805
+ for transform in transforms:
806
+ if isinstance(transform, BasicTransform):
807
+ if hasattr(transform, "set_processors"):
808
+ transform.set_processors(self.processors)
809
+ elif isinstance(transform, BaseCompose):
810
+ self._set_processors_for_transforms(transform.transforms)
811
+
812
+ def _set_check_args_for_transforms(self, transforms: TransformsSeqType) -> None:
813
+ for transform in transforms:
814
+ if isinstance(transform, BaseCompose):
815
+ self._set_check_args_for_transforms(transform.transforms)
816
+ transform.check_each_transform = self.check_each_transform
817
+ transform.processors = self.processors
818
+ if isinstance(transform, Compose):
819
+ transform.disable_check_args_private()
820
+
821
+ def disable_check_args_private(self) -> None:
822
+ """Disable argument checking for transforms.
823
+
824
+ This method disables strict mode and argument checking for all transforms in the composition.
825
+ """
826
+ self.is_check_args = False
827
+ self.strict = False
828
+ self.main_compose = False
829
+
830
+ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
831
+ """Apply transformations to data with automatic worker seed synchronization.
832
+
833
+ Args:
834
+ *args (Any): Positional arguments are not supported.
835
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
836
+ **data (Any): Dict with data to transform.
837
+
838
+ Returns:
839
+ dict[str, Any]: Dictionary with transformed data.
840
+
841
+ Raises:
842
+ KeyError: If positional arguments are provided.
843
+
844
+ """
845
+ # Check and sync worker seed if needed
846
+ self._check_worker_seed()
847
+
848
+ if args:
849
+ msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
850
+ raise KeyError(msg)
851
+
852
+ # Initialize applied_transforms only in top-level Compose if requested
853
+ if self.save_applied_params and self.main_compose:
854
+ data["applied_transforms"] = []
855
+
856
+ need_to_run = force_apply or self.py_random.random() < self.p
857
+ if not need_to_run:
858
+ return data
859
+
860
+ self.preprocess(data)
861
+
862
+ for t in self.transforms:
863
+ data = t(**data)
864
+ self._track_transform_params(t, data)
865
+ data = self.check_data_post_transform(data)
866
+
867
+ return self.postprocess(data)
868
+
869
+ def _check_worker_seed(self) -> None:
870
+ """Check and update random seed if in worker context."""
871
+ if not hasattr(self, "_base_seed") or self._base_seed is None:
872
+ return
873
+
874
+ # Check if we're in a worker and need to update the seed
875
+ try:
876
+ import torch
877
+ import torch.utils.data
878
+
879
+ worker_info = torch.utils.data.get_worker_info()
880
+ if worker_info is not None:
881
+ # Get the current torch initial seed
882
+ current_torch_seed = torch.initial_seed()
883
+
884
+ # Check if we've already synchronized for this seed
885
+ if hasattr(self, "_last_torch_seed") and self._last_torch_seed == current_torch_seed:
886
+ return
887
+
888
+ # Update the seed and mark as synchronized
889
+ self._last_torch_seed = current_torch_seed
890
+ effective_seed = self._get_effective_seed(self._base_seed)
891
+
892
+ # Update our own random state
893
+ self.random_generator = np.random.default_rng(effective_seed)
894
+ self.py_random = random.Random(effective_seed)
895
+
896
+ # Propagate to all transforms
897
+ for transform in self.transforms:
898
+ if hasattr(transform, "set_random_state"):
899
+ transform.set_random_state(self.random_generator, self.py_random)
900
+ elif hasattr(transform, "set_random_seed"):
901
+ # For transforms that don't have set_random_state, use set_random_seed
902
+ transform.set_random_seed(effective_seed)
903
+ except (ImportError, AttributeError):
904
+ pass
905
+
906
+ def __setstate__(self, state: dict[str, Any]) -> None:
907
+ """Set state from unpickling and handle worker seed."""
908
+ self.__dict__.update(state)
909
+ # If we have a base seed, recalculate effective seed in worker context
910
+ if hasattr(self, "_base_seed") and self._base_seed is not None:
911
+ # Reset _last_torch_seed to ensure worker-seed sync runs after unpickling
912
+ self._last_torch_seed = None
913
+ # Recalculate effective seed in worker context
914
+ self.set_random_seed(self._base_seed)
915
+ elif hasattr(self, "seed") and self.seed is not None:
916
+ # For backward compatibility, if no base seed but seed exists
917
+ self._base_seed = self.seed
918
+ self._last_torch_seed = None
919
+ self.set_random_seed(self.seed)
920
+
921
+ def set_random_seed(self, seed: int | None) -> None:
922
+ """Override to use worker-aware seed functionality.
923
+
924
+ Args:
925
+ seed (int | None): Random seed to use
926
+
927
+ """
928
+ # Store the original base seed
929
+ self._base_seed = seed
930
+ self.seed = seed
931
+
932
+ # Get effective seed considering worker context
933
+ effective_seed = self._get_effective_seed(seed)
934
+
935
+ # Initialize random generators with effective seed
936
+ self.random_generator = np.random.default_rng(effective_seed)
937
+ self.py_random = random.Random(effective_seed)
938
+
939
+ # Propagate to all transforms
940
+ for transform in self.transforms:
941
+ if hasattr(transform, "set_random_state"):
942
+ transform.set_random_state(self.random_generator, self.py_random)
943
+ elif hasattr(transform, "set_random_seed"):
944
+ # For transforms that don't have set_random_state, use set_random_seed
945
+ transform.set_random_seed(effective_seed)
946
+
947
+ def preprocess(self, data: Any) -> None:
948
+ """Preprocess input data before applying transforms."""
949
+ # Always validate shapes if is_check_shapes is True, regardless of strict mode
950
+ if self.is_check_shapes:
951
+ shapes = [] # For H,W checks
952
+ volume_shapes = [] # For D,H,W checks
953
+
954
+ for data_name, data_value in data.items():
955
+ internal_name = self._additional_targets.get(data_name, data_name)
956
+
957
+ # Skip empty data
958
+ if data_value is None:
959
+ continue
960
+
961
+ shape = self._get_data_shape(data_name, internal_name, data_value)
962
+ if shape is not None:
963
+ if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
964
+ shapes.append(shape[1:3]) # H,W from (D,H,W)
965
+ volume_shapes.append(shape[:3]) # D,H,W
966
+ elif internal_name in {"volumes", "masks3d"}:
967
+ shapes.append(shape[2:4]) # H,W from (N,D,H,W)
968
+ volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
969
+ else:
970
+ shapes.append(shape[:2]) # H,W
971
+
972
+ self._check_shape_consistency(shapes, volume_shapes)
973
+
974
+ # Do strict validation only if enabled
975
+ if self.strict:
976
+ self._validate_data(data)
977
+
978
+ self._preprocess_processors(data)
979
+ self._preprocess_arrays(data)
980
+
981
+ def _validate_data(self, data: dict[str, Any]) -> None:
982
+ """Validate input data keys and arguments."""
983
+ if not self.strict:
984
+ return
985
+
986
+ for data_name in data:
987
+ if not self._is_valid_key(data_name):
988
+ raise ValueError(f"Key {data_name} is not in available keys.")
989
+
990
+ if self.is_check_args:
991
+ self._check_args(**data)
992
+
993
+ def _is_valid_key(self, key: str) -> bool:
994
+ """Check if the key is valid for processing."""
995
+ return key in self._available_keys or key in MASK_KEYS or key in IMAGE_KEYS or key == "applied_transforms"
996
+
997
+ def _preprocess_processors(self, data: dict[str, Any]) -> None:
998
+ """Run preprocessors if this is the main compose."""
999
+ if not self.main_compose:
1000
+ return
1001
+
1002
+ for processor in self.processors.values():
1003
+ processor.ensure_data_valid(data)
1004
+ for processor in self.processors.values():
1005
+ processor.preprocess(data)
1006
+
1007
+ def _preprocess_arrays(self, data: dict[str, Any]) -> None:
1008
+ """Convert lists to numpy arrays for images and masks, and ensure contiguity."""
1009
+ self._preprocess_images(data)
1010
+ self._preprocess_masks(data)
1011
+
1012
+ def _preprocess_images(self, data: dict[str, Any]) -> None:
1013
+ """Convert image lists to numpy arrays."""
1014
+ if "images" not in data:
1015
+ return
1016
+
1017
+ if isinstance(data["images"], (list, tuple)):
1018
+ self._images_was_list = True
1019
+ # Skip stacking for empty lists
1020
+ if not data["images"]:
1021
+ return
1022
+ data["images"] = np.stack(data["images"])
1023
+ else:
1024
+ self._images_was_list = False
1025
+
1026
+ def _preprocess_masks(self, data: dict[str, Any]) -> None:
1027
+ """Convert mask lists to numpy arrays."""
1028
+ if "masks" not in data:
1029
+ return
1030
+
1031
+ if isinstance(data["masks"], (list, tuple)):
1032
+ self._masks_was_list = True
1033
+ # Skip stacking for empty lists
1034
+ if not data["masks"]:
1035
+ return
1036
+ data["masks"] = np.stack(data["masks"])
1037
+ else:
1038
+ self._masks_was_list = False
1039
+
1040
+ def postprocess(self, data: dict[str, Any]) -> dict[str, Any]:
1041
+ """Apply post-processing to data after all transforms have been applied.
1042
+
1043
+ Args:
1044
+ data (dict[str, Any]): Data after transformation.
1045
+
1046
+ Returns:
1047
+ dict[str, Any]: Post-processed data.
1048
+
1049
+ """
1050
+ if self.main_compose:
1051
+ for p in self.processors.values():
1052
+ p.postprocess(data)
1053
+
1054
+ # Convert back to list if original input was a list
1055
+ if "images" in data and self._images_was_list:
1056
+ data["images"] = list(data["images"])
1057
+
1058
+ if "masks" in data and self._masks_was_list:
1059
+ data["masks"] = list(data["masks"])
1060
+
1061
+ return data
1062
+
1063
+ def to_dict_private(self) -> dict[str, Any]:
1064
+ """Convert the composition to a dictionary for serialization.
1065
+
1066
+ Returns:
1067
+ dict[str, Any]: Dictionary representation of the composition.
1068
+
1069
+ """
1070
+ dictionary = super().to_dict_private()
1071
+ bbox_processor = self.processors.get("bboxes")
1072
+ keypoints_processor = self.processors.get("keypoints")
1073
+ dictionary.update(
1074
+ {
1075
+ "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
1076
+ "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
1077
+ "additional_targets": self.additional_targets,
1078
+ "is_check_shapes": self.is_check_shapes,
1079
+ "seed": getattr(self, "_base_seed", None),
1080
+ },
1081
+ )
1082
+ return dictionary
1083
+
1084
+ def get_dict_with_id(self) -> dict[str, Any]:
1085
+ """Get a dictionary representation with object IDs for replay mode.
1086
+
1087
+ Returns:
1088
+ dict[str, Any]: Dictionary with composition data and object IDs.
1089
+
1090
+ """
1091
+ dictionary = super().get_dict_with_id()
1092
+ bbox_processor = self.processors.get("bboxes")
1093
+ keypoints_processor = self.processors.get("keypoints")
1094
+ dictionary.update(
1095
+ {
1096
+ "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None,
1097
+ "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None),
1098
+ "additional_targets": self.additional_targets,
1099
+ "params": None,
1100
+ "is_check_shapes": self.is_check_shapes,
1101
+ },
1102
+ )
1103
+ return dictionary
1104
+
1105
+ @staticmethod
1106
+ def _check_single_data(data_name: str, data: Any) -> tuple[int, int]:
1107
+ if not isinstance(data, np.ndarray):
1108
+ raise TypeError(f"{data_name} must be numpy array type")
1109
+ return data.shape[:2]
1110
+
1111
+ @staticmethod
1112
+ def _check_masks_data(data_name: str, data: Any) -> tuple[int, int] | None:
1113
+ """Check masks data format and return shape.
1114
+
1115
+ Args:
1116
+ data_name (str): Name of the data field being checked
1117
+ data (Any): Input data in one of these formats:
1118
+ - List of numpy arrays, each of shape (H, W) or (H, W, C)
1119
+ - Numpy array of shape (N, H, W) or (N, H, W, C)
1120
+ - Empty list for cases where no masks are present
1121
+
1122
+ Returns:
1123
+ tuple[int, int] | None: (height, width) of the first mask, or None if masks list is empty
1124
+ Raises:
1125
+ TypeError: If data format is invalid
1126
+
1127
+ """
1128
+ if isinstance(data, np.ndarray):
1129
+ if data.ndim not in [3, 4]: # (N,H,W) or (N,H,W,C)
1130
+ raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
1131
+ return data.shape[1:3] # Return (H,W)
1132
+
1133
+ if isinstance(data, (list, tuple)):
1134
+ if not data:
1135
+ # Allow empty list/tuple of masks
1136
+ return None
1137
+ if not all(isinstance(m, np.ndarray) for m in data):
1138
+ raise TypeError(f"All elements in {data_name} must be numpy arrays")
1139
+ if any(m.ndim not in {2, 3} for m in data):
1140
+ raise TypeError(f"All masks in {data_name} must be 2D or 3D numpy arrays")
1141
+ return data[0].shape[:2]
1142
+
1143
+ raise TypeError(f"{data_name} must be either a numpy array or a sequence of numpy arrays")
1144
+
1145
+ @staticmethod
1146
+ def _check_multi_data(data_name: str, data: Any) -> tuple[int, int]:
1147
+ """Check multi-image data format and return shape.
1148
+
1149
+ Args:
1150
+ data_name (str): Name of the data field being checked
1151
+ data (Any): Input data in one of these formats:
1152
+ - List-like of numpy arrays
1153
+ - Numpy array of shape (N, H, W, C) or (N, H, W)
1154
+
1155
+ Returns:
1156
+ tuple[int, int]: (height, width) of the first image
1157
+ Raises:
1158
+ TypeError: If data format is invalid
1159
+
1160
+ """
1161
+ if isinstance(data, np.ndarray):
1162
+ if data.ndim not in {3, 4}: # (N,H,W) or (N,H,W,C)
1163
+ raise TypeError(f"{data_name} as numpy array must be 3D or 4D")
1164
+ return data.shape[1:3] # Return (H,W)
1165
+
1166
+ if not isinstance(data, Sequence) or not isinstance(data[0], np.ndarray):
1167
+ raise TypeError(f"{data_name} must be either a numpy array or a list of numpy arrays")
1168
+ return data[0].shape[:2]
1169
+
1170
+ @staticmethod
1171
+ def _check_bbox_keypoint_params(internal_data_name: str, processors: dict[str, Any]) -> None:
1172
+ if internal_data_name in CHECK_BBOX_PARAM and processors.get("bboxes") is None:
1173
+ raise ValueError("bbox_params must be specified for bbox transformations")
1174
+ if internal_data_name in CHECK_KEYPOINTS_PARAM and processors.get("keypoints") is None:
1175
+ raise ValueError("keypoints_params must be specified for keypoint transformations")
1176
+
1177
+ @staticmethod
1178
+ def _check_shapes(shapes: list[tuple[int, ...]], is_check_shapes: bool) -> None:
1179
+ if is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
1180
+ raise ValueError(
1181
+ "Height and Width of image, mask or masks should be equal. You can disable shapes check "
1182
+ "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure "
1183
+ "about your data consistency).",
1184
+ )
1185
+
1186
+ def _check_args(self, **kwargs: Any) -> None:
1187
+ shapes = [] # For H,W checks
1188
+ volume_shapes = [] # For D,H,W checks
1189
+
1190
+ for data_name, data in kwargs.items():
1191
+ internal_name = self._additional_targets.get(data_name, data_name)
1192
+
1193
+ # For CHECKED_SINGLE, we must validate even if None
1194
+ if internal_name in CHECKED_SINGLE:
1195
+ if not isinstance(data, np.ndarray):
1196
+ raise TypeError(f"{data_name} must be numpy array type")
1197
+ shapes.append(data.shape[:2])
1198
+ continue
1199
+
1200
+ # Skip empty data or non-array/list inputs for other types
1201
+ if data is None:
1202
+ continue
1203
+ if not isinstance(data, (np.ndarray, list)):
1204
+ continue
1205
+
1206
+ self._check_bbox_keypoint_params(internal_name, self.processors)
1207
+
1208
+ shape = self._get_data_shape(data_name, internal_name, data)
1209
+ if shape is None:
1210
+ continue
1211
+
1212
+ # Handle different shape types
1213
+ if internal_name in CHECKED_VOLUME | CHECKED_MASK3D:
1214
+ shapes.append(shape[1:3]) # H,W from (D,H,W)
1215
+ volume_shapes.append(shape[:3]) # D,H,W
1216
+ elif internal_name in {"volumes", "masks3d"}:
1217
+ shapes.append(shape[2:4]) # H,W from (N,D,H,W)
1218
+ volume_shapes.append(shape[1:4]) # D,H,W from (N,D,H,W)
1219
+ else:
1220
+ shapes.append(shape[:2]) # H,W
1221
+
1222
+ self._check_shape_consistency(shapes, volume_shapes)
1223
+
1224
+ def _get_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
1225
+ """Get shape of data based on its type."""
1226
+ # Handle single images and masks
1227
+ if internal_name in CHECKED_SINGLE:
1228
+ return self._get_single_data_shape(data_name, data)
1229
+
1230
+ # Handle volumes
1231
+ if internal_name in CHECKED_VOLUME:
1232
+ return self._check_volume_data(data_name, data)
1233
+
1234
+ # Handle 3D masks
1235
+ if internal_name in CHECKED_MASK3D:
1236
+ return self._check_mask3d_data(data_name, data)
1237
+
1238
+ # Handle multi-item data (masks, images, volumes)
1239
+ if internal_name in CHECKED_MULTI:
1240
+ return self._get_multi_data_shape(data_name, internal_name, data)
1241
+
1242
+ return None
1243
+
1244
+ def _get_single_data_shape(self, data_name: str, data: np.ndarray) -> tuple[int, ...]:
1245
+ """Get shape of single image or mask."""
1246
+ if not isinstance(data, np.ndarray):
1247
+ raise TypeError(f"{data_name} must be numpy array type")
1248
+ return data.shape
1249
+
1250
+ def _get_multi_data_shape(self, data_name: str, internal_name: str, data: Any) -> tuple[int, ...] | None:
1251
+ """Get shape of multi-item data (masks, images, volumes)."""
1252
+ if internal_name == "masks":
1253
+ shape = self._check_masks_data(data_name, data)
1254
+ # Skip empty masks lists when returning shape
1255
+ return None if shape is None else shape
1256
+
1257
+ if internal_name in {"volumes", "masks3d"}: # Group these together
1258
+ if not isinstance(data, np.ndarray):
1259
+ raise TypeError(f"{data_name} must be numpy array type")
1260
+ if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
1261
+ raise TypeError(f"{data_name} must be 4D or 5D array")
1262
+ return data.shape # Return full shape
1263
+
1264
+ return self._check_multi_data(data_name, data)
1265
+
1266
+ def _check_shape_consistency(self, shapes: list[tuple[int, ...]], volume_shapes: list[tuple[int, ...]]) -> None:
1267
+ """Check consistency of shapes."""
1268
+ # Check H,W consistency
1269
+ self._check_shapes(shapes, self.is_check_shapes)
1270
+
1271
+ # Check D,H,W consistency for volumes and 3D masks
1272
+ if self.is_check_shapes and volume_shapes and volume_shapes.count(volume_shapes[0]) != len(volume_shapes):
1273
+ raise ValueError(
1274
+ "Depth, Height and Width of volume, mask3d, volumes and masks3d should be equal. "
1275
+ "You can disable shapes check by setting is_check_shapes=False.",
1276
+ )
1277
+
1278
+ @staticmethod
1279
+ def _check_volume_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
1280
+ if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
1281
+ raise TypeError(f"{data_name} must be 3D or 4D array")
1282
+ return data.shape[:3] # Return (D,H,W)
1283
+
1284
+ @staticmethod
1285
+ def _check_volumes_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
1286
+ if data.ndim not in {4, 5}: # (N,D,H,W) or (N,D,H,W,C)
1287
+ raise TypeError(f"{data_name} must be 4D or 5D array")
1288
+ return data.shape[1:4] # Return (D,H,W)
1289
+
1290
+ @staticmethod
1291
+ def _check_mask3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
1292
+ """Check single volumetric mask data format and return shape."""
1293
+ if data.ndim not in {3, 4}: # (D,H,W) or (D,H,W,C)
1294
+ raise TypeError(f"{data_name} must be 3D or 4D array")
1295
+ return data.shape[:3] # Return (D,H,W)
1296
+
1297
+ @staticmethod
1298
+ def _check_masks3d_data(data_name: str, data: np.ndarray) -> tuple[int, int, int]:
1299
+ """Check multiple volumetric masks data format and return shape."""
1300
+ if data.ndim not in [4, 5]: # (N,D,H,W) or (N,D,H,W,C)
1301
+ raise TypeError(f"{data_name} must be 4D or 5D array")
1302
+ return data.shape[1:4] # Return (D,H,W)
1303
+
1304
+ def _get_init_params(self) -> dict[str, Any]:
1305
+ """Get parameters needed to recreate this Compose instance.
1306
+
1307
+ Returns:
1308
+ dict[str, Any]: Dictionary of initialization parameters
1309
+
1310
+ """
1311
+ bbox_processor = self.processors.get("bboxes")
1312
+ keypoints_processor = self.processors.get("keypoints")
1313
+
1314
+ return {
1315
+ "bbox_params": bbox_processor.params if bbox_processor else None,
1316
+ "keypoint_params": keypoints_processor.params if keypoints_processor else None,
1317
+ "additional_targets": self.additional_targets,
1318
+ "p": self.p,
1319
+ "is_check_shapes": self.is_check_shapes,
1320
+ "strict": self.strict,
1321
+ "mask_interpolation": getattr(self, "mask_interpolation", None),
1322
+ "seed": getattr(self, "_base_seed", None),
1323
+ "save_applied_params": getattr(self, "save_applied_params", False),
1324
+ }
1325
+
1326
+
1327
+ class OneOf(BaseCompose):
1328
+ """Select one of transforms to apply. Selected transform will be called with `force_apply=True`.
1329
+ Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
1330
+
1331
+ Args:
1332
+ transforms (list): list of transformations to compose.
1333
+ p (float): probability of applying selected transform. Default: 0.5.
1334
+
1335
+ """
1336
+
1337
+ def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
1338
+ super().__init__(transforms=transforms, p=p)
1339
+ transforms_ps = [t.p for t in self.transforms]
1340
+ s = sum(transforms_ps)
1341
+ self.transforms_ps = [t / s for t in transforms_ps]
1342
+
1343
+ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
1344
+ """Apply the OneOf composition to the input data.
1345
+
1346
+ Args:
1347
+ *args (Any): Positional arguments are not supported.
1348
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1349
+ **data (Any): Dict with data to transform.
1350
+
1351
+ Returns:
1352
+ dict[str, Any]: Dictionary with transformed data.
1353
+
1354
+ Raises:
1355
+ KeyError: If positional arguments are provided.
1356
+
1357
+ """
1358
+ if self.replay_mode:
1359
+ for t in self.transforms:
1360
+ data = t(**data)
1361
+ return data
1362
+
1363
+ if self.transforms_ps and (force_apply or self.py_random.random() < self.p):
1364
+ idx: int = self.random_generator.choice(len(self.transforms), p=self.transforms_ps)
1365
+ t = self.transforms[idx]
1366
+ data = t(force_apply=True, **data)
1367
+ self._track_transform_params(t, data)
1368
+ return data
1369
+
1370
+
1371
+ class SomeOf(BaseCompose):
1372
+ """Selects exactly `n` transforms from the given list and applies them.
1373
+
1374
+ The selection of which `n` transforms to apply is done **uniformly at random**
1375
+ from the provided list. Each transform in the list has an equal chance of being selected.
1376
+
1377
+ Once the `n` transforms are selected, each one is applied **based on its
1378
+ individual probability** `p`.
1379
+
1380
+ Args:
1381
+ transforms (list[BasicTransform | BaseCompose]): A list of transforms to choose from.
1382
+ n (int): The exact number of transforms to select and potentially apply.
1383
+ If `replace=False` and `n` is greater than the number of available transforms,
1384
+ `n` will be capped at the number of transforms.
1385
+ replace (bool): Whether to sample transforms with replacement. If True, the same
1386
+ transform can be selected multiple times (up to `n` times).
1387
+ Default is False.
1388
+ p (float): The probability that this `SomeOf` composition will be applied.
1389
+ If applied, it will select `n` transforms and attempt to apply them.
1390
+ Default is 1.0.
1391
+
1392
+ Note:
1393
+ - The overall probability `p` of the `SomeOf` block determines if *any* selection
1394
+ and application occurs.
1395
+ - The individual probability `p` of each transform inside the list determines if
1396
+ that specific transform runs *if it is selected*.
1397
+ - If `replace` is True, the same transform might be selected multiple times, and
1398
+ its individual probability `p` will be checked each time it's encountered.
1399
+ - When using pipeline modification operators (+, -, __radd__), the `n` parameter
1400
+ is preserved while the pool of available transforms changes:
1401
+ - `SomeOf([A, B], n=2) + C` → `SomeOf([A, B, C], n=2)` (selects 2 from 3 transforms)
1402
+ - This allows for dynamic adjustment of the transform pool without changing selection count.
1403
+
1404
+ Examples:
1405
+ >>> import albumentations as A
1406
+ >>> transform = A.SomeOf([
1407
+ ... A.HorizontalFlip(p=0.5), # 50% chance to apply if selected
1408
+ ... A.VerticalFlip(p=0.8), # 80% chance to apply if selected
1409
+ ... A.RandomRotate90(p=1.0), # 100% chance to apply if selected
1410
+ ... ], n=2, replace=False, p=1.0) # Always select 2 transforms uniformly
1411
+
1412
+ # In each call, 2 transforms out of 3 are chosen uniformly.
1413
+ # For example, if HFlip and VFlip are chosen:
1414
+ # - HFlip runs if random() < 0.5
1415
+ # - VFlip runs if random() < 0.8
1416
+ # If VFlip and Rotate90 are chosen:
1417
+ # - VFlip runs if random() < 0.8
1418
+ # - Rotate90 runs if random() < 1.0 (always)
1419
+
1420
+ >>> # Pipeline modification example:
1421
+ >>> # Add more transforms to the pool while keeping n=2
1422
+ >>> extended = transform + [A.Blur(p=1.0), A.RandomBrightnessContrast(p=0.7)]
1423
+ >>> # Now selects 2 transforms from 5 available transforms uniformly
1424
+
1425
+ """
1426
+
1427
+ def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
1428
+ super().__init__(transforms, p)
1429
+ self.n = n
1430
+ if not replace and n > len(self.transforms):
1431
+ self.n = len(self.transforms)
1432
+ warnings.warn(
1433
+ f"`n` is greater than number of transforms. `n` will be set to {self.n}.",
1434
+ UserWarning,
1435
+ stacklevel=2,
1436
+ )
1437
+ self.replace = replace
1438
+
1439
+ def __call__(self, *arg: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
1440
+ """Apply n randomly selected transforms from the list of transforms.
1441
+
1442
+ Args:
1443
+ *arg (Any): Positional arguments are not supported.
1444
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1445
+ **data (Any): Dict with data to transform.
1446
+
1447
+ Returns:
1448
+ dict[str, Any]: Dictionary with transformed data.
1449
+
1450
+ """
1451
+ if self.replay_mode:
1452
+ for t in self.transforms:
1453
+ data = t(**data)
1454
+ data = self.check_data_post_transform(data)
1455
+ return data
1456
+
1457
+ if self.py_random.random() < self.p: # Check overall SomeOf probability
1458
+ # Get indices uniformly
1459
+ indices_to_consider = self._get_idx()
1460
+ for i in indices_to_consider:
1461
+ t = self.transforms[i]
1462
+ # Apply the transform respecting its own probability `t.p`
1463
+ data = t(**data)
1464
+ self._track_transform_params(t, data)
1465
+ data = self.check_data_post_transform(data)
1466
+ return data
1467
+
1468
+ def _get_idx(self) -> np.ndarray[np.int_]:
1469
+ # Use uniform probability for selection, ignore individual p values here
1470
+ idx = self.random_generator.choice(
1471
+ len(self.transforms),
1472
+ size=self.n,
1473
+ replace=self.replace,
1474
+ )
1475
+ idx.sort()
1476
+ return idx
1477
+
1478
+ def to_dict_private(self) -> dict[str, Any]:
1479
+ """Convert the SomeOf composition to a dictionary for serialization.
1480
+
1481
+ Returns:
1482
+ dict[str, Any]: Dictionary representation of the composition.
1483
+
1484
+ """
1485
+ dictionary = super().to_dict_private()
1486
+ dictionary.update({"n": self.n, "replace": self.replace})
1487
+ return dictionary
1488
+
1489
+ def _get_init_params(self) -> dict[str, Any]:
1490
+ """Get parameters needed to recreate this SomeOf instance.
1491
+
1492
+ Returns:
1493
+ dict[str, Any]: Dictionary of initialization parameters
1494
+
1495
+ """
1496
+ base_params = super()._get_init_params()
1497
+ base_params.update(
1498
+ {
1499
+ "n": self.n,
1500
+ "replace": self.replace,
1501
+ },
1502
+ )
1503
+ return base_params
1504
+
1505
+
1506
+ class RandomOrder(SomeOf):
1507
+ """Apply a random subset of transforms from the given list in a random order.
1508
+
1509
+ Selects exactly `n` transforms uniformly at random from the list, and then applies
1510
+ the selected transforms in a random order. Each selected transform is applied
1511
+ based on its individual probability `p`.
1512
+
1513
+ Attributes:
1514
+ transforms (TransformsSeqType): A list of transformations to choose from.
1515
+ n (int): The number of transforms to apply. If `n` is greater than the number of available transforms
1516
+ and `replace` is False, `n` will be set to the number of available transforms.
1517
+ replace (bool): Whether to sample transforms with replacement. If True, the same transform can be
1518
+ selected multiple times. Default is False.
1519
+ p (float): Probability of applying the selected transforms. Should be in the range [0, 1]. Default is 1.0.
1520
+
1521
+ Examples:
1522
+ >>> import albumentations as A
1523
+ >>> transform = A.RandomOrder([
1524
+ ... A.HorizontalFlip(p=0.5),
1525
+ ... A.VerticalFlip(p=1.0),
1526
+ ... A.RandomBrightnessContrast(p=0.8),
1527
+ ... ], n=2, replace=False, p=1.0)
1528
+ >>> # This will uniformly select 2 transforms and apply them in a random order,
1529
+ >>> # respecting their individual probabilities (0.5, 1.0, 0.8).
1530
+
1531
+ Note:
1532
+ - Inherits from SomeOf, but overrides `_get_idx` to ensure random order without sorting.
1533
+ - Selection is uniform; application depends on individual transform probabilities.
1534
+
1535
+ """
1536
+
1537
+ def __init__(self, transforms: TransformsSeqType, n: int = 1, replace: bool = False, p: float = 1):
1538
+ # Initialize using SomeOf's logic (which now does uniform selection setup)
1539
+ super().__init__(transforms=transforms, n=n, replace=replace, p=p)
1540
+
1541
+ def _get_idx(self) -> np.ndarray[np.int_]:
1542
+ # Perform uniform random selection without replacement, like SomeOf
1543
+ # Crucially, DO NOT sort the indices here to maintain random order.
1544
+ return self.random_generator.choice(
1545
+ len(self.transforms),
1546
+ size=self.n,
1547
+ replace=self.replace,
1548
+ )
1549
+
1550
+
1551
+ class OneOrOther(BaseCompose):
1552
+ """Select one or another transform to apply. Selected transform will be called with `force_apply=True`."""
1553
+
1554
+ def __init__(
1555
+ self,
1556
+ first: TransformType | None = None,
1557
+ second: TransformType | None = None,
1558
+ transforms: TransformsSeqType | None = None,
1559
+ p: float = 0.5,
1560
+ ):
1561
+ if transforms is None:
1562
+ if first is None or second is None:
1563
+ msg = "You must set both first and second or set transforms argument."
1564
+ raise ValueError(msg)
1565
+ transforms = [first, second]
1566
+ super().__init__(transforms=transforms, p=p)
1567
+ if len(self.transforms) != NUM_ONEOF_TRANSFORMS:
1568
+ warnings.warn("Length of transforms is not equal to 2.", stacklevel=2)
1569
+
1570
+ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
1571
+ """Apply one or another transform to the input data.
1572
+
1573
+ Args:
1574
+ *args (Any): Positional arguments are not supported.
1575
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1576
+ **data (Any): Dict with data to transform.
1577
+
1578
+ Returns:
1579
+ dict[str, Any]: Dictionary with transformed data.
1580
+
1581
+ """
1582
+ if self.replay_mode:
1583
+ for t in self.transforms:
1584
+ data = t(**data)
1585
+ self._track_transform_params(t, data)
1586
+ return data
1587
+
1588
+ if self.py_random.random() < self.p:
1589
+ return self.transforms[0](force_apply=True, **data)
1590
+
1591
+ return self.transforms[-1](force_apply=True, **data)
1592
+
1593
+
1594
+ class SelectiveChannelTransform(BaseCompose):
1595
+ """A transformation class to apply specified transforms to selected channels of an image.
1596
+
1597
+ This class extends BaseCompose to allow selective application of transformations to
1598
+ specified image channels. It extracts the selected channels, applies the transformations,
1599
+ and then reinserts the transformed channels back into their original positions in the image.
1600
+
1601
+ Args:
1602
+ transforms (TransformsSeqType):
1603
+ A sequence of transformations (from Albumentations) to be applied to the specified channels.
1604
+ channels (Sequence[int]):
1605
+ A sequence of integers specifying the indices of the channels to which the transforms should be applied.
1606
+ p (float): Probability that the transform will be applied; the default is 1.0 (always apply).
1607
+
1608
+ Returns:
1609
+ dict[str, Any]: The transformed data dictionary, which includes the transformed 'image' key.
1610
+
1611
+ Note:
1612
+ - When using pipeline modification operators (+, -, __radd__), the `channels` parameter
1613
+ is preserved in the resulting SelectiveChannelTransform instance.
1614
+ - Only the transform list is modified while maintaining the same channel selection behavior.
1615
+
1616
+ """
1617
+
1618
+ def __init__(
1619
+ self,
1620
+ transforms: TransformsSeqType,
1621
+ channels: Sequence[int] = (0, 1, 2),
1622
+ p: float = 1.0,
1623
+ ) -> None:
1624
+ super().__init__(transforms=transforms, p=p)
1625
+ self.channels = channels
1626
+
1627
+ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
1628
+ """Apply transforms to specific channels of the image.
1629
+
1630
+ Args:
1631
+ *args (Any): Positional arguments are not supported.
1632
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1633
+ **data (Any): Dict with data to transform.
1634
+
1635
+ Returns:
1636
+ dict[str, Any]: Dictionary with transformed data.
1637
+
1638
+ """
1639
+ if force_apply or self.py_random.random() < self.p:
1640
+ image = data["image"]
1641
+
1642
+ selected_channels = image[:, :, self.channels]
1643
+ sub_image = np.ascontiguousarray(selected_channels)
1644
+
1645
+ for t in self.transforms:
1646
+ sub_data = {"image": sub_image}
1647
+ sub_image = t(**sub_data)["image"]
1648
+ self._track_transform_params(t, sub_data)
1649
+
1650
+ transformed_channels = cv2.split(sub_image)
1651
+ output_img = image.copy()
1652
+
1653
+ for idx, channel in zip(self.channels, transformed_channels):
1654
+ output_img[:, :, idx] = channel
1655
+
1656
+ data["image"] = np.ascontiguousarray(output_img)
1657
+
1658
+ return data
1659
+
1660
+ def _get_init_params(self) -> dict[str, Any]:
1661
+ """Get parameters needed to recreate this SelectiveChannelTransform instance.
1662
+
1663
+ Returns:
1664
+ dict[str, Any]: Dictionary of initialization parameters
1665
+
1666
+ """
1667
+ base_params = super()._get_init_params()
1668
+ base_params.update(
1669
+ {
1670
+ "channels": self.channels,
1671
+ },
1672
+ )
1673
+ return base_params
1674
+
1675
+
1676
+ class ReplayCompose(Compose):
1677
+ """Composition class that enables transform replay functionality.
1678
+
1679
+ This class extends the Compose class with the ability to record and replay
1680
+ transformations. This is useful for applying the same sequence of random
1681
+ transformations to different data.
1682
+
1683
+ Args:
1684
+ transforms (TransformsSeqType): List of transformations to compose.
1685
+ bbox_params (dict[str, Any] | BboxParams | None): Parameters for bounding box transforms.
1686
+ keypoint_params (dict[str, Any] | KeypointParams | None): Parameters for keypoint transforms.
1687
+ additional_targets (dict[str, str] | None): Dictionary of additional targets.
1688
+ p (float): Probability of applying the compose.
1689
+ is_check_shapes (bool): Whether to check shapes of different targets.
1690
+ save_key (str): Key for storing the applied transformations.
1691
+
1692
+ """
1693
+
1694
+ def __init__(
1695
+ self,
1696
+ transforms: TransformsSeqType,
1697
+ bbox_params: dict[str, Any] | BboxParams | None = None,
1698
+ keypoint_params: dict[str, Any] | KeypointParams | None = None,
1699
+ additional_targets: dict[str, str] | None = None,
1700
+ p: float = 1.0,
1701
+ is_check_shapes: bool = True,
1702
+ save_key: str = "replay",
1703
+ ):
1704
+ super().__init__(transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes)
1705
+ self.set_deterministic(True, save_key=save_key)
1706
+ self.save_key = save_key
1707
+ self._available_keys.add(save_key)
1708
+
1709
+ def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> dict[str, Any]:
1710
+ """Apply transforms and record parameters for future replay.
1711
+
1712
+ Args:
1713
+ *args (Any): Positional arguments are not supported.
1714
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1715
+ **kwargs (Any): Dict with data to transform.
1716
+
1717
+ Returns:
1718
+ dict[str, Any]: Dictionary with transformed data and replay information.
1719
+
1720
+ """
1721
+ kwargs[self.save_key] = defaultdict(dict)
1722
+ result = super().__call__(force_apply=force_apply, **kwargs)
1723
+ serialized = self.get_dict_with_id()
1724
+ self.fill_with_params(serialized, result[self.save_key])
1725
+ self.fill_applied(serialized)
1726
+ result[self.save_key] = serialized
1727
+ return result
1728
+
1729
+ @staticmethod
1730
+ def replay(saved_augmentations: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
1731
+ """Replay previously saved augmentations.
1732
+
1733
+ Args:
1734
+ saved_augmentations (dict[str, Any]): Previously saved augmentation parameters.
1735
+ **kwargs (Any): Dict with data to transform.
1736
+
1737
+ Returns:
1738
+ dict[str, Any]: Dictionary with transformed data using saved parameters.
1739
+
1740
+ """
1741
+ augs = ReplayCompose._restore_for_replay(saved_augmentations)
1742
+ return augs(force_apply=True, **kwargs)
1743
+
1744
+ @staticmethod
1745
+ def _restore_for_replay(
1746
+ transform_dict: dict[str, Any],
1747
+ lambda_transforms: dict[str, Any] | None = None,
1748
+ ) -> TransformType:
1749
+ """Args:
1750
+ transform_dict (dict[str, Any]): A dictionary that contains transform data.
1751
+ lambda_transforms (dict): A dictionary that contains lambda transforms, that
1752
+ is instances of the Lambda class.
1753
+ This dictionary is required when you are restoring a pipeline that contains lambda transforms.
1754
+ Keys in that dictionary should be named same as `name` arguments in respective lambda transforms
1755
+ from a serialized pipeline.
1756
+
1757
+ """
1758
+ applied = transform_dict["applied"]
1759
+ params = transform_dict["params"]
1760
+ lmbd = instantiate_nonserializable(transform_dict, lambda_transforms)
1761
+ if lmbd:
1762
+ transform = lmbd
1763
+ else:
1764
+ name = transform_dict["__class_fullname__"]
1765
+ args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]}
1766
+ cls = SERIALIZABLE_REGISTRY[name]
1767
+ if "transforms" in args:
1768
+ args["transforms"] = [
1769
+ ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
1770
+ for t in args["transforms"]
1771
+ ]
1772
+ transform = cls(**args)
1773
+
1774
+ transform = cast("BasicTransform", transform)
1775
+ if isinstance(transform, BasicTransform):
1776
+ transform.params = params
1777
+ transform.replay_mode = True
1778
+ transform.applied_in_replay = applied
1779
+ return transform
1780
+
1781
+ def fill_with_params(self, serialized: dict[str, Any], all_params: Any) -> None:
1782
+ """Fill serialized transform data with parameters for replay.
1783
+
1784
+ Args:
1785
+ serialized (dict[str, Any]): Serialized transform data.
1786
+ all_params (Any): Parameters to fill in.
1787
+
1788
+ """
1789
+ params = all_params.get(serialized.get("id"))
1790
+ serialized["params"] = params
1791
+ del serialized["id"]
1792
+ for transform in serialized.get("transforms", []):
1793
+ self.fill_with_params(transform, all_params)
1794
+
1795
+ def fill_applied(self, serialized: dict[str, Any]) -> bool:
1796
+ """Set 'applied' flag for transforms based on parameters.
1797
+
1798
+ Args:
1799
+ serialized (dict[str, Any]): Serialized transform data.
1800
+
1801
+ Returns:
1802
+ bool: True if any transform was applied, False otherwise.
1803
+
1804
+ """
1805
+ if "transforms" in serialized:
1806
+ applied = [self.fill_applied(t) for t in serialized["transforms"]]
1807
+ serialized["applied"] = any(applied)
1808
+ else:
1809
+ serialized["applied"] = serialized.get("params") is not None
1810
+ return serialized["applied"]
1811
+
1812
+ def to_dict_private(self) -> dict[str, Any]:
1813
+ """Convert the ReplayCompose to a dictionary for serialization.
1814
+
1815
+ Returns:
1816
+ dict[str, Any]: Dictionary representation of the composition.
1817
+
1818
+ """
1819
+ dictionary = super().to_dict_private()
1820
+ dictionary.update({"save_key": self.save_key})
1821
+ return dictionary
1822
+
1823
+ def _get_init_params(self) -> dict[str, Any]:
1824
+ """Get parameters needed to recreate this ReplayCompose instance.
1825
+
1826
+ Returns:
1827
+ dict[str, Any]: Dictionary of initialization parameters
1828
+
1829
+ """
1830
+ base_params = super()._get_init_params()
1831
+ base_params.update(
1832
+ {
1833
+ "save_key": self.save_key,
1834
+ },
1835
+ )
1836
+ return base_params
1837
+
1838
+
1839
+ class Sequential(BaseCompose):
1840
+ """Sequentially applies all transforms to targets.
1841
+
1842
+ Note:
1843
+ This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose`
1844
+ the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to
1845
+ create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly
1846
+ chose sequence to input data (see the `Example` section for an example definition of such pipeline).
1847
+
1848
+ Examples:
1849
+ >>> import albumentations as A
1850
+ >>> transform = A.Compose([
1851
+ >>> A.OneOf([
1852
+ >>> A.Sequential([
1853
+ >>> A.HorizontalFlip(p=0.5),
1854
+ >>> A.ShiftScaleRotate(p=0.5),
1855
+ >>> ]),
1856
+ >>> A.Sequential([
1857
+ >>> A.VerticalFlip(p=0.5),
1858
+ >>> A.RandomBrightnessContrast(p=0.5),
1859
+ >>> ]),
1860
+ >>> ], p=1)
1861
+ >>> ])
1862
+
1863
+ """
1864
+
1865
+ def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
1866
+ super().__init__(transforms=transforms, p=p)
1867
+
1868
+ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> dict[str, Any]:
1869
+ """Apply all transforms in sequential order.
1870
+
1871
+ Args:
1872
+ *args (Any): Positional arguments are not supported.
1873
+ force_apply (bool): Whether to apply transforms regardless of probability. Default: False.
1874
+ **data (Any): Dict with data to transform.
1875
+
1876
+ Returns:
1877
+ dict[str, Any]: Dictionary with transformed data.
1878
+
1879
+ """
1880
+ if self.replay_mode or force_apply or self.py_random.random() < self.p:
1881
+ for t in self.transforms:
1882
+ data = t(**data)
1883
+ self._track_transform_params(t, data)
1884
+ data = self.check_data_post_transform(data)
1885
+ return data