nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,1422 @@
|
|
|
1
|
+
"""Module containing 3D transformation classes for volumetric data augmentation.
|
|
2
|
+
|
|
3
|
+
This module provides a collection of transformation classes designed specifically for
|
|
4
|
+
3D volumetric data (such as medical CT/MRI scans). These transforms can manipulate properties
|
|
5
|
+
such as spatial dimensions, apply dropout effects, and perform symmetry operations on
|
|
6
|
+
3D volumes, masks, and keypoints. Each transformation inherits from a base transform
|
|
7
|
+
interface and implements specific 3D augmentation logic.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Annotated, Any, Literal, Union, cast
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
from pydantic import AfterValidator, field_validator, model_validator
|
|
16
|
+
from typing_extensions import Self
|
|
17
|
+
|
|
18
|
+
from albumentations.augmentations.geometric import functional as fgeometric
|
|
19
|
+
from albumentations.augmentations.transforms3d import functional as f3d
|
|
20
|
+
from albumentations.core.keypoints_utils import KeypointsProcessor
|
|
21
|
+
from albumentations.core.pydantic import check_range_bounds, nondecreasing
|
|
22
|
+
from albumentations.core.transforms_interface import BaseTransformInitSchema, Transform3D
|
|
23
|
+
from albumentations.core.type_definitions import Targets
|
|
24
|
+
|
|
25
|
+
__all__ = ["CenterCrop3D", "CoarseDropout3D", "CubicSymmetry", "Pad3D", "PadIfNeeded3D", "RandomCrop3D"]
|
|
26
|
+
|
|
27
|
+
NUM_DIMENSIONS = 3
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class BasePad3D(Transform3D):
|
|
31
|
+
"""Base class for 3D padding transforms.
|
|
32
|
+
|
|
33
|
+
This class serves as a foundation for all 3D transforms that perform padding operations
|
|
34
|
+
on volumetric data. It provides common functionality for padding 3D volumes, masks,
|
|
35
|
+
and processing 3D keypoints during padding operations.
|
|
36
|
+
|
|
37
|
+
The class handles different types of padding values (scalar or per-channel) and
|
|
38
|
+
provides separate fill values for volumes and masks.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
fill (tuple[float, ...] | float): Value to fill the padded voxels for volumes.
|
|
42
|
+
Can be a single value for all channels or a tuple of values per channel.
|
|
43
|
+
fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for 3D masks.
|
|
44
|
+
Can be a single value for all channels or a tuple of values per channel.
|
|
45
|
+
p (float): Probability of applying the transform. Default: 1.0.
|
|
46
|
+
|
|
47
|
+
Targets:
|
|
48
|
+
volume, mask3d, keypoints
|
|
49
|
+
|
|
50
|
+
Note:
|
|
51
|
+
This is a base class and not intended to be used directly. Use its derivatives
|
|
52
|
+
like Pad3D or PadIfNeeded3D instead, or create a custom padding transform
|
|
53
|
+
by inheriting from this class.
|
|
54
|
+
|
|
55
|
+
Examples:
|
|
56
|
+
>>> import numpy as np
|
|
57
|
+
>>> import albumentations as A
|
|
58
|
+
>>>
|
|
59
|
+
>>> # Example of a custom padding transform inheriting from BasePad3D
|
|
60
|
+
>>> class CustomPad3D(A.BasePad3D):
|
|
61
|
+
... def __init__(self, padding_size: tuple[int, int, int] = (5, 5, 5), *args, **kwargs):
|
|
62
|
+
... super().__init__(*args, **kwargs)
|
|
63
|
+
... self.padding_size = padding_size
|
|
64
|
+
...
|
|
65
|
+
... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
66
|
+
... # Create symmetric padding: same amount on all sides of each dimension
|
|
67
|
+
... pad_d, pad_h, pad_w = self.padding_size
|
|
68
|
+
... padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
|
|
69
|
+
... return {"padding": padding}
|
|
70
|
+
>>>
|
|
71
|
+
>>> # Prepare sample data
|
|
72
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
73
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
74
|
+
>>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
|
|
75
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
76
|
+
>>>
|
|
77
|
+
>>> # Use the custom transform in a pipeline
|
|
78
|
+
>>> transform = A.Compose([
|
|
79
|
+
... CustomPad3D(
|
|
80
|
+
... padding_size=(2, 10, 10),
|
|
81
|
+
... fill=0,
|
|
82
|
+
... fill_mask=1,
|
|
83
|
+
... p=1.0
|
|
84
|
+
... )
|
|
85
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
86
|
+
>>>
|
|
87
|
+
>>> # Apply the transform
|
|
88
|
+
>>> transformed = transform(
|
|
89
|
+
... volume=volume,
|
|
90
|
+
... mask3d=mask3d,
|
|
91
|
+
... keypoints=keypoints,
|
|
92
|
+
... keypoint_labels=keypoint_labels
|
|
93
|
+
... )
|
|
94
|
+
>>>
|
|
95
|
+
>>> # Get the transformed data
|
|
96
|
+
>>> transformed_volume = transformed["volume"] # Shape: (14, 120, 120)
|
|
97
|
+
>>> transformed_mask3d = transformed["mask3d"] # Shape: (14, 120, 120)
|
|
98
|
+
>>> transformed_keypoints = transformed["keypoints"] # Keypoints shifted by padding offsets
|
|
99
|
+
>>> transformed_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
100
|
+
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
_targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
|
|
104
|
+
|
|
105
|
+
class InitSchema(Transform3D.InitSchema):
|
|
106
|
+
fill: tuple[float, ...] | float
|
|
107
|
+
fill_mask: tuple[float, ...] | float
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
fill: tuple[float, ...] | float = 0,
|
|
112
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
113
|
+
p: float = 1.0,
|
|
114
|
+
):
|
|
115
|
+
super().__init__(p=p)
|
|
116
|
+
self.fill = fill
|
|
117
|
+
self.fill_mask = fill_mask
|
|
118
|
+
|
|
119
|
+
def apply_to_volume(
|
|
120
|
+
self,
|
|
121
|
+
volume: np.ndarray,
|
|
122
|
+
padding: tuple[int, int, int, int, int, int],
|
|
123
|
+
**params: Any,
|
|
124
|
+
) -> np.ndarray:
|
|
125
|
+
"""Apply padding to a 3D volume.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
129
|
+
padding (tuple[int, int, int, int, int, int]): Padding values in format:
|
|
130
|
+
(depth_front, depth_back, height_top, height_bottom, width_left, width_right)
|
|
131
|
+
**params (Any): Additional parameters
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
np.ndarray: Padded volume with same number of dimensions as input
|
|
135
|
+
|
|
136
|
+
"""
|
|
137
|
+
if padding == (0, 0, 0, 0, 0, 0):
|
|
138
|
+
return volume
|
|
139
|
+
return f3d.pad_3d_with_params(
|
|
140
|
+
volume=volume,
|
|
141
|
+
padding=padding,
|
|
142
|
+
value=self.fill,
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def apply_to_mask3d(
|
|
146
|
+
self,
|
|
147
|
+
mask3d: np.ndarray,
|
|
148
|
+
padding: tuple[int, int, int, int, int, int],
|
|
149
|
+
**params: Any,
|
|
150
|
+
) -> np.ndarray:
|
|
151
|
+
"""Apply padding to a 3D mask.
|
|
152
|
+
|
|
153
|
+
Args:
|
|
154
|
+
mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
|
|
155
|
+
padding (tuple[int, int, int, int, int, int]): Padding values in format:
|
|
156
|
+
(depth_front, depth_back, height_top, height_bottom, width_left, width_right)
|
|
157
|
+
**params (Any): Additional parameters
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
np.ndarray: Padded mask with same number of dimensions as input
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
if padding == (0, 0, 0, 0, 0, 0):
|
|
164
|
+
return mask3d
|
|
165
|
+
return f3d.pad_3d_with_params(
|
|
166
|
+
volume=mask3d,
|
|
167
|
+
padding=padding,
|
|
168
|
+
value=cast("Union[tuple[float, ...], float]", self.fill_mask),
|
|
169
|
+
)
|
|
170
|
+
|
|
171
|
+
def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray:
|
|
172
|
+
"""Apply padding to keypoints.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
176
|
+
The first three columns are x, y, z coordinates.
|
|
177
|
+
**params (Any): Additional parameters containing padding values
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
np.ndarray: Shifted keypoints with same shape as input
|
|
181
|
+
|
|
182
|
+
"""
|
|
183
|
+
padding = params["padding"]
|
|
184
|
+
shift_vector = np.array([padding[4], padding[2], padding[0]])
|
|
185
|
+
return fgeometric.shift_keypoints(keypoints, shift_vector)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class Pad3D(BasePad3D):
|
|
189
|
+
"""Pad the sides of a 3D volume by specified number of voxels.
|
|
190
|
+
|
|
191
|
+
Args:
|
|
192
|
+
padding (int, tuple[int, int, int] or tuple[int, int, int, int, int, int]): Padding values. Can be:
|
|
193
|
+
* int - pad all sides by this value
|
|
194
|
+
* tuple[int, int, int] - symmetric padding (depth, height, width) where each value
|
|
195
|
+
is applied to both sides of the corresponding dimension
|
|
196
|
+
* tuple[int, int, int, int, int, int] - explicit padding per side in order:
|
|
197
|
+
(depth_front, depth_back, height_top, height_bottom, width_left, width_right)
|
|
198
|
+
|
|
199
|
+
fill (tuple[float, ...] | float): Padding value for image
|
|
200
|
+
fill_mask (tuple[float, ...] | float): Padding value for mask
|
|
201
|
+
p (float): probability of applying the transform. Default: 1.0.
|
|
202
|
+
|
|
203
|
+
Targets:
|
|
204
|
+
volume, mask3d, keypoints
|
|
205
|
+
|
|
206
|
+
Image types:
|
|
207
|
+
uint8, float32
|
|
208
|
+
|
|
209
|
+
Note:
|
|
210
|
+
Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
|
|
211
|
+
with optional channel dimension as the last axis.
|
|
212
|
+
|
|
213
|
+
Examples:
|
|
214
|
+
>>> import numpy as np
|
|
215
|
+
>>> import albumentations as A
|
|
216
|
+
>>>
|
|
217
|
+
>>> # Prepare sample data
|
|
218
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
219
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
220
|
+
>>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
|
|
221
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
222
|
+
>>>
|
|
223
|
+
>>> # Create the transform with symmetric padding
|
|
224
|
+
>>> transform = A.Compose([
|
|
225
|
+
... A.Pad3D(
|
|
226
|
+
... padding=(2, 5, 10), # (depth, height, width) applied symmetrically
|
|
227
|
+
... fill=0,
|
|
228
|
+
... fill_mask=1,
|
|
229
|
+
... p=1.0
|
|
230
|
+
... )
|
|
231
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
232
|
+
>>>
|
|
233
|
+
>>> # Apply the transform
|
|
234
|
+
>>> transformed = transform(
|
|
235
|
+
... volume=volume,
|
|
236
|
+
... mask3d=mask3d,
|
|
237
|
+
... keypoints=keypoints,
|
|
238
|
+
... keypoint_labels=keypoint_labels
|
|
239
|
+
... )
|
|
240
|
+
>>>
|
|
241
|
+
>>> # Get the transformed data
|
|
242
|
+
>>> padded_volume = transformed["volume"] # Shape: (14, 110, 120)
|
|
243
|
+
>>> padded_mask3d = transformed["mask3d"] # Shape: (14, 110, 120)
|
|
244
|
+
>>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
|
|
245
|
+
>>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
246
|
+
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
class InitSchema(BasePad3D.InitSchema):
|
|
250
|
+
padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int]
|
|
251
|
+
|
|
252
|
+
@field_validator("padding")
|
|
253
|
+
@classmethod
|
|
254
|
+
def validate_padding(
|
|
255
|
+
cls,
|
|
256
|
+
v: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
|
|
257
|
+
) -> int | tuple[int, int, int] | tuple[int, int, int, int, int, int]:
|
|
258
|
+
"""Validate the padding parameter.
|
|
259
|
+
|
|
260
|
+
Args:
|
|
261
|
+
cls (type): The class object
|
|
262
|
+
v (int | tuple[int, int, int] | tuple[int, int, int, int, int, int]): The padding value to validate,
|
|
263
|
+
can be an integer or tuple of integers
|
|
264
|
+
|
|
265
|
+
Returns:
|
|
266
|
+
int | tuple[int, int, int] | tuple[int, int, int, int, int, int]: The validated padding value
|
|
267
|
+
|
|
268
|
+
Raises:
|
|
269
|
+
ValueError: If padding is negative or contains negative values
|
|
270
|
+
|
|
271
|
+
"""
|
|
272
|
+
if isinstance(v, int) and v < 0:
|
|
273
|
+
raise ValueError("Padding value must be non-negative")
|
|
274
|
+
if isinstance(v, tuple) and not all(isinstance(i, int) and i >= 0 for i in v):
|
|
275
|
+
raise ValueError("Padding tuple must contain non-negative integers")
|
|
276
|
+
|
|
277
|
+
return v
|
|
278
|
+
|
|
279
|
+
def __init__(
|
|
280
|
+
self,
|
|
281
|
+
padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
|
|
282
|
+
fill: tuple[float, ...] | float = 0,
|
|
283
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
284
|
+
p: float = 1.0,
|
|
285
|
+
):
|
|
286
|
+
super().__init__(fill=fill, fill_mask=fill_mask, p=p)
|
|
287
|
+
self.padding = padding
|
|
288
|
+
self.fill = fill
|
|
289
|
+
self.fill_mask = fill_mask
|
|
290
|
+
|
|
291
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
292
|
+
"""Get parameters dependent on input data.
|
|
293
|
+
|
|
294
|
+
Args:
|
|
295
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
296
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
297
|
+
|
|
298
|
+
Returns:
|
|
299
|
+
dict[str, Any]: Dictionary containing the padding parameter tuple in format:
|
|
300
|
+
(depth_front, depth_back, height_top, height_bottom, width_left, width_right)
|
|
301
|
+
|
|
302
|
+
"""
|
|
303
|
+
if isinstance(self.padding, int):
|
|
304
|
+
pad_d = pad_h = pad_w = self.padding
|
|
305
|
+
padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
|
|
306
|
+
elif len(self.padding) == NUM_DIMENSIONS:
|
|
307
|
+
pad_d, pad_h, pad_w = self.padding # type: ignore[misc]
|
|
308
|
+
padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
|
|
309
|
+
else:
|
|
310
|
+
padding = self.padding # type: ignore[assignment]
|
|
311
|
+
|
|
312
|
+
return {"padding": padding}
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PadIfNeeded3D(BasePad3D):
|
|
316
|
+
"""Pads the sides of a 3D volume if its dimensions are less than specified minimum dimensions.
|
|
317
|
+
If the pad_divisor_zyx is specified, the function additionally ensures that the volume
|
|
318
|
+
dimensions are divisible by these values.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
min_zyx (tuple[int, int, int] | None): Minimum desired size as (depth, height, width).
|
|
322
|
+
Ensures volume dimensions are at least these values.
|
|
323
|
+
If not specified, pad_divisor_zyx must be provided.
|
|
324
|
+
pad_divisor_zyx (tuple[int, int, int] | None): If set, pads each dimension to make it
|
|
325
|
+
divisible by corresponding value in format (depth_div, height_div, width_div).
|
|
326
|
+
If not specified, min_zyx must be provided.
|
|
327
|
+
position (Literal["center", "random"]): Position where the volume is to be placed after padding.
|
|
328
|
+
Default is 'center'.
|
|
329
|
+
fill (tuple[float, ...] | float): Value to fill the border voxels for volume. Default: 0
|
|
330
|
+
fill_mask (tuple[float, ...] | float): Value to fill the border voxels for masks. Default: 0
|
|
331
|
+
p (float): Probability of applying the transform. Default: 1.0
|
|
332
|
+
|
|
333
|
+
Targets:
|
|
334
|
+
volume, mask3d, keypoints
|
|
335
|
+
|
|
336
|
+
Image types:
|
|
337
|
+
uint8, float32
|
|
338
|
+
|
|
339
|
+
Note:
|
|
340
|
+
Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
|
|
341
|
+
with optional channel dimension as the last axis.
|
|
342
|
+
|
|
343
|
+
Examples:
|
|
344
|
+
>>> import numpy as np
|
|
345
|
+
>>> import albumentations as A
|
|
346
|
+
>>>
|
|
347
|
+
>>> # Prepare sample data
|
|
348
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
349
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
350
|
+
>>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
|
|
351
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
352
|
+
>>>
|
|
353
|
+
>>> # Create a transform with both min_zyx and pad_divisor_zyx
|
|
354
|
+
>>> transform = A.Compose([
|
|
355
|
+
... A.PadIfNeeded3D(
|
|
356
|
+
... min_zyx=(16, 128, 128), # Minimum size (depth, height, width)
|
|
357
|
+
... pad_divisor_zyx=(8, 16, 16), # Make dimensions divisible by these values
|
|
358
|
+
... position="center", # Center the volume in the padded space
|
|
359
|
+
... fill=0, # Fill value for volume
|
|
360
|
+
... fill_mask=1, # Fill value for mask
|
|
361
|
+
... p=1.0
|
|
362
|
+
... )
|
|
363
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
364
|
+
>>>
|
|
365
|
+
>>> # Apply the transform
|
|
366
|
+
>>> transformed = transform(
|
|
367
|
+
... volume=volume,
|
|
368
|
+
... mask3d=mask3d,
|
|
369
|
+
... keypoints=keypoints,
|
|
370
|
+
... keypoint_labels=keypoint_labels
|
|
371
|
+
... )
|
|
372
|
+
>>>
|
|
373
|
+
>>> # Get the transformed data
|
|
374
|
+
>>> padded_volume = transformed["volume"] # Shape: (16, 128, 128)
|
|
375
|
+
>>> padded_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
|
|
376
|
+
>>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
|
|
377
|
+
>>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
378
|
+
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
class InitSchema(BasePad3D.InitSchema):
|
|
382
|
+
min_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(0, None))]
|
|
383
|
+
pad_divisor_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(1, None))]
|
|
384
|
+
position: Literal["center", "random"]
|
|
385
|
+
|
|
386
|
+
@model_validator(mode="after")
|
|
387
|
+
def validate_params(self) -> Self:
|
|
388
|
+
"""Validate that either min_zyx or pad_divisor_zyx is provided.
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Self: Self reference for method chaining
|
|
392
|
+
|
|
393
|
+
Raises:
|
|
394
|
+
ValueError: If both min_zyx and pad_divisor_zyx are None
|
|
395
|
+
|
|
396
|
+
"""
|
|
397
|
+
if self.min_zyx is None and self.pad_divisor_zyx is None:
|
|
398
|
+
msg = "At least one of min_zyx or pad_divisor_zyx must be set"
|
|
399
|
+
raise ValueError(msg)
|
|
400
|
+
return self
|
|
401
|
+
|
|
402
|
+
def __init__(
|
|
403
|
+
self,
|
|
404
|
+
min_zyx: tuple[int, int, int] | None = None,
|
|
405
|
+
pad_divisor_zyx: tuple[int, int, int] | None = None,
|
|
406
|
+
position: Literal["center", "random"] = "center",
|
|
407
|
+
fill: tuple[float, ...] | float = 0,
|
|
408
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
409
|
+
p: float = 1.0,
|
|
410
|
+
):
|
|
411
|
+
super().__init__(fill=fill, fill_mask=fill_mask, p=p)
|
|
412
|
+
self.min_zyx = min_zyx
|
|
413
|
+
self.pad_divisor_zyx = pad_divisor_zyx
|
|
414
|
+
self.position = position
|
|
415
|
+
|
|
416
|
+
def get_params_dependent_on_data(
|
|
417
|
+
self,
|
|
418
|
+
params: dict[str, Any],
|
|
419
|
+
data: dict[str, Any],
|
|
420
|
+
) -> dict[str, Any]:
|
|
421
|
+
"""Calculate padding parameters based on input data dimensions.
|
|
422
|
+
|
|
423
|
+
Args:
|
|
424
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
425
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
426
|
+
|
|
427
|
+
Returns:
|
|
428
|
+
dict[str, Any]: Dictionary containing calculated padding parameters
|
|
429
|
+
|
|
430
|
+
"""
|
|
431
|
+
depth, height, width = data["volume"].shape[:3]
|
|
432
|
+
sizes = (depth, height, width)
|
|
433
|
+
|
|
434
|
+
paddings = [
|
|
435
|
+
fgeometric.get_dimension_padding(
|
|
436
|
+
current_size=size,
|
|
437
|
+
min_size=self.min_zyx[i] if self.min_zyx else None,
|
|
438
|
+
divisor=self.pad_divisor_zyx[i] if self.pad_divisor_zyx else None,
|
|
439
|
+
)
|
|
440
|
+
for i, size in enumerate(sizes)
|
|
441
|
+
]
|
|
442
|
+
|
|
443
|
+
padding = f3d.adjust_padding_by_position3d(
|
|
444
|
+
paddings=paddings,
|
|
445
|
+
position=self.position,
|
|
446
|
+
py_random=self.py_random,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
return {"padding": padding}
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
class BaseCropAndPad3D(Transform3D):
|
|
453
|
+
"""Base class for 3D transforms that need both cropping and padding.
|
|
454
|
+
|
|
455
|
+
This class serves as a foundation for transforms that combine cropping and padding operations
|
|
456
|
+
on 3D volumetric data. It provides functionality for calculating padding parameters,
|
|
457
|
+
applying crop and pad operations to volumes, masks, and handling keypoint coordinate shifts.
|
|
458
|
+
|
|
459
|
+
Args:
|
|
460
|
+
pad_if_needed (bool): Whether to pad if the volume is smaller than target dimensions
|
|
461
|
+
fill (tuple[float, ...] | float): Value to fill the padded voxels for volume
|
|
462
|
+
fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for mask
|
|
463
|
+
pad_position (Literal["center", "random"]): How to distribute padding when needed
|
|
464
|
+
"center" - equal amount on both sides, "random" - random distribution
|
|
465
|
+
p (float): Probability of applying the transform. Default: 1.0
|
|
466
|
+
|
|
467
|
+
Targets:
|
|
468
|
+
volume, mask3d, keypoints
|
|
469
|
+
|
|
470
|
+
Note:
|
|
471
|
+
This is a base class and not intended to be used directly. Use its derivatives
|
|
472
|
+
like CenterCrop3D or RandomCrop3D instead, or create a custom transform
|
|
473
|
+
by inheriting from this class.
|
|
474
|
+
|
|
475
|
+
Examples:
|
|
476
|
+
>>> import numpy as np
|
|
477
|
+
>>> import albumentations as A
|
|
478
|
+
>>>
|
|
479
|
+
>>> # Example of a custom crop transform inheriting from BaseCropAndPad3D
|
|
480
|
+
>>> class CustomFixedCrop3D(A.BaseCropAndPad3D):
|
|
481
|
+
... def __init__(self, crop_size: tuple[int, int, int] = (8, 64, 64), *args, **kwargs):
|
|
482
|
+
... super().__init__(
|
|
483
|
+
... pad_if_needed=True,
|
|
484
|
+
... fill=0,
|
|
485
|
+
... fill_mask=0,
|
|
486
|
+
... pad_position="center",
|
|
487
|
+
... *args,
|
|
488
|
+
... **kwargs
|
|
489
|
+
... )
|
|
490
|
+
... self.crop_size = crop_size
|
|
491
|
+
...
|
|
492
|
+
... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
493
|
+
... # Get the volume shape
|
|
494
|
+
... volume = data["volume"]
|
|
495
|
+
... z, h, w = volume.shape[:3]
|
|
496
|
+
... target_z, target_h, target_w = self.crop_size
|
|
497
|
+
...
|
|
498
|
+
... # Check if padding is needed and calculate parameters
|
|
499
|
+
... pad_params = self._get_pad_params(
|
|
500
|
+
... image_shape=(z, h, w),
|
|
501
|
+
... target_shape=self.crop_size,
|
|
502
|
+
... )
|
|
503
|
+
...
|
|
504
|
+
... # Update dimensions if padding is applied
|
|
505
|
+
... if pad_params is not None:
|
|
506
|
+
... z = z + pad_params["pad_front"] + pad_params["pad_back"]
|
|
507
|
+
... h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
|
|
508
|
+
... w = w + pad_params["pad_left"] + pad_params["pad_right"]
|
|
509
|
+
...
|
|
510
|
+
... # Calculate fixed crop coordinates - always start at position (0,0,0)
|
|
511
|
+
... crop_coords = (0, target_z, 0, target_h, 0, target_w)
|
|
512
|
+
...
|
|
513
|
+
... return {
|
|
514
|
+
... "crop_coords": crop_coords,
|
|
515
|
+
... "pad_params": pad_params,
|
|
516
|
+
... }
|
|
517
|
+
>>>
|
|
518
|
+
>>> # Prepare sample data
|
|
519
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
520
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
521
|
+
>>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
|
|
522
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
523
|
+
>>>
|
|
524
|
+
>>> # Use the custom transform in a pipeline
|
|
525
|
+
>>> transform = A.Compose([
|
|
526
|
+
... CustomFixedCrop3D(
|
|
527
|
+
... crop_size=(8, 64, 64), # Crop first 8x64x64 voxels (with padding if needed)
|
|
528
|
+
... p=1.0
|
|
529
|
+
... )
|
|
530
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
531
|
+
>>>
|
|
532
|
+
>>> # Apply the transform
|
|
533
|
+
>>> transformed = transform(
|
|
534
|
+
... volume=volume,
|
|
535
|
+
... mask3d=mask3d,
|
|
536
|
+
... keypoints=keypoints,
|
|
537
|
+
... keypoint_labels=keypoint_labels
|
|
538
|
+
... )
|
|
539
|
+
>>>
|
|
540
|
+
>>> # Get the transformed data
|
|
541
|
+
>>> cropped_volume = transformed["volume"] # Shape: (8, 64, 64)
|
|
542
|
+
>>> cropped_mask3d = transformed["mask3d"] # Shape: (8, 64, 64)
|
|
543
|
+
>>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to crop
|
|
544
|
+
>>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
545
|
+
|
|
546
|
+
"""
|
|
547
|
+
|
|
548
|
+
_targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
|
|
549
|
+
|
|
550
|
+
class InitSchema(Transform3D.InitSchema):
|
|
551
|
+
pad_if_needed: bool
|
|
552
|
+
fill: tuple[float, ...] | float
|
|
553
|
+
fill_mask: tuple[float, ...] | float
|
|
554
|
+
pad_position: Literal["center", "random"]
|
|
555
|
+
|
|
556
|
+
def __init__(
|
|
557
|
+
self,
|
|
558
|
+
pad_if_needed: bool,
|
|
559
|
+
fill: tuple[float, ...] | float,
|
|
560
|
+
fill_mask: tuple[float, ...] | float,
|
|
561
|
+
pad_position: Literal["center", "random"],
|
|
562
|
+
p: float = 1.0,
|
|
563
|
+
):
|
|
564
|
+
super().__init__(p=p)
|
|
565
|
+
self.pad_if_needed = pad_if_needed
|
|
566
|
+
self.fill = fill
|
|
567
|
+
self.fill_mask = fill_mask
|
|
568
|
+
self.pad_position = pad_position
|
|
569
|
+
|
|
570
|
+
def _random_pad(self, pad: int) -> tuple[int, int]:
|
|
571
|
+
"""Generate random padding values.
|
|
572
|
+
|
|
573
|
+
Args:
|
|
574
|
+
pad (int): Total padding value to distribute
|
|
575
|
+
|
|
576
|
+
Returns:
|
|
577
|
+
tuple[int, int]: Random padding values (front, back)
|
|
578
|
+
|
|
579
|
+
"""
|
|
580
|
+
if pad > 0:
|
|
581
|
+
pad_start = self.py_random.randint(0, pad)
|
|
582
|
+
pad_end = pad - pad_start
|
|
583
|
+
else:
|
|
584
|
+
pad_start = pad_end = 0
|
|
585
|
+
return pad_start, pad_end
|
|
586
|
+
|
|
587
|
+
def _center_pad(self, pad: int) -> tuple[int, int]:
|
|
588
|
+
"""Generate centered padding values.
|
|
589
|
+
|
|
590
|
+
Args:
|
|
591
|
+
pad (int): Total padding value to distribute
|
|
592
|
+
|
|
593
|
+
Returns:
|
|
594
|
+
tuple[int, int]: Centered padding values (front, back)
|
|
595
|
+
|
|
596
|
+
"""
|
|
597
|
+
pad_start = pad // 2
|
|
598
|
+
pad_end = pad - pad_start
|
|
599
|
+
return pad_start, pad_end
|
|
600
|
+
|
|
601
|
+
def _get_pad_params(
|
|
602
|
+
self,
|
|
603
|
+
image_shape: tuple[int, int, int],
|
|
604
|
+
target_shape: tuple[int, int, int],
|
|
605
|
+
) -> dict[str, int] | None:
|
|
606
|
+
"""Calculate padding parameters to reach target shape.
|
|
607
|
+
|
|
608
|
+
Args:
|
|
609
|
+
image_shape (tuple[int, int, int]): Current shape (depth, height, width)
|
|
610
|
+
target_shape (tuple[int, int, int]): Target shape (depth, height, width)
|
|
611
|
+
|
|
612
|
+
Returns:
|
|
613
|
+
dict[str, int] | None: Padding parameters or None if no padding needed
|
|
614
|
+
|
|
615
|
+
"""
|
|
616
|
+
if not self.pad_if_needed:
|
|
617
|
+
return None
|
|
618
|
+
|
|
619
|
+
z, h, w = image_shape
|
|
620
|
+
target_z, target_h, target_w = target_shape
|
|
621
|
+
|
|
622
|
+
# Calculate total padding needed for each dimension
|
|
623
|
+
z_pad = max(0, target_z - z)
|
|
624
|
+
h_pad = max(0, target_h - h)
|
|
625
|
+
w_pad = max(0, target_w - w)
|
|
626
|
+
|
|
627
|
+
if z_pad == 0 and h_pad == 0 and w_pad == 0:
|
|
628
|
+
return None
|
|
629
|
+
|
|
630
|
+
# For center padding, split equally
|
|
631
|
+
if self.pad_position == "center":
|
|
632
|
+
z_front, z_back = self._center_pad(z_pad)
|
|
633
|
+
h_top, h_bottom = self._center_pad(h_pad)
|
|
634
|
+
w_left, w_right = self._center_pad(w_pad)
|
|
635
|
+
# For random padding, randomly distribute the padding
|
|
636
|
+
else: # random
|
|
637
|
+
z_front, z_back = self._random_pad(z_pad)
|
|
638
|
+
h_top, h_bottom = self._random_pad(h_pad)
|
|
639
|
+
w_left, w_right = self._random_pad(w_pad)
|
|
640
|
+
|
|
641
|
+
return {
|
|
642
|
+
"pad_front": z_front,
|
|
643
|
+
"pad_back": z_back,
|
|
644
|
+
"pad_top": h_top,
|
|
645
|
+
"pad_bottom": h_bottom,
|
|
646
|
+
"pad_left": w_left,
|
|
647
|
+
"pad_right": w_right,
|
|
648
|
+
}
|
|
649
|
+
|
|
650
|
+
def apply_to_volume(
|
|
651
|
+
self,
|
|
652
|
+
volume: np.ndarray,
|
|
653
|
+
crop_coords: tuple[int, int, int, int, int, int],
|
|
654
|
+
pad_params: dict[str, int] | None,
|
|
655
|
+
**params: Any,
|
|
656
|
+
) -> np.ndarray:
|
|
657
|
+
"""Apply cropping and padding to a 3D volume.
|
|
658
|
+
|
|
659
|
+
Args:
|
|
660
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
661
|
+
crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
|
|
662
|
+
pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
|
|
663
|
+
**params (Any): Additional parameters
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
np.ndarray: Cropped and padded volume with same number of dimensions as input
|
|
667
|
+
|
|
668
|
+
"""
|
|
669
|
+
# First crop
|
|
670
|
+
cropped = f3d.crop3d(volume, crop_coords)
|
|
671
|
+
|
|
672
|
+
# Then pad if needed
|
|
673
|
+
if pad_params is not None:
|
|
674
|
+
padding = (
|
|
675
|
+
pad_params["pad_front"],
|
|
676
|
+
pad_params["pad_back"],
|
|
677
|
+
pad_params["pad_top"],
|
|
678
|
+
pad_params["pad_bottom"],
|
|
679
|
+
pad_params["pad_left"],
|
|
680
|
+
pad_params["pad_right"],
|
|
681
|
+
)
|
|
682
|
+
return f3d.pad_3d_with_params(
|
|
683
|
+
cropped,
|
|
684
|
+
padding=padding,
|
|
685
|
+
value=self.fill,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
return cropped
|
|
689
|
+
|
|
690
|
+
def apply_to_mask3d(
|
|
691
|
+
self,
|
|
692
|
+
mask3d: np.ndarray,
|
|
693
|
+
crop_coords: tuple[int, int, int, int, int, int],
|
|
694
|
+
pad_params: dict[str, int] | None,
|
|
695
|
+
**params: Any,
|
|
696
|
+
) -> np.ndarray:
|
|
697
|
+
"""Apply cropping and padding to a 3D mask.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
|
|
701
|
+
crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
|
|
702
|
+
pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
|
|
703
|
+
**params (Any): Additional parameters
|
|
704
|
+
|
|
705
|
+
Returns:
|
|
706
|
+
np.ndarray: Cropped and padded mask with same number of dimensions as input
|
|
707
|
+
|
|
708
|
+
"""
|
|
709
|
+
# First crop
|
|
710
|
+
cropped = f3d.crop3d(mask3d, crop_coords)
|
|
711
|
+
|
|
712
|
+
# Then pad if needed
|
|
713
|
+
if pad_params is not None:
|
|
714
|
+
padding = (
|
|
715
|
+
pad_params["pad_front"],
|
|
716
|
+
pad_params["pad_back"],
|
|
717
|
+
pad_params["pad_top"],
|
|
718
|
+
pad_params["pad_bottom"],
|
|
719
|
+
pad_params["pad_left"],
|
|
720
|
+
pad_params["pad_right"],
|
|
721
|
+
)
|
|
722
|
+
return f3d.pad_3d_with_params(
|
|
723
|
+
cropped,
|
|
724
|
+
padding=padding,
|
|
725
|
+
value=cast("Union[tuple[float, ...], float]", self.fill_mask),
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
return cropped
|
|
729
|
+
|
|
730
|
+
def apply_to_keypoints(
|
|
731
|
+
self,
|
|
732
|
+
keypoints: np.ndarray,
|
|
733
|
+
crop_coords: tuple[int, int, int, int, int, int],
|
|
734
|
+
pad_params: dict[str, int] | None,
|
|
735
|
+
**params: Any,
|
|
736
|
+
) -> np.ndarray:
|
|
737
|
+
"""Apply cropping and padding to keypoints.
|
|
738
|
+
|
|
739
|
+
Args:
|
|
740
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
741
|
+
The first three columns are x, y, z coordinates.
|
|
742
|
+
crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
|
|
743
|
+
pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
|
|
744
|
+
**params (Any): Additional parameters
|
|
745
|
+
|
|
746
|
+
Returns:
|
|
747
|
+
np.ndarray: Shifted keypoints with same shape as input
|
|
748
|
+
|
|
749
|
+
"""
|
|
750
|
+
# Extract crop start coordinates (z1,y1,x1)
|
|
751
|
+
crop_z1, _, crop_y1, _, crop_x1, _ = crop_coords
|
|
752
|
+
|
|
753
|
+
# Initialize shift vector with negative crop coordinates
|
|
754
|
+
shift = np.array(
|
|
755
|
+
[
|
|
756
|
+
-crop_x1, # X shift
|
|
757
|
+
-crop_y1, # Y shift
|
|
758
|
+
-crop_z1, # Z shift
|
|
759
|
+
],
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
# Add padding shift if needed
|
|
763
|
+
if pad_params is not None:
|
|
764
|
+
shift += np.array(
|
|
765
|
+
[
|
|
766
|
+
pad_params["pad_left"], # X shift
|
|
767
|
+
pad_params["pad_top"], # Y shift
|
|
768
|
+
pad_params["pad_front"], # Z shift
|
|
769
|
+
],
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
# Apply combined shift
|
|
773
|
+
return fgeometric.shift_keypoints(keypoints, shift)
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
class CenterCrop3D(BaseCropAndPad3D):
|
|
777
|
+
"""Crop the center of 3D volume.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
|
|
781
|
+
pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
|
|
782
|
+
fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
|
|
783
|
+
fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
|
|
784
|
+
p (float): probability of applying the transform. Default: 1.0
|
|
785
|
+
|
|
786
|
+
Targets:
|
|
787
|
+
volume, mask3d, keypoints
|
|
788
|
+
|
|
789
|
+
Image types:
|
|
790
|
+
uint8, float32
|
|
791
|
+
|
|
792
|
+
Note:
|
|
793
|
+
If you want to perform cropping only in the XY plane while preserving all slices along
|
|
794
|
+
the Z axis, consider using CenterCrop instead. CenterCrop will apply the same XY crop
|
|
795
|
+
to each slice independently, maintaining the full depth of the volume.
|
|
796
|
+
|
|
797
|
+
Examples:
|
|
798
|
+
>>> import numpy as np
|
|
799
|
+
>>> import albumentations as A
|
|
800
|
+
>>>
|
|
801
|
+
>>> # Prepare sample data
|
|
802
|
+
>>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
|
|
803
|
+
>>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
|
|
804
|
+
>>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
|
|
805
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
806
|
+
>>>
|
|
807
|
+
>>> # Create the transform - crop to 16x128x128 from center
|
|
808
|
+
>>> transform = A.Compose([
|
|
809
|
+
... A.CenterCrop3D(
|
|
810
|
+
... size=(16, 128, 128), # Output size (depth, height, width)
|
|
811
|
+
... pad_if_needed=True, # Pad if input is smaller than crop size
|
|
812
|
+
... fill=0, # Fill value for volume padding
|
|
813
|
+
... fill_mask=1, # Fill value for mask padding
|
|
814
|
+
... p=1.0
|
|
815
|
+
... )
|
|
816
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
817
|
+
>>>
|
|
818
|
+
>>> # Apply the transform
|
|
819
|
+
>>> transformed = transform(
|
|
820
|
+
... volume=volume,
|
|
821
|
+
... mask3d=mask3d,
|
|
822
|
+
... keypoints=keypoints,
|
|
823
|
+
... keypoint_labels=keypoint_labels
|
|
824
|
+
... )
|
|
825
|
+
>>>
|
|
826
|
+
>>> # Get the transformed data
|
|
827
|
+
>>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
|
|
828
|
+
>>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
|
|
829
|
+
>>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to center crop
|
|
830
|
+
>>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
831
|
+
>>>
|
|
832
|
+
>>> # Example with a small volume that requires padding
|
|
833
|
+
>>> small_volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8)
|
|
834
|
+
>>> small_transform = A.Compose([
|
|
835
|
+
... A.CenterCrop3D(
|
|
836
|
+
... size=(16, 128, 128),
|
|
837
|
+
... pad_if_needed=True, # Will pad since the input is smaller
|
|
838
|
+
... fill=0,
|
|
839
|
+
... p=1.0
|
|
840
|
+
... )
|
|
841
|
+
... ])
|
|
842
|
+
>>> small_result = small_transform(volume=small_volume)
|
|
843
|
+
>>> padded_and_cropped = small_result["volume"] # Shape: (16, 128, 128), padded to size
|
|
844
|
+
|
|
845
|
+
"""
|
|
846
|
+
|
|
847
|
+
class InitSchema(BaseTransformInitSchema):
|
|
848
|
+
size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
|
|
849
|
+
pad_if_needed: bool
|
|
850
|
+
fill: tuple[float, ...] | float
|
|
851
|
+
fill_mask: tuple[float, ...] | float
|
|
852
|
+
|
|
853
|
+
def __init__(
|
|
854
|
+
self,
|
|
855
|
+
size: tuple[int, int, int],
|
|
856
|
+
pad_if_needed: bool = False,
|
|
857
|
+
fill: tuple[float, ...] | float = 0,
|
|
858
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
859
|
+
p: float = 1.0,
|
|
860
|
+
):
|
|
861
|
+
super().__init__(
|
|
862
|
+
pad_if_needed=pad_if_needed,
|
|
863
|
+
fill=fill,
|
|
864
|
+
fill_mask=fill_mask,
|
|
865
|
+
pad_position="center", # Center crop always uses center padding
|
|
866
|
+
p=p,
|
|
867
|
+
)
|
|
868
|
+
self.size = size
|
|
869
|
+
|
|
870
|
+
def get_params_dependent_on_data(
|
|
871
|
+
self,
|
|
872
|
+
params: dict[str, Any],
|
|
873
|
+
data: dict[str, Any],
|
|
874
|
+
) -> dict[str, Any]:
|
|
875
|
+
"""Calculate crop coordinates for center cropping.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
879
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
880
|
+
|
|
881
|
+
Returns:
|
|
882
|
+
dict[str, Any]: Dictionary containing crop coordinates and optional padding parameters
|
|
883
|
+
|
|
884
|
+
"""
|
|
885
|
+
volume = data["volume"]
|
|
886
|
+
z, h, w = volume.shape[:3]
|
|
887
|
+
target_z, target_h, target_w = self.size
|
|
888
|
+
|
|
889
|
+
# Get padding params if needed
|
|
890
|
+
pad_params = self._get_pad_params(
|
|
891
|
+
image_shape=(z, h, w),
|
|
892
|
+
target_shape=self.size,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
# Update dimensions if padding is applied
|
|
896
|
+
if pad_params is not None:
|
|
897
|
+
z = z + pad_params["pad_front"] + pad_params["pad_back"]
|
|
898
|
+
h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
|
|
899
|
+
w = w + pad_params["pad_left"] + pad_params["pad_right"]
|
|
900
|
+
|
|
901
|
+
# Validate dimensions after padding
|
|
902
|
+
if z < target_z or h < target_h or w < target_w:
|
|
903
|
+
msg = (
|
|
904
|
+
f"Crop size {self.size} is larger than padded image size ({z}, {h}, {w}). "
|
|
905
|
+
f"This should not happen - please report this as a bug."
|
|
906
|
+
)
|
|
907
|
+
raise ValueError(msg)
|
|
908
|
+
|
|
909
|
+
# For CenterCrop3D:
|
|
910
|
+
z_start = (z - target_z) // 2
|
|
911
|
+
h_start = (h - target_h) // 2
|
|
912
|
+
w_start = (w - target_w) // 2
|
|
913
|
+
|
|
914
|
+
crop_coords = (
|
|
915
|
+
z_start,
|
|
916
|
+
z_start + target_z,
|
|
917
|
+
h_start,
|
|
918
|
+
h_start + target_h,
|
|
919
|
+
w_start,
|
|
920
|
+
w_start + target_w,
|
|
921
|
+
)
|
|
922
|
+
|
|
923
|
+
return {
|
|
924
|
+
"crop_coords": crop_coords,
|
|
925
|
+
"pad_params": pad_params,
|
|
926
|
+
}
|
|
927
|
+
|
|
928
|
+
|
|
929
|
+
class RandomCrop3D(BaseCropAndPad3D):
|
|
930
|
+
"""Crop random part of 3D volume.
|
|
931
|
+
|
|
932
|
+
Args:
|
|
933
|
+
size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
|
|
934
|
+
pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
|
|
935
|
+
fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
|
|
936
|
+
fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
|
|
937
|
+
p (float): probability of applying the transform. Default: 1.0
|
|
938
|
+
|
|
939
|
+
Targets:
|
|
940
|
+
volume, mask3d, keypoints
|
|
941
|
+
|
|
942
|
+
Image types:
|
|
943
|
+
uint8, float32
|
|
944
|
+
|
|
945
|
+
Note:
|
|
946
|
+
If you want to perform random cropping only in the XY plane while preserving all slices along
|
|
947
|
+
the Z axis, consider using RandomCrop instead. RandomCrop will apply the same XY crop
|
|
948
|
+
to each slice independently, maintaining the full depth of the volume.
|
|
949
|
+
|
|
950
|
+
Examples:
|
|
951
|
+
>>> import numpy as np
|
|
952
|
+
>>> import albumentations as A
|
|
953
|
+
>>>
|
|
954
|
+
>>> # Prepare sample data
|
|
955
|
+
>>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
|
|
956
|
+
>>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
|
|
957
|
+
>>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
|
|
958
|
+
>>> keypoint_labels = [1, 2] # Labels for each keypoint
|
|
959
|
+
>>>
|
|
960
|
+
>>> # Create the transform with random crop and padding if needed
|
|
961
|
+
>>> transform = A.Compose([
|
|
962
|
+
... A.RandomCrop3D(
|
|
963
|
+
... size=(16, 128, 128), # Output size (depth, height, width)
|
|
964
|
+
... pad_if_needed=True, # Pad if input is smaller than crop size
|
|
965
|
+
... fill=0, # Fill value for volume padding
|
|
966
|
+
... fill_mask=1, # Fill value for mask padding
|
|
967
|
+
... p=1.0
|
|
968
|
+
... )
|
|
969
|
+
... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
|
|
970
|
+
>>>
|
|
971
|
+
>>> # Apply the transform
|
|
972
|
+
>>> transformed = transform(
|
|
973
|
+
... volume=volume,
|
|
974
|
+
... mask3d=mask3d,
|
|
975
|
+
... keypoints=keypoints,
|
|
976
|
+
... keypoint_labels=keypoint_labels
|
|
977
|
+
... )
|
|
978
|
+
>>>
|
|
979
|
+
>>> # Get the transformed data
|
|
980
|
+
>>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
|
|
981
|
+
>>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
|
|
982
|
+
>>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to random crop
|
|
983
|
+
>>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
|
|
984
|
+
|
|
985
|
+
"""
|
|
986
|
+
|
|
987
|
+
class InitSchema(BaseTransformInitSchema):
|
|
988
|
+
size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
|
|
989
|
+
pad_if_needed: bool
|
|
990
|
+
fill: tuple[float, ...] | float
|
|
991
|
+
fill_mask: tuple[float, ...] | float
|
|
992
|
+
|
|
993
|
+
def __init__(
|
|
994
|
+
self,
|
|
995
|
+
size: tuple[int, int, int],
|
|
996
|
+
pad_if_needed: bool = False,
|
|
997
|
+
fill: tuple[float, ...] | float = 0,
|
|
998
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
999
|
+
p: float = 1.0,
|
|
1000
|
+
):
|
|
1001
|
+
super().__init__(
|
|
1002
|
+
pad_if_needed=pad_if_needed,
|
|
1003
|
+
fill=fill,
|
|
1004
|
+
fill_mask=fill_mask,
|
|
1005
|
+
pad_position="random", # Random crop uses random padding position
|
|
1006
|
+
p=p,
|
|
1007
|
+
)
|
|
1008
|
+
self.size = size
|
|
1009
|
+
|
|
1010
|
+
def get_params_dependent_on_data(
|
|
1011
|
+
self,
|
|
1012
|
+
params: dict[str, Any],
|
|
1013
|
+
data: dict[str, Any],
|
|
1014
|
+
) -> dict[str, Any]:
|
|
1015
|
+
"""Calculate random crop coordinates.
|
|
1016
|
+
|
|
1017
|
+
Args:
|
|
1018
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
1019
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
1020
|
+
|
|
1021
|
+
Returns:
|
|
1022
|
+
dict[str, Any]: Dictionary containing randomly generated crop coordinates and optional padding parameters
|
|
1023
|
+
|
|
1024
|
+
"""
|
|
1025
|
+
volume = data["volume"]
|
|
1026
|
+
z, h, w = volume.shape[:3]
|
|
1027
|
+
target_z, target_h, target_w = self.size
|
|
1028
|
+
|
|
1029
|
+
# Get padding params if needed
|
|
1030
|
+
pad_params = self._get_pad_params(
|
|
1031
|
+
image_shape=(z, h, w),
|
|
1032
|
+
target_shape=self.size,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
# Update dimensions if padding is applied
|
|
1036
|
+
if pad_params is not None:
|
|
1037
|
+
z = z + pad_params["pad_front"] + pad_params["pad_back"]
|
|
1038
|
+
h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
|
|
1039
|
+
w = w + pad_params["pad_left"] + pad_params["pad_right"]
|
|
1040
|
+
|
|
1041
|
+
# Calculate random crop coordinates
|
|
1042
|
+
z_start = self.py_random.randint(0, max(0, z - target_z))
|
|
1043
|
+
h_start = self.py_random.randint(0, max(0, h - target_h))
|
|
1044
|
+
w_start = self.py_random.randint(0, max(0, w - target_w))
|
|
1045
|
+
|
|
1046
|
+
crop_coords = (
|
|
1047
|
+
z_start,
|
|
1048
|
+
z_start + target_z,
|
|
1049
|
+
h_start,
|
|
1050
|
+
h_start + target_h,
|
|
1051
|
+
w_start,
|
|
1052
|
+
w_start + target_w,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
return {
|
|
1056
|
+
"crop_coords": crop_coords,
|
|
1057
|
+
"pad_params": pad_params,
|
|
1058
|
+
}
|
|
1059
|
+
|
|
1060
|
+
|
|
1061
|
+
class CoarseDropout3D(Transform3D):
|
|
1062
|
+
"""CoarseDropout3D randomly drops out cuboid regions from a 3D volume and optionally,
|
|
1063
|
+
the corresponding regions in an associated 3D mask, to simulate occlusion and
|
|
1064
|
+
varied object sizes found in real-world volumetric data.
|
|
1065
|
+
|
|
1066
|
+
Args:
|
|
1067
|
+
num_holes_range (tuple[int, int]): Range (min, max) for the number of cuboid
|
|
1068
|
+
regions to drop out. Default: (1, 1)
|
|
1069
|
+
hole_depth_range (tuple[float, float]): Range (min, max) for the depth
|
|
1070
|
+
of dropout regions as a fraction of the volume depth (between 0 and 1). Default: (0.1, 0.2)
|
|
1071
|
+
hole_height_range (tuple[float, float]): Range (min, max) for the height
|
|
1072
|
+
of dropout regions as a fraction of the volume height (between 0 and 1). Default: (0.1, 0.2)
|
|
1073
|
+
hole_width_range (tuple[float, float]): Range (min, max) for the width
|
|
1074
|
+
of dropout regions as a fraction of the volume width (between 0 and 1). Default: (0.1, 0.2)
|
|
1075
|
+
fill (tuple[float, float] | float): Value for the dropped voxels. Can be:
|
|
1076
|
+
- int or float: all channels are filled with this value
|
|
1077
|
+
- tuple: tuple of values for each channel
|
|
1078
|
+
Default: 0
|
|
1079
|
+
fill_mask (tuple[float, float] | float | None): Fill value for dropout regions in the 3D mask.
|
|
1080
|
+
If None, mask regions corresponding to volume dropouts are unchanged. Default: None
|
|
1081
|
+
p (float): Probability of applying the transform. Default: 0.5
|
|
1082
|
+
|
|
1083
|
+
Targets:
|
|
1084
|
+
volume, mask3d, keypoints
|
|
1085
|
+
|
|
1086
|
+
Image types:
|
|
1087
|
+
uint8, float32
|
|
1088
|
+
|
|
1089
|
+
Note:
|
|
1090
|
+
- The actual number and size of dropout regions are randomly chosen within the specified ranges.
|
|
1091
|
+
- All values in hole_depth_range, hole_height_range and hole_width_range must be between 0 and 1.
|
|
1092
|
+
- If you want to apply dropout only in the XY plane while preserving the full depth dimension,
|
|
1093
|
+
consider using CoarseDropout instead. CoarseDropout will apply the same rectangular dropout
|
|
1094
|
+
to each slice independently, effectively creating cylindrical dropout regions that extend
|
|
1095
|
+
through the entire depth of the volume.
|
|
1096
|
+
|
|
1097
|
+
Examples:
|
|
1098
|
+
>>> import numpy as np
|
|
1099
|
+
>>> import albumentations as A
|
|
1100
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
1101
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
1102
|
+
>>> aug = A.CoarseDropout3D(
|
|
1103
|
+
... num_holes_range=(3, 6),
|
|
1104
|
+
... hole_depth_range=(0.1, 0.2),
|
|
1105
|
+
... hole_height_range=(0.1, 0.2),
|
|
1106
|
+
... hole_width_range=(0.1, 0.2),
|
|
1107
|
+
... fill=0,
|
|
1108
|
+
... p=1.0
|
|
1109
|
+
... )
|
|
1110
|
+
>>> transformed = aug(volume=volume, mask3d=mask3d)
|
|
1111
|
+
>>> transformed_volume, transformed_mask3d = transformed["volume"], transformed["mask3d"]
|
|
1112
|
+
|
|
1113
|
+
"""
|
|
1114
|
+
|
|
1115
|
+
_targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
|
|
1116
|
+
|
|
1117
|
+
class InitSchema(Transform3D.InitSchema):
|
|
1118
|
+
num_holes_range: Annotated[
|
|
1119
|
+
tuple[int, int],
|
|
1120
|
+
AfterValidator(check_range_bounds(0, None)),
|
|
1121
|
+
AfterValidator(nondecreasing),
|
|
1122
|
+
]
|
|
1123
|
+
hole_depth_range: Annotated[
|
|
1124
|
+
tuple[float, float],
|
|
1125
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
1126
|
+
AfterValidator(nondecreasing),
|
|
1127
|
+
]
|
|
1128
|
+
hole_height_range: Annotated[
|
|
1129
|
+
tuple[float, float],
|
|
1130
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
1131
|
+
AfterValidator(nondecreasing),
|
|
1132
|
+
]
|
|
1133
|
+
hole_width_range: Annotated[
|
|
1134
|
+
tuple[float, float],
|
|
1135
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
1136
|
+
AfterValidator(nondecreasing),
|
|
1137
|
+
]
|
|
1138
|
+
fill: tuple[float, ...] | float
|
|
1139
|
+
fill_mask: tuple[float, ...] | float | None
|
|
1140
|
+
|
|
1141
|
+
@staticmethod
|
|
1142
|
+
def validate_range(range_value: tuple[float, float], range_name: str) -> None:
|
|
1143
|
+
"""Validate that range values are between 0 and 1 and in non-decreasing order.
|
|
1144
|
+
|
|
1145
|
+
Args:
|
|
1146
|
+
range_value (tuple[float, float]): Tuple of (min, max) values to check
|
|
1147
|
+
range_name (str): Name of the range for error reporting
|
|
1148
|
+
|
|
1149
|
+
Raises:
|
|
1150
|
+
ValueError: If range values are invalid
|
|
1151
|
+
|
|
1152
|
+
"""
|
|
1153
|
+
if not 0 <= range_value[0] <= range_value[1] <= 1:
|
|
1154
|
+
raise ValueError(
|
|
1155
|
+
f"All values in {range_name} should be in [0, 1] range and first value "
|
|
1156
|
+
f"should be less or equal than the second value. Got: {range_value}",
|
|
1157
|
+
)
|
|
1158
|
+
|
|
1159
|
+
@model_validator(mode="after")
|
|
1160
|
+
def _check_ranges(self) -> Self:
|
|
1161
|
+
self.validate_range(self.hole_depth_range, "hole_depth_range")
|
|
1162
|
+
self.validate_range(self.hole_height_range, "hole_height_range")
|
|
1163
|
+
self.validate_range(self.hole_width_range, "hole_width_range")
|
|
1164
|
+
return self
|
|
1165
|
+
|
|
1166
|
+
def __init__(
|
|
1167
|
+
self,
|
|
1168
|
+
num_holes_range: tuple[int, int] = (1, 1),
|
|
1169
|
+
hole_depth_range: tuple[float, float] = (0.1, 0.2),
|
|
1170
|
+
hole_height_range: tuple[float, float] = (0.1, 0.2),
|
|
1171
|
+
hole_width_range: tuple[float, float] = (0.1, 0.2),
|
|
1172
|
+
fill: tuple[float, ...] | float = 0,
|
|
1173
|
+
fill_mask: tuple[float, ...] | float | None = None,
|
|
1174
|
+
p: float = 0.5,
|
|
1175
|
+
):
|
|
1176
|
+
super().__init__(p=p)
|
|
1177
|
+
self.num_holes_range = num_holes_range
|
|
1178
|
+
self.hole_depth_range = hole_depth_range
|
|
1179
|
+
self.hole_height_range = hole_height_range
|
|
1180
|
+
self.hole_width_range = hole_width_range
|
|
1181
|
+
self.fill = fill
|
|
1182
|
+
self.fill_mask = fill_mask
|
|
1183
|
+
|
|
1184
|
+
def calculate_hole_dimensions(
|
|
1185
|
+
self,
|
|
1186
|
+
volume_shape: tuple[int, int, int],
|
|
1187
|
+
depth_range: tuple[float, float],
|
|
1188
|
+
height_range: tuple[float, float],
|
|
1189
|
+
width_range: tuple[float, float],
|
|
1190
|
+
size: int,
|
|
1191
|
+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
1192
|
+
"""Calculate dimensions for dropout holes.
|
|
1193
|
+
|
|
1194
|
+
Args:
|
|
1195
|
+
volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width)
|
|
1196
|
+
depth_range (tuple[float, float]): Range for hole depth as fraction of volume depth
|
|
1197
|
+
height_range (tuple[float, float]): Range for hole height as fraction of volume height
|
|
1198
|
+
width_range (tuple[float, float]): Range for hole width as fraction of volume width
|
|
1199
|
+
size (int): Number of holes to generate
|
|
1200
|
+
|
|
1201
|
+
Returns:
|
|
1202
|
+
tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays of hole dimensions (depths, heights, widths)
|
|
1203
|
+
|
|
1204
|
+
"""
|
|
1205
|
+
depth, height, width = volume_shape[:3]
|
|
1206
|
+
|
|
1207
|
+
hole_depths = np.maximum(1, np.ceil(depth * self.random_generator.uniform(*depth_range, size=size))).astype(int)
|
|
1208
|
+
hole_heights = np.maximum(1, np.ceil(height * self.random_generator.uniform(*height_range, size=size))).astype(
|
|
1209
|
+
int,
|
|
1210
|
+
)
|
|
1211
|
+
hole_widths = np.maximum(1, np.ceil(width * self.random_generator.uniform(*width_range, size=size))).astype(int)
|
|
1212
|
+
|
|
1213
|
+
return hole_depths, hole_heights, hole_widths
|
|
1214
|
+
|
|
1215
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
1216
|
+
"""Generate parameters for coarse dropout based on input data.
|
|
1217
|
+
|
|
1218
|
+
Args:
|
|
1219
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
1220
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
dict[str, Any]: Dictionary containing generated hole parameters for dropout
|
|
1224
|
+
|
|
1225
|
+
"""
|
|
1226
|
+
volume_shape = data["volume"].shape[:3]
|
|
1227
|
+
|
|
1228
|
+
num_holes = self.py_random.randint(*self.num_holes_range)
|
|
1229
|
+
|
|
1230
|
+
hole_depths, hole_heights, hole_widths = self.calculate_hole_dimensions(
|
|
1231
|
+
volume_shape,
|
|
1232
|
+
self.hole_depth_range,
|
|
1233
|
+
self.hole_height_range,
|
|
1234
|
+
self.hole_width_range,
|
|
1235
|
+
size=num_holes,
|
|
1236
|
+
)
|
|
1237
|
+
|
|
1238
|
+
depth, height, width = volume_shape[:3]
|
|
1239
|
+
|
|
1240
|
+
z_min = self.random_generator.integers(0, depth - hole_depths + 1, size=num_holes)
|
|
1241
|
+
y_min = self.random_generator.integers(0, height - hole_heights + 1, size=num_holes)
|
|
1242
|
+
x_min = self.random_generator.integers(0, width - hole_widths + 1, size=num_holes)
|
|
1243
|
+
z_max = z_min + hole_depths
|
|
1244
|
+
y_max = y_min + hole_heights
|
|
1245
|
+
x_max = x_min + hole_widths
|
|
1246
|
+
|
|
1247
|
+
holes = np.stack([z_min, y_min, x_min, z_max, y_max, x_max], axis=-1)
|
|
1248
|
+
|
|
1249
|
+
return {"holes": holes}
|
|
1250
|
+
|
|
1251
|
+
def apply_to_volume(self, volume: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
|
|
1252
|
+
"""Apply dropout to a 3D volume.
|
|
1253
|
+
|
|
1254
|
+
Args:
|
|
1255
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
1256
|
+
holes (np.ndarray): Array of holes with shape (num_holes, 6).
|
|
1257
|
+
Each hole is represented as [z1, y1, x1, z2, y2, x2]
|
|
1258
|
+
**params (Any): Additional parameters
|
|
1259
|
+
|
|
1260
|
+
Returns:
|
|
1261
|
+
np.ndarray: Volume with holes filled with the given value
|
|
1262
|
+
|
|
1263
|
+
"""
|
|
1264
|
+
if holes.size == 0:
|
|
1265
|
+
return volume
|
|
1266
|
+
|
|
1267
|
+
return f3d.cutout3d(volume, holes, self.fill)
|
|
1268
|
+
|
|
1269
|
+
def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
|
|
1270
|
+
"""Apply dropout to a 3D mask.
|
|
1271
|
+
|
|
1272
|
+
Args:
|
|
1273
|
+
mask (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
|
|
1274
|
+
holes (np.ndarray): Array of holes with shape (num_holes, 6).
|
|
1275
|
+
Each hole is represented as [z1, y1, x1, z2, y2, x2]
|
|
1276
|
+
**params (Any): Additional parameters
|
|
1277
|
+
|
|
1278
|
+
Returns:
|
|
1279
|
+
np.ndarray: Mask with holes filled with the given value
|
|
1280
|
+
|
|
1281
|
+
"""
|
|
1282
|
+
if self.fill_mask is None or holes.size == 0:
|
|
1283
|
+
return mask
|
|
1284
|
+
|
|
1285
|
+
return f3d.cutout3d(mask, holes, self.fill_mask)
|
|
1286
|
+
|
|
1287
|
+
def apply_to_keypoints(
|
|
1288
|
+
self,
|
|
1289
|
+
keypoints: np.ndarray,
|
|
1290
|
+
holes: np.ndarray,
|
|
1291
|
+
**params: Any,
|
|
1292
|
+
) -> np.ndarray:
|
|
1293
|
+
"""Apply dropout to keypoints.
|
|
1294
|
+
|
|
1295
|
+
Args:
|
|
1296
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
1297
|
+
The first three columns are x, y, z coordinates.
|
|
1298
|
+
holes (np.ndarray): Array of holes with shape (num_holes, 6).
|
|
1299
|
+
Each hole is represented as [z1, y1, x1, z2, y2, x2]
|
|
1300
|
+
**params (Any): Additional parameters
|
|
1301
|
+
|
|
1302
|
+
Returns:
|
|
1303
|
+
np.ndarray: Filtered keypoints with same shape as input
|
|
1304
|
+
|
|
1305
|
+
"""
|
|
1306
|
+
if holes.size == 0:
|
|
1307
|
+
return keypoints
|
|
1308
|
+
processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
|
|
1309
|
+
|
|
1310
|
+
if processor is None or not processor.params.remove_invisible:
|
|
1311
|
+
return keypoints
|
|
1312
|
+
return f3d.filter_keypoints_in_holes3d(keypoints, holes)
|
|
1313
|
+
|
|
1314
|
+
|
|
1315
|
+
class CubicSymmetry(Transform3D):
|
|
1316
|
+
"""Applies a random cubic symmetry transformation to a 3D volume.
|
|
1317
|
+
|
|
1318
|
+
This transform is a 3D extension of D4. While D4 handles the 8 symmetries
|
|
1319
|
+
of a square (4 rotations x 2 reflections), CubicSymmetry handles all 48 symmetries of a cube.
|
|
1320
|
+
Like D4, this transform does not create any interpolation artifacts as it only remaps voxels
|
|
1321
|
+
from one position to another without any interpolation.
|
|
1322
|
+
|
|
1323
|
+
The 48 transformations consist of:
|
|
1324
|
+
- 24 rotations (orientation-preserving):
|
|
1325
|
+
* 4 rotations around each face diagonal (6 face diagonals x 4 rotations = 24)
|
|
1326
|
+
- 24 rotoreflections (orientation-reversing):
|
|
1327
|
+
* Reflection through a plane followed by any of the 24 rotations
|
|
1328
|
+
|
|
1329
|
+
For a cube, these transformations preserve:
|
|
1330
|
+
- All face centers (6)
|
|
1331
|
+
- All vertex positions (8)
|
|
1332
|
+
- All edge centers (12)
|
|
1333
|
+
|
|
1334
|
+
works with 3D volumes and masks of the shape (D, H, W) or (D, H, W, C)
|
|
1335
|
+
|
|
1336
|
+
Args:
|
|
1337
|
+
p (float): Probability of applying the transform. Default: 1.0
|
|
1338
|
+
|
|
1339
|
+
Targets:
|
|
1340
|
+
volume, mask3d, keypoints
|
|
1341
|
+
|
|
1342
|
+
Image types:
|
|
1343
|
+
uint8, float32
|
|
1344
|
+
|
|
1345
|
+
Note:
|
|
1346
|
+
- This transform is particularly useful for data augmentation in 3D medical imaging,
|
|
1347
|
+
crystallography, and voxel-based 3D modeling where the object's orientation
|
|
1348
|
+
is arbitrary.
|
|
1349
|
+
- All transformations preserve the object's chirality (handedness) when using
|
|
1350
|
+
pure rotations (indices 0-23) and invert it when using rotoreflections
|
|
1351
|
+
(indices 24-47).
|
|
1352
|
+
|
|
1353
|
+
Examples:
|
|
1354
|
+
>>> import numpy as np
|
|
1355
|
+
>>> import albumentations as A
|
|
1356
|
+
>>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
1357
|
+
>>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
|
|
1358
|
+
>>> transform = A.CubicSymmetry(p=1.0)
|
|
1359
|
+
>>> transformed = transform(volume=volume, mask3d=mask3d)
|
|
1360
|
+
>>> transformed_volume = transformed["volume"]
|
|
1361
|
+
>>> transformed_mask3d = transformed["mask3d"]
|
|
1362
|
+
|
|
1363
|
+
See Also:
|
|
1364
|
+
- D4: The 2D version that handles the 8 symmetries of a square
|
|
1365
|
+
|
|
1366
|
+
"""
|
|
1367
|
+
|
|
1368
|
+
_targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
|
|
1369
|
+
|
|
1370
|
+
def __init__(
|
|
1371
|
+
self,
|
|
1372
|
+
p: float = 1.0,
|
|
1373
|
+
):
|
|
1374
|
+
super().__init__(p=p)
|
|
1375
|
+
|
|
1376
|
+
def get_params_dependent_on_data(
|
|
1377
|
+
self,
|
|
1378
|
+
params: dict[str, Any],
|
|
1379
|
+
data: dict[str, Any],
|
|
1380
|
+
) -> dict[str, Any]:
|
|
1381
|
+
"""Generate parameters for cubic symmetry transformation.
|
|
1382
|
+
|
|
1383
|
+
Args:
|
|
1384
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
1385
|
+
data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
|
|
1386
|
+
|
|
1387
|
+
Returns:
|
|
1388
|
+
dict[str, Any]: Dictionary containing the randomly selected transformation index
|
|
1389
|
+
|
|
1390
|
+
"""
|
|
1391
|
+
# Randomly select one of 48 possible transformations
|
|
1392
|
+
volume_shape = data["volume"].shape
|
|
1393
|
+
return {"index": self.py_random.randint(0, 47), "volume_shape": volume_shape}
|
|
1394
|
+
|
|
1395
|
+
def apply_to_volume(self, volume: np.ndarray, index: int, **params: Any) -> np.ndarray:
|
|
1396
|
+
"""Apply cubic symmetry transformation to a 3D volume.
|
|
1397
|
+
|
|
1398
|
+
Args:
|
|
1399
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
1400
|
+
index (int): Index of the transformation to apply (0-47)
|
|
1401
|
+
**params (Any): Additional parameters
|
|
1402
|
+
|
|
1403
|
+
Returns:
|
|
1404
|
+
np.ndarray: Transformed volume with same shape as input
|
|
1405
|
+
|
|
1406
|
+
"""
|
|
1407
|
+
return f3d.transform_cube(volume, index)
|
|
1408
|
+
|
|
1409
|
+
def apply_to_keypoints(self, keypoints: np.ndarray, index: int, **params: Any) -> np.ndarray:
|
|
1410
|
+
"""Apply cubic symmetry transformation to keypoints.
|
|
1411
|
+
|
|
1412
|
+
Args:
|
|
1413
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
1414
|
+
The first three columns are x, y, z coordinates.
|
|
1415
|
+
index (int): Index of the transformation to apply (0-47)
|
|
1416
|
+
**params (Any): Additional parameters
|
|
1417
|
+
|
|
1418
|
+
Returns:
|
|
1419
|
+
np.ndarray: Transformed keypoints with same shape as input
|
|
1420
|
+
|
|
1421
|
+
"""
|
|
1422
|
+
return f3d.transform_cube_keypoints(keypoints, index, volume_shape=params["volume_shape"])
|