nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,976 @@
|
|
|
1
|
+
"""Module containing base interfaces for all transform implementations.
|
|
2
|
+
|
|
3
|
+
This module defines the fundamental transform interfaces that form the base hierarchy for
|
|
4
|
+
all transformation classes in Albumentations. It provides abstract classes and mixins that
|
|
5
|
+
define common behavior for image, keypoint, bounding box, and volumetric transformations.
|
|
6
|
+
The interfaces handle parameter validation, random state management, target type checking,
|
|
7
|
+
and serialization capabilities that are inherited by concrete transform implementations.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import random
|
|
13
|
+
from copy import deepcopy
|
|
14
|
+
from typing import Any, Callable
|
|
15
|
+
from warnings import warn
|
|
16
|
+
|
|
17
|
+
import cv2
|
|
18
|
+
import numpy as np
|
|
19
|
+
from albucore import batch_transform
|
|
20
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
21
|
+
|
|
22
|
+
from albumentations.core.bbox_utils import BboxProcessor
|
|
23
|
+
from albumentations.core.keypoints_utils import KeypointsProcessor
|
|
24
|
+
from albumentations.core.validation import ValidatedTransformMeta
|
|
25
|
+
|
|
26
|
+
from .serialization import Serializable, SerializableMeta, get_shortest_class_fullname
|
|
27
|
+
from .type_definitions import ALL_TARGETS, Targets
|
|
28
|
+
from .utils import ensure_contiguous_output, format_args
|
|
29
|
+
|
|
30
|
+
__all__ = ["BasicTransform", "DualTransform", "ImageOnlyTransform", "NoOp", "Transform3D"]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class Interpolation:
|
|
34
|
+
def __init__(self, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_NEAREST):
|
|
35
|
+
self.downscale = downscale
|
|
36
|
+
self.upscale = upscale
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BaseTransformInitSchema(BaseModel):
|
|
40
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
41
|
+
p: float = Field(ge=0, le=1)
|
|
42
|
+
strict: bool
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class CombinedMeta(SerializableMeta, ValidatedTransformMeta):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class BasicTransform(Serializable, metaclass=CombinedMeta):
|
|
50
|
+
"""Base class for all transforms in Albumentations.
|
|
51
|
+
|
|
52
|
+
This class provides core functionality for transform application, serialization,
|
|
53
|
+
and parameter handling. It defines the interface that all transforms must follow
|
|
54
|
+
and implements common methods used across different transform types.
|
|
55
|
+
|
|
56
|
+
Class Attributes:
|
|
57
|
+
_targets (tuple[Targets, ...] | Targets): Target types this transform can work with.
|
|
58
|
+
_available_keys (set[str]): String representations of valid target keys.
|
|
59
|
+
_key2func (dict[str, Callable[..., Any]]): Mapping between target keys and their processing functions.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
interpolation (int): Interpolation method for image transforms.
|
|
63
|
+
fill (int | float | list[int] | list[float]): Fill value for image padding.
|
|
64
|
+
fill_mask (int | float | list[int] | list[float]): Fill value for mask padding.
|
|
65
|
+
deterministic (bool, optional): Whether the transform is deterministic.
|
|
66
|
+
save_key (str, optional): Key for saving transform parameters.
|
|
67
|
+
replay_mode (bool, optional): Whether the transform is in replay mode.
|
|
68
|
+
applied_in_replay (bool, optional): Whether the transform was applied in replay.
|
|
69
|
+
p (float): Probability of applying the transform.
|
|
70
|
+
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
_targets: tuple[Targets, ...] | Targets # targets that this transform can work on
|
|
74
|
+
_available_keys: set[str] # targets that this transform, as string, lower-cased
|
|
75
|
+
_key2func: dict[
|
|
76
|
+
str,
|
|
77
|
+
Callable[..., Any],
|
|
78
|
+
] # mapping for targets (plus additional targets) and methods for which they depend
|
|
79
|
+
call_backup = None
|
|
80
|
+
interpolation: int
|
|
81
|
+
fill: tuple[float, ...] | float
|
|
82
|
+
fill_mask: tuple[float, ...] | float | None
|
|
83
|
+
# replay mode params
|
|
84
|
+
deterministic: bool = False
|
|
85
|
+
save_key = "replay"
|
|
86
|
+
replay_mode = False
|
|
87
|
+
applied_in_replay = False
|
|
88
|
+
|
|
89
|
+
class InitSchema(BaseTransformInitSchema):
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
def __init__(self, p: float = 0.5):
|
|
93
|
+
self.p = p
|
|
94
|
+
self._additional_targets: dict[str, str] = {}
|
|
95
|
+
self.params: dict[Any, Any] = {}
|
|
96
|
+
self._key2func = {}
|
|
97
|
+
self._set_keys()
|
|
98
|
+
self.processors: dict[str, BboxProcessor | KeypointsProcessor] = {}
|
|
99
|
+
self.seed: int | None = None
|
|
100
|
+
self.set_random_seed(self.seed)
|
|
101
|
+
self._strict = False # Use private attribute
|
|
102
|
+
self.invalid_args: list[str] = [] # Store invalid args found during init
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def strict(self) -> bool:
|
|
106
|
+
"""Get the current strict mode setting.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
bool: True if strict mode is enabled, False otherwise.
|
|
110
|
+
|
|
111
|
+
"""
|
|
112
|
+
return self._strict
|
|
113
|
+
|
|
114
|
+
@strict.setter
|
|
115
|
+
def strict(self, value: bool) -> None:
|
|
116
|
+
"""Set strict mode and validate for invalid arguments if enabled."""
|
|
117
|
+
if value == self._strict:
|
|
118
|
+
return # No change needed
|
|
119
|
+
|
|
120
|
+
# Only validate if strict is being set to True and we have stored init args
|
|
121
|
+
if value and hasattr(self, "_init_args"):
|
|
122
|
+
# Get the list of valid arguments for this transform
|
|
123
|
+
valid_args = {"p", "strict"} # Base valid args
|
|
124
|
+
if hasattr(self, "InitSchema"):
|
|
125
|
+
valid_args.update(self.InitSchema.model_fields.keys())
|
|
126
|
+
|
|
127
|
+
# Check for invalid arguments
|
|
128
|
+
invalid_args = [name_arg for name_arg in self._init_args if name_arg not in valid_args]
|
|
129
|
+
|
|
130
|
+
if invalid_args:
|
|
131
|
+
message = (
|
|
132
|
+
f"Argument(s) '{', '.join(invalid_args)}' are not valid for transform {self.__class__.__name__}"
|
|
133
|
+
)
|
|
134
|
+
if value: # In strict mode
|
|
135
|
+
raise ValueError(message)
|
|
136
|
+
warn(message, stacklevel=2)
|
|
137
|
+
|
|
138
|
+
self._strict = value
|
|
139
|
+
|
|
140
|
+
def set_random_state(
|
|
141
|
+
self,
|
|
142
|
+
random_generator: np.random.Generator,
|
|
143
|
+
py_random: random.Random,
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Set random state directly from generators.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
random_generator (np.random.Generator): numpy random generator to use
|
|
149
|
+
py_random (random.Random): python random generator to use
|
|
150
|
+
|
|
151
|
+
"""
|
|
152
|
+
self.random_generator = random_generator
|
|
153
|
+
self.py_random = py_random
|
|
154
|
+
|
|
155
|
+
def set_random_seed(self, seed: int | None) -> None:
|
|
156
|
+
"""Set random state from seed.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
seed (int | None): Random seed to use
|
|
160
|
+
|
|
161
|
+
"""
|
|
162
|
+
self.seed = seed
|
|
163
|
+
self.random_generator = np.random.default_rng(seed)
|
|
164
|
+
self.py_random = random.Random(seed)
|
|
165
|
+
|
|
166
|
+
def get_dict_with_id(self) -> dict[str, Any]:
|
|
167
|
+
"""Return a dictionary representation of the transform with its ID.
|
|
168
|
+
|
|
169
|
+
Returns:
|
|
170
|
+
dict[str, Any]: Dictionary containing transform parameters and ID.
|
|
171
|
+
|
|
172
|
+
"""
|
|
173
|
+
d = self.to_dict_private()
|
|
174
|
+
d.update({"id": id(self)})
|
|
175
|
+
return d
|
|
176
|
+
|
|
177
|
+
def get_transform_init_args_names(self) -> tuple[str, ...]:
|
|
178
|
+
"""Returns names of arguments that are used in __init__ method of the transform.
|
|
179
|
+
|
|
180
|
+
This method introspects the entire Method Resolution Order (MRO) to gather the names
|
|
181
|
+
of parameters accepted by the __init__ methods of all parent classes,
|
|
182
|
+
to collect all possible parameters, excluding 'self' and 'strict'
|
|
183
|
+
which are handled separately.
|
|
184
|
+
"""
|
|
185
|
+
import inspect
|
|
186
|
+
|
|
187
|
+
all_param_names = set()
|
|
188
|
+
|
|
189
|
+
for cls in self.__class__.__mro__:
|
|
190
|
+
# Skip the class if it's the base object or doesn't define __init__
|
|
191
|
+
if cls is object or "__init__" not in cls.__dict__:
|
|
192
|
+
continue
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
# Access the class's __init__ method through __dict__ to avoid mypy errors
|
|
196
|
+
init_method = cls.__dict__["__init__"]
|
|
197
|
+
signature = inspect.signature(init_method)
|
|
198
|
+
for name, param in signature.parameters.items():
|
|
199
|
+
if param.kind in {inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY}:
|
|
200
|
+
all_param_names.add(name)
|
|
201
|
+
except (ValueError, TypeError):
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
# Exclude 'self' and 'strict'
|
|
205
|
+
return tuple(sorted(all_param_names - {"self", "strict"}))
|
|
206
|
+
|
|
207
|
+
def set_processors(self, processors: dict[str, BboxProcessor | KeypointsProcessor]) -> None:
|
|
208
|
+
"""Set the processors dictionary used for processing bbox and keypoint transformations.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
processors (dict[str, BboxProcessor | KeypointsProcessor]): Dictionary mapping processor
|
|
212
|
+
names to processor instances.
|
|
213
|
+
|
|
214
|
+
"""
|
|
215
|
+
self.processors = processors
|
|
216
|
+
|
|
217
|
+
def get_processor(self, key: str) -> BboxProcessor | KeypointsProcessor | None:
|
|
218
|
+
"""Get the processor for a specific key.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
key (str): The processor key to retrieve.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
BboxProcessor | KeypointsProcessor | None: The processor instance if found, None otherwise.
|
|
225
|
+
|
|
226
|
+
"""
|
|
227
|
+
return self.processors.get(key)
|
|
228
|
+
|
|
229
|
+
def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any:
|
|
230
|
+
"""Apply the transform to the input data.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
*args (Any): Positional arguments are not supported and will raise an error.
|
|
234
|
+
force_apply (bool, optional): If True, the transform will be applied regardless of probability.
|
|
235
|
+
**kwargs (Any): Input data to transform as named arguments.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
dict[str, Any]: Transformed data.
|
|
239
|
+
|
|
240
|
+
Raises:
|
|
241
|
+
KeyError: If positional arguments are provided.
|
|
242
|
+
|
|
243
|
+
"""
|
|
244
|
+
if args:
|
|
245
|
+
msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)"
|
|
246
|
+
raise KeyError(msg)
|
|
247
|
+
if self.replay_mode:
|
|
248
|
+
if self.applied_in_replay:
|
|
249
|
+
return self.apply_with_params(self.params, **kwargs)
|
|
250
|
+
return kwargs
|
|
251
|
+
|
|
252
|
+
# Reset params at the start of each call
|
|
253
|
+
self.params = {}
|
|
254
|
+
|
|
255
|
+
if self.should_apply(force_apply=force_apply):
|
|
256
|
+
params = self.get_params()
|
|
257
|
+
params = self.update_transform_params(params=params, data=kwargs)
|
|
258
|
+
|
|
259
|
+
if self.targets_as_params: # check if all required targets are in kwargs.
|
|
260
|
+
missing_keys = set(self.targets_as_params).difference(kwargs.keys())
|
|
261
|
+
if missing_keys and not (missing_keys == {"image"} and "images" in kwargs):
|
|
262
|
+
msg = f"{self.__class__.__name__} requires {self.targets_as_params} missing keys: {missing_keys}"
|
|
263
|
+
raise ValueError(msg)
|
|
264
|
+
|
|
265
|
+
params_dependent_on_data = self.get_params_dependent_on_data(params=params, data=kwargs)
|
|
266
|
+
params.update(params_dependent_on_data)
|
|
267
|
+
|
|
268
|
+
# Store the final params
|
|
269
|
+
self.params = params
|
|
270
|
+
|
|
271
|
+
if self.deterministic:
|
|
272
|
+
kwargs[self.save_key][id(self)] = deepcopy(params)
|
|
273
|
+
return self.apply_with_params(params, **kwargs)
|
|
274
|
+
|
|
275
|
+
return kwargs
|
|
276
|
+
|
|
277
|
+
def get_applied_params(self) -> dict[str, Any]:
|
|
278
|
+
"""Returns the parameters that were used in the last transform application.
|
|
279
|
+
Returns empty dict if transform was not applied.
|
|
280
|
+
"""
|
|
281
|
+
return self.params
|
|
282
|
+
|
|
283
|
+
def should_apply(self, force_apply: bool = False) -> bool:
|
|
284
|
+
"""Determine whether to apply the transform based on probability and force flag.
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
force_apply (bool, optional): If True, always apply the transform regardless of probability.
|
|
288
|
+
|
|
289
|
+
Returns:
|
|
290
|
+
bool: True if the transform should be applied, False otherwise.
|
|
291
|
+
|
|
292
|
+
"""
|
|
293
|
+
if self.p <= 0.0:
|
|
294
|
+
return False
|
|
295
|
+
if self.p >= 1.0 or force_apply:
|
|
296
|
+
return True
|
|
297
|
+
return self.py_random.random() < self.p
|
|
298
|
+
|
|
299
|
+
def apply_with_params(self, params: dict[str, Any], *args: Any, **kwargs: Any) -> dict[str, Any]:
|
|
300
|
+
"""Apply transforms with parameters."""
|
|
301
|
+
res = {}
|
|
302
|
+
for key, arg in kwargs.items():
|
|
303
|
+
if key in self._key2func and arg is not None:
|
|
304
|
+
# Handle empty lists for mask-like keys
|
|
305
|
+
if key in {"masks", "masks3d"} and isinstance(arg, (list, tuple)) and not arg:
|
|
306
|
+
res[key] = arg # Keep empty list as is
|
|
307
|
+
else:
|
|
308
|
+
target_function = self._key2func[key]
|
|
309
|
+
res[key] = ensure_contiguous_output(
|
|
310
|
+
target_function(ensure_contiguous_output(arg), **params),
|
|
311
|
+
)
|
|
312
|
+
else:
|
|
313
|
+
res[key] = arg
|
|
314
|
+
return res
|
|
315
|
+
|
|
316
|
+
def set_deterministic(self, flag: bool, save_key: str = "replay") -> BasicTransform:
|
|
317
|
+
"""Set transform to be deterministic."""
|
|
318
|
+
if save_key == "params":
|
|
319
|
+
msg = "params save_key is reserved"
|
|
320
|
+
raise KeyError(msg)
|
|
321
|
+
|
|
322
|
+
self.deterministic = flag
|
|
323
|
+
if self.deterministic and self.targets_as_params:
|
|
324
|
+
warn(
|
|
325
|
+
self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data"
|
|
326
|
+
" because its' params depend on targets.",
|
|
327
|
+
stacklevel=2,
|
|
328
|
+
)
|
|
329
|
+
self.save_key = save_key
|
|
330
|
+
return self
|
|
331
|
+
|
|
332
|
+
def __repr__(self) -> str:
|
|
333
|
+
state = self.get_base_init_args()
|
|
334
|
+
state.update(self.get_transform_init_args())
|
|
335
|
+
return f"{self.__class__.__name__}({format_args(state)})"
|
|
336
|
+
|
|
337
|
+
def apply(self, img: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
338
|
+
"""Apply transform on image."""
|
|
339
|
+
raise NotImplementedError
|
|
340
|
+
|
|
341
|
+
def apply_to_images(self, images: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
342
|
+
"""Apply transform on images.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
images (np.ndarray): Input images as numpy array of shape:
|
|
346
|
+
- (num_images, height, width, channels)
|
|
347
|
+
- (num_images, height, width) for grayscale
|
|
348
|
+
*args (Any): Additional positional arguments
|
|
349
|
+
**params (Any): Additional parameters specific to the transform
|
|
350
|
+
|
|
351
|
+
Returns:
|
|
352
|
+
np.ndarray: Transformed images as numpy array in the same format as input
|
|
353
|
+
|
|
354
|
+
"""
|
|
355
|
+
# Handle batched numpy array input
|
|
356
|
+
transformed = np.stack([self.apply(image, **params) for image in images])
|
|
357
|
+
return np.require(transformed, requirements=["C_CONTIGUOUS"])
|
|
358
|
+
|
|
359
|
+
def apply_to_volume(self, volume: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
360
|
+
"""Apply transform slice by slice to a volume.
|
|
361
|
+
|
|
362
|
+
Args:
|
|
363
|
+
volume (np.ndarray): Input volume of shape (depth, height, width) or (depth, height, width, channels)
|
|
364
|
+
*args (Any): Additional positional arguments
|
|
365
|
+
**params (Any): Additional parameters specific to the transform
|
|
366
|
+
|
|
367
|
+
Returns:
|
|
368
|
+
np.ndarray: Transformed volume as numpy array in the same format as input
|
|
369
|
+
|
|
370
|
+
"""
|
|
371
|
+
return self.apply_to_images(volume, *args, **params)
|
|
372
|
+
|
|
373
|
+
def apply_to_volumes(self, volumes: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
374
|
+
"""Apply transform to multiple volumes."""
|
|
375
|
+
return np.stack([self.apply_to_volume(vol, *args, **params) for vol in volumes])
|
|
376
|
+
|
|
377
|
+
def get_params(self) -> dict[str, Any]:
|
|
378
|
+
"""Returns parameters independent of input."""
|
|
379
|
+
return {}
|
|
380
|
+
|
|
381
|
+
def update_transform_params(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
382
|
+
"""Updates parameters with input shape and transform-specific params.
|
|
383
|
+
|
|
384
|
+
Args:
|
|
385
|
+
params (dict[str, Any]): Parameters to be updated
|
|
386
|
+
data (dict[str, Any]): Input data dictionary containing images/volumes
|
|
387
|
+
|
|
388
|
+
Returns:
|
|
389
|
+
dict[str, Any]: Updated parameters dictionary with shape and transform-specific params
|
|
390
|
+
|
|
391
|
+
"""
|
|
392
|
+
# Extract shape from volume, volumes, image, or images
|
|
393
|
+
if "volume" in data:
|
|
394
|
+
shape = data["volume"][0].shape # Take first slice of volume
|
|
395
|
+
elif "volumes" in data:
|
|
396
|
+
shape = data["volumes"][0][0].shape # Take first slice of first volume
|
|
397
|
+
elif "image" in data:
|
|
398
|
+
shape = data["image"].shape
|
|
399
|
+
else:
|
|
400
|
+
shape = data["images"][0].shape
|
|
401
|
+
|
|
402
|
+
# For volumes/images, shape will be either (H, W) or (H, W, C)
|
|
403
|
+
params["shape"] = shape
|
|
404
|
+
|
|
405
|
+
# Add transform-specific params
|
|
406
|
+
if hasattr(self, "interpolation"):
|
|
407
|
+
params["interpolation"] = self.interpolation
|
|
408
|
+
if hasattr(self, "fill"):
|
|
409
|
+
params["fill"] = self.fill
|
|
410
|
+
if hasattr(self, "fill_mask"):
|
|
411
|
+
params["fill_mask"] = self.fill_mask
|
|
412
|
+
|
|
413
|
+
return params
|
|
414
|
+
|
|
415
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
416
|
+
"""Returns parameters dependent on input."""
|
|
417
|
+
return params
|
|
418
|
+
|
|
419
|
+
@property
|
|
420
|
+
def targets(self) -> dict[str, Callable[..., Any]]:
|
|
421
|
+
"""Get mapping of target keys to their corresponding processing functions.
|
|
422
|
+
|
|
423
|
+
Returns:
|
|
424
|
+
dict[str, Callable[..., Any]]: Dictionary mapping target keys to their processing functions.
|
|
425
|
+
|
|
426
|
+
"""
|
|
427
|
+
# mapping for targets and methods for which they depend
|
|
428
|
+
# for example:
|
|
429
|
+
# >> {"image": self.apply}
|
|
430
|
+
# >> {"masks": self.apply_to_masks}
|
|
431
|
+
raise NotImplementedError
|
|
432
|
+
|
|
433
|
+
def _set_keys(self) -> None:
|
|
434
|
+
"""Set _available_keys."""
|
|
435
|
+
if not hasattr(self, "_targets"):
|
|
436
|
+
self._available_keys = set()
|
|
437
|
+
else:
|
|
438
|
+
self._available_keys = {
|
|
439
|
+
target.value.lower()
|
|
440
|
+
for target in (self._targets if isinstance(self._targets, tuple) else [self._targets])
|
|
441
|
+
}
|
|
442
|
+
self._available_keys.update(self.targets.keys())
|
|
443
|
+
self._key2func = {key: self.targets[key] for key in self._available_keys if key in self.targets}
|
|
444
|
+
|
|
445
|
+
@property
|
|
446
|
+
def available_keys(self) -> set[str]:
|
|
447
|
+
"""Returns set of available keys."""
|
|
448
|
+
return self._available_keys
|
|
449
|
+
|
|
450
|
+
def add_targets(self, additional_targets: dict[str, str]) -> None:
|
|
451
|
+
"""Add targets to transform them the same way as one of existing targets.
|
|
452
|
+
ex: {'target_image': 'image'}
|
|
453
|
+
ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'}
|
|
454
|
+
by the way you must have at least one object with key 'image'
|
|
455
|
+
|
|
456
|
+
Args:
|
|
457
|
+
additional_targets (dict[str, str]): keys - new target name, values
|
|
458
|
+
- old target name. ex: {'image2': 'image'}
|
|
459
|
+
|
|
460
|
+
"""
|
|
461
|
+
for k, v in additional_targets.items():
|
|
462
|
+
if k in self._additional_targets and v != self._additional_targets[k]:
|
|
463
|
+
raise ValueError(
|
|
464
|
+
f"Trying to overwrite existed additional targets. "
|
|
465
|
+
f"Key={k} Exists={self._additional_targets[k]} New value: {v}",
|
|
466
|
+
)
|
|
467
|
+
if v in self._available_keys:
|
|
468
|
+
self._additional_targets[k] = v
|
|
469
|
+
self._key2func[k] = self.targets[v]
|
|
470
|
+
self._available_keys.add(k)
|
|
471
|
+
|
|
472
|
+
@property
|
|
473
|
+
def targets_as_params(self) -> list[str]:
|
|
474
|
+
"""Targets used to get params dependent on targets.
|
|
475
|
+
This is used to check input has all required targets.
|
|
476
|
+
"""
|
|
477
|
+
return []
|
|
478
|
+
|
|
479
|
+
@classmethod
|
|
480
|
+
def get_class_fullname(cls) -> str:
|
|
481
|
+
"""Get the full qualified name of the class.
|
|
482
|
+
|
|
483
|
+
Returns:
|
|
484
|
+
str: The shortest class fullname.
|
|
485
|
+
|
|
486
|
+
"""
|
|
487
|
+
return get_shortest_class_fullname(cls)
|
|
488
|
+
|
|
489
|
+
@classmethod
|
|
490
|
+
def is_serializable(cls) -> bool:
|
|
491
|
+
"""Check if the transform class is serializable.
|
|
492
|
+
|
|
493
|
+
Returns:
|
|
494
|
+
bool: True if the class is serializable, False otherwise.
|
|
495
|
+
|
|
496
|
+
"""
|
|
497
|
+
return True
|
|
498
|
+
|
|
499
|
+
def get_base_init_args(self) -> dict[str, Any]:
|
|
500
|
+
"""Returns base init args - p"""
|
|
501
|
+
return {"p": self.p}
|
|
502
|
+
|
|
503
|
+
def get_transform_init_args(self) -> dict[str, Any]:
|
|
504
|
+
"""Get transform initialization arguments for serialization.
|
|
505
|
+
|
|
506
|
+
Returns a dictionary of parameter names and their values, excluding parameters
|
|
507
|
+
that are not actually set on the instance or that shouldn't be serialized.
|
|
508
|
+
"""
|
|
509
|
+
# Get the parameter names
|
|
510
|
+
arg_names = self.get_transform_init_args_names()
|
|
511
|
+
|
|
512
|
+
# Create a dictionary of parameter values
|
|
513
|
+
args = {}
|
|
514
|
+
for name in arg_names:
|
|
515
|
+
# Only include parameters that are actually set as instance attributes
|
|
516
|
+
# and have non-default values
|
|
517
|
+
if hasattr(self, name):
|
|
518
|
+
value = getattr(self, name)
|
|
519
|
+
# Skip attributes that are basic containers with no content
|
|
520
|
+
if not (isinstance(value, (list, dict, tuple, set)) and len(value) == 0):
|
|
521
|
+
args[name] = value
|
|
522
|
+
|
|
523
|
+
# Remove seed explicitly (it's not meant to be serialized)
|
|
524
|
+
args.pop("seed", None)
|
|
525
|
+
|
|
526
|
+
return args
|
|
527
|
+
|
|
528
|
+
def to_dict_private(self) -> dict[str, Any]:
|
|
529
|
+
"""Returns a dictionary representation of the transform, excluding internal parameters."""
|
|
530
|
+
state = {"__class_fullname__": self.get_class_fullname()}
|
|
531
|
+
state.update(self.get_base_init_args())
|
|
532
|
+
|
|
533
|
+
# Get transform init args (our improved method handles all types of transforms)
|
|
534
|
+
transform_args = self.get_transform_init_args()
|
|
535
|
+
|
|
536
|
+
# Add transform args to state
|
|
537
|
+
state.update(transform_args)
|
|
538
|
+
|
|
539
|
+
# Remove strict from serialization
|
|
540
|
+
state.pop("strict", None)
|
|
541
|
+
|
|
542
|
+
return state
|
|
543
|
+
|
|
544
|
+
|
|
545
|
+
class DualTransform(BasicTransform):
|
|
546
|
+
"""A base class for transformations that should be applied both to an image and its corresponding properties
|
|
547
|
+
such as masks, bounding boxes, and keypoints. This class ensures that when a transform is applied to an image,
|
|
548
|
+
all associated entities are transformed accordingly to maintain consistency between the image and its annotations.
|
|
549
|
+
|
|
550
|
+
Methods:
|
|
551
|
+
apply(img: np.ndarray, **params: Any) -> np.ndarray:
|
|
552
|
+
Apply the transform to the image.
|
|
553
|
+
|
|
554
|
+
img: Input image of shape (H, W, C) or (H, W) for grayscale.
|
|
555
|
+
**params: Additional parameters specific to the transform.
|
|
556
|
+
|
|
557
|
+
Returns Transformed image of the same shape as input.
|
|
558
|
+
|
|
559
|
+
apply_to_images(images: np.ndarray, **params: Any) -> np.ndarray:
|
|
560
|
+
Apply the transform to multiple images.
|
|
561
|
+
|
|
562
|
+
images: Input images of shape (N, H, W, C) or (N, H, W) for grayscale.
|
|
563
|
+
**params: Additional parameters specific to the transform.
|
|
564
|
+
|
|
565
|
+
Returns Transformed images in the same format as input.
|
|
566
|
+
|
|
567
|
+
apply_to_mask(mask: np.ndarray, **params: Any) -> np.ndarray:
|
|
568
|
+
Apply the transform to a mask.
|
|
569
|
+
|
|
570
|
+
mask: Input mask of shape (H, W), (H, W, C) for multi-channel masks
|
|
571
|
+
**params: Additional parameters specific to the transform.
|
|
572
|
+
|
|
573
|
+
Returns Transformed mask in the same format as input.
|
|
574
|
+
|
|
575
|
+
apply_to_masks(masks: np.ndarray, **params: Any) -> np.ndarray | list[np.ndarray]:
|
|
576
|
+
Apply the transform to multiple masks.
|
|
577
|
+
|
|
578
|
+
masks: Array of shape (N, H, W) or (N, H, W, C) where N is number of masks
|
|
579
|
+
**params: Additional parameters specific to the transform.
|
|
580
|
+
Returns Transformed masks in the same format as input.
|
|
581
|
+
|
|
582
|
+
apply_to_keypoints(keypoints: np.ndarray, **params: Any) -> np.ndarray:
|
|
583
|
+
Apply the transform to keypoints.
|
|
584
|
+
|
|
585
|
+
keypoints: Array of shape (N, 2+) where N is the number of keypoints.
|
|
586
|
+
**params: Additional parameters specific to the transform.
|
|
587
|
+
Returns Transformed keypoints array of shape (N, 2+).
|
|
588
|
+
|
|
589
|
+
apply_to_bboxes(bboxes: np.ndarray, **params: Any) -> np.ndarray:
|
|
590
|
+
Apply the transform to bounding boxes.
|
|
591
|
+
|
|
592
|
+
bboxes: Array of shape (N, 4+) where N is the number of bounding boxes,
|
|
593
|
+
and each row is in the format [x_min, y_min, x_max, y_max].
|
|
594
|
+
**params: Additional parameters specific to the transform.
|
|
595
|
+
|
|
596
|
+
Returns Transformed bounding boxes array of shape (N, 4+).
|
|
597
|
+
|
|
598
|
+
apply_to_volume(volume: np.ndarray, **params: Any) -> np.ndarray:
|
|
599
|
+
Apply the transform to a volume.
|
|
600
|
+
|
|
601
|
+
volume: Input volume of shape (D, H, W) or (D, H, W, C).
|
|
602
|
+
**params: Additional parameters specific to the transform.
|
|
603
|
+
|
|
604
|
+
Returns Transformed volume of the same shape as input.
|
|
605
|
+
|
|
606
|
+
apply_to_volumes(volumes: np.ndarray, **params: Any) -> np.ndarray:
|
|
607
|
+
Apply the transform to multiple volumes.
|
|
608
|
+
|
|
609
|
+
volumes: Input volumes of shape (N, D, H, W) or (N, D, H, W, C).
|
|
610
|
+
**params: Additional parameters specific to the transform.
|
|
611
|
+
|
|
612
|
+
Returns Transformed volumes in the same format as input.
|
|
613
|
+
|
|
614
|
+
apply_to_mask3d(mask: np.ndarray, **params: Any) -> np.ndarray:
|
|
615
|
+
Apply the transform to a 3D mask.
|
|
616
|
+
|
|
617
|
+
mask: Input 3D mask of shape (D, H, W) or (D, H, W, C)
|
|
618
|
+
**params: Additional parameters specific to the transform.
|
|
619
|
+
|
|
620
|
+
Returns Transformed 3D mask in the same format as input.
|
|
621
|
+
|
|
622
|
+
apply_to_masks3d(masks: np.ndarray, **params: Any) -> np.ndarray:
|
|
623
|
+
Apply the transform to multiple 3D masks.
|
|
624
|
+
|
|
625
|
+
masks: Input 3D masks of shape (N, D, H, W) or (N, D, H, W, C)
|
|
626
|
+
**params: Additional parameters specific to the transform.
|
|
627
|
+
|
|
628
|
+
Returns Transformed 3D masks in the same format as input.
|
|
629
|
+
|
|
630
|
+
Note:
|
|
631
|
+
- All `apply_*` methods should maintain the input shape and format of the data.
|
|
632
|
+
- When applying transforms to masks, ensure that discrete values (e.g., class labels) are preserved.
|
|
633
|
+
- For keypoints and bounding boxes, the transformation should maintain their relative positions
|
|
634
|
+
with respect to the transformed image.
|
|
635
|
+
- The difference between `apply_to_mask` and `apply_to_masks` is mainly in how they handle 3D arrays:
|
|
636
|
+
`apply_to_mask` treats a 3D array as a multi-channel mask, while `apply_to_masks` treats it as
|
|
637
|
+
multiple single-channel masks.
|
|
638
|
+
|
|
639
|
+
"""
|
|
640
|
+
|
|
641
|
+
@property
|
|
642
|
+
def targets(self) -> dict[str, Callable[..., Any]]:
|
|
643
|
+
"""Get mapping of target keys to their corresponding processing functions for DualTransform.
|
|
644
|
+
|
|
645
|
+
Returns:
|
|
646
|
+
dict[str, Callable[..., Any]]: Dictionary mapping target keys to their processing functions.
|
|
647
|
+
|
|
648
|
+
"""
|
|
649
|
+
return {
|
|
650
|
+
"image": self.apply,
|
|
651
|
+
"images": self.apply_to_images,
|
|
652
|
+
"mask": self.apply_to_mask,
|
|
653
|
+
"masks": self.apply_to_masks,
|
|
654
|
+
"mask3d": self.apply_to_mask3d,
|
|
655
|
+
"masks3d": self.apply_to_masks3d,
|
|
656
|
+
"bboxes": self.apply_to_bboxes,
|
|
657
|
+
"keypoints": self.apply_to_keypoints,
|
|
658
|
+
"volume": self.apply_to_volume,
|
|
659
|
+
"volumes": self.apply_to_volumes,
|
|
660
|
+
}
|
|
661
|
+
|
|
662
|
+
def apply_to_keypoints(self, keypoints: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
663
|
+
"""Apply transform to keypoints.
|
|
664
|
+
|
|
665
|
+
Args:
|
|
666
|
+
keypoints (np.ndarray): Array of keypoints of shape (N, 2+).
|
|
667
|
+
*args (Any): Additional positional arguments.
|
|
668
|
+
**params (Any): Additional parameters.
|
|
669
|
+
|
|
670
|
+
Raises:
|
|
671
|
+
NotImplementedError: This method must be implemented by subclass.
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
np.ndarray: Transformed keypoints.
|
|
675
|
+
|
|
676
|
+
"""
|
|
677
|
+
msg = f"Method apply_to_keypoints is not implemented in class {self.__class__.__name__}"
|
|
678
|
+
raise NotImplementedError(msg)
|
|
679
|
+
|
|
680
|
+
def apply_to_bboxes(self, bboxes: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
681
|
+
"""Apply transform to bounding boxes.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
bboxes (np.ndarray): Array of bounding boxes of shape (N, 4+).
|
|
685
|
+
*args (Any): Additional positional arguments.
|
|
686
|
+
**params (Any): Additional parameters.
|
|
687
|
+
|
|
688
|
+
Raises:
|
|
689
|
+
NotImplementedError: This method must be implemented by subclass.
|
|
690
|
+
|
|
691
|
+
Returns:
|
|
692
|
+
np.ndarray: Transformed bounding boxes.
|
|
693
|
+
|
|
694
|
+
"""
|
|
695
|
+
raise NotImplementedError(f"BBoxes not implemented for {self.__class__.__name__}")
|
|
696
|
+
|
|
697
|
+
def apply_to_mask(self, mask: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
698
|
+
"""Apply transform to mask.
|
|
699
|
+
|
|
700
|
+
Args:
|
|
701
|
+
mask (np.ndarray): Input mask.
|
|
702
|
+
*args (Any): Additional positional arguments.
|
|
703
|
+
**params (Any): Additional parameters.
|
|
704
|
+
|
|
705
|
+
Returns:
|
|
706
|
+
np.ndarray: Transformed mask.
|
|
707
|
+
|
|
708
|
+
"""
|
|
709
|
+
return self.apply(mask, *args, **params)
|
|
710
|
+
|
|
711
|
+
def apply_to_masks(self, masks: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
712
|
+
"""Apply transform to multiple masks.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
masks (np.ndarray): Input masks as numpy array
|
|
716
|
+
*args (Any): Additional positional arguments
|
|
717
|
+
**params (Any): Additional parameters specific to the transform
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
np.ndarray: Transformed masks as numpy array
|
|
721
|
+
|
|
722
|
+
"""
|
|
723
|
+
return np.stack([self.apply_to_mask(mask, **params) for mask in masks])
|
|
724
|
+
|
|
725
|
+
@batch_transform("spatial", has_batch_dim=False, has_depth_dim=True)
|
|
726
|
+
def apply_to_mask3d(self, mask3d: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
727
|
+
"""Apply transform to a 3D mask.
|
|
728
|
+
|
|
729
|
+
Args:
|
|
730
|
+
mask3d (np.ndarray): Input 3D mask as numpy array
|
|
731
|
+
*args (Any): Additional positional arguments
|
|
732
|
+
**params (Any): Additional parameters specific to the transform
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
np.ndarray: Transformed 3D mask as numpy array
|
|
736
|
+
|
|
737
|
+
"""
|
|
738
|
+
return self.apply_to_mask(mask3d, **params)
|
|
739
|
+
|
|
740
|
+
@batch_transform("spatial", has_batch_dim=True, has_depth_dim=True)
|
|
741
|
+
def apply_to_masks3d(self, masks3d: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
742
|
+
"""Apply transform to multiple 3D masks.
|
|
743
|
+
|
|
744
|
+
Args:
|
|
745
|
+
masks3d (np.ndarray): Input 3D masks as numpy array
|
|
746
|
+
*args (Any): Additional positional arguments
|
|
747
|
+
**params (Any): Additional parameters specific to the transform
|
|
748
|
+
|
|
749
|
+
Returns:
|
|
750
|
+
np.ndarray: Transformed 3D masks as numpy array
|
|
751
|
+
|
|
752
|
+
"""
|
|
753
|
+
return np.stack([self.apply_to_mask3d(mask3d, **params) for mask3d in masks3d])
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
class ImageOnlyTransform(BasicTransform):
|
|
757
|
+
"""Transform applied to image only."""
|
|
758
|
+
|
|
759
|
+
_targets = (Targets.IMAGE, Targets.VOLUME)
|
|
760
|
+
|
|
761
|
+
@property
|
|
762
|
+
def targets(self) -> dict[str, Callable[..., Any]]:
|
|
763
|
+
"""Get mapping of target keys to their corresponding processing functions for ImageOnlyTransform.
|
|
764
|
+
|
|
765
|
+
Returns:
|
|
766
|
+
dict[str, Callable[..., Any]]: Dictionary mapping target keys to their processing functions.
|
|
767
|
+
|
|
768
|
+
"""
|
|
769
|
+
return {
|
|
770
|
+
"image": self.apply,
|
|
771
|
+
"images": self.apply_to_images,
|
|
772
|
+
"volume": self.apply_to_volume,
|
|
773
|
+
"volumes": self.apply_to_volumes,
|
|
774
|
+
}
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
class NoOp(DualTransform):
|
|
778
|
+
"""Identity transform (does nothing).
|
|
779
|
+
|
|
780
|
+
Targets:
|
|
781
|
+
image, mask, bboxes, keypoints, volume, mask3d
|
|
782
|
+
|
|
783
|
+
Examples:
|
|
784
|
+
>>> import numpy as np
|
|
785
|
+
>>> import albumentations as A
|
|
786
|
+
>>>
|
|
787
|
+
>>> # Prepare sample data
|
|
788
|
+
>>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
789
|
+
>>> mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
790
|
+
>>> bboxes = np.array([[10, 10, 50, 50], [40, 40, 80, 80]], dtype=np.float32)
|
|
791
|
+
>>> bbox_labels = [1, 2]
|
|
792
|
+
>>> keypoints = np.array([[20, 30], [60, 70]], dtype=np.float32)
|
|
793
|
+
>>> keypoint_labels = [0, 1]
|
|
794
|
+
>>>
|
|
795
|
+
>>> # Create transform pipeline with NoOp
|
|
796
|
+
>>> transform = A.Compose([
|
|
797
|
+
... A.NoOp(p=1.0), # Always applied, but does nothing
|
|
798
|
+
... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bbox_labels']),
|
|
799
|
+
... keypoint_params=A.KeypointParams(format='xy', label_fields=['keypoint_labels']))
|
|
800
|
+
>>>
|
|
801
|
+
>>> # Apply the transform
|
|
802
|
+
>>> transformed = transform(
|
|
803
|
+
... image=image,
|
|
804
|
+
... mask=mask,
|
|
805
|
+
... bboxes=bboxes,
|
|
806
|
+
... bbox_labels=bbox_labels,
|
|
807
|
+
... keypoints=keypoints,
|
|
808
|
+
... keypoint_labels=keypoint_labels
|
|
809
|
+
... )
|
|
810
|
+
>>>
|
|
811
|
+
>>> # Verify nothing has changed
|
|
812
|
+
>>> np.array_equal(image, transformed['image']) # True
|
|
813
|
+
>>> np.array_equal(mask, transformed['mask']) # True
|
|
814
|
+
>>> np.array_equal(bboxes, transformed['bboxes']) # True
|
|
815
|
+
>>> np.array_equal(keypoints, transformed['keypoints']) # True
|
|
816
|
+
>>> bbox_labels == transformed['bbox_labels'] # True
|
|
817
|
+
>>> keypoint_labels == transformed['keypoint_labels'] # True
|
|
818
|
+
>>>
|
|
819
|
+
>>> # NoOp is often used as a placeholder or for testing
|
|
820
|
+
>>> # For example, in conditional transforms:
|
|
821
|
+
>>> condition = False # Some condition
|
|
822
|
+
>>> transform = A.Compose([
|
|
823
|
+
... A.HorizontalFlip(p=1.0) if condition else A.NoOp(p=1.0)
|
|
824
|
+
... ])
|
|
825
|
+
|
|
826
|
+
"""
|
|
827
|
+
|
|
828
|
+
_targets = ALL_TARGETS
|
|
829
|
+
|
|
830
|
+
def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray:
|
|
831
|
+
"""Apply transform to keypoints (identity operation).
|
|
832
|
+
|
|
833
|
+
Args:
|
|
834
|
+
keypoints (np.ndarray): Array of keypoints.
|
|
835
|
+
**params (Any): Additional parameters.
|
|
836
|
+
|
|
837
|
+
Returns:
|
|
838
|
+
np.ndarray: Unchanged keypoints array.
|
|
839
|
+
|
|
840
|
+
"""
|
|
841
|
+
return keypoints
|
|
842
|
+
|
|
843
|
+
def apply_to_bboxes(self, bboxes: np.ndarray, **params: Any) -> np.ndarray:
|
|
844
|
+
"""Apply transform to bounding boxes (identity operation).
|
|
845
|
+
|
|
846
|
+
Args:
|
|
847
|
+
bboxes (np.ndarray): Array of bounding boxes.
|
|
848
|
+
**params (Any): Additional parameters.
|
|
849
|
+
|
|
850
|
+
Returns:
|
|
851
|
+
np.ndarray: Unchanged bounding boxes array.
|
|
852
|
+
|
|
853
|
+
"""
|
|
854
|
+
return bboxes
|
|
855
|
+
|
|
856
|
+
def apply(self, img: np.ndarray, **params: Any) -> np.ndarray:
|
|
857
|
+
"""Apply transform to image (identity operation).
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
img (np.ndarray): Input image.
|
|
861
|
+
**params (Any): Additional parameters.
|
|
862
|
+
|
|
863
|
+
Returns:
|
|
864
|
+
np.ndarray: Unchanged image.
|
|
865
|
+
|
|
866
|
+
"""
|
|
867
|
+
return img
|
|
868
|
+
|
|
869
|
+
def apply_to_mask(self, mask: np.ndarray, **params: Any) -> np.ndarray:
|
|
870
|
+
"""Apply transform to mask (identity operation).
|
|
871
|
+
|
|
872
|
+
Args:
|
|
873
|
+
mask (np.ndarray): Input mask.
|
|
874
|
+
**params (Any): Additional parameters.
|
|
875
|
+
|
|
876
|
+
Returns:
|
|
877
|
+
np.ndarray: Unchanged mask.
|
|
878
|
+
|
|
879
|
+
"""
|
|
880
|
+
return mask
|
|
881
|
+
|
|
882
|
+
def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
|
|
883
|
+
"""Apply transform to volume (identity operation).
|
|
884
|
+
|
|
885
|
+
Args:
|
|
886
|
+
volume (np.ndarray): Input volume.
|
|
887
|
+
**params (Any): Additional parameters.
|
|
888
|
+
|
|
889
|
+
Returns:
|
|
890
|
+
np.ndarray: Unchanged volume.
|
|
891
|
+
|
|
892
|
+
"""
|
|
893
|
+
return volume
|
|
894
|
+
|
|
895
|
+
def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
|
|
896
|
+
"""Apply transform to multiple volumes (identity operation).
|
|
897
|
+
|
|
898
|
+
Args:
|
|
899
|
+
volumes (np.ndarray): Input volumes.
|
|
900
|
+
**params (Any): Additional parameters.
|
|
901
|
+
|
|
902
|
+
Returns:
|
|
903
|
+
np.ndarray: Unchanged volumes.
|
|
904
|
+
|
|
905
|
+
"""
|
|
906
|
+
return volumes
|
|
907
|
+
|
|
908
|
+
def apply_to_mask3d(self, mask3d: np.ndarray, **params: Any) -> np.ndarray:
|
|
909
|
+
"""Apply transform to 3D mask (identity operation).
|
|
910
|
+
|
|
911
|
+
Args:
|
|
912
|
+
mask3d (np.ndarray): Input 3D mask.
|
|
913
|
+
**params (Any): Additional parameters.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
np.ndarray: Unchanged 3D mask.
|
|
917
|
+
|
|
918
|
+
"""
|
|
919
|
+
return mask3d
|
|
920
|
+
|
|
921
|
+
def apply_to_masks3d(self, masks3d: np.ndarray, **params: Any) -> np.ndarray:
|
|
922
|
+
"""Apply transform to multiple 3D masks (identity operation).
|
|
923
|
+
|
|
924
|
+
Args:
|
|
925
|
+
masks3d (np.ndarray): Input 3D masks.
|
|
926
|
+
**params (Any): Additional parameters.
|
|
927
|
+
|
|
928
|
+
Returns:
|
|
929
|
+
np.ndarray: Unchanged 3D masks.
|
|
930
|
+
|
|
931
|
+
"""
|
|
932
|
+
return masks3d
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
class Transform3D(DualTransform):
|
|
936
|
+
"""Base class for all 3D transforms.
|
|
937
|
+
|
|
938
|
+
Transform3D inherits from DualTransform because 3D transforms can be applied to both
|
|
939
|
+
volumes and masks, similar to how 2D DualTransforms work with images and masks.
|
|
940
|
+
|
|
941
|
+
Targets:
|
|
942
|
+
volume: 3D numpy array of shape (D, H, W) or (D, H, W, C)
|
|
943
|
+
volumes: Batch of 3D arrays of shape (N, D, H, W) or (N, D, H, W, C)
|
|
944
|
+
mask: 3D numpy array of shape (D, H, W)
|
|
945
|
+
masks: Batch of 3D arrays of shape (N, D, H, W)
|
|
946
|
+
keypoints: 3D numpy array of shape (N, 3)
|
|
947
|
+
"""
|
|
948
|
+
|
|
949
|
+
def apply_to_volume(self, volume: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
950
|
+
"""Apply transform to single 3D volume."""
|
|
951
|
+
raise NotImplementedError
|
|
952
|
+
|
|
953
|
+
@batch_transform("spatial", keep_depth_dim=True, has_batch_dim=True, has_depth_dim=True)
|
|
954
|
+
def apply_to_volumes(self, volumes: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
955
|
+
"""Apply transform to batch of 3D volumes."""
|
|
956
|
+
return self.apply_to_volume(volumes, *args, **params)
|
|
957
|
+
|
|
958
|
+
def apply_to_mask3d(self, mask3d: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
959
|
+
"""Apply transform to single 3D mask."""
|
|
960
|
+
return self.apply_to_volume(mask3d, *args, **params)
|
|
961
|
+
|
|
962
|
+
@batch_transform("spatial", keep_depth_dim=True, has_batch_dim=True, has_depth_dim=True)
|
|
963
|
+
def apply_to_masks3d(self, masks3d: np.ndarray, *args: Any, **params: Any) -> np.ndarray:
|
|
964
|
+
"""Apply transform to batch of 3D masks."""
|
|
965
|
+
return self.apply_to_mask3d(masks3d, *args, **params)
|
|
966
|
+
|
|
967
|
+
@property
|
|
968
|
+
def targets(self) -> dict[str, Callable[..., Any]]:
|
|
969
|
+
"""Define valid targets for 3D transforms."""
|
|
970
|
+
return {
|
|
971
|
+
"volume": self.apply_to_volume,
|
|
972
|
+
"volumes": self.apply_to_volumes,
|
|
973
|
+
"mask3d": self.apply_to_mask3d,
|
|
974
|
+
"masks3d": self.apply_to_masks3d,
|
|
975
|
+
"keypoints": self.apply_to_keypoints,
|
|
976
|
+
}
|