nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
"""Transforms for text rendering and augmentation on images.
|
|
2
|
+
|
|
3
|
+
This module provides transforms for adding and manipulating text on images,
|
|
4
|
+
including text augmentation techniques like word insertion, deletion, and swapping.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Annotated, Any, Literal
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
from pydantic import AfterValidator
|
|
15
|
+
|
|
16
|
+
import albumentations.augmentations.text.functional as ftext
|
|
17
|
+
from albumentations.core.bbox_utils import check_bboxes, denormalize_bboxes
|
|
18
|
+
from albumentations.core.pydantic import check_range_bounds, nondecreasing
|
|
19
|
+
from albumentations.core.transforms_interface import BaseTransformInitSchema, ImageOnlyTransform
|
|
20
|
+
|
|
21
|
+
__all__ = ["TextImage"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TextImage(ImageOnlyTransform):
|
|
25
|
+
"""Apply text rendering transformations on images.
|
|
26
|
+
|
|
27
|
+
This class supports rendering text directly onto images using a variety of configurations,
|
|
28
|
+
such as custom fonts, font sizes, colors, and augmentation methods. The text can be placed
|
|
29
|
+
inside specified bounding boxes.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
font_path (str | Path): Path to the font file to use for rendering text.
|
|
33
|
+
stopwords (list[str] | None): List of stopwords for text augmentation.
|
|
34
|
+
augmentations (tuple[str | None, ...]): List of text augmentations to apply.
|
|
35
|
+
None: text is printed as is
|
|
36
|
+
"insertion": insert random stop words into the text.
|
|
37
|
+
"swap": swap random words in the text.
|
|
38
|
+
"deletion": delete random words from the text.
|
|
39
|
+
fraction_range (tuple[float, float]): Range for selecting a fraction of bounding boxes to modify.
|
|
40
|
+
font_size_fraction_range (tuple[float, float]): Range for selecting the font size as a fraction of
|
|
41
|
+
bounding box height.
|
|
42
|
+
font_color (tuple[float, ...]): Font color as RGB values (e.g., (0, 0, 0) for black).
|
|
43
|
+
clear_bg (bool): Whether to clear the background before rendering text.
|
|
44
|
+
metadata_key (str): Key to access metadata in the parameters.
|
|
45
|
+
p (float): Probability of applying the transform.
|
|
46
|
+
|
|
47
|
+
Targets:
|
|
48
|
+
image, volume
|
|
49
|
+
|
|
50
|
+
Image types:
|
|
51
|
+
uint8, float32
|
|
52
|
+
|
|
53
|
+
References:
|
|
54
|
+
doc-augmentation: https://github.com/danaaubakirova/doc-augmentation
|
|
55
|
+
|
|
56
|
+
Examples:
|
|
57
|
+
>>> import albumentations as A
|
|
58
|
+
>>> transform = A.Compose([
|
|
59
|
+
A.TextImage(
|
|
60
|
+
font_path=Path("/path/to/font.ttf"),
|
|
61
|
+
stopwords=("the", "is", "in"),
|
|
62
|
+
augmentations=("insertion", "deletion"),
|
|
63
|
+
fraction_range=(0.5, 1.0),
|
|
64
|
+
font_size_fraction_range=(0.5, 0.9),
|
|
65
|
+
font_color=(255, 0, 0), # red in RGB
|
|
66
|
+
metadata_key="text_metadata",
|
|
67
|
+
p=0.5
|
|
68
|
+
)
|
|
69
|
+
])
|
|
70
|
+
>>> transformed = transform(image=my_image, text_metadata=my_metadata)
|
|
71
|
+
>>> image = transformed['image']
|
|
72
|
+
# This will render text on `my_image` based on the metadata provided in `my_metadata`.
|
|
73
|
+
|
|
74
|
+
"""
|
|
75
|
+
|
|
76
|
+
class InitSchema(BaseTransformInitSchema):
|
|
77
|
+
font_path: str | Path
|
|
78
|
+
stopwords: tuple[str, ...]
|
|
79
|
+
augmentations: tuple[str | None, ...]
|
|
80
|
+
fraction_range: Annotated[
|
|
81
|
+
tuple[float, float],
|
|
82
|
+
AfterValidator(nondecreasing),
|
|
83
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
84
|
+
]
|
|
85
|
+
font_size_fraction_range: Annotated[
|
|
86
|
+
tuple[float, float],
|
|
87
|
+
AfterValidator(nondecreasing),
|
|
88
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
89
|
+
]
|
|
90
|
+
font_color: tuple[float, ...]
|
|
91
|
+
clear_bg: bool
|
|
92
|
+
metadata_key: str
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
font_path: str | Path,
|
|
97
|
+
stopwords: tuple[str, ...] = ("the", "is", "in", "at", "of"),
|
|
98
|
+
augmentations: tuple[Literal["insertion", "swap", "deletion"] | None, ...] = (None,),
|
|
99
|
+
fraction_range: tuple[float, float] = (1.0, 1.0),
|
|
100
|
+
font_size_fraction_range: tuple[float, float] = (0.8, 0.9),
|
|
101
|
+
font_color: tuple[float, ...] = (0, 0, 0), # black in RGB
|
|
102
|
+
clear_bg: bool = False,
|
|
103
|
+
metadata_key: str = "textimage_metadata",
|
|
104
|
+
p: float = 0.5,
|
|
105
|
+
) -> None:
|
|
106
|
+
super().__init__(p=p)
|
|
107
|
+
self.metadata_key = metadata_key
|
|
108
|
+
self.font_path = font_path
|
|
109
|
+
self.fraction_range = fraction_range
|
|
110
|
+
self.stopwords = stopwords
|
|
111
|
+
self.augmentations = list(augmentations)
|
|
112
|
+
self.font_size_fraction_range = font_size_fraction_range
|
|
113
|
+
self.font_color = font_color
|
|
114
|
+
self.clear_bg = clear_bg
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def targets_as_params(self) -> list[str]:
|
|
118
|
+
"""Get list of targets that should be passed as parameters to transforms.
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
list[str]: List containing the metadata key name
|
|
122
|
+
|
|
123
|
+
"""
|
|
124
|
+
return [self.metadata_key]
|
|
125
|
+
|
|
126
|
+
def random_aug(
|
|
127
|
+
self,
|
|
128
|
+
text: str,
|
|
129
|
+
fraction: float,
|
|
130
|
+
choice: Literal["insertion", "swap", "deletion"],
|
|
131
|
+
) -> str:
|
|
132
|
+
"""Apply a random text augmentation to the input text.
|
|
133
|
+
|
|
134
|
+
Args:
|
|
135
|
+
text (str): Original text to augment
|
|
136
|
+
fraction (float): Fraction of words to modify
|
|
137
|
+
choice (Literal["insertion", "swap", "deletion"]): Type of augmentation to apply
|
|
138
|
+
|
|
139
|
+
Returns:
|
|
140
|
+
str: Augmented text or empty string if no change was made
|
|
141
|
+
|
|
142
|
+
Raises:
|
|
143
|
+
ValueError: If an invalid choice is provided
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
words = [word for word in text.strip().split() if word]
|
|
147
|
+
num_words = len(words)
|
|
148
|
+
num_words_to_modify = max(1, int(fraction * num_words))
|
|
149
|
+
|
|
150
|
+
if choice == "insertion":
|
|
151
|
+
result_sentence = ftext.insert_random_stopwords(words, num_words_to_modify, self.stopwords, self.py_random)
|
|
152
|
+
elif choice == "swap":
|
|
153
|
+
result_sentence = ftext.swap_random_words(words, num_words_to_modify, self.py_random)
|
|
154
|
+
elif choice == "deletion":
|
|
155
|
+
result_sentence = ftext.delete_random_words(words, num_words_to_modify, self.py_random)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError("Invalid choice. Choose from 'insertion', 'swap', or 'deletion'.")
|
|
158
|
+
|
|
159
|
+
result_sentence = re.sub(" +", " ", result_sentence).strip()
|
|
160
|
+
return result_sentence if result_sentence != text else ""
|
|
161
|
+
|
|
162
|
+
def preprocess_metadata(
|
|
163
|
+
self,
|
|
164
|
+
image: np.ndarray,
|
|
165
|
+
bbox: tuple[float, float, float, float],
|
|
166
|
+
text: str,
|
|
167
|
+
bbox_index: int,
|
|
168
|
+
) -> dict[str, Any]:
|
|
169
|
+
"""Preprocess text metadata for a single bounding box.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
image (np.ndarray): Input image
|
|
173
|
+
bbox (tuple[float, float, float, float]): Normalized bounding box coordinates
|
|
174
|
+
text (str): Text to render in the bounding box
|
|
175
|
+
bbox_index (int): Index of the bounding box in the original metadata
|
|
176
|
+
|
|
177
|
+
Returns:
|
|
178
|
+
dict[str, Any]: Processed metadata including font, position, and text information
|
|
179
|
+
|
|
180
|
+
Raises:
|
|
181
|
+
ImportError: If PIL.ImageFont is not installed
|
|
182
|
+
|
|
183
|
+
"""
|
|
184
|
+
try:
|
|
185
|
+
from PIL import ImageFont
|
|
186
|
+
except ImportError as err:
|
|
187
|
+
raise ImportError(
|
|
188
|
+
"ImageFont from PIL is required to use TextImage transform. Install it with `pip install Pillow`.",
|
|
189
|
+
) from err
|
|
190
|
+
check_bboxes(np.array([bbox]))
|
|
191
|
+
denormalized_bbox = denormalize_bboxes(np.array([bbox]), image.shape[:2])[0]
|
|
192
|
+
|
|
193
|
+
x_min, y_min, x_max, y_max = (int(x) for x in denormalized_bbox[:4])
|
|
194
|
+
bbox_height = y_max - y_min
|
|
195
|
+
|
|
196
|
+
font_size_fraction = self.py_random.uniform(*self.font_size_fraction_range)
|
|
197
|
+
|
|
198
|
+
font = ImageFont.truetype(str(self.font_path), int(font_size_fraction * bbox_height))
|
|
199
|
+
|
|
200
|
+
if not self.augmentations or self.augmentations is None:
|
|
201
|
+
augmented_text = text
|
|
202
|
+
else:
|
|
203
|
+
augmentation = self.py_random.choice(self.augmentations)
|
|
204
|
+
|
|
205
|
+
augmented_text = text if augmentation is None else self.random_aug(text, 0.5, choice=augmentation)
|
|
206
|
+
|
|
207
|
+
font_color = self.font_color
|
|
208
|
+
|
|
209
|
+
return {
|
|
210
|
+
"bbox_coords": (x_min, y_min, x_max, y_max),
|
|
211
|
+
"bbox_index": bbox_index,
|
|
212
|
+
"original_text": text,
|
|
213
|
+
"text": augmented_text,
|
|
214
|
+
"font": font,
|
|
215
|
+
"font_color": font_color,
|
|
216
|
+
}
|
|
217
|
+
|
|
218
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
219
|
+
"""Generate parameters based on input data.
|
|
220
|
+
|
|
221
|
+
Args:
|
|
222
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
223
|
+
data (dict[str, Any]): Dictionary containing input data with image and metadata
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
dict[str, Any]: Dictionary containing the overlay data for text rendering
|
|
227
|
+
|
|
228
|
+
"""
|
|
229
|
+
image = data["image"] if "image" in data else data["images"][0]
|
|
230
|
+
|
|
231
|
+
metadata = data[self.metadata_key]
|
|
232
|
+
|
|
233
|
+
if metadata == []:
|
|
234
|
+
return {
|
|
235
|
+
"overlay_data": [],
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
if isinstance(metadata, dict):
|
|
239
|
+
metadata = [metadata]
|
|
240
|
+
|
|
241
|
+
fraction = self.py_random.uniform(*self.fraction_range)
|
|
242
|
+
|
|
243
|
+
num_bboxes_to_modify = int(len(metadata) * fraction)
|
|
244
|
+
|
|
245
|
+
bbox_indices_to_update = self.py_random.sample(range(len(metadata)), num_bboxes_to_modify)
|
|
246
|
+
|
|
247
|
+
overlay_data = [
|
|
248
|
+
self.preprocess_metadata(image, metadata[bbox_index]["bbox"], metadata[bbox_index]["text"], bbox_index)
|
|
249
|
+
for bbox_index in bbox_indices_to_update
|
|
250
|
+
]
|
|
251
|
+
|
|
252
|
+
return {
|
|
253
|
+
"overlay_data": overlay_data,
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
def apply(
|
|
257
|
+
self,
|
|
258
|
+
img: np.ndarray,
|
|
259
|
+
overlay_data: list[dict[str, Any]],
|
|
260
|
+
**params: Any,
|
|
261
|
+
) -> np.ndarray:
|
|
262
|
+
"""Apply text rendering to the input image.
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
img (np.ndarray): Input image
|
|
266
|
+
overlay_data (list[dict[str, Any]]): List of dictionaries containing text rendering information
|
|
267
|
+
**params (Any): Additional parameters
|
|
268
|
+
|
|
269
|
+
Returns:
|
|
270
|
+
np.ndarray: Image with rendered text
|
|
271
|
+
|
|
272
|
+
"""
|
|
273
|
+
return ftext.render_text(img, overlay_data, clear_bg=self.clear_bg)
|
|
274
|
+
|
|
275
|
+
def apply_with_params(self, params: dict[str, Any], *args: Any, **kwargs: Any) -> dict[str, Any]:
|
|
276
|
+
"""Apply the transform and include overlay data in the result.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
params (dict[str, Any]): Parameters for the transform
|
|
280
|
+
*args (Any): Additional positional arguments
|
|
281
|
+
**kwargs (Any): Additional keyword arguments
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
dict[str, Any]: Dictionary containing the transformed data and simplified overlay information
|
|
285
|
+
|
|
286
|
+
"""
|
|
287
|
+
res = super().apply_with_params(params, *args, **kwargs)
|
|
288
|
+
res["overlay_data"] = [
|
|
289
|
+
{
|
|
290
|
+
"bbox_coords": overlay["bbox_coords"],
|
|
291
|
+
"text": overlay["text"],
|
|
292
|
+
"original_text": overlay["original_text"],
|
|
293
|
+
"bbox_index": overlay["bbox_index"],
|
|
294
|
+
"font_color": overlay["font_color"],
|
|
295
|
+
}
|
|
296
|
+
for overlay in params["overlay_data"]
|
|
297
|
+
]
|
|
298
|
+
|
|
299
|
+
return res
|
|
File without changes
|
|
@@ -0,0 +1,393 @@
|
|
|
1
|
+
"""Module containing functional implementations of 3D transformations.
|
|
2
|
+
|
|
3
|
+
This module provides a collection of utility functions for manipulating and transforming
|
|
4
|
+
3D volumetric data (such as medical imaging volumes). The functions here implement the core
|
|
5
|
+
algorithms for operations like padding, cropping, rotation, and other spatial manipulations
|
|
6
|
+
specifically designed for 3D data.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import random
|
|
12
|
+
from typing import Literal
|
|
13
|
+
|
|
14
|
+
import numpy as np
|
|
15
|
+
|
|
16
|
+
from albumentations.augmentations.utils import handle_empty_array
|
|
17
|
+
from albumentations.core.type_definitions import NUM_VOLUME_DIMENSIONS
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def adjust_padding_by_position3d(
|
|
21
|
+
paddings: list[tuple[int, int]], # [(front, back), (top, bottom), (left, right)]
|
|
22
|
+
position: Literal["center", "random"],
|
|
23
|
+
py_random: random.Random,
|
|
24
|
+
) -> tuple[int, int, int, int, int, int]:
|
|
25
|
+
"""Adjust padding values based on desired position for 3D data.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
paddings (list[tuple[int, int]]): List of tuples containing padding pairs
|
|
29
|
+
for each dimension [(d_pad), (h_pad), (w_pad)]
|
|
30
|
+
position (Literal["center", "random"]): Position of the image after padding.
|
|
31
|
+
py_random (random.Random): Random number generator
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
tuple[int, int, int, int, int, int]: Final padding values (d_front, d_back, h_top, h_bottom, w_left, w_right)
|
|
35
|
+
|
|
36
|
+
"""
|
|
37
|
+
if position == "center":
|
|
38
|
+
return (
|
|
39
|
+
paddings[0][0], # d_front
|
|
40
|
+
paddings[0][1], # d_back
|
|
41
|
+
paddings[1][0], # h_top
|
|
42
|
+
paddings[1][1], # h_bottom
|
|
43
|
+
paddings[2][0], # w_left
|
|
44
|
+
paddings[2][1], # w_right
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# For random position, redistribute padding for each dimension
|
|
48
|
+
d_pad = sum(paddings[0])
|
|
49
|
+
h_pad = sum(paddings[1])
|
|
50
|
+
w_pad = sum(paddings[2])
|
|
51
|
+
|
|
52
|
+
return (
|
|
53
|
+
py_random.randint(0, d_pad), # d_front
|
|
54
|
+
d_pad - py_random.randint(0, d_pad), # d_back
|
|
55
|
+
py_random.randint(0, h_pad), # h_top
|
|
56
|
+
h_pad - py_random.randint(0, h_pad), # h_bottom
|
|
57
|
+
py_random.randint(0, w_pad), # w_left
|
|
58
|
+
w_pad - py_random.randint(0, w_pad), # w_right
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def pad_3d_with_params(
|
|
63
|
+
volume: np.ndarray,
|
|
64
|
+
padding: tuple[int, int, int, int, int, int],
|
|
65
|
+
value: tuple[float, ...] | float,
|
|
66
|
+
) -> np.ndarray:
|
|
67
|
+
"""Pad 3D volume with given parameters.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
71
|
+
padding (tuple[int, int, int, int, int, int]): Padding values in format:
|
|
72
|
+
(depth_front, depth_back, height_top, height_bottom, width_left, width_right)
|
|
73
|
+
where:
|
|
74
|
+
- depth_front/back: padding at start/end of depth axis (z)
|
|
75
|
+
- height_top/bottom: padding at start/end of height axis (y)
|
|
76
|
+
- width_left/right: padding at start/end of width axis (x)
|
|
77
|
+
value (tuple[float, ...] | float): Value to fill the padding
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
np.ndarray: Padded volume with same number of dimensions as input
|
|
81
|
+
|
|
82
|
+
Note:
|
|
83
|
+
The padding order matches the volume dimensions (depth, height, width).
|
|
84
|
+
For each dimension, the first value is padding at the start (smaller indices),
|
|
85
|
+
and the second value is padding at the end (larger indices).
|
|
86
|
+
|
|
87
|
+
"""
|
|
88
|
+
depth_front, depth_back, height_top, height_bottom, width_left, width_right = padding
|
|
89
|
+
|
|
90
|
+
# Skip if no padding is needed
|
|
91
|
+
if all(p == 0 for p in padding):
|
|
92
|
+
return volume
|
|
93
|
+
|
|
94
|
+
# Handle both 3D and 4D arrays
|
|
95
|
+
pad_width = [
|
|
96
|
+
(depth_front, depth_back), # depth (z) padding
|
|
97
|
+
(height_top, height_bottom), # height (y) padding
|
|
98
|
+
(width_left, width_right), # width (x) padding
|
|
99
|
+
]
|
|
100
|
+
|
|
101
|
+
# Add channel padding if 4D array
|
|
102
|
+
if volume.ndim == NUM_VOLUME_DIMENSIONS:
|
|
103
|
+
pad_width.append((0, 0)) # no padding for channels
|
|
104
|
+
|
|
105
|
+
return np.pad(
|
|
106
|
+
volume,
|
|
107
|
+
pad_width=pad_width,
|
|
108
|
+
mode="constant",
|
|
109
|
+
constant_values=value,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def crop3d(
|
|
114
|
+
volume: np.ndarray,
|
|
115
|
+
crop_coords: tuple[int, int, int, int, int, int],
|
|
116
|
+
) -> np.ndarray:
|
|
117
|
+
"""Crop 3D volume using coordinates.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
volume (np.ndarray): Input volume with shape (z, y, x) or (z, y, x, channels)
|
|
121
|
+
crop_coords (tuple[int, int, int, int, int, int]):
|
|
122
|
+
(z_min, z_max, y_min, y_max, x_min, x_max) coordinates for cropping
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
np.ndarray: Cropped volume with same number of dimensions as input
|
|
126
|
+
|
|
127
|
+
"""
|
|
128
|
+
z_min, z_max, y_min, y_max, x_min, x_max = crop_coords
|
|
129
|
+
|
|
130
|
+
return volume[z_min:z_max, y_min:y_max, x_min:x_max]
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def cutout3d(volume: np.ndarray, holes: np.ndarray, fill: tuple[float, ...] | float) -> np.ndarray:
|
|
134
|
+
"""Cut out holes in 3D volume and fill them with a given value.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
volume (np.ndarray): Input volume with shape (depth, height, width) or (depth, height, width, channels)
|
|
138
|
+
holes (np.ndarray): Array of holes with shape (num_holes, 6).
|
|
139
|
+
Each hole is represented as [z1, y1, x1, z2, y2, x2]
|
|
140
|
+
fill (tuple[float, ...] | float): Value to fill the holes
|
|
141
|
+
|
|
142
|
+
Returns:
|
|
143
|
+
np.ndarray: Volume with holes filled with the given value
|
|
144
|
+
|
|
145
|
+
"""
|
|
146
|
+
volume = volume.copy()
|
|
147
|
+
for z1, y1, x1, z2, y2, x2 in holes:
|
|
148
|
+
volume[z1:z2, y1:y2, x1:x2] = fill
|
|
149
|
+
return volume
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def transform_cube(cube: np.ndarray, index: int) -> np.ndarray:
|
|
153
|
+
"""Transform cube by index (0-47)
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
cube (np.ndarray): Input array with shape (D, H, W) or (D, H, W, C)
|
|
157
|
+
index (int): Integer from 0 to 47 specifying which transformation to apply
|
|
158
|
+
|
|
159
|
+
Returns:
|
|
160
|
+
np.ndarray: Transformed cube with same shape as input
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
if not (0 <= index < 48):
|
|
164
|
+
raise ValueError("Index must be between 0 and 47")
|
|
165
|
+
|
|
166
|
+
transformations = {
|
|
167
|
+
# First 4: rotate around axis 0 (indices 0-3)
|
|
168
|
+
0: lambda x: x,
|
|
169
|
+
1: lambda x: np.rot90(x, k=1, axes=(1, 2)),
|
|
170
|
+
2: lambda x: np.rot90(x, k=2, axes=(1, 2)),
|
|
171
|
+
3: lambda x: np.rot90(x, k=3, axes=(1, 2)),
|
|
172
|
+
# Next 4: flip 180° about axis 1, then rotate around axis 0 (indices 4-7)
|
|
173
|
+
4: lambda x: x[::-1, :, ::-1], # was: np.flip(x, axis=(0, 2))
|
|
174
|
+
5: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=1, axes=(1, 2)),
|
|
175
|
+
6: lambda x: x[::-1, ::-1, :], # was: np.flip(x, axis=(0, 1))
|
|
176
|
+
7: lambda x: np.rot90(np.rot90(x, k=2, axes=(0, 2)), k=3, axes=(1, 2)),
|
|
177
|
+
# Next 8: split between 90° and 270° about axis 1, then rotate around axis 2 (indices 8-15)
|
|
178
|
+
8: lambda x: np.rot90(x, k=1, axes=(0, 2)),
|
|
179
|
+
9: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=1, axes=(0, 1)),
|
|
180
|
+
10: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 2)), k=2, axes=(0, 1)),
|
|
181
|
+
11: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim)),
|
|
182
|
+
12: lambda x: np.rot90(x, k=-1, axes=(0, 2)),
|
|
183
|
+
13: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=1, axes=(0, 1)),
|
|
184
|
+
14: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=2, axes=(0, 1)),
|
|
185
|
+
15: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 2)), k=3, axes=(0, 1)),
|
|
186
|
+
# Final 8: split between rotations about axis 2, then rotate around axis 1 (indices 16-23)
|
|
187
|
+
16: lambda x: np.rot90(x, k=1, axes=(0, 1)),
|
|
188
|
+
17: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=1, axes=(0, 2)),
|
|
189
|
+
18: lambda x: np.rot90(np.rot90(x, k=1, axes=(0, 1)), k=2, axes=(0, 2)),
|
|
190
|
+
19: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim)),
|
|
191
|
+
20: lambda x: np.rot90(x, k=-1, axes=(0, 1)),
|
|
192
|
+
21: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=1, axes=(0, 2)),
|
|
193
|
+
22: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=2, axes=(0, 2)),
|
|
194
|
+
23: lambda x: np.rot90(np.rot90(x, k=-1, axes=(0, 1)), k=3, axes=(0, 2)),
|
|
195
|
+
# Reflected versions (24-47) - same as above but with initial reflection
|
|
196
|
+
24: lambda x: x[:, :, ::-1], # was: np.flip(x, axis=2)
|
|
197
|
+
25: lambda x: x.transpose(0, 2, 1, *range(3, x.ndim)),
|
|
198
|
+
26: lambda x: x[:, ::-1, :], # was: np.flip(x, axis=1)
|
|
199
|
+
27: lambda x: np.rot90(x[:, :, ::-1], k=3, axes=(1, 2)),
|
|
200
|
+
28: lambda x: x[::-1, :, :], # was: np.flip(x, axis=0)
|
|
201
|
+
29: lambda x: np.rot90(x[::-1, :, :], k=1, axes=(1, 2)),
|
|
202
|
+
30: lambda x: x[::-1, ::-1, ::-1], # was: np.flip(x, axis=(0, 1, 2))
|
|
203
|
+
31: lambda x: np.rot90(x[::-1, :, :], k=-1, axes=(1, 2)),
|
|
204
|
+
32: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim)),
|
|
205
|
+
33: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, :, :],
|
|
206
|
+
34: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[::-1, ::-1, :],
|
|
207
|
+
35: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, ::-1, :],
|
|
208
|
+
36: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 2)),
|
|
209
|
+
37: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[::-1, ::-1, ::-1],
|
|
210
|
+
38: lambda x: x.transpose(2, 1, 0, *range(3, x.ndim))[:, ::-1, ::-1],
|
|
211
|
+
39: lambda x: x.transpose(1, 2, 0, *range(3, x.ndim))[:, :, ::-1],
|
|
212
|
+
40: lambda x: np.rot90(x[:, :, ::-1], k=1, axes=(0, 1)),
|
|
213
|
+
41: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, :, ::-1],
|
|
214
|
+
42: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim)),
|
|
215
|
+
43: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, :, :],
|
|
216
|
+
44: lambda x: np.rot90(x[:, :, ::-1], k=-1, axes=(0, 1)),
|
|
217
|
+
45: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[:, ::-1, :],
|
|
218
|
+
46: lambda x: x.transpose(1, 0, 2, *range(3, x.ndim))[::-1, ::-1, :],
|
|
219
|
+
47: lambda x: x.transpose(2, 0, 1, *range(3, x.ndim))[::-1, ::-1, ::-1],
|
|
220
|
+
}
|
|
221
|
+
|
|
222
|
+
return transformations[index](cube.copy())
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
@handle_empty_array("keypoints")
|
|
226
|
+
def filter_keypoints_in_holes3d(keypoints: np.ndarray, holes: np.ndarray) -> np.ndarray:
|
|
227
|
+
"""Filter out keypoints that are inside any of the 3D holes.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
231
|
+
The first three columns are x, y, z coordinates.
|
|
232
|
+
holes (np.ndarray): Array of holes with shape (num_holes, 6).
|
|
233
|
+
Each hole is represented as [z1, y1, x1, z2, y2, x2].
|
|
234
|
+
|
|
235
|
+
Returns:
|
|
236
|
+
np.ndarray: Array of keypoints that are not inside any hole.
|
|
237
|
+
|
|
238
|
+
"""
|
|
239
|
+
if holes.size == 0:
|
|
240
|
+
return keypoints
|
|
241
|
+
|
|
242
|
+
# Broadcast keypoints and holes for vectorized comparison
|
|
243
|
+
# Convert keypoints from XYZ to ZYX for comparison with holes
|
|
244
|
+
kp_z = keypoints[:, 2][:, np.newaxis] # Shape: (num_keypoints, 1)
|
|
245
|
+
kp_y = keypoints[:, 1][:, np.newaxis] # Shape: (num_keypoints, 1)
|
|
246
|
+
kp_x = keypoints[:, 0][:, np.newaxis] # Shape: (num_keypoints, 1)
|
|
247
|
+
|
|
248
|
+
# Extract hole coordinates (in ZYX order)
|
|
249
|
+
hole_z1 = holes[:, 0] # Shape: (num_holes,)
|
|
250
|
+
hole_y1 = holes[:, 1]
|
|
251
|
+
hole_x1 = holes[:, 2]
|
|
252
|
+
hole_z2 = holes[:, 3]
|
|
253
|
+
hole_y2 = holes[:, 4]
|
|
254
|
+
hole_x2 = holes[:, 5]
|
|
255
|
+
|
|
256
|
+
# Check if each keypoint is inside each hole
|
|
257
|
+
inside_hole = (
|
|
258
|
+
(kp_z >= hole_z1)
|
|
259
|
+
& (kp_z < hole_z2)
|
|
260
|
+
& (kp_y >= hole_y1)
|
|
261
|
+
& (kp_y < hole_y2)
|
|
262
|
+
& (kp_x >= hole_x1)
|
|
263
|
+
& (kp_x < hole_x2)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
# A keypoint is valid if it's not inside any hole
|
|
267
|
+
valid_keypoints = ~np.any(inside_hole, axis=1)
|
|
268
|
+
|
|
269
|
+
# Return filtered keypoints with same dtype as input
|
|
270
|
+
result = keypoints[valid_keypoints]
|
|
271
|
+
if len(result) == 0:
|
|
272
|
+
# Ensure empty result has correct shape and dtype
|
|
273
|
+
return np.array([], dtype=keypoints.dtype).reshape(0, keypoints.shape[1])
|
|
274
|
+
return result
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def keypoints_rot90(
|
|
278
|
+
keypoints: np.ndarray,
|
|
279
|
+
k: int,
|
|
280
|
+
axes: tuple[int, int],
|
|
281
|
+
volume_shape: tuple[int, int, int],
|
|
282
|
+
) -> np.ndarray:
|
|
283
|
+
"""Rotate keypoints 90 degrees k times around the specified axes.
|
|
284
|
+
|
|
285
|
+
Args:
|
|
286
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
287
|
+
The first three columns are x, y, z coordinates.
|
|
288
|
+
k (int): Number of times to rotate by 90 degrees.
|
|
289
|
+
axes (tuple[int, int]): Axes to rotate around.
|
|
290
|
+
volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
|
|
291
|
+
|
|
292
|
+
Returns:
|
|
293
|
+
np.ndarray: Rotated keypoints with same shape as input.
|
|
294
|
+
|
|
295
|
+
"""
|
|
296
|
+
if k == 0 or len(keypoints) == 0:
|
|
297
|
+
return keypoints
|
|
298
|
+
|
|
299
|
+
# Normalize factor to range [0, 3]
|
|
300
|
+
k = ((k % 4) + 4) % 4
|
|
301
|
+
|
|
302
|
+
result = keypoints.copy()
|
|
303
|
+
|
|
304
|
+
# Get dimensions for the rotation axes
|
|
305
|
+
dims = [volume_shape[ax] for ax in axes]
|
|
306
|
+
|
|
307
|
+
# Get coordinates to rotate
|
|
308
|
+
coords1 = result[:, axes[0]].copy()
|
|
309
|
+
coords2 = result[:, axes[1]].copy()
|
|
310
|
+
|
|
311
|
+
# Apply rotation based on factor (counterclockwise)
|
|
312
|
+
if k == 1: # 90 degrees CCW
|
|
313
|
+
result[:, axes[0]] = (dims[1] - 1) - coords2
|
|
314
|
+
result[:, axes[1]] = coords1
|
|
315
|
+
elif k == 2: # 180 degrees
|
|
316
|
+
result[:, axes[0]] = (dims[0] - 1) - coords1
|
|
317
|
+
result[:, axes[1]] = (dims[1] - 1) - coords2
|
|
318
|
+
elif k == 3: # 270 degrees CCW
|
|
319
|
+
result[:, axes[0]] = coords2
|
|
320
|
+
result[:, axes[1]] = (dims[0] - 1) - coords1
|
|
321
|
+
|
|
322
|
+
return result
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
@handle_empty_array("keypoints")
|
|
326
|
+
def transform_cube_keypoints(
|
|
327
|
+
keypoints: np.ndarray,
|
|
328
|
+
index: int,
|
|
329
|
+
volume_shape: tuple[int, int, int],
|
|
330
|
+
) -> np.ndarray:
|
|
331
|
+
"""Transform keypoints according to the cube transformation specified by index.
|
|
332
|
+
|
|
333
|
+
Args:
|
|
334
|
+
keypoints (np.ndarray): Array of keypoints with shape (num_keypoints, 3+).
|
|
335
|
+
The first three columns are x, y, z coordinates.
|
|
336
|
+
index (int): Integer from 0 to 47 specifying which transformation to apply.
|
|
337
|
+
volume_shape (tuple[int, int, int]): Shape of the volume (depth, height, width).
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
np.ndarray: Transformed keypoints with same shape as input.
|
|
341
|
+
|
|
342
|
+
"""
|
|
343
|
+
if not (0 <= index < 48):
|
|
344
|
+
raise ValueError("Index must be between 0 and 47")
|
|
345
|
+
|
|
346
|
+
# Create working copy preserving all columns
|
|
347
|
+
working_points = keypoints.copy()
|
|
348
|
+
|
|
349
|
+
# Convert only XYZ coordinates to HWD, keeping other columns unchanged
|
|
350
|
+
xyz = working_points[:, :3] # Get first 3 columns (XYZ)
|
|
351
|
+
xyz = xyz[:, [2, 1, 0]] # XYZ -> HWD
|
|
352
|
+
working_points[:, :3] = xyz # Put back transformed coordinates
|
|
353
|
+
|
|
354
|
+
current_shape = volume_shape
|
|
355
|
+
|
|
356
|
+
# Handle reflection first (indices 24-47)
|
|
357
|
+
if index >= 24:
|
|
358
|
+
working_points[:, 2] = current_shape[2] - 1 - working_points[:, 2] # Reflect W axis
|
|
359
|
+
|
|
360
|
+
rotation_index = index % 24
|
|
361
|
+
|
|
362
|
+
# Apply the same rotation logic as transform_cube
|
|
363
|
+
if rotation_index < 4:
|
|
364
|
+
# First 4: rotate around axis 0
|
|
365
|
+
result = keypoints_rot90(working_points, k=rotation_index, axes=(1, 2), volume_shape=current_shape)
|
|
366
|
+
elif rotation_index < 8:
|
|
367
|
+
# Next 4: flip 180° about axis 1, then rotate around axis 0
|
|
368
|
+
temp = keypoints_rot90(working_points, k=2, axes=(0, 2), volume_shape=current_shape)
|
|
369
|
+
result = keypoints_rot90(temp, k=rotation_index - 4, axes=(1, 2), volume_shape=volume_shape)
|
|
370
|
+
elif rotation_index < 16:
|
|
371
|
+
if rotation_index < 12:
|
|
372
|
+
temp = keypoints_rot90(working_points, k=1, axes=(0, 2), volume_shape=current_shape)
|
|
373
|
+
temp_shape = (current_shape[2], current_shape[1], current_shape[0])
|
|
374
|
+
result = keypoints_rot90(temp, k=rotation_index - 8, axes=(0, 1), volume_shape=temp_shape)
|
|
375
|
+
else:
|
|
376
|
+
temp = keypoints_rot90(working_points, k=3, axes=(0, 2), volume_shape=current_shape)
|
|
377
|
+
temp_shape = (current_shape[2], current_shape[1], current_shape[0])
|
|
378
|
+
result = keypoints_rot90(temp, k=rotation_index - 12, axes=(0, 1), volume_shape=temp_shape)
|
|
379
|
+
elif rotation_index < 20:
|
|
380
|
+
temp = keypoints_rot90(working_points, k=1, axes=(0, 1), volume_shape=current_shape)
|
|
381
|
+
temp_shape = (current_shape[1], current_shape[0], current_shape[2])
|
|
382
|
+
result = keypoints_rot90(temp, k=rotation_index - 16, axes=(0, 2), volume_shape=temp_shape)
|
|
383
|
+
else:
|
|
384
|
+
temp = keypoints_rot90(working_points, k=3, axes=(0, 1), volume_shape=current_shape)
|
|
385
|
+
temp_shape = (current_shape[1], current_shape[0], current_shape[2])
|
|
386
|
+
result = keypoints_rot90(temp, k=rotation_index - 20, axes=(0, 2), volume_shape=temp_shape)
|
|
387
|
+
|
|
388
|
+
# Convert back from HWD to XYZ coordinates for first 3 columns only
|
|
389
|
+
xyz = result[:, :3]
|
|
390
|
+
xyz = xyz[:, [2, 1, 0]] # HWD -> XYZ
|
|
391
|
+
result[:, :3] = xyz
|
|
392
|
+
|
|
393
|
+
return result
|