nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,453 @@
1
+ """Functional implementations for domain adaptation image transformations.
2
+
3
+ This module provides low-level functions and classes for performing domain adaptation
4
+ between images. It includes implementations for histogram matching, Fourier domain adaptation,
5
+ and pixel distribution matching with various normalization techniques.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import abc
11
+ from copy import deepcopy
12
+ from typing import Literal
13
+
14
+ import cv2
15
+ import numpy as np
16
+ from albucore import add_weighted, clip, clipped, from_float, get_num_channels, preserve_channel_dim, to_float, uint8_io
17
+ from typing_extensions import Protocol
18
+
19
+ import albumentations.augmentations.geometric.functional as fgeometric
20
+ from albumentations.augmentations.utils import PCA
21
+ from albumentations.core.type_definitions import MONO_CHANNEL_DIMENSIONS
22
+
23
+ __all__ = [
24
+ "adapt_pixel_distribution",
25
+ "apply_histogram",
26
+ "fourier_domain_adaptation",
27
+ ]
28
+
29
+
30
+ class BaseScaler:
31
+ def __init__(self) -> None:
32
+ self.data_min: np.ndarray | None = None
33
+ self.data_max: np.ndarray | None = None
34
+ self.mean: np.ndarray | None = None
35
+ self.var: np.ndarray | None = None
36
+ self.scale: np.ndarray | None = None
37
+
38
+ def fit(self, x: np.ndarray) -> None:
39
+ raise NotImplementedError
40
+
41
+ def transform(self, x: np.ndarray) -> np.ndarray:
42
+ raise NotImplementedError
43
+
44
+ def fit_transform(self, x: np.ndarray) -> np.ndarray:
45
+ self.fit(x)
46
+ return self.transform(x)
47
+
48
+ def inverse_transform(self, x: np.ndarray) -> np.ndarray:
49
+ raise NotImplementedError
50
+
51
+
52
+ class MinMaxScaler(BaseScaler):
53
+ def __init__(self, feature_range: tuple[float, float] = (0.0, 1.0)) -> None:
54
+ super().__init__()
55
+ self.min: float = feature_range[0]
56
+ self.max: float = feature_range[1]
57
+ self.data_range: np.ndarray | None = None
58
+
59
+ def fit(self, x: np.ndarray) -> None:
60
+ self.data_min = np.min(x, axis=0)
61
+ self.data_max = np.max(x, axis=0)
62
+ self.data_range = self.data_max - self.data_min
63
+ # Handle case where data_min equals data_max
64
+ self.data_range[self.data_range == 0] = 1
65
+
66
+ def transform(self, x: np.ndarray) -> np.ndarray:
67
+ if self.data_min is None or self.data_max is None or self.data_range is None:
68
+ raise ValueError(
69
+ "This MinMaxScaler instance is not fitted yet. "
70
+ "Call 'fit' with appropriate arguments before using this estimator.",
71
+ )
72
+
73
+ x_std = np.subtract(x, self.data_min).astype(float)
74
+ np.divide(x_std, self.data_range, out=x_std)
75
+ np.multiply(x_std, (self.max - self.min), out=x_std)
76
+ np.add(x_std, self.min, out=x_std)
77
+
78
+ return x_std
79
+
80
+ def inverse_transform(self, x: np.ndarray) -> np.ndarray:
81
+ if self.data_min is None or self.data_max is None or self.data_range is None:
82
+ raise ValueError(
83
+ "This MinMaxScaler instance is not fitted yet. "
84
+ "Call 'fit' with appropriate arguments before using this estimator.",
85
+ )
86
+ x_std = ((x - self.min) / (self.max - self.min)).astype(float)
87
+ return x_std * self.data_range + self.data_min
88
+
89
+
90
+ class StandardScaler(BaseScaler):
91
+ def __init__(self) -> None:
92
+ super().__init__()
93
+
94
+ def fit(self, x: np.ndarray) -> None:
95
+ self.mean = np.mean(x, axis=0)
96
+ self.var = np.var(x, axis=0)
97
+ self.scale = np.sqrt(self.var)
98
+ # Handle case where variance is zero
99
+ self.scale[self.scale == 0] = 1
100
+
101
+ def transform(self, x: np.ndarray) -> np.ndarray:
102
+ if self.mean is None or self.scale is None:
103
+ raise ValueError(
104
+ "This StandardScaler instance is not fitted yet. "
105
+ "Call 'fit' with appropriate arguments before using this estimator.",
106
+ )
107
+ return (x - self.mean) / self.scale
108
+
109
+ def inverse_transform(self, x: np.ndarray) -> np.ndarray:
110
+ if self.mean is None or self.scale is None:
111
+ raise ValueError(
112
+ "This StandardScaler instance is not fitted yet. "
113
+ "Call 'fit' with appropriate arguments before using this estimator.",
114
+ )
115
+ return (x * self.scale) + self.mean
116
+
117
+
118
+ class TransformerInterface(Protocol):
119
+ @abc.abstractmethod
120
+ def inverse_transform(self, x: np.ndarray) -> np.ndarray: ...
121
+
122
+ @abc.abstractmethod
123
+ def fit(self, x: np.ndarray, y: np.ndarray | None = None) -> np.ndarray: ...
124
+
125
+ @abc.abstractmethod
126
+ def transform(self, x: np.ndarray, y: np.ndarray | None = None) -> np.ndarray: ...
127
+
128
+
129
+ class DomainAdapter:
130
+ def __init__(
131
+ self,
132
+ transformer: TransformerInterface,
133
+ ref_img: np.ndarray,
134
+ color_conversions: tuple[None, None] = (None, None),
135
+ ):
136
+ self.color_in, self.color_out = color_conversions
137
+ self.source_transformer = deepcopy(transformer)
138
+ self.target_transformer = transformer
139
+ self.num_channels = get_num_channels(ref_img)
140
+ self.target_transformer.fit(self.flatten(ref_img))
141
+
142
+ def to_colorspace(self, img: np.ndarray) -> np.ndarray:
143
+ return img if self.color_in is None else cv2.cvtColor(img, self.color_in)
144
+
145
+ def from_colorspace(self, img: np.ndarray) -> np.ndarray:
146
+ if self.color_out is None:
147
+ return img
148
+ return cv2.cvtColor(clip(img, np.uint8, inplace=True), self.color_out)
149
+
150
+ def flatten(self, img: np.ndarray) -> np.ndarray:
151
+ img = self.to_colorspace(img)
152
+ img = to_float(img)
153
+ return img.reshape(-1, self.num_channels)
154
+
155
+ def reconstruct(self, pixels: np.ndarray, height: int, width: int) -> np.ndarray:
156
+ pixels = clip(pixels, np.uint8, inplace=True)
157
+ if self.num_channels == 1:
158
+ return self.from_colorspace(pixels.reshape(height, width))
159
+ return self.from_colorspace(pixels.reshape(height, width, self.num_channels))
160
+
161
+ @staticmethod
162
+ def _pca_sign(x: np.ndarray) -> np.ndarray:
163
+ return np.sign(np.trace(x.components_))
164
+
165
+ def __call__(self, image: np.ndarray) -> np.ndarray:
166
+ height, width = image.shape[:2]
167
+ pixels = self.flatten(image)
168
+ self.source_transformer.fit(pixels)
169
+
170
+ if (
171
+ hasattr(self.target_transformer, "components_")
172
+ and hasattr(self.source_transformer, "components_")
173
+ and self._pca_sign(self.target_transformer) != self._pca_sign(self.source_transformer)
174
+ ):
175
+ self.target_transformer.components_ *= -1
176
+
177
+ representation = self.source_transformer.transform(pixels)
178
+ result = self.target_transformer.inverse_transform(representation)
179
+ return self.reconstruct(result, height, width)
180
+
181
+
182
+ @clipped
183
+ @preserve_channel_dim
184
+ def adapt_pixel_distribution(
185
+ img: np.ndarray,
186
+ ref: np.ndarray,
187
+ transform_type: Literal["pca", "standard", "minmax"],
188
+ weight: float,
189
+ ) -> np.ndarray:
190
+ """Adapt the pixel distribution of an image to match a reference image.
191
+
192
+ This function adapts the pixel distribution of an image to match a reference image
193
+ using a specified transformation type and weight.
194
+
195
+ Args:
196
+ img (np.ndarray): The input image to be adapted.
197
+ ref (np.ndarray): The reference image.
198
+ transform_type (Literal["pca", "standard", "minmax"]): The type of transformation to use.
199
+ weight (float): The weight of the transformation.
200
+
201
+ Returns:
202
+ np.ndarray: The adapted image.
203
+
204
+ Raises:
205
+ ValueError: If the input image and reference image have different dtypes or numbers of channels.
206
+
207
+ """
208
+ if img.dtype != ref.dtype:
209
+ raise ValueError("Input image and reference image must have the same dtype.")
210
+ img_num_channels = get_num_channels(img)
211
+ ref_num_channels = get_num_channels(ref)
212
+
213
+ if img_num_channels != ref_num_channels:
214
+ raise ValueError("Input image and reference image must have the same number of channels.")
215
+
216
+ if img_num_channels == 1:
217
+ img = np.squeeze(img)
218
+ ref = np.squeeze(ref)
219
+
220
+ if img.shape != ref.shape:
221
+ ref = cv2.resize(ref, dsize=img.shape[:2], interpolation=cv2.INTER_AREA)
222
+
223
+ original_dtype = img.dtype
224
+
225
+ if original_dtype == np.float32:
226
+ img = from_float(img, np.uint8)
227
+ ref = from_float(ref, np.uint8)
228
+
229
+ transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]()
230
+ adapter = DomainAdapter(transformer=transformer, ref_img=ref)
231
+ transformed = adapter(img).astype(np.float32)
232
+
233
+ result = img.astype(np.float32) * (1 - weight) + transformed * weight
234
+
235
+ return result if original_dtype == np.uint8 else to_float(result)
236
+
237
+
238
+ def low_freq_mutate(amp_src: np.ndarray, amp_trg: np.ndarray, beta: float) -> np.ndarray:
239
+ image_shape = amp_src.shape[:2]
240
+
241
+ border = int(np.floor(min(image_shape) * beta))
242
+
243
+ center_x, center_y = fgeometric.center(image_shape)
244
+
245
+ height, width = image_shape
246
+
247
+ h1, h2 = max(0, int(center_y - border)), min(int(center_y + border), height)
248
+ w1, w2 = max(0, int(center_x - border)), min(int(center_x + border), width)
249
+ amp_src[h1:h2, w1:w2] = amp_trg[h1:h2, w1:w2]
250
+ return amp_src
251
+
252
+
253
+ @clipped
254
+ @preserve_channel_dim
255
+ def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray:
256
+ """Apply Fourier Domain Adaptation to the input image using a target image.
257
+
258
+ This function performs domain adaptation in the frequency domain by modifying the amplitude
259
+ spectrum of the source image based on the target image's amplitude spectrum. It preserves
260
+ the phase information of the source image, which helps maintain its content while adapting
261
+ its style to match the target image.
262
+
263
+ Args:
264
+ img (np.ndarray): The source image to be adapted. Can be grayscale or RGB.
265
+ target_img (np.ndarray): The target image used as a reference for adaptation.
266
+ Should have the same dimensions as the source image.
267
+ beta (float): The adaptation strength, typically in the range [0, 1].
268
+ Higher values result in stronger adaptation towards the target image's style.
269
+
270
+ Returns:
271
+ np.ndarray: The adapted image with the same shape and type as the input image.
272
+
273
+ Raises:
274
+ ValueError: If the source and target images have different shapes.
275
+
276
+ Note:
277
+ - Both input images are converted to float32 for processing.
278
+ - The function handles both grayscale (2D) and color (3D) images.
279
+ - For grayscale images, an extra dimension is added to facilitate uniform processing.
280
+ - The adaptation is performed channel-wise for color images.
281
+ - The output is clipped to the valid range and preserves the original number of channels.
282
+
283
+ The adaptation process involves the following steps for each channel:
284
+ 1. Compute the 2D Fourier Transform of both source and target images.
285
+ 2. Shift the zero frequency component to the center of the spectrum.
286
+ 3. Extract amplitude and phase information from the source image's spectrum.
287
+ 4. Mutate the source amplitude using the target amplitude and the beta parameter.
288
+ 5. Combine the mutated amplitude with the original phase.
289
+ 6. Perform the inverse Fourier Transform to obtain the adapted channel.
290
+
291
+ The `low_freq_mutate` function (not shown here) is responsible for the actual
292
+ amplitude mutation, focusing on low-frequency components which carry style information.
293
+
294
+ Examples:
295
+ >>> import numpy as np
296
+ >>> import albumentations as A
297
+ >>> source_img = np.random.rand(100, 100, 3).astype(np.float32)
298
+ >>> target_img = np.random.rand(100, 100, 3).astype(np.float32)
299
+ >>> adapted_img = A.fourier_domain_adaptation(source_img, target_img, beta=0.5)
300
+ >>> assert adapted_img.shape == source_img.shape
301
+
302
+ References:
303
+ FDA: Fourier Domain Adaptation for Semantic Segmentation: Yang and Soatto, 2020, CVPR
304
+ https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf
305
+
306
+ """
307
+ src_img = img.astype(np.float32)
308
+ trg_img = target_img.astype(np.float32)
309
+
310
+ if src_img.ndim == MONO_CHANNEL_DIMENSIONS:
311
+ src_img = np.expand_dims(src_img, axis=-1)
312
+ if trg_img.ndim == MONO_CHANNEL_DIMENSIONS:
313
+ trg_img = np.expand_dims(trg_img, axis=-1)
314
+
315
+ num_channels = src_img.shape[-1]
316
+
317
+ # Prepare container for the output image
318
+ src_in_trg = np.zeros_like(src_img)
319
+
320
+ for channel_id in range(num_channels):
321
+ # Perform FFT on each channel
322
+ fft_src = np.fft.fft2(src_img[:, :, channel_id])
323
+ fft_trg = np.fft.fft2(trg_img[:, :, channel_id])
324
+
325
+ # Shift the zero frequency component to the center
326
+ fft_src_shifted = np.fft.fftshift(fft_src)
327
+ fft_trg_shifted = np.fft.fftshift(fft_trg)
328
+
329
+ # Extract amplitude and phase
330
+ amp_src, pha_src = np.abs(fft_src_shifted), np.angle(fft_src_shifted)
331
+ amp_trg = np.abs(fft_trg_shifted)
332
+
333
+ # Mutate the amplitude part of the source with the target
334
+ mutated_amp = low_freq_mutate(amp_src.copy(), amp_trg, beta)
335
+
336
+ # Combine the mutated amplitude with the original phase
337
+ fft_src_mutated = np.fft.ifftshift(mutated_amp * np.exp(1j * pha_src))
338
+
339
+ # Perform inverse FFT
340
+ src_in_trg_channel = np.fft.ifft2(fft_src_mutated)
341
+
342
+ # Store the result in the corresponding channel of the output image
343
+ src_in_trg[:, :, channel_id] = np.real(src_in_trg_channel)
344
+
345
+ return src_in_trg
346
+
347
+
348
+ @clipped
349
+ @preserve_channel_dim
350
+ def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray:
351
+ """Apply histogram matching to an input image using a reference image and blend the result.
352
+
353
+ This function performs histogram matching between the input image and a reference image,
354
+ then blends the result with the original input image based on the specified blend ratio.
355
+
356
+ Args:
357
+ img (np.ndarray): The input image to be transformed. Can be either grayscale or RGB.
358
+ Supported dtypes: uint8, float32 (values should be in [0, 1] range).
359
+ reference_image (np.ndarray): The reference image used for histogram matching.
360
+ Should have the same number of channels as the input image.
361
+ Supported dtypes: uint8, float32 (values should be in [0, 1] range).
362
+ blend_ratio (float): The ratio for blending the matched image with the original image.
363
+ Should be in the range [0, 1], where 0 means no change and 1 means full histogram matching.
364
+
365
+ Returns:
366
+ np.ndarray: The transformed image after histogram matching and blending.
367
+ The output will have the same shape and dtype as the input image.
368
+
369
+ Supported image types:
370
+ - Grayscale images: 2D arrays
371
+ - RGB images: 3D arrays with 3 channels
372
+ - Multispectral images: 3D arrays with more than 3 channels
373
+
374
+ Note:
375
+ - If the input and reference images have different sizes, the reference image
376
+ will be resized to match the input image's dimensions.
377
+ - The function uses a custom implementation of histogram matching based on OpenCV and NumPy.
378
+ - The @clipped and @preserve_channel_dim decorators ensure the output is within
379
+ the valid range and maintains the original number of dimensions.
380
+
381
+ """
382
+ # Resize reference image only if necessary
383
+ if img.shape[:2] != reference_image.shape[:2]:
384
+ reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0]))
385
+
386
+ img = np.squeeze(img)
387
+ reference_image = np.squeeze(reference_image)
388
+
389
+ # Match histograms between the images
390
+ matched = match_histograms(img, reference_image)
391
+
392
+ # Blend the original image and the matched image
393
+ return add_weighted(matched, blend_ratio, img, 1 - blend_ratio)
394
+
395
+
396
+ @uint8_io
397
+ @preserve_channel_dim
398
+ def match_histograms(image: np.ndarray, reference: np.ndarray) -> np.ndarray:
399
+ """Adjust an image so that its cumulative histogram matches that of another.
400
+
401
+ The adjustment is applied separately for each channel.
402
+
403
+ Args:
404
+ image (np.ndarray): Input image. Can be gray-scale or in color.
405
+ reference (np.ndarray): Image to match histogram of. Must have the same number of channels as image.
406
+ channel_axis (int | None): If None, the image is assumed to be a grayscale (single channel) image.
407
+ Otherwise, this indicates which axis of the array corresponds to channels.
408
+
409
+ Returns:
410
+ np.ndarray: Transformed input image.
411
+
412
+ Raises:
413
+ ValueError: Thrown when the number of channels in the input image and the reference differ.
414
+
415
+ """
416
+ if reference.dtype != np.uint8:
417
+ reference = from_float(reference, np.uint8)
418
+
419
+ if image.ndim != reference.ndim:
420
+ raise ValueError("Image and reference must have the same number of dimensions.")
421
+
422
+ # Expand dimensions for grayscale images
423
+ if image.ndim == 2:
424
+ image = np.expand_dims(image, axis=-1)
425
+ if reference.ndim == 2:
426
+ reference = np.expand_dims(reference, axis=-1)
427
+
428
+ matched = np.empty(image.shape, dtype=np.uint8)
429
+
430
+ num_channels = image.shape[-1]
431
+
432
+ for channel in range(num_channels):
433
+ matched_channel = _match_cumulative_cdf(image[..., channel], reference[..., channel]).astype(np.uint8)
434
+ matched[..., channel] = matched_channel
435
+
436
+ return matched
437
+
438
+
439
+ def _match_cumulative_cdf(source: np.ndarray, template: np.ndarray) -> np.ndarray:
440
+ src_lookup = source.reshape(-1)
441
+ src_counts = np.bincount(src_lookup)
442
+ tmpl_counts = np.bincount(template.reshape(-1))
443
+
444
+ # omit values where the count was 0
445
+ tmpl_values = np.nonzero(tmpl_counts)[0]
446
+ tmpl_counts = tmpl_counts[tmpl_values]
447
+
448
+ # calculate normalized quantiles for each array
449
+ src_quantiles = np.cumsum(src_counts) / source.size
450
+ tmpl_quantiles = np.cumsum(tmpl_counts) / template.size
451
+
452
+ interp_a_values = np.interp(src_quantiles, tmpl_quantiles, tmpl_values)
453
+ return interp_a_values[src_lookup].reshape(source.shape).astype(np.uint8)