stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,125 @@
1
+ """ Configuration file for the project. """
2
+
3
+ # Imports
4
+ import os
5
+ from typing import Literal
6
+
7
+ from stouputils.decorators import LogLevels
8
+ from stouputils.io import get_root_path
9
+
10
+ # Environment variables
11
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "9" # Suppress TensorFlow logging
12
+ os.environ["GRPC_VERBOSITY"] = "ERROR" # Suppress gRPC logging
13
+
14
+
15
+ # Configuration class
16
+ class DataScienceConfig:
17
+ """ Configuration class for the project. """
18
+
19
+ # Common
20
+ SEED: int = 42
21
+ """ Seed for the random number generator. """
22
+
23
+ ERROR_LOG: LogLevels = LogLevels.WARNING_TRACEBACK
24
+ """ Log level for errors for all functions. """
25
+
26
+ AUGMENTED_FILE_SUFFIX: str = "_aug_"
27
+ """ Suffix for augmented files, e.g. 'image_008_aug_1.png'. """
28
+
29
+ AUGMENTED_DIRECTORY_PREFIX: str = "aug_"
30
+ """ Prefix for augmented directories, e.g. 'data/hip_implant' -> 'data/aug_hip_implant'. """
31
+
32
+ PREPROCESSED_DIRECTORY_SUFFIX: str = "_preprocessed"
33
+ """ Suffix for preprocessed directories, e.g. 'data/hip_implant' -> 'data/hip_implant_preprocessed'. """
34
+
35
+
36
+ # Directories
37
+ ROOT: str = get_root_path(__file__, go_up=3)
38
+ """ Root directory of the project. """
39
+
40
+ MLFLOW_FOLDER: str = f"{ROOT}/mlruns"
41
+ """ Folder containing the mlflow data. """
42
+ MLFLOW_URI: str = f"file://{MLFLOW_FOLDER}"
43
+ """ URI to the mlflow data. """
44
+
45
+ DATA_FOLDER: str = f"{ROOT}/data"
46
+ """ Folder containing all the data (e.g. subfolders containing images). """
47
+
48
+ TEMP_FOLDER: str = f"{ROOT}/temp"
49
+ """ Folder containing temporary files (e.g. models checkpoints, plots, etc.). """
50
+
51
+ LOGS_FOLDER: str = f"{ROOT}/logs"
52
+ """ Folder containing the logs. """
53
+
54
+ TENSORBOARD_FOLDER: str = f"{ROOT}/tensorboard"
55
+ """ Folder containing the tensorboard logs. """
56
+
57
+
58
+ # Behaviours
59
+ TEST_SIZE: float = 0.2
60
+ """ Size of the test set by default (0.2 means 80% training, 20% test). """
61
+
62
+ VALIDATION_SIZE: float = 0.2
63
+ """ Size of the validation set by default (0.2 means 80% training, 20% validation). """
64
+
65
+ # Machine learning
66
+ SAVE_MODEL: bool = False
67
+ """ If the model should be saved in the mlflow folder using mlflow.*.save_model. """
68
+
69
+ DO_SALIENCY_AND_GRADCAM: bool = True
70
+ """ If the saliency and gradcam should be done during the run. """
71
+
72
+ DO_LEARNING_RATE_FINDER: Literal[0, 1, 2] = 1
73
+ """ If the learning rate finder should be done during the run.
74
+ 0: no, 1: only plot, 2: plot and use value for the remaining run
75
+ """
76
+
77
+ DO_UNFREEZE_FINDER: Literal[0, 1, 2] = 0
78
+ """ If the unfreeze finder should be done during the run.
79
+ 0: no, 1: only plot, 2: plot and use value for the remaining run
80
+ """
81
+
82
+ DO_FIT_IN_SUBPROCESS: bool = True
83
+ """ If the model should be fitted in a subprocess.
84
+ Is memory efficient, and more stable. Turn it off only if you are having issues.
85
+
86
+ Note: This allow a program to make lots of runs without getting killed by the OS for using too much resources.
87
+ (e.g. LeaveOneOut Cross Validation, Grid Search, etc.)
88
+ """
89
+
90
+ MIXED_PRECISION_POLICY: Literal["mixed_float16", "mixed_bfloat16", "float32"] = "mixed_float16"
91
+ """ Mixed precision policy to use. Turn back to "float32" if you are having issues with a specific model or metrics.
92
+ See: https://www.tensorflow.org/guide/mixed_precision
93
+ """
94
+
95
+ TENSORFLOW_DEVICE: str = "/gpu:1"
96
+ """ TensorFlow device to use. """
97
+
98
+
99
+
100
+ @classmethod
101
+ def update_root(cls, new_root: str) -> None:
102
+ """ Update the root directory and recalculate all dependent paths.
103
+
104
+ Args:
105
+ new_root: The new root directory path.
106
+ """
107
+ cls.ROOT = new_root
108
+
109
+ # Update all paths that depend on ROOT
110
+ cls.MLFLOW_FOLDER = f"{cls.ROOT}/mlruns"
111
+ cls.MLFLOW_URI = f"file://{cls.MLFLOW_FOLDER}"
112
+ cls.DATA_FOLDER = f"{cls.ROOT}/data"
113
+ cls.TEMP_FOLDER = f"{cls.ROOT}/temp"
114
+ cls.LOGS_FOLDER = f"{cls.ROOT}/logs"
115
+ cls.TENSORBOARD_FOLDER = f"{cls.ROOT}/tensorboard"
116
+
117
+ # Fix MLFLOW_URI for Windows by adding a missing slash
118
+ if os.name == "nt":
119
+ cls.MLFLOW_URI = cls.MLFLOW_URI.replace("file:", "file:/")
120
+
121
+
122
+ # Fix MLFLOW_URI for Windows by adding a missing slash
123
+ if os.name == "nt":
124
+ DataScienceConfig.MLFLOW_URI = DataScienceConfig.MLFLOW_URI.replace("file:", "file:/")
125
+
@@ -0,0 +1,66 @@
1
+
2
+ # Imports
3
+ from .auto_contrast import auto_contrast_image
4
+ from .axis_flip import flip_image
5
+ from .bias_field_correction import bias_field_correction_image
6
+ from .binary_threshold import binary_threshold_image
7
+ from .blur import blur_image
8
+ from .brightness import brightness_image
9
+ from .canny import canny_image
10
+ from .clahe import clahe_image
11
+ from .contrast import contrast_image
12
+ from .curvature_flow_filter import curvature_flow_filter_image
13
+ from .denoise import (
14
+ adaptive_denoise_image,
15
+ bilateral_denoise_image,
16
+ nlm_denoise_image,
17
+ tv_denoise_image,
18
+ wavelet_denoise_image,
19
+ )
20
+ from .invert import invert_image
21
+ from .laplacian import laplacian_image
22
+ from .median_blur import median_blur_image
23
+ from .noise import noise_image
24
+ from .normalize import normalize_image
25
+ from .random_erase import random_erase_image
26
+ from .resize import resize_image
27
+ from .rotation import rotate_image
28
+ from .salt_pepper import salt_pepper_image
29
+ from .sharpening import sharpen_image
30
+ from .shearing import shear_image
31
+ from .threshold import threshold_image
32
+ from .translation import translate_image
33
+ from .zoom import zoom_image
34
+
35
+ __all__ = [
36
+ "adaptive_denoise_image",
37
+ "auto_contrast_image",
38
+ "bias_field_correction_image",
39
+ "bilateral_denoise_image",
40
+ "binary_threshold_image",
41
+ "blur_image",
42
+ "brightness_image",
43
+ "canny_image",
44
+ "clahe_image",
45
+ "contrast_image",
46
+ "curvature_flow_filter_image",
47
+ "flip_image",
48
+ "invert_image",
49
+ "laplacian_image",
50
+ "median_blur_image",
51
+ "nlm_denoise_image",
52
+ "noise_image",
53
+ "normalize_image",
54
+ "random_erase_image",
55
+ "resize_image",
56
+ "rotate_image",
57
+ "salt_pepper_image",
58
+ "sharpen_image",
59
+ "shear_image",
60
+ "threshold_image",
61
+ "translate_image",
62
+ "tv_denoise_image",
63
+ "wavelet_denoise_image",
64
+ "zoom_image",
65
+ ]
66
+
@@ -0,0 +1,79 @@
1
+
2
+ # pyright: reportUnusedImport=false
3
+ # ruff: noqa: F401
4
+
5
+ # Imports
6
+ from .common import Any, NDArray, check_image, cv2, np
7
+
8
+
9
+ # Functions
10
+ def auto_contrast_image(image: NDArray[Any], ignore_dtype: bool = False) -> NDArray[Any]:
11
+ """ Adjust the contrast of an image.
12
+
13
+ Args:
14
+ image (NDArray[Any]): Image to adjust contrast
15
+ ignore_dtype (bool): Ignore the dtype check
16
+ Returns:
17
+ NDArray[Any]: Image with adjusted contrast
18
+
19
+ >>> ## Basic tests
20
+ >>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8)
21
+ >>> adjusted = auto_contrast_image(image)
22
+ >>> adjusted.tolist()
23
+ [[0, 36, 73], [109, 146, 182], [219, 255, 255]]
24
+ >>> adjusted.shape == image.shape
25
+ True
26
+ >>> adjusted.dtype == image.dtype
27
+ True
28
+
29
+ >>> ## Test invalid inputs
30
+ >>> auto_contrast_image("not an image")
31
+ Traceback (most recent call last):
32
+ ...
33
+ AssertionError: Image must be a numpy array
34
+ """
35
+ # Check input data
36
+ check_image(image, ignore_dtype=ignore_dtype)
37
+
38
+ # Perform histogram clipping
39
+ clip_hist_percent: float = 1.0
40
+
41
+ # Calculate the histogram of the image
42
+ hist: NDArray[Any] = cv2.calcHist([image], [0], None, [256], [0, 256])
43
+
44
+ # Create an accumulator list to store the cumulative histogram
45
+ accumulator: list[float] = []
46
+ accumulator.append(hist[0])
47
+ for i in range(1, 256):
48
+ accumulator.append(accumulator[i - 1] + hist[i])
49
+
50
+ # Find the maximum value in the accumulator
51
+ max_value: float = accumulator[-1]
52
+
53
+ # Calculate the clipping threshold
54
+ clip_hist_percent = clip_hist_percent * (max_value / 100.0)
55
+ clip_hist_percent = clip_hist_percent / 2.0
56
+
57
+ # Find the minimum and maximum gray levels after clipping
58
+ min_gray: int = 0
59
+ while accumulator[min_gray] < clip_hist_percent:
60
+ min_gray = min_gray + 1
61
+ max_gray: int = 256 - 1
62
+ while (max_gray >= 0 and accumulator[max_gray] >= (max_value - clip_hist_percent)):
63
+ max_gray = max_gray - 1
64
+
65
+ # Calculate the input range after clipping
66
+ input_range: int = max_gray - min_gray
67
+
68
+ # If the input range is 0, return the original image
69
+ if input_range == 0:
70
+ return image
71
+
72
+ # Calculate the scaling factors for contrast adjustment
73
+ alpha: float = (256 - 1) / input_range
74
+ beta: float = -min_gray * alpha
75
+
76
+ # Apply the contrast adjustment
77
+ adjusted: NDArray[Any] = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
78
+ return adjusted
79
+
@@ -0,0 +1,58 @@
1
+
2
+ # pyright: reportUnusedImport=false
3
+ # ruff: noqa: F401
4
+
5
+ # Imports
6
+ from typing import Literal
7
+
8
+ from .common import Any, NDArray, check_image, cv2, np
9
+
10
+
11
+ # Functions
12
+ def flip_image(
13
+ image: NDArray[Any], axis: Literal["horizontal", "vertical", "both"], ignore_dtype: bool = True
14
+ ) -> NDArray[Any]:
15
+ """ Flip an image along specified axis
16
+
17
+ Args:
18
+ image (NDArray[Any]): Image to flip
19
+ axis (str): Axis along which to flip ("horizontal" or "vertical" or "both")
20
+ ignore_dtype (bool): Ignore the dtype check
21
+ Returns:
22
+ NDArray[Any]: Flipped image
23
+
24
+ >>> ## Basic tests
25
+ >>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
26
+ >>> flip_image(image, "horizontal").tolist()
27
+ [[3, 2, 1], [6, 5, 4], [9, 8, 7]]
28
+
29
+ >>> flip_image(image, "vertical").tolist()
30
+ [[7, 8, 9], [4, 5, 6], [1, 2, 3]]
31
+
32
+ >>> flip_image(image, "both").tolist()
33
+ [[9, 8, 7], [6, 5, 4], [3, 2, 1]]
34
+
35
+ >>> ## Test invalid inputs
36
+ >>> flip_image(image, "diagonal")
37
+ Traceback (most recent call last):
38
+ AssertionError: axis must be either 'horizontal' or 'vertical' or 'both', got 'diagonal'
39
+
40
+ >>> flip_image("not an image", "horizontal")
41
+ Traceback (most recent call last):
42
+ ...
43
+ AssertionError: Image must be a numpy array
44
+ """
45
+ # Check input data
46
+ check_image(image, ignore_dtype=ignore_dtype)
47
+ assert axis in ("horizontal", "vertical", "both"), (
48
+ f"axis must be either 'horizontal' or 'vertical' or 'both', got '{axis}'"
49
+ )
50
+
51
+ # Apply the flip
52
+ if axis == "horizontal":
53
+ return cv2.flip(image, 1) # 1 for horizontal flip
54
+ elif axis == "vertical":
55
+ return cv2.flip(image, 0) # 0 for vertical flip
56
+ else:
57
+ return cv2.flip(image, -1) # -1 for both flips
58
+
@@ -0,0 +1,74 @@
1
+
2
+ # pyright: reportUnknownMemberType=false
3
+ # pyright: reportUnknownArgumentType=false
4
+ # pyright: reportUnknownVariableType=false
5
+
6
+ # Imports
7
+ import SimpleITK as Sitk
8
+
9
+ from .common import Any, NDArray, check_image, np
10
+
11
+
12
+ # Functions
13
+ def bias_field_correction_image(image: NDArray[Any], ignore_dtype: bool = False) -> NDArray[Any]:
14
+ """ Apply a bias field correction to an image. (N4 Filter)
15
+
16
+ Args:
17
+ image (NDArray[Any]): Image to apply the bias field correction (can't be 8-bit unsigned integer)
18
+ ignore_dtype (bool): Ignore the dtype check
19
+ Returns:
20
+ NDArray[Any]: Image with the curvature flow filter applied
21
+
22
+ >>> ## Basic tests
23
+ >>> image = np.random.randint(0, 255, size=(10,10), dtype=np.uint8) / 255
24
+ >>> corrected = bias_field_correction_image(image)
25
+ >>> corrected.shape == image.shape
26
+ True
27
+ >>> corrected.dtype == np.float64
28
+ True
29
+
30
+ >>> ## Test invalid inputs
31
+ >>> bias_field_correction_image("not an image")
32
+ Traceback (most recent call last):
33
+ ...
34
+ AssertionError: Image must be a numpy array
35
+ """
36
+ # Check input data
37
+ check_image(image, ignore_dtype=ignore_dtype)
38
+
39
+ # If the image is 3D, convert to grayscale first
40
+ if image.ndim == 3:
41
+ image = np.mean(image, axis=-1)
42
+
43
+ # Convert numpy array to SimpleITK image
44
+ image_sitk: Sitk.Image = Sitk.GetImageFromArray(image)
45
+
46
+ # Create binary mask of the head region
47
+ transformed: Sitk.Image = Sitk.RescaleIntensity(image_sitk) # Normalize intensities
48
+ transformed = Sitk.LiThreshold(transformed, 0, 1) # Apply Li thresholding
49
+ head_mask: Sitk.Image = transformed
50
+
51
+ # Downsample images to speed up bias field estimation
52
+ shrink_factor: int = 4 # Reduce image size by factor of 4
53
+ input_image: Sitk.Image = Sitk.Shrink(
54
+ image_sitk,
55
+ [shrink_factor] * image_sitk.GetDimension() # Apply shrink factor to all dimensions
56
+ )
57
+ mask_image: Sitk.Image = Sitk.Shrink(
58
+ head_mask,
59
+ [shrink_factor] * image_sitk.GetDimension()
60
+ )
61
+
62
+ # Apply N4 bias field correction
63
+ corrector = Sitk.N4BiasFieldCorrectionImageFilter()
64
+ corrector.Execute(input_image, mask_image)
65
+
66
+ # Get estimated bias field and apply correction
67
+ log_bias_field: Sitk.Image = Sitk.Cast(
68
+ corrector.GetLogBiasFieldAsImage(image_sitk), Sitk.sitkFloat64
69
+ )
70
+ corrected_image_full_resolution: Sitk.Image = image_sitk / Sitk.Exp(log_bias_field)
71
+
72
+ # Convert back to numpy array and return
73
+ return Sitk.GetArrayFromImage(corrected_image_full_resolution)
74
+
@@ -0,0 +1,73 @@
1
+
2
+ # pyright: reportUnusedImport=false
3
+ # ruff: noqa: F401
4
+
5
+ # Imports
6
+ from .common import Any, NDArray, check_image, cv2, np
7
+
8
+
9
+ # Functions
10
+ def binary_threshold_image(image: NDArray[Any], threshold: float, ignore_dtype: bool = False) -> NDArray[Any]:
11
+ """ Apply binary threshold to an image.
12
+
13
+ Args:
14
+ image (NDArray[Any]): Image to threshold
15
+ threshold (float): Threshold value (between 0 and 1)
16
+ ignore_dtype (bool): Ignore the dtype check
17
+ Returns:
18
+ NDArray[Any]: Thresholded binary image
19
+
20
+ >>> ## Basic tests
21
+ >>> image = np.array([[100, 150, 200], [50, 125, 175], [25, 75, 225]])
22
+ >>> binary_threshold_image(image.astype(np.uint8), 0.5).tolist()
23
+ [[0, 255, 255], [0, 0, 255], [0, 0, 255]]
24
+
25
+ >>> np.random.seed(42)
26
+ >>> img = np.random.randint(0, 256, (4,4), dtype=np.uint8)
27
+ >>> thresholded = binary_threshold_image(img, 0.7)
28
+ >>> set(np.unique(thresholded).tolist()) <= {0, 255} # Should only contain 0 and 255
29
+ True
30
+
31
+ >>> rgb = np.random.randint(0, 256, (3,3,3), dtype=np.uint8)
32
+ >>> thresh_rgb = binary_threshold_image(rgb, 0.5)
33
+ >>> thresh_rgb.shape == rgb.shape
34
+ True
35
+ >>> set(np.unique(thresh_rgb).tolist()) <= {0, 255}
36
+ True
37
+
38
+ >>> ## Test invalid inputs
39
+ >>> binary_threshold_image("not an image", 0.5)
40
+ Traceback (most recent call last):
41
+ ...
42
+ AssertionError: Image must be a numpy array
43
+
44
+ >>> binary_threshold_image(image.astype(np.uint8), "0.5")
45
+ Traceback (most recent call last):
46
+ ...
47
+ AssertionError: threshold must be a number, got <class 'str'>
48
+
49
+ >>> binary_threshold_image(image.astype(np.uint8), 1.5)
50
+ Traceback (most recent call last):
51
+ ...
52
+ AssertionError: threshold must be between 0 and 1, got 1.5
53
+ """
54
+ # Check input data
55
+ check_image(image, ignore_dtype=ignore_dtype)
56
+ assert isinstance(threshold, float | int), f"threshold must be a number, got {type(threshold)}"
57
+ assert 0 <= threshold <= 1, f"threshold must be between 0 and 1, got {threshold}"
58
+
59
+ # Convert threshold from 0-1 range to 0-255 range
60
+ threshold_value: int = int(threshold * 255)
61
+
62
+ # Apply threshold
63
+ if len(image.shape) == 2:
64
+ # Grayscale image
65
+ binary: NDArray[Any] = cv2.threshold(image, threshold_value, 255, cv2.THRESH_BINARY)[1]
66
+ else:
67
+ # Color image - convert to grayscale first, then back to color
68
+ gray: NDArray[Any] = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
69
+ binary: NDArray[Any] = cv2.threshold(gray, threshold_value, 255, cv2.THRESH_BINARY)[1]
70
+ binary = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
71
+
72
+ return binary
73
+
@@ -0,0 +1,59 @@
1
+
2
+ # pyright: reportUnusedImport=false
3
+ # ruff: noqa: F401
4
+
5
+ # Imports
6
+ from .common import Any, NDArray, check_image, cv2, np
7
+
8
+
9
+ # Functions
10
+ def blur_image(image: NDArray[Any], blur_strength: float, ignore_dtype: bool = False) -> NDArray[Any]:
11
+ """ Apply Gaussian blur to an image.
12
+
13
+ Args:
14
+ image (NDArray[Any]): Image to blur
15
+ blur_strength (float): Strength of the blur
16
+ ignore_dtype (bool): Ignore the dtype check
17
+ Returns:
18
+ NDArray[Any]: Blurred image
19
+
20
+ >>> ## Basic tests
21
+ >>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
22
+ >>> blurred = blur_image(image.astype(np.uint8), 1.5)
23
+ >>> blurred.shape == image.shape
24
+ True
25
+
26
+ >>> img = np.zeros((5,5), dtype=np.uint8)
27
+ >>> img[2,2] = 255 # Single bright pixel
28
+ >>> blurred = blur_image(img, 1.0)
29
+ >>> bool(blurred[2,2] < 255) # Center should be blurred
30
+ True
31
+
32
+ >>> rgb = np.full((3,3,3), 128, dtype=np.uint8)
33
+ >>> blurred_rgb = blur_image(rgb, 1.0)
34
+ >>> blurred_rgb.shape == (3,3,3)
35
+ True
36
+
37
+ >>> ## Test invalid inputs
38
+ >>> blur_image("not an image", 1.5)
39
+ Traceback (most recent call last):
40
+ ...
41
+ AssertionError: Image must be a numpy array
42
+
43
+ >>> blur_image(image.astype(np.uint8), "1.5")
44
+ Traceback (most recent call last):
45
+ ...
46
+ AssertionError: blur_strength must be a number, got <class 'str'>
47
+ """
48
+ # Check input data
49
+ check_image(image, ignore_dtype=ignore_dtype)
50
+ assert isinstance(blur_strength, float | int), f"blur_strength must be a number, got {type(blur_strength)}"
51
+
52
+ # Apply Gaussian blur
53
+ kernel_size: int = max(3, int(blur_strength * 2) + 1)
54
+ if kernel_size % 2 == 0:
55
+ kernel_size += 1
56
+ blurred_image: NDArray[Any] = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
57
+
58
+ return blurred_image
59
+
@@ -0,0 +1,54 @@
1
+
2
+ # pyright: reportUnusedImport=false
3
+ # ruff: noqa: F401
4
+
5
+ # Imports
6
+ from .common import Any, NDArray, check_image, cv2, np
7
+
8
+
9
+ # Functions
10
+ def brightness_image(image: NDArray[Any], brightness_factor: float, ignore_dtype: bool = False) -> NDArray[Any]:
11
+ """ Adjust the brightness of an image.
12
+
13
+ Args:
14
+ image (NDArray[Any]): Image to adjust brightness
15
+ brightness_factor (float): Brightness adjustment factor
16
+ ignore_dtype (bool): Ignore the dtype check
17
+ Returns:
18
+ NDArray[Any]: Image with adjusted brightness
19
+
20
+ >>> ## Basic tests
21
+ >>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
22
+ >>> brightened = brightness_image(image.astype(np.uint8), 1.5)
23
+ >>> brightened.shape == image.shape
24
+ True
25
+
26
+ >>> img = np.full((3,3), 100, dtype=np.uint8)
27
+ >>> bright = brightness_image(img, 2.0)
28
+ >>> dark = brightness_image(img, 0.5)
29
+ >>> bool(np.mean(bright) > np.mean(img) > np.mean(dark))
30
+ True
31
+
32
+ >>> rgb = np.full((3,3,3), 128, dtype=np.uint8)
33
+ >>> bright_rgb = brightness_image(rgb, 1.5)
34
+ >>> bright_rgb.shape == (3,3,3)
35
+ True
36
+
37
+ >>> ## Test invalid inputs
38
+ >>> brightness_image("not an image", 1.5)
39
+ Traceback (most recent call last):
40
+ ...
41
+ AssertionError: Image must be a numpy array
42
+
43
+ >>> brightness_image(image.astype(np.uint8), "1.5")
44
+ Traceback (most recent call last):
45
+ ...
46
+ AssertionError: brightness_factor must be a number, got <class 'str'>
47
+ """
48
+ # Check input data
49
+ check_image(image, ignore_dtype=ignore_dtype)
50
+ assert isinstance(brightness_factor, float | int), f"brightness_factor must be a number, got {type(brightness_factor)}"
51
+
52
+ # Apply brightness adjustment
53
+ return cv2.convertScaleAbs(image, alpha=brightness_factor, beta=0)
54
+