nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,832 @@
1
+ """Transforms that combine multiple images and their associated annotations.
2
+
3
+ This module contains transformations that take multiple input sources (e.g., a primary image
4
+ and additional images provided via metadata) and combine them into a single output.
5
+ Examples include overlaying elements (`OverlayElements`) or creating complex compositions
6
+ like `Mosaic`.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import random
12
+ from copy import deepcopy
13
+ from typing import Annotated, Any, Literal, cast
14
+
15
+ import cv2
16
+ import numpy as np
17
+ from pydantic import AfterValidator, model_validator
18
+ from typing_extensions import Self
19
+
20
+ from albumentations.augmentations.mixing import functional as fmixing
21
+ from albumentations.core.bbox_utils import BboxProcessor, check_bboxes, denormalize_bboxes, filter_bboxes
22
+ from albumentations.core.keypoints_utils import KeypointsProcessor
23
+ from albumentations.core.pydantic import check_range_bounds, nondecreasing
24
+ from albumentations.core.transforms_interface import BaseTransformInitSchema, DualTransform
25
+ from albumentations.core.type_definitions import LENGTH_RAW_BBOX, Targets
26
+
27
+ __all__ = ["Mosaic", "OverlayElements"]
28
+
29
+
30
+ class OverlayElements(DualTransform):
31
+ """Apply overlay elements such as images and masks onto an input image. This transformation can be used to add
32
+ various objects (e.g., stickers, logos) to images with optional masks and bounding boxes for better placement
33
+ control.
34
+
35
+ Args:
36
+ metadata_key (str): Additional target key for metadata. Default `overlay_metadata`.
37
+ p (float): Probability of applying the transformation. Default: 0.5.
38
+
39
+ Possible Metadata Fields:
40
+ - image (np.ndarray): The overlay image to be applied. This is a required field.
41
+ - bbox (list[int]): The bounding box specifying the region where the overlay should be applied. It should
42
+ contain four floats: [y_min, x_min, y_max, x_max]. If `label_id` is provided, it should
43
+ be appended as the fifth element in the bbox. BBox should be in Albumentations format,
44
+ that is the same as normalized Pascal VOC format
45
+ [x_min / width, y_min / height, x_max / width, y_max / height]
46
+ - mask (np.ndarray): An optional mask that defines the non-rectangular region of the overlay image. If not
47
+ provided, the entire overlay image is used.
48
+ - mask_id (int): An optional identifier for the mask. If provided, the regions specified by the mask will
49
+ be labeled with this identifier in the output mask.
50
+
51
+ Targets:
52
+ image, mask
53
+
54
+ Image types:
55
+ uint8, float32
56
+
57
+ References:
58
+ doc-augmentation: https://github.com/danaaubakirova/doc-augmentation
59
+
60
+ Examples:
61
+ >>> import numpy as np
62
+ >>> import albumentations as A
63
+ >>> import cv2
64
+ >>>
65
+ >>> # Prepare primary data (base image and mask)
66
+ >>> image = np.zeros((300, 300, 3), dtype=np.uint8)
67
+ >>> mask = np.zeros((300, 300), dtype=np.uint8)
68
+ >>>
69
+ >>> # 1. Create a simple overlay image (a red square)
70
+ >>> overlay_image1 = np.zeros((50, 50, 3), dtype=np.uint8)
71
+ >>> overlay_image1[:, :, 0] = 255 # Red color
72
+ >>>
73
+ >>> # 2. Create another overlay with a mask (a blue circle with transparency)
74
+ >>> overlay_image2 = np.zeros((80, 80, 3), dtype=np.uint8)
75
+ >>> overlay_image2[:, :, 2] = 255 # Blue color
76
+ >>> overlay_mask2 = np.zeros((80, 80), dtype=np.uint8)
77
+ >>> # Create a circular mask
78
+ >>> center = (40, 40)
79
+ >>> radius = 30
80
+ >>> for i in range(80):
81
+ ... for j in range(80):
82
+ ... if (i - center[0])**2 + (j - center[1])**2 < radius**2:
83
+ ... overlay_mask2[i, j] = 255
84
+ >>>
85
+ >>> # 3. Create an overlay with both bbox and mask_id
86
+ >>> overlay_image3 = np.zeros((60, 120, 3), dtype=np.uint8)
87
+ >>> overlay_image3[:, :, 1] = 255 # Green color
88
+ >>> # Create a rectangular mask with rounded corners
89
+ >>> overlay_mask3 = np.zeros((60, 120), dtype=np.uint8)
90
+ >>> cv2.rectangle(overlay_mask3, (10, 10), (110, 50), 255, -1)
91
+ >>>
92
+ >>> # Create the metadata list - each item is a dictionary with overlay information
93
+ >>> overlay_metadata = [
94
+ ... {
95
+ ... 'image': overlay_image1,
96
+ ... # No bbox provided - will be placed randomly
97
+ ... },
98
+ ... {
99
+ ... 'image': overlay_image2,
100
+ ... 'bbox': [0.6, 0.1, 0.9, 0.4], # Normalized coordinates [x_min, y_min, x_max, y_max]
101
+ ... 'mask': overlay_mask2,
102
+ ... 'mask_id': 1 # This overlay will update the mask with id 1
103
+ ... },
104
+ ... {
105
+ ... 'image': overlay_image3,
106
+ ... 'bbox': [0.1, 0.7, 0.5, 0.9], # Bottom left placement
107
+ ... 'mask': overlay_mask3,
108
+ ... 'mask_id': 2 # This overlay will update the mask with id 2
109
+ ... }
110
+ ... ]
111
+ >>>
112
+ >>> # Create the transform
113
+ >>> transform = A.Compose([
114
+ ... A.OverlayElements(p=1.0),
115
+ ... ])
116
+ >>>
117
+ >>> # Apply the transform
118
+ >>> result = transform(
119
+ ... image=image,
120
+ ... mask=mask,
121
+ ... overlay_metadata=overlay_metadata # Pass metadata using the default key
122
+ ... )
123
+ >>>
124
+ >>> # Get results with overlays applied
125
+ >>> result_image = result['image'] # Image with the three overlays applied
126
+ >>> result_mask = result['mask'] # Mask with regions labeled using the mask_id values
127
+ >>>
128
+ >>> # Let's verify the mask contains the specified mask_id values
129
+ >>> has_mask_id_1 = np.any(result_mask == 1) # Should be True
130
+ >>> has_mask_id_2 = np.any(result_mask == 2) # Should be True
131
+
132
+ """
133
+
134
+ _targets = (Targets.IMAGE, Targets.MASK)
135
+
136
+ class InitSchema(BaseTransformInitSchema):
137
+ metadata_key: str
138
+
139
+ def __init__(
140
+ self,
141
+ metadata_key: str = "overlay_metadata",
142
+ p: float = 0.5,
143
+ ):
144
+ super().__init__(p=p)
145
+ self.metadata_key = metadata_key
146
+
147
+ @property
148
+ def targets_as_params(self) -> list[str]:
149
+ """Get list of targets that should be passed as parameters to transforms.
150
+
151
+ Returns:
152
+ list[str]: List containing the metadata key name
153
+
154
+ """
155
+ return [self.metadata_key]
156
+
157
+ @staticmethod
158
+ def preprocess_metadata(
159
+ metadata: dict[str, Any],
160
+ img_shape: tuple[int, int],
161
+ random_state: random.Random,
162
+ ) -> dict[str, Any]:
163
+ """Process overlay metadata to prepare for application.
164
+
165
+ Args:
166
+ metadata (dict[str, Any]): Dictionary containing overlay data such as image, mask, bbox
167
+ img_shape (tuple[int, int]): Shape of the target image as (height, width)
168
+ random_state (random.Random): Random state object for reproducible randomness
169
+
170
+ Returns:
171
+ dict[str, Any]: Processed overlay data including resized overlay image, mask,
172
+ offset coordinates, and bounding box information
173
+
174
+ """
175
+ overlay_image = metadata["image"]
176
+ overlay_height, overlay_width = overlay_image.shape[:2]
177
+ image_height, image_width = img_shape[:2]
178
+
179
+ if "bbox" in metadata:
180
+ bbox = metadata["bbox"]
181
+ bbox_np = np.array([bbox])
182
+ check_bboxes(bbox_np)
183
+ denormalized_bbox = denormalize_bboxes(bbox_np, img_shape[:2])[0]
184
+
185
+ x_min, y_min, x_max, y_max = (int(x) for x in denormalized_bbox[:4])
186
+
187
+ if "mask" in metadata:
188
+ mask = metadata["mask"]
189
+ mask = cv2.resize(mask, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_NEAREST)
190
+ else:
191
+ mask = np.ones((y_max - y_min, x_max - x_min), dtype=np.uint8)
192
+
193
+ overlay_image = cv2.resize(overlay_image, (x_max - x_min, y_max - y_min), interpolation=cv2.INTER_AREA)
194
+ offset = (y_min, x_min)
195
+
196
+ if len(bbox) == LENGTH_RAW_BBOX and "bbox_id" in metadata:
197
+ bbox = [x_min, y_min, x_max, y_max, metadata["bbox_id"]]
198
+ else:
199
+ bbox = (x_min, y_min, x_max, y_max, *bbox[4:])
200
+ else:
201
+ if image_height < overlay_height or image_width < overlay_width:
202
+ overlay_image = cv2.resize(overlay_image, (image_width, image_height), interpolation=cv2.INTER_AREA)
203
+ overlay_height, overlay_width = overlay_image.shape[:2]
204
+
205
+ mask = metadata["mask"] if "mask" in metadata else np.ones_like(overlay_image, dtype=np.uint8)
206
+
207
+ max_x_offset = image_width - overlay_width
208
+ max_y_offset = image_height - overlay_height
209
+
210
+ offset_x = random_state.randint(0, max_x_offset)
211
+ offset_y = random_state.randint(0, max_y_offset)
212
+
213
+ offset = (offset_y, offset_x)
214
+
215
+ bbox = [
216
+ offset_x,
217
+ offset_y,
218
+ offset_x + overlay_width,
219
+ offset_y + overlay_height,
220
+ ]
221
+
222
+ if "bbox_id" in metadata:
223
+ bbox = [*bbox, metadata["bbox_id"]]
224
+
225
+ result = {
226
+ "overlay_image": overlay_image,
227
+ "overlay_mask": mask,
228
+ "offset": offset,
229
+ "bbox": bbox,
230
+ }
231
+
232
+ if "mask_id" in metadata:
233
+ result["mask_id"] = metadata["mask_id"]
234
+
235
+ return result
236
+
237
+ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
238
+ """Generate parameters for overlay transform based on input data.
239
+
240
+ Args:
241
+ params (dict[str, Any]): Dictionary of existing parameters
242
+ data (dict[str, Any]): Dictionary containing input data with image and metadata
243
+
244
+ Returns:
245
+ dict[str, Any]: Dictionary containing processed overlay data ready for application
246
+
247
+ """
248
+ metadata = data[self.metadata_key]
249
+ img_shape = params["shape"]
250
+
251
+ if isinstance(metadata, list):
252
+ overlay_data = [self.preprocess_metadata(md, img_shape, self.py_random) for md in metadata]
253
+ else:
254
+ overlay_data = [self.preprocess_metadata(metadata, img_shape, self.py_random)]
255
+
256
+ return {
257
+ "overlay_data": overlay_data,
258
+ }
259
+
260
+ def apply(
261
+ self,
262
+ img: np.ndarray,
263
+ overlay_data: list[dict[str, Any]],
264
+ **params: Any,
265
+ ) -> np.ndarray:
266
+ """Apply overlay elements to the input image.
267
+
268
+ Args:
269
+ img (np.ndarray): Input image
270
+ overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
271
+ **params (Any): Additional parameters
272
+
273
+ Returns:
274
+ np.ndarray: Image with overlays applied
275
+
276
+ """
277
+ for data in overlay_data:
278
+ overlay_image = data["overlay_image"]
279
+ overlay_mask = data["overlay_mask"]
280
+ offset = data["offset"]
281
+ img = fmixing.copy_and_paste_blend(img, overlay_image, overlay_mask, offset=offset)
282
+ return img
283
+
284
+ def apply_to_mask(
285
+ self,
286
+ mask: np.ndarray,
287
+ overlay_data: list[dict[str, Any]],
288
+ **params: Any,
289
+ ) -> np.ndarray:
290
+ """Apply overlay masks to the input mask.
291
+
292
+ Args:
293
+ mask (np.ndarray): Input mask
294
+ overlay_data (list[dict[str, Any]]): List of dictionaries containing overlay information
295
+ **params (Any): Additional parameters
296
+
297
+ Returns:
298
+ np.ndarray: Mask with overlay masks applied using the specified mask_id values
299
+
300
+ """
301
+ for data in overlay_data:
302
+ if "mask_id" in data and data["mask_id"] is not None:
303
+ overlay_mask = data["overlay_mask"]
304
+ offset = data["offset"]
305
+ mask_id = data["mask_id"]
306
+
307
+ y_min, x_min = offset
308
+ y_max = y_min + overlay_mask.shape[0]
309
+ x_max = x_min + overlay_mask.shape[1]
310
+
311
+ mask_section = mask[y_min:y_max, x_min:x_max]
312
+ mask_section[overlay_mask > 0] = mask_id
313
+
314
+ return mask
315
+
316
+
317
+ class Mosaic(DualTransform):
318
+ """Combine multiple images and their annotations into a single image using a mosaic grid layout.
319
+
320
+ This transform takes a primary input image (and its annotations) and combines it with
321
+ additional images/annotations provided via metadata. It calculates the geometry for
322
+ a mosaic grid, selects additional items, preprocesses annotations consistently
323
+ (handling label encoding updates), applies geometric transformations, and assembles
324
+ the final output.
325
+
326
+ Args:
327
+ grid_yx (tuple[int, int]): The number of rows (y) and columns (x) in the mosaic grid.
328
+ Determines the maximum number of images involved (grid_yx[0] * grid_yx[1]).
329
+ Default: (2, 2).
330
+ target_size (tuple[int, int]): The desired output (height, width) for the final mosaic image.
331
+ after cropping the mosaic grid.
332
+ cell_shape (tuple[int, int]): cell shape of each cell in the mosaic grid.
333
+ metadata_key (str): Key in the input dictionary specifying the list of additional data dictionaries
334
+ for the mosaic. Each dictionary in the list should represent one potential additional item.
335
+ Expected keys: 'image' (required, np.ndarray), and optionally 'mask' (np.ndarray),
336
+ 'bboxes' (np.ndarray), 'keypoints' (np.ndarray), and any relevant label fields
337
+ (e.g., 'class_labels') corresponding to those specified in `Compose`'s `bbox_params` or
338
+ `keypoint_params`. Default: "mosaic_metadata".
339
+ center_range (tuple[float, float]): Range [0.0-1.0] to sample the center point of the mosaic view
340
+ relative to the valid central region of the conceptual large grid. This affects which parts
341
+ of the assembled grid are visible in the final crop. Default: (0.3, 0.7).
342
+ interpolation (int): OpenCV interpolation flag used for resizing images during geometric processing.
343
+ Default: cv2.INTER_LINEAR.
344
+ mask_interpolation (int): OpenCV interpolation flag used for resizing masks during geometric processing.
345
+ Default: cv2.INTER_NEAREST.
346
+ fill (tuple[float, ...] | float): Value used for padding images if needed during geometric processing.
347
+ Default: 0.
348
+ fill_mask (tuple[float, ...] | float): Value used for padding masks if needed during geometric processing.
349
+ Default: 0.
350
+ p (float): Probability of applying the transform. Default: 0.5.
351
+
352
+ Workflow (`get_params_dependent_on_data`):
353
+ 1. Calculate Geometry & Visible Cells: Determine which grid cells are visible in the final
354
+ `target_size` crop and their placement coordinates on the output canvas.
355
+ 2. Validate Raw Additional Metadata: Filter the list provided via `metadata_key`,
356
+ keeping only valid items (dicts with an 'image' key).
357
+ 3. Select Subset of Raw Additional Metadata: Choose a subset of the valid raw items based
358
+ on the number of visible cells requiring additional data.
359
+ 4. Preprocess Selected Raw Additional Items: Preprocess bboxes/keypoints for the *selected*
360
+ additional items *only*. This uses shared processors from `Compose`, updating their
361
+ internal state (e.g., `LabelEncoder`) based on labels in these selected items.
362
+ 5. Prepare Primary Data: Extract preprocessed primary data fields from the input `data` dictionary
363
+ into a `primary` dictionary.
364
+ 6. Determine & Perform Replication: If fewer additional items were selected than needed,
365
+ replicate the preprocessed primary data as required.
366
+ 7. Combine Final Items: Create the list of all preprocessed items (primary, selected additional,
367
+ replicated primary) that will be used.
368
+ 8. Assign Items to VISIBLE Grid Cells
369
+ 9. Process Geometry & Shift Coordinates: For each assigned item:
370
+ a. Apply geometric transforms (Crop, Resize, Pad) to image/mask.
371
+ b. Apply geometric shift to the *preprocessed* bboxes/keypoints based on cell placement.
372
+ 10. Return Parameters: Return the processed cell data (image, mask, shifted bboxes, shifted kps)
373
+ keyed by placement coordinates.
374
+
375
+ Label Handling:
376
+ - The transform relies on `bbox_processor` and `keypoint_processor` provided by `Compose`.
377
+ - `Compose.preprocess` initially fits the processors' `LabelEncoder` on the primary data.
378
+ - This transform (`Mosaic`) preprocesses the *selected* additional raw items using the same
379
+ processors. If new labels are found, the shared `LabelEncoder` state is updated via its
380
+ `update` method.
381
+ - `Compose.postprocess` uses the final updated encoder state to decode all labels present
382
+ in the mosaic output for the current `Compose` call.
383
+ - The encoder state is transient per `Compose` call.
384
+
385
+ Targets:
386
+ image, mask, bboxes, keypoints
387
+
388
+ Image types:
389
+ uint8, float32
390
+
391
+ Reference:
392
+ YOLOv4: Optimal Speed and Accuracy of Object Detection: https://arxiv.org/pdf/2004.10934
393
+
394
+ Examples:
395
+ >>> import numpy as np
396
+ >>> import albumentations as A
397
+ >>> import cv2
398
+ >>>
399
+ >>> # Prepare primary data
400
+ >>> primary_image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
401
+ >>> primary_mask = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
402
+ >>> primary_bboxes = np.array([[10, 10, 40, 40], [50, 50, 90, 90]], dtype=np.float32)
403
+ >>> primary_labels = [1, 2]
404
+ >>>
405
+ >>> # Prepare additional images for mosaic
406
+ >>> additional_image1 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
407
+ >>> additional_mask1 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
408
+ >>> additional_bboxes1 = np.array([[20, 20, 60, 60]], dtype=np.float32)
409
+ >>> additional_labels1 = [3]
410
+ >>>
411
+ >>> additional_image2 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
412
+ >>> additional_mask2 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
413
+ >>> additional_bboxes2 = np.array([[30, 30, 70, 70]], dtype=np.float32)
414
+ >>> additional_labels2 = [4]
415
+ >>>
416
+ >>> additional_image3 = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
417
+ >>> additional_mask3 = np.random.randint(0, 2, (100, 100), dtype=np.uint8)
418
+ >>> additional_bboxes3 = np.array([[5, 5, 45, 45]], dtype=np.float32)
419
+ >>> additional_labels3 = [5]
420
+ >>>
421
+ >>> # Create metadata for additional images - structured as a list of dicts
422
+ >>> mosaic_metadata = [
423
+ ... {
424
+ ... 'image': additional_image1,
425
+ ... 'mask': additional_mask1,
426
+ ... 'bboxes': additional_bboxes1,
427
+ ... 'labels': additional_labels1
428
+ ... },
429
+ ... {
430
+ ... 'image': additional_image2,
431
+ ... 'mask': additional_mask2,
432
+ ... 'bboxes': additional_bboxes2,
433
+ ... 'labels': additional_labels2
434
+ ... },
435
+ ... {
436
+ ... 'image': additional_image3,
437
+ ... 'mask': additional_mask3,
438
+ ... 'bboxes': additional_bboxes3,
439
+ ... 'labels': additional_labels3
440
+ ... }
441
+ ... ]
442
+ >>>
443
+ >>> # Create the transform with Mosaic
444
+ >>> transform = A.Compose([
445
+ ... A.Mosaic(
446
+ ... grid_yx=(2, 2),
447
+ ... target_size=(200, 200),
448
+ ... cell_shape=(120, 120),
449
+ ... center_range=(0.4, 0.6),
450
+ ... fit_mode="cover",
451
+ ... p=1.0
452
+ ... ),
453
+ ... ], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['labels']))
454
+ >>>
455
+ >>> # Apply the transform
456
+ >>> transformed = transform(
457
+ ... image=primary_image,
458
+ ... mask=primary_mask,
459
+ ... bboxes=primary_bboxes,
460
+ ... labels=primary_labels,
461
+ ... mosaic_metadata=mosaic_metadata # Pass the metadata using the default key
462
+ ... )
463
+ >>>
464
+ >>> # Access the transformed data
465
+ >>> mosaic_image = transformed['image'] # Combined mosaic image
466
+ >>> mosaic_mask = transformed['mask'] # Combined mosaic mask
467
+ >>> mosaic_bboxes = transformed['bboxes'] # Combined and repositioned bboxes
468
+ >>> mosaic_labels = transformed['labels'] # Combined labels from all images
469
+
470
+ """
471
+
472
+ _targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)
473
+
474
+ class InitSchema(BaseTransformInitSchema):
475
+ grid_yx: tuple[int, int]
476
+ target_size: Annotated[
477
+ tuple[int, int],
478
+ AfterValidator(check_range_bounds(1, None)),
479
+ ]
480
+ cell_shape: Annotated[
481
+ tuple[int, int],
482
+ AfterValidator(check_range_bounds(1, None)),
483
+ ]
484
+ metadata_key: str
485
+ center_range: Annotated[
486
+ tuple[float, float],
487
+ AfterValidator(check_range_bounds(0, 1)),
488
+ AfterValidator(nondecreasing),
489
+ ]
490
+ interpolation: Literal[
491
+ cv2.INTER_NEAREST,
492
+ cv2.INTER_NEAREST_EXACT,
493
+ cv2.INTER_LINEAR,
494
+ cv2.INTER_CUBIC,
495
+ cv2.INTER_AREA,
496
+ cv2.INTER_LANCZOS4,
497
+ cv2.INTER_LINEAR_EXACT,
498
+ ]
499
+ mask_interpolation: Literal[
500
+ cv2.INTER_NEAREST,
501
+ cv2.INTER_NEAREST_EXACT,
502
+ cv2.INTER_LINEAR,
503
+ cv2.INTER_CUBIC,
504
+ cv2.INTER_AREA,
505
+ cv2.INTER_LANCZOS4,
506
+ cv2.INTER_LINEAR_EXACT,
507
+ ]
508
+ fill: tuple[float, ...] | float
509
+ fill_mask: tuple[float, ...] | float
510
+ fit_mode: Literal["cover", "contain"]
511
+
512
+ @model_validator(mode="after")
513
+ def _check_cell_shape(self) -> Self:
514
+ if (
515
+ self.cell_shape[0] * self.grid_yx[0] < self.target_size[0]
516
+ or self.cell_shape[1] * self.grid_yx[1] < self.target_size[1]
517
+ ):
518
+ raise ValueError("Target size should be smaller than cell cell_size * grid_yx")
519
+ return self
520
+
521
+ def __init__(
522
+ self,
523
+ grid_yx: tuple[int, int] = (2, 2),
524
+ target_size: tuple[int, int] = (512, 512),
525
+ cell_shape: tuple[int, int] = (512, 512),
526
+ center_range: tuple[float, float] = (0.3, 0.7),
527
+ fit_mode: Literal["cover", "contain"] = "cover",
528
+ interpolation: Literal[
529
+ cv2.INTER_NEAREST,
530
+ cv2.INTER_NEAREST_EXACT,
531
+ cv2.INTER_LINEAR,
532
+ cv2.INTER_CUBIC,
533
+ cv2.INTER_AREA,
534
+ cv2.INTER_LANCZOS4,
535
+ cv2.INTER_LINEAR_EXACT,
536
+ ] = cv2.INTER_LINEAR,
537
+ mask_interpolation: Literal[
538
+ cv2.INTER_NEAREST,
539
+ cv2.INTER_NEAREST_EXACT,
540
+ cv2.INTER_LINEAR,
541
+ cv2.INTER_CUBIC,
542
+ cv2.INTER_AREA,
543
+ cv2.INTER_LANCZOS4,
544
+ cv2.INTER_LINEAR_EXACT,
545
+ ] = cv2.INTER_NEAREST,
546
+ fill: tuple[float, ...] | float = 0,
547
+ fill_mask: tuple[float, ...] | float = 0,
548
+ metadata_key: str = "mosaic_metadata",
549
+ p: float = 0.5,
550
+ ) -> None:
551
+ super().__init__(p=p)
552
+ self.grid_yx = grid_yx
553
+ self.target_size = target_size
554
+
555
+ self.metadata_key = metadata_key
556
+ self.center_range = center_range
557
+ self.interpolation = interpolation
558
+ self.mask_interpolation = mask_interpolation
559
+ self.fill = fill
560
+ self.fill_mask = fill_mask
561
+ self.fit_mode = fit_mode
562
+ self.cell_shape = cell_shape
563
+
564
+ @property
565
+ def targets_as_params(self) -> list[str]:
566
+ """Get list of targets that should be passed as parameters to transforms.
567
+
568
+ Returns:
569
+ list[str]: List containing the metadata key name
570
+
571
+ """
572
+ return [self.metadata_key]
573
+
574
+ def _calculate_geometry(self, data: dict[str, Any]) -> list[tuple[int, int, int, int]]:
575
+ # Step 1: Calculate Geometry & Cell Placements
576
+ center_xy = fmixing.calculate_mosaic_center_point(
577
+ grid_yx=self.grid_yx,
578
+ cell_shape=self.cell_shape,
579
+ target_size=self.target_size,
580
+ center_range=self.center_range,
581
+ py_random=self.py_random,
582
+ )
583
+
584
+ return fmixing.calculate_cell_placements(
585
+ grid_yx=self.grid_yx,
586
+ cell_shape=self.cell_shape,
587
+ target_size=self.target_size,
588
+ center_xy=center_xy,
589
+ )
590
+
591
+ def _select_additional_items(self, data: dict[str, Any], num_additional_needed: int) -> list[dict[str, Any]]:
592
+ valid_items = fmixing.filter_valid_metadata(data.get(self.metadata_key), self.metadata_key, data)
593
+ if len(valid_items) > num_additional_needed:
594
+ return self.py_random.sample(valid_items, num_additional_needed)
595
+ return valid_items
596
+
597
+ def _preprocess_additional_items(
598
+ self,
599
+ additional_items: list[dict[str, Any]],
600
+ data: dict[str, Any],
601
+ ) -> list[fmixing.ProcessedMosaicItem]:
602
+ if "bboxes" in data or "keypoints" in data:
603
+ bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
604
+ keypoint_processor = cast("KeypointsProcessor", self.get_processor("keypoints"))
605
+ return fmixing.preprocess_selected_mosaic_items(additional_items, bbox_processor, keypoint_processor)
606
+ return cast("list[fmixing.ProcessedMosaicItem]", list(additional_items))
607
+
608
+ def _prepare_final_items(
609
+ self,
610
+ primary: fmixing.ProcessedMosaicItem,
611
+ additional_items: list[fmixing.ProcessedMosaicItem],
612
+ num_needed: int,
613
+ ) -> list[fmixing.ProcessedMosaicItem]:
614
+ num_replications = max(0, num_needed - len(additional_items))
615
+ replicated = [deepcopy(primary) for _ in range(num_replications)]
616
+ return [primary, *additional_items, *replicated]
617
+
618
+ def get_params_dependent_on_data(self, params: dict[str, Any], data: dict[str, Any]) -> dict[str, Any]:
619
+ """Orchestrates the steps to calculate mosaic parameters by calling helper methods."""
620
+ cell_placements = self._calculate_geometry(data)
621
+
622
+ num_cells = len(cell_placements)
623
+ num_additional_needed = max(0, num_cells - 1)
624
+
625
+ additional_items = self._select_additional_items(data, num_additional_needed)
626
+
627
+ preprocessed_additional = self._preprocess_additional_items(additional_items, data)
628
+
629
+ primary = self.get_primary_data(data)
630
+ final_items = self._prepare_final_items(primary, preprocessed_additional, num_additional_needed)
631
+
632
+ placement_to_item_index = fmixing.assign_items_to_grid_cells(
633
+ num_items=len(final_items),
634
+ cell_placements=cell_placements,
635
+ py_random=self.py_random,
636
+ )
637
+
638
+ processed_cells = fmixing.process_all_mosaic_geometries(
639
+ canvas_shape=self.target_size,
640
+ cell_shape=self.cell_shape,
641
+ placement_to_item_index=placement_to_item_index,
642
+ final_items_for_grid=final_items,
643
+ fill=self.fill,
644
+ fill_mask=self.fill_mask if self.fill_mask is not None else self.fill,
645
+ fit_mode=self.fit_mode,
646
+ interpolation=self.interpolation,
647
+ mask_interpolation=self.mask_interpolation,
648
+ )
649
+
650
+ if "bboxes" in data or "keypoints" in data:
651
+ processed_cells = fmixing.shift_all_coordinates(processed_cells, canvas_shape=self.target_size)
652
+
653
+ result = {"processed_cells": processed_cells, "target_shape": self._get_target_shape(data["image"].shape)}
654
+ if "mask" in data:
655
+ result["target_mask_shape"] = self._get_target_shape(data["mask"].shape)
656
+ return result
657
+
658
+ @staticmethod
659
+ def get_primary_data(data: dict[str, Any]) -> fmixing.ProcessedMosaicItem:
660
+ """Get a copy of the primary data (data passed in `data` parameter) to avoid modifying the original data.
661
+
662
+ Args:
663
+ data (dict[str, Any]): Dictionary containing the primary data.
664
+
665
+ Returns:
666
+ fmixing.ProcessedMosaicItem: A copy of the primary data.
667
+
668
+ """
669
+ mask = data.get("mask")
670
+ if mask is not None:
671
+ mask = mask.copy()
672
+ bboxes = data.get("bboxes")
673
+ if bboxes is not None:
674
+ bboxes = bboxes.copy()
675
+ keypoints = data.get("keypoints")
676
+ if keypoints is not None:
677
+ keypoints = keypoints.copy()
678
+ return {
679
+ "image": data["image"],
680
+ "mask": mask,
681
+ "bboxes": bboxes,
682
+ "keypoints": keypoints,
683
+ }
684
+
685
+ def _get_target_shape(self, np_shape: tuple[int, ...]) -> list[int]:
686
+ target_shape = list(np_shape)
687
+ target_shape[0] = self.target_size[0]
688
+ target_shape[1] = self.target_size[1]
689
+ return target_shape
690
+
691
+ def apply(
692
+ self,
693
+ img: np.ndarray,
694
+ processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
695
+ target_shape: tuple[int, int],
696
+ **params: Any,
697
+ ) -> np.ndarray:
698
+ """Apply mosaic transformation to the input image.
699
+
700
+ Args:
701
+ img (np.ndarray): Input image
702
+ processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary of processed cell data
703
+ target_shape (tuple[int, int]): Shape of the target image.
704
+ **params (Any): Additional parameters
705
+
706
+ Returns:
707
+ np.ndarray: Mosaic transformed image
708
+
709
+ """
710
+ return fmixing.assemble_mosaic_from_processed_cells(
711
+ processed_cells=processed_cells,
712
+ target_shape=target_shape,
713
+ dtype=img.dtype,
714
+ data_key="image",
715
+ fill=self.fill,
716
+ )
717
+
718
+ def apply_to_mask(
719
+ self,
720
+ mask: np.ndarray,
721
+ processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
722
+ target_mask_shape: tuple[int, int],
723
+ **params: Any,
724
+ ) -> np.ndarray:
725
+ """Apply mosaic transformation to the input mask.
726
+
727
+ Args:
728
+ mask (np.ndarray): Input mask.
729
+ processed_cells (dict): Dictionary of processed cell data containing cropped/padded mask segments.
730
+ target_mask_shape (tuple[int, int]): Shape of the target mask.
731
+ **params (Any): Additional parameters (unused).
732
+
733
+ Returns:
734
+ np.ndarray: Mosaic transformed mask.
735
+
736
+ """
737
+ return fmixing.assemble_mosaic_from_processed_cells(
738
+ processed_cells=processed_cells,
739
+ target_shape=target_mask_shape,
740
+ dtype=mask.dtype,
741
+ data_key="mask",
742
+ fill=self.fill_mask,
743
+ )
744
+
745
+ def apply_to_bboxes(
746
+ self,
747
+ bboxes: np.ndarray, # Original bboxes - ignored
748
+ processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
749
+ **params: Any,
750
+ ) -> np.ndarray:
751
+ """Applies mosaic transformation to bounding boxes.
752
+
753
+ Args:
754
+ bboxes (np.ndarray): Original bounding boxes (ignored).
755
+ processed_cells (dict): Dictionary mapping placement coords to processed cell data
756
+ (containing shifted bboxes in absolute pixel coords).
757
+ **params (Any): Additional parameters (unused).
758
+
759
+ Returns:
760
+ np.ndarray: Final combined, filtered, bounding boxes.
761
+
762
+ """
763
+ all_shifted_bboxes = []
764
+
765
+ for cell_data in processed_cells.values():
766
+ shifted_bboxes = cell_data["bboxes"]
767
+ if shifted_bboxes.size > 0:
768
+ all_shifted_bboxes.append(shifted_bboxes)
769
+
770
+ if not all_shifted_bboxes:
771
+ return np.empty((0, bboxes.shape[1]), dtype=bboxes.dtype)
772
+
773
+ # Concatenate (these are absolute pixel coordinates)
774
+ combined_bboxes = np.concatenate(all_shifted_bboxes, axis=0)
775
+
776
+ # Apply filtering using processor parameters
777
+ bbox_processor = cast("BboxProcessor", self.get_processor("bboxes"))
778
+ # Assume processor exists if bboxes are being processed
779
+ shape_dict: dict[Literal["depth", "height", "width"], int] = {
780
+ "height": self.target_size[0],
781
+ "width": self.target_size[1],
782
+ }
783
+ return filter_bboxes(
784
+ combined_bboxes,
785
+ shape_dict,
786
+ min_area=bbox_processor.params.min_area,
787
+ min_visibility=bbox_processor.params.min_visibility,
788
+ min_width=bbox_processor.params.min_width,
789
+ min_height=bbox_processor.params.min_height,
790
+ max_accept_ratio=bbox_processor.params.max_accept_ratio,
791
+ )
792
+
793
+ def apply_to_keypoints(
794
+ self,
795
+ keypoints: np.ndarray, # Original keypoints - ignored
796
+ processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
797
+ **params: Any,
798
+ ) -> np.ndarray:
799
+ """Applies mosaic transformation to keypoints.
800
+
801
+ Args:
802
+ keypoints (np.ndarray): Original keypoints (ignored).
803
+ processed_cells (dict): Dictionary mapping placement coords to processed cell data
804
+ (containing shifted keypoints).
805
+ **params (Any): Additional parameters (unused).
806
+
807
+ Returns:
808
+ np.ndarray: Final combined, filtered keypoints.
809
+
810
+ """
811
+ all_shifted_keypoints = []
812
+
813
+ for cell_data in processed_cells.values():
814
+ shifted_keypoints = cell_data["keypoints"]
815
+ if shifted_keypoints.size > 0:
816
+ all_shifted_keypoints.append(shifted_keypoints)
817
+
818
+ if not all_shifted_keypoints:
819
+ return np.empty((0, keypoints.shape[1]), dtype=keypoints.dtype)
820
+
821
+ combined_keypoints = np.concatenate(all_shifted_keypoints, axis=0)
822
+
823
+ # Filter out keypoints outside the target canvas boundaries
824
+ target_h, target_w = self.target_size
825
+ valid_indices = (
826
+ (combined_keypoints[:, 0] >= 0)
827
+ & (combined_keypoints[:, 0] < target_w)
828
+ & (combined_keypoints[:, 1] >= 0)
829
+ & (combined_keypoints[:, 1] < target_h)
830
+ )
831
+
832
+ return combined_keypoints[valid_indices]