nrtk-albumentations 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nrtk-albumentations might be problematic. Click here for more details.

Files changed (62) hide show
  1. albumentations/__init__.py +21 -0
  2. albumentations/augmentations/__init__.py +23 -0
  3. albumentations/augmentations/blur/__init__.py +0 -0
  4. albumentations/augmentations/blur/functional.py +438 -0
  5. albumentations/augmentations/blur/transforms.py +1633 -0
  6. albumentations/augmentations/crops/__init__.py +0 -0
  7. albumentations/augmentations/crops/functional.py +494 -0
  8. albumentations/augmentations/crops/transforms.py +3647 -0
  9. albumentations/augmentations/dropout/__init__.py +0 -0
  10. albumentations/augmentations/dropout/channel_dropout.py +134 -0
  11. albumentations/augmentations/dropout/coarse_dropout.py +567 -0
  12. albumentations/augmentations/dropout/functional.py +1017 -0
  13. albumentations/augmentations/dropout/grid_dropout.py +166 -0
  14. albumentations/augmentations/dropout/mask_dropout.py +274 -0
  15. albumentations/augmentations/dropout/transforms.py +461 -0
  16. albumentations/augmentations/dropout/xy_masking.py +186 -0
  17. albumentations/augmentations/geometric/__init__.py +0 -0
  18. albumentations/augmentations/geometric/distortion.py +1238 -0
  19. albumentations/augmentations/geometric/flip.py +752 -0
  20. albumentations/augmentations/geometric/functional.py +4151 -0
  21. albumentations/augmentations/geometric/pad.py +676 -0
  22. albumentations/augmentations/geometric/resize.py +956 -0
  23. albumentations/augmentations/geometric/rotate.py +864 -0
  24. albumentations/augmentations/geometric/transforms.py +1962 -0
  25. albumentations/augmentations/mixing/__init__.py +0 -0
  26. albumentations/augmentations/mixing/domain_adaptation.py +787 -0
  27. albumentations/augmentations/mixing/domain_adaptation_functional.py +453 -0
  28. albumentations/augmentations/mixing/functional.py +878 -0
  29. albumentations/augmentations/mixing/transforms.py +832 -0
  30. albumentations/augmentations/other/__init__.py +0 -0
  31. albumentations/augmentations/other/lambda_transform.py +180 -0
  32. albumentations/augmentations/other/type_transform.py +261 -0
  33. albumentations/augmentations/pixel/__init__.py +0 -0
  34. albumentations/augmentations/pixel/functional.py +4226 -0
  35. albumentations/augmentations/pixel/transforms.py +7556 -0
  36. albumentations/augmentations/spectrogram/__init__.py +0 -0
  37. albumentations/augmentations/spectrogram/transform.py +220 -0
  38. albumentations/augmentations/text/__init__.py +0 -0
  39. albumentations/augmentations/text/functional.py +272 -0
  40. albumentations/augmentations/text/transforms.py +299 -0
  41. albumentations/augmentations/transforms3d/__init__.py +0 -0
  42. albumentations/augmentations/transforms3d/functional.py +393 -0
  43. albumentations/augmentations/transforms3d/transforms.py +1422 -0
  44. albumentations/augmentations/utils.py +249 -0
  45. albumentations/core/__init__.py +0 -0
  46. albumentations/core/bbox_utils.py +920 -0
  47. albumentations/core/composition.py +1885 -0
  48. albumentations/core/hub_mixin.py +299 -0
  49. albumentations/core/keypoints_utils.py +521 -0
  50. albumentations/core/label_manager.py +339 -0
  51. albumentations/core/pydantic.py +239 -0
  52. albumentations/core/serialization.py +352 -0
  53. albumentations/core/transforms_interface.py +976 -0
  54. albumentations/core/type_definitions.py +127 -0
  55. albumentations/core/utils.py +605 -0
  56. albumentations/core/validation.py +129 -0
  57. albumentations/pytorch/__init__.py +1 -0
  58. albumentations/pytorch/transforms.py +189 -0
  59. nrtk_albumentations-2.1.0.dist-info/METADATA +196 -0
  60. nrtk_albumentations-2.1.0.dist-info/RECORD +62 -0
  61. nrtk_albumentations-2.1.0.dist-info/WHEEL +4 -0
  62. nrtk_albumentations-2.1.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,4226 @@
1
+ """Functional implementations of image augmentation operations.
2
+
3
+ This module contains low-level functions for various image augmentation techniques including
4
+ color transformations, blur effects, tone curve adjustments, noise additions, and other visual
5
+ modifications. These functions form the foundation for the transform classes and provide
6
+ the core functionality for manipulating image data during the augmentation process.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import math
12
+ from collections.abc import Sequence
13
+ from typing import Any, Callable, Literal
14
+ from warnings import warn
15
+
16
+ import cv2
17
+ import numpy as np
18
+ from albucore import (
19
+ MAX_VALUES_BY_DTYPE,
20
+ add,
21
+ add_array,
22
+ add_constant,
23
+ add_weighted,
24
+ clip,
25
+ clipped,
26
+ convert_value,
27
+ float32_io,
28
+ from_float,
29
+ get_num_channels,
30
+ is_grayscale_image,
31
+ is_rgb_image,
32
+ maybe_process_in_chunks,
33
+ multiply,
34
+ multiply_add,
35
+ multiply_by_array,
36
+ multiply_by_constant,
37
+ normalize,
38
+ normalize_per_image,
39
+ power,
40
+ preserve_channel_dim,
41
+ reshape_for_channel,
42
+ restore_from_channel,
43
+ sz_lut,
44
+ uint8_io,
45
+ )
46
+
47
+ import albumentations.augmentations.geometric.functional as fgeometric
48
+ from albumentations.augmentations.utils import (
49
+ PCA,
50
+ non_rgb_error,
51
+ )
52
+ from albumentations.core.type_definitions import (
53
+ MONO_CHANNEL_DIMENSIONS,
54
+ NUM_MULTI_CHANNEL_DIMENSIONS,
55
+ NUM_RGB_CHANNELS,
56
+ )
57
+
58
+
59
+ @uint8_io
60
+ @preserve_channel_dim
61
+ def shift_hsv(
62
+ img: np.ndarray,
63
+ hue_shift: float,
64
+ sat_shift: float,
65
+ val_shift: float,
66
+ ) -> np.ndarray:
67
+ """Shift the hue, saturation, and value of an image.
68
+
69
+ Args:
70
+ img (np.ndarray): The image to shift.
71
+ hue_shift (float): The amount to shift the hue.
72
+ sat_shift (float): The amount to shift the saturation.
73
+ val_shift (float): The amount to shift the value.
74
+
75
+ Returns:
76
+ np.ndarray: The shifted image.
77
+
78
+ """
79
+ if hue_shift == 0 and sat_shift == 0 and val_shift == 0:
80
+ return img
81
+
82
+ is_gray = is_grayscale_image(img)
83
+
84
+ if is_gray:
85
+ if hue_shift != 0 or sat_shift != 0:
86
+ hue_shift = 0
87
+ sat_shift = 0
88
+ warn(
89
+ "HueSaturationValue: hue_shift and sat_shift are not applicable to grayscale image. "
90
+ "Set them to 0 or use RGB image",
91
+ stacklevel=2,
92
+ )
93
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
94
+
95
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
96
+ hue, sat, val = cv2.split(img)
97
+
98
+ if hue_shift != 0:
99
+ lut_hue = np.arange(0, 256, dtype=np.int16)
100
+ lut_hue = np.mod(lut_hue + hue_shift, 180).astype(np.uint8)
101
+ hue = sz_lut(hue, lut_hue, inplace=False)
102
+
103
+ if sat_shift != 0:
104
+ # Create a mask for all grayscale pixels (S=0)
105
+ # These should remain grayscale regardless of saturation change
106
+ grayscale_mask = sat == 0
107
+
108
+ # Apply saturation shift only to non-white pixels
109
+ sat = add_constant(sat, sat_shift, inplace=True)
110
+
111
+ # Reset saturation for white pixels
112
+ sat[grayscale_mask] = 0
113
+
114
+ if val_shift != 0:
115
+ val = add_constant(val, val_shift, inplace=True)
116
+
117
+ img = cv2.merge((hue, sat, val))
118
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
119
+
120
+ return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) if is_gray else img
121
+
122
+
123
+ @clipped
124
+ def solarize(img: np.ndarray, threshold: float) -> np.ndarray:
125
+ """Invert all pixel values above a threshold.
126
+
127
+ Args:
128
+ img (np.ndarray): The image to solarize. Can be uint8 or float32.
129
+ threshold (float): Normalized threshold value in range [0, 1].
130
+ For uint8 images: pixels above threshold * 255 are inverted
131
+ For float32 images: pixels above threshold are inverted
132
+
133
+ Returns:
134
+ np.ndarray: Solarized image.
135
+
136
+ Note:
137
+ The threshold is normalized to [0, 1] range for both uint8 and float32 images.
138
+ For uint8 images, the threshold is internally scaled by 255.
139
+
140
+ """
141
+ dtype = img.dtype
142
+ max_val = MAX_VALUES_BY_DTYPE[dtype]
143
+
144
+ if dtype == np.uint8:
145
+ lut = np.array(
146
+ [max_val - i if i >= threshold * max_val else i for i in range(int(max_val) + 1)],
147
+ dtype=dtype,
148
+ )
149
+ prev_shape = img.shape
150
+ img = sz_lut(img, lut, inplace=False)
151
+ return img if len(prev_shape) == img.ndim else np.expand_dims(img, -1)
152
+ return np.where(img >= threshold, max_val - img, img)
153
+
154
+
155
+ @uint8_io
156
+ @clipped
157
+ def posterize(img: np.ndarray, bits: Literal[1, 2, 3, 4, 5, 6, 7] | list[Literal[1, 2, 3, 4, 5, 6, 7]]) -> np.ndarray:
158
+ """Reduce the number of bits for each color channel by keeping only the highest N bits.
159
+
160
+ Args:
161
+ img (np.ndarray): Input image. Can be single or multi-channel.
162
+ bits (Literal[1, 2, 3, 4, 5, 6, 7] | list[Literal[1, 2, 3, 4, 5, 6, 7]]): Number of high bits to keep..
163
+ Can be either:
164
+ - A single value to apply the same bit reduction to all channels
165
+ - A list of values to apply different bit reduction per channel.
166
+ Length of list must match number of channels in image.
167
+
168
+ Returns:
169
+ np.ndarray: Image with reduced bit depth. Has same shape and dtype as input.
170
+
171
+ Note:
172
+ - The transform keeps the N highest bits and sets all other bits to 0
173
+ - For example, if bits=3:
174
+ - Original value: 11010110 (214)
175
+ - Keep 3 bits: 11000000 (192)
176
+ - The number of unique colors per channel will be 2^bits
177
+ - Higher bits values = more colors = more subtle effect
178
+ - Lower bits values = fewer colors = more dramatic posterization
179
+
180
+ Examples:
181
+ >>> import numpy as np
182
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
183
+ >>> # Same posterization for all channels
184
+ >>> result = posterize(image, bits=3)
185
+ >>> # Different posterization per channel
186
+ >>> result = posterize(image, bits=[3, 4, 5]) # RGB channels
187
+
188
+ """
189
+ bits_array = np.uint8(bits)
190
+
191
+ if not bits_array.shape or len(bits_array) == 1:
192
+ lut = np.arange(0, 256, dtype=np.uint8)
193
+ mask = ~np.uint8(2 ** (8 - bits_array) - 1)
194
+ lut &= mask
195
+
196
+ return sz_lut(img, lut, inplace=False)
197
+
198
+ result_img = np.empty_like(img)
199
+ for i, channel_bits in enumerate(bits_array):
200
+ lut = np.arange(0, 256, dtype=np.uint8)
201
+ mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
202
+ lut &= mask
203
+
204
+ result_img[..., i] = sz_lut(img[..., i], lut, inplace=True)
205
+
206
+ return result_img
207
+
208
+
209
+ def _equalize_pil(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
210
+ histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
211
+ h = np.array([_f for _f in histogram if _f])
212
+
213
+ if len(h) <= 1:
214
+ return img.copy()
215
+
216
+ step = np.sum(h[:-1]) // 255
217
+ if not step:
218
+ return img.copy()
219
+
220
+ lut = np.minimum((np.cumsum(histogram) + step // 2) // step, 255).astype(np.uint8)
221
+
222
+ return sz_lut(img, lut, inplace=True)
223
+
224
+
225
+ def _equalize_cv(img: np.ndarray, mask: np.ndarray | None = None) -> np.ndarray:
226
+ if mask is None:
227
+ return cv2.equalizeHist(img)
228
+
229
+ histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
230
+
231
+ # Find the first non-zero index with a numpy operation
232
+ i = np.flatnonzero(histogram)[0] if np.any(histogram) else 255
233
+
234
+ total = np.sum(histogram)
235
+
236
+ scale = 255.0 / (total - histogram[i])
237
+
238
+ # Optimize cumulative sum and scale to generate LUT
239
+ cumsum_histogram = np.cumsum(histogram)
240
+ lut = np.clip(((cumsum_histogram - cumsum_histogram[i]) * scale).round(), 0, 255).astype(np.uint8)
241
+
242
+ return sz_lut(img, lut, inplace=True)
243
+
244
+
245
+ def _check_preconditions(
246
+ img: np.ndarray,
247
+ mask: np.ndarray | None,
248
+ by_channels: bool,
249
+ ) -> None:
250
+ if mask is not None:
251
+ if is_rgb_image(mask) and is_grayscale_image(img):
252
+ raise ValueError(
253
+ f"Wrong mask shape. Image shape: {img.shape}. Mask shape: {mask.shape}",
254
+ )
255
+ if not by_channels and not is_grayscale_image(mask):
256
+ msg = f"When by_channels=False only 1-channel mask supports. Mask shape: {mask.shape}"
257
+ raise ValueError(msg)
258
+
259
+
260
+ def _handle_mask(
261
+ mask: np.ndarray | None,
262
+ i: int | None = None,
263
+ ) -> np.ndarray | None:
264
+ if mask is None:
265
+ return None
266
+ mask = mask.astype(
267
+ np.uint8,
268
+ copy=False,
269
+ ) # Use copy=False to avoid unnecessary copying
270
+ # Check for grayscale image and avoid slicing if i is None
271
+ if i is not None and not is_grayscale_image(mask):
272
+ mask = mask[..., i]
273
+
274
+ return mask
275
+
276
+
277
+ @uint8_io
278
+ @preserve_channel_dim
279
+ def equalize(
280
+ img: np.ndarray,
281
+ mask: np.ndarray | None = None,
282
+ mode: Literal["cv", "pil"] = "cv",
283
+ by_channels: bool = True,
284
+ ) -> np.ndarray:
285
+ """Apply histogram equalization to the input image.
286
+
287
+ This function enhances the contrast of the input image by equalizing its histogram.
288
+ It supports both grayscale and color images, and can operate on individual channels
289
+ or on the luminance channel of the image.
290
+
291
+ Args:
292
+ img (np.ndarray): Input image. Can be grayscale (2D array) or RGB (3D array).
293
+ mask (np.ndarray | None): Optional mask to apply the equalization selectively.
294
+ If provided, must have the same shape as the input image. Default: None.
295
+ mode (ImageMode): The backend to use for equalization. Can be either "cv" for
296
+ OpenCV or "pil" for Pillow-style equalization. Default: "cv".
297
+ by_channels (bool): If True, applies equalization to each channel independently.
298
+ If False, converts the image to YCrCb color space and equalizes only the
299
+ luminance channel. Only applicable to color images. Default: True.
300
+
301
+ Returns:
302
+ np.ndarray: Equalized image. The output has the same dtype as the input.
303
+
304
+ Raises:
305
+ ValueError: If the input image or mask have invalid shapes or types.
306
+
307
+ Note:
308
+ - If the input image is not uint8, it will be temporarily converted to uint8
309
+ for processing and then converted back to its original dtype.
310
+ - For color images, when by_channels=False, the image is converted to YCrCb
311
+ color space, equalized on the Y channel, and then converted back to RGB.
312
+ - The function preserves the original number of channels in the image.
313
+
314
+ Examples:
315
+ >>> import numpy as np
316
+ >>> import albumentations as A
317
+ >>> image = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
318
+ >>> equalized = A.equalize(image, mode="cv", by_channels=True)
319
+ >>> assert equalized.shape == image.shape
320
+ >>> assert equalized.dtype == image.dtype
321
+
322
+ """
323
+ _check_preconditions(img, mask, by_channels)
324
+ function = _equalize_pil if mode == "pil" else _equalize_cv
325
+
326
+ if is_grayscale_image(img):
327
+ return function(img, _handle_mask(mask))
328
+
329
+ if not by_channels:
330
+ result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
331
+ result_img[..., 0] = function(result_img[..., 0], _handle_mask(mask))
332
+ return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB)
333
+
334
+ result_img = np.empty_like(img)
335
+ for i in range(NUM_RGB_CHANNELS):
336
+ _mask = _handle_mask(mask, i)
337
+ result_img[..., i] = function(img[..., i], _mask)
338
+
339
+ return result_img
340
+
341
+
342
+ def evaluate_bez(
343
+ low_y: float | np.ndarray,
344
+ high_y: float | np.ndarray,
345
+ ) -> np.ndarray:
346
+ """Evaluate the Bezier curve at the given t values.
347
+
348
+ Args:
349
+ t (np.ndarray): The t values to evaluate the Bezier curve at.
350
+ low_y (float | np.ndarray): The low y values to evaluate the Bezier curve at.
351
+ high_y (float | np.ndarray): The high y values to evaluate the Bezier curve at.
352
+
353
+ Returns:
354
+ np.ndarray: The Bezier curve values.
355
+
356
+ """
357
+ t = np.linspace(0.0, 1.0, 256)[..., None]
358
+
359
+ one_minus_t = 1 - t
360
+ return (3 * one_minus_t**2 * t * low_y + 3 * one_minus_t * t**2 * high_y + t**3) * 255
361
+
362
+
363
+ @uint8_io
364
+ def move_tone_curve(
365
+ img: np.ndarray,
366
+ low_y: float | np.ndarray,
367
+ high_y: float | np.ndarray,
368
+ num_channels: int,
369
+ ) -> np.ndarray:
370
+ """Rescales the relationship between bright and dark areas of the image by manipulating its tone curve.
371
+
372
+ Args:
373
+ img (np.ndarray): Any number of channels
374
+ low_y (float | np.ndarray): per-channel or single y-position of a Bezier control point used
375
+ to adjust the tone curve, must be in range [0, 1]
376
+ high_y (float | np.ndarray): per-channel or single y-position of a Bezier control point used
377
+ to adjust image tone curve, must be in range [0, 1]
378
+ num_channels (int): The number of channels in the input image.
379
+
380
+ Returns:
381
+ np.ndarray: Image with adjusted tone curve
382
+
383
+ """
384
+ if np.isscalar(low_y) and np.isscalar(high_y):
385
+ lut = clip(np.rint(evaluate_bez(low_y, high_y)), np.uint8, inplace=False)
386
+ return sz_lut(img, lut, inplace=False)
387
+
388
+ if isinstance(low_y, np.ndarray) and isinstance(high_y, np.ndarray):
389
+ luts = clip(
390
+ np.rint(evaluate_bez(low_y, high_y).T),
391
+ np.uint8,
392
+ inplace=False,
393
+ )
394
+ return np.stack(
395
+ [sz_lut(img[..., i], np.ascontiguousarray(luts[i]), inplace=False) for i in range(num_channels)],
396
+ axis=-1,
397
+ )
398
+
399
+ raise TypeError(
400
+ f"low_y and high_y must both be of type float or np.ndarray. Got {type(low_y)} and {type(high_y)}",
401
+ )
402
+
403
+
404
+ @clipped
405
+ def linear_transformation_rgb(
406
+ img: np.ndarray,
407
+ transformation_matrix: np.ndarray,
408
+ ) -> np.ndarray:
409
+ """Apply a linear transformation to the RGB channels of an image.
410
+
411
+ This function applies a linear transformation matrix to the RGB channels of an image.
412
+ The transformation matrix is a 3x3 matrix that maps the RGB values to new values.
413
+
414
+ Args:
415
+ img (np.ndarray): Input image. Can be grayscale (2D array) or RGB (3D array).
416
+ transformation_matrix (np.ndarray): 3x3 transformation matrix.
417
+
418
+ Returns:
419
+ np.ndarray: Image with the linear transformation applied. The output has the same dtype as the input.
420
+
421
+ """
422
+ return cv2.transform(img, transformation_matrix)
423
+
424
+
425
+ @uint8_io
426
+ @preserve_channel_dim
427
+ def clahe(
428
+ img: np.ndarray,
429
+ clip_limit: float,
430
+ tile_grid_size: tuple[int, int],
431
+ ) -> np.ndarray:
432
+ """Apply Contrast Limited Adaptive Histogram Equalization (CLAHE) to the input image.
433
+
434
+ This function enhances the contrast of the input image using CLAHE. For color images,
435
+ it converts the image to the LAB color space, applies CLAHE to the L channel, and then
436
+ converts the image back to RGB.
437
+
438
+ Args:
439
+ img (np.ndarray): Input image. Can be grayscale (2D array) or RGB (3D array).
440
+ clip_limit (float): Threshold for contrast limiting. Higher values give more contrast.
441
+ tile_grid_size (tuple[int, int]): Size of grid for histogram equalization.
442
+ Width and height of the grid.
443
+
444
+ Returns:
445
+ np.ndarray: Image with CLAHE applied. The output has the same dtype as the input.
446
+
447
+ Note:
448
+ - If the input image is float32, it's temporarily converted to uint8 for processing
449
+ and then converted back to float32.
450
+ - For color images, CLAHE is applied only to the luminance channel in the LAB color space.
451
+
452
+ Raises:
453
+ ValueError: If the input image is not 2D or 3D.
454
+
455
+ Examples:
456
+ >>> import numpy as np
457
+ >>> img = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
458
+ >>> result = clahe(img, clip_limit=2.0, tile_grid_size=(8, 8))
459
+ >>> assert result.shape == img.shape
460
+ >>> assert result.dtype == img.dtype
461
+
462
+ """
463
+ clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
464
+
465
+ if is_grayscale_image(img):
466
+ return clahe_mat.apply(img)
467
+
468
+ img_lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
469
+ img_lab[:, :, 0] = clahe_mat.apply(img_lab[:, :, 0])
470
+
471
+ return cv2.cvtColor(img_lab, cv2.COLOR_LAB2RGB)
472
+
473
+
474
+ @uint8_io
475
+ @preserve_channel_dim
476
+ def image_compression(
477
+ img: np.ndarray,
478
+ quality: int,
479
+ image_type: Literal[".jpg", ".webp"],
480
+ ) -> np.ndarray:
481
+ """Compress the image using JPEG or WebP compression.
482
+
483
+ Args:
484
+ img (np.ndarray): Input image
485
+ quality (int): Quality of compression in range [1, 100]
486
+ image_type (Literal[".jpg", ".webp"]): Type of compression to use
487
+
488
+ Returns:
489
+ np.ndarray: Compressed image
490
+
491
+ """
492
+ # Determine the quality flag for compression
493
+ quality_flag = cv2.IMWRITE_JPEG_QUALITY if image_type == ".jpg" else cv2.IMWRITE_WEBP_QUALITY
494
+ num_channels = get_num_channels(img)
495
+
496
+ # Prepare to encode and decode
497
+ def encode_decode(src_img: np.ndarray, read_mode: int) -> np.ndarray:
498
+ _, encoded_img = cv2.imencode(image_type, src_img, (int(quality_flag), quality))
499
+ return cv2.imdecode(encoded_img, read_mode)
500
+
501
+ if num_channels == 1:
502
+ # Grayscale image
503
+ decoded = encode_decode(img, cv2.IMREAD_GRAYSCALE)
504
+ return decoded[..., np.newaxis] # Add channel dimension back
505
+
506
+ if num_channels in (2, NUM_RGB_CHANNELS):
507
+ # 2 channels: pad to 3, or 3 (RGB) channels
508
+ padded_img = np.pad(img, ((0, 0), (0, 0), (0, 1)), mode="constant") if num_channels == 2 else img
509
+ decoded_bgr = encode_decode(padded_img, cv2.IMREAD_UNCHANGED)
510
+ return decoded_bgr[..., :num_channels] # Return only the required number of channels
511
+
512
+ # More than 3 channels
513
+ bgr = img[..., :NUM_RGB_CHANNELS]
514
+ decoded_bgr = encode_decode(bgr, cv2.IMREAD_UNCHANGED)
515
+
516
+ # Process additional channels
517
+ extra_channels = [
518
+ encode_decode(img[..., i], cv2.IMREAD_GRAYSCALE)[..., np.newaxis] for i in range(NUM_RGB_CHANNELS, num_channels)
519
+ ]
520
+ return np.dstack([decoded_bgr, *extra_channels])
521
+
522
+
523
+ @uint8_io
524
+ def add_snow_bleach(
525
+ img: np.ndarray,
526
+ snow_point: float,
527
+ brightness_coeff: float,
528
+ ) -> np.ndarray:
529
+ """Adds a simple snow effect to the image by bleaching out pixels.
530
+
531
+ This function simulates a basic snow effect by increasing the brightness of pixels
532
+ that are above a certain threshold (snow_point). It operates in the HLS color space
533
+ to modify the lightness channel.
534
+
535
+ Args:
536
+ img (np.ndarray): Input image. Can be either RGB uint8 or float32.
537
+ snow_point (float): A float in the range [0, 1], scaled and adjusted to determine
538
+ the threshold for pixel modification. Higher values result in less snow effect.
539
+ brightness_coeff (float): Coefficient applied to increase the brightness of pixels
540
+ below the snow_point threshold. Larger values lead to more pronounced snow effects.
541
+ Should be greater than 1.0 for a visible effect.
542
+
543
+ Returns:
544
+ np.ndarray: Image with simulated snow effect. The output has the same dtype as the input.
545
+
546
+ Note:
547
+ - This function converts the image to the HLS color space to modify the lightness channel.
548
+ - The snow effect is created by selectively increasing the brightness of pixels.
549
+ - This method tends to create a 'bleached' look, which may not be as realistic as more
550
+ advanced snow simulation techniques.
551
+ - The function automatically handles both uint8 and float32 input images.
552
+
553
+ The snow effect is created through the following steps:
554
+ 1. Convert the image from RGB to HLS color space.
555
+ 2. Adjust the snow_point threshold.
556
+ 3. Increase the lightness of pixels below the threshold.
557
+ 4. Convert the image back to RGB.
558
+
559
+ Mathematical Formulation:
560
+ Let L be the lightness channel in HLS space.
561
+ For each pixel (i, j):
562
+ If L[i, j] < snow_point:
563
+ L[i, j] = L[i, j] * brightness_coeff
564
+
565
+ Examples:
566
+ >>> import numpy as np
567
+ >>> import albumentations as A
568
+ >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
569
+ >>> snowy_image = A.functional.add_snow_v1(image, snow_point=0.5, brightness_coeff=1.5)
570
+
571
+ References:
572
+ - HLS Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
573
+ - Original implementation: https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
574
+
575
+ """
576
+ max_value = MAX_VALUES_BY_DTYPE[np.uint8]
577
+
578
+ # Precompute snow_point threshold
579
+ snow_point = (snow_point * max_value / 2) + (max_value / 3)
580
+
581
+ # Convert image to HLS color space once and avoid repeated dtype casting
582
+ image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
583
+ lightness_channel = image_hls[:, :, 1].astype(np.float32)
584
+
585
+ # Utilize boolean indexing for efficient lightness adjustment
586
+ mask = lightness_channel < snow_point
587
+ lightness_channel[mask] *= brightness_coeff
588
+
589
+ # Clip the lightness values in place
590
+ lightness_channel = clip(lightness_channel, np.uint8, inplace=True)
591
+
592
+ # Update the lightness channel in the original image
593
+ image_hls[:, :, 1] = lightness_channel
594
+
595
+ # Convert back to RGB
596
+ return cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
597
+
598
+
599
+ def generate_snow_textures(
600
+ img_shape: tuple[int, int],
601
+ random_generator: np.random.Generator,
602
+ ) -> tuple[np.ndarray, np.ndarray]:
603
+ """Generate snow texture and sparkle mask.
604
+
605
+ Args:
606
+ img_shape (tuple[int, int]): Image shape.
607
+ random_generator (np.random.Generator): Random generator to use.
608
+
609
+ Returns:
610
+ tuple[np.ndarray, np.ndarray]: Tuple of (snow_texture, sparkle_mask) arrays.
611
+
612
+ """
613
+ # Generate base snow texture
614
+ snow_texture = random_generator.normal(size=img_shape[:2], loc=0.5, scale=0.3)
615
+ snow_texture = cv2.GaussianBlur(snow_texture, (0, 0), sigmaX=1, sigmaY=1)
616
+
617
+ # Generate sparkle mask
618
+ sparkle_mask = random_generator.random(img_shape[:2]) > 0.99
619
+
620
+ return snow_texture, sparkle_mask
621
+
622
+
623
+ @uint8_io
624
+ def add_snow_texture(
625
+ img: np.ndarray,
626
+ snow_point: float,
627
+ brightness_coeff: float,
628
+ snow_texture: np.ndarray,
629
+ sparkle_mask: np.ndarray,
630
+ ) -> np.ndarray:
631
+ """Add a realistic snow effect to the input image.
632
+
633
+ This function simulates snowfall by applying multiple visual effects to the image,
634
+ including brightness adjustment, snow texture overlay, depth simulation, and color tinting.
635
+ The result is a more natural-looking snow effect compared to simple pixel bleaching methods.
636
+
637
+ Args:
638
+ img (np.ndarray): Input image in RGB format.
639
+ snow_point (float): Coefficient that controls the amount and intensity of snow.
640
+ Should be in the range [0, 1], where 0 means no snow and 1 means maximum snow effect.
641
+ brightness_coeff (float): Coefficient for brightness adjustment to simulate the
642
+ reflective nature of snow. Should be in the range [0, 1], where higher values
643
+ result in a brighter image.
644
+ snow_texture (np.ndarray): Snow texture.
645
+ sparkle_mask (np.ndarray): Sparkle mask.
646
+
647
+ Returns:
648
+ np.ndarray: Image with added snow effect. The output has the same dtype as the input.
649
+
650
+ Note:
651
+ - The function first converts the image to HSV color space for better control over
652
+ brightness and color adjustments.
653
+ - A snow texture is generated using Gaussian noise and then filtered for a more
654
+ natural appearance.
655
+ - A depth effect is simulated, with more snow at the top of the image and less at the bottom.
656
+ - A slight blue tint is added to simulate the cool color of snow.
657
+ - Random sparkle effects are added to simulate light reflecting off snow crystals.
658
+
659
+ The snow effect is created through the following steps:
660
+ 1. Brightness adjustment in HSV space
661
+ 2. Generation of a snow texture using Gaussian noise
662
+ 3. Application of a depth effect to the snow texture
663
+ 4. Blending of the snow texture with the original image
664
+ 5. Addition of a cool blue tint
665
+ 6. Addition of sparkle effects
666
+
667
+ Examples:
668
+ >>> import numpy as np
669
+ >>> import albumentations as A
670
+ >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
671
+ >>> snowy_image = A.functional.add_snow_v2(image, snow_coeff=0.5, brightness_coeff=0.2)
672
+
673
+ Note:
674
+ This function works with both uint8 and float32 image types, automatically
675
+ handling the conversion between them.
676
+
677
+ References:
678
+ - Perlin Noise: https://en.wikipedia.org/wiki/Perlin_noise
679
+ - HSV Color Space: https://en.wikipedia.org/wiki/HSL_and_HSV
680
+
681
+ """
682
+ max_value = MAX_VALUES_BY_DTYPE[np.uint8]
683
+
684
+ # Convert to HSV for better color control
685
+ img_hsv = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.float32)
686
+
687
+ # Increase brightness
688
+ img_hsv[:, :, 2] = np.clip(
689
+ img_hsv[:, :, 2] * (1 + brightness_coeff * snow_point),
690
+ 0,
691
+ max_value,
692
+ )
693
+
694
+ # Generate snow texture
695
+ snow_texture = cv2.GaussianBlur(snow_texture, (0, 0), sigmaX=1, sigmaY=1)
696
+
697
+ # Create depth effect for snow simulation
698
+ # More snow accumulates at the top of the image, gradually decreasing towards the bottom
699
+ # This simulates natural snow distribution on surfaces
700
+ # The effect is achieved using a linear gradient from 1 (full snow) to 0.2 (less snow)
701
+ rows = img.shape[0]
702
+ depth_effect = np.linspace(1, 0.2, rows)[:, np.newaxis]
703
+ snow_texture *= depth_effect
704
+
705
+ # Apply snow texture
706
+ snow_layer = (np.dstack([snow_texture] * 3) * max_value * snow_point).astype(
707
+ np.float32,
708
+ )
709
+
710
+ # Blend snow with original image
711
+ img_with_snow = cv2.add(img_hsv, snow_layer)
712
+
713
+ # Add a slight blue tint to simulate cool snow color
714
+ blue_tint = np.full_like(img_with_snow, (0.6, 0.75, 1)) # Slight blue in HSV
715
+
716
+ img_with_snow = cv2.addWeighted(
717
+ img_with_snow,
718
+ 0.85,
719
+ blue_tint,
720
+ 0.15 * snow_point,
721
+ 0,
722
+ )
723
+
724
+ # Convert back to RGB
725
+ img_with_snow = cv2.cvtColor(img_with_snow.astype(np.uint8), cv2.COLOR_HSV2RGB)
726
+
727
+ # Add some sparkle effects for snow glitter
728
+ img_with_snow[sparkle_mask] = [max_value, max_value, max_value]
729
+
730
+ return img_with_snow
731
+
732
+
733
+ @uint8_io
734
+ @preserve_channel_dim
735
+ def add_rain(
736
+ img: np.ndarray,
737
+ slant: float,
738
+ drop_length: int,
739
+ drop_width: int,
740
+ drop_color: tuple[int, int, int],
741
+ blur_value: int,
742
+ brightness_coefficient: float,
743
+ rain_drops: np.ndarray,
744
+ ) -> np.ndarray:
745
+ """Add rain to an image.
746
+
747
+ This function adds rain to an image by drawing rain drops on the image.
748
+ The rain drops are drawn using the OpenCV function cv2.polylines.
749
+
750
+ Args:
751
+ img (np.ndarray): The image to add rain to.
752
+ slant (float): The slant of the rain drops.
753
+ drop_length (int): The length of the rain drops.
754
+ drop_width (int): The width of the rain drops.
755
+ drop_color (tuple[int, int, int]): The color of the rain drops.
756
+ blur_value (int): The blur value of the rain drops.
757
+ brightness_coefficient (float): The brightness coefficient of the rain drops.
758
+ rain_drops (np.ndarray): The rain drops to draw on the image.
759
+
760
+ Returns:
761
+ np.ndarray: The image with rain added.
762
+
763
+ """
764
+ if not rain_drops.size:
765
+ return img.copy()
766
+
767
+ img = img.copy()
768
+
769
+ # Pre-allocate rain layer
770
+ rain_layer = np.zeros_like(img, dtype=np.uint8)
771
+
772
+ # Calculate end points correctly
773
+ end_points = rain_drops + np.array([[slant, drop_length]]) # This creates correct shape
774
+
775
+ # Stack arrays properly - both must be same shape arrays
776
+ lines = np.stack((rain_drops, end_points), axis=1) # Use tuple and proper axis
777
+
778
+ cv2.polylines(
779
+ rain_layer,
780
+ lines.astype(np.int32),
781
+ False,
782
+ drop_color,
783
+ drop_width,
784
+ lineType=cv2.LINE_4,
785
+ )
786
+
787
+ if blur_value > 1:
788
+ cv2.blur(rain_layer, (blur_value, blur_value), dst=rain_layer)
789
+
790
+ cv2.add(img, rain_layer, dst=img)
791
+
792
+ if brightness_coefficient != 1.0:
793
+ cv2.multiply(img, brightness_coefficient, dst=img, dtype=cv2.CV_8U)
794
+
795
+ return img
796
+
797
+
798
+ def get_fog_particle_radiuses(
799
+ img_shape: tuple[int, int],
800
+ num_particles: int,
801
+ fog_intensity: float,
802
+ random_generator: np.random.Generator,
803
+ ) -> list[int]:
804
+ """Generate radiuses for fog particles.
805
+
806
+ Args:
807
+ img_shape (tuple[int, int]): Image shape.
808
+ num_particles (int): Number of fog particles.
809
+ fog_intensity (float): Intensity of the fog effect, between 0 and 1.
810
+ random_generator (np.random.Generator): Random generator to use.
811
+
812
+ Returns:
813
+ list[int]: List of radiuses for each fog particle.
814
+
815
+ """
816
+ height, width = img_shape[:2]
817
+ max_fog_radius = max(2, int(min(height, width) * 0.1 * fog_intensity))
818
+ min_radius = max(1, max_fog_radius // 2)
819
+
820
+ return [random_generator.integers(min_radius, max_fog_radius) for _ in range(num_particles)]
821
+
822
+
823
+ @uint8_io
824
+ @clipped
825
+ @preserve_channel_dim
826
+ def add_fog(
827
+ img: np.ndarray,
828
+ fog_intensity: float,
829
+ alpha_coef: float,
830
+ fog_particle_positions: list[tuple[int, int]],
831
+ fog_particle_radiuses: list[int],
832
+ ) -> np.ndarray:
833
+ """Add fog to an image.
834
+
835
+ This function adds fog to an image by drawing fog particles on the image.
836
+ The fog particles are drawn using the OpenCV function cv2.circle.
837
+
838
+ Args:
839
+ img (np.ndarray): The image to add fog to.
840
+ fog_intensity (float): The intensity of the fog effect, between 0 and 1.
841
+ alpha_coef (float): The coefficient for the alpha blending.
842
+ fog_particle_positions (list[tuple[int, int]]): The positions of the fog particles.
843
+ fog_particle_radiuses (list[int]): The radiuses of the fog particles.
844
+
845
+ Returns:
846
+ np.ndarray: The image with fog added.
847
+
848
+ """
849
+ result = img.copy()
850
+
851
+ # Apply fog particles progressively like in old version
852
+ for (x, y), radius in zip(fog_particle_positions, fog_particle_radiuses):
853
+ overlay = result.copy()
854
+ cv2.circle(
855
+ overlay,
856
+ center=(x, y),
857
+ radius=radius,
858
+ color=(255, 255, 255),
859
+ thickness=-1,
860
+ )
861
+
862
+ # Progressive blending
863
+ alpha = alpha_coef * fog_intensity
864
+ cv2.addWeighted(overlay, alpha, result, 1 - alpha, 0, dst=result)
865
+
866
+ # Final subtle blur
867
+ blur_size = max(3, int(min(img.shape[:2]) // 30))
868
+ if blur_size % 2 == 0:
869
+ blur_size += 1
870
+
871
+ result = cv2.GaussianBlur(result, (blur_size, blur_size), 0)
872
+
873
+ return clip(result, np.uint8, inplace=True)
874
+
875
+
876
+ @uint8_io
877
+ @preserve_channel_dim
878
+ @maybe_process_in_chunks
879
+ def add_sun_flare_overlay(
880
+ img: np.ndarray,
881
+ flare_center: tuple[float, float],
882
+ src_radius: int,
883
+ src_color: tuple[int, ...],
884
+ circles: list[Any],
885
+ ) -> np.ndarray:
886
+ """Add a sun flare effect to an image using a simple overlay technique.
887
+
888
+ This function creates a basic sun flare effect by overlaying multiple semi-transparent
889
+ circles of varying sizes and intensities on the input image. The effect simulates
890
+ a simple lens flare caused by bright light sources.
891
+
892
+ Args:
893
+ img (np.ndarray): The input image.
894
+ flare_center (tuple[float, float]): (x, y) coordinates of the flare center
895
+ in pixel coordinates.
896
+ src_radius (int): The radius of the main sun circle in pixels.
897
+ src_color (tuple[int, ...]): The color of the sun, represented as a tuple of RGB values.
898
+ circles (list[Any]): A list of tuples, each representing a circle that contributes
899
+ to the flare effect. Each tuple contains:
900
+ - alpha (float): The transparency of the circle (0.0 to 1.0).
901
+ - center (tuple[int, int]): (x, y) coordinates of the circle center.
902
+ - radius (int): The radius of the circle.
903
+ - color (tuple[int, int, int]): RGB color of the circle.
904
+
905
+ Returns:
906
+ np.ndarray: The output image with the sun flare effect added.
907
+
908
+ Note:
909
+ - This function uses a simple alpha blending technique to overlay flare elements.
910
+ - The main sun is created as a gradient circle, fading from the center outwards.
911
+ - Additional flare circles are added along an imaginary line from the sun's position.
912
+ - This method is computationally efficient but may produce less realistic results
913
+ compared to more advanced techniques.
914
+
915
+ The flare effect is created through the following steps:
916
+ 1. Create an overlay image and output image as copies of the input.
917
+ 2. Add smaller flare circles to the overlay.
918
+ 3. Blend the overlay with the output image using alpha compositing.
919
+ 4. Add the main sun circle with a radial gradient.
920
+
921
+ Examples:
922
+ >>> import numpy as np
923
+ >>> import albumentations as A
924
+ >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
925
+ >>> flare_center = (50, 50)
926
+ >>> src_radius = 20
927
+ >>> src_color = (255, 255, 200)
928
+ >>> circles = [
929
+ ... (0.1, (60, 60), 5, (255, 200, 200)),
930
+ ... (0.2, (70, 70), 3, (200, 255, 200))
931
+ ... ]
932
+ >>> flared_image = A.functional.add_sun_flare_overlay(
933
+ ... image, flare_center, src_radius, src_color, circles
934
+ ... )
935
+
936
+ References:
937
+ - Alpha compositing: https://en.wikipedia.org/wiki/Alpha_compositing
938
+ - Lens flare: https://en.wikipedia.org/wiki/Lens_flare
939
+
940
+ """
941
+ overlay = img.copy()
942
+ output = img.copy()
943
+
944
+ weighted_brightness = 0.0
945
+ total_radius_length = 0.0
946
+
947
+ for alpha, (x, y), rad3, circle_color in circles:
948
+ weighted_brightness += alpha * rad3
949
+ total_radius_length += rad3
950
+ cv2.circle(overlay, (x, y), rad3, circle_color, -1)
951
+ output = add_weighted(overlay, alpha, output, 1 - alpha)
952
+
953
+ point = [int(x) for x in flare_center]
954
+
955
+ overlay = output.copy()
956
+ num_times = src_radius // 10
957
+
958
+ # max_alpha is calculated using weighted_brightness and total_radii_length times 5
959
+ # meaning the higher the alpha with larger area, the brighter the bright spot will be
960
+ # for list of alphas in range [0.05, 0.2], the max_alpha should below 1
961
+ max_alpha = weighted_brightness / total_radius_length * 5
962
+ alpha = np.linspace(0.0, min(max_alpha, 1.0), num=num_times)
963
+
964
+ rad = np.linspace(1, src_radius, num=num_times)
965
+
966
+ for i in range(num_times):
967
+ cv2.circle(overlay, point, int(rad[i]), src_color, -1)
968
+ alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
969
+ output = add_weighted(overlay, alp, output, 1 - alp)
970
+
971
+ return output
972
+
973
+
974
+ @uint8_io
975
+ @clipped
976
+ def add_sun_flare_physics_based(
977
+ img: np.ndarray,
978
+ flare_center: tuple[int, int],
979
+ src_radius: int,
980
+ src_color: tuple[int, int, int],
981
+ circles: list[Any],
982
+ ) -> np.ndarray:
983
+ """Add a more realistic sun flare effect to the image.
984
+
985
+ This function creates a complex sun flare effect by simulating various optical phenomena
986
+ that occur in real camera lenses when capturing bright light sources. The result is a
987
+ more realistic and physically plausible lens flare effect.
988
+
989
+ Args:
990
+ img (np.ndarray): Input image.
991
+ flare_center (tuple[int, int]): (x, y) coordinates of the sun's center in pixels.
992
+ src_radius (int): Radius of the main sun circle in pixels.
993
+ src_color (tuple[int, int, int]): Color of the sun in RGB format.
994
+ circles (list[Any]): List of tuples, each representing a flare circle with parameters:
995
+ (alpha, center, size, color)
996
+ - alpha (float): Transparency of the circle (0.0 to 1.0).
997
+ - center (tuple[int, int]): (x, y) coordinates of the circle center.
998
+ - size (float): Size factor for the circle radius.
999
+ - color (tuple[int, int, int]): RGB color of the circle.
1000
+
1001
+ Returns:
1002
+ np.ndarray: Image with added sun flare effect.
1003
+
1004
+ Note:
1005
+ This function implements several techniques to create a more realistic flare:
1006
+ 1. Separate flare layer: Allows for complex manipulations of the flare effect.
1007
+ 2. Lens diffraction spikes: Simulates light diffraction in camera aperture.
1008
+ 3. Radial gradient mask: Creates natural fading of the flare from the center.
1009
+ 4. Gaussian blur: Softens the flare for a more natural glow effect.
1010
+ 5. Chromatic aberration: Simulates color fringing often seen in real lens flares.
1011
+ 6. Screen blending: Provides a more realistic blending of the flare with the image.
1012
+
1013
+ The flare effect is created through the following steps:
1014
+ 1. Create a separate flare layer.
1015
+ 2. Add the main sun circle and diffraction spikes to the flare layer.
1016
+ 3. Add additional flare circles based on the input parameters.
1017
+ 4. Apply Gaussian blur to soften the flare.
1018
+ 5. Create and apply a radial gradient mask for natural fading.
1019
+ 6. Simulate chromatic aberration by applying different blurs to color channels.
1020
+ 7. Blend the flare with the original image using screen blending mode.
1021
+
1022
+ Examples:
1023
+ >>> import numpy as np
1024
+ >>> import albumentations as A
1025
+ >>> image = np.random.randint(0, 256, [1000, 1000, 3], dtype=np.uint8)
1026
+ >>> flare_center = (500, 500)
1027
+ >>> src_radius = 50
1028
+ >>> src_color = (255, 255, 200)
1029
+ >>> circles = [
1030
+ ... (0.1, (550, 550), 10, (255, 200, 200)),
1031
+ ... (0.2, (600, 600), 5, (200, 255, 200))
1032
+ ... ]
1033
+ >>> flared_image = A.functional.add_sun_flare_physics_based(
1034
+ ... image, flare_center, src_radius, src_color, circles
1035
+ ... )
1036
+
1037
+ References:
1038
+ - Lens flare: https://en.wikipedia.org/wiki/Lens_flare
1039
+ - Diffraction: https://en.wikipedia.org/wiki/Diffraction
1040
+ - Chromatic aberration: https://en.wikipedia.org/wiki/Chromatic_aberration
1041
+ - Screen blending: https://en.wikipedia.org/wiki/Blend_modes#Screen
1042
+
1043
+ """
1044
+ output = img.copy()
1045
+ height, width = img.shape[:2]
1046
+
1047
+ # Create a separate flare layer
1048
+ flare_layer = np.zeros_like(img, dtype=np.float32)
1049
+
1050
+ # Add the main sun
1051
+ cv2.circle(flare_layer, flare_center, src_radius, src_color, -1)
1052
+
1053
+ # Add lens diffraction spikes
1054
+ for angle in [0, 45, 90, 135]:
1055
+ end_point = (
1056
+ int(flare_center[0] + np.cos(np.radians(angle)) * max(width, height)),
1057
+ int(flare_center[1] + np.sin(np.radians(angle)) * max(width, height)),
1058
+ )
1059
+ cv2.line(flare_layer, flare_center, end_point, src_color, 2)
1060
+
1061
+ # Add flare circles
1062
+ for _, center, size, color in circles:
1063
+ cv2.circle(flare_layer, center, int(size**0.33), color, -1)
1064
+
1065
+ # Apply gaussian blur to soften the flare
1066
+ flare_layer = cv2.GaussianBlur(flare_layer, (0, 0), sigmaX=15, sigmaY=15)
1067
+
1068
+ # Create a radial gradient mask
1069
+ y, x = np.ogrid[:height, :width]
1070
+ mask = np.sqrt((x - flare_center[0]) ** 2 + (y - flare_center[1]) ** 2)
1071
+ mask = 1 - np.clip(mask / (max(width, height) * 0.7), 0, 1)
1072
+ mask = np.dstack([mask] * 3)
1073
+
1074
+ # Apply the mask to the flare layer
1075
+ flare_layer *= mask
1076
+
1077
+ # Add chromatic aberration
1078
+ channels = list(cv2.split(flare_layer))
1079
+ channels[0] = cv2.GaussianBlur(
1080
+ channels[0],
1081
+ (0, 0),
1082
+ sigmaX=3,
1083
+ sigmaY=3,
1084
+ ) # Blue channel
1085
+ channels[2] = cv2.GaussianBlur(
1086
+ channels[2],
1087
+ (0, 0),
1088
+ sigmaX=5,
1089
+ sigmaY=5,
1090
+ ) # Red channel
1091
+ flare_layer = cv2.merge(channels)
1092
+
1093
+ # Blend the flare with the original image using screen blending
1094
+ return 255 - ((255 - output) * (255 - flare_layer) / 255)
1095
+
1096
+
1097
+ @uint8_io
1098
+ @preserve_channel_dim
1099
+ def add_shadow(
1100
+ img: np.ndarray,
1101
+ vertices_list: list[np.ndarray],
1102
+ intensities: np.ndarray,
1103
+ ) -> np.ndarray:
1104
+ """Add shadows to the image by reducing the intensity of the pixel values in specified regions.
1105
+
1106
+ Args:
1107
+ img (np.ndarray): Input image. Multichannel images are supported.
1108
+ vertices_list (list[np.ndarray]): List of vertices for shadow polygons.
1109
+ intensities (np.ndarray): Array of shadow intensities. Range is [0, 1].
1110
+
1111
+ Returns:
1112
+ np.ndarray: Image with shadows added.
1113
+
1114
+ References:
1115
+ Automold--Road-Augmentation-Library: https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
1116
+
1117
+ """
1118
+ num_channels = get_num_channels(img)
1119
+ max_value = MAX_VALUES_BY_DTYPE[np.uint8]
1120
+
1121
+ img_shadowed = img.copy()
1122
+
1123
+ # Iterate over the vertices and intensity list
1124
+ for vertices, shadow_intensity in zip(vertices_list, intensities):
1125
+ # Create mask for the current shadow polygon
1126
+ mask = np.zeros((img.shape[0], img.shape[1], 1), dtype=np.uint8)
1127
+ cv2.fillPoly(mask, [vertices], (max_value,))
1128
+
1129
+ # Duplicate the mask to have the same number of channels as the image
1130
+ mask = np.repeat(mask, num_channels, axis=2)
1131
+
1132
+ # Apply shadow to the channels directly
1133
+ # It could be tempting to convert to HLS and apply the shadow to the L channel, but it creates artifacts
1134
+ shadowed_indices = mask[:, :, 0] == max_value
1135
+ darkness = 1 - shadow_intensity
1136
+ img_shadowed[shadowed_indices] = clip(
1137
+ img_shadowed[shadowed_indices] * darkness,
1138
+ np.uint8,
1139
+ inplace=True,
1140
+ )
1141
+
1142
+ return img_shadowed
1143
+
1144
+
1145
+ @uint8_io
1146
+ @clipped
1147
+ @preserve_channel_dim
1148
+ def add_gravel(img: np.ndarray, gravels: list[Any]) -> np.ndarray:
1149
+ """Add gravel to an image.
1150
+
1151
+ This function adds gravel to an image by drawing gravel particles on the image.
1152
+ The gravel particles are drawn using the OpenCV function cv2.circle.
1153
+
1154
+ Args:
1155
+ img (np.ndarray): The image to add gravel to.
1156
+ gravels (list[Any]): The gravel particles to draw on the image.
1157
+
1158
+ Returns:
1159
+ np.ndarray: The image with gravel added.
1160
+
1161
+ """
1162
+ non_rgb_error(img)
1163
+ image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
1164
+
1165
+ for gravel in gravels:
1166
+ min_y, max_y, min_x, max_x, sat = gravel
1167
+ image_hls[min_y:max_y, min_x:max_x, 1] = sat
1168
+
1169
+ return cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
1170
+
1171
+
1172
+ def invert(img: np.ndarray) -> np.ndarray:
1173
+ """Invert the colors of an image.
1174
+
1175
+ This function inverts the colors of an image by subtracting each pixel value from the maximum possible value.
1176
+ The result is a negative of the original image.
1177
+
1178
+ Args:
1179
+ img (np.ndarray): The image to invert.
1180
+
1181
+ Returns:
1182
+ np.ndarray: The inverted image.
1183
+
1184
+ """
1185
+ # Supports all the valid dtypes
1186
+ # clips the img to avoid unexpected behaviour.
1187
+ return MAX_VALUES_BY_DTYPE[img.dtype] - img
1188
+
1189
+
1190
+ def channel_shuffle(img: np.ndarray, channels_shuffled: list[int]) -> np.ndarray:
1191
+ """Shuffle the channels of an image.
1192
+
1193
+ This function shuffles the channels of an image by using the cv2.mixChannels function.
1194
+ The channels are shuffled according to the channels_shuffled array.
1195
+
1196
+ Args:
1197
+ img (np.ndarray): The image to shuffle.
1198
+ channels_shuffled (np.ndarray): The array of channels to shuffle.
1199
+
1200
+ Returns:
1201
+ np.ndarray: The shuffled image.
1202
+
1203
+ """
1204
+ output = np.empty_like(img)
1205
+ from_to = []
1206
+ for i, j in enumerate(channels_shuffled):
1207
+ from_to.extend([j, i]) # Use [src, dst]
1208
+ cv2.mixChannels([img], [output], from_to)
1209
+ return output
1210
+
1211
+
1212
+ def volume_channel_shuffle(volume: np.ndarray, channels_shuffled: Sequence[int]) -> np.ndarray:
1213
+ """Shuffle channels of a single volume (D, H, W, C) or (D, H, W).
1214
+
1215
+ Args:
1216
+ volume (np.ndarray): Input volume.
1217
+ channels_shuffled (Sequence[int]): New channel order.
1218
+
1219
+ Returns:
1220
+ np.ndarray: Volume with channels shuffled.
1221
+
1222
+ """
1223
+ return volume.copy()[..., channels_shuffled] if volume.ndim == 4 else volume
1224
+
1225
+
1226
+ def volumes_channel_shuffle(volumes: np.ndarray, channels_shuffled: Sequence[int]) -> np.ndarray:
1227
+ """Shuffle channels of a batch of volumes (B, D, H, W, C) or (B, D, H, W).
1228
+
1229
+ Args:
1230
+ volumes (np.ndarray): Input batch of volumes.
1231
+ channels_shuffled (Sequence[int]): New channel order.
1232
+
1233
+ Returns:
1234
+ np.ndarray: Batch of volumes with channels shuffled.
1235
+
1236
+ """
1237
+ return volumes.copy()[..., channels_shuffled] if volumes.ndim == 5 else volumes
1238
+
1239
+
1240
+ def gamma_transform(img: np.ndarray, gamma: float) -> np.ndarray:
1241
+ """Apply gamma transformation to an image.
1242
+
1243
+ This function applies gamma transformation to an image by raising each pixel value to the power of gamma.
1244
+ The result is a non-linear transformation that can enhance or reduce the contrast of the image.
1245
+
1246
+ Args:
1247
+ img (np.ndarray): The image to apply gamma transformation to.
1248
+ gamma (float): The gamma value to apply.
1249
+
1250
+ Returns:
1251
+ np.ndarray: The gamma transformed image.
1252
+
1253
+ """
1254
+ if img.dtype == np.uint8:
1255
+ table = (np.arange(0, 256.0 / 255, 1.0 / 255) ** gamma) * 255
1256
+ return sz_lut(img, table.astype(np.uint8), inplace=False)
1257
+
1258
+ return np.power(img, gamma)
1259
+
1260
+
1261
+ @float32_io
1262
+ @clipped
1263
+ def iso_noise(
1264
+ image: np.ndarray,
1265
+ color_shift: float,
1266
+ intensity: float,
1267
+ random_generator: np.random.Generator,
1268
+ ) -> np.ndarray:
1269
+ """Apply poisson noise to an image to simulate camera sensor noise.
1270
+
1271
+ Args:
1272
+ image (np.ndarray): Input image. Currently, only RGB images are supported.
1273
+ color_shift (float): The amount of color shift to apply.
1274
+ intensity (float): Multiplication factor for noise values. Values of ~0.5 produce a noticeable,
1275
+ yet acceptable level of noise.
1276
+ random_generator (np.random.Generator): If specified, this will be random generator used
1277
+ for noise generation.
1278
+
1279
+ Returns:
1280
+ np.ndarray: The noised image.
1281
+
1282
+ Image types:
1283
+ uint8, float32
1284
+
1285
+ Number of channels:
1286
+ 3
1287
+
1288
+ """
1289
+ hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS)
1290
+ _, stddev = cv2.meanStdDev(hls)
1291
+
1292
+ luminance_noise = random_generator.poisson(
1293
+ stddev[1] * intensity,
1294
+ size=hls.shape[:2],
1295
+ )
1296
+ color_noise = random_generator.normal(
1297
+ 0,
1298
+ color_shift * intensity,
1299
+ size=hls.shape[:2],
1300
+ )
1301
+
1302
+ hls[..., 0] += color_noise
1303
+ hls[..., 1] = add_array(
1304
+ hls[..., 1],
1305
+ luminance_noise * intensity * (1.0 - hls[..., 1]),
1306
+ )
1307
+
1308
+ noised_hls = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB)
1309
+ return np.clip(noised_hls, 0, 1, out=noised_hls) # Ensure output is in [0, 1] range
1310
+
1311
+
1312
+ def to_gray_weighted_average(img: np.ndarray) -> np.ndarray:
1313
+ """Convert an RGB image to grayscale using the weighted average method.
1314
+
1315
+ This function uses OpenCV's cvtColor function with COLOR_RGB2GRAY conversion,
1316
+ which applies the following formula:
1317
+ Y = 0.299*R + 0.587*G + 0.114*B
1318
+
1319
+ Args:
1320
+ img (np.ndarray): Input RGB image as a numpy array.
1321
+
1322
+ Returns:
1323
+ np.ndarray: Grayscale image as a 2D numpy array.
1324
+
1325
+ Image types:
1326
+ uint8, float32
1327
+
1328
+ Number of channels:
1329
+ 3
1330
+
1331
+ """
1332
+ return cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
1333
+
1334
+
1335
+ @uint8_io
1336
+ @clipped
1337
+ def to_gray_from_lab(img: np.ndarray) -> np.ndarray:
1338
+ """Convert an RGB image or batch of images to grayscale using LAB color space.
1339
+
1340
+ This function converts RGB images to grayscale by first converting to LAB color space
1341
+ and then extracting the L (lightness) channel. It uses albucore's reshape utilities
1342
+ to efficiently handle batches/volumes by processing them as a single tall image.
1343
+
1344
+ Implementation Details:
1345
+ The function uses albucore's reshape_for_channel and restore_from_channel functions:
1346
+ - reshape_for_channel: Flattens batches/volumes to 2D format for OpenCV processing
1347
+ - restore_from_channel: Restores the original shape after processing
1348
+
1349
+ This enables processing all images in a single OpenCV call
1350
+
1351
+ Args:
1352
+ img: Input RGB image(s) as a numpy array. Must have 3 channels in the last dimension.
1353
+ Supported shapes:
1354
+ - Single image: (H, W, 3)
1355
+ - Batch of images: (N, H, W, 3)
1356
+ - Volume: (D, H, W, 3)
1357
+ - Batch of volumes: (N, D, H, W, 3)
1358
+
1359
+ Supported dtypes:
1360
+ - np.uint8: Values in range [0, 255]
1361
+ - np.float32: Values in range [0, 1]
1362
+
1363
+ Returns:
1364
+ Grayscale image(s) with the same spatial dimensions as input but without channel dimension:
1365
+ - Single image: (H, W)
1366
+ - Batch of images: (N, H, W)
1367
+ - Volume: (D, H, W)
1368
+ - Batch of volumes: (N, D, H, W)
1369
+
1370
+ The output dtype matches the input dtype. For float inputs, the L channel
1371
+ is normalized to [0, 1] by dividing by 100.
1372
+
1373
+ Raises:
1374
+ ValueError: If the last dimension is not 3 (RGB channels)
1375
+
1376
+ Examples:
1377
+ >>> # Single image
1378
+ >>> img = np.random.randint(0, 256, (100, 100, 3), dtype=np.uint8)
1379
+ >>> gray = to_gray_from_lab(img)
1380
+ >>> assert gray.shape == (100, 100)
1381
+
1382
+ >>> # Batch of images - efficiently processed without loops
1383
+ >>> batch = np.random.randint(0, 256, (10, 100, 100, 3), dtype=np.uint8)
1384
+ >>> gray_batch = to_gray_from_lab(batch)
1385
+ >>> assert gray_batch.shape == (10, 100, 100)
1386
+
1387
+ >>> # Volume (e.g., video frames or 3D medical data)
1388
+ >>> volume = np.random.randint(0, 256, (16, 100, 100, 3), dtype=np.uint8)
1389
+ >>> gray_volume = to_gray_from_lab(volume)
1390
+ >>> assert gray_volume.shape == (16, 100, 100)
1391
+
1392
+ >>> # Float32 input
1393
+ >>> img_float = img.astype(np.float32) / 255.0
1394
+ >>> gray_float = to_gray_from_lab(img_float)
1395
+ >>> assert 0 <= gray_float.min() <= gray_float.max() <= 1.0
1396
+
1397
+ Note:
1398
+ The LAB color space provides perceptually uniform grayscale conversion,
1399
+ where the L (lightness) channel represents human perception of brightness
1400
+ better than simple RGB averaging or other methods.
1401
+
1402
+ """
1403
+ original_dtype = img.dtype
1404
+ ndim = img.ndim
1405
+
1406
+ # Handle single image case by adding a batch dimension
1407
+ if ndim == 3:
1408
+ # Add batch dimension to make it (1, H, W, C)
1409
+ return cv2.cvtColor(img, cv2.COLOR_RGB2LAB)[..., 0]
1410
+
1411
+ # Determine dimensions for reshape_for_channel
1412
+ if ndim == 4:
1413
+ # Batch of images (N, H, W, C) or single image with added batch dimension
1414
+ has_batch_dim = True
1415
+ has_depth_dim = False
1416
+ elif ndim == 5:
1417
+ # Batch of volumes (N, D, H, W, C)
1418
+ has_batch_dim = True
1419
+ has_depth_dim = True
1420
+
1421
+ # Use reshape utilities from albucore for efficient batch processing
1422
+ flattened, original_shape = reshape_for_channel(img, has_batch_dim=has_batch_dim, has_depth_dim=has_depth_dim)
1423
+
1424
+ lab = cv2.cvtColor(flattened, cv2.COLOR_RGB2LAB)
1425
+
1426
+ grayscale_flat = lab[..., 0]
1427
+ grayscale = restore_from_channel(
1428
+ grayscale_flat,
1429
+ original_shape,
1430
+ has_batch_dim=has_batch_dim,
1431
+ has_depth_dim=has_depth_dim,
1432
+ )
1433
+
1434
+ return grayscale / 100.0 if original_dtype == np.float32 else grayscale
1435
+
1436
+
1437
+ @clipped
1438
+ def to_gray_desaturation(img: np.ndarray) -> np.ndarray:
1439
+ """Convert an image to grayscale using the desaturation method.
1440
+
1441
+ Args:
1442
+ img (np.ndarray): Input image as a numpy array.
1443
+
1444
+ Returns:
1445
+ np.ndarray: Grayscale image as a 2D numpy array.
1446
+
1447
+ Image types:
1448
+ uint8, float32
1449
+
1450
+ Number of channels:
1451
+ any
1452
+
1453
+ """
1454
+ float_image = img.astype(np.float32)
1455
+ return (np.max(float_image, axis=-1) + np.min(float_image, axis=-1)) / 2
1456
+
1457
+
1458
+ def to_gray_average(img: np.ndarray) -> np.ndarray:
1459
+ """Convert an image to grayscale using the average method.
1460
+
1461
+ This function computes the arithmetic mean across all channels for each pixel,
1462
+ resulting in a grayscale representation of the image.
1463
+
1464
+ Key aspects of this method:
1465
+ 1. It treats all channels equally, regardless of their perceptual importance.
1466
+ 2. Works with any number of channels, making it versatile for various image types.
1467
+ 3. Simple and fast to compute, but may not accurately represent perceived brightness.
1468
+ 4. For RGB images, the formula is: Gray = (R + G + B) / 3
1469
+
1470
+ Note: This method may produce different results compared to weighted methods
1471
+ (like RGB weighted average) which account for human perception of color brightness.
1472
+ It may also produce unexpected results for images with alpha channels or
1473
+ non-color data in additional channels.
1474
+
1475
+ Args:
1476
+ img (np.ndarray): Input image as a numpy array. Can be any number of channels.
1477
+
1478
+ Returns:
1479
+ np.ndarray: Grayscale image as a 2D numpy array. The output data type
1480
+ matches the input data type.
1481
+
1482
+ Image types:
1483
+ uint8, float32
1484
+
1485
+ Number of channels:
1486
+ any
1487
+
1488
+ """
1489
+ return np.mean(img, axis=-1).astype(img.dtype)
1490
+
1491
+
1492
+ def to_gray_max(img: np.ndarray) -> np.ndarray:
1493
+ """Convert an image to grayscale using the maximum channel value method.
1494
+
1495
+ This function takes the maximum value across all channels for each pixel,
1496
+ resulting in a grayscale image that preserves the brightest parts of the original image.
1497
+
1498
+ Key aspects of this method:
1499
+ 1. Works with any number of channels, making it versatile for various image types.
1500
+ 2. For 3-channel (e.g., RGB) images, this method is equivalent to extracting the V (Value)
1501
+ channel from the HSV color space.
1502
+ 3. Preserves the brightest parts of the image but may lose some color contrast information.
1503
+ 4. Simple and fast to compute.
1504
+
1505
+ Note:
1506
+ - This method tends to produce brighter grayscale images compared to other conversion methods,
1507
+ as it always selects the highest intensity value from the channels.
1508
+ - For RGB images, it may not accurately represent perceived brightness as it doesn't
1509
+ account for human color perception.
1510
+
1511
+ Args:
1512
+ img (np.ndarray): Input image as a numpy array. Can be any number of channels.
1513
+
1514
+ Returns:
1515
+ np.ndarray: Grayscale image as a 2D numpy array. The output data type
1516
+ matches the input data type.
1517
+
1518
+ Image types:
1519
+ uint8, float32
1520
+
1521
+ Number of channels:
1522
+ any
1523
+
1524
+ """
1525
+ return np.max(img, axis=-1)
1526
+
1527
+
1528
+ @clipped
1529
+ def to_gray_pca(img: np.ndarray) -> np.ndarray:
1530
+ """Convert an image to grayscale using Principal Component Analysis (PCA).
1531
+
1532
+ This function applies PCA to reduce a multi-channel image to a single channel,
1533
+ effectively creating a grayscale representation that captures the maximum variance
1534
+ in the color data.
1535
+
1536
+ Args:
1537
+ img (np.ndarray): Input image as a numpy array. Can be:
1538
+ - Single multi-channel image: (H, W, C)
1539
+ - Batch of multi-channel images: (N, H, W, C)
1540
+ - Single multi-channel volume: (D, H, W, C)
1541
+ - Batch of multi-channel volumes: (N, D, H, W, C)
1542
+
1543
+ Returns:
1544
+ np.ndarray: Grayscale image with the same spatial dimensions as input.
1545
+ If input is uint8, output is uint8 in range [0, 255].
1546
+ If input is float32, output is float32 in range [0, 1].
1547
+
1548
+ Note:
1549
+ This method can potentially preserve more information from the original image
1550
+ compared to standard weighted average methods, as it accounts for the
1551
+ correlations between color channels.
1552
+
1553
+ Image types:
1554
+ uint8, float32
1555
+
1556
+ Number of channels:
1557
+ any
1558
+
1559
+ """
1560
+ dtype = img.dtype
1561
+ # Reshape the image to a 2D array of pixels
1562
+ pixels = img.reshape(-1, img.shape[-1])
1563
+
1564
+ # Perform PCA
1565
+ pca = PCA(n_components=1)
1566
+ pca_result = pca.fit_transform(pixels)
1567
+
1568
+ # Reshape back to image dimensions and scale to 0-255
1569
+ grayscale = pca_result.reshape(img.shape[:-1])
1570
+ grayscale = normalize_per_image(grayscale, "min_max")
1571
+
1572
+ return from_float(grayscale, target_dtype=dtype) if dtype == np.uint8 else grayscale
1573
+
1574
+
1575
+ def to_gray(
1576
+ img: np.ndarray,
1577
+ num_output_channels: int,
1578
+ method: Literal[
1579
+ "weighted_average",
1580
+ "from_lab",
1581
+ "desaturation",
1582
+ "average",
1583
+ "max",
1584
+ "pca",
1585
+ ],
1586
+ ) -> np.ndarray:
1587
+ """Convert an image to grayscale using a specified method.
1588
+
1589
+ This function converts an image to grayscale using a specified method.
1590
+ The method can be one of the following:
1591
+ - "weighted_average": Use the weighted average method.
1592
+ - "from_lab": Use the L channel from the LAB color space.
1593
+ - "desaturation": Use the desaturation method.
1594
+ - "average": Use the average method.
1595
+ - "max": Use the maximum channel value method.
1596
+ - "pca": Use the Principal Component Analysis method.
1597
+
1598
+ Args:
1599
+ img (np.ndarray): Input image as a numpy array.
1600
+ num_output_channels (int): The number of channels in the output image.
1601
+ method (Literal["weighted_average", "from_lab", "desaturation", "average", "max", "pca"]):
1602
+ The method to use for grayscale conversion.
1603
+
1604
+ Returns:
1605
+ np.ndarray: Grayscale image as a 2D numpy array.
1606
+
1607
+ """
1608
+ if method == "weighted_average":
1609
+ result = to_gray_weighted_average(img)
1610
+ elif method == "from_lab":
1611
+ result = to_gray_from_lab(img)
1612
+ elif method == "desaturation":
1613
+ result = to_gray_desaturation(img)
1614
+ elif method == "average":
1615
+ result = to_gray_average(img)
1616
+ elif method == "max":
1617
+ result = to_gray_max(img)
1618
+ elif method == "pca":
1619
+ result = to_gray_pca(img)
1620
+ else:
1621
+ raise ValueError(f"Unsupported method: {method}")
1622
+
1623
+ return grayscale_to_multichannel(result, num_output_channels)
1624
+
1625
+
1626
+ def grayscale_to_multichannel(
1627
+ grayscale_image: np.ndarray,
1628
+ num_output_channels: int = 3,
1629
+ ) -> np.ndarray:
1630
+ """Convert a grayscale image to a multi-channel image.
1631
+
1632
+ This function takes a 2D grayscale image or a 3D image with a single channel
1633
+ and converts it to a multi-channel image by repeating the grayscale data
1634
+ across the specified number of channels.
1635
+
1636
+ Args:
1637
+ grayscale_image (np.ndarray): Input grayscale image. Can be 2D (height, width)
1638
+ or 3D (height, width, 1).
1639
+ num_output_channels (int, optional): Number of channels in the output image. Defaults to 3.
1640
+
1641
+ Returns:
1642
+ np.ndarray: Multi-channel image with shape (height, width, num_channels)
1643
+
1644
+ """
1645
+ # If output should be single channel, add channel dimension if needed
1646
+ if num_output_channels == 1:
1647
+ return grayscale_image
1648
+
1649
+ squeezed = np.squeeze(grayscale_image)
1650
+ # For multi-channel output, use tile for better performance
1651
+ return np.tile(squeezed[..., np.newaxis], (1,) * squeezed.ndim + (num_output_channels,))
1652
+
1653
+
1654
+ @preserve_channel_dim
1655
+ @uint8_io
1656
+ def downscale(
1657
+ img: np.ndarray,
1658
+ scale: float,
1659
+ down_interpolation: int,
1660
+ up_interpolation: int,
1661
+ ) -> np.ndarray:
1662
+ """Downscale and upscale an image.
1663
+
1664
+ This function downscales and upscales an image using the specified interpolation methods.
1665
+ The downscaling and upscaling are performed using the cv2.resize function.
1666
+
1667
+ Args:
1668
+ img (np.ndarray): Input image as a numpy array.
1669
+ scale (float): The scale factor for the downscaling and upscaling.
1670
+ down_interpolation (int): The interpolation method for the downscaling.
1671
+ up_interpolation (int): The interpolation method for the upscaling.
1672
+
1673
+ Returns:
1674
+ np.ndarray: The downscaled and upscaled image.
1675
+
1676
+ """
1677
+ height, width = img.shape[:2]
1678
+
1679
+ downscaled = cv2.resize(
1680
+ img,
1681
+ None,
1682
+ fx=scale,
1683
+ fy=scale,
1684
+ interpolation=down_interpolation,
1685
+ )
1686
+ return cv2.resize(downscaled, (width, height), interpolation=up_interpolation)
1687
+
1688
+
1689
+ def noop(input_obj: Any, **params: Any) -> Any:
1690
+ """No-op function.
1691
+
1692
+ This function is a no-op and returns the input object unchanged.
1693
+ It is used to satisfy the type checker requirements for the `noop` function.
1694
+
1695
+ Args:
1696
+ input_obj (Any): The input object to return unchanged.
1697
+ **params (Any): Additional keyword arguments.
1698
+
1699
+ Returns:
1700
+ Any: The input object unchanged.
1701
+
1702
+ """
1703
+ return input_obj
1704
+
1705
+
1706
+ @float32_io
1707
+ @clipped
1708
+ @preserve_channel_dim
1709
+ def fancy_pca(img: np.ndarray, alpha_vector: np.ndarray) -> np.ndarray:
1710
+ """Perform 'Fancy PCA' augmentation on an image with any number of channels.
1711
+
1712
+ Args:
1713
+ img (np.ndarray): Input image
1714
+ alpha_vector (np.ndarray): Vector of scale factors for each principal component.
1715
+ Should have the same length as the number of channels in the image.
1716
+
1717
+ Returns:
1718
+ np.ndarray: Augmented image of the same shape, type, and range as the input.
1719
+
1720
+ Image types:
1721
+ uint8, float32
1722
+
1723
+ Number of channels:
1724
+ Any
1725
+
1726
+ Note:
1727
+ - This function generalizes the Fancy PCA augmentation to work with any number of channels.
1728
+ - It preserves the original range of the image ([0, 255] for uint8, [0, 1] for float32).
1729
+ - For single-channel images, the augmentation is applied as a simple scaling of pixel intensity variation.
1730
+ - For multi-channel images, PCA is performed on the entire image, treating each pixel
1731
+ as a point in N-dimensional space (where N is the number of channels).
1732
+ - The augmentation preserves the correlation between channels while adding controlled noise.
1733
+ - Computation time may increase significantly for images with a large number of channels.
1734
+
1735
+ References:
1736
+ ImageNet classification with deep convolutional neural networks: Krizhevsky, A., Sutskever, I.,
1737
+ & Hinton, G. E. (2012): In Advances in neural information processing systems (pp. 1097-1105).
1738
+
1739
+ """
1740
+ orig_shape = img.shape
1741
+ num_channels = get_num_channels(img)
1742
+
1743
+ # Reshape image to 2D array of pixels
1744
+ img_reshaped = img.reshape(-1, num_channels)
1745
+
1746
+ # Center the pixel values
1747
+ img_mean = np.mean(img_reshaped, axis=0)
1748
+ img_centered = img_reshaped - img_mean
1749
+
1750
+ if num_channels == 1:
1751
+ # For grayscale images, apply a simple scaling
1752
+ std_dev = np.std(img_centered)
1753
+ noise = alpha_vector[0] * std_dev * img_centered
1754
+ else:
1755
+ # Compute covariance matrix
1756
+ img_cov = np.cov(img_centered, rowvar=False)
1757
+
1758
+ # Compute eigenvectors & eigenvalues of the covariance matrix
1759
+ eig_vals, eig_vecs = np.linalg.eigh(img_cov)
1760
+
1761
+ # Sort eigenvectors by eigenvalues in descending order
1762
+ sort_perm = eig_vals[::-1].argsort()
1763
+ eig_vals = eig_vals[sort_perm]
1764
+ eig_vecs = eig_vecs[:, sort_perm]
1765
+
1766
+ # Create noise vector
1767
+ noise = np.dot(
1768
+ np.dot(eig_vecs, np.diag(alpha_vector * eig_vals)),
1769
+ img_centered.T,
1770
+ ).T
1771
+
1772
+ # Add noise to the image
1773
+ img_pca = img_reshaped + noise
1774
+
1775
+ # Reshape back to original shape
1776
+ img_pca = img_pca.reshape(orig_shape)
1777
+
1778
+ # Clip values to [0, 1] range
1779
+ return np.clip(img_pca, 0, 1, out=img_pca)
1780
+
1781
+
1782
+ @preserve_channel_dim
1783
+ def adjust_brightness_torchvision(img: np.ndarray, factor: np.ndarray) -> np.ndarray:
1784
+ """Adjust the brightness of an image.
1785
+
1786
+ This function adjusts the brightness of an image by multiplying each pixel value by a factor.
1787
+ The brightness is adjusted by multiplying the image by the factor.
1788
+
1789
+ Args:
1790
+ img (np.ndarray): Input image as a numpy array.
1791
+ factor (np.ndarray): The factor to adjust the brightness by.
1792
+
1793
+ Returns:
1794
+ np.ndarray: The adjusted image.
1795
+
1796
+ """
1797
+ if factor == 0:
1798
+ return np.zeros_like(img)
1799
+ if factor == 1:
1800
+ return img
1801
+
1802
+ return multiply(img, factor, inplace=False)
1803
+
1804
+
1805
+ @preserve_channel_dim
1806
+ def adjust_contrast_torchvision(img: np.ndarray, factor: float) -> np.ndarray:
1807
+ """Adjust the contrast of an image.
1808
+
1809
+ This function adjusts the contrast of an image by multiplying each pixel value by a factor.
1810
+ The contrast is adjusted by multiplying the image by the factor.
1811
+
1812
+ Args:
1813
+ img (np.ndarray): Input image as a numpy array.
1814
+ factor (float): The factor to adjust the contrast by.
1815
+
1816
+ Returns:
1817
+ np.ndarray: The adjusted image.
1818
+
1819
+ """
1820
+ if factor == 1:
1821
+ return img
1822
+
1823
+ mean = img.mean() if is_grayscale_image(img) else cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean()
1824
+
1825
+ if factor == 0:
1826
+ if img.dtype != np.float32:
1827
+ mean = int(mean + 0.5)
1828
+ return np.full_like(img, mean, dtype=img.dtype)
1829
+
1830
+ return multiply_add(img, factor, mean * (1 - factor), inplace=False)
1831
+
1832
+
1833
+ @clipped
1834
+ @preserve_channel_dim
1835
+ def adjust_saturation_torchvision(
1836
+ img: np.ndarray,
1837
+ factor: float,
1838
+ gamma: float = 0,
1839
+ ) -> np.ndarray:
1840
+ """Adjust the saturation of an image.
1841
+
1842
+ This function adjusts the saturation of an image by multiplying each pixel value by a factor.
1843
+ The saturation is adjusted by multiplying the image by the factor.
1844
+
1845
+ Args:
1846
+ img (np.ndarray): Input image as a numpy array.
1847
+ factor (float): The factor to adjust the saturation by.
1848
+ gamma (float): The gamma value to use for the adjustment.
1849
+
1850
+ Returns:
1851
+ np.ndarray: The adjusted image.
1852
+
1853
+ """
1854
+ if factor == 1 or is_grayscale_image(img):
1855
+ return img
1856
+
1857
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
1858
+ gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
1859
+
1860
+ return gray if factor == 0 else cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma)
1861
+
1862
+
1863
+ def _adjust_hue_torchvision_uint8(img: np.ndarray, factor: float) -> np.ndarray:
1864
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
1865
+
1866
+ lut = np.arange(0, 256, dtype=np.int16)
1867
+ lut = np.mod(lut + 180 * factor, 180).astype(np.uint8)
1868
+ img[..., 0] = sz_lut(img[..., 0], lut, inplace=False)
1869
+
1870
+ return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
1871
+
1872
+
1873
+ def adjust_hue_torchvision(img: np.ndarray, factor: float) -> np.ndarray:
1874
+ """Adjust the hue of an image.
1875
+
1876
+ This function adjusts the hue of an image by adding a factor to the hue value.
1877
+
1878
+ Args:
1879
+ img (np.ndarray): Input image.
1880
+ factor (float): The factor to adjust the hue by.
1881
+
1882
+ Returns:
1883
+ np.ndarray: The adjusted image.
1884
+
1885
+ """
1886
+ if is_grayscale_image(img) or factor == 0:
1887
+ return img
1888
+
1889
+ if img.dtype == np.uint8:
1890
+ return _adjust_hue_torchvision_uint8(img, factor)
1891
+
1892
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
1893
+ img[..., 0] = np.mod(img[..., 0] + factor * 360, 360)
1894
+ return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
1895
+
1896
+
1897
+ @uint8_io
1898
+ @preserve_channel_dim
1899
+ def superpixels(
1900
+ image: np.ndarray,
1901
+ n_segments: int,
1902
+ replace_samples: Sequence[bool],
1903
+ max_size: int | None,
1904
+ interpolation: int,
1905
+ ) -> np.ndarray:
1906
+ """Apply superpixels to an image.
1907
+
1908
+ This function applies superpixels to an image using the SLIC algorithm.
1909
+ The superpixels are applied by replacing the pixels in the image with the mean intensity of the superpixel.
1910
+
1911
+ Args:
1912
+ image (np.ndarray): Input image as a numpy array.
1913
+ n_segments (int): The number of segments to use for the superpixels.
1914
+ replace_samples (Sequence[bool]): The samples to replace.
1915
+ max_size (int | None): The maximum size of the superpixels.
1916
+ interpolation (int): The interpolation method to use.
1917
+
1918
+ Returns:
1919
+ np.ndarray: The superpixels applied to the image.
1920
+
1921
+ """
1922
+ if not np.any(replace_samples):
1923
+ return image
1924
+
1925
+ orig_shape = image.shape
1926
+ if max_size is not None:
1927
+ size = max(image.shape[:2])
1928
+ if size > max_size:
1929
+ scale = max_size / size
1930
+ height, width = image.shape[:2]
1931
+ new_height, new_width = int(height * scale), int(width * scale)
1932
+ image = fgeometric.resize(image, (new_height, new_width), interpolation)
1933
+
1934
+ segments = slic(
1935
+ image,
1936
+ n_segments=n_segments,
1937
+ compactness=10,
1938
+ )
1939
+
1940
+ min_value = 0
1941
+ max_value = MAX_VALUES_BY_DTYPE[image.dtype]
1942
+ image = np.copy(image)
1943
+
1944
+ if image.ndim == MONO_CHANNEL_DIMENSIONS:
1945
+ image = np.expand_dims(image, axis=-1)
1946
+
1947
+ num_channels = get_num_channels(image)
1948
+
1949
+ for c in range(num_channels):
1950
+ image_sp_c = image[..., c]
1951
+ # Get unique segment labels (skip 0 if it exists as it's typically background)
1952
+ unique_labels = np.unique(segments)
1953
+ if unique_labels[0] == 0:
1954
+ unique_labels = unique_labels[1:]
1955
+
1956
+ # Calculate mean intensity for each segment
1957
+ for idx, label in enumerate(unique_labels):
1958
+ # with mod here, because slic can sometimes create more superpixel than requested.
1959
+ # replace_samples then does not have enough values, so we just start over with the first one again.
1960
+ if replace_samples[idx % len(replace_samples)]:
1961
+ mask = segments == label
1962
+ mean_intensity = np.mean(image_sp_c[mask])
1963
+
1964
+ if image_sp_c.dtype.kind in ["i", "u", "b"]:
1965
+ # After rounding the value can end up slightly outside of the value_range. Hence, we need to clip.
1966
+ # We do clip via min(max(...)) instead of np.clip because
1967
+ # the latter one does not seem to keep dtypes for dtypes with large itemsizes (e.g. uint64).
1968
+ value: int | float
1969
+ value = int(np.round(mean_intensity))
1970
+ value = min(max(value, min_value), max_value)
1971
+ else:
1972
+ value = mean_intensity
1973
+
1974
+ image_sp_c[mask] = value
1975
+
1976
+ return fgeometric.resize(image, orig_shape[:2], interpolation) if orig_shape != image.shape else image
1977
+
1978
+
1979
+ @float32_io
1980
+ @clipped
1981
+ @preserve_channel_dim
1982
+ def unsharp_mask(
1983
+ image: np.ndarray,
1984
+ ksize: int,
1985
+ sigma: float,
1986
+ alpha: float,
1987
+ threshold: int,
1988
+ ) -> np.ndarray:
1989
+ """Apply an unsharp mask to an image.
1990
+ This function applies an unsharp mask to an image using the Gaussian blur function.
1991
+ The unsharp mask is applied by subtracting the blurred image from the original image and
1992
+ then adding the result to the original image.
1993
+
1994
+ Args:
1995
+ image (np.ndarray): Input image as a numpy array.
1996
+ ksize (int): The kernel size to use for the Gaussian blur.
1997
+ sigma (float): The sigma value to use for the Gaussian blur.
1998
+ alpha (float): The alpha value to use for the unsharp mask.
1999
+ threshold (int): The threshold value to use for the unsharp mask.
2000
+
2001
+ Returns:
2002
+ np.ndarray: The unsharp mask applied to the image.
2003
+
2004
+ """
2005
+ blur_fn = maybe_process_in_chunks(
2006
+ cv2.GaussianBlur,
2007
+ ksize=(ksize, ksize),
2008
+ sigmaX=sigma,
2009
+ )
2010
+
2011
+ if image.ndim == NUM_MULTI_CHANNEL_DIMENSIONS and get_num_channels(image) == 1:
2012
+ image = np.squeeze(image, axis=-1)
2013
+
2014
+ blur = blur_fn(image)
2015
+ residual = image - blur
2016
+
2017
+ # Do not sharpen noise
2018
+ mask = np.abs(residual) * 255 > threshold
2019
+ mask = mask.astype(np.float32)
2020
+
2021
+ sharp = image + alpha * residual
2022
+ # Avoid color noise artefacts.
2023
+ sharp = np.clip(sharp, 0, 1, out=sharp)
2024
+
2025
+ soft_mask = blur_fn(mask)
2026
+
2027
+ return add_array(
2028
+ multiply(sharp, soft_mask),
2029
+ multiply(image, 1 - soft_mask),
2030
+ inplace=True,
2031
+ )
2032
+
2033
+
2034
+ @preserve_channel_dim
2035
+ def pixel_dropout(
2036
+ image: np.ndarray,
2037
+ drop_mask: np.ndarray,
2038
+ drop_values: np.ndarray,
2039
+ ) -> np.ndarray:
2040
+ """Apply pixel dropout to the image.
2041
+
2042
+ Args:
2043
+ image (np.ndarray): Input image
2044
+ drop_mask (np.ndarray): Boolean mask indicating which pixels to drop
2045
+ drop_values (np.ndarray): Values to replace dropped pixels with
2046
+
2047
+ Returns:
2048
+ np.ndarray: Image with dropped pixels
2049
+
2050
+ """
2051
+ return np.where(drop_mask, drop_values, image)
2052
+
2053
+
2054
+ @float32_io
2055
+ @clipped
2056
+ @preserve_channel_dim
2057
+ def spatter_rain(img: np.ndarray, rain: np.ndarray) -> np.ndarray:
2058
+ """Apply spatter rain to an image.
2059
+
2060
+ This function applies spatter rain to an image by adding the rain to the image.
2061
+
2062
+ Args:
2063
+ img (np.ndarray): Input image as a numpy array.
2064
+ rain (np.ndarray): Rain image as a numpy array.
2065
+
2066
+ Returns:
2067
+ np.ndarray: The spatter rain applied to the image.
2068
+
2069
+ """
2070
+ return add(img, rain, inplace=False)
2071
+
2072
+
2073
+ @float32_io
2074
+ @clipped
2075
+ @preserve_channel_dim
2076
+ def spatter_mud(img: np.ndarray, non_mud: np.ndarray, mud: np.ndarray) -> np.ndarray:
2077
+ """Apply spatter mud to an image.
2078
+
2079
+ This function applies spatter mud to an image by adding the mud to the image.
2080
+
2081
+ Args:
2082
+ img (np.ndarray): Input image as a numpy array.
2083
+ non_mud (np.ndarray): Non-mud image as a numpy array.
2084
+ mud (np.ndarray): Mud image as a numpy array.
2085
+
2086
+ Returns:
2087
+ np.ndarray: The spatter mud applied to the image.
2088
+
2089
+ """
2090
+ return add(img * non_mud, mud, inplace=False)
2091
+
2092
+
2093
+ @uint8_io
2094
+ @clipped
2095
+ def chromatic_aberration(
2096
+ img: np.ndarray,
2097
+ primary_distortion_red: float,
2098
+ secondary_distortion_red: float,
2099
+ primary_distortion_blue: float,
2100
+ secondary_distortion_blue: float,
2101
+ interpolation: int,
2102
+ ) -> np.ndarray:
2103
+ """Apply chromatic aberration to an image.
2104
+
2105
+ This function applies chromatic aberration to an image by distorting the red and blue channels.
2106
+
2107
+ Args:
2108
+ img (np.ndarray): Input image as a numpy array.
2109
+ primary_distortion_red (float): The primary distortion of the red channel.
2110
+ secondary_distortion_red (float): The secondary distortion of the red channel.
2111
+ primary_distortion_blue (float): The primary distortion of the blue channel.
2112
+ secondary_distortion_blue (float): The secondary distortion of the blue channel.
2113
+ interpolation (int): The interpolation method to use.
2114
+
2115
+ Returns:
2116
+ np.ndarray: The chromatic aberration applied to the image.
2117
+
2118
+ """
2119
+ height, width = img.shape[:2]
2120
+
2121
+ # Build camera matrix
2122
+ camera_mat = np.eye(3, dtype=np.float32)
2123
+ camera_mat[0, 0] = width
2124
+ camera_mat[1, 1] = height
2125
+ camera_mat[0, 2] = width / 2.0
2126
+ camera_mat[1, 2] = height / 2.0
2127
+
2128
+ # Build distortion coefficients
2129
+ distortion_coeffs_red = np.array(
2130
+ [primary_distortion_red, secondary_distortion_red, 0, 0],
2131
+ dtype=np.float32,
2132
+ )
2133
+ distortion_coeffs_blue = np.array(
2134
+ [primary_distortion_blue, secondary_distortion_blue, 0, 0],
2135
+ dtype=np.float32,
2136
+ )
2137
+
2138
+ # Distort the red and blue channels
2139
+ red_distorted = _distort_channel(
2140
+ img[..., 0],
2141
+ camera_mat,
2142
+ distortion_coeffs_red,
2143
+ height,
2144
+ width,
2145
+ interpolation,
2146
+ )
2147
+ blue_distorted = _distort_channel(
2148
+ img[..., 2],
2149
+ camera_mat,
2150
+ distortion_coeffs_blue,
2151
+ height,
2152
+ width,
2153
+ interpolation,
2154
+ )
2155
+
2156
+ return np.dstack([red_distorted, img[..., 1], blue_distorted])
2157
+
2158
+
2159
+ def _distort_channel(
2160
+ channel: np.ndarray,
2161
+ camera_mat: np.ndarray,
2162
+ distortion_coeffs: np.ndarray,
2163
+ height: int,
2164
+ width: int,
2165
+ interpolation: int,
2166
+ ) -> np.ndarray:
2167
+ map_x, map_y = cv2.initUndistortRectifyMap(
2168
+ cameraMatrix=camera_mat,
2169
+ distCoeffs=distortion_coeffs,
2170
+ R=None,
2171
+ newCameraMatrix=camera_mat,
2172
+ size=(width, height),
2173
+ m1type=cv2.CV_32FC1,
2174
+ )
2175
+ return cv2.remap(
2176
+ channel,
2177
+ map_x,
2178
+ map_y,
2179
+ interpolation=interpolation,
2180
+ borderMode=cv2.BORDER_REPLICATE,
2181
+ )
2182
+
2183
+
2184
+ PLANCKIAN_COEFFS: dict[str, dict[int, list[float]]] = {
2185
+ "blackbody": {
2186
+ 3_000: [0.6743, 0.4029, 0.0013],
2187
+ 3_500: [0.6281, 0.4241, 0.1665],
2188
+ 4_000: [0.5919, 0.4372, 0.2513],
2189
+ 4_500: [0.5623, 0.4457, 0.3154],
2190
+ 5_000: [0.5376, 0.4515, 0.3672],
2191
+ 5_500: [0.5163, 0.4555, 0.4103],
2192
+ 6_000: [0.4979, 0.4584, 0.4468],
2193
+ 6_500: [0.4816, 0.4604, 0.4782],
2194
+ 7_000: [0.4672, 0.4619, 0.5053],
2195
+ 7_500: [0.4542, 0.4630, 0.5289],
2196
+ 8_000: [0.4426, 0.4638, 0.5497],
2197
+ 8_500: [0.4320, 0.4644, 0.5681],
2198
+ 9_000: [0.4223, 0.4648, 0.5844],
2199
+ 9_500: [0.4135, 0.4651, 0.5990],
2200
+ 10_000: [0.4054, 0.4653, 0.6121],
2201
+ 10_500: [0.3980, 0.4654, 0.6239],
2202
+ 11_000: [0.3911, 0.4655, 0.6346],
2203
+ 11_500: [0.3847, 0.4656, 0.6444],
2204
+ 12_000: [0.3787, 0.4656, 0.6532],
2205
+ 12_500: [0.3732, 0.4656, 0.6613],
2206
+ 13_000: [0.3680, 0.4655, 0.6688],
2207
+ 13_500: [0.3632, 0.4655, 0.6756],
2208
+ 14_000: [0.3586, 0.4655, 0.6820],
2209
+ 14_500: [0.3544, 0.4654, 0.6878],
2210
+ 15_000: [0.3503, 0.4653, 0.6933],
2211
+ },
2212
+ "cied": {
2213
+ 4_000: [0.5829, 0.4421, 0.2288],
2214
+ 4_500: [0.5510, 0.4514, 0.2948],
2215
+ 5_000: [0.5246, 0.4576, 0.3488],
2216
+ 5_500: [0.5021, 0.4618, 0.3941],
2217
+ 6_000: [0.4826, 0.4646, 0.4325],
2218
+ 6_500: [0.4654, 0.4667, 0.4654],
2219
+ 7_000: [0.4502, 0.4681, 0.4938],
2220
+ 7_500: [0.4364, 0.4692, 0.5186],
2221
+ 8_000: [0.4240, 0.4700, 0.5403],
2222
+ 8_500: [0.4127, 0.4705, 0.5594],
2223
+ 9_000: [0.4023, 0.4709, 0.5763],
2224
+ 9_500: [0.3928, 0.4713, 0.5914],
2225
+ 10_000: [0.3839, 0.4715, 0.6049],
2226
+ 10_500: [0.3757, 0.4716, 0.6171],
2227
+ 11_000: [0.3681, 0.4717, 0.6281],
2228
+ 11_500: [0.3609, 0.4718, 0.6380],
2229
+ 12_000: [0.3543, 0.4719, 0.6472],
2230
+ 12_500: [0.3480, 0.4719, 0.6555],
2231
+ 13_000: [0.3421, 0.4719, 0.6631],
2232
+ 13_500: [0.3365, 0.4719, 0.6702],
2233
+ 14_000: [0.3313, 0.4719, 0.6766],
2234
+ 14_500: [0.3263, 0.4719, 0.6826],
2235
+ 15_000: [0.3217, 0.4719, 0.6882],
2236
+ },
2237
+ }
2238
+
2239
+
2240
+ @clipped
2241
+ def planckian_jitter(
2242
+ img: np.ndarray,
2243
+ temperature: int,
2244
+ mode: Literal["blackbody", "cied"],
2245
+ ) -> np.ndarray:
2246
+ """Apply Planckian jitter to an image.
2247
+
2248
+ This function applies Planckian jitter to an image by linearly interpolating
2249
+ between the two closest temperatures in the PLANCKIAN_COEFFS dictionary.
2250
+
2251
+ Args:
2252
+ img (np.ndarray): Input image as a numpy array.
2253
+ temperature (int): The temperature to apply.
2254
+ mode (Literal["blackbody", "cied"]): The mode to use.
2255
+
2256
+ Returns:
2257
+ np.ndarray: The Planckian jitter applied to the image.
2258
+
2259
+ """
2260
+ img = img.copy()
2261
+ # Get the min and max temperatures for the given mode
2262
+ min_temp = min(PLANCKIAN_COEFFS[mode].keys())
2263
+ max_temp = max(PLANCKIAN_COEFFS[mode].keys())
2264
+
2265
+ # Clamp the temperature to the available range
2266
+ temperature = np.clip(temperature, min_temp, max_temp)
2267
+
2268
+ # Linearly interpolate between 2 closest temperatures
2269
+ step = 500
2270
+ t_left = max(
2271
+ (temperature // step) * step,
2272
+ min_temp,
2273
+ ) # Ensure t_left doesn't go below min_temp
2274
+ t_right = min(
2275
+ (temperature // step + 1) * step,
2276
+ max_temp,
2277
+ ) # Ensure t_right doesn't exceed max_temp
2278
+
2279
+ # Handle the case where temperature is at or near min_temp or max_temp
2280
+ if t_left == t_right:
2281
+ coeffs = np.array(PLANCKIAN_COEFFS[mode][t_left])
2282
+ else:
2283
+ w_right = (temperature - t_left) / (t_right - t_left)
2284
+ w_left = 1 - w_right
2285
+ coeffs = w_left * np.array(PLANCKIAN_COEFFS[mode][t_left]) + w_right * np.array(
2286
+ PLANCKIAN_COEFFS[mode][t_right],
2287
+ )
2288
+
2289
+ img[:, :, 0] = multiply_by_constant(
2290
+ img[:, :, 0],
2291
+ coeffs[0] / coeffs[1],
2292
+ inplace=True,
2293
+ )
2294
+ img[:, :, 2] = multiply_by_constant(
2295
+ img[:, :, 2],
2296
+ coeffs[2] / coeffs[1],
2297
+ inplace=True,
2298
+ )
2299
+
2300
+ return img
2301
+
2302
+
2303
+ @clipped
2304
+ def add_noise(img: np.ndarray, noise: np.ndarray) -> np.ndarray:
2305
+ """Add noise to an image.
2306
+
2307
+ This function adds noise to an image by adding the noise to the image.
2308
+
2309
+ Args:
2310
+ img (np.ndarray): Input image as a numpy array.
2311
+ noise (np.ndarray): Noise as a numpy array.
2312
+
2313
+ Returns:
2314
+ np.ndarray: The noise added to the image.
2315
+
2316
+ """
2317
+ n_tiles = np.prod(img.shape) // np.prod(noise.shape)
2318
+ noise = np.tile(noise, (n_tiles,) + (1,) * noise.ndim).reshape(img.shape)
2319
+
2320
+ return add_array(img, noise, inplace=False)
2321
+
2322
+
2323
+ def slic(
2324
+ image: np.ndarray,
2325
+ n_segments: int,
2326
+ compactness: float = 10.0,
2327
+ max_iterations: int = 10,
2328
+ ) -> np.ndarray:
2329
+ """Simple Linear Iterative Clustering (SLIC) superpixel segmentation using OpenCV and NumPy.
2330
+
2331
+ Args:
2332
+ image (np.ndarray): Input image (2D or 3D numpy array).
2333
+ n_segments (int): Approximate number of superpixels to generate.
2334
+ compactness (float): Balance between color proximity and space proximity.
2335
+ max_iterations (int): Maximum number of iterations for k-means.
2336
+
2337
+ Returns:
2338
+ np.ndarray: Segmentation mask where each superpixel has a unique label.
2339
+
2340
+ """
2341
+ if image.ndim == MONO_CHANNEL_DIMENSIONS:
2342
+ image = image[..., np.newaxis]
2343
+
2344
+ height, width = image.shape[:2]
2345
+ num_pixels = height * width
2346
+
2347
+ # Normalize image to [0, 1] range
2348
+ image_normalized = image.astype(np.float32) / np.max(image + 1e-6)
2349
+
2350
+ # Initialize cluster centers
2351
+ grid_step = int((num_pixels / n_segments) ** 0.5)
2352
+ x_range = np.arange(grid_step // 2, width, grid_step)
2353
+ y_range = np.arange(grid_step // 2, height, grid_step)
2354
+ centers = np.array(
2355
+ [(x, y) for y in y_range for x in x_range if x < width and y < height],
2356
+ )
2357
+
2358
+ # Initialize labels and distances
2359
+ labels = -1 * np.ones((height, width), dtype=np.int32)
2360
+ distances = np.full((height, width), np.inf)
2361
+
2362
+ for _ in range(max_iterations):
2363
+ for i, center in enumerate(centers):
2364
+ y, x = int(center[1]), int(center[0])
2365
+
2366
+ # Define the neighborhood
2367
+ y_low, y_high = max(0, y - grid_step), min(height, y + grid_step + 1)
2368
+ x_low, x_high = max(0, x - grid_step), min(width, x + grid_step + 1)
2369
+
2370
+ # Compute distances
2371
+ crop = image_normalized[y_low:y_high, x_low:x_high]
2372
+ color_diff = crop - image_normalized[y, x]
2373
+ color_distance = np.sum(color_diff**2, axis=-1)
2374
+
2375
+ yy, xx = np.ogrid[y_low:y_high, x_low:x_high]
2376
+ spatial_distance = ((yy - y) ** 2 + (xx - x) ** 2) / (grid_step**2)
2377
+
2378
+ distance = color_distance + compactness * spatial_distance
2379
+
2380
+ mask = distance < distances[y_low:y_high, x_low:x_high]
2381
+ distances[y_low:y_high, x_low:x_high][mask] = distance[mask]
2382
+ labels[y_low:y_high, x_low:x_high][mask] = i
2383
+
2384
+ # Update centers
2385
+ for i in range(len(centers)):
2386
+ mask = labels == i
2387
+ if np.any(mask):
2388
+ centers[i] = np.mean(np.argwhere(mask), axis=0)[::-1]
2389
+
2390
+ return labels
2391
+
2392
+
2393
+ @preserve_channel_dim
2394
+ @float32_io
2395
+ def shot_noise(
2396
+ img: np.ndarray,
2397
+ scale: float,
2398
+ random_generator: np.random.Generator,
2399
+ ) -> np.ndarray:
2400
+ """Apply shot noise to the image.
2401
+
2402
+ Args:
2403
+ img (np.ndarray): Input image
2404
+ scale (float): Scale factor for the noise
2405
+ random_generator (np.random.Generator): Random number generator
2406
+
2407
+ Returns:
2408
+ np.ndarray: Image with shot noise
2409
+
2410
+ """
2411
+ # Apply inverse gamma correction to work in linear space
2412
+ img_linear = cv2.pow(img, 2.2)
2413
+
2414
+ # Scale image values and add small constant to avoid zero values
2415
+ scaled_img = (img_linear + scale * 1e-6) / scale
2416
+
2417
+ # Generate Poisson noise
2418
+ noisy_img = multiply_by_constant(
2419
+ random_generator.poisson(scaled_img).astype(np.float32),
2420
+ scale,
2421
+ inplace=True,
2422
+ )
2423
+
2424
+ # Scale back and apply gamma correction
2425
+ return power(np.clip(noisy_img, 0, 1, out=noisy_img), 1 / 2.2)
2426
+
2427
+
2428
+ def get_safe_brightness_contrast_params(
2429
+ alpha: float,
2430
+ beta: float,
2431
+ max_value: float,
2432
+ ) -> tuple[float, float]:
2433
+ """Get safe brightness and contrast parameters.
2434
+
2435
+ Args:
2436
+ alpha (float): Contrast factor
2437
+ beta (float): Brightness factor
2438
+ max_value (float): Maximum pixel value
2439
+
2440
+ Returns:
2441
+ tuple[float, float]: Safe alpha and beta values
2442
+
2443
+ """
2444
+ if alpha > 0:
2445
+ # For x = max_value: alpha * max_value + beta <= max_value
2446
+ # For x = 0: beta >= 0
2447
+ safe_beta = np.clip(beta, 0, max_value)
2448
+ # From alpha * max_value + safe_beta <= max_value
2449
+ safe_alpha = min(alpha, (max_value - safe_beta) / max_value)
2450
+ else:
2451
+ # For x = 0: beta <= max_value
2452
+ # For x = max_value: alpha * max_value + beta >= 0
2453
+ safe_beta = min(beta, max_value)
2454
+ # From alpha * max_value + safe_beta >= 0
2455
+ safe_alpha = max(alpha, -safe_beta / max_value)
2456
+
2457
+ return safe_alpha, safe_beta
2458
+
2459
+
2460
+ def generate_noise(
2461
+ noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
2462
+ spatial_mode: Literal["constant", "per_pixel", "shared"],
2463
+ shape: tuple[int, ...],
2464
+ params: dict[str, Any] | None,
2465
+ max_value: float,
2466
+ approximation: float,
2467
+ random_generator: np.random.Generator,
2468
+ ) -> np.ndarray:
2469
+ """Generate noise with optional approximation for speed.
2470
+
2471
+ This function generates noise with optional approximation for speed.
2472
+
2473
+ Args:
2474
+ noise_type (Literal["uniform", "gaussian", "laplace", "beta"]): The type of noise to generate.
2475
+ spatial_mode (Literal["constant", "per_pixel", "shared"]): The spatial mode to use.
2476
+ shape (tuple[int, ...]): The shape of the noise to generate.
2477
+ params (dict[str, Any] | None): The parameters of the noise to generate.
2478
+ max_value (float): The maximum value of the noise to generate.
2479
+ approximation (float): The approximation to use for the noise to generate.
2480
+ random_generator (np.random.Generator): The random number generator to use.
2481
+
2482
+ Returns:
2483
+ np.ndarray: The noise generated.
2484
+
2485
+ """
2486
+ if params is None:
2487
+ return np.zeros(shape, dtype=np.float32)
2488
+
2489
+ cv2_seed = random_generator.integers(0, 2**16)
2490
+ cv2.setRNGSeed(cv2_seed)
2491
+
2492
+ if spatial_mode == "constant":
2493
+ return generate_constant_noise(
2494
+ noise_type,
2495
+ shape,
2496
+ params,
2497
+ max_value,
2498
+ random_generator,
2499
+ )
2500
+
2501
+ if approximation == 1.0:
2502
+ if spatial_mode == "shared":
2503
+ return generate_shared_noise(
2504
+ noise_type,
2505
+ shape,
2506
+ params,
2507
+ max_value,
2508
+ random_generator,
2509
+ )
2510
+ return generate_per_pixel_noise(
2511
+ noise_type,
2512
+ shape,
2513
+ params,
2514
+ max_value,
2515
+ random_generator,
2516
+ )
2517
+
2518
+ # Calculate reduced size for noise generation
2519
+ height, width = shape[:2]
2520
+ reduced_height = max(1, int(height * approximation))
2521
+ reduced_width = max(1, int(width * approximation))
2522
+ reduced_shape = (reduced_height, reduced_width, *shape[2:])
2523
+
2524
+ # Generate noise at reduced resolution
2525
+ if spatial_mode == "shared":
2526
+ noise = generate_shared_noise(
2527
+ noise_type,
2528
+ reduced_shape,
2529
+ params,
2530
+ max_value,
2531
+ random_generator,
2532
+ )
2533
+ else: # per_pixel
2534
+ noise = generate_per_pixel_noise(
2535
+ noise_type,
2536
+ reduced_shape,
2537
+ params,
2538
+ max_value,
2539
+ random_generator,
2540
+ )
2541
+
2542
+ # Resize noise to original size using existing resize function
2543
+ return fgeometric.resize(noise, (height, width), interpolation=cv2.INTER_LINEAR)
2544
+
2545
+
2546
+ def generate_constant_noise(
2547
+ noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
2548
+ shape: tuple[int, ...],
2549
+ params: dict[str, Any],
2550
+ max_value: float,
2551
+ random_generator: np.random.Generator,
2552
+ ) -> np.ndarray:
2553
+ """Generate constant noise.
2554
+
2555
+ This function generates constant noise by sampling from the noise distribution.
2556
+
2557
+ Args:
2558
+ noise_type (Literal["uniform", "gaussian", "laplace", "beta"]): The type of noise to generate.
2559
+ shape (tuple[int, ...]): The shape of the noise to generate.
2560
+ params (dict[str, Any]): The parameters of the noise to generate.
2561
+ max_value (float): The maximum value of the noise to generate.
2562
+ random_generator (np.random.Generator): The random number generator to use.
2563
+
2564
+ Returns:
2565
+ np.ndarray: The constant noise generated.
2566
+
2567
+ """
2568
+ num_channels = shape[-1] if len(shape) > MONO_CHANNEL_DIMENSIONS else 1
2569
+ return sample_noise(
2570
+ noise_type,
2571
+ (num_channels,),
2572
+ params,
2573
+ max_value,
2574
+ random_generator,
2575
+ )
2576
+
2577
+
2578
+ def generate_per_pixel_noise(
2579
+ noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
2580
+ shape: tuple[int, ...],
2581
+ params: dict[str, Any],
2582
+ max_value: float,
2583
+ random_generator: np.random.Generator,
2584
+ ) -> np.ndarray:
2585
+ """Generate per-pixel noise.
2586
+
2587
+ This function generates per-pixel noise by sampling from the noise distribution.
2588
+
2589
+ Args:
2590
+ noise_type (Literal["uniform", "gaussian", "laplace", "beta"]): The type of noise to generate.
2591
+ shape (tuple[int, ...]): The shape of the noise to generate.
2592
+ params (dict[str, Any]): The parameters of the noise to generate.
2593
+ max_value (float): The maximum value of the noise to generate.
2594
+ random_generator (np.random.Generator): The random number generator to use.
2595
+
2596
+ Returns:
2597
+ np.ndarray: The per-pixel noise generated.
2598
+
2599
+ """
2600
+ return sample_noise(noise_type, shape, params, max_value, random_generator)
2601
+
2602
+
2603
+ def sample_noise(
2604
+ noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
2605
+ size: tuple[int, ...],
2606
+ params: dict[str, Any],
2607
+ max_value: float,
2608
+ random_generator: np.random.Generator,
2609
+ ) -> np.ndarray:
2610
+ """Sample from specific noise distribution.
2611
+
2612
+ This function samples from a specific noise distribution.
2613
+
2614
+ Args:
2615
+ noise_type (Literal["uniform", "gaussian", "laplace", "beta"]): The type of noise to generate.
2616
+ size (tuple[int, ...]): The size of the noise to generate.
2617
+ params (dict[str, Any]): The parameters of the noise to generate.
2618
+ max_value (float): The maximum value of the noise to generate.
2619
+ random_generator (np.random.Generator): The random number generator to use.
2620
+
2621
+ Returns:
2622
+ np.ndarray: The noise sampled.
2623
+
2624
+ """
2625
+ if noise_type == "uniform":
2626
+ return sample_uniform(size, params, random_generator) * max_value
2627
+ if noise_type == "gaussian":
2628
+ return sample_gaussian(size, params, random_generator) * max_value
2629
+ if noise_type == "laplace":
2630
+ return sample_laplace(size, params, random_generator) * max_value
2631
+ if noise_type == "beta":
2632
+ return sample_beta(size, params, random_generator) * max_value
2633
+
2634
+ raise ValueError(f"Unknown noise type: {noise_type}")
2635
+
2636
+
2637
+ def sample_uniform(
2638
+ size: tuple[int, ...],
2639
+ params: dict[str, Any],
2640
+ random_generator: np.random.Generator,
2641
+ ) -> np.ndarray | float:
2642
+ """Sample from uniform distribution.
2643
+
2644
+ Args:
2645
+ size (tuple[int, ...]): Size of the output array
2646
+ params (dict[str, Any]): Distribution parameters
2647
+ random_generator (np.random.Generator): Random number generator
2648
+
2649
+ Returns:
2650
+ np.ndarray | float: Sampled values
2651
+
2652
+ """
2653
+ if len(size) == 1: # constant mode
2654
+ ranges = params["ranges"]
2655
+ num_channels = size[0]
2656
+
2657
+ if len(ranges) == 1:
2658
+ ranges = ranges * num_channels
2659
+ elif len(ranges) < num_channels:
2660
+ raise ValueError(
2661
+ f"Not enough ranges provided. Expected {num_channels}, got {len(ranges)}",
2662
+ )
2663
+
2664
+ return np.array(
2665
+ [random_generator.uniform(low, high) for low, high in ranges[:num_channels]],
2666
+ )
2667
+
2668
+ # use first range for spatial noise
2669
+ low, high = params["ranges"][0]
2670
+ return random_generator.uniform(low, high, size=size)
2671
+
2672
+
2673
+ def sample_gaussian(
2674
+ size: tuple[int, ...],
2675
+ params: dict[str, Any],
2676
+ random_generator: np.random.Generator,
2677
+ ) -> np.ndarray:
2678
+ """Sample from Gaussian distribution.
2679
+
2680
+ This function samples from a Gaussian distribution.
2681
+
2682
+ Args:
2683
+ size (tuple[int, ...]): The size of the noise to generate.
2684
+ params (dict[str, Any]): The parameters of the noise to generate.
2685
+ random_generator (np.random.Generator): The random number generator to use.
2686
+
2687
+ Returns:
2688
+ np.ndarray: The Gaussian noise sampled.
2689
+
2690
+ """
2691
+ mean = (
2692
+ params["mean_range"][0]
2693
+ if params["mean_range"][0] == params["mean_range"][1]
2694
+ else random_generator.uniform(*params["mean_range"])
2695
+ )
2696
+ std = (
2697
+ params["std_range"][0]
2698
+ if params["std_range"][0] == params["std_range"][1]
2699
+ else random_generator.uniform(*params["std_range"])
2700
+ )
2701
+ num_channels = size[2] if len(size) > MONO_CHANNEL_DIMENSIONS else 1
2702
+ mean_vector = mean * np.ones(shape=(num_channels,), dtype=np.float32)
2703
+ std_dev_vector = std * np.ones(shape=(num_channels,), dtype=np.float32)
2704
+ gaussian_sampled_arr = np.zeros(shape=size)
2705
+
2706
+ cv2.randn(dst=gaussian_sampled_arr, mean=mean_vector, stddev=std_dev_vector)
2707
+ return gaussian_sampled_arr.astype(np.float32)
2708
+
2709
+
2710
+ def sample_laplace(
2711
+ size: tuple[int, ...],
2712
+ params: dict[str, Any],
2713
+ random_generator: np.random.Generator,
2714
+ ) -> np.ndarray:
2715
+ """Sample from Laplace distribution.
2716
+
2717
+ This function samples from a Laplace distribution.
2718
+
2719
+ Args:
2720
+ size (tuple[int, ...]): The size of the noise to generate.
2721
+ params (dict[str, Any]): The parameters of the noise to generate.
2722
+ random_generator (np.random.Generator): The random number generator to use.
2723
+
2724
+ Returns:
2725
+ np.ndarray: The Laplace noise sampled.
2726
+
2727
+ """
2728
+ loc = random_generator.uniform(*params["mean_range"])
2729
+ scale = random_generator.uniform(*params["scale_range"])
2730
+ return random_generator.laplace(loc=loc, scale=scale, size=size)
2731
+
2732
+
2733
+ def sample_beta(
2734
+ size: tuple[int, ...],
2735
+ params: dict[str, Any],
2736
+ random_generator: np.random.Generator,
2737
+ ) -> np.ndarray:
2738
+ """Sample from Beta distribution.
2739
+
2740
+ This function samples from a Beta distribution.
2741
+
2742
+ Args:
2743
+ size (tuple[int, ...]): The size of the noise to generate.
2744
+ params (dict[str, Any]): The parameters of the noise to generate.
2745
+ random_generator (np.random.Generator): The random number generator to use.
2746
+
2747
+ Returns:
2748
+ np.ndarray: The Beta noise sampled.
2749
+
2750
+ """
2751
+ alpha = random_generator.uniform(*params["alpha_range"])
2752
+ beta = random_generator.uniform(*params["beta_range"])
2753
+ scale = random_generator.uniform(*params["scale_range"])
2754
+
2755
+ # Sample from Beta[0,1] and transform to [-scale,scale]
2756
+ samples = random_generator.beta(alpha, beta, size=size)
2757
+ return (2 * samples - 1) * scale
2758
+
2759
+
2760
+ def generate_shared_noise(
2761
+ noise_type: Literal["uniform", "gaussian", "laplace", "beta"],
2762
+ shape: tuple[int, ...],
2763
+ params: dict[str, Any],
2764
+ max_value: float,
2765
+ random_generator: np.random.Generator,
2766
+ ) -> np.ndarray:
2767
+ """Generate shared noise.
2768
+
2769
+ Args:
2770
+ noise_type (Literal["uniform", "gaussian", "laplace", "beta"]): Type of noise to generate
2771
+ shape (tuple[int, ...]): Shape of the output array
2772
+ params (dict[str, Any]): Distribution parameters
2773
+ max_value (float): Maximum value for the noise
2774
+ random_generator (np.random.Generator): Random number generator
2775
+
2776
+ Returns:
2777
+ np.ndarray: Generated noise
2778
+
2779
+ """
2780
+ # Generate noise for (H, W)
2781
+ height, width = shape[:2]
2782
+ noise_map = sample_noise(
2783
+ noise_type,
2784
+ (height, width),
2785
+ params,
2786
+ max_value,
2787
+ random_generator,
2788
+ )
2789
+
2790
+ # If input is multichannel, broadcast noise to all channels
2791
+ if len(shape) > MONO_CHANNEL_DIMENSIONS:
2792
+ return np.broadcast_to(noise_map[..., None], shape)
2793
+ return noise_map
2794
+
2795
+
2796
+ @clipped
2797
+ @preserve_channel_dim
2798
+ def sharpen_gaussian(
2799
+ img: np.ndarray,
2800
+ alpha: float,
2801
+ kernel_size: int,
2802
+ sigma: float,
2803
+ ) -> np.ndarray:
2804
+ """Sharpen image using Gaussian blur.
2805
+
2806
+ This function sharpens an image using a Gaussian blur.
2807
+
2808
+ Args:
2809
+ img (np.ndarray): The image to sharpen.
2810
+ alpha (float): The alpha value to use for the sharpening.
2811
+ kernel_size (int): The kernel size to use for the Gaussian blur.
2812
+ sigma (float): The sigma value to use for the Gaussian blur.
2813
+
2814
+ Returns:
2815
+ np.ndarray: The sharpened image.
2816
+
2817
+ """
2818
+ blurred = cv2.GaussianBlur(
2819
+ img,
2820
+ ksize=(kernel_size, kernel_size),
2821
+ sigmaX=sigma,
2822
+ sigmaY=sigma,
2823
+ )
2824
+ # Unsharp mask formula: original + alpha * (original - blurred)
2825
+ # This is equivalent to: original * (1 + alpha) - alpha * blurred
2826
+ return img + alpha * (img - blurred)
2827
+
2828
+
2829
+ def apply_salt_and_pepper(
2830
+ img: np.ndarray,
2831
+ salt_mask: np.ndarray,
2832
+ pepper_mask: np.ndarray,
2833
+ ) -> np.ndarray:
2834
+ """Apply salt and pepper noise to an image.
2835
+
2836
+ This function applies salt and pepper noise to an image using pre-computed masks.
2837
+
2838
+ Args:
2839
+ img (np.ndarray): The image to apply salt and pepper noise to.
2840
+ salt_mask (np.ndarray): The salt mask to use for the salt and pepper noise.
2841
+ pepper_mask (np.ndarray): The pepper mask to use for the salt and pepper noise.
2842
+
2843
+ Returns:
2844
+ np.ndarray: The image with salt and pepper noise applied.
2845
+
2846
+ """
2847
+ # Add channel dimension to masks if image is 3D
2848
+ if img.ndim == 3:
2849
+ salt_mask = salt_mask[..., None]
2850
+ pepper_mask = pepper_mask[..., None]
2851
+
2852
+ max_value = MAX_VALUES_BY_DTYPE[img.dtype]
2853
+ return np.where(salt_mask, max_value, np.where(pepper_mask, 0, img))
2854
+
2855
+
2856
+ # Pre-compute constant kernels
2857
+ DIAMOND_KERNEL = np.array(
2858
+ [
2859
+ [0.25, 0.0, 0.25],
2860
+ [0.0, 0.0, 0.0],
2861
+ [0.25, 0.0, 0.25],
2862
+ ],
2863
+ dtype=np.float32,
2864
+ )
2865
+
2866
+ SQUARE_KERNEL = np.array(
2867
+ [
2868
+ [0.0, 0.25, 0.0],
2869
+ [0.25, 0.0, 0.25],
2870
+ [0.0, 0.25, 0.0],
2871
+ ],
2872
+ dtype=np.float32,
2873
+ )
2874
+
2875
+ # Pre-compute initial grid
2876
+ INITIAL_GRID_SIZE = (3, 3)
2877
+
2878
+
2879
+ def generate_plasma_pattern(
2880
+ target_shape: tuple[int, int],
2881
+ roughness: float,
2882
+ random_generator: np.random.Generator,
2883
+ ) -> np.ndarray:
2884
+ """Generate a plasma pattern.
2885
+
2886
+ This function generates a plasma pattern using the diamond-square algorithm.
2887
+
2888
+ Args:
2889
+ target_shape (tuple[int, int]): The shape of the plasma pattern to generate.
2890
+ roughness (float): The roughness of the plasma pattern.
2891
+ random_generator (np.random.Generator): The random number generator to use.
2892
+
2893
+ Returns:
2894
+ np.ndarray: The plasma pattern generated.
2895
+
2896
+ """
2897
+
2898
+ def one_diamond_square_step(current_grid: np.ndarray, noise_scale: float) -> np.ndarray:
2899
+ next_height = (current_grid.shape[0] - 1) * 2 + 1
2900
+ next_width = (current_grid.shape[1] - 1) * 2 + 1
2901
+
2902
+ # Pre-allocate expanded grid
2903
+ expanded_grid = np.zeros((next_height, next_width), dtype=np.float32)
2904
+
2905
+ # Generate all noise at once for both steps (already scaled by noise_scale)
2906
+ all_noise = random_generator.uniform(-noise_scale, noise_scale, (next_height, next_width)).astype(np.float32)
2907
+
2908
+ # Copy existing points with noise
2909
+ expanded_grid[::2, ::2] = current_grid + all_noise[::2, ::2]
2910
+
2911
+ # Diamond step - keep separate for natural look
2912
+ diamond_interpolation = cv2.filter2D(expanded_grid, -1, DIAMOND_KERNEL, borderType=cv2.BORDER_CONSTANT)
2913
+ diamond_mask = diamond_interpolation > 0
2914
+ expanded_grid += (diamond_interpolation + all_noise) * diamond_mask
2915
+
2916
+ # Square step - keep separate for natural look
2917
+ square_interpolation = cv2.filter2D(expanded_grid, -1, SQUARE_KERNEL, borderType=cv2.BORDER_CONSTANT)
2918
+ square_mask = square_interpolation > 0
2919
+ expanded_grid += (square_interpolation + all_noise) * square_mask
2920
+
2921
+ # Normalize after each step to prevent value drift
2922
+ return cv2.normalize(expanded_grid, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
2923
+
2924
+ # Pre-compute noise scales
2925
+ max_dimension = max(target_shape)
2926
+ power_of_two_size = 2 ** np.ceil(np.log2(max_dimension - 1)) + 1
2927
+ total_steps = int(np.log2(power_of_two_size - 1) - 1)
2928
+ noise_scales = np.float32([roughness**i for i in range(total_steps)])
2929
+
2930
+ # Initialize with small random grid
2931
+ plasma_grid = random_generator.uniform(-1, 1, (3, 3)).astype(np.float32)
2932
+
2933
+ # Recursively apply diamond-square steps
2934
+ for noise_scale in noise_scales:
2935
+ plasma_grid = one_diamond_square_step(plasma_grid, noise_scale)
2936
+
2937
+ return np.clip(
2938
+ cv2.normalize(plasma_grid[: target_shape[0], : target_shape[1]], None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F),
2939
+ 0,
2940
+ 1,
2941
+ )
2942
+
2943
+
2944
+ @clipped
2945
+ @float32_io
2946
+ def apply_plasma_brightness_contrast(
2947
+ img: np.ndarray,
2948
+ brightness_factor: float,
2949
+ contrast_factor: float,
2950
+ plasma_pattern: np.ndarray,
2951
+ ) -> np.ndarray:
2952
+ """Apply plasma-based brightness and contrast adjustments.
2953
+
2954
+ This function applies plasma-based brightness and contrast adjustments to an image.
2955
+
2956
+ Args:
2957
+ img (np.ndarray): The image to apply the brightness and contrast adjustments to.
2958
+ brightness_factor (float): The brightness factor to apply.
2959
+ contrast_factor (float): The contrast factor to apply.
2960
+ plasma_pattern (np.ndarray): The plasma pattern to use for the brightness and contrast adjustments.
2961
+
2962
+ Returns:
2963
+ np.ndarray: The image with the brightness and contrast adjustments applied.
2964
+
2965
+ """
2966
+ # Early return if no adjustments needed
2967
+ if brightness_factor == 0 and contrast_factor == 0:
2968
+ return img
2969
+
2970
+ img = img.copy()
2971
+
2972
+ # Expand plasma pattern once if needed
2973
+ if img.ndim > MONO_CHANNEL_DIMENSIONS:
2974
+ plasma_pattern = np.tile(plasma_pattern[..., np.newaxis], (1, 1, img.shape[-1]))
2975
+
2976
+ # Apply brightness adjustment
2977
+ if brightness_factor != 0:
2978
+ brightness_adjustment = multiply(plasma_pattern, brightness_factor, inplace=False)
2979
+ img = add(img, brightness_adjustment, inplace=True)
2980
+
2981
+ # Apply contrast adjustment
2982
+ if contrast_factor != 0:
2983
+ mean = img.mean()
2984
+ contrast_weights = multiply(plasma_pattern, contrast_factor, inplace=False) + 1
2985
+
2986
+ img = multiply(img, contrast_weights, inplace=True)
2987
+
2988
+ mean_factor = mean * (1.0 - contrast_weights)
2989
+ return add(img, mean_factor, inplace=True)
2990
+
2991
+ return img
2992
+
2993
+
2994
+ @clipped
2995
+ def apply_plasma_shadow(
2996
+ img: np.ndarray,
2997
+ intensity: float,
2998
+ plasma_pattern: np.ndarray,
2999
+ ) -> np.ndarray:
3000
+ """Apply plasma shadow to the image.
3001
+
3002
+ Args:
3003
+ img (np.ndarray): Input image
3004
+ intensity (float): Shadow intensity
3005
+ plasma_pattern (np.ndarray): Plasma pattern to use
3006
+
3007
+ Returns:
3008
+ np.ndarray: Image with plasma shadow
3009
+
3010
+ """
3011
+ # Scale plasma pattern by intensity first (scalar operation)
3012
+ scaled_pattern = plasma_pattern * intensity
3013
+
3014
+ # Expand dimensions only once if needed
3015
+ if img.ndim > MONO_CHANNEL_DIMENSIONS:
3016
+ scaled_pattern = scaled_pattern[..., np.newaxis]
3017
+
3018
+ # Single multiply operation
3019
+ return img * (1 - scaled_pattern)
3020
+
3021
+
3022
+ def create_directional_gradient(height: int, width: int, angle: float) -> np.ndarray:
3023
+ """Create a directional gradient in [0, 1] range.
3024
+
3025
+ This function creates a directional gradient in the [0, 1] range.
3026
+
3027
+ Args:
3028
+ height (int): The height of the image.
3029
+ width (int): The width of the image.
3030
+ angle (float): The angle of the gradient.
3031
+
3032
+ Returns:
3033
+ np.ndarray: The directional gradient.
3034
+
3035
+ """
3036
+ # Fast path for horizontal gradients
3037
+ if angle == 0:
3038
+ return np.linspace(0, 1, width, dtype=np.float32)[None, :] * np.ones((height, 1), dtype=np.float32)
3039
+ if angle == 180:
3040
+ return np.linspace(1, 0, width, dtype=np.float32)[None, :] * np.ones((height, 1), dtype=np.float32)
3041
+
3042
+ # Fast path for vertical gradients
3043
+ if angle == 90:
3044
+ return np.linspace(0, 1, height, dtype=np.float32)[:, None] * np.ones((1, width), dtype=np.float32)
3045
+ if angle == 270:
3046
+ return np.linspace(1, 0, height, dtype=np.float32)[:, None] * np.ones((1, width), dtype=np.float32)
3047
+
3048
+ # Fast path for diagonal gradients using broadcasting
3049
+ if angle in (45, 135, 225, 315):
3050
+ x = np.linspace(0, 1, width, dtype=np.float32)[None, :] # Horizontal
3051
+ y = np.linspace(0, 1, height, dtype=np.float32)[:, None] # Vertical
3052
+
3053
+ if angle == 45: # Bottom-left to top-right
3054
+ return cv2.normalize(x + y, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
3055
+ if angle == 135: # Bottom-right to top-left
3056
+ return cv2.normalize((1 - x) + y, None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
3057
+ if angle == 225: # Top-right to bottom-left
3058
+ return cv2.normalize((1 - x) + (1 - y), None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
3059
+ # angle == 315: # Top-left to bottom-right
3060
+ return cv2.normalize(x + (1 - y), None, 0, 1, cv2.NORM_MINMAX, dtype=cv2.CV_32F)
3061
+
3062
+ # General case for arbitrary angles using broadcasting
3063
+ y = np.linspace(0, 1, height, dtype=np.float32)[:, None] # Column vector
3064
+ x = np.linspace(0, 1, width, dtype=np.float32)[None, :] # Row vector
3065
+
3066
+ angle_rad = np.deg2rad(angle)
3067
+ cos_a = math.cos(angle_rad)
3068
+ sin_a = math.sin(angle_rad)
3069
+
3070
+ cv2.multiply(x, cos_a, dst=x)
3071
+ cv2.multiply(y, sin_a, dst=y)
3072
+
3073
+ return x + y
3074
+
3075
+
3076
+ @float32_io
3077
+ def apply_linear_illumination(img: np.ndarray, intensity: float, angle: float) -> np.ndarray:
3078
+ """Apply linear illumination to the image.
3079
+
3080
+ Args:
3081
+ img (np.ndarray): Input image
3082
+ intensity (float): Illumination intensity
3083
+ angle (float): Illumination angle in radians
3084
+
3085
+ Returns:
3086
+ np.ndarray: Image with linear illumination
3087
+
3088
+ """
3089
+ height, width = img.shape[:2]
3090
+ abs_intensity = abs(intensity)
3091
+
3092
+ # Create gradient and handle negative intensity in one step
3093
+ gradient = create_directional_gradient(height, width, angle)
3094
+
3095
+ if intensity < 0:
3096
+ cv2.subtract(1, gradient, dst=gradient)
3097
+
3098
+ cv2.multiply(gradient, 2 * abs_intensity, dst=gradient)
3099
+ cv2.add(gradient, 1 - abs_intensity, dst=gradient)
3100
+
3101
+ # Add channel dimension if needed
3102
+ if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
3103
+ gradient = gradient[..., np.newaxis]
3104
+
3105
+ return multiply_by_array(img, gradient)
3106
+
3107
+
3108
+ @clipped
3109
+ def apply_corner_illumination(
3110
+ img: np.ndarray,
3111
+ intensity: float,
3112
+ corner: Literal[0, 1, 2, 3],
3113
+ ) -> np.ndarray:
3114
+ """Apply corner illumination to the image.
3115
+
3116
+ Args:
3117
+ img (np.ndarray): Input image
3118
+ intensity (float): Illumination intensity
3119
+ corner (Literal[0, 1, 2, 3]): The corner to apply the illumination to.
3120
+
3121
+ Returns:
3122
+ np.ndarray: Image with corner illumination applied.
3123
+
3124
+ """
3125
+ if intensity == 0:
3126
+ return img.copy()
3127
+
3128
+ height, width = img.shape[:2]
3129
+
3130
+ # Pre-compute diagonal length once
3131
+ diagonal_length = math.sqrt(height * height + width * width)
3132
+
3133
+ # Create inverted distance map mask directly
3134
+ # Use uint8 for distanceTransform regardless of input dtype
3135
+ mask = np.full((height, width), 255, dtype=np.uint8)
3136
+
3137
+ # Use array indexing instead of conditionals
3138
+ corners = [(0, 0), (0, width - 1), (height - 1, width - 1), (height - 1, 0)]
3139
+ mask[corners[corner]] = 0
3140
+
3141
+ # Calculate distance transform
3142
+ pattern = cv2.distanceTransform(
3143
+ mask,
3144
+ distanceType=cv2.DIST_L2,
3145
+ maskSize=cv2.DIST_MASK_PRECISE,
3146
+ dstType=cv2.CV_32F, # Specify float output directly
3147
+ )
3148
+
3149
+ # Combine operations to reduce array copies
3150
+ cv2.multiply(pattern, -intensity / diagonal_length, dst=pattern)
3151
+ cv2.add(pattern, 1, dst=pattern)
3152
+
3153
+ if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
3154
+ pattern = cv2.merge([pattern] * img.shape[2])
3155
+
3156
+ return multiply_by_array(img, pattern)
3157
+
3158
+
3159
+ @clipped
3160
+ def apply_gaussian_illumination(
3161
+ img: np.ndarray,
3162
+ intensity: float,
3163
+ center: tuple[float, float],
3164
+ sigma: float,
3165
+ ) -> np.ndarray:
3166
+ """Apply gaussian illumination to the image.
3167
+
3168
+ Args:
3169
+ img (np.ndarray): Input image
3170
+ intensity (float): Illumination intensity
3171
+ center (tuple[float, float]): The center of the illumination.
3172
+ sigma (float): The sigma of the illumination.
3173
+
3174
+ """
3175
+ if intensity == 0:
3176
+ return img.copy()
3177
+
3178
+ height, width = img.shape[:2]
3179
+
3180
+ # Pre-compute constants
3181
+ center_x = width * center[0]
3182
+ center_y = height * center[1]
3183
+ sigma2 = 2 * (max(height, width) * sigma) ** 2 # Pre-compute denominator
3184
+
3185
+ # Create coordinate grid and calculate distances in-place
3186
+ y, x = np.ogrid[:height, :width]
3187
+ x = x.astype(np.float32)
3188
+ y = y.astype(np.float32)
3189
+ x -= center_x
3190
+ y -= center_y
3191
+
3192
+ # Calculate squared distances in-place
3193
+ cv2.multiply(x, x, dst=x)
3194
+ cv2.multiply(y, y, dst=y)
3195
+
3196
+ x = x + y
3197
+
3198
+ # Calculate gaussian directly into x array
3199
+ cv2.multiply(x, -1 / sigma2, dst=x)
3200
+ cv2.exp(x, dst=x)
3201
+
3202
+ # Scale by intensity
3203
+ cv2.multiply(x, intensity, dst=x)
3204
+ cv2.add(x, 1, dst=x)
3205
+
3206
+ if img.ndim == NUM_MULTI_CHANNEL_DIMENSIONS:
3207
+ x = cv2.merge([x] * img.shape[2])
3208
+
3209
+ return multiply_by_array(img, x)
3210
+
3211
+
3212
+ @uint8_io
3213
+ def auto_contrast(
3214
+ img: np.ndarray,
3215
+ cutoff: float,
3216
+ ignore: int | None,
3217
+ method: Literal["cdf", "pil"],
3218
+ ) -> np.ndarray:
3219
+ """Apply automatic contrast enhancement.
3220
+
3221
+ Args:
3222
+ img (np.ndarray): Input image
3223
+ cutoff (float): Cutoff percentage for histogram
3224
+ ignore (int | None): Value to ignore in histogram
3225
+ method (Literal["cdf", "pil"]): Method to use for contrast enhancement
3226
+
3227
+ Returns:
3228
+ np.ndarray: Image with enhanced contrast
3229
+
3230
+ """
3231
+ result = img.copy()
3232
+ num_channels = get_num_channels(img)
3233
+ max_value = MAX_VALUES_BY_DTYPE[img.dtype]
3234
+
3235
+ # Pre-compute histograms using cv2.calcHist - much faster than np.histogram
3236
+ if img.ndim > MONO_CHANNEL_DIMENSIONS:
3237
+ channels = cv2.split(img)
3238
+ hists: list[np.ndarray] = []
3239
+ for i, channel in enumerate(channels):
3240
+ if ignore is not None and i == ignore:
3241
+ hists.append(None)
3242
+ continue
3243
+ mask = None if ignore is None else (channel != ignore)
3244
+ hist = cv2.calcHist([channel], [0], mask, [256], [0, max_value])
3245
+ hists.append(hist.ravel())
3246
+
3247
+ for i in range(num_channels):
3248
+ if ignore is not None and i == ignore:
3249
+ continue
3250
+
3251
+ if img.ndim > MONO_CHANNEL_DIMENSIONS:
3252
+ hist = hists[i]
3253
+ channel = channels[i]
3254
+ else:
3255
+ mask = None if ignore is None else (img != ignore)
3256
+ hist = cv2.calcHist([img], [0], mask, [256], [0, max_value]).ravel()
3257
+ channel = img
3258
+
3259
+ lo, hi = get_histogram_bounds(hist, cutoff)
3260
+ if hi <= lo:
3261
+ continue
3262
+
3263
+ lut = create_contrast_lut(hist, lo, hi, max_value, method)
3264
+ if ignore is not None:
3265
+ lut[ignore] = ignore
3266
+
3267
+ if img.ndim > MONO_CHANNEL_DIMENSIONS:
3268
+ result[..., i] = sz_lut(channel, lut)
3269
+ else:
3270
+ result = sz_lut(channel, lut)
3271
+
3272
+ return result
3273
+
3274
+
3275
+ def create_contrast_lut(
3276
+ hist: np.ndarray,
3277
+ min_intensity: int,
3278
+ max_intensity: int,
3279
+ max_value: int,
3280
+ method: Literal["cdf", "pil"],
3281
+ ) -> np.ndarray:
3282
+ """Create lookup table for contrast adjustment.
3283
+
3284
+ This function creates a lookup table for contrast adjustment.
3285
+
3286
+ Args:
3287
+ hist (np.ndarray): Histogram of the image.
3288
+ min_intensity (int): Minimum intensity of the histogram.
3289
+ max_intensity (int): Maximum intensity of the histogram.
3290
+ max_value (int): Maximum value of the lookup table.
3291
+ method (Literal["cdf", "pil"]): Method to use for contrast enhancement.
3292
+
3293
+ Returns:
3294
+ np.ndarray: Lookup table for contrast adjustment.
3295
+
3296
+ """
3297
+ if min_intensity >= max_intensity:
3298
+ return np.zeros(256, dtype=np.uint8)
3299
+
3300
+ if method == "cdf":
3301
+ hist_range = hist[min_intensity : max_intensity + 1]
3302
+ cdf = hist_range.cumsum()
3303
+
3304
+ if cdf[-1] == 0: # No valid pixels
3305
+ return np.arange(256, dtype=np.uint8)
3306
+
3307
+ # Normalize CDF to full range
3308
+ cdf = (cdf - cdf[0]) * max_value / (cdf[-1] - cdf[0])
3309
+
3310
+ # Create lookup table
3311
+ lut = np.zeros(256, dtype=np.uint8)
3312
+ lut[min_intensity : max_intensity + 1] = np.clip(np.round(cdf), 0, max_value).astype(np.uint8)
3313
+ lut[max_intensity + 1 :] = max_value
3314
+ return lut
3315
+
3316
+ # "pil" method
3317
+ scale = max_value / (max_intensity - min_intensity)
3318
+ indices = np.arange(256, dtype=float)
3319
+ # Changed: Use np.round to get 128 for middle value
3320
+ # Test expects [0, 128, 255] for range [0, 2]
3321
+ lut = np.clip(np.round((indices - min_intensity) * scale), 0, max_value).astype(np.uint8)
3322
+ lut[:min_intensity] = 0
3323
+ lut[max_intensity + 1 :] = max_value
3324
+ return lut
3325
+
3326
+
3327
+ def get_histogram_bounds(hist: np.ndarray, cutoff: float) -> tuple[int, int]:
3328
+ """Get the low and high bounds of the histogram.
3329
+
3330
+ This function gets the low and high bounds of the histogram.
3331
+
3332
+ Args:
3333
+ hist (np.ndarray): Histogram of the image.
3334
+ cutoff (float): Cutoff percentage for histogram.
3335
+
3336
+ Returns:
3337
+ tuple[int, int]: Low and high bounds of the histogram.
3338
+
3339
+ """
3340
+ if not cutoff:
3341
+ non_zero_intensities = np.nonzero(hist)[0]
3342
+ if len(non_zero_intensities) == 0:
3343
+ return 0, 0
3344
+ return int(non_zero_intensities[0]), int(non_zero_intensities[-1])
3345
+
3346
+ total_pixels = float(hist.sum())
3347
+ if total_pixels == 0:
3348
+ return 0, 0
3349
+
3350
+ pixels_to_cut = total_pixels * cutoff / 100.0
3351
+
3352
+ # Special case for uniform 256-bin histogram
3353
+ if len(hist) == 256 and np.all(hist == hist[0]):
3354
+ min_intensity = int(len(hist) * cutoff / 100) # floor division
3355
+ max_intensity = len(hist) - min_intensity - 1
3356
+ return min_intensity, max_intensity
3357
+
3358
+ # Find minimum intensity
3359
+ cumsum = 0.0
3360
+ min_intensity = 0
3361
+ for i in range(len(hist)):
3362
+ cumsum += hist[i]
3363
+ if cumsum >= pixels_to_cut: # Use >= for left bound
3364
+ min_intensity = i + 1
3365
+ break
3366
+ min_intensity = min(min_intensity, len(hist) - 1)
3367
+
3368
+ # Find maximum intensity
3369
+ cumsum = 0.0
3370
+ max_intensity = len(hist) - 1
3371
+ for i in range(len(hist) - 1, -1, -1):
3372
+ cumsum += hist[i]
3373
+ if cumsum >= pixels_to_cut: # Use >= for right bound
3374
+ max_intensity = i
3375
+ break
3376
+
3377
+ # Handle edge cases
3378
+ if min_intensity > max_intensity:
3379
+ mid_point = (len(hist) - 1) // 2
3380
+ return mid_point, mid_point
3381
+
3382
+ return min_intensity, max_intensity
3383
+
3384
+
3385
+ def get_drop_mask(
3386
+ shape: tuple[int, ...],
3387
+ per_channel: bool,
3388
+ dropout_prob: float,
3389
+ random_generator: np.random.Generator,
3390
+ ) -> np.ndarray:
3391
+ """Generate dropout mask.
3392
+
3393
+ This function generates a dropout mask.
3394
+
3395
+ Args:
3396
+ shape (tuple[int, ...]): Shape of the output mask
3397
+ per_channel (bool): Whether to apply dropout per channel
3398
+ dropout_prob (float): Dropout probability
3399
+ random_generator (np.random.Generator): Random number generator
3400
+
3401
+ Returns:
3402
+ np.ndarray: Dropout mask
3403
+
3404
+ """
3405
+ if per_channel or len(shape) == 2:
3406
+ return random_generator.choice(
3407
+ [True, False],
3408
+ shape,
3409
+ p=[dropout_prob, 1 - dropout_prob],
3410
+ )
3411
+
3412
+ # Generate 2D mask and expand to match channels
3413
+ mask_2d = random_generator.choice(
3414
+ [True, False],
3415
+ shape[:2],
3416
+ p=[dropout_prob, 1 - dropout_prob],
3417
+ )
3418
+
3419
+ # If input is 2D, return 2D mask
3420
+ if len(shape) == 2:
3421
+ return mask_2d
3422
+
3423
+ # For 3D input, expand and repeat across channels
3424
+ return np.repeat(mask_2d[..., None], shape[2], axis=2)
3425
+
3426
+
3427
+ def generate_random_values(
3428
+ channels: int,
3429
+ dtype: np.dtype,
3430
+ random_generator: np.random.Generator,
3431
+ ) -> np.ndarray:
3432
+ """Generate random values.
3433
+
3434
+ Args:
3435
+ channels (int): Number of channels
3436
+ dtype (np.dtype): Data type of the output array
3437
+ random_generator (np.random.Generator): Random number generator
3438
+
3439
+ Returns:
3440
+ np.ndarray: Random values
3441
+
3442
+ """
3443
+ if dtype == np.uint8:
3444
+ return random_generator.integers(
3445
+ 0,
3446
+ int(MAX_VALUES_BY_DTYPE[dtype]),
3447
+ size=channels,
3448
+ dtype=dtype,
3449
+ )
3450
+ if dtype == np.float32:
3451
+ return random_generator.uniform(0, 1, size=channels).astype(dtype)
3452
+
3453
+ raise ValueError(f"Unsupported dtype: {dtype}")
3454
+
3455
+
3456
+ def prepare_drop_values(
3457
+ array: np.ndarray,
3458
+ value: float | Sequence[float] | np.ndarray | None,
3459
+ random_generator: np.random.Generator,
3460
+ ) -> np.ndarray:
3461
+ """Prepare values to fill dropped pixels.
3462
+
3463
+ Args:
3464
+ array (np.ndarray): Input array to determine shape and dtype
3465
+ value (float | Sequence[float] | np.ndarray | None): User-specified drop values or None for random
3466
+ random_generator (np.random.Generator): Random number generator
3467
+
3468
+ Returns:
3469
+ np.ndarray: Array of values matching input shape
3470
+
3471
+ """
3472
+ if value is None:
3473
+ channels = get_num_channels(array)
3474
+ values = generate_random_values(channels, array.dtype, random_generator)
3475
+ elif isinstance(value, (int, float)):
3476
+ return np.full(array.shape, value, dtype=array.dtype)
3477
+ else:
3478
+ values = np.array(value, dtype=array.dtype).reshape(-1)
3479
+
3480
+ # For monochannel input, return single value
3481
+ if array.ndim == 2:
3482
+ return np.full(array.shape, values[0], dtype=array.dtype)
3483
+
3484
+ # For multichannel input, broadcast values to full shape
3485
+ return np.full((*array.shape[:2], len(values)), values, dtype=array.dtype)
3486
+
3487
+
3488
+ def get_mask_array(data: dict[str, Any]) -> np.ndarray | None:
3489
+ """Get mask array from input data if it exists."""
3490
+ if "mask" in data:
3491
+ return data["mask"]
3492
+ return data["masks"][0] if "masks" in data else None
3493
+
3494
+
3495
+ def get_rain_params(
3496
+ liquid_layer: np.ndarray,
3497
+ color: np.ndarray,
3498
+ intensity: float,
3499
+ ) -> dict[str, Any]:
3500
+ """Generate parameters for rain effect.
3501
+
3502
+ This function generates parameters for a rain effect.
3503
+
3504
+ Args:
3505
+ liquid_layer (np.ndarray): Liquid layer of the image.
3506
+ color (np.ndarray): Color of the rain.
3507
+ intensity (float): Intensity of the rain.
3508
+
3509
+ Returns:
3510
+ dict[str, Any]: Parameters for the rain effect.
3511
+
3512
+ """
3513
+ liquid_layer = clip(liquid_layer * 255, np.uint8, inplace=False)
3514
+
3515
+ # Generate distance transform with more defined edges
3516
+ dist = 255 - cv2.Canny(liquid_layer, 50, 150)
3517
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
3518
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
3519
+
3520
+ # Use separate blur operations for better drop formation
3521
+ dist = cv2.GaussianBlur(
3522
+ dist,
3523
+ ksize=(3, 3),
3524
+ sigmaX=1, # Add slight sigma for smoother drops
3525
+ sigmaY=1,
3526
+ borderType=cv2.BORDER_REPLICATE,
3527
+ )
3528
+ dist = clip(dist, np.uint8, inplace=True)
3529
+
3530
+ # Enhance contrast in the distance map
3531
+ dist = equalize(dist)
3532
+
3533
+ # Modified kernel for more natural drop shapes
3534
+ ker = np.array(
3535
+ [
3536
+ [-2, -1, 0],
3537
+ [-1, 1, 1],
3538
+ [0, 1, 2],
3539
+ ],
3540
+ dtype=np.float32,
3541
+ )
3542
+
3543
+ # Apply convolution with better precision
3544
+ dist = convolve(dist, ker)
3545
+
3546
+ # Final blur with larger kernel for smoother drops
3547
+ dist = cv2.GaussianBlur(
3548
+ dist,
3549
+ ksize=(5, 5), # Increased kernel size
3550
+ sigmaX=1.5, # Adjusted sigma
3551
+ sigmaY=1.5,
3552
+ borderType=cv2.BORDER_REPLICATE,
3553
+ ).astype(np.float32)
3554
+
3555
+ # Calculate final rain mask with better blending
3556
+ m = liquid_layer.astype(np.float32) * dist
3557
+
3558
+ # Normalize with better handling of edge cases
3559
+ m_max = np.max(m, axis=(0, 1))
3560
+ if m_max > 0:
3561
+ m *= 1 / m_max
3562
+ else:
3563
+ m = np.zeros_like(m)
3564
+
3565
+ # Apply color with adjusted intensity for more natural look
3566
+ drops = m[:, :, None] * color * (intensity * 0.9) # Slightly reduced intensity
3567
+
3568
+ return {
3569
+ "drops": drops,
3570
+ }
3571
+
3572
+
3573
+ def get_mud_params(
3574
+ liquid_layer: np.ndarray,
3575
+ color: np.ndarray,
3576
+ cutout_threshold: float,
3577
+ sigma: float,
3578
+ intensity: float,
3579
+ random_generator: np.random.Generator,
3580
+ ) -> dict[str, Any]:
3581
+ """Generate parameters for mud effect.
3582
+
3583
+ This function generates parameters for a mud effect.
3584
+
3585
+ Args:
3586
+ liquid_layer (np.ndarray): Liquid layer of the image.
3587
+ color (np.ndarray): Color of the mud.
3588
+ cutout_threshold (float): Cutout threshold for the mud.
3589
+ sigma (float): Sigma for the Gaussian blur.
3590
+ intensity (float): Intensity of the mud.
3591
+ random_generator (np.random.Generator): Random number generator.
3592
+
3593
+ Returns:
3594
+ dict[str, Any]: Parameters for the mud effect.
3595
+
3596
+ """
3597
+ height, width = liquid_layer.shape
3598
+
3599
+ # Create initial mask (ensure we have some non-zero values)
3600
+ mask = (liquid_layer > cutout_threshold).astype(np.float32)
3601
+ if np.sum(mask) == 0: # If mask is all zeros
3602
+ # Force minimum coverage of 10%
3603
+ num_pixels = height * width
3604
+ num_needed = max(1, int(0.1 * num_pixels)) # At least 1 pixel
3605
+ flat_indices = random_generator.choice(num_pixels, num_needed, replace=False)
3606
+ mask = np.zeros_like(liquid_layer, dtype=np.float32)
3607
+ mask.flat[flat_indices] = 1.0
3608
+
3609
+ # Apply Gaussian blur if sigma > 0
3610
+ if sigma > 0:
3611
+ mask = cv2.GaussianBlur(
3612
+ mask,
3613
+ ksize=(0, 0),
3614
+ sigmaX=sigma,
3615
+ sigmaY=sigma,
3616
+ borderType=cv2.BORDER_REPLICATE,
3617
+ )
3618
+
3619
+ # Safe normalization (avoid division by zero)
3620
+ mask_max = np.max(mask)
3621
+ if mask_max > 0:
3622
+ mask = mask / mask_max
3623
+ else:
3624
+ # If mask is somehow all zeros after blur, force some effect
3625
+ mask[0, 0] = 1.0
3626
+
3627
+ # Scale by intensity directly (no minimum)
3628
+ mask = mask * intensity
3629
+
3630
+ # Create mud effect array
3631
+ mud = np.zeros((height, width, 3), dtype=np.float32)
3632
+
3633
+ # Apply color directly - the intensity scaling is already handled
3634
+ for i in range(3):
3635
+ mud[..., i] = mask * color[i]
3636
+
3637
+ # Create complementary non-mud array
3638
+ non_mud = np.ones_like(mud)
3639
+ for i in range(3):
3640
+ if color[i] > 0:
3641
+ non_mud[..., i] = np.clip((color[i] - mud[..., i]) / color[i], 0, 1)
3642
+ else:
3643
+ non_mud[..., i] = 1.0 - mask
3644
+
3645
+ return {
3646
+ "mud": mud.astype(np.float32),
3647
+ "non_mud": non_mud.astype(np.float32),
3648
+ }
3649
+
3650
+
3651
+ # Standard reference H&E stain matrices
3652
+ STAIN_MATRICES = {
3653
+ "ruifrok": np.array(
3654
+ [ # Ruifrok & Johnston standard reference
3655
+ [0.644211, 0.716556, 0.266844], # Hematoxylin
3656
+ [0.092789, 0.954111, 0.283111], # Eosin
3657
+ ],
3658
+ ),
3659
+ "macenko": np.array(
3660
+ [ # Macenko's reference
3661
+ [0.5626, 0.7201, 0.4062],
3662
+ [0.2159, 0.8012, 0.5581],
3663
+ ],
3664
+ ),
3665
+ "standard": np.array(
3666
+ [ # Standard bright-field microscopy
3667
+ [0.65, 0.70, 0.29],
3668
+ [0.07, 0.99, 0.11],
3669
+ ],
3670
+ ),
3671
+ "high_contrast": np.array(
3672
+ [ # Enhanced contrast
3673
+ [0.55, 0.88, 0.11],
3674
+ [0.12, 0.86, 0.49],
3675
+ ],
3676
+ ),
3677
+ "h_heavy": np.array(
3678
+ [ # Hematoxylin dominant
3679
+ [0.75, 0.61, 0.32],
3680
+ [0.04, 0.93, 0.36],
3681
+ ],
3682
+ ),
3683
+ "e_heavy": np.array(
3684
+ [ # Eosin dominant
3685
+ [0.60, 0.75, 0.28],
3686
+ [0.17, 0.95, 0.25],
3687
+ ],
3688
+ ),
3689
+ "dark": np.array(
3690
+ [ # Darker staining
3691
+ [0.78, 0.55, 0.28],
3692
+ [0.09, 0.97, 0.21],
3693
+ ],
3694
+ ),
3695
+ "light": np.array(
3696
+ [ # Lighter staining
3697
+ [0.57, 0.71, 0.38],
3698
+ [0.15, 0.89, 0.42],
3699
+ ],
3700
+ ),
3701
+ }
3702
+
3703
+
3704
+ def rgb_to_optical_density(img: np.ndarray, eps: float = 1e-6) -> np.ndarray:
3705
+ """Convert RGB image to optical density.
3706
+
3707
+ This function converts an RGB image to optical density.
3708
+
3709
+ Args:
3710
+ img (np.ndarray): Input image.
3711
+ eps (float): Epsilon value.
3712
+
3713
+ Returns:
3714
+ np.ndarray: Optical density image.
3715
+
3716
+ """
3717
+ max_value = MAX_VALUES_BY_DTYPE[img.dtype]
3718
+ pixel_matrix = img.reshape(-1, 3).astype(np.float32)
3719
+ pixel_matrix = np.maximum(pixel_matrix / max_value, eps)
3720
+ return -np.log(pixel_matrix)
3721
+
3722
+
3723
+ def normalize_vectors(vectors: np.ndarray) -> np.ndarray:
3724
+ """Normalize vectors.
3725
+
3726
+ This function normalizes vectors.
3727
+
3728
+ Args:
3729
+ vectors (np.ndarray): Vectors to normalize.
3730
+
3731
+ Returns:
3732
+ np.ndarray: Normalized vectors.
3733
+
3734
+ """
3735
+ norms = np.sqrt(np.sum(vectors**2, axis=1, keepdims=True))
3736
+ return vectors / norms
3737
+
3738
+
3739
+ def get_normalizer(method: Literal["vahadane", "macenko"]) -> StainNormalizer:
3740
+ """Get stain normalizer based on method.
3741
+
3742
+ This function gets a stain normalizer based on a method.
3743
+
3744
+ Args:
3745
+ method (Literal["vahadane", "macenko"]): Method to use for stain normalization.
3746
+
3747
+ Returns:
3748
+ StainNormalizer: Stain normalizer.
3749
+
3750
+ """
3751
+ return VahadaneNormalizer() if method == "vahadane" else MacenkoNormalizer()
3752
+
3753
+
3754
+ class StainNormalizer:
3755
+ """Base class for stain normalizers."""
3756
+
3757
+ def __init__(self) -> None:
3758
+ self.stain_matrix_target = None
3759
+
3760
+ def fit(self, img: np.ndarray) -> None:
3761
+ """Fit the stain normalizer to an image.
3762
+
3763
+ This function fits the stain normalizer to an image.
3764
+
3765
+ Args:
3766
+ img (np.ndarray): Input image.
3767
+
3768
+ """
3769
+ raise NotImplementedError
3770
+
3771
+
3772
+ class SimpleNMF:
3773
+ """Simple Non-negative Matrix Factorization (NMF) for histology stain separation.
3774
+
3775
+ This class implements a simplified version of the Non-negative Matrix Factorization algorithm
3776
+ specifically designed for separating Hematoxylin and Eosin (H&E) stains in histopathology images.
3777
+ It is used as part of the Vahadane stain normalization method.
3778
+
3779
+ The algorithm decomposes optical density values of H&E stained images into stain color appearances
3780
+ (the stain color vectors) and stain concentrations (the density of each stain at each pixel).
3781
+
3782
+ The implementation uses an iterative multiplicative update approach that preserves non-negativity
3783
+ constraints, which are physically meaningful for stain separation as concentrations and
3784
+ absorption coefficients cannot be negative.
3785
+
3786
+ This implementation is optimized for stability by:
3787
+ 1. Initializing with standard H&E reference colors from Ruifrok
3788
+ 2. Using normalized projection for initial concentrations
3789
+ 3. Applying careful normalization to avoid numerical issues
3790
+
3791
+ Args:
3792
+ n_iter (int): Number of iterations for the NMF algorithm. Default: 100
3793
+
3794
+ References:
3795
+ - Vahadane, A., et al. (2016): Structure-preserving color normalization and
3796
+ sparse stain separation for histological images. IEEE Transactions on
3797
+ Medical Imaging, 35(8), 1962-1971.
3798
+ - Ruifrok, A. C., & Johnston, D. A. (2001): Quantification of histochemical
3799
+ staining by color deconvolution. Analytical and Quantitative Cytology and
3800
+ Histology, 23(4), 291-299.
3801
+
3802
+ """
3803
+
3804
+ def __init__(self, n_iter: int = 100):
3805
+ self.n_iter = n_iter
3806
+ # Initialize with standard H&E colors from Ruifrok
3807
+ self.initial_colors = np.array(
3808
+ [
3809
+ [0.644211, 0.716556, 0.266844], # Hematoxylin
3810
+ [0.092789, 0.954111, 0.283111], # Eosin
3811
+ ],
3812
+ dtype=np.float32,
3813
+ )
3814
+
3815
+ def fit_transform(self, optical_density: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
3816
+ """Fit the NMF model to optical density.
3817
+
3818
+ This function fits the NMF model to optical density.
3819
+
3820
+ Args:
3821
+ optical_density (np.ndarray): Optical density image.
3822
+
3823
+ Returns:
3824
+ tuple[np.ndarray, np.ndarray]: Stain concentrations and stain colors.
3825
+
3826
+ """
3827
+ # Start with known H&E colors
3828
+ stain_colors = self.initial_colors.copy()
3829
+
3830
+ # Initialize concentrations based on projection onto initial colors
3831
+ # This gives us a physically meaningful starting point
3832
+ stain_colors_normalized = normalize_vectors(stain_colors)
3833
+ stain_concentrations = np.maximum(optical_density @ stain_colors_normalized.T, 0)
3834
+
3835
+ # Iterative updates with careful normalization
3836
+ eps = 1e-6
3837
+ for _ in range(self.n_iter):
3838
+ # Update concentrations
3839
+ numerator = optical_density @ stain_colors.T
3840
+ denominator = stain_concentrations @ (stain_colors @ stain_colors.T)
3841
+ stain_concentrations *= numerator / (denominator + eps)
3842
+
3843
+ # Ensure non-negativity
3844
+ stain_concentrations = np.maximum(stain_concentrations, 0)
3845
+
3846
+ # Update colors
3847
+ numerator = stain_concentrations.T @ optical_density
3848
+ denominator = (stain_concentrations.T @ stain_concentrations) @ stain_colors
3849
+ stain_colors *= numerator / (denominator + eps)
3850
+
3851
+ # Ensure non-negativity and normalize
3852
+ stain_colors = np.maximum(stain_colors, 0)
3853
+ stain_colors = normalize_vectors(stain_colors)
3854
+
3855
+ return stain_concentrations, stain_colors
3856
+
3857
+
3858
+ def order_stains_combined(stain_colors: np.ndarray) -> tuple[int, int]:
3859
+ """Order stains using a combination of methods.
3860
+
3861
+ This combines both angular information and spectral characteristics
3862
+ for more robust identification.
3863
+
3864
+ Args:
3865
+ stain_colors (np.ndarray): Stain colors.
3866
+
3867
+ Returns:
3868
+ tuple[int, int]: Hematoxylin and eosin indices.
3869
+
3870
+ """
3871
+ # Normalize stain vectors
3872
+ stain_colors = normalize_vectors(stain_colors)
3873
+
3874
+ # Calculate angles (Macenko)
3875
+ angles = np.mod(np.arctan2(stain_colors[:, 1], stain_colors[:, 0]), np.pi)
3876
+
3877
+ # Calculate spectral ratios (Ruifrok)
3878
+ blue_ratio = stain_colors[:, 2] / (np.sum(stain_colors, axis=1) + 1e-6)
3879
+ red_ratio = stain_colors[:, 0] / (np.sum(stain_colors, axis=1) + 1e-6)
3880
+
3881
+ # Combine scores
3882
+ # High angle and high blue ratio indicates Hematoxylin
3883
+ # Low angle and high red ratio indicates Eosin
3884
+ scores = angles * blue_ratio - red_ratio
3885
+
3886
+ hematoxylin_idx = np.argmax(scores)
3887
+ eosin_idx = 1 - hematoxylin_idx
3888
+
3889
+ return hematoxylin_idx, eosin_idx
3890
+
3891
+
3892
+ class VahadaneNormalizer(StainNormalizer):
3893
+ """A stain normalizer implementation based on Vahadane's method for histopathology images.
3894
+
3895
+ This class implements the "Structure-Preserving Color Normalization and Sparse Stain Separation
3896
+ for Histological Images" method proposed by Vahadane et al. The technique uses Non-negative
3897
+ Matrix Factorization (NMF) to separate Hematoxylin and Eosin (H&E) stains in histopathology
3898
+ images and then normalizes them to a target standard.
3899
+
3900
+ The Vahadane method is particularly effective for histology image normalization because:
3901
+ 1. It maintains tissue structure during color normalization
3902
+ 2. It performs sparse stain separation, reducing color bleeding
3903
+ 3. It adaptively estimates stain vectors from each image
3904
+ 4. It preserves biologically relevant information
3905
+
3906
+ This implementation uses SimpleNMF as its core matrix factorization algorithm to extract
3907
+ stain color vectors (appearance matrix) and concentration matrices from optical
3908
+ density-transformed images. It identifies the Hematoxylin and Eosin stains by their
3909
+ characteristic color profiles and spatial distribution.
3910
+
3911
+ References:
3912
+ Vahadane, et al., 2016: Structure-preserving color normalization
3913
+ and sparse stain separation for histological images. IEEE transactions on medical imaging,
3914
+ 35(8), pp.1962-1971.
3915
+
3916
+ Examples:
3917
+ >>> import numpy as np
3918
+ >>> import albumentations as A
3919
+ >>> from albumentations.augmentations.pixel import functional as F
3920
+ >>> import cv2
3921
+ >>>
3922
+ >>> # Load source and target images (H&E stained histopathology)
3923
+ >>> source_img = cv2.imread('source_image.png')
3924
+ >>> source_img = cv2.cvtColor(source_img, cv2.COLOR_BGR2RGB)
3925
+ >>> target_img = cv2.imread('target_image.png')
3926
+ >>> target_img = cv2.cvtColor(target_img, cv2.COLOR_BGR2RGB)
3927
+ >>>
3928
+ >>> # Create and fit the normalizer to the target image
3929
+ >>> normalizer = F.VahadaneNormalizer()
3930
+ >>> normalizer.fit(target_img)
3931
+ >>>
3932
+ >>> # Normalize the source image to match the target's stain characteristics
3933
+ >>> normalized_img = normalizer.transform(source_img)
3934
+
3935
+ """
3936
+
3937
+ def fit(self, img: np.ndarray) -> None:
3938
+ """Fit the Vahadane stain normalizer to an image.
3939
+
3940
+ This function fits the Vahadane stain normalizer to an image.
3941
+
3942
+ Args:
3943
+ img (np.ndarray): Input image.
3944
+
3945
+ """
3946
+ optical_density = rgb_to_optical_density(img)
3947
+
3948
+ nmf = SimpleNMF(n_iter=100)
3949
+ _, stain_colors = nmf.fit_transform(optical_density)
3950
+
3951
+ # Use combined method for robust stain ordering
3952
+ hematoxylin_idx, eosin_idx = order_stains_combined(stain_colors)
3953
+
3954
+ self.stain_matrix_target = np.array(
3955
+ [
3956
+ stain_colors[hematoxylin_idx],
3957
+ stain_colors[eosin_idx],
3958
+ ],
3959
+ )
3960
+
3961
+
3962
+ class MacenkoNormalizer(StainNormalizer):
3963
+ """Macenko stain normalizer with optimized computations."""
3964
+
3965
+ def __init__(self, angular_percentile: float = 99):
3966
+ super().__init__()
3967
+ self.angular_percentile = angular_percentile
3968
+
3969
+ def fit(self, img: np.ndarray, angular_percentile: float = 99) -> None:
3970
+ """Fit the Macenko stain normalizer to an image.
3971
+
3972
+ This function fits the Macenko stain normalizer to an image.
3973
+
3974
+ Args:
3975
+ img (np.ndarray): Input image.
3976
+ angular_percentile (float): Angular percentile.
3977
+
3978
+ """
3979
+ # Step 1: Convert RGB to optical density (OD) space
3980
+ optical_density = rgb_to_optical_density(img)
3981
+
3982
+ # Step 2: Remove background pixels
3983
+ od_threshold = 0.05
3984
+ threshold_mask = (optical_density > od_threshold).any(axis=1)
3985
+ tissue_density = optical_density[threshold_mask]
3986
+
3987
+ if len(tissue_density) < 1:
3988
+ raise ValueError(f"No tissue pixels found (threshold={od_threshold})")
3989
+
3990
+ # Step 3: Compute covariance matrix
3991
+ tissue_density = np.ascontiguousarray(tissue_density, dtype=np.float32)
3992
+ od_covariance = cv2.calcCovarMatrix(
3993
+ tissue_density,
3994
+ None,
3995
+ cv2.COVAR_NORMAL | cv2.COVAR_ROWS | cv2.COVAR_SCALE,
3996
+ )[0]
3997
+
3998
+ # Step 4: Get principal components
3999
+ eigenvalues, eigenvectors = cv2.eigen(od_covariance)[1:]
4000
+ idx = np.argsort(eigenvalues.ravel())[-2:]
4001
+ principal_eigenvectors = np.ascontiguousarray(eigenvectors[:, idx], dtype=np.float32)
4002
+
4003
+ # Step 5: Project onto eigenvector plane
4004
+ # Add small epsilon to avoid numerical instability
4005
+ epsilon = 1e-8
4006
+ if np.any(np.abs(principal_eigenvectors) < epsilon):
4007
+ # Regularize near-zero entries by assigning ±ε based on original sign
4008
+ principal_eigenvectors = np.where(
4009
+ np.abs(principal_eigenvectors) < epsilon,
4010
+ np.where(principal_eigenvectors < 0, -epsilon, epsilon),
4011
+ principal_eigenvectors,
4012
+ )
4013
+
4014
+ # Add small epsilon to tissue_density to avoid numerical issues
4015
+ safe_tissue_density = tissue_density + epsilon
4016
+ plane_coordinates = safe_tissue_density @ principal_eigenvectors
4017
+
4018
+ # Step 6: Find angles of extreme points
4019
+ polar_angles = np.arctan2(
4020
+ plane_coordinates[:, 1],
4021
+ plane_coordinates[:, 0],
4022
+ )
4023
+
4024
+ # Get robust angle estimates
4025
+ hematoxylin_angle = np.percentile(polar_angles, 100 - angular_percentile)
4026
+ eosin_angle = np.percentile(polar_angles, angular_percentile)
4027
+
4028
+ # Step 7: Convert angles back to RGB space
4029
+ hem_cos, hem_sin = np.cos(hematoxylin_angle), np.sin(hematoxylin_angle)
4030
+ eos_cos, eos_sin = np.cos(eosin_angle), np.sin(eosin_angle)
4031
+
4032
+ angle_to_vector = np.array(
4033
+ [[hem_cos, hem_sin], [eos_cos, eos_sin]],
4034
+ dtype=np.float32,
4035
+ )
4036
+
4037
+ # Ensure both matrices have the same data type for cv2.gemm
4038
+ principal_eigenvectors_t = np.ascontiguousarray(principal_eigenvectors.T, dtype=np.float32)
4039
+ stain_vectors = cv2.gemm(
4040
+ angle_to_vector,
4041
+ principal_eigenvectors_t,
4042
+ 1,
4043
+ None,
4044
+ 0,
4045
+ )
4046
+
4047
+ # Step 8: Ensure non-negativity by taking absolute values
4048
+ stain_vectors = np.abs(stain_vectors)
4049
+
4050
+ # Step 9: Normalize vectors to unit length
4051
+ stain_vectors = stain_vectors / np.sqrt(np.sum(stain_vectors**2, axis=1, keepdims=True) + epsilon)
4052
+
4053
+ # Step 10: Order vectors as [hematoxylin, eosin]
4054
+ self.stain_matrix_target = stain_vectors if stain_vectors[0, 0] > stain_vectors[1, 0] else stain_vectors[::-1]
4055
+
4056
+
4057
+ def get_tissue_mask(img: np.ndarray, threshold: float = 0.85) -> np.ndarray:
4058
+ """Get tissue mask from image.
4059
+
4060
+ Args:
4061
+ img (np.ndarray): Input image
4062
+ threshold (float): Threshold for tissue detection. Default: 0.85
4063
+
4064
+ Returns:
4065
+ np.ndarray: Binary mask where True indicates tissue regions
4066
+
4067
+ """
4068
+ # Convert to grayscale using RGB weights: R*0.299 + G*0.587 + B*0.114
4069
+ luminosity = img[..., 0] * 0.299 + img[..., 1] * 0.587 + img[..., 2] * 0.114
4070
+
4071
+ # Tissue is darker, so we want pixels below threshold
4072
+ mask = luminosity < threshold
4073
+
4074
+ return mask.reshape(-1)
4075
+
4076
+
4077
+ @clipped
4078
+ @float32_io
4079
+ def apply_he_stain_augmentation(
4080
+ img: np.ndarray,
4081
+ stain_matrix: np.ndarray,
4082
+ scale_factors: np.ndarray,
4083
+ shift_values: np.ndarray,
4084
+ augment_background: bool,
4085
+ ) -> np.ndarray:
4086
+ """Apply HE stain augmentation to an image.
4087
+
4088
+ This function applies HE stain augmentation to an image.
4089
+
4090
+ Args:
4091
+ img (np.ndarray): Input image.
4092
+ stain_matrix (np.ndarray): Stain matrix.
4093
+ scale_factors (np.ndarray): Scale factors.
4094
+ shift_values (np.ndarray): Shift values.
4095
+ augment_background (bool): Whether to augment the background.
4096
+
4097
+ Returns:
4098
+ np.ndarray: Augmented image.
4099
+
4100
+ """
4101
+ # Step 1: Convert RGB to optical density space
4102
+ optical_density = rgb_to_optical_density(img)
4103
+
4104
+ # Step 2: Calculate stain concentrations using regularized pseudo-inverse
4105
+ stain_matrix = np.ascontiguousarray(stain_matrix, dtype=np.float32)
4106
+
4107
+ # Add small regularization term for numerical stability
4108
+ regularization = 1e-6
4109
+ stain_correlation = stain_matrix @ stain_matrix.T + regularization * np.eye(2)
4110
+ density_projection = stain_matrix @ optical_density.T
4111
+
4112
+ try:
4113
+ # Solve for stain concentrations
4114
+ stain_concentrations = np.linalg.solve(stain_correlation, density_projection).T
4115
+ except np.linalg.LinAlgError:
4116
+ # Fallback to pseudo-inverse if direct solve fails
4117
+ stain_concentrations = np.linalg.lstsq(
4118
+ stain_matrix.T,
4119
+ optical_density,
4120
+ rcond=regularization,
4121
+ )[0].T
4122
+
4123
+ # Step 3: Apply concentration adjustments
4124
+ if not augment_background:
4125
+ # Only modify tissue regions
4126
+ tissue_mask = get_tissue_mask(img).reshape(-1)
4127
+ stain_concentrations[tissue_mask] = stain_concentrations[tissue_mask] * scale_factors + shift_values
4128
+ else:
4129
+ # Modify all pixels
4130
+ stain_concentrations = stain_concentrations * scale_factors + shift_values
4131
+
4132
+ # Step 4: Reconstruct RGB image
4133
+ optical_density_result = stain_concentrations @ stain_matrix
4134
+ rgb_result = np.exp(-optical_density_result)
4135
+
4136
+ return rgb_result.reshape(img.shape)
4137
+
4138
+
4139
+ @clipped
4140
+ @preserve_channel_dim
4141
+ def convolve(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
4142
+ """Convolve an image with a kernel.
4143
+
4144
+ This function convolves an image with a kernel.
4145
+
4146
+ Args:
4147
+ img (np.ndarray): Input image.
4148
+ kernel (np.ndarray): Kernel.
4149
+
4150
+ Returns:
4151
+ np.ndarray: Convolved image.
4152
+
4153
+ """
4154
+ conv_fn = maybe_process_in_chunks(cv2.filter2D, ddepth=-1, kernel=kernel)
4155
+ return conv_fn(img)
4156
+
4157
+
4158
+ @clipped
4159
+ @preserve_channel_dim
4160
+ def separable_convolve(img: np.ndarray, kernel: np.ndarray) -> np.ndarray:
4161
+ """Convolve an image with a separable kernel.
4162
+
4163
+ This function convolves an image with a separable kernel.
4164
+
4165
+ Args:
4166
+ img (np.ndarray): Input image.
4167
+ kernel (np.ndarray): Kernel.
4168
+
4169
+ Returns:
4170
+ np.ndarray: Convolved image.
4171
+
4172
+ """
4173
+ conv_fn = maybe_process_in_chunks(cv2.sepFilter2D, ddepth=-1, kernelX=kernel, kernelY=kernel)
4174
+ return conv_fn(img)
4175
+
4176
+
4177
+ def normalize_dispatch(
4178
+ img: np.ndarray,
4179
+ normalization: Literal["standard", "image", "image_per_channel", "min_max", "min_max_per_channel"],
4180
+ normalize_fn: Callable[[np.ndarray, str], np.ndarray],
4181
+ mean: np.ndarray | None = None,
4182
+ denominator: np.ndarray | None = None,
4183
+ **params: Any,
4184
+ ) -> np.ndarray:
4185
+ """Dispatch normalization to the appropriate method based on normalization type.
4186
+
4187
+ This function acts as a dispatcher that either applies standard normalization using
4188
+ provided mean and standard deviation values, or delegates to a specific normalization
4189
+ function for other normalization types.
4190
+
4191
+ Args:
4192
+ img (np.ndarray): Input data to normalize. Can be:
4193
+ - Single image: (H, W) or (H, W, C)
4194
+ - Batch of images: (N, H, W) or (N, H, W, C)
4195
+ - Single volume: (D, H, W) or (D, H, W, C)
4196
+ - Batch of volumes: (N, D, H, W) or (N, D, H, W, C)
4197
+ normalization (Literal["standard", "image", "image_per_channel", "min_max", "min_max_per_channel"]):
4198
+ Type of normalization to apply:
4199
+ - "standard": Use provided mean and std values
4200
+ - "image": Normalize using global image statistics
4201
+ - "image_per_channel": Normalize each channel separately
4202
+ - "min_max": Scale to [0, 1] using global min/max
4203
+ - "min_max_per_channel": Scale each channel to [0, 1]
4204
+ normalize_fn (Callable[[np.ndarray, str], np.ndarray]): Function to use for non-standard normalization.
4205
+ Should accept (img, normalization_type) as arguments and return normalized array.
4206
+ mean (np.ndarray | None): Mean values for standard normalization.
4207
+ Required when normalization="standard", ignored otherwise.
4208
+ denominator (np.ndarray | None): Reciprocal of standard deviation for standard normalization.
4209
+ Required when normalization="standard", ignored otherwise.
4210
+ **params (Any): Additional parameters passed to the normalization function.
4211
+
4212
+ Returns:
4213
+ np.ndarray: Normalized data with the same shape as input.
4214
+
4215
+ Note:
4216
+ - For standard normalization, the formula used is: (img - mean) * denominator
4217
+ - The denominator is the reciprocal of std for computational efficiency
4218
+ - Channel conversion is handled automatically based on the number of channels
4219
+
4220
+ """
4221
+ if normalization == "standard":
4222
+ num_channels = get_num_channels(img)
4223
+ denominator = convert_value(denominator, num_channels)
4224
+ mean = convert_value(mean, num_channels)
4225
+ return normalize(img, mean, denominator)
4226
+ return normalize_fn(img, normalization)