nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,878 @@
1
+ """Functional implementations for image mixing operations.
2
+
3
+ This module provides utility functions for blending and combining images,
4
+ such as copy-and-paste operations with masking.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import random
10
+ from collections.abc import Sequence
11
+ from typing import Any, Literal, TypedDict, cast
12
+ from warnings import warn
13
+
14
+ import cv2
15
+ import numpy as np
16
+
17
+ import albumentations.augmentations.geometric.functional as fgeometric
18
+ from albumentations.augmentations.crops.transforms import Crop
19
+ from albumentations.augmentations.geometric.resize import LongestMaxSize, SmallestMaxSize
20
+ from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
21
+ from albumentations.core.composition import Compose
22
+ from albumentations.core.keypoints_utils import KeypointsProcessor
23
+ from albumentations.core.type_definitions import (
24
+ NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS,
25
+ NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS,
26
+ )
27
+
28
+
29
+ # Type definition for a processed mosaic item
30
+ class ProcessedMosaicItem(TypedDict):
31
+ """Represents a single data item (primary or additional) after preprocessing.
32
+
33
+ Includes the original image/mask and the *preprocessed* annotations.
34
+ """
35
+
36
+ image: np.ndarray # Image is mandatory
37
+ mask: np.ndarray | None
38
+ bboxes: np.ndarray | None
39
+ keypoints: np.ndarray | None
40
+
41
+
42
+ def copy_and_paste_blend(
43
+ base_image: np.ndarray,
44
+ overlay_image: np.ndarray,
45
+ overlay_mask: np.ndarray,
46
+ offset: tuple[int, int],
47
+ ) -> np.ndarray:
48
+ """Blend images by copying pixels from an overlay image to a base image using a mask.
49
+
50
+ This function copies pixels from the overlay image to the base image only where
51
+ the mask has non-zero values. The overlay is placed at the specified offset
52
+ from the top-left corner of the base image.
53
+
54
+ Args:
55
+ base_image (np.ndarray): The destination image that will be modified.
56
+ overlay_image (np.ndarray): The source image containing pixels to copy.
57
+ overlay_mask (np.ndarray): Binary mask indicating which pixels to copy from the overlay.
58
+ Pixels are copied where mask > 0.
59
+ offset (tuple[int, int]): The (y, x) offset specifying where to place the
60
+ top-left corner of the overlay relative to the base image.
61
+
62
+ Returns:
63
+ np.ndarray: The blended image with the overlay applied to the base image.
64
+
65
+ """
66
+ y_offset, x_offset = offset
67
+
68
+ blended_image = base_image.copy()
69
+ mask_indices = np.where(overlay_mask > 0)
70
+ blended_image[mask_indices[0] + y_offset, mask_indices[1] + x_offset] = overlay_image[
71
+ mask_indices[0],
72
+ mask_indices[1],
73
+ ]
74
+ return blended_image
75
+
76
+
77
+ def calculate_mosaic_center_point(
78
+ grid_yx: tuple[int, int],
79
+ cell_shape: tuple[int, int],
80
+ target_size: tuple[int, int],
81
+ center_range: tuple[float, float],
82
+ py_random: random.Random,
83
+ ) -> tuple[int, int]:
84
+ """Calculates the center point for the mosaic crop using proportional sampling within the valid zone.
85
+
86
+ Ensures the center point allows a crop of target_size to overlap
87
+ all grid cells, applying randomness based on center_range proportionally
88
+ within the valid region where the center can lie.
89
+
90
+ Args:
91
+ grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
92
+ cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
93
+ target_size (tuple[int, int]): The final output (height, width).
94
+ center_range (tuple[float, float]): Range [0.0-1.0] for sampling center proportionally
95
+ within the valid zone.
96
+ py_random (random.Random): Random state instance.
97
+
98
+ Returns:
99
+ tuple[int, int]: The calculated (x, y) center point relative to the
100
+ top-left of the conceptual large grid.
101
+
102
+ """
103
+ rows, cols = grid_yx
104
+ cell_h, cell_w = cell_shape
105
+ target_h, target_w = target_size
106
+
107
+ large_grid_h = rows * cell_h
108
+ large_grid_w = cols * cell_w
109
+
110
+ # Define valid center range bounds (inclusive)
111
+ # The center must be far enough from edges so the crop window fits
112
+ min_cx = target_w // 2
113
+ max_cx = large_grid_w - (target_w + 1) // 2
114
+ min_cy = target_h // 2
115
+ max_cy = large_grid_h - (target_h + 1) // 2
116
+
117
+ # Calculate valid range dimensions (size of the safe zone)
118
+ valid_w = max_cx - min_cx + 1
119
+ valid_h = max_cy - min_cy + 1
120
+
121
+ # Sample relative position within the valid range using center_range
122
+ rel_x = py_random.uniform(*center_range)
123
+ rel_y = py_random.uniform(*center_range)
124
+
125
+ # Calculate center coordinates by scaling relative position within valid range
126
+ # Add the minimum bound to shift the range start
127
+ center_x = min_cx + int(valid_w * rel_x)
128
+ center_y = min_cy + int(valid_h * rel_y)
129
+
130
+ # Ensure the result is strictly within the calculated bounds after int conversion
131
+ # (This clip is mostly a safety measure, shouldn't be needed with correct int conversion)
132
+ center_x = max(min_cx, min(center_x, max_cx))
133
+ center_y = max(min_cy, min(center_y, max_cy))
134
+
135
+ return center_x, center_y
136
+
137
+
138
+ def calculate_cell_placements(
139
+ grid_yx: tuple[int, int],
140
+ cell_shape: tuple[int, int],
141
+ target_size: tuple[int, int],
142
+ center_xy: tuple[int, int],
143
+ ) -> list[tuple[int, int, int, int]]:
144
+ """Calculates placements by clipping arange-defined grid lines to the crop window.
145
+
146
+ Args:
147
+ grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
148
+ cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
149
+ target_size (tuple[int, int]): The final output (height, width).
150
+ center_xy (tuple[int, int]): The calculated (x, y) center of the final crop window,
151
+ relative to the top-left of the conceptual large grid.
152
+
153
+ Returns:
154
+ list[tuple[int, int, int, int]]:
155
+ A list containing placement coordinates `(x_min, y_min, x_max, y_max)`
156
+ for each resulting cell part on the final output canvas.
157
+
158
+ """
159
+ rows, cols = grid_yx
160
+ cell_h, cell_w = cell_shape
161
+ target_h, target_w = target_size
162
+ center_x, center_y = center_xy
163
+
164
+ # 1. Generate grid line coordinates using arange for the large grid
165
+ y_coords_large = np.arange(rows + 1) * cell_h
166
+ x_coords_large = np.arange(cols + 1) * cell_w
167
+
168
+ # 2. Calculate Crop Window boundaries
169
+ crop_x_min = center_x - target_w // 2
170
+ crop_y_min = center_y - target_h // 2
171
+ crop_x_max = crop_x_min + target_w
172
+ crop_y_max = crop_y_min + target_h
173
+
174
+ def _clip_coords(coords: np.ndarray, min_val: int, max_val: int) -> np.ndarray:
175
+ clipped_coords = np.clip(coords, min_val, max_val)
176
+ # Subtract min_val to convert absolute clipped coordinates
177
+ # into coordinates relative to the crop window's origin (min_val becomes 0).
178
+ return np.unique(clipped_coords) - min_val
179
+
180
+ y_coords_clipped = _clip_coords(y_coords_large, crop_y_min, crop_y_max)
181
+ x_coords_clipped = _clip_coords(x_coords_large, crop_x_min, crop_x_max)
182
+
183
+ # 4. Form all cell coordinates efficiently
184
+ num_x_intervals = len(x_coords_clipped) - 1
185
+ num_y_intervals = len(y_coords_clipped) - 1
186
+ result = []
187
+
188
+ for y_idx in range(num_y_intervals):
189
+ y_min = y_coords_clipped[y_idx]
190
+ y_max = y_coords_clipped[y_idx + 1]
191
+ for x_idx in range(num_x_intervals):
192
+ x_min = x_coords_clipped[x_idx]
193
+ x_max = x_coords_clipped[x_idx + 1]
194
+ result.append((int(x_min), int(y_min), int(x_max), int(y_max)))
195
+
196
+ return result
197
+
198
+
199
+ def _check_data_compatibility(
200
+ primary_data: np.ndarray | None,
201
+ item_data: np.ndarray | None,
202
+ data_key: Literal["image", "mask"],
203
+ ) -> tuple[bool, str | None]: # Returns (is_compatible, error_message)
204
+ """Checks if the dimensions and channels of item_data match primary_data."""
205
+ # 1. Check if item has the required data (image is always required)
206
+ if item_data is None:
207
+ if data_key == "image":
208
+ return False, "Item is missing required key 'image'"
209
+ # Mask is optional, missing is compatible
210
+ return True, None
211
+
212
+ # 2. If item data exists, check against primary data (if primary data exists)
213
+ if primary_data is None: # No primary data to compare against
214
+ return True, None
215
+
216
+ # Both primary and item data exist, compare them
217
+ primary_ndim = primary_data.ndim
218
+ item_ndim = item_data.ndim
219
+
220
+ if primary_ndim != item_ndim:
221
+ return False, (
222
+ f"Item '{data_key}' has {item_ndim} dimensions, but primary has {primary_ndim}. "
223
+ f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
224
+ )
225
+
226
+ if primary_ndim == 3:
227
+ primary_channels = primary_data.shape[-1]
228
+ item_channels = item_data.shape[-1]
229
+ if primary_channels != item_channels:
230
+ return False, (
231
+ f"Item '{data_key}' has {item_channels} channels, but primary has {primary_channels}. "
232
+ f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
233
+ )
234
+
235
+ # Dimensions match (either both 2D or both 3D with same channels)
236
+ return True, None
237
+
238
+
239
+ def filter_valid_metadata(
240
+ metadata_input: Sequence[dict[str, Any]] | None,
241
+ metadata_key_name: str,
242
+ data: dict[str, Any],
243
+ ) -> list[dict[str, Any]]:
244
+ """Filters a list of metadata dicts, keeping only valid ones based on data compatibility."""
245
+ if not isinstance(metadata_input, Sequence):
246
+ warn(
247
+ f"Metadata under key '{metadata_key_name}' is not a Sequence (e.g., list or tuple). "
248
+ f"Returning empty list for additional items.",
249
+ UserWarning,
250
+ stacklevel=3,
251
+ )
252
+ return []
253
+
254
+ valid_items = []
255
+ primary_image = data.get("image")
256
+ primary_mask = data.get("mask")
257
+
258
+ for i, item in enumerate(metadata_input):
259
+ if not isinstance(item, dict):
260
+ warn(
261
+ f"Item at index {i} in '{metadata_key_name}' is not a dict and will be skipped.",
262
+ UserWarning,
263
+ stacklevel=4,
264
+ )
265
+ continue
266
+
267
+ item_is_valid = True # Assume valid initially
268
+ for target_key, primary_target_data in [
269
+ ("image", primary_image),
270
+ ("mask", primary_mask),
271
+ ]:
272
+ item_target_data = item.get(target_key)
273
+
274
+ is_compatible, error_msg = _check_data_compatibility(
275
+ primary_target_data,
276
+ item_target_data,
277
+ cast("Literal['image', 'mask']", target_key),
278
+ )
279
+
280
+ if not is_compatible:
281
+ msg = (
282
+ f"Item at index {i} in '{metadata_key_name}' skipped due "
283
+ f"to incompatibility in '{target_key}': {error_msg}"
284
+ )
285
+ warn(msg, UserWarning, stacklevel=4)
286
+ item_is_valid = False
287
+ break # Stop checking other targets for this item
288
+
289
+ if item_is_valid:
290
+ valid_items.append(item)
291
+
292
+ return valid_items
293
+
294
+
295
+ def assign_items_to_grid_cells(
296
+ num_items: int,
297
+ cell_placements: list[tuple[int, int, int, int]],
298
+ py_random: random.Random,
299
+ ) -> dict[tuple[int, int, int, int], int]:
300
+ """Assigns item indices to placement coordinate tuples.
301
+
302
+ Assigns the primary item (index 0) to the placement with the largest area,
303
+ and assigns the remaining items (indices 1 to num_items-1) randomly to the
304
+ remaining placements.
305
+
306
+ Args:
307
+ num_items (int): The total number of items to assign (primary + additional + replicas).
308
+ cell_placements (list[tuple[int, int, int, int]]): List of placement
309
+ coords (x1, y1, x2, y2) for cells to be filled.
310
+ py_random (random.Random): Random state instance.
311
+
312
+ Returns:
313
+ dict[tuple[int, int, int, int], int]: Dict mapping placement coords (x1, y1, x2, y2)
314
+ to assigned item index.
315
+
316
+ """
317
+ if not cell_placements:
318
+ return {}
319
+
320
+ # Find the placement tuple with the largest area for primary assignment
321
+ primary_placement = max(
322
+ cell_placements,
323
+ key=lambda coords: (coords[2] - coords[0]) * (coords[3] - coords[1]),
324
+ )
325
+
326
+ placement_to_item_index: dict[tuple[int, int, int, int], int] = {
327
+ primary_placement: 0,
328
+ }
329
+
330
+ # Use list comprehension for potentially better performance
331
+ remaining_placements = [coords for coords in cell_placements if coords != primary_placement]
332
+
333
+ # Indices for additional/replicated items start from 1
334
+ remaining_item_indices = list(range(1, num_items))
335
+ py_random.shuffle(remaining_item_indices)
336
+
337
+ num_to_assign = min(len(remaining_placements), len(remaining_item_indices))
338
+ for i in range(num_to_assign):
339
+ placement_to_item_index[remaining_placements[i]] = remaining_item_indices[i]
340
+
341
+ return placement_to_item_index
342
+
343
+
344
+ def _preprocess_item_annotations(
345
+ item: dict[str, Any],
346
+ processor: BboxProcessor | KeypointsProcessor | None,
347
+ data_key: Literal["bboxes", "keypoints"],
348
+ ) -> np.ndarray | None:
349
+ """Helper to preprocess annotations (bboxes or keypoints) for a single item."""
350
+ original_data = item.get(data_key)
351
+
352
+ # Check if processor exists and the relevant data key is in the item
353
+ if processor and data_key in item and item.get(data_key) is not None:
354
+ # === Add validation for required label fields ===
355
+ required_labels = processor.params.label_fields
356
+
357
+ if required_labels and [field for field in required_labels if field not in item]:
358
+ raise ValueError(
359
+ f"Item contains '{data_key}' but is missing required label "
360
+ "fields: {[field for field in required_labels if field not in item]}. "
361
+ f"Ensure all label fields declared in {type(processor.params).__name__} "
362
+ f"({required_labels}) are present in the item dictionary when '{data_key}' is present.",
363
+ )
364
+ # === End validation ===
365
+
366
+ # Create a temporary minimal dict for the processor
367
+ temp_data = {
368
+ "image": item["image"],
369
+ data_key: item[data_key],
370
+ }
371
+
372
+ # Add declared label fields if they exist in the item (already validated above)
373
+ if required_labels:
374
+ for field in required_labels:
375
+ # Check again just in case validation logic changes, avoids KeyError
376
+ if field in item:
377
+ temp_data[field] = item[field]
378
+
379
+ # Preprocess modifies temp_data in-place
380
+ processor.preprocess(temp_data)
381
+ # Return the potentially modified data from the temp dict
382
+ return temp_data.get(data_key)
383
+
384
+ # Return original data if no processor or data key wasn't in item
385
+ return original_data
386
+
387
+
388
+ def preprocess_selected_mosaic_items(
389
+ selected_raw_items: list[dict[str, Any]],
390
+ bbox_processor: BboxProcessor | None, # Allow None
391
+ keypoint_processor: KeypointsProcessor | None, # Allow None
392
+ ) -> list[ProcessedMosaicItem]:
393
+ """Preprocesses bboxes/keypoints for selected raw additional items.
394
+
395
+ Iterates through items, preprocesses annotations individually using processors
396
+ (updating label encoders), and returns a list of dicts with original image/mask
397
+ and the corresponding preprocessed bboxes/keypoints.
398
+ """
399
+ if not selected_raw_items:
400
+ return []
401
+
402
+ result_data_items: list[ProcessedMosaicItem] = []
403
+
404
+ for item in selected_raw_items:
405
+ processed_bboxes = _preprocess_item_annotations(item, bbox_processor, "bboxes")
406
+ processed_keypoints = _preprocess_item_annotations(item, keypoint_processor, "keypoints")
407
+
408
+ # Construct the final processed item dict
409
+ processed_item_dict: ProcessedMosaicItem = {
410
+ "image": item["image"],
411
+ "mask": item.get("mask"),
412
+ "bboxes": processed_bboxes, # Already np.ndarray or None
413
+ "keypoints": processed_keypoints, # Already np.ndarray or None
414
+ }
415
+ result_data_items.append(processed_item_dict)
416
+
417
+ return result_data_items
418
+
419
+
420
+ def get_opposite_crop_coords(
421
+ cell_size: tuple[int, int],
422
+ crop_size: tuple[int, int],
423
+ cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
424
+ ) -> tuple[int, int, int, int]:
425
+ """Calculates crop coordinates positioned opposite to the specified cell_position.
426
+
427
+ Given a cell of `cell_size`, this function determines the top-left (x_min, y_min)
428
+ and bottom-right (x_max, y_max) coordinates for a crop of `crop_size`, such
429
+ that the crop is located in the corner or center opposite to `cell_position`.
430
+
431
+ For example, if `cell_position` is "top_left", the crop coordinates will
432
+ correspond to the bottom-right region of the cell.
433
+
434
+ Args:
435
+ cell_size: The (height, width) of the cell from which to crop.
436
+ crop_size: The (height, width) of the desired crop.
437
+ cell_position: The reference position within the cell. The crop will be
438
+ taken from the opposite position.
439
+
440
+ Returns:
441
+ tuple[int, int, int, int]: (x_min, y_min, x_max, y_max) representing the crop coordinates.
442
+
443
+ Raises:
444
+ ValueError: If crop_size is larger than cell_size in either dimension.
445
+
446
+ """
447
+ cell_h, cell_w = cell_size
448
+ crop_h, crop_w = crop_size
449
+
450
+ if crop_h > cell_h or crop_w > cell_w:
451
+ raise ValueError(f"Crop size {crop_size} cannot be larger than cell size {cell_size}")
452
+
453
+ # Determine top-left corner (x_min, y_min) based on the OPPOSITE position
454
+ if cell_position == "top_left": # Crop from bottom_right
455
+ x_min = cell_w - crop_w
456
+ y_min = cell_h - crop_h
457
+ elif cell_position == "top_right": # Crop from bottom_left
458
+ x_min = 0
459
+ y_min = cell_h - crop_h
460
+ elif cell_position == "bottom_left": # Crop from top_right
461
+ x_min = cell_w - crop_w
462
+ y_min = 0
463
+ elif cell_position == "bottom_right": # Crop from top_left
464
+ x_min = 0
465
+ y_min = 0
466
+ elif cell_position == "center": # Crop from center
467
+ x_min = (cell_w - crop_w) // 2
468
+ y_min = (cell_h - crop_h) // 2
469
+ else:
470
+ # Should be unreachable due to Literal type hint, but good practice
471
+ raise ValueError(f"Invalid cell_position: {cell_position}")
472
+
473
+ # Calculate bottom-right corner
474
+ x_max = x_min + crop_w
475
+ y_max = y_min + crop_h
476
+
477
+ return x_min, y_min, x_max, y_max
478
+
479
+
480
+ def process_cell_geometry(
481
+ cell_shape: tuple[int, int],
482
+ item: ProcessedMosaicItem,
483
+ target_shape: tuple[int, int],
484
+ fill: float | tuple[float, ...],
485
+ fill_mask: float | tuple[float, ...],
486
+ fit_mode: Literal["cover", "contain"],
487
+ interpolation: int,
488
+ mask_interpolation: int,
489
+ cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
490
+ ) -> ProcessedMosaicItem:
491
+ """Applies geometric transformations (padding and/or cropping) to a single mosaic item.
492
+
493
+ Uses a Compose pipeline with PadIfNeeded and Crop to ensure the output
494
+ matches the target cell dimensions exactly, handling both padding and cropping cases.
495
+
496
+ Args:
497
+ cell_shape: (tuple[int, int]): Shape of the cell.
498
+ item: (ProcessedMosaicItem): The preprocessed mosaic item dictionary.
499
+ target_shape: (tuple[int, int]): Target shape of the cell.
500
+ fill: (float | tuple[float, ...]): Fill value for image padding.
501
+ fill_mask: (float | tuple[float, ...]): Fill value for mask padding.
502
+ fit_mode: (Literal["cover", "contain"]): Fit mode for the mosaic.
503
+ interpolation: (int): Interpolation method for image.
504
+ mask_interpolation: (int): Interpolation method for mask.
505
+ cell_position: (Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]): Position
506
+ of the cell.
507
+
508
+ Returns: (ProcessedMosaicItem): Dictionary containing the geometrically processed image,
509
+ mask, bboxes, and keypoints, fitting the target dimensions.
510
+
511
+ """
512
+ # Define the pipeline: PadIfNeeded first, then Crop
513
+ compose_kwargs: dict[str, Any] = {"p": 1.0}
514
+ if item.get("bboxes") is not None:
515
+ compose_kwargs["bbox_params"] = {"format": "albumentations"}
516
+ if item.get("keypoints") is not None:
517
+ compose_kwargs["keypoint_params"] = {"format": "albumentations"}
518
+
519
+ crop_coords = get_opposite_crop_coords(cell_shape, target_shape, cell_position)
520
+
521
+ if fit_mode == "cover":
522
+ geom_pipeline = Compose(
523
+ [
524
+ SmallestMaxSize(
525
+ max_size_hw=cell_shape,
526
+ interpolation=interpolation,
527
+ mask_interpolation=mask_interpolation,
528
+ p=1.0,
529
+ ),
530
+ Crop(
531
+ x_min=crop_coords[0],
532
+ y_min=crop_coords[1],
533
+ x_max=crop_coords[2],
534
+ y_max=crop_coords[3],
535
+ ),
536
+ ],
537
+ **compose_kwargs,
538
+ )
539
+ elif fit_mode == "contain":
540
+ geom_pipeline = Compose(
541
+ [
542
+ LongestMaxSize(
543
+ max_size_hw=cell_shape,
544
+ interpolation=interpolation,
545
+ mask_interpolation=mask_interpolation,
546
+ p=1.0,
547
+ ),
548
+ Crop(
549
+ x_min=crop_coords[0],
550
+ y_min=crop_coords[1],
551
+ x_max=crop_coords[2],
552
+ y_max=crop_coords[3],
553
+ pad_if_needed=True,
554
+ fill=fill,
555
+ fill_mask=fill_mask,
556
+ p=1.0,
557
+ ),
558
+ ],
559
+ **compose_kwargs,
560
+ )
561
+ else:
562
+ raise ValueError(f"Invalid fit_mode: {fit_mode}. Must be 'cover' or 'contain'.")
563
+
564
+ # Prepare input data for the pipeline
565
+ geom_input = {"image": item["image"]}
566
+ if item.get("mask") is not None:
567
+ geom_input["mask"] = item["mask"]
568
+ if item.get("bboxes") is not None:
569
+ # Compose expects bboxes in a specific format, ensure it's compatible
570
+ # Assuming item['bboxes'] is already preprocessed correctly
571
+ geom_input["bboxes"] = item["bboxes"]
572
+ if item.get("keypoints") is not None:
573
+ geom_input["keypoints"] = item["keypoints"]
574
+
575
+ # Apply the pipeline
576
+ processed_item = geom_pipeline(**geom_input)
577
+
578
+ # Ensure output dict has the same structure as ProcessedMosaicItem
579
+ # Compose might not return None for missing keys, handle explicitly
580
+ return {
581
+ "image": processed_item["image"],
582
+ "mask": processed_item.get("mask"),
583
+ "bboxes": processed_item.get("bboxes"),
584
+ "keypoints": processed_item.get("keypoints"),
585
+ }
586
+
587
+
588
+ def shift_cell_coordinates(
589
+ processed_item_geom: ProcessedMosaicItem,
590
+ placement_coords: tuple[int, int, int, int],
591
+ ) -> ProcessedMosaicItem:
592
+ """Shifts the coordinates of geometrically processed bboxes and keypoints.
593
+
594
+ Args:
595
+ processed_item_geom: (ProcessedMosaicItem): The output from process_cell_geometry.
596
+ placement_coords: (tuple[int, int, int, int]): The (x1, y1, x2, y2) placement on the final canvas.
597
+
598
+ Returns: (ProcessedMosaicItem): A dictionary with keys 'bboxes' and 'keypoints', containing the shifted
599
+ numpy arrays (potentially empty).
600
+
601
+ """
602
+ tgt_x1, tgt_y1, _, _ = placement_coords
603
+
604
+ shifted_bboxes = None
605
+ shifted_keypoints = None
606
+
607
+ bboxes_geom = processed_item_geom.get("bboxes")
608
+ if bboxes_geom is not None and np.asarray(bboxes_geom).size > 0:
609
+ bboxes_geom_arr = np.asarray(bboxes_geom) # Ensure it's an array
610
+ bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.int32)
611
+ shifted_bboxes = fgeometric.shift_bboxes(bboxes_geom_arr, bbox_shift_vector)
612
+
613
+ keypoints_geom = processed_item_geom.get("keypoints")
614
+ if keypoints_geom is not None and np.asarray(keypoints_geom).size > 0:
615
+ keypoints_geom_arr = np.asarray(keypoints_geom) # Ensure it's an array
616
+ kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
617
+ shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
618
+
619
+ return {
620
+ "bboxes": shifted_bboxes,
621
+ "keypoints": shifted_keypoints,
622
+ "image": processed_item_geom["image"],
623
+ "mask": processed_item_geom.get("mask"),
624
+ }
625
+
626
+
627
+ def assemble_mosaic_from_processed_cells(
628
+ processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
629
+ target_shape: tuple[int, ...], # Use full canvas shape (H, W) or (H, W, C)
630
+ dtype: np.dtype,
631
+ data_key: Literal["image", "mask"],
632
+ fill: float | tuple[float, ...] | None, # Value for image fill or mask fill
633
+ ) -> np.ndarray:
634
+ """Assembles the final mosaic image or mask from processed cell data onto a canvas.
635
+
636
+ Initializes the canvas with the fill value and overwrites with processed segments.
637
+ Handles potentially multi-channel masks.
638
+ Addresses potential broadcasting errors if mask segments have unexpected dimensions.
639
+ Assumes input data is valid and correctly sized.
640
+
641
+ Args:
642
+ processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary mapping
643
+ placement coords to processed cell data.
644
+ target_shape (tuple[int, ...]): The target shape of the output canvas (e.g., (H, W) or (H, W, C)).
645
+ dtype (np.dtype): NumPy dtype for the canvas.
646
+ data_key (Literal["image", "mask"]): Specifies whether to assemble 'image' or 'mask'.
647
+ fill (float | tuple[float, ...] | None): Value used to initialize the canvas (image fill or mask fill).
648
+ Should be a float/int or a tuple matching the number of channels.
649
+ If None, defaults to 0.
650
+
651
+ Returns:
652
+ np.ndarray: The assembled mosaic canvas.
653
+
654
+ """
655
+ # Use 0 as default fill if None is provided
656
+ actual_fill = fill if fill is not None else 0
657
+
658
+ # Convert fill to numpy array to handle broadcasting in np.full
659
+ fill_value = np.array(actual_fill, dtype=dtype)
660
+ # Initialize canvas with the fill value.
661
+ # If fill_value shape is incompatible with target_shape, np.full will raise ValueError.
662
+ canvas = np.full(target_shape, fill_value=fill_value, dtype=dtype)
663
+
664
+ # Iterate and paste segments onto the pre-filled canvas
665
+ for placement_coords, cell_data in processed_cells.items():
666
+ segment = cell_data.get(data_key)
667
+
668
+ # If segment exists, paste it over the filled background
669
+ if segment is not None:
670
+ tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
671
+
672
+ canvas[tgt_y1:tgt_y2, tgt_x1:tgt_x2] = segment
673
+
674
+ return canvas
675
+
676
+
677
+ def process_all_mosaic_geometries(
678
+ canvas_shape: tuple[int, int],
679
+ cell_shape: tuple[int, int],
680
+ placement_to_item_index: dict[tuple[int, int, int, int], int],
681
+ final_items_for_grid: list[ProcessedMosaicItem],
682
+ fill: float | tuple[float, ...],
683
+ fill_mask: float | tuple[float, ...],
684
+ fit_mode: Literal["cover", "contain"],
685
+ interpolation: Literal[
686
+ cv2.INTER_NEAREST,
687
+ cv2.INTER_NEAREST_EXACT,
688
+ cv2.INTER_LINEAR,
689
+ cv2.INTER_CUBIC,
690
+ cv2.INTER_AREA,
691
+ cv2.INTER_LANCZOS4,
692
+ cv2.INTER_LINEAR_EXACT,
693
+ ],
694
+ mask_interpolation: Literal[
695
+ cv2.INTER_NEAREST,
696
+ cv2.INTER_NEAREST_EXACT,
697
+ cv2.INTER_LINEAR,
698
+ cv2.INTER_CUBIC,
699
+ cv2.INTER_AREA,
700
+ cv2.INTER_LANCZOS4,
701
+ cv2.INTER_LINEAR_EXACT,
702
+ ],
703
+ ) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]:
704
+ """Processes the geometry (cropping/padding) for all assigned mosaic cells.
705
+
706
+ Iterates through assigned placements, applies geometric transforms via process_cell_geometry,
707
+ and returns a dictionary mapping final placement coordinates to the processed item data.
708
+ The bbox/keypoint coordinates in the returned dict are *not* shifted yet.
709
+
710
+ Args:
711
+ canvas_shape (tuple[int, int]): The shape of the canvas.
712
+ cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
713
+ placement_to_item_index (dict[tuple[int, int, int, int], int]): Mapping from placement
714
+ coordinates (x1, y1, x2, y2) to assigned item index.
715
+ final_items_for_grid (list[ProcessedMosaicItem]): List of all preprocessed items available.
716
+ fill (float | tuple[float, ...]): Fill value for image padding.
717
+ fill_mask (float | tuple[float, ...]): Fill value for mask padding.
718
+ fit_mode (Literal["cover", "contain"]): Fit mode for the mosaic.
719
+ interpolation (int): Interpolation method for image.
720
+ mask_interpolation (int): Interpolation method for mask.
721
+
722
+ Returns:
723
+ dict[tuple[int, int, int, int], ProcessedMosaicItem]: Dictionary mapping final placement
724
+ coordinates (x1, y1, x2, y2) to the geometrically processed item data (image, mask, un-shifted bboxes/kps).
725
+
726
+ """
727
+ processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
728
+
729
+ # Iterate directly over placements and their assigned item indices
730
+ for placement_coords, item_idx in placement_to_item_index.items():
731
+ item = final_items_for_grid[item_idx]
732
+ tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
733
+ target_h = tgt_y2 - tgt_y1
734
+ target_w = tgt_x2 - tgt_x1
735
+
736
+ cell_position = get_cell_relative_position(placement_coords, canvas_shape)
737
+
738
+ # Apply geometric processing (crop/pad)
739
+ processed_cells_geom[placement_coords] = process_cell_geometry(
740
+ cell_shape=cell_shape,
741
+ item=item,
742
+ target_shape=(target_h, target_w),
743
+ fill=fill,
744
+ fill_mask=fill_mask,
745
+ fit_mode=fit_mode,
746
+ interpolation=interpolation,
747
+ mask_interpolation=mask_interpolation,
748
+ cell_position=cell_position,
749
+ )
750
+
751
+ return processed_cells_geom
752
+
753
+
754
+ def get_cell_relative_position(
755
+ placement_coords: tuple[int, int, int, int],
756
+ target_shape: tuple[int, int],
757
+ ) -> Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
758
+ """Determines the position of a cell relative to the center of the target canvas.
759
+
760
+ Compares the cell center to the canvas center and returns its quadrant
761
+ or "center" if it lies on or very close to a central axis.
762
+
763
+ Args:
764
+ placement_coords (tuple[int, int, int, int]): The (x_min, y_min, x_max, y_max) coordinates
765
+ of the cell.
766
+ target_shape (tuple[int, int]): The (height, width) of the overall target canvas.
767
+
768
+ Returns:
769
+ Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
770
+ The position of the cell relative to the center of the target canvas.
771
+
772
+ """
773
+ target_h, target_w = target_shape
774
+ x1, y1, x2, y2 = placement_coords
775
+
776
+ canvas_center_x = target_w / 2.0
777
+ canvas_center_y = target_h / 2.0
778
+
779
+ cell_center_x = (x1 + x2) / 2.0
780
+ cell_center_y = (y1 + y2) / 2.0
781
+
782
+ # Determine vertical position
783
+ if cell_center_y < canvas_center_y:
784
+ v_pos = "top"
785
+ elif cell_center_y > canvas_center_y:
786
+ v_pos = "bottom"
787
+ else: # Exactly on the horizontal center line
788
+ v_pos = "center"
789
+
790
+ # Determine horizontal position
791
+ if cell_center_x < canvas_center_x:
792
+ h_pos = "left"
793
+ elif cell_center_x > canvas_center_x:
794
+ h_pos = "right"
795
+ else: # Exactly on the vertical center line
796
+ h_pos = "center"
797
+
798
+ # Map positions to the final string
799
+ position_map = {
800
+ ("top", "left"): "top_left",
801
+ ("top", "right"): "top_right",
802
+ ("bottom", "left"): "bottom_left",
803
+ ("bottom", "right"): "bottom_right",
804
+ }
805
+
806
+ # Default to "center" if the combination is not in the map
807
+ # (which happens if either v_pos or h_pos is "center")
808
+ return cast(
809
+ "Literal['top_left', 'top_right', 'center', 'bottom_left', 'bottom_right']",
810
+ position_map.get((v_pos, h_pos), "center"),
811
+ )
812
+
813
+
814
+ def shift_all_coordinates(
815
+ processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem],
816
+ canvas_shape: tuple[int, int],
817
+ ) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]: # Return type matches input, but values are updated
818
+ """Shifts coordinates for all geometrically processed cells.
819
+
820
+ Iterates through the processed cells (keyed by placement coords), applies coordinate
821
+ shifting to bboxes/keypoints, and returns a new dictionary with the same keys
822
+ but updated ProcessedMosaicItem values containing the *shifted* coordinates.
823
+
824
+ Args:
825
+ processed_cells_geom (dict[tuple[int, int, int, int], ProcessedMosaicItem]):
826
+ Output from process_all_mosaic_geometries (keyed by placement coords).
827
+ canvas_shape (tuple[int, int]): The shape of the canvas.
828
+
829
+ Returns:
830
+ dict[tuple[int, int, int, int], ProcessedMosaicItem]: Final dictionary mapping
831
+ placement coords (x1, y1, x2, y2) to processed cell data with shifted coordinates.
832
+
833
+ """
834
+ final_processed_cells: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
835
+ canvas_h, canvas_w = canvas_shape
836
+
837
+ for placement_coords, cell_data_geom in processed_cells_geom.items():
838
+ tgt_x1, tgt_y1 = placement_coords[:2]
839
+
840
+ cell_width = placement_coords[2] - placement_coords[0]
841
+ cell_height = placement_coords[3] - placement_coords[1]
842
+
843
+ # Extract geometrically processed bboxes/keypoints
844
+ bboxes_geom = cell_data_geom.get("bboxes")
845
+ keypoints_geom = cell_data_geom.get("keypoints")
846
+
847
+ final_cell_data = {
848
+ "image": cell_data_geom["image"],
849
+ "mask": cell_data_geom.get("mask"),
850
+ }
851
+
852
+ # Perform shifting if data exists
853
+ if bboxes_geom is not None and bboxes_geom.size > 0:
854
+ bboxes_geom_arr = np.asarray(bboxes_geom)
855
+ bbox_denoramlized = denormalize_bboxes(bboxes_geom_arr, {"height": cell_height, "width": cell_width})
856
+ bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.float32)
857
+
858
+ shifted_bboxes_denormalized = fgeometric.shift_bboxes(bbox_denoramlized, bbox_shift_vector)
859
+ shifted_bboxes = normalize_bboxes(shifted_bboxes_denormalized, {"height": canvas_h, "width": canvas_w})
860
+ final_cell_data["bboxes"] = shifted_bboxes
861
+ else:
862
+ final_cell_data["bboxes"] = np.empty((0, NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS))
863
+
864
+ if keypoints_geom is not None and keypoints_geom.size > 0:
865
+ keypoints_geom_arr = np.asarray(keypoints_geom)
866
+
867
+ # Ensure shift vector matches keypoint dtype (usually float)
868
+ kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
869
+
870
+ shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
871
+
872
+ final_cell_data["keypoints"] = shifted_keypoints
873
+ else:
874
+ final_cell_data["keypoints"] = np.empty((0, NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS))
875
+
876
+ final_processed_cells[placement_coords] = cast("ProcessedMosaicItem", final_cell_data)
877
+
878
+ return final_processed_cells