stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,168 @@
|
|
|
1
|
+
|
|
2
|
+
# Imports
|
|
3
|
+
import argparse
|
|
4
|
+
import itertools
|
|
5
|
+
from typing import Any, Literal
|
|
6
|
+
|
|
7
|
+
from ...decorators import handle_error, measure_time
|
|
8
|
+
from ...print import info, error, progress
|
|
9
|
+
from ...io import clean_path
|
|
10
|
+
from ..config.get import DataScienceConfig
|
|
11
|
+
from ..dataset import LOWER_GS, Dataset, DatasetLoader, GroupingStrategy, XyTuple
|
|
12
|
+
from ..models.all import ALL_MODELS, CLASS_MAP, ModelInterface
|
|
13
|
+
|
|
14
|
+
# Constants
|
|
15
|
+
MODEL_HELP: str = "Model(s) name or alias to use"
|
|
16
|
+
INPUT_HELP: str = "Path to the dataset, e.g. 'data/aug_hip_implant'"
|
|
17
|
+
BASED_OF_HELP: str = "Path to the base dataset for filtering train/test, e.g. 'data/hip_implant'"
|
|
18
|
+
TRANSFER_LEARNING_HELP: str = "Transfer learning source (imagenet, None, 'data/dataset_folder')"
|
|
19
|
+
GROUPING_HELP: str = "Grouping strategy for the dataset"
|
|
20
|
+
K_FOLD_HELP: str = "Number of folds for k-fold cross validation (0 = no k-fold, negative = LeavePOut)"
|
|
21
|
+
GRID_SEARCH_HELP: str = "If grid search should be performed on hyperparameters"
|
|
22
|
+
VERBOSE_HELP: str = "Verbosity level sent to functions"
|
|
23
|
+
PARSER_DESCRIPTION: str = "Command-line interface for training and evaluating machine learning models."
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Main function
|
|
27
|
+
@measure_time(printer=info, message="Total execution time of the script")
|
|
28
|
+
@handle_error(exceptions=(KeyboardInterrupt, Exception), error_log=DataScienceConfig.ERROR_LOG)
|
|
29
|
+
def routine(
|
|
30
|
+
default_input: str = f"{DataScienceConfig.DATA_FOLDER}/aug_hip_implant_preprocessed",
|
|
31
|
+
default_based_of: str = "auto",
|
|
32
|
+
default_transfer_learning: str = "imagenet",
|
|
33
|
+
default_grouping_strategy: str = "none",
|
|
34
|
+
default_kfold: int = 0,
|
|
35
|
+
default_verbose: int = 100,
|
|
36
|
+
|
|
37
|
+
loading_type: Literal["image"] = "image",
|
|
38
|
+
grid_search_param_grid: dict[str, list[Any]] | None = None,
|
|
39
|
+
add_to_train_only: list[str] | None = None,
|
|
40
|
+
) -> None:
|
|
41
|
+
""" Main function of the script for training and evaluating machine learning models.
|
|
42
|
+
|
|
43
|
+
This function handles the entire workflow for model training and evaluation, including:
|
|
44
|
+
- Parsing command-line arguments (default values are set in the function signature)
|
|
45
|
+
- Loading and preparing datasets with configurable grouping strategies
|
|
46
|
+
- Supporting transfer learning from various sources
|
|
47
|
+
- Enabling K-fold cross-validation, LeavePOut or LeaveOneOut
|
|
48
|
+
- Providing grid search capabilities for hyperparameter optimization
|
|
49
|
+
- Incorporating additional training data from specified paths
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
default_input (str): Default path to the dataset to use.
|
|
53
|
+
default_based_of (str): Default path to the base dataset for filtering train/test data.
|
|
54
|
+
default_transfer_learning (str): Default transfer learning source.
|
|
55
|
+
default_grouping_strategy (str): Default grouping strategy for the dataset.
|
|
56
|
+
default_kfold (int): Default number of folds for k-fold cross validation.
|
|
57
|
+
default_verbose (int): Default verbosity level.
|
|
58
|
+
loading_type (Literal["image"]): Type of data to load, currently only supports "image".
|
|
59
|
+
grid_search_param_grid (dict[str, list[Any]] | None): Parameters grid for hyperparameter optimization.
|
|
60
|
+
add_to_train_only (list[str] | None): List of paths to additional training datasets.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
None: This function does not return anything.
|
|
64
|
+
"""
|
|
65
|
+
if grid_search_param_grid is None:
|
|
66
|
+
grid_search_param_grid = {"batch_size": [8, 16, 32, 64]}
|
|
67
|
+
if add_to_train_only is None:
|
|
68
|
+
add_to_train_only = []
|
|
69
|
+
|
|
70
|
+
info("Starting the script...")
|
|
71
|
+
|
|
72
|
+
# Parse the arguments
|
|
73
|
+
parser = argparse.ArgumentParser(description=PARSER_DESCRIPTION)
|
|
74
|
+
parser.add_argument("--model", type=str, choices=ALL_MODELS, required=True, help=MODEL_HELP)
|
|
75
|
+
parser.add_argument("--input", type=str, default=default_input, help=INPUT_HELP)
|
|
76
|
+
parser.add_argument("--based_of", type=str, default=default_based_of, help=BASED_OF_HELP)
|
|
77
|
+
parser.add_argument("--transfer_learning", type=str, default=default_transfer_learning, help=TRANSFER_LEARNING_HELP)
|
|
78
|
+
parser.add_argument("--grouping_strategy", type=str, default=default_grouping_strategy, choices=LOWER_GS, help=GROUPING_HELP)
|
|
79
|
+
parser.add_argument("--kfold", type=int, default=default_kfold, help=K_FOLD_HELP)
|
|
80
|
+
parser.add_argument("--grid_search", action="store_true", help=GRID_SEARCH_HELP)
|
|
81
|
+
parser.add_argument("--verbose", type=int, default=default_verbose, help=VERBOSE_HELP)
|
|
82
|
+
args: argparse.Namespace = parser.parse_args()
|
|
83
|
+
model: str = args.model.lower()
|
|
84
|
+
kfold: int = args.kfold
|
|
85
|
+
input_path: str = clean_path(args.input, trailing_slash=False)
|
|
86
|
+
based_of: str = clean_path(args.based_of, trailing_slash=False)
|
|
87
|
+
transfer_learning: str = clean_path(args.transfer_learning, trailing_slash=False)
|
|
88
|
+
verbose: int = args.verbose
|
|
89
|
+
grouping_strategy: str = args.grouping_strategy
|
|
90
|
+
grid_search: bool = args.grid_search
|
|
91
|
+
|
|
92
|
+
# If based_of is "auto", set it to the input path without the "aug"
|
|
93
|
+
if based_of == "auto":
|
|
94
|
+
prefix: str = "/" + DataScienceConfig.AUGMENTED_DIRECTORY_PREFIX
|
|
95
|
+
if prefix in input_path:
|
|
96
|
+
based_of = input_path.replace(prefix, "/")
|
|
97
|
+
else:
|
|
98
|
+
based_of = ""
|
|
99
|
+
|
|
100
|
+
# Load the dataset
|
|
101
|
+
kwargs: dict[str, Any] = {}
|
|
102
|
+
if grouping_strategy == "concatenate":
|
|
103
|
+
kwargs["color_mode"] = "grayscale"
|
|
104
|
+
dataset: Dataset = DatasetLoader.from_path(
|
|
105
|
+
path=input_path,
|
|
106
|
+
loading_type=loading_type,
|
|
107
|
+
seed=DataScienceConfig.SEED,
|
|
108
|
+
test_size=DataScienceConfig.TEST_SIZE,
|
|
109
|
+
grouping_strategy=next(x for x in GroupingStrategy if x.name.lower() == grouping_strategy),
|
|
110
|
+
based_of=based_of,
|
|
111
|
+
**kwargs
|
|
112
|
+
)
|
|
113
|
+
info(dataset)
|
|
114
|
+
|
|
115
|
+
# Define parameter combinations
|
|
116
|
+
param_combinations: list[dict[str, Any]] = [{}] # Default empty params
|
|
117
|
+
if grid_search:
|
|
118
|
+
|
|
119
|
+
# Generate all parameter combinations
|
|
120
|
+
param_combinations.clear()
|
|
121
|
+
for values in itertools.product(*grid_search_param_grid.values()):
|
|
122
|
+
param_combinations.append(dict(zip(grid_search_param_grid.keys(), values, strict=False)))
|
|
123
|
+
|
|
124
|
+
# Load additional training data from provided paths
|
|
125
|
+
additional_training_data: XyTuple = XyTuple.empty()
|
|
126
|
+
for path in add_to_train_only:
|
|
127
|
+
try:
|
|
128
|
+
additional_dataset: Dataset = DatasetLoader.from_path(
|
|
129
|
+
path=path,
|
|
130
|
+
loading_type=loading_type,
|
|
131
|
+
seed=DataScienceConfig.SEED,
|
|
132
|
+
test_size=0, # Use all data for training
|
|
133
|
+
**kwargs
|
|
134
|
+
)
|
|
135
|
+
additional_training_data += additional_dataset.training_data
|
|
136
|
+
except Exception as e:
|
|
137
|
+
error(f"Failed to load additional training data from '{path}': {e}")
|
|
138
|
+
|
|
139
|
+
# Prepare the initialization arguments
|
|
140
|
+
# (num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any)
|
|
141
|
+
initialization_args: dict[str, Any] = {
|
|
142
|
+
|
|
143
|
+
# Mandatory arguments
|
|
144
|
+
"num_classes": dataset.num_classes,
|
|
145
|
+
"kfold": kfold,
|
|
146
|
+
"transfer_learning": transfer_learning,
|
|
147
|
+
|
|
148
|
+
# Optional arguments (override_params)
|
|
149
|
+
"additional_training_data": additional_training_data
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
# Collect all class routines that match the model name
|
|
153
|
+
classes: list[type[ModelInterface]] = [key for key, values in CLASS_MAP.items() if model in values]
|
|
154
|
+
|
|
155
|
+
# For each parameter combination
|
|
156
|
+
for i, params in enumerate(param_combinations):
|
|
157
|
+
if grid_search:
|
|
158
|
+
progress(f"Grid search {i+1}/{len(param_combinations)}, Training with parameters:\n{params}")
|
|
159
|
+
initialization_args["override_params"] = params
|
|
160
|
+
|
|
161
|
+
# Launch all class routines
|
|
162
|
+
for class_to_process in classes:
|
|
163
|
+
model_instance: ModelInterface = class_to_process(**initialization_args)
|
|
164
|
+
trained_model: ModelInterface = model_instance.routine_full(dataset, verbose)
|
|
165
|
+
info(trained_model)
|
|
166
|
+
del trained_model
|
|
167
|
+
return
|
|
168
|
+
|
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the Utils class, which provides static methods for common operations.
|
|
3
|
+
|
|
4
|
+
This class contains static methods for:
|
|
5
|
+
|
|
6
|
+
- Safe division (with 0 as denominator or None)
|
|
7
|
+
- Safe multiplication (with None)
|
|
8
|
+
- Converting between one-hot encoding and class indices
|
|
9
|
+
- Calculating ROC curves and AUC scores
|
|
10
|
+
"""
|
|
11
|
+
# pyright: reportUnknownMemberType=false
|
|
12
|
+
# pyright: reportUnknownVariableType=false
|
|
13
|
+
|
|
14
|
+
# Imports
|
|
15
|
+
from typing import Any
|
|
16
|
+
|
|
17
|
+
import numpy as np
|
|
18
|
+
from numpy.typing import NDArray
|
|
19
|
+
|
|
20
|
+
from ..ctx import Muffle
|
|
21
|
+
from ..decorators import handle_error
|
|
22
|
+
from .config.get import DataScienceConfig
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
# Class
|
|
26
|
+
class Utils:
|
|
27
|
+
""" Utility class providing common operations. """
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def safe_divide_float(a: float, b: float) -> float:
|
|
31
|
+
""" Safe division of two numbers, return 0 if denominator is 0.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
a (float): First number
|
|
35
|
+
b (float): Second number
|
|
36
|
+
Returns:
|
|
37
|
+
float: Result of the division
|
|
38
|
+
|
|
39
|
+
Examples:
|
|
40
|
+
>>> Utils.safe_divide_float(10, 2)
|
|
41
|
+
5.0
|
|
42
|
+
>>> Utils.safe_divide_float(0, 5)
|
|
43
|
+
0.0
|
|
44
|
+
>>> Utils.safe_divide_float(10, 0)
|
|
45
|
+
0
|
|
46
|
+
>>> Utils.safe_divide_float(-10, 2)
|
|
47
|
+
-5.0
|
|
48
|
+
"""
|
|
49
|
+
return a / b if b > 0 else 0
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def safe_divide_none(a: float | None, b: float | None) -> float | None:
|
|
53
|
+
""" Safe division of two numbers, return None if either number is None or denominator is 0.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
a (float | None): First number
|
|
57
|
+
b (float | None): Second number
|
|
58
|
+
Returns:
|
|
59
|
+
float | None: Result of the division or None if denominator is None
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
>>> None == Utils.safe_divide_none(None, 2)
|
|
63
|
+
True
|
|
64
|
+
>>> None == Utils.safe_divide_none(10, None)
|
|
65
|
+
True
|
|
66
|
+
>>> None == Utils.safe_divide_none(10, 0)
|
|
67
|
+
True
|
|
68
|
+
>>> Utils.safe_divide_none(10, 2)
|
|
69
|
+
5.0
|
|
70
|
+
"""
|
|
71
|
+
return a / b if a is not None and b is not None and b > 0 else None
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def safe_multiply_none(a: float | None, b: float | None) -> float | None:
|
|
75
|
+
""" Safe multiplication of two numbers, return None if either number is None.
|
|
76
|
+
|
|
77
|
+
Args:
|
|
78
|
+
a (float | None): First number
|
|
79
|
+
b (float | None): Second number
|
|
80
|
+
Returns:
|
|
81
|
+
float | None: Result of the multiplication or None if either number is None
|
|
82
|
+
|
|
83
|
+
Examples:
|
|
84
|
+
>>> None == Utils.safe_multiply_none(None, 2)
|
|
85
|
+
True
|
|
86
|
+
>>> None == Utils.safe_multiply_none(10, None)
|
|
87
|
+
True
|
|
88
|
+
>>> Utils.safe_multiply_none(10, 2)
|
|
89
|
+
20
|
|
90
|
+
>>> Utils.safe_multiply_none(-10, 2)
|
|
91
|
+
-20
|
|
92
|
+
"""
|
|
93
|
+
return a * b if a is not None and b is not None else None
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
97
|
+
def convert_to_class_indices(y: NDArray[np.intc | np.single] | list[NDArray[np.intc | np.single]]) -> NDArray[Any]:
|
|
98
|
+
""" Convert array from one-hot encoded format to class indices.
|
|
99
|
+
If the input is already class indices, it returns the same array.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
y (NDArray[intc | single] | list[NDArray[intc | single]]): Input array (either one-hot encoded or class indices)
|
|
103
|
+
Returns:
|
|
104
|
+
NDArray[Any]: Array of class indices: [[0, 0, 1, 0], [1, 0, 0, 0]] -> [2, 0]
|
|
105
|
+
|
|
106
|
+
Examples:
|
|
107
|
+
>>> Utils.convert_to_class_indices(np.array([[0, 0, 1, 0], [1, 0, 0, 0]])).tolist()
|
|
108
|
+
[2, 0]
|
|
109
|
+
>>> Utils.convert_to_class_indices(np.array([2, 0, 1])).tolist()
|
|
110
|
+
[2, 0, 1]
|
|
111
|
+
>>> Utils.convert_to_class_indices(np.array([[1], [0]])).tolist()
|
|
112
|
+
[[1], [0]]
|
|
113
|
+
>>> Utils.convert_to_class_indices(np.array([])).tolist()
|
|
114
|
+
[]
|
|
115
|
+
"""
|
|
116
|
+
y = np.array(y)
|
|
117
|
+
if y.ndim > 1 and y.shape[1] > 1:
|
|
118
|
+
return np.argmax(y, axis=1)
|
|
119
|
+
return y
|
|
120
|
+
|
|
121
|
+
@staticmethod
|
|
122
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
123
|
+
def convert_to_one_hot(
|
|
124
|
+
y: NDArray[np.intc | np.single] | list[NDArray[np.intc | np.single]], num_classes: int
|
|
125
|
+
) -> NDArray[Any]:
|
|
126
|
+
""" Convert array from class indices to one-hot encoded format.
|
|
127
|
+
If the input is already one-hot encoded, it returns the same array.
|
|
128
|
+
|
|
129
|
+
Args:
|
|
130
|
+
y (NDArray[intc|single] | list[NDArray[intc|single]]): Input array (either class indices or one-hot encoded)
|
|
131
|
+
num_classes (int): Total number of classes
|
|
132
|
+
Returns:
|
|
133
|
+
NDArray[Any]: One-hot encoded array: [2, 0] -> [[0, 0, 1, 0], [1, 0, 0, 0]]
|
|
134
|
+
|
|
135
|
+
Examples:
|
|
136
|
+
>>> Utils.convert_to_one_hot(np.array([2, 0]), 4).tolist()
|
|
137
|
+
[[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0]]
|
|
138
|
+
>>> Utils.convert_to_one_hot(np.array([[0, 0, 1, 0], [1, 0, 0, 0]]), 4).tolist()
|
|
139
|
+
[[0, 0, 1, 0], [1, 0, 0, 0]]
|
|
140
|
+
>>> Utils.convert_to_one_hot(np.array([0, 1, 2]), 3).shape
|
|
141
|
+
(3, 3)
|
|
142
|
+
>>> Utils.convert_to_one_hot(np.array([]), 3)
|
|
143
|
+
array([], shape=(0, 3), dtype=float32)
|
|
144
|
+
|
|
145
|
+
>>> array = np.array([[0.1, 0.9], [0.2, 0.8]])
|
|
146
|
+
>>> array = Utils.convert_to_class_indices(array)
|
|
147
|
+
>>> array = Utils.convert_to_one_hot(array, 2)
|
|
148
|
+
>>> array.tolist()
|
|
149
|
+
[[0.0, 1.0], [0.0, 1.0]]
|
|
150
|
+
"""
|
|
151
|
+
y = np.array(y)
|
|
152
|
+
if y.ndim == 1 or y.shape[1] != num_classes:
|
|
153
|
+
|
|
154
|
+
# Get the number of samples and create a one-hot encoded array
|
|
155
|
+
n_samples: int = len(y)
|
|
156
|
+
one_hot: NDArray[np.float32] = np.zeros((n_samples, num_classes), dtype=np.float32)
|
|
157
|
+
if n_samples > 0:
|
|
158
|
+
# Create a one-hot encoding by setting specific positions to 1.0:
|
|
159
|
+
# - np.arange(n_samples) creates an array [0, 1, 2, ..., n_samples-1] for row indices
|
|
160
|
+
# - y.astype(int) contains the class indices that determine which column gets the 1.0
|
|
161
|
+
# - Together they form coordinate pairs (row_idx, class_idx) where we set values to 1.0
|
|
162
|
+
row_indices: NDArray[np.intc] = np.arange(n_samples)
|
|
163
|
+
one_hot[row_indices, y.astype(int)] = 1.0
|
|
164
|
+
return one_hot
|
|
165
|
+
return y
|
|
166
|
+
|
|
167
|
+
@staticmethod
|
|
168
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
169
|
+
def get_roc_curve_and_auc(
|
|
170
|
+
y_true: NDArray[np.intc | np.single],
|
|
171
|
+
y_pred: NDArray[np.single]
|
|
172
|
+
) -> tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
173
|
+
""" Calculate ROC curve and AUC score.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
|
|
177
|
+
y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
|
|
178
|
+
Returns:
|
|
179
|
+
tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
180
|
+
Tuple containing AUC score, False Positive Rate, True Positive Rate, and Thresholds
|
|
181
|
+
|
|
182
|
+
Examples:
|
|
183
|
+
>>> # Binary classification example
|
|
184
|
+
>>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
|
|
185
|
+
>>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
|
|
186
|
+
>>> auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(y_true, y_pred)
|
|
187
|
+
>>> round(auc_value, 2)
|
|
188
|
+
0.92
|
|
189
|
+
>>> [round(x, 2) for x in fpr.tolist()]
|
|
190
|
+
[0.0, 0.0, 0.33, 0.67, 1.0]
|
|
191
|
+
>>> [round(x, 2) for x in tpr.tolist()]
|
|
192
|
+
[0.0, 0.5, 1.0, 1.0, 1.0]
|
|
193
|
+
>>> [round(x, 2) for x in thresholds.tolist()]
|
|
194
|
+
[inf, 0.9, 0.8, 0.3, 0.2]
|
|
195
|
+
"""
|
|
196
|
+
# For predictions, assert they are probabilities (one-hot encoded)
|
|
197
|
+
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
198
|
+
pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
|
|
199
|
+
|
|
200
|
+
# Calculate ROC curve and AUC score using probabilities
|
|
201
|
+
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
202
|
+
|
|
203
|
+
# Import functions
|
|
204
|
+
try:
|
|
205
|
+
from sklearn.metrics import roc_auc_score, roc_curve
|
|
206
|
+
except ImportError as e:
|
|
207
|
+
raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
|
|
208
|
+
|
|
209
|
+
# Convert y_true to class indices for both functions
|
|
210
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
211
|
+
|
|
212
|
+
# Calculate AUC score directly using roc_auc_score
|
|
213
|
+
auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
|
|
214
|
+
|
|
215
|
+
# Calculate ROC curve points
|
|
216
|
+
results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
|
|
217
|
+
fpr: NDArray[np.single] = results[0]
|
|
218
|
+
tpr: NDArray[np.single] = results[1]
|
|
219
|
+
thresholds: NDArray[np.single] = results[2]
|
|
220
|
+
|
|
221
|
+
return auc_value, fpr, tpr, thresholds
|
|
222
|
+
|
|
223
|
+
@staticmethod
|
|
224
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
225
|
+
def get_pr_curve_and_auc(
|
|
226
|
+
y_true: NDArray[np.intc | np.single],
|
|
227
|
+
y_pred: NDArray[np.single],
|
|
228
|
+
negative: bool = False
|
|
229
|
+
) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
230
|
+
""" Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
|
|
234
|
+
y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
|
|
235
|
+
negative (bool): Whether to calculate the negative Precision-Recall Curve
|
|
236
|
+
Returns:
|
|
237
|
+
tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
238
|
+
Tuple containing either:
|
|
239
|
+
- AUC score, Average Precision, Precision, Recall, and Thresholds
|
|
240
|
+
- AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
|
|
241
|
+
|
|
242
|
+
Examples:
|
|
243
|
+
>>> # Binary classification example
|
|
244
|
+
>>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
|
|
245
|
+
>>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
|
|
246
|
+
>>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
|
|
247
|
+
>>> round(auc_value, 2)
|
|
248
|
+
0.92
|
|
249
|
+
>>> round(average_precision, 2)
|
|
250
|
+
0.83
|
|
251
|
+
>>> [round(x, 2) for x in precision.tolist()]
|
|
252
|
+
[0.4, 0.5, 0.67, 1.0, 1.0]
|
|
253
|
+
>>> [round(x, 2) for x in recall.tolist()]
|
|
254
|
+
[1.0, 1.0, 1.0, 0.5, 0.0]
|
|
255
|
+
>>> [round(x, 2) for x in thresholds.tolist()]
|
|
256
|
+
[0.2, 0.3, 0.8, 0.9]
|
|
257
|
+
"""
|
|
258
|
+
# For predictions, assert they are probabilities (one-hot encoded)
|
|
259
|
+
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
260
|
+
pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
|
|
261
|
+
|
|
262
|
+
# Calculate Precision-Recall Curve and AUC score using probabilities
|
|
263
|
+
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
264
|
+
|
|
265
|
+
# Import functions
|
|
266
|
+
try:
|
|
267
|
+
from sklearn.metrics import auc, average_precision_score, precision_recall_curve
|
|
268
|
+
except ImportError as e:
|
|
269
|
+
raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
|
|
270
|
+
|
|
271
|
+
# Convert y_true to class indices for both functions
|
|
272
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
273
|
+
|
|
274
|
+
results: tuple[Any, Any, Any] = precision_recall_curve(
|
|
275
|
+
y_true_indices,
|
|
276
|
+
pred_probs,
|
|
277
|
+
pos_label=1 if not negative else 0
|
|
278
|
+
)
|
|
279
|
+
precision: NDArray[np.single] = results[0]
|
|
280
|
+
recall: NDArray[np.single] = results[1]
|
|
281
|
+
thresholds: NDArray[np.single] = results[2]
|
|
282
|
+
auc_value: float = float(auc(recall, precision))
|
|
283
|
+
average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
|
|
284
|
+
return auc_value, average_precision, precision, recall, thresholds
|
|
285
|
+
|