nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,832 @@
|
|
|
1
|
+
"""Transforms that combine multiple images and their associated annotations.
|
|
2
|
+
|
|
3
|
+
This module contains transformations that take multiple input sources (e.g., a primary image
|
|
4
|
+
and additional images provided via metadata) and combine them into a single output.
|
|
5
|
+
Examples include overlaying elements (`OverlayElements`) or creating complex compositions
|
|
6
|
+
like `Mosaic`.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import random
|
|
12
|
+
from copy import deepcopy
|
|
13
|
+
from typing import Annotated, Any, Literal, cast
|
|
14
|
+
|
|
15
|
+
import cv2
|
|
16
|
+
import numpy as np
|
|
17
|
+
from pydantic import AfterValidator, model_validator
|
|
18
|
+
from typing_extensions import Self
|
|
19
|
+
|
|
20
|
+
from albumentations.augmentations.mixing import functional as fmixing
|
|
21
|
+
from albumentations.core.bbox_utils import BboxProcessor, check_bboxes, denormalize_bboxes, filter_bboxes
|
|
22
|
+
from albumentations.core.keypoints_utils import KeypointsProcessor
|
|
23
|
+
from albumentations.core.pydantic import check_range_bounds, nondecreasing
|
|
24
|
+
from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
|
|
25
|
+
from albumentations.core.type_definitions import LENGTH_RAW_BBOX, Targets
|
|
26
|
+
|
|
27
|
+
__all__ = ["Mosaic", "OverlayElements"]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class OverlayElements(DualTransform):
|
|
31
|
+
"""Apply overlay elements such as images and masks onto an input image. This transformation can be used to add
|
|
32
|
+
various objects (e.g., stickers, logos) to images with optional masks and bounding boxes for better placement
|
|
33
|
+
control.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
metadata_key (str): Additional target key for metadata. Default `overlay_metadata`.
|
|
37
|
+
p (float): Probability of applying the transformation. Default: 0.5.
|
|
38
|
+
|
|
39
|
+
Possible Metadata Fields:
|
|
40
|
+
- image (np.ndarray): The overlay image to be applied. This is a required field.
|
|
41
|
+
- bbox (list[int]): The bounding box specifying the region where the overlay should be applied. It should
|
|
42
|
+
contain four floats: [y_min, x_min, y_max, x_max]. If `label_id` is provided, it should
|
|
43
|
+
be appended as the fifth element in the bbox. BBox should be in Albumentations format,
|
|
44
|
+
that is the same as normalized Pascal VOC format
|
|
45
|
+
[x_min / width, y_min / height, x_max / width, y_max / height]
|
|
46
|
+
- mask (np.ndarray): An optional mask that defines the non-rectangular region of the overlay image. If not
|
|
47
|
+
provided, the entire overlay image is used.
|
|
48
|
+
- mask_id (int): An optional identifier for the mask. If provided, the regions specified by the mask will
|
|
49
|
+
be labeled with this identifier in the output mask.
|
|
50
|
+
|
|
51
|
+
Targets:
|
|
52
|
+
image, mask
|
|
53
|
+
|
|
54
|
+
Image types:
|
|
55
|
+
uint8, float32
|
|
56
|
+
|
|
57
|
+
References:
|
|
58
|
+
doc-augmentation: https://github.com/danaaubakirova/doc-augmentation
|
|
59
|
+
|
|
60
|
+
Examples:
|
|
61
|
+
>>> import numpy as np
|
|
62
|
+
>>> import albumentations as A
|
|
63
|
+
>>> import cv2
|
|
64
|
+
>>>
|
|
65
|
+
>>> # Prepare primary data (base image and mask)
|
|
66
|
+
>>> image = np.zeros((300, 300, 3), dtype=np.uint8)
|
|
67
|
+
>>> mask = np.zeros((300, 300), dtype=np.uint8)
|
|
68
|
+
>>>
|
|
69
|
+
>>> # 1. Create a simple overlay image (a red square)
|
|
70
|
+
>>> overlay_image1 = np.zeros((50, 50, 3), dtype=np.uint8)
|
|
71
|
+
>>> overlay_image1[:, :, 0] = 255 # Red color
|
|
72
|
+
>>>
|
|
73
|
+
>>> # 2. Create another overlay with a mask (a blue circle with transparency)
|
|
74
|
+
>>> overlay_image2 = np.zeros((80, 80, 3), dtype=np.uint8)
|
|
75
|
+
>>> overlay_image2[:, :, 2] = 255 # Blue color
|
|
76
|
+
>>> overlay_mask2 = np.zeros((80, 80), dtype=np.uint8)
|
|
77
|
+
>>> # Create a circular mask
|
|
78
|
+
>>> center = (40, 40)
|
|
79
|
+
>>> radius = 30
|
|
80
|
+
>>> for i in range(80):
|
|
81
|
+
... for j in range(80):
|
|
82
|
+
... if (i - center[0])**2 + (j - center[1])**2 < radius**2:
|
|
83
|
+
... overlay_mask2[i, j] = 255
|
|
84
|
+
>>>
|
|
85
|
+
>>> # 3. Create an overlay with both bbox and mask_id
|
|
86
|
+
>>> overlay_image3 = np.zeros((60, 120, 3), dtype=np.uint8)
|
|
87
|
+
>>> overlay_image3[:, :, 1] = 255 # Green color
|
|
88
|
+
>>> # Create a rectangular mask with rounded corners
|
|
89
|
+
>>> overlay_mask3 = np.zeros((60, 120), dtype=np.uint8)
|
|
90
|
+
>>> cv2.rectangle(overlay_mask3, (10, 10), (110, 50), 255, -1)
|
|
91
|
+
>>>
|
|
92
|
+
>>> # Create the metadata list - each item is a dictionary with overlay information
|
|
93
|
+
>>> overlay_metadata = [
|
|
94
|
+
... {
|
|
95
|
+
... 'image': overlay_image1,
|
|
96
|
+
... # No bbox provided - will be placed randomly
|
|
97
|
+
... },
|
|
98
|
+
... {
|
|
99
|
+
... 'image': overlay_image2,
|
|
100
|
+
... 'bbox': [0.6, 0.1, 0.9, 0.4], # Normalized coordinates [x_min, y_min, x_max, y_max]
|
|
101
|
+
... 'mask': overlay_mask2,
|
|
102
|
+
... 'mask_id': 1 # This overlay will update the mask with id 1
|
|
103
|
+
... },
|
|
104
|
+
... {
|
|
105
|
+
... 'image': overlay_image3,
|
|
106
|
+
... 'bbox': [0.1, 0.7, 0.5, 0.9], # Bottom left placement
|
|
107
|
+
... 'mask': overlay_mask3,
|
|
108
|
+
... 'mask_id': 2 # This overlay will update the mask with id 2
|
|
109
|
+
... }
|
|
110
|
+
... ]
|
|
111
|
+
>>>
|
|
112
|
+
>>> # Create the transform
|
|
113
|
+
>>> transform = A.Compose([
|
|
114
|
+
... A.OverlayElements(p=1.0),
|
|
115
|
+
... ])
|
|
116
|
+
>>>
|
|
117
|
+
>>> # Apply the transform
|
|
118
|
+
>>> result = transform(
|
|
119
|
+
... image=image,
|
|
120
|
+
... mask=mask,
|
|
121
|
+
... overlay_metadata=overlay_metadata # Pass metadata using the default key
|
|
122
|
+
... )
|
|
123
|
+
>>>
|
|
124
|
+
>>> # Get results with overlays applied
|
|
125
|
+
>>> result_image = result['image'] # Image with the three overlays applied
|
|
126
|
+
>>> result_mask = result['mask'] # Mask with regions labeled using the mask_id values
|
|
127
|
+
>>>
|
|
128
|
+
>>> # Let's verify the mask contains the specified mask_id values
|
|
129
|
+
>>> has_mask_id_1 = np.any(result_mask == 1) # Should be True
|
|
130
|
+
>>> has_mask_id_2 = np.any(result_mask == 2) # Should be True
|
|
131
|
+
|
|
132
|
+
"""
|
|
133
|
+
|
|
134
|
+
_targets = (Targets.IMAGE, Targets.MASK)
|
|
135
|
+
|
|
136
|
+
class InitSchema(BaseTransformInitSchema):
|
|
137
|
+
metadata_key: str
|
|
138
|
+
|
|
139
|
+
def __init__(
|
|
140
|
+
self,
|
|
141
|
+
metadata_key: str = "overlay_metadata",
|
|
142
|
+
p: float = 0.5,
|
|
143
|
+
):
|
|
144
|
+
super().__init__(p=p)
|
|
145
|
+
self.metadata_key = metadata_key
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def targets_as_params(self) -> list[str]:
|
|
149
|
+
"""Get list of targets that should be passed as parameters to transforms.
|
|
150
|
+
|
|
151
|
+
Returns:
|
|
152
|
+
list[str]: List containing the metadata key name
|
|
153
|
+
|
|
154
|
+
"""
|
|
155
|
+
return [self.metadata_key]
|
|
156
|
+
|
|
157
|
+
@staticmethod
|
|
158
|
+
def preprocess_metadata(
|
|
159
|
+
metadata: dict[str, Any],
|
|
160
|
+
img_shape: tuple[int, int],
|
|
161
|
+
random_state: random.Random,
|
|
162
|
+
) -> dict[str, Any]:
|
|
163
|
+
"""Process overlay metadata to prepare for application.
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
metadata (dict[str, Any]): Dictionary containing overlay data such as image, mask, bbox
|
|
167
|
+
img_shape (tuple[int, int]): Shape of the target image as (height, width)
|
|
168
|
+
random_state (random.Random): Random state object for reproducible randomness
|
|
169
|
+
|
|
170
|
+
Returns:
|
|
171
|
+
dict[str, Any]: Processed overlay data including resized overlay image, mask,
|
|
172
|
+
offset coordinates, and bounding box information
|
|
173
|
+
|
|
174
|
+
"""
|
|
175
|
+
overlay_image = metadata["image"]
|
|
176
|
+
overlay_height, overlay_width = overlay_image.shape[:2]
|
|
177
|
+
image_height, image_width = img_shape[:2]
|
|
178
|
+
|
|
179
|
+
if "bbox" in metadata:
|
|
180
|
+
bbox = metadata["bbox"]
|
|
181
|
+
bbox_np = np.array([bbox])
|
|
182
|
+
check_bboxes(bbox_np)
|
|
183
|
+
denormalized_bbox = denormalize_bboxes(bbox_np, img_shape[:2])[0]
|
|
184
|
+
|
|
185
|
+
x_min, y_min, x_max, y_max = (int(x) for x in denormalized_bbox[:4])
|
|
186
|
+
|
|
187
|
+
if "mask" in metadata:
|
|
188
|
+
mask = metadata["mask"]
|
|
189
|
+
mask = cv2.resize(mask, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_NEAREST)
|
|
190
|
+
else:
|
|
191
|
+
mask = np.ones((y_max - y_min, x_max - x_min), dtype=np.uint8)
|
|
192
|
+
|
|
193
|
+
overlay_image = cv2.resize(overlay_image, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_AREA)
|
|
194
|
+
offset = (y_min, x_min)
|
|
195
|
+
|
|
196
|
+
if len(bbox) == LENGTH_RAW_BBOX and "bbox_id" in metadata:
|
|
197
|
+
bbox = [x_min, y_min, x_max, y_max, metadata["bbox_id"]]
|
|
198
|
+
else:
|
|
199
|
+
bbox = (x_min, y_min, x_max, y_max, *bbox[4:])
|
|
200
|
+
else:
|
|
201
|
+
if image_height < overlay_height or image_width < overlay_width:
|
|
202
|
+
overlay_image = cv2.resize(overlay_image, (image_width, image_height), interpolation=cv2.INTER_AREA)
|
|
203
|
+
overlay_height, overlay_width = overlay_image.shape[:2]
|
|
204
|
+
|
|
205
|
+
mask = metadata["mask"] if "mask" in metadata else np.ones_like(overlay_image, dtype=np.uint8)
|
|
206
|
+
|
|
207
|
+
max_x_offset = image_width - overlay_width
|
|
208
|
+
max_y_offset = image_height - overlay_height
|
|
209
|
+
|
|
210
|
+
offset_x = random_state.randint(0, max_x_offset)
|
|
211
|
+
offset_y = random_state.randint(0, max_y_offset)
|
|
212
|
+
|
|
213
|
+
offset = (offset_y, offset_x)
|
|
214
|
+
|
|
215
|
+
bbox = [
|
|
216
|
+
offset_x,
|
|
217
|
+
offset_y,
|
|
218
|
+
offset_x + overlay_width,
|
|
219
|
+
offset_y + overlay_height,
|
|
220
|
+
]
|
|
221
|
+
|
|
222
|
+
if "bbox_id" in metadata:
|
|
223
|
+
bbox = [*bbox, metadata["bbox_id"]]
|
|
224
|
+
|
|
225
|
+
result = {
|
|
226
|
+
"overlay_image": overlay_image,
|
|
227
|
+
"overlay_mask": mask,
|
|
228
|
+
"offset": offset,
|
|
229
|
+
"bbox": bbox,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if "mask_id" in metadata:
|
|
233
|
+
result["mask_id"] = metadata["mask_id"]
|
|
234
|
+
|
|
235
|
+
return result
|
|
236
|
+
|
|
237
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
238
|
+
"""Generate parameters for overlay transform based on input data.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
params (dict[str, Any]): Dictionary of existing parameters
|
|
242
|
+
data (dict[str, Any]): Dictionary containing input data with image and metadata
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
dict[str, Any]: Dictionary containing processed overlay data ready for application
|
|
246
|
+
|
|
247
|
+
"""
|
|
248
|
+
metadata = data[self.metadata_key]
|
|
249
|
+
img_shape = params["shape"]
|
|
250
|
+
|
|
251
|
+
if isinstance(metadata, list):
|
|
252
|
+
overlay_data = [self.preprocess_metadata(md, img_shape, self.py_random) for md in metadata]
|
|
253
|
+
else:
|
|
254
|
+
overlay_data = [self.preprocess_metadata(metadata, img_shape, self.py_random)]
|
|
255
|
+
|
|
256
|
+
return {
|
|
257
|
+
"overlay_data": overlay_data,
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
def apply(
|
|
261
|
+
self,
|
|
262
|
+
img: np.ndarray,
|
|
263
|
+
overlay_data: list[dict[str, Any]],
|
|
264
|
+
**params: Any,
|
|
265
|
+
) -> np.ndarray:
|
|
266
|
+
"""Apply overlay elements to the input image.
|
|
267
|
+
|
|
268
|
+
Args:
|
|
269
|
+
img (np.ndarray): Input image
|
|
270
|
+
overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
|
|
271
|
+
**params (Any): Additional parameters
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
np.ndarray: Image with overlays applied
|
|
275
|
+
|
|
276
|
+
"""
|
|
277
|
+
for data in overlay_data:
|
|
278
|
+
overlay_image = data["overlay_image"]
|
|
279
|
+
overlay_mask = data["overlay_mask"]
|
|
280
|
+
offset = data["offset"]
|
|
281
|
+
img = fmixing.copy_and_paste_blend(img, overlay_image, overlay_mask, offset=offset)
|
|
282
|
+
return img
|
|
283
|
+
|
|
284
|
+
def apply_to_mask(
|
|
285
|
+
self,
|
|
286
|
+
mask: np.ndarray,
|
|
287
|
+
overlay_data: list[dict[str, Any]],
|
|
288
|
+
**params: Any,
|
|
289
|
+
) -> np.ndarray:
|
|
290
|
+
"""Apply overlay masks to the input mask.
|
|
291
|
+
|
|
292
|
+
Args:
|
|
293
|
+
mask (np.ndarray): Input mask
|
|
294
|
+
overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
|
|
295
|
+
**params (Any): Additional parameters
|
|
296
|
+
|
|
297
|
+
Returns:
|
|
298
|
+
np.ndarray: Mask with overlay masks applied using the specified mask_id values
|
|
299
|
+
|
|
300
|
+
"""
|
|
301
|
+
for data in overlay_data:
|
|
302
|
+
if "mask_id" in data and data["mask_id"] is not None:
|
|
303
|
+
overlay_mask = data["overlay_mask"]
|
|
304
|
+
offset = data["offset"]
|
|
305
|
+
mask_id = data["mask_id"]
|
|
306
|
+
|
|
307
|
+
y_min, x_min = offset
|
|
308
|
+
y_max = y_min + overlay_mask.shape[0]
|
|
309
|
+
x_max = x_min + overlay_mask.shape[1]
|
|
310
|
+
|
|
311
|
+
mask_section = mask[y_min:y_max, x_min:x_max]
|
|
312
|
+
mask_section[overlay_mask > 0] = mask_id
|
|
313
|
+
|
|
314
|
+
return mask
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
class Mosaic(DualTransform):
|
|
318
|
+
"""Combine multiple images and their annotations into a single image using a mosaic grid layout.
|
|
319
|
+
|
|
320
|
+
This transform takes a primary input image (and its annotations) and combines it with
|
|
321
|
+
additional images/annotations provided via metadata. It calculates the geometry for
|
|
322
|
+
a mosaic grid, selects additional items, preprocesses annotations consistently
|
|
323
|
+
(handling label encoding updates), applies geometric transformations, and assembles
|
|
324
|
+
the final output.
|
|
325
|
+
|
|
326
|
+
Args:
|
|
327
|
+
grid_yx (tuple[int, int]): The number of rows (y) and columns (x) in the mosaic grid.
|
|
328
|
+
Determines the maximum number of images involved (grid_yx[0] * grid_yx[1]).
|
|
329
|
+
Default: (2, 2).
|
|
330
|
+
target_size (tuple[int, int]): The desired output (height, width) for the final mosaic image.
|
|
331
|
+
after cropping the mosaic grid.
|
|
332
|
+
cell_shape (tuple[int, int]): cell shape of each cell in the mosaic grid.
|
|
333
|
+
metadata_key (str): Key in the input dictionary specifying the list of additional data dictionaries
|
|
334
|
+
for the mosaic. Each dictionary in the list should represent one potential additional item.
|
|
335
|
+
Expected keys: 'image' (required, np.ndarray), and optionally 'mask' (np.ndarray),
|
|
336
|
+
'bboxes' (np.ndarray), 'keypoints' (np.ndarray), and any relevant label fields
|
|
337
|
+
(e.g., 'class_labels') corresponding to those specified in `Compose`'s `bbox_params` or
|
|
338
|
+
`keypoint_params`. Default: "mosaic_metadata".
|
|
339
|
+
center_range (tuple[float, float]): Range [0.0-1.0] to sample the center point of the mosaic view
|
|
340
|
+
relative to the valid central region of the conceptual large grid. This affects which parts
|
|
341
|
+
of the assembled grid are visible in the final crop. Default: (0.3, 0.7).
|
|
342
|
+
interpolation (int): OpenCV interpolation flag used for resizing images during geometric processing.
|
|
343
|
+
Default: cv2.INTER_LINEAR.
|
|
344
|
+
mask_interpolation (int): OpenCV interpolation flag used for resizing masks during geometric processing.
|
|
345
|
+
Default: cv2.INTER_NEAREST.
|
|
346
|
+
fill (tuple[float, ...] | float): Value used for padding images if needed during geometric processing.
|
|
347
|
+
Default: 0.
|
|
348
|
+
fill_mask (tuple[float, ...] | float): Value used for padding masks if needed during geometric processing.
|
|
349
|
+
Default: 0.
|
|
350
|
+
p (float): Probability of applying the transform. Default: 0.5.
|
|
351
|
+
|
|
352
|
+
Workflow (`get_params_dependent_on_data`):
|
|
353
|
+
1. Calculate Geometry & Visible Cells: Determine which grid cells are visible in the final
|
|
354
|
+
`target_size` crop and their placement coordinates on the output canvas.
|
|
355
|
+
2. Validate Raw Additional Metadata: Filter the list provided via `metadata_key`,
|
|
356
|
+
keeping only valid items (dicts with an 'image' key).
|
|
357
|
+
3. Select Subset of Raw Additional Metadata: Choose a subset of the valid raw items based
|
|
358
|
+
on the number of visible cells requiring additional data.
|
|
359
|
+
4. Preprocess Selected Raw Additional Items: Preprocess bboxes/keypoints for the *selected*
|
|
360
|
+
additional items *only*. This uses shared processors from `Compose`, updating their
|
|
361
|
+
internal state (e.g., `LabelEncoder`) based on labels in these selected items.
|
|
362
|
+
5. Prepare Primary Data: Extract preprocessed primary data fields from the input `data` dictionary
|
|
363
|
+
into a `primary` dictionary.
|
|
364
|
+
6. Determine & Perform Replication: If fewer additional items were selected than needed,
|
|
365
|
+
replicate the preprocessed primary data as required.
|
|
366
|
+
7. Combine Final Items: Create the list of all preprocessed items (primary, selected additional,
|
|
367
|
+
replicated primary) that will be used.
|
|
368
|
+
8. Assign Items to VISIBLE Grid Cells
|
|
369
|
+
9. Process Geometry & Shift Coordinates: For each assigned item:
|
|
370
|
+
a. Apply geometric transforms (Crop, Resize, Pad) to image/mask.
|
|
371
|
+
b. Apply geometric shift to the *preprocessed* bboxes/keypoints based on cell placement.
|
|
372
|
+
10. Return Parameters: Return the processed cell data (image, mask, shifted bboxes, shifted kps)
|
|
373
|
+
keyed by placement coordinates.
|
|
374
|
+
|
|
375
|
+
Label Handling:
|
|
376
|
+
- The transform relies on `bbox_processor` and `keypoint_processor` provided by `Compose`.
|
|
377
|
+
- `Compose.preprocess` initially fits the processors' `LabelEncoder` on the primary data.
|
|
378
|
+
- This transform (`Mosaic`) preprocesses the *selected* additional raw items using the same
|
|
379
|
+
processors. If new labels are found, the shared `LabelEncoder` state is updated via its
|
|
380
|
+
`update` method.
|
|
381
|
+
- `Compose.postprocess` uses the final updated encoder state to decode all labels present
|
|
382
|
+
in the mosaic output for the current `Compose` call.
|
|
383
|
+
- The encoder state is transient per `Compose` call.
|
|
384
|
+
|
|
385
|
+
Targets:
|
|
386
|
+
image, mask, bboxes, keypoints
|
|
387
|
+
|
|
388
|
+
Image types:
|
|
389
|
+
uint8, float32
|
|
390
|
+
|
|
391
|
+
Reference:
|
|
392
|
+
YOLOv4: Optimal Speed and Accuracy of Object Detection: https://arxiv.org/pdf/2004.10934
|
|
393
|
+
|
|
394
|
+
Examples:
|
|
395
|
+
>>> import numpy as np
|
|
396
|
+
>>> import albumentations as A
|
|
397
|
+
>>> import cv2
|
|
398
|
+
>>>
|
|
399
|
+
>>> # Prepare primary data
|
|
400
|
+
>>> primary_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
401
|
+
>>> primary_mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
402
|
+
>>> primary_bboxes = np.array([[10, 10, 40, 40], [50, 50, 90, 90]], dtype=np.float32)
|
|
403
|
+
>>> primary_labels = [1, 2]
|
|
404
|
+
>>>
|
|
405
|
+
>>> # Prepare additional images for mosaic
|
|
406
|
+
>>> additional_image1 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
407
|
+
>>> additional_mask1 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
408
|
+
>>> additional_bboxes1 = np.array([[20, 20, 60, 60]], dtype=np.float32)
|
|
409
|
+
>>> additional_labels1 = [3]
|
|
410
|
+
>>>
|
|
411
|
+
>>> additional_image2 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
412
|
+
>>> additional_mask2 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
413
|
+
>>> additional_bboxes2 = np.array([[30, 30, 70, 70]], dtype=np.float32)
|
|
414
|
+
>>> additional_labels2 = [4]
|
|
415
|
+
>>>
|
|
416
|
+
>>> additional_image3 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
|
|
417
|
+
>>> additional_mask3 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
|
|
418
|
+
>>> additional_bboxes3 = np.array([[5, 5, 45, 45]], dtype=np.float32)
|
|
419
|
+
>>> additional_labels3 = [5]
|
|
420
|
+
>>>
|
|
421
|
+
>>> # Create metadata for additional images - structured as a list of dicts
|
|
422
|
+
>>> mosaic_metadata = [
|
|
423
|
+
... {
|
|
424
|
+
... 'image': additional_image1,
|
|
425
|
+
... 'mask': additional_mask1,
|
|
426
|
+
... 'bboxes': additional_bboxes1,
|
|
427
|
+
... 'labels': additional_labels1
|
|
428
|
+
... },
|
|
429
|
+
... {
|
|
430
|
+
... 'image': additional_image2,
|
|
431
|
+
... 'mask': additional_mask2,
|
|
432
|
+
... 'bboxes': additional_bboxes2,
|
|
433
|
+
... 'labels': additional_labels2
|
|
434
|
+
... },
|
|
435
|
+
... {
|
|
436
|
+
... 'image': additional_image3,
|
|
437
|
+
... 'mask': additional_mask3,
|
|
438
|
+
... 'bboxes': additional_bboxes3,
|
|
439
|
+
... 'labels': additional_labels3
|
|
440
|
+
... }
|
|
441
|
+
... ]
|
|
442
|
+
>>>
|
|
443
|
+
>>> # Create the transform with Mosaic
|
|
444
|
+
>>> transform = A.Compose([
|
|
445
|
+
... A.Mosaic(
|
|
446
|
+
... grid_yx=(2, 2),
|
|
447
|
+
... target_size=(200, 200),
|
|
448
|
+
... cell_shape=(120, 120),
|
|
449
|
+
... center_range=(0.4, 0.6),
|
|
450
|
+
... fit_mode="cover",
|
|
451
|
+
... p=1.0
|
|
452
|
+
... ),
|
|
453
|
+
... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
|
|
454
|
+
>>>
|
|
455
|
+
>>> # Apply the transform
|
|
456
|
+
>>> transformed = transform(
|
|
457
|
+
... image=primary_image,
|
|
458
|
+
... mask=primary_mask,
|
|
459
|
+
... bboxes=primary_bboxes,
|
|
460
|
+
... labels=primary_labels,
|
|
461
|
+
... mosaic_metadata=mosaic_metadata # Pass the metadata using the default key
|
|
462
|
+
... )
|
|
463
|
+
>>>
|
|
464
|
+
>>> # Access the transformed data
|
|
465
|
+
>>> mosaic_image = transformed['image'] # Combined mosaic image
|
|
466
|
+
>>> mosaic_mask = transformed['mask'] # Combined mosaic mask
|
|
467
|
+
>>> mosaic_bboxes = transformed['bboxes'] # Combined and repositioned bboxes
|
|
468
|
+
>>> mosaic_labels = transformed['labels'] # Combined labels from all images
|
|
469
|
+
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)
|
|
473
|
+
|
|
474
|
+
class InitSchema(BaseTransformInitSchema):
|
|
475
|
+
grid_yx: tuple[int, int]
|
|
476
|
+
target_size: Annotated[
|
|
477
|
+
tuple[int, int],
|
|
478
|
+
AfterValidator(check_range_bounds(1, None)),
|
|
479
|
+
]
|
|
480
|
+
cell_shape: Annotated[
|
|
481
|
+
tuple[int, int],
|
|
482
|
+
AfterValidator(check_range_bounds(1, None)),
|
|
483
|
+
]
|
|
484
|
+
metadata_key: str
|
|
485
|
+
center_range: Annotated[
|
|
486
|
+
tuple[float, float],
|
|
487
|
+
AfterValidator(check_range_bounds(0, 1)),
|
|
488
|
+
AfterValidator(nondecreasing),
|
|
489
|
+
]
|
|
490
|
+
interpolation: Literal[
|
|
491
|
+
cv2.INTER_NEAREST,
|
|
492
|
+
cv2.INTER_NEAREST_EXACT,
|
|
493
|
+
cv2.INTER_LINEAR,
|
|
494
|
+
cv2.INTER_CUBIC,
|
|
495
|
+
cv2.INTER_AREA,
|
|
496
|
+
cv2.INTER_LANCZOS4,
|
|
497
|
+
cv2.INTER_LINEAR_EXACT,
|
|
498
|
+
]
|
|
499
|
+
mask_interpolation: Literal[
|
|
500
|
+
cv2.INTER_NEAREST,
|
|
501
|
+
cv2.INTER_NEAREST_EXACT,
|
|
502
|
+
cv2.INTER_LINEAR,
|
|
503
|
+
cv2.INTER_CUBIC,
|
|
504
|
+
cv2.INTER_AREA,
|
|
505
|
+
cv2.INTER_LANCZOS4,
|
|
506
|
+
cv2.INTER_LINEAR_EXACT,
|
|
507
|
+
]
|
|
508
|
+
fill: tuple[float, ...] | float
|
|
509
|
+
fill_mask: tuple[float, ...] | float
|
|
510
|
+
fit_mode: Literal["cover", "contain"]
|
|
511
|
+
|
|
512
|
+
@model_validator(mode="after")
|
|
513
|
+
def _check_cell_shape(self) -> Self:
|
|
514
|
+
if (
|
|
515
|
+
self.cell_shape[0] * self.grid_yx[0] < self.target_size[0]
|
|
516
|
+
or self.cell_shape[1] * self.grid_yx[1] < self.target_size[1]
|
|
517
|
+
):
|
|
518
|
+
raise ValueError("Target size should be smaller than cell cell_size * grid_yx")
|
|
519
|
+
return self
|
|
520
|
+
|
|
521
|
+
def __init__(
|
|
522
|
+
self,
|
|
523
|
+
grid_yx: tuple[int, int] = (2, 2),
|
|
524
|
+
target_size: tuple[int, int] = (512, 512),
|
|
525
|
+
cell_shape: tuple[int, int] = (512, 512),
|
|
526
|
+
center_range: tuple[float, float] = (0.3, 0.7),
|
|
527
|
+
fit_mode: Literal["cover", "contain"] = "cover",
|
|
528
|
+
interpolation: Literal[
|
|
529
|
+
cv2.INTER_NEAREST,
|
|
530
|
+
cv2.INTER_NEAREST_EXACT,
|
|
531
|
+
cv2.INTER_LINEAR,
|
|
532
|
+
cv2.INTER_CUBIC,
|
|
533
|
+
cv2.INTER_AREA,
|
|
534
|
+
cv2.INTER_LANCZOS4,
|
|
535
|
+
cv2.INTER_LINEAR_EXACT,
|
|
536
|
+
] = cv2.INTER_LINEAR,
|
|
537
|
+
mask_interpolation: Literal[
|
|
538
|
+
cv2.INTER_NEAREST,
|
|
539
|
+
cv2.INTER_NEAREST_EXACT,
|
|
540
|
+
cv2.INTER_LINEAR,
|
|
541
|
+
cv2.INTER_CUBIC,
|
|
542
|
+
cv2.INTER_AREA,
|
|
543
|
+
cv2.INTER_LANCZOS4,
|
|
544
|
+
cv2.INTER_LINEAR_EXACT,
|
|
545
|
+
] = cv2.INTER_NEAREST,
|
|
546
|
+
fill: tuple[float, ...] | float = 0,
|
|
547
|
+
fill_mask: tuple[float, ...] | float = 0,
|
|
548
|
+
metadata_key: str = "mosaic_metadata",
|
|
549
|
+
p: float = 0.5,
|
|
550
|
+
) -> None:
|
|
551
|
+
super().__init__(p=p)
|
|
552
|
+
self.grid_yx = grid_yx
|
|
553
|
+
self.target_size = target_size
|
|
554
|
+
|
|
555
|
+
self.metadata_key = metadata_key
|
|
556
|
+
self.center_range = center_range
|
|
557
|
+
self.interpolation = interpolation
|
|
558
|
+
self.mask_interpolation = mask_interpolation
|
|
559
|
+
self.fill = fill
|
|
560
|
+
self.fill_mask = fill_mask
|
|
561
|
+
self.fit_mode = fit_mode
|
|
562
|
+
self.cell_shape = cell_shape
|
|
563
|
+
|
|
564
|
+
@property
|
|
565
|
+
def targets_as_params(self) -> list[str]:
|
|
566
|
+
"""Get list of targets that should be passed as parameters to transforms.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
list[str]: List containing the metadata key name
|
|
570
|
+
|
|
571
|
+
"""
|
|
572
|
+
return [self.metadata_key]
|
|
573
|
+
|
|
574
|
+
def _calculate_geometry(self, data: dict[str, Any]) -> list[tuple[int, int, int, int]]:
|
|
575
|
+
# Step 1: Calculate Geometry & Cell Placements
|
|
576
|
+
center_xy = fmixing.calculate_mosaic_center_point(
|
|
577
|
+
grid_yx=self.grid_yx,
|
|
578
|
+
cell_shape=self.cell_shape,
|
|
579
|
+
target_size=self.target_size,
|
|
580
|
+
center_range=self.center_range,
|
|
581
|
+
py_random=self.py_random,
|
|
582
|
+
)
|
|
583
|
+
|
|
584
|
+
return fmixing.calculate_cell_placements(
|
|
585
|
+
grid_yx=self.grid_yx,
|
|
586
|
+
cell_shape=self.cell_shape,
|
|
587
|
+
target_size=self.target_size,
|
|
588
|
+
center_xy=center_xy,
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
def _select_additional_items(self, data: dict[str, Any], num_additional_needed: int) -> list[dict[str, Any]]:
|
|
592
|
+
valid_items = fmixing.filter_valid_metadata(data.get(self.metadata_key), self.metadata_key, data)
|
|
593
|
+
if len(valid_items) > num_additional_needed:
|
|
594
|
+
return self.py_random.sample(valid_items, num_additional_needed)
|
|
595
|
+
return valid_items
|
|
596
|
+
|
|
597
|
+
def _preprocess_additional_items(
|
|
598
|
+
self,
|
|
599
|
+
additional_items: list[dict[str, Any]],
|
|
600
|
+
data: dict[str, Any],
|
|
601
|
+
) -> list[fmixing.ProcessedMosaicItem]:
|
|
602
|
+
if "bboxes" in data or "keypoints" in data:
|
|
603
|
+
bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
|
|
604
|
+
keypoint_processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
|
|
605
|
+
return fmixing.preprocess_selected_mosaic_items(additional_items, bbox_processor, keypoint_processor)
|
|
606
|
+
return cast("list[fmixing.ProcessedMosaicItem]", list(additional_items))
|
|
607
|
+
|
|
608
|
+
def _prepare_final_items(
|
|
609
|
+
self,
|
|
610
|
+
primary: fmixing.ProcessedMosaicItem,
|
|
611
|
+
additional_items: list[fmixing.ProcessedMosaicItem],
|
|
612
|
+
num_needed: int,
|
|
613
|
+
) -> list[fmixing.ProcessedMosaicItem]:
|
|
614
|
+
num_replications = max(0, num_needed - len(additional_items))
|
|
615
|
+
replicated = [deepcopy(primary) for _ in range(num_replications)]
|
|
616
|
+
return [primary, *additional_items, *replicated]
|
|
617
|
+
|
|
618
|
+
def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
|
|
619
|
+
"""Orchestrates the steps to calculate mosaic parameters by calling helper methods."""
|
|
620
|
+
cell_placements = self._calculate_geometry(data)
|
|
621
|
+
|
|
622
|
+
num_cells = len(cell_placements)
|
|
623
|
+
num_additional_needed = max(0, num_cells - 1)
|
|
624
|
+
|
|
625
|
+
additional_items = self._select_additional_items(data, num_additional_needed)
|
|
626
|
+
|
|
627
|
+
preprocessed_additional = self._preprocess_additional_items(additional_items, data)
|
|
628
|
+
|
|
629
|
+
primary = self.get_primary_data(data)
|
|
630
|
+
final_items = self._prepare_final_items(primary, preprocessed_additional, num_additional_needed)
|
|
631
|
+
|
|
632
|
+
placement_to_item_index = fmixing.assign_items_to_grid_cells(
|
|
633
|
+
num_items=len(final_items),
|
|
634
|
+
cell_placements=cell_placements,
|
|
635
|
+
py_random=self.py_random,
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
processed_cells = fmixing.process_all_mosaic_geometries(
|
|
639
|
+
canvas_shape=self.target_size,
|
|
640
|
+
cell_shape=self.cell_shape,
|
|
641
|
+
placement_to_item_index=placement_to_item_index,
|
|
642
|
+
final_items_for_grid=final_items,
|
|
643
|
+
fill=self.fill,
|
|
644
|
+
fill_mask=self.fill_mask if self.fill_mask is not None else self.fill,
|
|
645
|
+
fit_mode=self.fit_mode,
|
|
646
|
+
interpolation=self.interpolation,
|
|
647
|
+
mask_interpolation=self.mask_interpolation,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
if "bboxes" in data or "keypoints" in data:
|
|
651
|
+
processed_cells = fmixing.shift_all_coordinates(processed_cells, canvas_shape=self.target_size)
|
|
652
|
+
|
|
653
|
+
result = {"processed_cells": processed_cells, "target_shape": self._get_target_shape(data["image"].shape)}
|
|
654
|
+
if "mask" in data:
|
|
655
|
+
result["target_mask_shape"] = self._get_target_shape(data["mask"].shape)
|
|
656
|
+
return result
|
|
657
|
+
|
|
658
|
+
@staticmethod
|
|
659
|
+
def get_primary_data(data: dict[str, Any]) -> fmixing.ProcessedMosaicItem:
|
|
660
|
+
"""Get a copy of the primary data (data passed in `data` parameter) to avoid modifying the original data.
|
|
661
|
+
|
|
662
|
+
Args:
|
|
663
|
+
data (dict[str, Any]): Dictionary containing the primary data.
|
|
664
|
+
|
|
665
|
+
Returns:
|
|
666
|
+
fmixing.ProcessedMosaicItem: A copy of the primary data.
|
|
667
|
+
|
|
668
|
+
"""
|
|
669
|
+
mask = data.get("mask")
|
|
670
|
+
if mask is not None:
|
|
671
|
+
mask = mask.copy()
|
|
672
|
+
bboxes = data.get("bboxes")
|
|
673
|
+
if bboxes is not None:
|
|
674
|
+
bboxes = bboxes.copy()
|
|
675
|
+
keypoints = data.get("keypoints")
|
|
676
|
+
if keypoints is not None:
|
|
677
|
+
keypoints = keypoints.copy()
|
|
678
|
+
return {
|
|
679
|
+
"image": data["image"],
|
|
680
|
+
"mask": mask,
|
|
681
|
+
"bboxes": bboxes,
|
|
682
|
+
"keypoints": keypoints,
|
|
683
|
+
}
|
|
684
|
+
|
|
685
|
+
def _get_target_shape(self, np_shape: tuple[int, ...]) -> list[int]:
|
|
686
|
+
target_shape = list(np_shape)
|
|
687
|
+
target_shape[0] = self.target_size[0]
|
|
688
|
+
target_shape[1] = self.target_size[1]
|
|
689
|
+
return target_shape
|
|
690
|
+
|
|
691
|
+
def apply(
|
|
692
|
+
self,
|
|
693
|
+
img: np.ndarray,
|
|
694
|
+
processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
|
|
695
|
+
target_shape: tuple[int, int],
|
|
696
|
+
**params: Any,
|
|
697
|
+
) -> np.ndarray:
|
|
698
|
+
"""Apply mosaic transformation to the input image.
|
|
699
|
+
|
|
700
|
+
Args:
|
|
701
|
+
img (np.ndarray): Input image
|
|
702
|
+
processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary of processed cell data
|
|
703
|
+
target_shape (tuple[int, int]): Shape of the target image.
|
|
704
|
+
**params (Any): Additional parameters
|
|
705
|
+
|
|
706
|
+
Returns:
|
|
707
|
+
np.ndarray: Mosaic transformed image
|
|
708
|
+
|
|
709
|
+
"""
|
|
710
|
+
return fmixing.assemble_mosaic_from_processed_cells(
|
|
711
|
+
processed_cells=processed_cells,
|
|
712
|
+
target_shape=target_shape,
|
|
713
|
+
dtype=img.dtype,
|
|
714
|
+
data_key="image",
|
|
715
|
+
fill=self.fill,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
def apply_to_mask(
|
|
719
|
+
self,
|
|
720
|
+
mask: np.ndarray,
|
|
721
|
+
processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
|
|
722
|
+
target_mask_shape: tuple[int, int],
|
|
723
|
+
**params: Any,
|
|
724
|
+
) -> np.ndarray:
|
|
725
|
+
"""Apply mosaic transformation to the input mask.
|
|
726
|
+
|
|
727
|
+
Args:
|
|
728
|
+
mask (np.ndarray): Input mask.
|
|
729
|
+
processed_cells (dict): Dictionary of processed cell data containing cropped/padded mask segments.
|
|
730
|
+
target_mask_shape (tuple[int, int]): Shape of the target mask.
|
|
731
|
+
**params (Any): Additional parameters (unused).
|
|
732
|
+
|
|
733
|
+
Returns:
|
|
734
|
+
np.ndarray: Mosaic transformed mask.
|
|
735
|
+
|
|
736
|
+
"""
|
|
737
|
+
return fmixing.assemble_mosaic_from_processed_cells(
|
|
738
|
+
processed_cells=processed_cells,
|
|
739
|
+
target_shape=target_mask_shape,
|
|
740
|
+
dtype=mask.dtype,
|
|
741
|
+
data_key="mask",
|
|
742
|
+
fill=self.fill_mask,
|
|
743
|
+
)
|
|
744
|
+
|
|
745
|
+
def apply_to_bboxes(
|
|
746
|
+
self,
|
|
747
|
+
bboxes: np.ndarray, # Original bboxes - ignored
|
|
748
|
+
processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
|
|
749
|
+
**params: Any,
|
|
750
|
+
) -> np.ndarray:
|
|
751
|
+
"""Applies mosaic transformation to bounding boxes.
|
|
752
|
+
|
|
753
|
+
Args:
|
|
754
|
+
bboxes (np.ndarray): Original bounding boxes (ignored).
|
|
755
|
+
processed_cells (dict): Dictionary mapping placement coords to processed cell data
|
|
756
|
+
(containing shifted bboxes in absolute pixel coords).
|
|
757
|
+
**params (Any): Additional parameters (unused).
|
|
758
|
+
|
|
759
|
+
Returns:
|
|
760
|
+
np.ndarray: Final combined, filtered, bounding boxes.
|
|
761
|
+
|
|
762
|
+
"""
|
|
763
|
+
all_shifted_bboxes = []
|
|
764
|
+
|
|
765
|
+
for cell_data in processed_cells.values():
|
|
766
|
+
shifted_bboxes = cell_data["bboxes"]
|
|
767
|
+
if shifted_bboxes.size > 0:
|
|
768
|
+
all_shifted_bboxes.append(shifted_bboxes)
|
|
769
|
+
|
|
770
|
+
if not all_shifted_bboxes:
|
|
771
|
+
return np.empty((0, bboxes.shape[1]), dtype=bboxes.dtype)
|
|
772
|
+
|
|
773
|
+
# Concatenate (these are absolute pixel coordinates)
|
|
774
|
+
combined_bboxes = np.concatenate(all_shifted_bboxes, axis=0)
|
|
775
|
+
|
|
776
|
+
# Apply filtering using processor parameters
|
|
777
|
+
bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
|
|
778
|
+
# Assume processor exists if bboxes are being processed
|
|
779
|
+
shape_dict: dict[Literal["depth", "height", "width"], int] = {
|
|
780
|
+
"height": self.target_size[0],
|
|
781
|
+
"width": self.target_size[1],
|
|
782
|
+
}
|
|
783
|
+
return filter_bboxes(
|
|
784
|
+
combined_bboxes,
|
|
785
|
+
shape_dict,
|
|
786
|
+
min_area=bbox_processor.params.min_area,
|
|
787
|
+
min_visibility=bbox_processor.params.min_visibility,
|
|
788
|
+
min_width=bbox_processor.params.min_width,
|
|
789
|
+
min_height=bbox_processor.params.min_height,
|
|
790
|
+
max_accept_ratio=bbox_processor.params.max_accept_ratio,
|
|
791
|
+
)
|
|
792
|
+
|
|
793
|
+
def apply_to_keypoints(
|
|
794
|
+
self,
|
|
795
|
+
keypoints: np.ndarray, # Original keypoints - ignored
|
|
796
|
+
processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
|
|
797
|
+
**params: Any,
|
|
798
|
+
) -> np.ndarray:
|
|
799
|
+
"""Applies mosaic transformation to keypoints.
|
|
800
|
+
|
|
801
|
+
Args:
|
|
802
|
+
keypoints (np.ndarray): Original keypoints (ignored).
|
|
803
|
+
processed_cells (dict): Dictionary mapping placement coords to processed cell data
|
|
804
|
+
(containing shifted keypoints).
|
|
805
|
+
**params (Any): Additional parameters (unused).
|
|
806
|
+
|
|
807
|
+
Returns:
|
|
808
|
+
np.ndarray: Final combined, filtered keypoints.
|
|
809
|
+
|
|
810
|
+
"""
|
|
811
|
+
all_shifted_keypoints = []
|
|
812
|
+
|
|
813
|
+
for cell_data in processed_cells.values():
|
|
814
|
+
shifted_keypoints = cell_data["keypoints"]
|
|
815
|
+
if shifted_keypoints.size > 0:
|
|
816
|
+
all_shifted_keypoints.append(shifted_keypoints)
|
|
817
|
+
|
|
818
|
+
if not all_shifted_keypoints:
|
|
819
|
+
return np.empty((0, keypoints.shape[1]), dtype=keypoints.dtype)
|
|
820
|
+
|
|
821
|
+
combined_keypoints = np.concatenate(all_shifted_keypoints, axis=0)
|
|
822
|
+
|
|
823
|
+
# Filter out keypoints outside the target canvas boundaries
|
|
824
|
+
target_h, target_w = self.target_size
|
|
825
|
+
valid_indices = (
|
|
826
|
+
(combined_keypoints[:, 0] >= 0)
|
|
827
|
+
& (combined_keypoints[:, 0] < target_w)
|
|
828
|
+
& (combined_keypoints[:, 1] >= 0)
|
|
829
|
+
& (combined_keypoints[:, 1] < target_h)
|
|
830
|
+
)
|
|
831
|
+
|
|
832
|
+
return combined_keypoints[valid_indices]
|