nrtk-albumentations 2.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nrtk-albumentations might be problematic. Click here for more details.
- albumentations/__init__.py +21 -0
- albumentations/augmentations/__init__.py +23 -0
- albumentations/augmentations/blur/__init__.py +0 -0
- albumentations/augmentations/blur/functional.py +438 -0
- albumentations/augmentations/blur/transforms.py +1633 -0
- albumentations/augmentations/crops/__init__.py +0 -0
- albumentations/augmentations/crops/functional.py +494 -0
- albumentations/augmentations/crops/transforms.py +3647 -0
- albumentations/augmentations/dropout/__init__.py +0 -0
- albumentations/augmentations/dropout/channel_dropout.py +134 -0
- albumentations/augmentations/dropout/coarse_dropout.py +567 -0
- albumentations/augmentations/dropout/functional.py +1017 -0
- albumentations/augmentations/dropout/grid_dropout.py +166 -0
- albumentations/augmentations/dropout/mask_dropout.py +274 -0
- albumentations/augmentations/dropout/transforms.py +461 -0
- albumentations/augmentations/dropout/xy_masking.py +186 -0
- albumentations/augmentations/geometric/__init__.py +0 -0
- albumentations/augmentations/geometric/distortion.py +1238 -0
- albumentations/augmentations/geometric/flip.py +752 -0
- albumentations/augmentations/geometric/functional.py +4151 -0
- albumentations/augmentations/geometric/pad.py +676 -0
- albumentations/augmentations/geometric/resize.py +956 -0
- albumentations/augmentations/geometric/rotate.py +864 -0
- albumentations/augmentations/geometric/transforms.py +1962 -0
- albumentations/augmentations/mixing/__init__.py +0 -0
- albumentations/augmentations/mixing/domain_adaptation.py +787 -0
- albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
- albumentations/augmentations/mixing/functional.py +878 -0
- albumentations/augmentations/mixing/transforms.py +832 -0
- albumentations/augmentations/other/__init__.py +0 -0
- albumentations/augmentations/other/lambda_transform.py +180 -0
- albumentations/augmentations/other/type_transform.py +261 -0
- albumentations/augmentations/pixel/__init__.py +0 -0
- albumentations/augmentations/pixel/functional.py +4226 -0
- albumentations/augmentations/pixel/transforms.py +7556 -0
- albumentations/augmentations/spectrogram/__init__.py +0 -0
- albumentations/augmentations/spectrogram/transform.py +220 -0
- albumentations/augmentations/text/__init__.py +0 -0
- albumentations/augmentations/text/functional.py +272 -0
- albumentations/augmentations/text/transforms.py +299 -0
- albumentations/augmentations/transforms3d/__init__.py +0 -0
- albumentations/augmentations/transforms3d/functional.py +393 -0
- albumentations/augmentations/transforms3d/transforms.py +1422 -0
- albumentations/augmentations/utils.py +249 -0
- albumentations/core/__init__.py +0 -0
- albumentations/core/bbox_utils.py +920 -0
- albumentations/core/composition.py +1885 -0
- albumentations/core/hub_mixin.py +299 -0
- albumentations/core/keypoints_utils.py +521 -0
- albumentations/core/label_manager.py +339 -0
- albumentations/core/pydantic.py +239 -0
- albumentations/core/serialization.py +352 -0
- albumentations/core/transforms_interface.py +976 -0
- albumentations/core/type_definitions.py +127 -0
- albumentations/core/utils.py +605 -0
- albumentations/core/validation.py +129 -0
- albumentations/pytorch/__init__.py +1 -0
- albumentations/pytorch/transforms.py +189 -0
- nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
- nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
- nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
- nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,878 @@
|
|
|
1
|
+
"""Functional implementations for image mixing operations.
|
|
2
|
+
|
|
3
|
+
This module provides utility functions for blending and combining images,
|
|
4
|
+
such as copy-and-paste operations with masking.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import random
|
|
10
|
+
from collections.abc import Sequence
|
|
11
|
+
from typing import Any, Literal, TypedDict, cast
|
|
12
|
+
from warnings import warn
|
|
13
|
+
|
|
14
|
+
import cv2
|
|
15
|
+
import numpy as np
|
|
16
|
+
|
|
17
|
+
import albumentations.augmentations.geometric.functional as fgeometric
|
|
18
|
+
from albumentations.augmentations.crops.transforms import Crop
|
|
19
|
+
from albumentations.augmentations.geometric.resize import LongestMaxSize, SmallestMaxSize
|
|
20
|
+
from albumentations.core.bbox_utils import BboxProcessor, denormalize_bboxes, normalize_bboxes
|
|
21
|
+
from albumentations.core.composition import Compose
|
|
22
|
+
from albumentations.core.keypoints_utils import KeypointsProcessor
|
|
23
|
+
from albumentations.core.type_definitions import (
|
|
24
|
+
NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS,
|
|
25
|
+
NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Type definition for a processed mosaic item
|
|
30
|
+
class ProcessedMosaicItem(TypedDict):
|
|
31
|
+
"""Represents a single data item (primary or additional) after preprocessing.
|
|
32
|
+
|
|
33
|
+
Includes the original image/mask and the *preprocessed* annotations.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
image: np.ndarray # Image is mandatory
|
|
37
|
+
mask: np.ndarray | None
|
|
38
|
+
bboxes: np.ndarray | None
|
|
39
|
+
keypoints: np.ndarray | None
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def copy_and_paste_blend(
|
|
43
|
+
base_image: np.ndarray,
|
|
44
|
+
overlay_image: np.ndarray,
|
|
45
|
+
overlay_mask: np.ndarray,
|
|
46
|
+
offset: tuple[int, int],
|
|
47
|
+
) -> np.ndarray:
|
|
48
|
+
"""Blend images by copying pixels from an overlay image to a base image using a mask.
|
|
49
|
+
|
|
50
|
+
This function copies pixels from the overlay image to the base image only where
|
|
51
|
+
the mask has non-zero values. The overlay is placed at the specified offset
|
|
52
|
+
from the top-left corner of the base image.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
base_image (np.ndarray): The destination image that will be modified.
|
|
56
|
+
overlay_image (np.ndarray): The source image containing pixels to copy.
|
|
57
|
+
overlay_mask (np.ndarray): Binary mask indicating which pixels to copy from the overlay.
|
|
58
|
+
Pixels are copied where mask > 0.
|
|
59
|
+
offset (tuple[int, int]): The (y, x) offset specifying where to place the
|
|
60
|
+
top-left corner of the overlay relative to the base image.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
np.ndarray: The blended image with the overlay applied to the base image.
|
|
64
|
+
|
|
65
|
+
"""
|
|
66
|
+
y_offset, x_offset = offset
|
|
67
|
+
|
|
68
|
+
blended_image = base_image.copy()
|
|
69
|
+
mask_indices = np.where(overlay_mask > 0)
|
|
70
|
+
blended_image[mask_indices[0] + y_offset, mask_indices[1] + x_offset] = overlay_image[
|
|
71
|
+
mask_indices[0],
|
|
72
|
+
mask_indices[1],
|
|
73
|
+
]
|
|
74
|
+
return blended_image
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def calculate_mosaic_center_point(
|
|
78
|
+
grid_yx: tuple[int, int],
|
|
79
|
+
cell_shape: tuple[int, int],
|
|
80
|
+
target_size: tuple[int, int],
|
|
81
|
+
center_range: tuple[float, float],
|
|
82
|
+
py_random: random.Random,
|
|
83
|
+
) -> tuple[int, int]:
|
|
84
|
+
"""Calculates the center point for the mosaic crop using proportional sampling within the valid zone.
|
|
85
|
+
|
|
86
|
+
Ensures the center point allows a crop of target_size to overlap
|
|
87
|
+
all grid cells, applying randomness based on center_range proportionally
|
|
88
|
+
within the valid region where the center can lie.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
|
|
92
|
+
cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
|
|
93
|
+
target_size (tuple[int, int]): The final output (height, width).
|
|
94
|
+
center_range (tuple[float, float]): Range [0.0-1.0] for sampling center proportionally
|
|
95
|
+
within the valid zone.
|
|
96
|
+
py_random (random.Random): Random state instance.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
tuple[int, int]: The calculated (x, y) center point relative to the
|
|
100
|
+
top-left of the conceptual large grid.
|
|
101
|
+
|
|
102
|
+
"""
|
|
103
|
+
rows, cols = grid_yx
|
|
104
|
+
cell_h, cell_w = cell_shape
|
|
105
|
+
target_h, target_w = target_size
|
|
106
|
+
|
|
107
|
+
large_grid_h = rows * cell_h
|
|
108
|
+
large_grid_w = cols * cell_w
|
|
109
|
+
|
|
110
|
+
# Define valid center range bounds (inclusive)
|
|
111
|
+
# The center must be far enough from edges so the crop window fits
|
|
112
|
+
min_cx = target_w // 2
|
|
113
|
+
max_cx = large_grid_w - (target_w + 1) // 2
|
|
114
|
+
min_cy = target_h // 2
|
|
115
|
+
max_cy = large_grid_h - (target_h + 1) // 2
|
|
116
|
+
|
|
117
|
+
# Calculate valid range dimensions (size of the safe zone)
|
|
118
|
+
valid_w = max_cx - min_cx + 1
|
|
119
|
+
valid_h = max_cy - min_cy + 1
|
|
120
|
+
|
|
121
|
+
# Sample relative position within the valid range using center_range
|
|
122
|
+
rel_x = py_random.uniform(*center_range)
|
|
123
|
+
rel_y = py_random.uniform(*center_range)
|
|
124
|
+
|
|
125
|
+
# Calculate center coordinates by scaling relative position within valid range
|
|
126
|
+
# Add the minimum bound to shift the range start
|
|
127
|
+
center_x = min_cx + int(valid_w * rel_x)
|
|
128
|
+
center_y = min_cy + int(valid_h * rel_y)
|
|
129
|
+
|
|
130
|
+
# Ensure the result is strictly within the calculated bounds after int conversion
|
|
131
|
+
# (This clip is mostly a safety measure, shouldn't be needed with correct int conversion)
|
|
132
|
+
center_x = max(min_cx, min(center_x, max_cx))
|
|
133
|
+
center_y = max(min_cy, min(center_y, max_cy))
|
|
134
|
+
|
|
135
|
+
return center_x, center_y
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def calculate_cell_placements(
|
|
139
|
+
grid_yx: tuple[int, int],
|
|
140
|
+
cell_shape: tuple[int, int],
|
|
141
|
+
target_size: tuple[int, int],
|
|
142
|
+
center_xy: tuple[int, int],
|
|
143
|
+
) -> list[tuple[int, int, int, int]]:
|
|
144
|
+
"""Calculates placements by clipping arange-defined grid lines to the crop window.
|
|
145
|
+
|
|
146
|
+
Args:
|
|
147
|
+
grid_yx (tuple[int, int]): The (rows, cols) of the mosaic grid.
|
|
148
|
+
cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
|
|
149
|
+
target_size (tuple[int, int]): The final output (height, width).
|
|
150
|
+
center_xy (tuple[int, int]): The calculated (x, y) center of the final crop window,
|
|
151
|
+
relative to the top-left of the conceptual large grid.
|
|
152
|
+
|
|
153
|
+
Returns:
|
|
154
|
+
list[tuple[int, int, int, int]]:
|
|
155
|
+
A list containing placement coordinates `(x_min, y_min, x_max, y_max)`
|
|
156
|
+
for each resulting cell part on the final output canvas.
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
rows, cols = grid_yx
|
|
160
|
+
cell_h, cell_w = cell_shape
|
|
161
|
+
target_h, target_w = target_size
|
|
162
|
+
center_x, center_y = center_xy
|
|
163
|
+
|
|
164
|
+
# 1. Generate grid line coordinates using arange for the large grid
|
|
165
|
+
y_coords_large = np.arange(rows + 1) * cell_h
|
|
166
|
+
x_coords_large = np.arange(cols + 1) * cell_w
|
|
167
|
+
|
|
168
|
+
# 2. Calculate Crop Window boundaries
|
|
169
|
+
crop_x_min = center_x - target_w // 2
|
|
170
|
+
crop_y_min = center_y - target_h // 2
|
|
171
|
+
crop_x_max = crop_x_min + target_w
|
|
172
|
+
crop_y_max = crop_y_min + target_h
|
|
173
|
+
|
|
174
|
+
def _clip_coords(coords: np.ndarray, min_val: int, max_val: int) -> np.ndarray:
|
|
175
|
+
clipped_coords = np.clip(coords, min_val, max_val)
|
|
176
|
+
# Subtract min_val to convert absolute clipped coordinates
|
|
177
|
+
# into coordinates relative to the crop window's origin (min_val becomes 0).
|
|
178
|
+
return np.unique(clipped_coords) - min_val
|
|
179
|
+
|
|
180
|
+
y_coords_clipped = _clip_coords(y_coords_large, crop_y_min, crop_y_max)
|
|
181
|
+
x_coords_clipped = _clip_coords(x_coords_large, crop_x_min, crop_x_max)
|
|
182
|
+
|
|
183
|
+
# 4. Form all cell coordinates efficiently
|
|
184
|
+
num_x_intervals = len(x_coords_clipped) - 1
|
|
185
|
+
num_y_intervals = len(y_coords_clipped) - 1
|
|
186
|
+
result = []
|
|
187
|
+
|
|
188
|
+
for y_idx in range(num_y_intervals):
|
|
189
|
+
y_min = y_coords_clipped[y_idx]
|
|
190
|
+
y_max = y_coords_clipped[y_idx + 1]
|
|
191
|
+
for x_idx in range(num_x_intervals):
|
|
192
|
+
x_min = x_coords_clipped[x_idx]
|
|
193
|
+
x_max = x_coords_clipped[x_idx + 1]
|
|
194
|
+
result.append((int(x_min), int(y_min), int(x_max), int(y_max)))
|
|
195
|
+
|
|
196
|
+
return result
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def _check_data_compatibility(
|
|
200
|
+
primary_data: np.ndarray | None,
|
|
201
|
+
item_data: np.ndarray | None,
|
|
202
|
+
data_key: Literal["image", "mask"],
|
|
203
|
+
) -> tuple[bool, str | None]: # Returns (is_compatible, error_message)
|
|
204
|
+
"""Checks if the dimensions and channels of item_data match primary_data."""
|
|
205
|
+
# 1. Check if item has the required data (image is always required)
|
|
206
|
+
if item_data is None:
|
|
207
|
+
if data_key == "image":
|
|
208
|
+
return False, "Item is missing required key 'image'"
|
|
209
|
+
# Mask is optional, missing is compatible
|
|
210
|
+
return True, None
|
|
211
|
+
|
|
212
|
+
# 2. If item data exists, check against primary data (if primary data exists)
|
|
213
|
+
if primary_data is None: # No primary data to compare against
|
|
214
|
+
return True, None
|
|
215
|
+
|
|
216
|
+
# Both primary and item data exist, compare them
|
|
217
|
+
primary_ndim = primary_data.ndim
|
|
218
|
+
item_ndim = item_data.ndim
|
|
219
|
+
|
|
220
|
+
if primary_ndim != item_ndim:
|
|
221
|
+
return False, (
|
|
222
|
+
f"Item '{data_key}' has {item_ndim} dimensions, but primary has {primary_ndim}. "
|
|
223
|
+
f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if primary_ndim == 3:
|
|
227
|
+
primary_channels = primary_data.shape[-1]
|
|
228
|
+
item_channels = item_data.shape[-1]
|
|
229
|
+
if primary_channels != item_channels:
|
|
230
|
+
return False, (
|
|
231
|
+
f"Item '{data_key}' has {item_channels} channels, but primary has {primary_channels}. "
|
|
232
|
+
f"Primary shape: {primary_data.shape}, Item shape: {item_data.shape}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Dimensions match (either both 2D or both 3D with same channels)
|
|
236
|
+
return True, None
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def filter_valid_metadata(
|
|
240
|
+
metadata_input: Sequence[dict[str, Any]] | None,
|
|
241
|
+
metadata_key_name: str,
|
|
242
|
+
data: dict[str, Any],
|
|
243
|
+
) -> list[dict[str, Any]]:
|
|
244
|
+
"""Filters a list of metadata dicts, keeping only valid ones based on data compatibility."""
|
|
245
|
+
if not isinstance(metadata_input, Sequence):
|
|
246
|
+
warn(
|
|
247
|
+
f"Metadata under key '{metadata_key_name}' is not a Sequence (e.g., list or tuple). "
|
|
248
|
+
f"Returning empty list for additional items.",
|
|
249
|
+
UserWarning,
|
|
250
|
+
stacklevel=3,
|
|
251
|
+
)
|
|
252
|
+
return []
|
|
253
|
+
|
|
254
|
+
valid_items = []
|
|
255
|
+
primary_image = data.get("image")
|
|
256
|
+
primary_mask = data.get("mask")
|
|
257
|
+
|
|
258
|
+
for i, item in enumerate(metadata_input):
|
|
259
|
+
if not isinstance(item, dict):
|
|
260
|
+
warn(
|
|
261
|
+
f"Item at index {i} in '{metadata_key_name}' is not a dict and will be skipped.",
|
|
262
|
+
UserWarning,
|
|
263
|
+
stacklevel=4,
|
|
264
|
+
)
|
|
265
|
+
continue
|
|
266
|
+
|
|
267
|
+
item_is_valid = True # Assume valid initially
|
|
268
|
+
for target_key, primary_target_data in [
|
|
269
|
+
("image", primary_image),
|
|
270
|
+
("mask", primary_mask),
|
|
271
|
+
]:
|
|
272
|
+
item_target_data = item.get(target_key)
|
|
273
|
+
|
|
274
|
+
is_compatible, error_msg = _check_data_compatibility(
|
|
275
|
+
primary_target_data,
|
|
276
|
+
item_target_data,
|
|
277
|
+
cast("Literal['image', 'mask']", target_key),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
if not is_compatible:
|
|
281
|
+
msg = (
|
|
282
|
+
f"Item at index {i} in '{metadata_key_name}' skipped due "
|
|
283
|
+
f"to incompatibility in '{target_key}': {error_msg}"
|
|
284
|
+
)
|
|
285
|
+
warn(msg, UserWarning, stacklevel=4)
|
|
286
|
+
item_is_valid = False
|
|
287
|
+
break # Stop checking other targets for this item
|
|
288
|
+
|
|
289
|
+
if item_is_valid:
|
|
290
|
+
valid_items.append(item)
|
|
291
|
+
|
|
292
|
+
return valid_items
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def assign_items_to_grid_cells(
|
|
296
|
+
num_items: int,
|
|
297
|
+
cell_placements: list[tuple[int, int, int, int]],
|
|
298
|
+
py_random: random.Random,
|
|
299
|
+
) -> dict[tuple[int, int, int, int], int]:
|
|
300
|
+
"""Assigns item indices to placement coordinate tuples.
|
|
301
|
+
|
|
302
|
+
Assigns the primary item (index 0) to the placement with the largest area,
|
|
303
|
+
and assigns the remaining items (indices 1 to num_items-1) randomly to the
|
|
304
|
+
remaining placements.
|
|
305
|
+
|
|
306
|
+
Args:
|
|
307
|
+
num_items (int): The total number of items to assign (primary + additional + replicas).
|
|
308
|
+
cell_placements (list[tuple[int, int, int, int]]): List of placement
|
|
309
|
+
coords (x1, y1, x2, y2) for cells to be filled.
|
|
310
|
+
py_random (random.Random): Random state instance.
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
dict[tuple[int, int, int, int], int]: Dict mapping placement coords (x1, y1, x2, y2)
|
|
314
|
+
to assigned item index.
|
|
315
|
+
|
|
316
|
+
"""
|
|
317
|
+
if not cell_placements:
|
|
318
|
+
return {}
|
|
319
|
+
|
|
320
|
+
# Find the placement tuple with the largest area for primary assignment
|
|
321
|
+
primary_placement = max(
|
|
322
|
+
cell_placements,
|
|
323
|
+
key=lambda coords: (coords[2] - coords[0]) * (coords[3] - coords[1]),
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
placement_to_item_index: dict[tuple[int, int, int, int], int] = {
|
|
327
|
+
primary_placement: 0,
|
|
328
|
+
}
|
|
329
|
+
|
|
330
|
+
# Use list comprehension for potentially better performance
|
|
331
|
+
remaining_placements = [coords for coords in cell_placements if coords != primary_placement]
|
|
332
|
+
|
|
333
|
+
# Indices for additional/replicated items start from 1
|
|
334
|
+
remaining_item_indices = list(range(1, num_items))
|
|
335
|
+
py_random.shuffle(remaining_item_indices)
|
|
336
|
+
|
|
337
|
+
num_to_assign = min(len(remaining_placements), len(remaining_item_indices))
|
|
338
|
+
for i in range(num_to_assign):
|
|
339
|
+
placement_to_item_index[remaining_placements[i]] = remaining_item_indices[i]
|
|
340
|
+
|
|
341
|
+
return placement_to_item_index
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def _preprocess_item_annotations(
|
|
345
|
+
item: dict[str, Any],
|
|
346
|
+
processor: BboxProcessor | KeypointsProcessor | None,
|
|
347
|
+
data_key: Literal["bboxes", "keypoints"],
|
|
348
|
+
) -> np.ndarray | None:
|
|
349
|
+
"""Helper to preprocess annotations (bboxes or keypoints) for a single item."""
|
|
350
|
+
original_data = item.get(data_key)
|
|
351
|
+
|
|
352
|
+
# Check if processor exists and the relevant data key is in the item
|
|
353
|
+
if processor and data_key in item and item.get(data_key) is not None:
|
|
354
|
+
# === Add validation for required label fields ===
|
|
355
|
+
required_labels = processor.params.label_fields
|
|
356
|
+
|
|
357
|
+
if required_labels and [field for field in required_labels if field not in item]:
|
|
358
|
+
raise ValueError(
|
|
359
|
+
f"Item contains '{data_key}' but is missing required label "
|
|
360
|
+
"fields: {[field for field in required_labels if field not in item]}. "
|
|
361
|
+
f"Ensure all label fields declared in {type(processor.params).__name__} "
|
|
362
|
+
f"({required_labels}) are present in the item dictionary when '{data_key}' is present.",
|
|
363
|
+
)
|
|
364
|
+
# === End validation ===
|
|
365
|
+
|
|
366
|
+
# Create a temporary minimal dict for the processor
|
|
367
|
+
temp_data = {
|
|
368
|
+
"image": item["image"],
|
|
369
|
+
data_key: item[data_key],
|
|
370
|
+
}
|
|
371
|
+
|
|
372
|
+
# Add declared label fields if they exist in the item (already validated above)
|
|
373
|
+
if required_labels:
|
|
374
|
+
for field in required_labels:
|
|
375
|
+
# Check again just in case validation logic changes, avoids KeyError
|
|
376
|
+
if field in item:
|
|
377
|
+
temp_data[field] = item[field]
|
|
378
|
+
|
|
379
|
+
# Preprocess modifies temp_data in-place
|
|
380
|
+
processor.preprocess(temp_data)
|
|
381
|
+
# Return the potentially modified data from the temp dict
|
|
382
|
+
return temp_data.get(data_key)
|
|
383
|
+
|
|
384
|
+
# Return original data if no processor or data key wasn't in item
|
|
385
|
+
return original_data
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def preprocess_selected_mosaic_items(
|
|
389
|
+
selected_raw_items: list[dict[str, Any]],
|
|
390
|
+
bbox_processor: BboxProcessor | None, # Allow None
|
|
391
|
+
keypoint_processor: KeypointsProcessor | None, # Allow None
|
|
392
|
+
) -> list[ProcessedMosaicItem]:
|
|
393
|
+
"""Preprocesses bboxes/keypoints for selected raw additional items.
|
|
394
|
+
|
|
395
|
+
Iterates through items, preprocesses annotations individually using processors
|
|
396
|
+
(updating label encoders), and returns a list of dicts with original image/mask
|
|
397
|
+
and the corresponding preprocessed bboxes/keypoints.
|
|
398
|
+
"""
|
|
399
|
+
if not selected_raw_items:
|
|
400
|
+
return []
|
|
401
|
+
|
|
402
|
+
result_data_items: list[ProcessedMosaicItem] = []
|
|
403
|
+
|
|
404
|
+
for item in selected_raw_items:
|
|
405
|
+
processed_bboxes = _preprocess_item_annotations(item, bbox_processor, "bboxes")
|
|
406
|
+
processed_keypoints = _preprocess_item_annotations(item, keypoint_processor, "keypoints")
|
|
407
|
+
|
|
408
|
+
# Construct the final processed item dict
|
|
409
|
+
processed_item_dict: ProcessedMosaicItem = {
|
|
410
|
+
"image": item["image"],
|
|
411
|
+
"mask": item.get("mask"),
|
|
412
|
+
"bboxes": processed_bboxes, # Already np.ndarray or None
|
|
413
|
+
"keypoints": processed_keypoints, # Already np.ndarray or None
|
|
414
|
+
}
|
|
415
|
+
result_data_items.append(processed_item_dict)
|
|
416
|
+
|
|
417
|
+
return result_data_items
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
def get_opposite_crop_coords(
|
|
421
|
+
cell_size: tuple[int, int],
|
|
422
|
+
crop_size: tuple[int, int],
|
|
423
|
+
cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
|
|
424
|
+
) -> tuple[int, int, int, int]:
|
|
425
|
+
"""Calculates crop coordinates positioned opposite to the specified cell_position.
|
|
426
|
+
|
|
427
|
+
Given a cell of `cell_size`, this function determines the top-left (x_min, y_min)
|
|
428
|
+
and bottom-right (x_max, y_max) coordinates for a crop of `crop_size`, such
|
|
429
|
+
that the crop is located in the corner or center opposite to `cell_position`.
|
|
430
|
+
|
|
431
|
+
For example, if `cell_position` is "top_left", the crop coordinates will
|
|
432
|
+
correspond to the bottom-right region of the cell.
|
|
433
|
+
|
|
434
|
+
Args:
|
|
435
|
+
cell_size: The (height, width) of the cell from which to crop.
|
|
436
|
+
crop_size: The (height, width) of the desired crop.
|
|
437
|
+
cell_position: The reference position within the cell. The crop will be
|
|
438
|
+
taken from the opposite position.
|
|
439
|
+
|
|
440
|
+
Returns:
|
|
441
|
+
tuple[int, int, int, int]: (x_min, y_min, x_max, y_max) representing the crop coordinates.
|
|
442
|
+
|
|
443
|
+
Raises:
|
|
444
|
+
ValueError: If crop_size is larger than cell_size in either dimension.
|
|
445
|
+
|
|
446
|
+
"""
|
|
447
|
+
cell_h, cell_w = cell_size
|
|
448
|
+
crop_h, crop_w = crop_size
|
|
449
|
+
|
|
450
|
+
if crop_h > cell_h or crop_w > cell_w:
|
|
451
|
+
raise ValueError(f"Crop size {crop_size} cannot be larger than cell size {cell_size}")
|
|
452
|
+
|
|
453
|
+
# Determine top-left corner (x_min, y_min) based on the OPPOSITE position
|
|
454
|
+
if cell_position == "top_left": # Crop from bottom_right
|
|
455
|
+
x_min = cell_w - crop_w
|
|
456
|
+
y_min = cell_h - crop_h
|
|
457
|
+
elif cell_position == "top_right": # Crop from bottom_left
|
|
458
|
+
x_min = 0
|
|
459
|
+
y_min = cell_h - crop_h
|
|
460
|
+
elif cell_position == "bottom_left": # Crop from top_right
|
|
461
|
+
x_min = cell_w - crop_w
|
|
462
|
+
y_min = 0
|
|
463
|
+
elif cell_position == "bottom_right": # Crop from top_left
|
|
464
|
+
x_min = 0
|
|
465
|
+
y_min = 0
|
|
466
|
+
elif cell_position == "center": # Crop from center
|
|
467
|
+
x_min = (cell_w - crop_w) // 2
|
|
468
|
+
y_min = (cell_h - crop_h) // 2
|
|
469
|
+
else:
|
|
470
|
+
# Should be unreachable due to Literal type hint, but good practice
|
|
471
|
+
raise ValueError(f"Invalid cell_position: {cell_position}")
|
|
472
|
+
|
|
473
|
+
# Calculate bottom-right corner
|
|
474
|
+
x_max = x_min + crop_w
|
|
475
|
+
y_max = y_min + crop_h
|
|
476
|
+
|
|
477
|
+
return x_min, y_min, x_max, y_max
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def process_cell_geometry(
|
|
481
|
+
cell_shape: tuple[int, int],
|
|
482
|
+
item: ProcessedMosaicItem,
|
|
483
|
+
target_shape: tuple[int, int],
|
|
484
|
+
fill: float | tuple[float, ...],
|
|
485
|
+
fill_mask: float | tuple[float, ...],
|
|
486
|
+
fit_mode: Literal["cover", "contain"],
|
|
487
|
+
interpolation: int,
|
|
488
|
+
mask_interpolation: int,
|
|
489
|
+
cell_position: Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"],
|
|
490
|
+
) -> ProcessedMosaicItem:
|
|
491
|
+
"""Applies geometric transformations (padding and/or cropping) to a single mosaic item.
|
|
492
|
+
|
|
493
|
+
Uses a Compose pipeline with PadIfNeeded and Crop to ensure the output
|
|
494
|
+
matches the target cell dimensions exactly, handling both padding and cropping cases.
|
|
495
|
+
|
|
496
|
+
Args:
|
|
497
|
+
cell_shape: (tuple[int, int]): Shape of the cell.
|
|
498
|
+
item: (ProcessedMosaicItem): The preprocessed mosaic item dictionary.
|
|
499
|
+
target_shape: (tuple[int, int]): Target shape of the cell.
|
|
500
|
+
fill: (float | tuple[float, ...]): Fill value for image padding.
|
|
501
|
+
fill_mask: (float | tuple[float, ...]): Fill value for mask padding.
|
|
502
|
+
fit_mode: (Literal["cover", "contain"]): Fit mode for the mosaic.
|
|
503
|
+
interpolation: (int): Interpolation method for image.
|
|
504
|
+
mask_interpolation: (int): Interpolation method for mask.
|
|
505
|
+
cell_position: (Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]): Position
|
|
506
|
+
of the cell.
|
|
507
|
+
|
|
508
|
+
Returns: (ProcessedMosaicItem): Dictionary containing the geometrically processed image,
|
|
509
|
+
mask, bboxes, and keypoints, fitting the target dimensions.
|
|
510
|
+
|
|
511
|
+
"""
|
|
512
|
+
# Define the pipeline: PadIfNeeded first, then Crop
|
|
513
|
+
compose_kwargs: dict[str, Any] = {"p": 1.0}
|
|
514
|
+
if item.get("bboxes") is not None:
|
|
515
|
+
compose_kwargs["bbox_params"] = {"format": "albumentations"}
|
|
516
|
+
if item.get("keypoints") is not None:
|
|
517
|
+
compose_kwargs["keypoint_params"] = {"format": "albumentations"}
|
|
518
|
+
|
|
519
|
+
crop_coords = get_opposite_crop_coords(cell_shape, target_shape, cell_position)
|
|
520
|
+
|
|
521
|
+
if fit_mode == "cover":
|
|
522
|
+
geom_pipeline = Compose(
|
|
523
|
+
[
|
|
524
|
+
SmallestMaxSize(
|
|
525
|
+
max_size_hw=cell_shape,
|
|
526
|
+
interpolation=interpolation,
|
|
527
|
+
mask_interpolation=mask_interpolation,
|
|
528
|
+
p=1.0,
|
|
529
|
+
),
|
|
530
|
+
Crop(
|
|
531
|
+
x_min=crop_coords[0],
|
|
532
|
+
y_min=crop_coords[1],
|
|
533
|
+
x_max=crop_coords[2],
|
|
534
|
+
y_max=crop_coords[3],
|
|
535
|
+
),
|
|
536
|
+
],
|
|
537
|
+
**compose_kwargs,
|
|
538
|
+
)
|
|
539
|
+
elif fit_mode == "contain":
|
|
540
|
+
geom_pipeline = Compose(
|
|
541
|
+
[
|
|
542
|
+
LongestMaxSize(
|
|
543
|
+
max_size_hw=cell_shape,
|
|
544
|
+
interpolation=interpolation,
|
|
545
|
+
mask_interpolation=mask_interpolation,
|
|
546
|
+
p=1.0,
|
|
547
|
+
),
|
|
548
|
+
Crop(
|
|
549
|
+
x_min=crop_coords[0],
|
|
550
|
+
y_min=crop_coords[1],
|
|
551
|
+
x_max=crop_coords[2],
|
|
552
|
+
y_max=crop_coords[3],
|
|
553
|
+
pad_if_needed=True,
|
|
554
|
+
fill=fill,
|
|
555
|
+
fill_mask=fill_mask,
|
|
556
|
+
p=1.0,
|
|
557
|
+
),
|
|
558
|
+
],
|
|
559
|
+
**compose_kwargs,
|
|
560
|
+
)
|
|
561
|
+
else:
|
|
562
|
+
raise ValueError(f"Invalid fit_mode: {fit_mode}. Must be 'cover' or 'contain'.")
|
|
563
|
+
|
|
564
|
+
# Prepare input data for the pipeline
|
|
565
|
+
geom_input = {"image": item["image"]}
|
|
566
|
+
if item.get("mask") is not None:
|
|
567
|
+
geom_input["mask"] = item["mask"]
|
|
568
|
+
if item.get("bboxes") is not None:
|
|
569
|
+
# Compose expects bboxes in a specific format, ensure it's compatible
|
|
570
|
+
# Assuming item['bboxes'] is already preprocessed correctly
|
|
571
|
+
geom_input["bboxes"] = item["bboxes"]
|
|
572
|
+
if item.get("keypoints") is not None:
|
|
573
|
+
geom_input["keypoints"] = item["keypoints"]
|
|
574
|
+
|
|
575
|
+
# Apply the pipeline
|
|
576
|
+
processed_item = geom_pipeline(**geom_input)
|
|
577
|
+
|
|
578
|
+
# Ensure output dict has the same structure as ProcessedMosaicItem
|
|
579
|
+
# Compose might not return None for missing keys, handle explicitly
|
|
580
|
+
return {
|
|
581
|
+
"image": processed_item["image"],
|
|
582
|
+
"mask": processed_item.get("mask"),
|
|
583
|
+
"bboxes": processed_item.get("bboxes"),
|
|
584
|
+
"keypoints": processed_item.get("keypoints"),
|
|
585
|
+
}
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
def shift_cell_coordinates(
|
|
589
|
+
processed_item_geom: ProcessedMosaicItem,
|
|
590
|
+
placement_coords: tuple[int, int, int, int],
|
|
591
|
+
) -> ProcessedMosaicItem:
|
|
592
|
+
"""Shifts the coordinates of geometrically processed bboxes and keypoints.
|
|
593
|
+
|
|
594
|
+
Args:
|
|
595
|
+
processed_item_geom: (ProcessedMosaicItem): The output from process_cell_geometry.
|
|
596
|
+
placement_coords: (tuple[int, int, int, int]): The (x1, y1, x2, y2) placement on the final canvas.
|
|
597
|
+
|
|
598
|
+
Returns: (ProcessedMosaicItem): A dictionary with keys 'bboxes' and 'keypoints', containing the shifted
|
|
599
|
+
numpy arrays (potentially empty).
|
|
600
|
+
|
|
601
|
+
"""
|
|
602
|
+
tgt_x1, tgt_y1, _, _ = placement_coords
|
|
603
|
+
|
|
604
|
+
shifted_bboxes = None
|
|
605
|
+
shifted_keypoints = None
|
|
606
|
+
|
|
607
|
+
bboxes_geom = processed_item_geom.get("bboxes")
|
|
608
|
+
if bboxes_geom is not None and np.asarray(bboxes_geom).size > 0:
|
|
609
|
+
bboxes_geom_arr = np.asarray(bboxes_geom) # Ensure it's an array
|
|
610
|
+
bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.int32)
|
|
611
|
+
shifted_bboxes = fgeometric.shift_bboxes(bboxes_geom_arr, bbox_shift_vector)
|
|
612
|
+
|
|
613
|
+
keypoints_geom = processed_item_geom.get("keypoints")
|
|
614
|
+
if keypoints_geom is not None and np.asarray(keypoints_geom).size > 0:
|
|
615
|
+
keypoints_geom_arr = np.asarray(keypoints_geom) # Ensure it's an array
|
|
616
|
+
kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
|
|
617
|
+
shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
|
|
618
|
+
|
|
619
|
+
return {
|
|
620
|
+
"bboxes": shifted_bboxes,
|
|
621
|
+
"keypoints": shifted_keypoints,
|
|
622
|
+
"image": processed_item_geom["image"],
|
|
623
|
+
"mask": processed_item_geom.get("mask"),
|
|
624
|
+
}
|
|
625
|
+
|
|
626
|
+
|
|
627
|
+
def assemble_mosaic_from_processed_cells(
|
|
628
|
+
processed_cells: dict[tuple[int, int, int, int], dict[str, Any]],
|
|
629
|
+
target_shape: tuple[int, ...], # Use full canvas shape (H, W) or (H, W, C)
|
|
630
|
+
dtype: np.dtype,
|
|
631
|
+
data_key: Literal["image", "mask"],
|
|
632
|
+
fill: float | tuple[float, ...] | None, # Value for image fill or mask fill
|
|
633
|
+
) -> np.ndarray:
|
|
634
|
+
"""Assembles the final mosaic image or mask from processed cell data onto a canvas.
|
|
635
|
+
|
|
636
|
+
Initializes the canvas with the fill value and overwrites with processed segments.
|
|
637
|
+
Handles potentially multi-channel masks.
|
|
638
|
+
Addresses potential broadcasting errors if mask segments have unexpected dimensions.
|
|
639
|
+
Assumes input data is valid and correctly sized.
|
|
640
|
+
|
|
641
|
+
Args:
|
|
642
|
+
processed_cells (dict[tuple[int, int, int, int], dict[str, Any]]): Dictionary mapping
|
|
643
|
+
placement coords to processed cell data.
|
|
644
|
+
target_shape (tuple[int, ...]): The target shape of the output canvas (e.g., (H, W) or (H, W, C)).
|
|
645
|
+
dtype (np.dtype): NumPy dtype for the canvas.
|
|
646
|
+
data_key (Literal["image", "mask"]): Specifies whether to assemble 'image' or 'mask'.
|
|
647
|
+
fill (float | tuple[float, ...] | None): Value used to initialize the canvas (image fill or mask fill).
|
|
648
|
+
Should be a float/int or a tuple matching the number of channels.
|
|
649
|
+
If None, defaults to 0.
|
|
650
|
+
|
|
651
|
+
Returns:
|
|
652
|
+
np.ndarray: The assembled mosaic canvas.
|
|
653
|
+
|
|
654
|
+
"""
|
|
655
|
+
# Use 0 as default fill if None is provided
|
|
656
|
+
actual_fill = fill if fill is not None else 0
|
|
657
|
+
|
|
658
|
+
# Convert fill to numpy array to handle broadcasting in np.full
|
|
659
|
+
fill_value = np.array(actual_fill, dtype=dtype)
|
|
660
|
+
# Initialize canvas with the fill value.
|
|
661
|
+
# If fill_value shape is incompatible with target_shape, np.full will raise ValueError.
|
|
662
|
+
canvas = np.full(target_shape, fill_value=fill_value, dtype=dtype)
|
|
663
|
+
|
|
664
|
+
# Iterate and paste segments onto the pre-filled canvas
|
|
665
|
+
for placement_coords, cell_data in processed_cells.items():
|
|
666
|
+
segment = cell_data.get(data_key)
|
|
667
|
+
|
|
668
|
+
# If segment exists, paste it over the filled background
|
|
669
|
+
if segment is not None:
|
|
670
|
+
tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
|
|
671
|
+
|
|
672
|
+
canvas[tgt_y1:tgt_y2, tgt_x1:tgt_x2] = segment
|
|
673
|
+
|
|
674
|
+
return canvas
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def process_all_mosaic_geometries(
|
|
678
|
+
canvas_shape: tuple[int, int],
|
|
679
|
+
cell_shape: tuple[int, int],
|
|
680
|
+
placement_to_item_index: dict[tuple[int, int, int, int], int],
|
|
681
|
+
final_items_for_grid: list[ProcessedMosaicItem],
|
|
682
|
+
fill: float | tuple[float, ...],
|
|
683
|
+
fill_mask: float | tuple[float, ...],
|
|
684
|
+
fit_mode: Literal["cover", "contain"],
|
|
685
|
+
interpolation: Literal[
|
|
686
|
+
cv2.INTER_NEAREST,
|
|
687
|
+
cv2.INTER_NEAREST_EXACT,
|
|
688
|
+
cv2.INTER_LINEAR,
|
|
689
|
+
cv2.INTER_CUBIC,
|
|
690
|
+
cv2.INTER_AREA,
|
|
691
|
+
cv2.INTER_LANCZOS4,
|
|
692
|
+
cv2.INTER_LINEAR_EXACT,
|
|
693
|
+
],
|
|
694
|
+
mask_interpolation: Literal[
|
|
695
|
+
cv2.INTER_NEAREST,
|
|
696
|
+
cv2.INTER_NEAREST_EXACT,
|
|
697
|
+
cv2.INTER_LINEAR,
|
|
698
|
+
cv2.INTER_CUBIC,
|
|
699
|
+
cv2.INTER_AREA,
|
|
700
|
+
cv2.INTER_LANCZOS4,
|
|
701
|
+
cv2.INTER_LINEAR_EXACT,
|
|
702
|
+
],
|
|
703
|
+
) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]:
|
|
704
|
+
"""Processes the geometry (cropping/padding) for all assigned mosaic cells.
|
|
705
|
+
|
|
706
|
+
Iterates through assigned placements, applies geometric transforms via process_cell_geometry,
|
|
707
|
+
and returns a dictionary mapping final placement coordinates to the processed item data.
|
|
708
|
+
The bbox/keypoint coordinates in the returned dict are *not* shifted yet.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
canvas_shape (tuple[int, int]): The shape of the canvas.
|
|
712
|
+
cell_shape (tuple[int, int]): Shape of each cell in the mosaic grid.
|
|
713
|
+
placement_to_item_index (dict[tuple[int, int, int, int], int]): Mapping from placement
|
|
714
|
+
coordinates (x1, y1, x2, y2) to assigned item index.
|
|
715
|
+
final_items_for_grid (list[ProcessedMosaicItem]): List of all preprocessed items available.
|
|
716
|
+
fill (float | tuple[float, ...]): Fill value for image padding.
|
|
717
|
+
fill_mask (float | tuple[float, ...]): Fill value for mask padding.
|
|
718
|
+
fit_mode (Literal["cover", "contain"]): Fit mode for the mosaic.
|
|
719
|
+
interpolation (int): Interpolation method for image.
|
|
720
|
+
mask_interpolation (int): Interpolation method for mask.
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
dict[tuple[int, int, int, int], ProcessedMosaicItem]: Dictionary mapping final placement
|
|
724
|
+
coordinates (x1, y1, x2, y2) to the geometrically processed item data (image, mask, un-shifted bboxes/kps).
|
|
725
|
+
|
|
726
|
+
"""
|
|
727
|
+
processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
|
|
728
|
+
|
|
729
|
+
# Iterate directly over placements and their assigned item indices
|
|
730
|
+
for placement_coords, item_idx in placement_to_item_index.items():
|
|
731
|
+
item = final_items_for_grid[item_idx]
|
|
732
|
+
tgt_x1, tgt_y1, tgt_x2, tgt_y2 = placement_coords
|
|
733
|
+
target_h = tgt_y2 - tgt_y1
|
|
734
|
+
target_w = tgt_x2 - tgt_x1
|
|
735
|
+
|
|
736
|
+
cell_position = get_cell_relative_position(placement_coords, canvas_shape)
|
|
737
|
+
|
|
738
|
+
# Apply geometric processing (crop/pad)
|
|
739
|
+
processed_cells_geom[placement_coords] = process_cell_geometry(
|
|
740
|
+
cell_shape=cell_shape,
|
|
741
|
+
item=item,
|
|
742
|
+
target_shape=(target_h, target_w),
|
|
743
|
+
fill=fill,
|
|
744
|
+
fill_mask=fill_mask,
|
|
745
|
+
fit_mode=fit_mode,
|
|
746
|
+
interpolation=interpolation,
|
|
747
|
+
mask_interpolation=mask_interpolation,
|
|
748
|
+
cell_position=cell_position,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
return processed_cells_geom
|
|
752
|
+
|
|
753
|
+
|
|
754
|
+
def get_cell_relative_position(
|
|
755
|
+
placement_coords: tuple[int, int, int, int],
|
|
756
|
+
target_shape: tuple[int, int],
|
|
757
|
+
) -> Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
|
|
758
|
+
"""Determines the position of a cell relative to the center of the target canvas.
|
|
759
|
+
|
|
760
|
+
Compares the cell center to the canvas center and returns its quadrant
|
|
761
|
+
or "center" if it lies on or very close to a central axis.
|
|
762
|
+
|
|
763
|
+
Args:
|
|
764
|
+
placement_coords (tuple[int, int, int, int]): The (x_min, y_min, x_max, y_max) coordinates
|
|
765
|
+
of the cell.
|
|
766
|
+
target_shape (tuple[int, int]): The (height, width) of the overall target canvas.
|
|
767
|
+
|
|
768
|
+
Returns:
|
|
769
|
+
Literal["top_left", "top_right", "center", "bottom_left", "bottom_right"]:
|
|
770
|
+
The position of the cell relative to the center of the target canvas.
|
|
771
|
+
|
|
772
|
+
"""
|
|
773
|
+
target_h, target_w = target_shape
|
|
774
|
+
x1, y1, x2, y2 = placement_coords
|
|
775
|
+
|
|
776
|
+
canvas_center_x = target_w / 2.0
|
|
777
|
+
canvas_center_y = target_h / 2.0
|
|
778
|
+
|
|
779
|
+
cell_center_x = (x1 + x2) / 2.0
|
|
780
|
+
cell_center_y = (y1 + y2) / 2.0
|
|
781
|
+
|
|
782
|
+
# Determine vertical position
|
|
783
|
+
if cell_center_y < canvas_center_y:
|
|
784
|
+
v_pos = "top"
|
|
785
|
+
elif cell_center_y > canvas_center_y:
|
|
786
|
+
v_pos = "bottom"
|
|
787
|
+
else: # Exactly on the horizontal center line
|
|
788
|
+
v_pos = "center"
|
|
789
|
+
|
|
790
|
+
# Determine horizontal position
|
|
791
|
+
if cell_center_x < canvas_center_x:
|
|
792
|
+
h_pos = "left"
|
|
793
|
+
elif cell_center_x > canvas_center_x:
|
|
794
|
+
h_pos = "right"
|
|
795
|
+
else: # Exactly on the vertical center line
|
|
796
|
+
h_pos = "center"
|
|
797
|
+
|
|
798
|
+
# Map positions to the final string
|
|
799
|
+
position_map = {
|
|
800
|
+
("top", "left"): "top_left",
|
|
801
|
+
("top", "right"): "top_right",
|
|
802
|
+
("bottom", "left"): "bottom_left",
|
|
803
|
+
("bottom", "right"): "bottom_right",
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
# Default to "center" if the combination is not in the map
|
|
807
|
+
# (which happens if either v_pos or h_pos is "center")
|
|
808
|
+
return cast(
|
|
809
|
+
"Literal['top_left', 'top_right', 'center', 'bottom_left', 'bottom_right']",
|
|
810
|
+
position_map.get((v_pos, h_pos), "center"),
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
|
|
814
|
+
def shift_all_coordinates(
|
|
815
|
+
processed_cells_geom: dict[tuple[int, int, int, int], ProcessedMosaicItem],
|
|
816
|
+
canvas_shape: tuple[int, int],
|
|
817
|
+
) -> dict[tuple[int, int, int, int], ProcessedMosaicItem]: # Return type matches input, but values are updated
|
|
818
|
+
"""Shifts coordinates for all geometrically processed cells.
|
|
819
|
+
|
|
820
|
+
Iterates through the processed cells (keyed by placement coords), applies coordinate
|
|
821
|
+
shifting to bboxes/keypoints, and returns a new dictionary with the same keys
|
|
822
|
+
but updated ProcessedMosaicItem values containing the *shifted* coordinates.
|
|
823
|
+
|
|
824
|
+
Args:
|
|
825
|
+
processed_cells_geom (dict[tuple[int, int, int, int], ProcessedMosaicItem]):
|
|
826
|
+
Output from process_all_mosaic_geometries (keyed by placement coords).
|
|
827
|
+
canvas_shape (tuple[int, int]): The shape of the canvas.
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
dict[tuple[int, int, int, int], ProcessedMosaicItem]: Final dictionary mapping
|
|
831
|
+
placement coords (x1, y1, x2, y2) to processed cell data with shifted coordinates.
|
|
832
|
+
|
|
833
|
+
"""
|
|
834
|
+
final_processed_cells: dict[tuple[int, int, int, int], ProcessedMosaicItem] = {}
|
|
835
|
+
canvas_h, canvas_w = canvas_shape
|
|
836
|
+
|
|
837
|
+
for placement_coords, cell_data_geom in processed_cells_geom.items():
|
|
838
|
+
tgt_x1, tgt_y1 = placement_coords[:2]
|
|
839
|
+
|
|
840
|
+
cell_width = placement_coords[2] - placement_coords[0]
|
|
841
|
+
cell_height = placement_coords[3] - placement_coords[1]
|
|
842
|
+
|
|
843
|
+
# Extract geometrically processed bboxes/keypoints
|
|
844
|
+
bboxes_geom = cell_data_geom.get("bboxes")
|
|
845
|
+
keypoints_geom = cell_data_geom.get("keypoints")
|
|
846
|
+
|
|
847
|
+
final_cell_data = {
|
|
848
|
+
"image": cell_data_geom["image"],
|
|
849
|
+
"mask": cell_data_geom.get("mask"),
|
|
850
|
+
}
|
|
851
|
+
|
|
852
|
+
# Perform shifting if data exists
|
|
853
|
+
if bboxes_geom is not None and bboxes_geom.size > 0:
|
|
854
|
+
bboxes_geom_arr = np.asarray(bboxes_geom)
|
|
855
|
+
bbox_denoramlized = denormalize_bboxes(bboxes_geom_arr, {"height": cell_height, "width": cell_width})
|
|
856
|
+
bbox_shift_vector = np.array([tgt_x1, tgt_y1, tgt_x1, tgt_y1], dtype=np.float32)
|
|
857
|
+
|
|
858
|
+
shifted_bboxes_denormalized = fgeometric.shift_bboxes(bbox_denoramlized, bbox_shift_vector)
|
|
859
|
+
shifted_bboxes = normalize_bboxes(shifted_bboxes_denormalized, {"height": canvas_h, "width": canvas_w})
|
|
860
|
+
final_cell_data["bboxes"] = shifted_bboxes
|
|
861
|
+
else:
|
|
862
|
+
final_cell_data["bboxes"] = np.empty((0, NUM_BBOXES_COLUMNS_IN_ALBUMENTATIONS))
|
|
863
|
+
|
|
864
|
+
if keypoints_geom is not None and keypoints_geom.size > 0:
|
|
865
|
+
keypoints_geom_arr = np.asarray(keypoints_geom)
|
|
866
|
+
|
|
867
|
+
# Ensure shift vector matches keypoint dtype (usually float)
|
|
868
|
+
kp_shift_vector = np.array([tgt_x1, tgt_y1, 0], dtype=keypoints_geom_arr.dtype)
|
|
869
|
+
|
|
870
|
+
shifted_keypoints = fgeometric.shift_keypoints(keypoints_geom_arr, kp_shift_vector)
|
|
871
|
+
|
|
872
|
+
final_cell_data["keypoints"] = shifted_keypoints
|
|
873
|
+
else:
|
|
874
|
+
final_cell_data["keypoints"] = np.empty((0, NUM_KEYPOINTS_COLUMNS_IN_ALBUMENTATIONS))
|
|
875
|
+
|
|
876
|
+
final_processed_cells[placement_coords] = cast("ProcessedMosaicItem", final_cell_data)
|
|
877
|
+
|
|
878
|
+
return final_processed_cells
|