stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
""" Configuration file for the project. """
|
|
2
|
+
|
|
3
|
+
# Imports
|
|
4
|
+
import os
|
|
5
|
+
from typing import Literal
|
|
6
|
+
|
|
7
|
+
from stouputils.decorators import LogLevels
|
|
8
|
+
from stouputils.io import get_root_path
|
|
9
|
+
|
|
10
|
+
# Environment variables
|
|
11
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "9" # Suppress TensorFlow logging
|
|
12
|
+
os.environ["GRPC_VERBOSITY"] = "ERROR" # Suppress gRPC logging
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# Configuration class
|
|
16
|
+
class DataScienceConfig:
|
|
17
|
+
""" Configuration class for the project. """
|
|
18
|
+
|
|
19
|
+
# Common
|
|
20
|
+
SEED: int = 42
|
|
21
|
+
""" Seed for the random number generator. """
|
|
22
|
+
|
|
23
|
+
ERROR_LOG: LogLevels = LogLevels.WARNING_TRACEBACK
|
|
24
|
+
""" Log level for errors for all functions. """
|
|
25
|
+
|
|
26
|
+
AUGMENTED_FILE_SUFFIX: str = "_aug_"
|
|
27
|
+
""" Suffix for augmented files, e.g. 'image_008_aug_1.png'. """
|
|
28
|
+
|
|
29
|
+
AUGMENTED_DIRECTORY_PREFIX: str = "aug_"
|
|
30
|
+
""" Prefix for augmented directories, e.g. 'data/hip_implant' -> 'data/aug_hip_implant'. """
|
|
31
|
+
|
|
32
|
+
PREPROCESSED_DIRECTORY_SUFFIX: str = "_preprocessed"
|
|
33
|
+
""" Suffix for preprocessed directories, e.g. 'data/hip_implant' -> 'data/hip_implant_preprocessed'. """
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# Directories
|
|
37
|
+
ROOT: str = get_root_path(__file__, go_up=3)
|
|
38
|
+
""" Root directory of the project. """
|
|
39
|
+
|
|
40
|
+
MLFLOW_FOLDER: str = f"{ROOT}/mlruns"
|
|
41
|
+
""" Folder containing the mlflow data. """
|
|
42
|
+
MLFLOW_URI: str = f"file://{MLFLOW_FOLDER}"
|
|
43
|
+
""" URI to the mlflow data. """
|
|
44
|
+
|
|
45
|
+
DATA_FOLDER: str = f"{ROOT}/data"
|
|
46
|
+
""" Folder containing all the data (e.g. subfolders containing images). """
|
|
47
|
+
|
|
48
|
+
TEMP_FOLDER: str = f"{ROOT}/temp"
|
|
49
|
+
""" Folder containing temporary files (e.g. models checkpoints, plots, etc.). """
|
|
50
|
+
|
|
51
|
+
LOGS_FOLDER: str = f"{ROOT}/logs"
|
|
52
|
+
""" Folder containing the logs. """
|
|
53
|
+
|
|
54
|
+
TENSORBOARD_FOLDER: str = f"{ROOT}/tensorboard"
|
|
55
|
+
""" Folder containing the tensorboard logs. """
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
# Behaviours
|
|
59
|
+
TEST_SIZE: float = 0.2
|
|
60
|
+
""" Size of the test set by default (0.2 means 80% training, 20% test). """
|
|
61
|
+
|
|
62
|
+
VALIDATION_SIZE: float = 0.2
|
|
63
|
+
""" Size of the validation set by default (0.2 means 80% training, 20% validation). """
|
|
64
|
+
|
|
65
|
+
# Machine learning
|
|
66
|
+
SAVE_MODEL: bool = False
|
|
67
|
+
""" If the model should be saved in the mlflow folder using mlflow.*.save_model. """
|
|
68
|
+
|
|
69
|
+
DO_SALIENCY_AND_GRADCAM: bool = True
|
|
70
|
+
""" If the saliency and gradcam should be done during the run. """
|
|
71
|
+
|
|
72
|
+
DO_LEARNING_RATE_FINDER: Literal[0, 1, 2] = 1
|
|
73
|
+
""" If the learning rate finder should be done during the run.
|
|
74
|
+
0: no, 1: only plot, 2: plot and use value for the remaining run
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
DO_UNFREEZE_FINDER: Literal[0, 1, 2] = 0
|
|
78
|
+
""" If the unfreeze finder should be done during the run.
|
|
79
|
+
0: no, 1: only plot, 2: plot and use value for the remaining run
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
DO_FIT_IN_SUBPROCESS: bool = True
|
|
83
|
+
""" If the model should be fitted in a subprocess.
|
|
84
|
+
Is memory efficient, and more stable. Turn it off only if you are having issues.
|
|
85
|
+
|
|
86
|
+
Note: This allow a program to make lots of runs without getting killed by the OS for using too much resources.
|
|
87
|
+
(e.g. LeaveOneOut Cross Validation, Grid Search, etc.)
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
MIXED_PRECISION_POLICY: Literal["mixed_float16", "mixed_bfloat16", "float32"] = "mixed_float16"
|
|
91
|
+
""" Mixed precision policy to use. Turn back to "float32" if you are having issues with a specific model or metrics.
|
|
92
|
+
See: https://www.tensorflow.org/guide/mixed_precision
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
TENSORFLOW_DEVICE: str = "/gpu:1"
|
|
96
|
+
""" TensorFlow device to use. """
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
def update_root(cls, new_root: str) -> None:
|
|
102
|
+
""" Update the root directory and recalculate all dependent paths.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
new_root: The new root directory path.
|
|
106
|
+
"""
|
|
107
|
+
cls.ROOT = new_root
|
|
108
|
+
|
|
109
|
+
# Update all paths that depend on ROOT
|
|
110
|
+
cls.MLFLOW_FOLDER = f"{cls.ROOT}/mlruns"
|
|
111
|
+
cls.MLFLOW_URI = f"file://{cls.MLFLOW_FOLDER}"
|
|
112
|
+
cls.DATA_FOLDER = f"{cls.ROOT}/data"
|
|
113
|
+
cls.TEMP_FOLDER = f"{cls.ROOT}/temp"
|
|
114
|
+
cls.LOGS_FOLDER = f"{cls.ROOT}/logs"
|
|
115
|
+
cls.TENSORBOARD_FOLDER = f"{cls.ROOT}/tensorboard"
|
|
116
|
+
|
|
117
|
+
# Fix MLFLOW_URI for Windows by adding a missing slash
|
|
118
|
+
if os.name == "nt":
|
|
119
|
+
cls.MLFLOW_URI = cls.MLFLOW_URI.replace("file:", "file:/")
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
# Fix MLFLOW_URI for Windows by adding a missing slash
|
|
123
|
+
if os.name == "nt":
|
|
124
|
+
DataScienceConfig.MLFLOW_URI = DataScienceConfig.MLFLOW_URI.replace("file:", "file:/")
|
|
125
|
+
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
|
|
2
|
+
# Imports
|
|
3
|
+
from .auto_contrast import auto_contrast_image
|
|
4
|
+
from .axis_flip import flip_image
|
|
5
|
+
from .bias_field_correction import bias_field_correction_image
|
|
6
|
+
from .binary_threshold import binary_threshold_image
|
|
7
|
+
from .blur import blur_image
|
|
8
|
+
from .brightness import brightness_image
|
|
9
|
+
from .canny import canny_image
|
|
10
|
+
from .clahe import clahe_image
|
|
11
|
+
from .contrast import contrast_image
|
|
12
|
+
from .curvature_flow_filter import curvature_flow_filter_image
|
|
13
|
+
from .denoise import (
|
|
14
|
+
adaptive_denoise_image,
|
|
15
|
+
bilateral_denoise_image,
|
|
16
|
+
nlm_denoise_image,
|
|
17
|
+
tv_denoise_image,
|
|
18
|
+
wavelet_denoise_image,
|
|
19
|
+
)
|
|
20
|
+
from .invert import invert_image
|
|
21
|
+
from .laplacian import laplacian_image
|
|
22
|
+
from .median_blur import median_blur_image
|
|
23
|
+
from .noise import noise_image
|
|
24
|
+
from .normalize import normalize_image
|
|
25
|
+
from .random_erase import random_erase_image
|
|
26
|
+
from .resize import resize_image
|
|
27
|
+
from .rotation import rotate_image
|
|
28
|
+
from .salt_pepper import salt_pepper_image
|
|
29
|
+
from .sharpening import sharpen_image
|
|
30
|
+
from .shearing import shear_image
|
|
31
|
+
from .threshold import threshold_image
|
|
32
|
+
from .translation import translate_image
|
|
33
|
+
from .zoom import zoom_image
|
|
34
|
+
|
|
35
|
+
__all__ = [
|
|
36
|
+
"adaptive_denoise_image",
|
|
37
|
+
"auto_contrast_image",
|
|
38
|
+
"bias_field_correction_image",
|
|
39
|
+
"bilateral_denoise_image",
|
|
40
|
+
"binary_threshold_image",
|
|
41
|
+
"blur_image",
|
|
42
|
+
"brightness_image",
|
|
43
|
+
"canny_image",
|
|
44
|
+
"clahe_image",
|
|
45
|
+
"contrast_image",
|
|
46
|
+
"curvature_flow_filter_image",
|
|
47
|
+
"flip_image",
|
|
48
|
+
"invert_image",
|
|
49
|
+
"laplacian_image",
|
|
50
|
+
"median_blur_image",
|
|
51
|
+
"nlm_denoise_image",
|
|
52
|
+
"noise_image",
|
|
53
|
+
"normalize_image",
|
|
54
|
+
"random_erase_image",
|
|
55
|
+
"resize_image",
|
|
56
|
+
"rotate_image",
|
|
57
|
+
"salt_pepper_image",
|
|
58
|
+
"sharpen_image",
|
|
59
|
+
"shear_image",
|
|
60
|
+
"threshold_image",
|
|
61
|
+
"translate_image",
|
|
62
|
+
"tv_denoise_image",
|
|
63
|
+
"wavelet_denoise_image",
|
|
64
|
+
"zoom_image",
|
|
65
|
+
]
|
|
66
|
+
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnusedImport=false
|
|
3
|
+
# ruff: noqa: F401
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from .common import Any, NDArray, check_image, cv2, np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Functions
|
|
10
|
+
def auto_contrast_image(image: NDArray[Any], ignore_dtype: bool = False) -> NDArray[Any]:
|
|
11
|
+
""" Adjust the contrast of an image.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
image (NDArray[Any]): Image to adjust contrast
|
|
15
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
16
|
+
Returns:
|
|
17
|
+
NDArray[Any]: Image with adjusted contrast
|
|
18
|
+
|
|
19
|
+
>>> ## Basic tests
|
|
20
|
+
>>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.uint8)
|
|
21
|
+
>>> adjusted = auto_contrast_image(image)
|
|
22
|
+
>>> adjusted.tolist()
|
|
23
|
+
[[0, 36, 73], [109, 146, 182], [219, 255, 255]]
|
|
24
|
+
>>> adjusted.shape == image.shape
|
|
25
|
+
True
|
|
26
|
+
>>> adjusted.dtype == image.dtype
|
|
27
|
+
True
|
|
28
|
+
|
|
29
|
+
>>> ## Test invalid inputs
|
|
30
|
+
>>> auto_contrast_image("not an image")
|
|
31
|
+
Traceback (most recent call last):
|
|
32
|
+
...
|
|
33
|
+
AssertionError: Image must be a numpy array
|
|
34
|
+
"""
|
|
35
|
+
# Check input data
|
|
36
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
37
|
+
|
|
38
|
+
# Perform histogram clipping
|
|
39
|
+
clip_hist_percent: float = 1.0
|
|
40
|
+
|
|
41
|
+
# Calculate the histogram of the image
|
|
42
|
+
hist: NDArray[Any] = cv2.calcHist([image], [0], None, [256], [0, 256])
|
|
43
|
+
|
|
44
|
+
# Create an accumulator list to store the cumulative histogram
|
|
45
|
+
accumulator: list[float] = []
|
|
46
|
+
accumulator.append(hist[0])
|
|
47
|
+
for i in range(1, 256):
|
|
48
|
+
accumulator.append(accumulator[i - 1] + hist[i])
|
|
49
|
+
|
|
50
|
+
# Find the maximum value in the accumulator
|
|
51
|
+
max_value: float = accumulator[-1]
|
|
52
|
+
|
|
53
|
+
# Calculate the clipping threshold
|
|
54
|
+
clip_hist_percent = clip_hist_percent * (max_value / 100.0)
|
|
55
|
+
clip_hist_percent = clip_hist_percent / 2.0
|
|
56
|
+
|
|
57
|
+
# Find the minimum and maximum gray levels after clipping
|
|
58
|
+
min_gray: int = 0
|
|
59
|
+
while accumulator[min_gray] < clip_hist_percent:
|
|
60
|
+
min_gray = min_gray + 1
|
|
61
|
+
max_gray: int = 256 - 1
|
|
62
|
+
while (max_gray >= 0 and accumulator[max_gray] >= (max_value - clip_hist_percent)):
|
|
63
|
+
max_gray = max_gray - 1
|
|
64
|
+
|
|
65
|
+
# Calculate the input range after clipping
|
|
66
|
+
input_range: int = max_gray - min_gray
|
|
67
|
+
|
|
68
|
+
# If the input range is 0, return the original image
|
|
69
|
+
if input_range == 0:
|
|
70
|
+
return image
|
|
71
|
+
|
|
72
|
+
# Calculate the scaling factors for contrast adjustment
|
|
73
|
+
alpha: float = (256 - 1) / input_range
|
|
74
|
+
beta: float = -min_gray * alpha
|
|
75
|
+
|
|
76
|
+
# Apply the contrast adjustment
|
|
77
|
+
adjusted: NDArray[Any] = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
|
|
78
|
+
return adjusted
|
|
79
|
+
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnusedImport=false
|
|
3
|
+
# ruff: noqa: F401
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from typing import Literal
|
|
7
|
+
|
|
8
|
+
from .common import Any, NDArray, check_image, cv2, np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
# Functions
|
|
12
|
+
def flip_image(
|
|
13
|
+
image: NDArray[Any], axis: Literal["horizontal", "vertical", "both"], ignore_dtype: bool = True
|
|
14
|
+
) -> NDArray[Any]:
|
|
15
|
+
""" Flip an image along specified axis
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
image (NDArray[Any]): Image to flip
|
|
19
|
+
axis (str): Axis along which to flip ("horizontal" or "vertical" or "both")
|
|
20
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
21
|
+
Returns:
|
|
22
|
+
NDArray[Any]: Flipped image
|
|
23
|
+
|
|
24
|
+
>>> ## Basic tests
|
|
25
|
+
>>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
26
|
+
>>> flip_image(image, "horizontal").tolist()
|
|
27
|
+
[[3, 2, 1], [6, 5, 4], [9, 8, 7]]
|
|
28
|
+
|
|
29
|
+
>>> flip_image(image, "vertical").tolist()
|
|
30
|
+
[[7, 8, 9], [4, 5, 6], [1, 2, 3]]
|
|
31
|
+
|
|
32
|
+
>>> flip_image(image, "both").tolist()
|
|
33
|
+
[[9, 8, 7], [6, 5, 4], [3, 2, 1]]
|
|
34
|
+
|
|
35
|
+
>>> ## Test invalid inputs
|
|
36
|
+
>>> flip_image(image, "diagonal")
|
|
37
|
+
Traceback (most recent call last):
|
|
38
|
+
AssertionError: axis must be either 'horizontal' or 'vertical' or 'both', got 'diagonal'
|
|
39
|
+
|
|
40
|
+
>>> flip_image("not an image", "horizontal")
|
|
41
|
+
Traceback (most recent call last):
|
|
42
|
+
...
|
|
43
|
+
AssertionError: Image must be a numpy array
|
|
44
|
+
"""
|
|
45
|
+
# Check input data
|
|
46
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
47
|
+
assert axis in ("horizontal", "vertical", "both"), (
|
|
48
|
+
f"axis must be either 'horizontal' or 'vertical' or 'both', got '{axis}'"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Apply the flip
|
|
52
|
+
if axis == "horizontal":
|
|
53
|
+
return cv2.flip(image, 1) # 1 for horizontal flip
|
|
54
|
+
elif axis == "vertical":
|
|
55
|
+
return cv2.flip(image, 0) # 0 for vertical flip
|
|
56
|
+
else:
|
|
57
|
+
return cv2.flip(image, -1) # -1 for both flips
|
|
58
|
+
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnknownMemberType=false
|
|
3
|
+
# pyright: reportUnknownArgumentType=false
|
|
4
|
+
# pyright: reportUnknownVariableType=false
|
|
5
|
+
|
|
6
|
+
# Imports
|
|
7
|
+
import SimpleITK as Sitk
|
|
8
|
+
|
|
9
|
+
from .common import Any, NDArray, check_image, np
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# Functions
|
|
13
|
+
def bias_field_correction_image(image: NDArray[Any], ignore_dtype: bool = False) -> NDArray[Any]:
|
|
14
|
+
""" Apply a bias field correction to an image. (N4 Filter)
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
image (NDArray[Any]): Image to apply the bias field correction (can't be 8-bit unsigned integer)
|
|
18
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
19
|
+
Returns:
|
|
20
|
+
NDArray[Any]: Image with the curvature flow filter applied
|
|
21
|
+
|
|
22
|
+
>>> ## Basic tests
|
|
23
|
+
>>> image = np.random.randint(0, 255, size=(10,10), dtype=np.uint8) / 255
|
|
24
|
+
>>> corrected = bias_field_correction_image(image)
|
|
25
|
+
>>> corrected.shape == image.shape
|
|
26
|
+
True
|
|
27
|
+
>>> corrected.dtype == np.float64
|
|
28
|
+
True
|
|
29
|
+
|
|
30
|
+
>>> ## Test invalid inputs
|
|
31
|
+
>>> bias_field_correction_image("not an image")
|
|
32
|
+
Traceback (most recent call last):
|
|
33
|
+
...
|
|
34
|
+
AssertionError: Image must be a numpy array
|
|
35
|
+
"""
|
|
36
|
+
# Check input data
|
|
37
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
38
|
+
|
|
39
|
+
# If the image is 3D, convert to grayscale first
|
|
40
|
+
if image.ndim == 3:
|
|
41
|
+
image = np.mean(image, axis=-1)
|
|
42
|
+
|
|
43
|
+
# Convert numpy array to SimpleITK image
|
|
44
|
+
image_sitk: Sitk.Image = Sitk.GetImageFromArray(image)
|
|
45
|
+
|
|
46
|
+
# Create binary mask of the head region
|
|
47
|
+
transformed: Sitk.Image = Sitk.RescaleIntensity(image_sitk) # Normalize intensities
|
|
48
|
+
transformed = Sitk.LiThreshold(transformed, 0, 1) # Apply Li thresholding
|
|
49
|
+
head_mask: Sitk.Image = transformed
|
|
50
|
+
|
|
51
|
+
# Downsample images to speed up bias field estimation
|
|
52
|
+
shrink_factor: int = 4 # Reduce image size by factor of 4
|
|
53
|
+
input_image: Sitk.Image = Sitk.Shrink(
|
|
54
|
+
image_sitk,
|
|
55
|
+
[shrink_factor] * image_sitk.GetDimension() # Apply shrink factor to all dimensions
|
|
56
|
+
)
|
|
57
|
+
mask_image: Sitk.Image = Sitk.Shrink(
|
|
58
|
+
head_mask,
|
|
59
|
+
[shrink_factor] * image_sitk.GetDimension()
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Apply N4 bias field correction
|
|
63
|
+
corrector = Sitk.N4BiasFieldCorrectionImageFilter()
|
|
64
|
+
corrector.Execute(input_image, mask_image)
|
|
65
|
+
|
|
66
|
+
# Get estimated bias field and apply correction
|
|
67
|
+
log_bias_field: Sitk.Image = Sitk.Cast(
|
|
68
|
+
corrector.GetLogBiasFieldAsImage(image_sitk), Sitk.sitkFloat64
|
|
69
|
+
)
|
|
70
|
+
corrected_image_full_resolution: Sitk.Image = image_sitk / Sitk.Exp(log_bias_field)
|
|
71
|
+
|
|
72
|
+
# Convert back to numpy array and return
|
|
73
|
+
return Sitk.GetArrayFromImage(corrected_image_full_resolution)
|
|
74
|
+
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnusedImport=false
|
|
3
|
+
# ruff: noqa: F401
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from .common import Any, NDArray, check_image, cv2, np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Functions
|
|
10
|
+
def binary_threshold_image(image: NDArray[Any], threshold: float, ignore_dtype: bool = False) -> NDArray[Any]:
|
|
11
|
+
""" Apply binary threshold to an image.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
image (NDArray[Any]): Image to threshold
|
|
15
|
+
threshold (float): Threshold value (between 0 and 1)
|
|
16
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
17
|
+
Returns:
|
|
18
|
+
NDArray[Any]: Thresholded binary image
|
|
19
|
+
|
|
20
|
+
>>> ## Basic tests
|
|
21
|
+
>>> image = np.array([[100, 150, 200], [50, 125, 175], [25, 75, 225]])
|
|
22
|
+
>>> binary_threshold_image(image.astype(np.uint8), 0.5).tolist()
|
|
23
|
+
[[0, 255, 255], [0, 0, 255], [0, 0, 255]]
|
|
24
|
+
|
|
25
|
+
>>> np.random.seed(42)
|
|
26
|
+
>>> img = np.random.randint(0, 256, (4,4), dtype=np.uint8)
|
|
27
|
+
>>> thresholded = binary_threshold_image(img, 0.7)
|
|
28
|
+
>>> set(np.unique(thresholded).tolist()) <= {0, 255} # Should only contain 0 and 255
|
|
29
|
+
True
|
|
30
|
+
|
|
31
|
+
>>> rgb = np.random.randint(0, 256, (3,3,3), dtype=np.uint8)
|
|
32
|
+
>>> thresh_rgb = binary_threshold_image(rgb, 0.5)
|
|
33
|
+
>>> thresh_rgb.shape == rgb.shape
|
|
34
|
+
True
|
|
35
|
+
>>> set(np.unique(thresh_rgb).tolist()) <= {0, 255}
|
|
36
|
+
True
|
|
37
|
+
|
|
38
|
+
>>> ## Test invalid inputs
|
|
39
|
+
>>> binary_threshold_image("not an image", 0.5)
|
|
40
|
+
Traceback (most recent call last):
|
|
41
|
+
...
|
|
42
|
+
AssertionError: Image must be a numpy array
|
|
43
|
+
|
|
44
|
+
>>> binary_threshold_image(image.astype(np.uint8), "0.5")
|
|
45
|
+
Traceback (most recent call last):
|
|
46
|
+
...
|
|
47
|
+
AssertionError: threshold must be a number, got <class 'str'>
|
|
48
|
+
|
|
49
|
+
>>> binary_threshold_image(image.astype(np.uint8), 1.5)
|
|
50
|
+
Traceback (most recent call last):
|
|
51
|
+
...
|
|
52
|
+
AssertionError: threshold must be between 0 and 1, got 1.5
|
|
53
|
+
"""
|
|
54
|
+
# Check input data
|
|
55
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
56
|
+
assert isinstance(threshold, float | int), f"threshold must be a number, got {type(threshold)}"
|
|
57
|
+
assert 0 <= threshold <= 1, f"threshold must be between 0 and 1, got {threshold}"
|
|
58
|
+
|
|
59
|
+
# Convert threshold from 0-1 range to 0-255 range
|
|
60
|
+
threshold_value: int = int(threshold * 255)
|
|
61
|
+
|
|
62
|
+
# Apply threshold
|
|
63
|
+
if len(image.shape) == 2:
|
|
64
|
+
# Grayscale image
|
|
65
|
+
binary: NDArray[Any] = cv2.threshold(image, threshold_value, 255, cv2.THRESH_BINARY)[1]
|
|
66
|
+
else:
|
|
67
|
+
# Color image - convert to grayscale first, then back to color
|
|
68
|
+
gray: NDArray[Any] = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
|
69
|
+
binary: NDArray[Any] = cv2.threshold(gray, threshold_value, 255, cv2.THRESH_BINARY)[1]
|
|
70
|
+
binary = cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
|
|
71
|
+
|
|
72
|
+
return binary
|
|
73
|
+
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnusedImport=false
|
|
3
|
+
# ruff: noqa: F401
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from .common import Any, NDArray, check_image, cv2, np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Functions
|
|
10
|
+
def blur_image(image: NDArray[Any], blur_strength: float, ignore_dtype: bool = False) -> NDArray[Any]:
|
|
11
|
+
""" Apply Gaussian blur to an image.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
image (NDArray[Any]): Image to blur
|
|
15
|
+
blur_strength (float): Strength of the blur
|
|
16
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
17
|
+
Returns:
|
|
18
|
+
NDArray[Any]: Blurred image
|
|
19
|
+
|
|
20
|
+
>>> ## Basic tests
|
|
21
|
+
>>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
22
|
+
>>> blurred = blur_image(image.astype(np.uint8), 1.5)
|
|
23
|
+
>>> blurred.shape == image.shape
|
|
24
|
+
True
|
|
25
|
+
|
|
26
|
+
>>> img = np.zeros((5,5), dtype=np.uint8)
|
|
27
|
+
>>> img[2,2] = 255 # Single bright pixel
|
|
28
|
+
>>> blurred = blur_image(img, 1.0)
|
|
29
|
+
>>> bool(blurred[2,2] < 255) # Center should be blurred
|
|
30
|
+
True
|
|
31
|
+
|
|
32
|
+
>>> rgb = np.full((3,3,3), 128, dtype=np.uint8)
|
|
33
|
+
>>> blurred_rgb = blur_image(rgb, 1.0)
|
|
34
|
+
>>> blurred_rgb.shape == (3,3,3)
|
|
35
|
+
True
|
|
36
|
+
|
|
37
|
+
>>> ## Test invalid inputs
|
|
38
|
+
>>> blur_image("not an image", 1.5)
|
|
39
|
+
Traceback (most recent call last):
|
|
40
|
+
...
|
|
41
|
+
AssertionError: Image must be a numpy array
|
|
42
|
+
|
|
43
|
+
>>> blur_image(image.astype(np.uint8), "1.5")
|
|
44
|
+
Traceback (most recent call last):
|
|
45
|
+
...
|
|
46
|
+
AssertionError: blur_strength must be a number, got <class 'str'>
|
|
47
|
+
"""
|
|
48
|
+
# Check input data
|
|
49
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
50
|
+
assert isinstance(blur_strength, float | int), f"blur_strength must be a number, got {type(blur_strength)}"
|
|
51
|
+
|
|
52
|
+
# Apply Gaussian blur
|
|
53
|
+
kernel_size: int = max(3, int(blur_strength * 2) + 1)
|
|
54
|
+
if kernel_size % 2 == 0:
|
|
55
|
+
kernel_size += 1
|
|
56
|
+
blurred_image: NDArray[Any] = cv2.GaussianBlur(image, (kernel_size, kernel_size), 0)
|
|
57
|
+
|
|
58
|
+
return blurred_image
|
|
59
|
+
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
|
|
2
|
+
# pyright: reportUnusedImport=false
|
|
3
|
+
# ruff: noqa: F401
|
|
4
|
+
|
|
5
|
+
# Imports
|
|
6
|
+
from .common import Any, NDArray, check_image, cv2, np
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Functions
|
|
10
|
+
def brightness_image(image: NDArray[Any], brightness_factor: float, ignore_dtype: bool = False) -> NDArray[Any]:
|
|
11
|
+
""" Adjust the brightness of an image.
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
image (NDArray[Any]): Image to adjust brightness
|
|
15
|
+
brightness_factor (float): Brightness adjustment factor
|
|
16
|
+
ignore_dtype (bool): Ignore the dtype check
|
|
17
|
+
Returns:
|
|
18
|
+
NDArray[Any]: Image with adjusted brightness
|
|
19
|
+
|
|
20
|
+
>>> ## Basic tests
|
|
21
|
+
>>> image = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
|
|
22
|
+
>>> brightened = brightness_image(image.astype(np.uint8), 1.5)
|
|
23
|
+
>>> brightened.shape == image.shape
|
|
24
|
+
True
|
|
25
|
+
|
|
26
|
+
>>> img = np.full((3,3), 100, dtype=np.uint8)
|
|
27
|
+
>>> bright = brightness_image(img, 2.0)
|
|
28
|
+
>>> dark = brightness_image(img, 0.5)
|
|
29
|
+
>>> bool(np.mean(bright) > np.mean(img) > np.mean(dark))
|
|
30
|
+
True
|
|
31
|
+
|
|
32
|
+
>>> rgb = np.full((3,3,3), 128, dtype=np.uint8)
|
|
33
|
+
>>> bright_rgb = brightness_image(rgb, 1.5)
|
|
34
|
+
>>> bright_rgb.shape == (3,3,3)
|
|
35
|
+
True
|
|
36
|
+
|
|
37
|
+
>>> ## Test invalid inputs
|
|
38
|
+
>>> brightness_image("not an image", 1.5)
|
|
39
|
+
Traceback (most recent call last):
|
|
40
|
+
...
|
|
41
|
+
AssertionError: Image must be a numpy array
|
|
42
|
+
|
|
43
|
+
>>> brightness_image(image.astype(np.uint8), "1.5")
|
|
44
|
+
Traceback (most recent call last):
|
|
45
|
+
...
|
|
46
|
+
AssertionError: brightness_factor must be a number, got <class 'str'>
|
|
47
|
+
"""
|
|
48
|
+
# Check input data
|
|
49
|
+
check_image(image, ignore_dtype=ignore_dtype)
|
|
50
|
+
assert isinstance(brightness_factor, float | int), f"brightness_factor must be a number, got {type(brightness_factor)}"
|
|
51
|
+
|
|
52
|
+
# Apply brightness adjustment
|
|
53
|
+
return cv2.convertScaleAbs(image, alpha=brightness_factor, beta=0)
|
|
54
|
+
|