nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,1422 @@
1
+ """Module containing 3D transformation classes for volumetric data augmentation.
2
+
3
+ This module provides a collection of transformation classes designed specifically for
4
+ 3D volumetric data (such as medical CT/MRI scans). These transforms can manipulate properties
5
+ such as spatial dimensions, apply dropout effects, and perform symmetry operations on
6
+ 3D volumes, masks, and keypoints. Each transformation inherits from a base transform
7
+ interface and implements specific 3D augmentation logic.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Annotated, Any, Literal, Union, cast
13
+
14
+ import numpy as np
15
+ from pydantic import AfterValidator, field_validator, model_validator
16
+ from typing_extensions import Self
17
+
18
+ from albumentations.augmentations.geometric import functional as fgeometric
19
+ from albumentations.augmentations.transforms3d import functional as f3d
20
+ from albumentations.core.keypoints_utils import KeypointsProcessor
21
+ from albumentations.core.pydantic import check_range_bounds, nondecreasing
22
+ from albumentations.core.transforms_interface import BaseTransformInitSchema, Transform3D
23
+ from albumentations.core.type_definitions import Targets
24
+
25
+ __all__ = ["CenterCrop3D", "CoarseDropout3D", "CubicSymmetry", "Pad3D", "PadIfNeeded3D", "RandomCrop3D"]
26
+
27
+ NUM_DIMENSIONS = 3
28
+
29
+
30
+ class BasePad3D(Transform3D):
31
+ """Base class for 3D padding transforms.
32
+
33
+ This class serves as a foundation for all 3D transforms that perform padding operations
34
+ on volumetric data. It provides common functionality for padding 3D volumes, masks,
35
+ and processing 3D keypoints during padding operations.
36
+
37
+ The class handles different types of padding values (scalar or per-channel) and
38
+ provides separate fill values for volumes and masks.
39
+
40
+ Args:
41
+ fill (tuple[float, ...] | float): Value to fill the padded voxels for volumes.
42
+ Can be a single value for all channels or a tuple of values per channel.
43
+ fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for 3D masks.
44
+ Can be a single value for all channels or a tuple of values per channel.
45
+ p (float): Probability of applying the transform. Default: 1.0.
46
+
47
+ Targets:
48
+ volume, mask3d, keypoints
49
+
50
+ Note:
51
+ This is a base class and not intended to be used directly. Use its derivatives
52
+ like Pad3D or PadIfNeeded3D instead, or create a custom padding transform
53
+ by inheriting from this class.
54
+
55
+ Examples:
56
+ >>> import numpy as np
57
+ >>> import albumentations as A
58
+ >>>
59
+ >>> # Example of a custom padding transform inheriting from BasePad3D
60
+ >>> class CustomPad3D(A.BasePad3D):
61
+ ... def __init__(self, padding_size: tuple[int, int, int] = (5, 5, 5), *args, **kwargs):
62
+ ... super().__init__(*args, **kwargs)
63
+ ... self.padding_size = padding_size
64
+ ...
65
+ ... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
66
+ ... # Create symmetric padding: same amount on all sides of each dimension
67
+ ... pad_d, pad_h, pad_w = self.padding_size
68
+ ... padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
69
+ ... return {"padding": padding}
70
+ >>>
71
+ >>> # Prepare sample data
72
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
73
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
74
+ >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
75
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
76
+ >>>
77
+ >>> # Use the custom transform in a pipeline
78
+ >>> transform = A.Compose([
79
+ ... CustomPad3D(
80
+ ... padding_size=(2, 10, 10),
81
+ ... fill=0,
82
+ ... fill_mask=1,
83
+ ... p=1.0
84
+ ... )
85
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
86
+ >>>
87
+ >>> # Apply the transform
88
+ >>> transformed = transform(
89
+ ... volume=volume,
90
+ ... mask3d=mask3d,
91
+ ... keypoints=keypoints,
92
+ ... keypoint_labels=keypoint_labels
93
+ ... )
94
+ >>>
95
+ >>> # Get the transformed data
96
+ >>> transformed_volume = transformed["volume"] # Shape: (14, 120, 120)
97
+ >>> transformed_mask3d = transformed["mask3d"] # Shape: (14, 120, 120)
98
+ >>> transformed_keypoints = transformed["keypoints"] # Keypoints shifted by padding offsets
99
+ >>> transformed_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
100
+
101
+ """
102
+
103
+ _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
104
+
105
+ class InitSchema(Transform3D.InitSchema):
106
+ fill: tuple[float, ...] | float
107
+ fill_mask: tuple[float, ...] | float
108
+
109
+ def __init__(
110
+ self,
111
+ fill: tuple[float, ...] | float = 0,
112
+ fill_mask: tuple[float, ...] | float = 0,
113
+ p: float = 1.0,
114
+ ):
115
+ super().__init__(p=p)
116
+ self.fill = fill
117
+ self.fill_mask = fill_mask
118
+
119
+ def apply_to_volume(
120
+ self,
121
+ volume: np.ndarray,
122
+ padding: tuple[int, int, int, int, int, int],
123
+ **params: Any,
124
+ ) -> np.ndarray:
125
+ """Apply padding to a 3D volume.
126
+
127
+ Args:
128
+ volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
129
+ padding (tuple[int, int, int, int, int, int]): Padding values in format:
130
+ (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
131
+ **params (Any): Additional parameters
132
+
133
+ Returns:
134
+ np.ndarray: Padded volume with same number of dimensions as input
135
+
136
+ """
137
+ if padding == (0, 0, 0, 0, 0, 0):
138
+ return volume
139
+ return f3d.pad_3d_with_params(
140
+ volume=volume,
141
+ padding=padding,
142
+ value=self.fill,
143
+ )
144
+
145
+ def apply_to_mask3d(
146
+ self,
147
+ mask3d: np.ndarray,
148
+ padding: tuple[int, int, int, int, int, int],
149
+ **params: Any,
150
+ ) -> np.ndarray:
151
+ """Apply padding to a 3D mask.
152
+
153
+ Args:
154
+ mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
155
+ padding (tuple[int, int, int, int, int, int]): Padding values in format:
156
+ (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
157
+ **params (Any): Additional parameters
158
+
159
+ Returns:
160
+ np.ndarray: Padded mask with same number of dimensions as input
161
+
162
+ """
163
+ if padding == (0, 0, 0, 0, 0, 0):
164
+ return mask3d
165
+ return f3d.pad_3d_with_params(
166
+ volume=mask3d,
167
+ padding=padding,
168
+ value=cast("Union[tuple[float, ...], float]", self.fill_mask),
169
+ )
170
+
171
+ def apply_to_keypoints(self, keypoints: np.ndarray, **params: Any) -> np.ndarray:
172
+ """Apply padding to keypoints.
173
+
174
+ Args:
175
+ keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
176
+ The first three columns are x, y, z coordinates.
177
+ **params (Any): Additional parameters containing padding values
178
+
179
+ Returns:
180
+ np.ndarray: Shifted keypoints with same shape as input
181
+
182
+ """
183
+ padding = params["padding"]
184
+ shift_vector = np.array([padding[4], padding[2], padding[0]])
185
+ return fgeometric.shift_keypoints(keypoints, shift_vector)
186
+
187
+
188
+ class Pad3D(BasePad3D):
189
+ """Pad the sides of a 3D volume by specified number of voxels.
190
+
191
+ Args:
192
+ padding (int, tuple[int, int, int] or tuple[int, int, int, int, int, int]): Padding values. Can be:
193
+ * int - pad all sides by this value
194
+ * tuple[int, int, int] - symmetric padding (depth, height, width) where each value
195
+ is applied to both sides of the corresponding dimension
196
+ * tuple[int, int, int, int, int, int] - explicit padding per side in order:
197
+ (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
198
+
199
+ fill (tuple[float, ...] | float): Padding value for image
200
+ fill_mask (tuple[float, ...] | float): Padding value for mask
201
+ p (float): probability of applying the transform. Default: 1.0.
202
+
203
+ Targets:
204
+ volume, mask3d, keypoints
205
+
206
+ Image types:
207
+ uint8, float32
208
+
209
+ Note:
210
+ Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
211
+ with optional channel dimension as the last axis.
212
+
213
+ Examples:
214
+ >>> import numpy as np
215
+ >>> import albumentations as A
216
+ >>>
217
+ >>> # Prepare sample data
218
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
219
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
220
+ >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
221
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
222
+ >>>
223
+ >>> # Create the transform with symmetric padding
224
+ >>> transform = A.Compose([
225
+ ... A.Pad3D(
226
+ ... padding=(2, 5, 10), # (depth, height, width) applied symmetrically
227
+ ... fill=0,
228
+ ... fill_mask=1,
229
+ ... p=1.0
230
+ ... )
231
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
232
+ >>>
233
+ >>> # Apply the transform
234
+ >>> transformed = transform(
235
+ ... volume=volume,
236
+ ... mask3d=mask3d,
237
+ ... keypoints=keypoints,
238
+ ... keypoint_labels=keypoint_labels
239
+ ... )
240
+ >>>
241
+ >>> # Get the transformed data
242
+ >>> padded_volume = transformed["volume"] # Shape: (14, 110, 120)
243
+ >>> padded_mask3d = transformed["mask3d"] # Shape: (14, 110, 120)
244
+ >>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
245
+ >>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
246
+
247
+ """
248
+
249
+ class InitSchema(BasePad3D.InitSchema):
250
+ padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int]
251
+
252
+ @field_validator("padding")
253
+ @classmethod
254
+ def validate_padding(
255
+ cls,
256
+ v: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
257
+ ) -> int | tuple[int, int, int] | tuple[int, int, int, int, int, int]:
258
+ """Validate the padding parameter.
259
+
260
+ Args:
261
+ cls (type): The class object
262
+ v (int | tuple[int, int, int] | tuple[int, int, int, int, int, int]): The padding value to validate,
263
+ can be an integer or tuple of integers
264
+
265
+ Returns:
266
+ int | tuple[int, int, int] | tuple[int, int, int, int, int, int]: The validated padding value
267
+
268
+ Raises:
269
+ ValueError: If padding is negative or contains negative values
270
+
271
+ """
272
+ if isinstance(v, int) and v < 0:
273
+ raise ValueError("Padding value must be non-negative")
274
+ if isinstance(v, tuple) and not all(isinstance(i, int) and i >= 0 for i in v):
275
+ raise ValueError("Padding tuple must contain non-negative integers")
276
+
277
+ return v
278
+
279
+ def __init__(
280
+ self,
281
+ padding: int | tuple[int, int, int] | tuple[int, int, int, int, int, int],
282
+ fill: tuple[float, ...] | float = 0,
283
+ fill_mask: tuple[float, ...] | float = 0,
284
+ p: float = 1.0,
285
+ ):
286
+ super().__init__(fill=fill, fill_mask=fill_mask, p=p)
287
+ self.padding = padding
288
+ self.fill = fill
289
+ self.fill_mask = fill_mask
290
+
291
+ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
292
+ """Get parameters dependent on input data.
293
+
294
+ Args:
295
+ params (dict[str, Any]): Dictionary of existing parameters
296
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
297
+
298
+ Returns:
299
+ dict[str, Any]: Dictionary containing the padding parameter tuple in format:
300
+ (depth_front, depth_back, height_top, height_bottom, width_left, width_right)
301
+
302
+ """
303
+ if isinstance(self.padding, int):
304
+ pad_d = pad_h = pad_w = self.padding
305
+ padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
306
+ elif len(self.padding) == NUM_DIMENSIONS:
307
+ pad_d, pad_h, pad_w = self.padding # type: ignore[misc]
308
+ padding = (pad_d, pad_d, pad_h, pad_h, pad_w, pad_w)
309
+ else:
310
+ padding = self.padding # type: ignore[assignment]
311
+
312
+ return {"padding": padding}
313
+
314
+
315
+ class PadIfNeeded3D(BasePad3D):
316
+ """Pads the sides of a 3D volume if its dimensions are less than specified minimum dimensions.
317
+ If the pad_divisor_zyx is specified, the function additionally ensures that the volume
318
+ dimensions are divisible by these values.
319
+
320
+ Args:
321
+ min_zyx (tuple[int, int, int] | None): Minimum desired size as (depth, height, width).
322
+ Ensures volume dimensions are at least these values.
323
+ If not specified, pad_divisor_zyx must be provided.
324
+ pad_divisor_zyx (tuple[int, int, int] | None): If set, pads each dimension to make it
325
+ divisible by corresponding value in format (depth_div, height_div, width_div).
326
+ If not specified, min_zyx must be provided.
327
+ position (Literal["center", "random"]): Position where the volume is to be placed after padding.
328
+ Default is 'center'.
329
+ fill (tuple[float, ...] | float): Value to fill the border voxels for volume. Default: 0
330
+ fill_mask (tuple[float, ...] | float): Value to fill the border voxels for masks. Default: 0
331
+ p (float): Probability of applying the transform. Default: 1.0
332
+
333
+ Targets:
334
+ volume, mask3d, keypoints
335
+
336
+ Image types:
337
+ uint8, float32
338
+
339
+ Note:
340
+ Input volume should be a numpy array with dimensions ordered as (z, y, x) or (depth, height, width),
341
+ with optional channel dimension as the last axis.
342
+
343
+ Examples:
344
+ >>> import numpy as np
345
+ >>> import albumentations as A
346
+ >>>
347
+ >>> # Prepare sample data
348
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
349
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
350
+ >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
351
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
352
+ >>>
353
+ >>> # Create a transform with both min_zyx and pad_divisor_zyx
354
+ >>> transform = A.Compose([
355
+ ... A.PadIfNeeded3D(
356
+ ... min_zyx=(16, 128, 128), # Minimum size (depth, height, width)
357
+ ... pad_divisor_zyx=(8, 16, 16), # Make dimensions divisible by these values
358
+ ... position="center", # Center the volume in the padded space
359
+ ... fill=0, # Fill value for volume
360
+ ... fill_mask=1, # Fill value for mask
361
+ ... p=1.0
362
+ ... )
363
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
364
+ >>>
365
+ >>> # Apply the transform
366
+ >>> transformed = transform(
367
+ ... volume=volume,
368
+ ... mask3d=mask3d,
369
+ ... keypoints=keypoints,
370
+ ... keypoint_labels=keypoint_labels
371
+ ... )
372
+ >>>
373
+ >>> # Get the transformed data
374
+ >>> padded_volume = transformed["volume"] # Shape: (16, 128, 128)
375
+ >>> padded_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
376
+ >>> padded_keypoints = transformed["keypoints"] # Keypoints shifted by padding
377
+ >>> padded_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
378
+
379
+ """
380
+
381
+ class InitSchema(BasePad3D.InitSchema):
382
+ min_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(0, None))]
383
+ pad_divisor_zyx: Annotated[tuple[int, int, int] | None, AfterValidator(check_range_bounds(1, None))]
384
+ position: Literal["center", "random"]
385
+
386
+ @model_validator(mode="after")
387
+ def validate_params(self) -> Self:
388
+ """Validate that either min_zyx or pad_divisor_zyx is provided.
389
+
390
+ Returns:
391
+ Self: Self reference for method chaining
392
+
393
+ Raises:
394
+ ValueError: If both min_zyx and pad_divisor_zyx are None
395
+
396
+ """
397
+ if self.min_zyx is None and self.pad_divisor_zyx is None:
398
+ msg = "At least one of min_zyx or pad_divisor_zyx must be set"
399
+ raise ValueError(msg)
400
+ return self
401
+
402
+ def __init__(
403
+ self,
404
+ min_zyx: tuple[int, int, int] | None = None,
405
+ pad_divisor_zyx: tuple[int, int, int] | None = None,
406
+ position: Literal["center", "random"] = "center",
407
+ fill: tuple[float, ...] | float = 0,
408
+ fill_mask: tuple[float, ...] | float = 0,
409
+ p: float = 1.0,
410
+ ):
411
+ super().__init__(fill=fill, fill_mask=fill_mask, p=p)
412
+ self.min_zyx = min_zyx
413
+ self.pad_divisor_zyx = pad_divisor_zyx
414
+ self.position = position
415
+
416
+ def get_params_dependent_on_data(
417
+ self,
418
+ params: dict[str, Any],
419
+ data: dict[str, Any],
420
+ ) -> dict[str, Any]:
421
+ """Calculate padding parameters based on input data dimensions.
422
+
423
+ Args:
424
+ params (dict[str, Any]): Dictionary of existing parameters
425
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
426
+
427
+ Returns:
428
+ dict[str, Any]: Dictionary containing calculated padding parameters
429
+
430
+ """
431
+ depth, height, width = data["volume"].shape[:3]
432
+ sizes = (depth, height, width)
433
+
434
+ paddings = [
435
+ fgeometric.get_dimension_padding(
436
+ current_size=size,
437
+ min_size=self.min_zyx[i] if self.min_zyx else None,
438
+ divisor=self.pad_divisor_zyx[i] if self.pad_divisor_zyx else None,
439
+ )
440
+ for i, size in enumerate(sizes)
441
+ ]
442
+
443
+ padding = f3d.adjust_padding_by_position3d(
444
+ paddings=paddings,
445
+ position=self.position,
446
+ py_random=self.py_random,
447
+ )
448
+
449
+ return {"padding": padding}
450
+
451
+
452
+ class BaseCropAndPad3D(Transform3D):
453
+ """Base class for 3D transforms that need both cropping and padding.
454
+
455
+ This class serves as a foundation for transforms that combine cropping and padding operations
456
+ on 3D volumetric data. It provides functionality for calculating padding parameters,
457
+ applying crop and pad operations to volumes, masks, and handling keypoint coordinate shifts.
458
+
459
+ Args:
460
+ pad_if_needed (bool): Whether to pad if the volume is smaller than target dimensions
461
+ fill (tuple[float, ...] | float): Value to fill the padded voxels for volume
462
+ fill_mask (tuple[float, ...] | float): Value to fill the padded voxels for mask
463
+ pad_position (Literal["center", "random"]): How to distribute padding when needed
464
+ "center" - equal amount on both sides, "random" - random distribution
465
+ p (float): Probability of applying the transform. Default: 1.0
466
+
467
+ Targets:
468
+ volume, mask3d, keypoints
469
+
470
+ Note:
471
+ This is a base class and not intended to be used directly. Use its derivatives
472
+ like CenterCrop3D or RandomCrop3D instead, or create a custom transform
473
+ by inheriting from this class.
474
+
475
+ Examples:
476
+ >>> import numpy as np
477
+ >>> import albumentations as A
478
+ >>>
479
+ >>> # Example of a custom crop transform inheriting from BaseCropAndPad3D
480
+ >>> class CustomFixedCrop3D(A.BaseCropAndPad3D):
481
+ ... def __init__(self, crop_size: tuple[int, int, int] = (8, 64, 64), *args, **kwargs):
482
+ ... super().__init__(
483
+ ... pad_if_needed=True,
484
+ ... fill=0,
485
+ ... fill_mask=0,
486
+ ... pad_position="center",
487
+ ... *args,
488
+ ... **kwargs
489
+ ... )
490
+ ... self.crop_size = crop_size
491
+ ...
492
+ ... def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
493
+ ... # Get the volume shape
494
+ ... volume = data["volume"]
495
+ ... z, h, w = volume.shape[:3]
496
+ ... target_z, target_h, target_w = self.crop_size
497
+ ...
498
+ ... # Check if padding is needed and calculate parameters
499
+ ... pad_params = self._get_pad_params(
500
+ ... image_shape=(z, h, w),
501
+ ... target_shape=self.crop_size,
502
+ ... )
503
+ ...
504
+ ... # Update dimensions if padding is applied
505
+ ... if pad_params is not None:
506
+ ... z = z + pad_params["pad_front"] + pad_params["pad_back"]
507
+ ... h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
508
+ ... w = w + pad_params["pad_left"] + pad_params["pad_right"]
509
+ ...
510
+ ... # Calculate fixed crop coordinates - always start at position (0,0,0)
511
+ ... crop_coords = (0, target_z, 0, target_h, 0, target_w)
512
+ ...
513
+ ... return {
514
+ ... "crop_coords": crop_coords,
515
+ ... "pad_params": pad_params,
516
+ ... }
517
+ >>>
518
+ >>> # Prepare sample data
519
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
520
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
521
+ >>> keypoints = np.array([[20, 30, 5], [60, 70, 8]], dtype=np.float32) # (x, y, z)
522
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
523
+ >>>
524
+ >>> # Use the custom transform in a pipeline
525
+ >>> transform = A.Compose([
526
+ ... CustomFixedCrop3D(
527
+ ... crop_size=(8, 64, 64), # Crop first 8x64x64 voxels (with padding if needed)
528
+ ... p=1.0
529
+ ... )
530
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
531
+ >>>
532
+ >>> # Apply the transform
533
+ >>> transformed = transform(
534
+ ... volume=volume,
535
+ ... mask3d=mask3d,
536
+ ... keypoints=keypoints,
537
+ ... keypoint_labels=keypoint_labels
538
+ ... )
539
+ >>>
540
+ >>> # Get the transformed data
541
+ >>> cropped_volume = transformed["volume"] # Shape: (8, 64, 64)
542
+ >>> cropped_mask3d = transformed["mask3d"] # Shape: (8, 64, 64)
543
+ >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to crop
544
+ >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
545
+
546
+ """
547
+
548
+ _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
549
+
550
+ class InitSchema(Transform3D.InitSchema):
551
+ pad_if_needed: bool
552
+ fill: tuple[float, ...] | float
553
+ fill_mask: tuple[float, ...] | float
554
+ pad_position: Literal["center", "random"]
555
+
556
+ def __init__(
557
+ self,
558
+ pad_if_needed: bool,
559
+ fill: tuple[float, ...] | float,
560
+ fill_mask: tuple[float, ...] | float,
561
+ pad_position: Literal["center", "random"],
562
+ p: float = 1.0,
563
+ ):
564
+ super().__init__(p=p)
565
+ self.pad_if_needed = pad_if_needed
566
+ self.fill = fill
567
+ self.fill_mask = fill_mask
568
+ self.pad_position = pad_position
569
+
570
+ def _random_pad(self, pad: int) -> tuple[int, int]:
571
+ """Generate random padding values.
572
+
573
+ Args:
574
+ pad (int): Total padding value to distribute
575
+
576
+ Returns:
577
+ tuple[int, int]: Random padding values (front, back)
578
+
579
+ """
580
+ if pad > 0:
581
+ pad_start = self.py_random.randint(0, pad)
582
+ pad_end = pad - pad_start
583
+ else:
584
+ pad_start = pad_end = 0
585
+ return pad_start, pad_end
586
+
587
+ def _center_pad(self, pad: int) -> tuple[int, int]:
588
+ """Generate centered padding values.
589
+
590
+ Args:
591
+ pad (int): Total padding value to distribute
592
+
593
+ Returns:
594
+ tuple[int, int]: Centered padding values (front, back)
595
+
596
+ """
597
+ pad_start = pad // 2
598
+ pad_end = pad - pad_start
599
+ return pad_start, pad_end
600
+
601
+ def _get_pad_params(
602
+ self,
603
+ image_shape: tuple[int, int, int],
604
+ target_shape: tuple[int, int, int],
605
+ ) -> dict[str, int] | None:
606
+ """Calculate padding parameters to reach target shape.
607
+
608
+ Args:
609
+ image_shape (tuple[int, int, int]): Current shape (depth, height, width)
610
+ target_shape (tuple[int, int, int]): Target shape (depth, height, width)
611
+
612
+ Returns:
613
+ dict[str, int] | None: Padding parameters or None if no padding needed
614
+
615
+ """
616
+ if not self.pad_if_needed:
617
+ return None
618
+
619
+ z, h, w = image_shape
620
+ target_z, target_h, target_w = target_shape
621
+
622
+ # Calculate total padding needed for each dimension
623
+ z_pad = max(0, target_z - z)
624
+ h_pad = max(0, target_h - h)
625
+ w_pad = max(0, target_w - w)
626
+
627
+ if z_pad == 0 and h_pad == 0 and w_pad == 0:
628
+ return None
629
+
630
+ # For center padding, split equally
631
+ if self.pad_position == "center":
632
+ z_front, z_back = self._center_pad(z_pad)
633
+ h_top, h_bottom = self._center_pad(h_pad)
634
+ w_left, w_right = self._center_pad(w_pad)
635
+ # For random padding, randomly distribute the padding
636
+ else: # random
637
+ z_front, z_back = self._random_pad(z_pad)
638
+ h_top, h_bottom = self._random_pad(h_pad)
639
+ w_left, w_right = self._random_pad(w_pad)
640
+
641
+ return {
642
+ "pad_front": z_front,
643
+ "pad_back": z_back,
644
+ "pad_top": h_top,
645
+ "pad_bottom": h_bottom,
646
+ "pad_left": w_left,
647
+ "pad_right": w_right,
648
+ }
649
+
650
+ def apply_to_volume(
651
+ self,
652
+ volume: np.ndarray,
653
+ crop_coords: tuple[int, int, int, int, int, int],
654
+ pad_params: dict[str, int] | None,
655
+ **params: Any,
656
+ ) -> np.ndarray:
657
+ """Apply cropping and padding to a 3D volume.
658
+
659
+ Args:
660
+ volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
661
+ crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
662
+ pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
663
+ **params (Any): Additional parameters
664
+
665
+ Returns:
666
+ np.ndarray: Cropped and padded volume with same number of dimensions as input
667
+
668
+ """
669
+ # First crop
670
+ cropped = f3d.crop3d(volume, crop_coords)
671
+
672
+ # Then pad if needed
673
+ if pad_params is not None:
674
+ padding = (
675
+ pad_params["pad_front"],
676
+ pad_params["pad_back"],
677
+ pad_params["pad_top"],
678
+ pad_params["pad_bottom"],
679
+ pad_params["pad_left"],
680
+ pad_params["pad_right"],
681
+ )
682
+ return f3d.pad_3d_with_params(
683
+ cropped,
684
+ padding=padding,
685
+ value=self.fill,
686
+ )
687
+
688
+ return cropped
689
+
690
+ def apply_to_mask3d(
691
+ self,
692
+ mask3d: np.ndarray,
693
+ crop_coords: tuple[int, int, int, int, int, int],
694
+ pad_params: dict[str, int] | None,
695
+ **params: Any,
696
+ ) -> np.ndarray:
697
+ """Apply cropping and padding to a 3D mask.
698
+
699
+ Args:
700
+ mask3d (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
701
+ crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
702
+ pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
703
+ **params (Any): Additional parameters
704
+
705
+ Returns:
706
+ np.ndarray: Cropped and padded mask with same number of dimensions as input
707
+
708
+ """
709
+ # First crop
710
+ cropped = f3d.crop3d(mask3d, crop_coords)
711
+
712
+ # Then pad if needed
713
+ if pad_params is not None:
714
+ padding = (
715
+ pad_params["pad_front"],
716
+ pad_params["pad_back"],
717
+ pad_params["pad_top"],
718
+ pad_params["pad_bottom"],
719
+ pad_params["pad_left"],
720
+ pad_params["pad_right"],
721
+ )
722
+ return f3d.pad_3d_with_params(
723
+ cropped,
724
+ padding=padding,
725
+ value=cast("Union[tuple[float, ...], float]", self.fill_mask),
726
+ )
727
+
728
+ return cropped
729
+
730
+ def apply_to_keypoints(
731
+ self,
732
+ keypoints: np.ndarray,
733
+ crop_coords: tuple[int, int, int, int, int, int],
734
+ pad_params: dict[str, int] | None,
735
+ **params: Any,
736
+ ) -> np.ndarray:
737
+ """Apply cropping and padding to keypoints.
738
+
739
+ Args:
740
+ keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
741
+ The first three columns are x, y, z coordinates.
742
+ crop_coords (tuple[int, int, int, int, int, int]): Crop coordinates (z1, z2, y1, y2, x1, x2)
743
+ pad_params (dict[str, int] | None): Padding parameters or None if no padding needed
744
+ **params (Any): Additional parameters
745
+
746
+ Returns:
747
+ np.ndarray: Shifted keypoints with same shape as input
748
+
749
+ """
750
+ # Extract crop start coordinates (z1,y1,x1)
751
+ crop_z1, _, crop_y1, _, crop_x1, _ = crop_coords
752
+
753
+ # Initialize shift vector with negative crop coordinates
754
+ shift = np.array(
755
+ [
756
+ -crop_x1, # X shift
757
+ -crop_y1, # Y shift
758
+ -crop_z1, # Z shift
759
+ ],
760
+ )
761
+
762
+ # Add padding shift if needed
763
+ if pad_params is not None:
764
+ shift += np.array(
765
+ [
766
+ pad_params["pad_left"], # X shift
767
+ pad_params["pad_top"], # Y shift
768
+ pad_params["pad_front"], # Z shift
769
+ ],
770
+ )
771
+
772
+ # Apply combined shift
773
+ return fgeometric.shift_keypoints(keypoints, shift)
774
+
775
+
776
+ class CenterCrop3D(BaseCropAndPad3D):
777
+ """Crop the center of 3D volume.
778
+
779
+ Args:
780
+ size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
781
+ pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
782
+ fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
783
+ fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
784
+ p (float): probability of applying the transform. Default: 1.0
785
+
786
+ Targets:
787
+ volume, mask3d, keypoints
788
+
789
+ Image types:
790
+ uint8, float32
791
+
792
+ Note:
793
+ If you want to perform cropping only in the XY plane while preserving all slices along
794
+ the Z axis, consider using CenterCrop instead. CenterCrop will apply the same XY crop
795
+ to each slice independently, maintaining the full depth of the volume.
796
+
797
+ Examples:
798
+ >>> import numpy as np
799
+ >>> import albumentations as A
800
+ >>>
801
+ >>> # Prepare sample data
802
+ >>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
803
+ >>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
804
+ >>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
805
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
806
+ >>>
807
+ >>> # Create the transform - crop to 16x128x128 from center
808
+ >>> transform = A.Compose([
809
+ ... A.CenterCrop3D(
810
+ ... size=(16, 128, 128), # Output size (depth, height, width)
811
+ ... pad_if_needed=True, # Pad if input is smaller than crop size
812
+ ... fill=0, # Fill value for volume padding
813
+ ... fill_mask=1, # Fill value for mask padding
814
+ ... p=1.0
815
+ ... )
816
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
817
+ >>>
818
+ >>> # Apply the transform
819
+ >>> transformed = transform(
820
+ ... volume=volume,
821
+ ... mask3d=mask3d,
822
+ ... keypoints=keypoints,
823
+ ... keypoint_labels=keypoint_labels
824
+ ... )
825
+ >>>
826
+ >>> # Get the transformed data
827
+ >>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
828
+ >>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
829
+ >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to center crop
830
+ >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
831
+ >>>
832
+ >>> # Example with a small volume that requires padding
833
+ >>> small_volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8)
834
+ >>> small_transform = A.Compose([
835
+ ... A.CenterCrop3D(
836
+ ... size=(16, 128, 128),
837
+ ... pad_if_needed=True, # Will pad since the input is smaller
838
+ ... fill=0,
839
+ ... p=1.0
840
+ ... )
841
+ ... ])
842
+ >>> small_result = small_transform(volume=small_volume)
843
+ >>> padded_and_cropped = small_result["volume"] # Shape: (16, 128, 128), padded to size
844
+
845
+ """
846
+
847
+ class InitSchema(BaseTransformInitSchema):
848
+ size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
849
+ pad_if_needed: bool
850
+ fill: tuple[float, ...] | float
851
+ fill_mask: tuple[float, ...] | float
852
+
853
+ def __init__(
854
+ self,
855
+ size: tuple[int, int, int],
856
+ pad_if_needed: bool = False,
857
+ fill: tuple[float, ...] | float = 0,
858
+ fill_mask: tuple[float, ...] | float = 0,
859
+ p: float = 1.0,
860
+ ):
861
+ super().__init__(
862
+ pad_if_needed=pad_if_needed,
863
+ fill=fill,
864
+ fill_mask=fill_mask,
865
+ pad_position="center", # Center crop always uses center padding
866
+ p=p,
867
+ )
868
+ self.size = size
869
+
870
+ def get_params_dependent_on_data(
871
+ self,
872
+ params: dict[str, Any],
873
+ data: dict[str, Any],
874
+ ) -> dict[str, Any]:
875
+ """Calculate crop coordinates for center cropping.
876
+
877
+ Args:
878
+ params (dict[str, Any]): Dictionary of existing parameters
879
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
880
+
881
+ Returns:
882
+ dict[str, Any]: Dictionary containing crop coordinates and optional padding parameters
883
+
884
+ """
885
+ volume = data["volume"]
886
+ z, h, w = volume.shape[:3]
887
+ target_z, target_h, target_w = self.size
888
+
889
+ # Get padding params if needed
890
+ pad_params = self._get_pad_params(
891
+ image_shape=(z, h, w),
892
+ target_shape=self.size,
893
+ )
894
+
895
+ # Update dimensions if padding is applied
896
+ if pad_params is not None:
897
+ z = z + pad_params["pad_front"] + pad_params["pad_back"]
898
+ h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
899
+ w = w + pad_params["pad_left"] + pad_params["pad_right"]
900
+
901
+ # Validate dimensions after padding
902
+ if z < target_z or h < target_h or w < target_w:
903
+ msg = (
904
+ f"Crop size {self.size} is larger than padded image size ({z}, {h}, {w}). "
905
+ f"This should not happen - please report this as a bug."
906
+ )
907
+ raise ValueError(msg)
908
+
909
+ # For CenterCrop3D:
910
+ z_start = (z - target_z) // 2
911
+ h_start = (h - target_h) // 2
912
+ w_start = (w - target_w) // 2
913
+
914
+ crop_coords = (
915
+ z_start,
916
+ z_start + target_z,
917
+ h_start,
918
+ h_start + target_h,
919
+ w_start,
920
+ w_start + target_w,
921
+ )
922
+
923
+ return {
924
+ "crop_coords": crop_coords,
925
+ "pad_params": pad_params,
926
+ }
927
+
928
+
929
+ class RandomCrop3D(BaseCropAndPad3D):
930
+ """Crop random part of 3D volume.
931
+
932
+ Args:
933
+ size (tuple[int, int, int]): Desired output size of the crop in format (depth, height, width)
934
+ pad_if_needed (bool): Whether to pad if the volume is smaller than desired crop size. Default: False
935
+ fill (tuple[float, float] | float): Padding value for image if pad_if_needed is True. Default: 0
936
+ fill_mask (tuple[float, float] | float): Padding value for mask if pad_if_needed is True. Default: 0
937
+ p (float): probability of applying the transform. Default: 1.0
938
+
939
+ Targets:
940
+ volume, mask3d, keypoints
941
+
942
+ Image types:
943
+ uint8, float32
944
+
945
+ Note:
946
+ If you want to perform random cropping only in the XY plane while preserving all slices along
947
+ the Z axis, consider using RandomCrop instead. RandomCrop will apply the same XY crop
948
+ to each slice independently, maintaining the full depth of the volume.
949
+
950
+ Examples:
951
+ >>> import numpy as np
952
+ >>> import albumentations as A
953
+ >>>
954
+ >>> # Prepare sample data
955
+ >>> volume = np.random.randint(0, 256, (20, 200, 200), dtype=np.uint8) # (D, H, W)
956
+ >>> mask3d = np.random.randint(0, 2, (20, 200, 200), dtype=np.uint8) # (D, H, W)
957
+ >>> keypoints = np.array([[100, 100, 10], [150, 150, 15]], dtype=np.float32) # (x, y, z)
958
+ >>> keypoint_labels = [1, 2] # Labels for each keypoint
959
+ >>>
960
+ >>> # Create the transform with random crop and padding if needed
961
+ >>> transform = A.Compose([
962
+ ... A.RandomCrop3D(
963
+ ... size=(16, 128, 128), # Output size (depth, height, width)
964
+ ... pad_if_needed=True, # Pad if input is smaller than crop size
965
+ ... fill=0, # Fill value for volume padding
966
+ ... fill_mask=1, # Fill value for mask padding
967
+ ... p=1.0
968
+ ... )
969
+ ... ], keypoint_params=A.KeypointParams(format='xyz', label_fields=['keypoint_labels']))
970
+ >>>
971
+ >>> # Apply the transform
972
+ >>> transformed = transform(
973
+ ... volume=volume,
974
+ ... mask3d=mask3d,
975
+ ... keypoints=keypoints,
976
+ ... keypoint_labels=keypoint_labels
977
+ ... )
978
+ >>>
979
+ >>> # Get the transformed data
980
+ >>> cropped_volume = transformed["volume"] # Shape: (16, 128, 128)
981
+ >>> cropped_mask3d = transformed["mask3d"] # Shape: (16, 128, 128)
982
+ >>> cropped_keypoints = transformed["keypoints"] # Keypoints shifted relative to random crop
983
+ >>> cropped_keypoint_labels = transformed["keypoint_labels"] # Labels remain unchanged
984
+
985
+ """
986
+
987
+ class InitSchema(BaseTransformInitSchema):
988
+ size: Annotated[tuple[int, int, int], AfterValidator(check_range_bounds(1, None))]
989
+ pad_if_needed: bool
990
+ fill: tuple[float, ...] | float
991
+ fill_mask: tuple[float, ...] | float
992
+
993
+ def __init__(
994
+ self,
995
+ size: tuple[int, int, int],
996
+ pad_if_needed: bool = False,
997
+ fill: tuple[float, ...] | float = 0,
998
+ fill_mask: tuple[float, ...] | float = 0,
999
+ p: float = 1.0,
1000
+ ):
1001
+ super().__init__(
1002
+ pad_if_needed=pad_if_needed,
1003
+ fill=fill,
1004
+ fill_mask=fill_mask,
1005
+ pad_position="random", # Random crop uses random padding position
1006
+ p=p,
1007
+ )
1008
+ self.size = size
1009
+
1010
+ def get_params_dependent_on_data(
1011
+ self,
1012
+ params: dict[str, Any],
1013
+ data: dict[str, Any],
1014
+ ) -> dict[str, Any]:
1015
+ """Calculate random crop coordinates.
1016
+
1017
+ Args:
1018
+ params (dict[str, Any]): Dictionary of existing parameters
1019
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
1020
+
1021
+ Returns:
1022
+ dict[str, Any]: Dictionary containing randomly generated crop coordinates and optional padding parameters
1023
+
1024
+ """
1025
+ volume = data["volume"]
1026
+ z, h, w = volume.shape[:3]
1027
+ target_z, target_h, target_w = self.size
1028
+
1029
+ # Get padding params if needed
1030
+ pad_params = self._get_pad_params(
1031
+ image_shape=(z, h, w),
1032
+ target_shape=self.size,
1033
+ )
1034
+
1035
+ # Update dimensions if padding is applied
1036
+ if pad_params is not None:
1037
+ z = z + pad_params["pad_front"] + pad_params["pad_back"]
1038
+ h = h + pad_params["pad_top"] + pad_params["pad_bottom"]
1039
+ w = w + pad_params["pad_left"] + pad_params["pad_right"]
1040
+
1041
+ # Calculate random crop coordinates
1042
+ z_start = self.py_random.randint(0, max(0, z - target_z))
1043
+ h_start = self.py_random.randint(0, max(0, h - target_h))
1044
+ w_start = self.py_random.randint(0, max(0, w - target_w))
1045
+
1046
+ crop_coords = (
1047
+ z_start,
1048
+ z_start + target_z,
1049
+ h_start,
1050
+ h_start + target_h,
1051
+ w_start,
1052
+ w_start + target_w,
1053
+ )
1054
+
1055
+ return {
1056
+ "crop_coords": crop_coords,
1057
+ "pad_params": pad_params,
1058
+ }
1059
+
1060
+
1061
+ class CoarseDropout3D(Transform3D):
1062
+ """CoarseDropout3D randomly drops out cuboid regions from a 3D volume and optionally,
1063
+ the corresponding regions in an associated 3D mask, to simulate occlusion and
1064
+ varied object sizes found in real-world volumetric data.
1065
+
1066
+ Args:
1067
+ num_holes_range (tuple[int, int]): Range (min, max) for the number of cuboid
1068
+ regions to drop out. Default: (1, 1)
1069
+ hole_depth_range (tuple[float, float]): Range (min, max) for the depth
1070
+ of dropout regions as a fraction of the volume depth (between 0 and 1). Default: (0.1, 0.2)
1071
+ hole_height_range (tuple[float, float]): Range (min, max) for the height
1072
+ of dropout regions as a fraction of the volume height (between 0 and 1). Default: (0.1, 0.2)
1073
+ hole_width_range (tuple[float, float]): Range (min, max) for the width
1074
+ of dropout regions as a fraction of the volume width (between 0 and 1). Default: (0.1, 0.2)
1075
+ fill (tuple[float, float] | float): Value for the dropped voxels. Can be:
1076
+ - int or float: all channels are filled with this value
1077
+ - tuple: tuple of values for each channel
1078
+ Default: 0
1079
+ fill_mask (tuple[float, float] | float | None): Fill value for dropout regions in the 3D mask.
1080
+ If None, mask regions corresponding to volume dropouts are unchanged. Default: None
1081
+ p (float): Probability of applying the transform. Default: 0.5
1082
+
1083
+ Targets:
1084
+ volume, mask3d, keypoints
1085
+
1086
+ Image types:
1087
+ uint8, float32
1088
+
1089
+ Note:
1090
+ - The actual number and size of dropout regions are randomly chosen within the specified ranges.
1091
+ - All values in hole_depth_range, hole_height_range and hole_width_range must be between 0 and 1.
1092
+ - If you want to apply dropout only in the XY plane while preserving the full depth dimension,
1093
+ consider using CoarseDropout instead. CoarseDropout will apply the same rectangular dropout
1094
+ to each slice independently, effectively creating cylindrical dropout regions that extend
1095
+ through the entire depth of the volume.
1096
+
1097
+ Examples:
1098
+ >>> import numpy as np
1099
+ >>> import albumentations as A
1100
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
1101
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
1102
+ >>> aug = A.CoarseDropout3D(
1103
+ ... num_holes_range=(3, 6),
1104
+ ... hole_depth_range=(0.1, 0.2),
1105
+ ... hole_height_range=(0.1, 0.2),
1106
+ ... hole_width_range=(0.1, 0.2),
1107
+ ... fill=0,
1108
+ ... p=1.0
1109
+ ... )
1110
+ >>> transformed = aug(volume=volume, mask3d=mask3d)
1111
+ >>> transformed_volume, transformed_mask3d = transformed["volume"], transformed["mask3d"]
1112
+
1113
+ """
1114
+
1115
+ _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
1116
+
1117
+ class InitSchema(Transform3D.InitSchema):
1118
+ num_holes_range: Annotated[
1119
+ tuple[int, int],
1120
+ AfterValidator(check_range_bounds(0, None)),
1121
+ AfterValidator(nondecreasing),
1122
+ ]
1123
+ hole_depth_range: Annotated[
1124
+ tuple[float, float],
1125
+ AfterValidator(check_range_bounds(0, 1)),
1126
+ AfterValidator(nondecreasing),
1127
+ ]
1128
+ hole_height_range: Annotated[
1129
+ tuple[float, float],
1130
+ AfterValidator(check_range_bounds(0, 1)),
1131
+ AfterValidator(nondecreasing),
1132
+ ]
1133
+ hole_width_range: Annotated[
1134
+ tuple[float, float],
1135
+ AfterValidator(check_range_bounds(0, 1)),
1136
+ AfterValidator(nondecreasing),
1137
+ ]
1138
+ fill: tuple[float, ...] | float
1139
+ fill_mask: tuple[float, ...] | float | None
1140
+
1141
+ @staticmethod
1142
+ def validate_range(range_value: tuple[float, float], range_name: str) -> None:
1143
+ """Validate that range values are between 0 and 1 and in non-decreasing order.
1144
+
1145
+ Args:
1146
+ range_value (tuple[float, float]): Tuple of (min, max) values to check
1147
+ range_name (str): Name of the range for error reporting
1148
+
1149
+ Raises:
1150
+ ValueError: If range values are invalid
1151
+
1152
+ """
1153
+ if not 0 <= range_value[0] <= range_value[1] <= 1:
1154
+ raise ValueError(
1155
+ f"All values in {range_name} should be in [0, 1] range and first value "
1156
+ f"should be less or equal than the second value. Got: {range_value}",
1157
+ )
1158
+
1159
+ @model_validator(mode="after")
1160
+ def _check_ranges(self) -> Self:
1161
+ self.validate_range(self.hole_depth_range, "hole_depth_range")
1162
+ self.validate_range(self.hole_height_range, "hole_height_range")
1163
+ self.validate_range(self.hole_width_range, "hole_width_range")
1164
+ return self
1165
+
1166
+ def __init__(
1167
+ self,
1168
+ num_holes_range: tuple[int, int] = (1, 1),
1169
+ hole_depth_range: tuple[float, float] = (0.1, 0.2),
1170
+ hole_height_range: tuple[float, float] = (0.1, 0.2),
1171
+ hole_width_range: tuple[float, float] = (0.1, 0.2),
1172
+ fill: tuple[float, ...] | float = 0,
1173
+ fill_mask: tuple[float, ...] | float | None = None,
1174
+ p: float = 0.5,
1175
+ ):
1176
+ super().__init__(p=p)
1177
+ self.num_holes_range = num_holes_range
1178
+ self.hole_depth_range = hole_depth_range
1179
+ self.hole_height_range = hole_height_range
1180
+ self.hole_width_range = hole_width_range
1181
+ self.fill = fill
1182
+ self.fill_mask = fill_mask
1183
+
1184
+ def calculate_hole_dimensions(
1185
+ self,
1186
+ volume_shape: tuple[int, int, int],
1187
+ depth_range: tuple[float, float],
1188
+ height_range: tuple[float, float],
1189
+ width_range: tuple[float, float],
1190
+ size: int,
1191
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
1192
+ """Calculate dimensions for dropout holes.
1193
+
1194
+ Args:
1195
+ volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width)
1196
+ depth_range (tuple[float, float]): Range for hole depth as fraction of volume depth
1197
+ height_range (tuple[float, float]): Range for hole height as fraction of volume height
1198
+ width_range (tuple[float, float]): Range for hole width as fraction of volume width
1199
+ size (int): Number of holes to generate
1200
+
1201
+ Returns:
1202
+ tuple[np.ndarray, np.ndarray, np.ndarray]: Arrays of hole dimensions (depths, heights, widths)
1203
+
1204
+ """
1205
+ depth, height, width = volume_shape[:3]
1206
+
1207
+ hole_depths = np.maximum(1, np.ceil(depth * self.random_generator.uniform(*depth_range, size=size))).astype(int)
1208
+ hole_heights = np.maximum(1, np.ceil(height * self.random_generator.uniform(*height_range, size=size))).astype(
1209
+ int,
1210
+ )
1211
+ hole_widths = np.maximum(1, np.ceil(width * self.random_generator.uniform(*width_range, size=size))).astype(int)
1212
+
1213
+ return hole_depths, hole_heights, hole_widths
1214
+
1215
+ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
1216
+ """Generate parameters for coarse dropout based on input data.
1217
+
1218
+ Args:
1219
+ params (dict[str, Any]): Dictionary of existing parameters
1220
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
1221
+
1222
+ Returns:
1223
+ dict[str, Any]: Dictionary containing generated hole parameters for dropout
1224
+
1225
+ """
1226
+ volume_shape = data["volume"].shape[:3]
1227
+
1228
+ num_holes = self.py_random.randint(*self.num_holes_range)
1229
+
1230
+ hole_depths, hole_heights, hole_widths = self.calculate_hole_dimensions(
1231
+ volume_shape,
1232
+ self.hole_depth_range,
1233
+ self.hole_height_range,
1234
+ self.hole_width_range,
1235
+ size=num_holes,
1236
+ )
1237
+
1238
+ depth, height, width = volume_shape[:3]
1239
+
1240
+ z_min = self.random_generator.integers(0, depth - hole_depths + 1, size=num_holes)
1241
+ y_min = self.random_generator.integers(0, height - hole_heights + 1, size=num_holes)
1242
+ x_min = self.random_generator.integers(0, width - hole_widths + 1, size=num_holes)
1243
+ z_max = z_min + hole_depths
1244
+ y_max = y_min + hole_heights
1245
+ x_max = x_min + hole_widths
1246
+
1247
+ holes = np.stack([z_min, y_min, x_min, z_max, y_max, x_max], axis=-1)
1248
+
1249
+ return {"holes": holes}
1250
+
1251
+ def apply_to_volume(self, volume: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
1252
+ """Apply dropout to a 3D volume.
1253
+
1254
+ Args:
1255
+ volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
1256
+ holes (np.ndarray): Array of holes with shape (num_holes, 6).
1257
+ Each hole is represented as [z1, y1, x1, z2, y2, x2]
1258
+ **params (Any): Additional parameters
1259
+
1260
+ Returns:
1261
+ np.ndarray: Volume with holes filled with the given value
1262
+
1263
+ """
1264
+ if holes.size == 0:
1265
+ return volume
1266
+
1267
+ return f3d.cutout3d(volume, holes, self.fill)
1268
+
1269
+ def apply_to_mask(self, mask: np.ndarray, holes: np.ndarray, **params: Any) -> np.ndarray:
1270
+ """Apply dropout to a 3D mask.
1271
+
1272
+ Args:
1273
+ mask (np.ndarray): Input mask with shape (depth, height, width) or (depth, height, width, channels)
1274
+ holes (np.ndarray): Array of holes with shape (num_holes, 6).
1275
+ Each hole is represented as [z1, y1, x1, z2, y2, x2]
1276
+ **params (Any): Additional parameters
1277
+
1278
+ Returns:
1279
+ np.ndarray: Mask with holes filled with the given value
1280
+
1281
+ """
1282
+ if self.fill_mask is None or holes.size == 0:
1283
+ return mask
1284
+
1285
+ return f3d.cutout3d(mask, holes, self.fill_mask)
1286
+
1287
+ def apply_to_keypoints(
1288
+ self,
1289
+ keypoints: np.ndarray,
1290
+ holes: np.ndarray,
1291
+ **params: Any,
1292
+ ) -> np.ndarray:
1293
+ """Apply dropout to keypoints.
1294
+
1295
+ Args:
1296
+ keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
1297
+ The first three columns are x, y, z coordinates.
1298
+ holes (np.ndarray): Array of holes with shape (num_holes, 6).
1299
+ Each hole is represented as [z1, y1, x1, z2, y2, x2]
1300
+ **params (Any): Additional parameters
1301
+
1302
+ Returns:
1303
+ np.ndarray: Filtered keypoints with same shape as input
1304
+
1305
+ """
1306
+ if holes.size == 0:
1307
+ return keypoints
1308
+ processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
1309
+
1310
+ if processor is None or not processor.params.remove_invisible:
1311
+ return keypoints
1312
+ return f3d.filter_keypoints_in_holes3d(keypoints, holes)
1313
+
1314
+
1315
+ class CubicSymmetry(Transform3D):
1316
+ """Applies a random cubic symmetry transformation to a 3D volume.
1317
+
1318
+ This transform is a 3D extension of D4. While D4 handles the 8 symmetries
1319
+ of a square (4 rotations x 2 reflections), CubicSymmetry handles all 48 symmetries of a cube.
1320
+ Like D4, this transform does not create any interpolation artifacts as it only remaps voxels
1321
+ from one position to another without any interpolation.
1322
+
1323
+ The 48 transformations consist of:
1324
+ - 24 rotations (orientation-preserving):
1325
+ * 4 rotations around each face diagonal (6 face diagonals x 4 rotations = 24)
1326
+ - 24 rotoreflections (orientation-reversing):
1327
+ * Reflection through a plane followed by any of the 24 rotations
1328
+
1329
+ For a cube, these transformations preserve:
1330
+ - All face centers (6)
1331
+ - All vertex positions (8)
1332
+ - All edge centers (12)
1333
+
1334
+ works with 3D volumes and masks of the shape (D, H, W) or (D, H, W, C)
1335
+
1336
+ Args:
1337
+ p (float): Probability of applying the transform. Default: 1.0
1338
+
1339
+ Targets:
1340
+ volume, mask3d, keypoints
1341
+
1342
+ Image types:
1343
+ uint8, float32
1344
+
1345
+ Note:
1346
+ - This transform is particularly useful for data augmentation in 3D medical imaging,
1347
+ crystallography, and voxel-based 3D modeling where the object's orientation
1348
+ is arbitrary.
1349
+ - All transformations preserve the object's chirality (handedness) when using
1350
+ pure rotations (indices 0-23) and invert it when using rotoreflections
1351
+ (indices 24-47).
1352
+
1353
+ Examples:
1354
+ >>> import numpy as np
1355
+ >>> import albumentations as A
1356
+ >>> volume = np.random.randint(0, 256, (10, 100, 100), dtype=np.uint8) # (D, H, W)
1357
+ >>> mask3d = np.random.randint(0, 2, (10, 100, 100), dtype=np.uint8) # (D, H, W)
1358
+ >>> transform = A.CubicSymmetry(p=1.0)
1359
+ >>> transformed = transform(volume=volume, mask3d=mask3d)
1360
+ >>> transformed_volume = transformed["volume"]
1361
+ >>> transformed_mask3d = transformed["mask3d"]
1362
+
1363
+ See Also:
1364
+ - D4: The 2D version that handles the 8 symmetries of a square
1365
+
1366
+ """
1367
+
1368
+ _targets = (Targets.VOLUME, Targets.MASK3D, Targets.KEYPOINTS)
1369
+
1370
+ def __init__(
1371
+ self,
1372
+ p: float = 1.0,
1373
+ ):
1374
+ super().__init__(p=p)
1375
+
1376
+ def get_params_dependent_on_data(
1377
+ self,
1378
+ params: dict[str, Any],
1379
+ data: dict[str, Any],
1380
+ ) -> dict[str, Any]:
1381
+ """Generate parameters for cubic symmetry transformation.
1382
+
1383
+ Args:
1384
+ params (dict[str, Any]): Dictionary of existing parameters
1385
+ data (dict[str, Any]): Dictionary containing input data with volume, mask, etc.
1386
+
1387
+ Returns:
1388
+ dict[str, Any]: Dictionary containing the randomly selected transformation index
1389
+
1390
+ """
1391
+ # Randomly select one of 48 possible transformations
1392
+ volume_shape = data["volume"].shape
1393
+ return {"index": self.py_random.randint(0, 47), "volume_shape": volume_shape}
1394
+
1395
+ def apply_to_volume(self, volume: np.ndarray, index: int, **params: Any) -> np.ndarray:
1396
+ """Apply cubic symmetry transformation to a 3D volume.
1397
+
1398
+ Args:
1399
+ volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
1400
+ index (int): Index of the transformation to apply (0-47)
1401
+ **params (Any): Additional parameters
1402
+
1403
+ Returns:
1404
+ np.ndarray: Transformed volume with same shape as input
1405
+
1406
+ """
1407
+ return f3d.transform_cube(volume, index)
1408
+
1409
+ def apply_to_keypoints(self, keypoints: np.ndarray, index: int, **params: Any) -> np.ndarray:
1410
+ """Apply cubic symmetry transformation to keypoints.
1411
+
1412
+ Args:
1413
+ keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
1414
+ The first three columns are x, y, z coordinates.
1415
+ index (int): Index of the transformation to apply (0-47)
1416
+ **params (Any): Additional parameters
1417
+
1418
+ Returns:
1419
+ np.ndarray: Transformed keypoints with same shape as input
1420
+
1421
+ """
1422
+ return f3d.transform_cube_keypoints(keypoints, index, volume_shape=params["volume_shape"])