stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,168 @@
1
+
2
+ # Imports
3
+ import argparse
4
+ import itertools
5
+ from typing import Any, Literal
6
+
7
+ from ...decorators import handle_error, measure_time
8
+ from ...print import info, error, progress
9
+ from ...io import clean_path
10
+ from ..config.get import DataScienceConfig
11
+ from ..dataset import LOWER_GS, Dataset, DatasetLoader, GroupingStrategy, XyTuple
12
+ from ..models.all import ALL_MODELS, CLASS_MAP, ModelInterface
13
+
14
+ # Constants
15
+ MODEL_HELP: str = "Model(s) name or alias to use"
16
+ INPUT_HELP: str = "Path to the dataset, e.g. 'data/aug_hip_implant'"
17
+ BASED_OF_HELP: str = "Path to the base dataset for filtering train/test, e.g. 'data/hip_implant'"
18
+ TRANSFER_LEARNING_HELP: str = "Transfer learning source (imagenet, None, 'data/dataset_folder')"
19
+ GROUPING_HELP: str = "Grouping strategy for the dataset"
20
+ K_FOLD_HELP: str = "Number of folds for k-fold cross validation (0 = no k-fold, negative = LeavePOut)"
21
+ GRID_SEARCH_HELP: str = "If grid search should be performed on hyperparameters"
22
+ VERBOSE_HELP: str = "Verbosity level sent to functions"
23
+ PARSER_DESCRIPTION: str = "Command-line interface for training and evaluating machine learning models."
24
+
25
+
26
+ # Main function
27
+ @measure_time(printer=info, message="Total execution time of the script")
28
+ @handle_error(exceptions=(KeyboardInterrupt, Exception), error_log=DataScienceConfig.ERROR_LOG)
29
+ def routine(
30
+ default_input: str = f"{DataScienceConfig.DATA_FOLDER}/aug_hip_implant_preprocessed",
31
+ default_based_of: str = "auto",
32
+ default_transfer_learning: str = "imagenet",
33
+ default_grouping_strategy: str = "none",
34
+ default_kfold: int = 0,
35
+ default_verbose: int = 100,
36
+
37
+ loading_type: Literal["image"] = "image",
38
+ grid_search_param_grid: dict[str, list[Any]] | None = None,
39
+ add_to_train_only: list[str] | None = None,
40
+ ) -> None:
41
+ """ Main function of the script for training and evaluating machine learning models.
42
+
43
+ This function handles the entire workflow for model training and evaluation, including:
44
+ - Parsing command-line arguments (default values are set in the function signature)
45
+ - Loading and preparing datasets with configurable grouping strategies
46
+ - Supporting transfer learning from various sources
47
+ - Enabling K-fold cross-validation, LeavePOut or LeaveOneOut
48
+ - Providing grid search capabilities for hyperparameter optimization
49
+ - Incorporating additional training data from specified paths
50
+
51
+ Args:
52
+ default_input (str): Default path to the dataset to use.
53
+ default_based_of (str): Default path to the base dataset for filtering train/test data.
54
+ default_transfer_learning (str): Default transfer learning source.
55
+ default_grouping_strategy (str): Default grouping strategy for the dataset.
56
+ default_kfold (int): Default number of folds for k-fold cross validation.
57
+ default_verbose (int): Default verbosity level.
58
+ loading_type (Literal["image"]): Type of data to load, currently only supports "image".
59
+ grid_search_param_grid (dict[str, list[Any]] | None): Parameters grid for hyperparameter optimization.
60
+ add_to_train_only (list[str] | None): List of paths to additional training datasets.
61
+
62
+ Returns:
63
+ None: This function does not return anything.
64
+ """
65
+ if grid_search_param_grid is None:
66
+ grid_search_param_grid = {"batch_size": [8, 16, 32, 64]}
67
+ if add_to_train_only is None:
68
+ add_to_train_only = []
69
+
70
+ info("Starting the script...")
71
+
72
+ # Parse the arguments
73
+ parser = argparse.ArgumentParser(description=PARSER_DESCRIPTION)
74
+ parser.add_argument("--model", type=str, choices=ALL_MODELS, required=True, help=MODEL_HELP)
75
+ parser.add_argument("--input", type=str, default=default_input, help=INPUT_HELP)
76
+ parser.add_argument("--based_of", type=str, default=default_based_of, help=BASED_OF_HELP)
77
+ parser.add_argument("--transfer_learning", type=str, default=default_transfer_learning, help=TRANSFER_LEARNING_HELP)
78
+ parser.add_argument("--grouping_strategy", type=str, default=default_grouping_strategy, choices=LOWER_GS, help=GROUPING_HELP)
79
+ parser.add_argument("--kfold", type=int, default=default_kfold, help=K_FOLD_HELP)
80
+ parser.add_argument("--grid_search", action="store_true", help=GRID_SEARCH_HELP)
81
+ parser.add_argument("--verbose", type=int, default=default_verbose, help=VERBOSE_HELP)
82
+ args: argparse.Namespace = parser.parse_args()
83
+ model: str = args.model.lower()
84
+ kfold: int = args.kfold
85
+ input_path: str = clean_path(args.input, trailing_slash=False)
86
+ based_of: str = clean_path(args.based_of, trailing_slash=False)
87
+ transfer_learning: str = clean_path(args.transfer_learning, trailing_slash=False)
88
+ verbose: int = args.verbose
89
+ grouping_strategy: str = args.grouping_strategy
90
+ grid_search: bool = args.grid_search
91
+
92
+ # If based_of is "auto", set it to the input path without the "aug"
93
+ if based_of == "auto":
94
+ prefix: str = "/" + DataScienceConfig.AUGMENTED_DIRECTORY_PREFIX
95
+ if prefix in input_path:
96
+ based_of = input_path.replace(prefix, "/")
97
+ else:
98
+ based_of = ""
99
+
100
+ # Load the dataset
101
+ kwargs: dict[str, Any] = {}
102
+ if grouping_strategy == "concatenate":
103
+ kwargs["color_mode"] = "grayscale"
104
+ dataset: Dataset = DatasetLoader.from_path(
105
+ path=input_path,
106
+ loading_type=loading_type,
107
+ seed=DataScienceConfig.SEED,
108
+ test_size=DataScienceConfig.TEST_SIZE,
109
+ grouping_strategy=next(x for x in GroupingStrategy if x.name.lower() == grouping_strategy),
110
+ based_of=based_of,
111
+ **kwargs
112
+ )
113
+ info(dataset)
114
+
115
+ # Define parameter combinations
116
+ param_combinations: list[dict[str, Any]] = [{}] # Default empty params
117
+ if grid_search:
118
+
119
+ # Generate all parameter combinations
120
+ param_combinations.clear()
121
+ for values in itertools.product(*grid_search_param_grid.values()):
122
+ param_combinations.append(dict(zip(grid_search_param_grid.keys(), values, strict=False)))
123
+
124
+ # Load additional training data from provided paths
125
+ additional_training_data: XyTuple = XyTuple.empty()
126
+ for path in add_to_train_only:
127
+ try:
128
+ additional_dataset: Dataset = DatasetLoader.from_path(
129
+ path=path,
130
+ loading_type=loading_type,
131
+ seed=DataScienceConfig.SEED,
132
+ test_size=0, # Use all data for training
133
+ **kwargs
134
+ )
135
+ additional_training_data += additional_dataset.training_data
136
+ except Exception as e:
137
+ error(f"Failed to load additional training data from '{path}': {e}")
138
+
139
+ # Prepare the initialization arguments
140
+ # (num_classes: int, kfold: int = 0, transfer_learning: str = "imagenet", **override_params: Any)
141
+ initialization_args: dict[str, Any] = {
142
+
143
+ # Mandatory arguments
144
+ "num_classes": dataset.num_classes,
145
+ "kfold": kfold,
146
+ "transfer_learning": transfer_learning,
147
+
148
+ # Optional arguments (override_params)
149
+ "additional_training_data": additional_training_data
150
+ }
151
+
152
+ # Collect all class routines that match the model name
153
+ classes: list[type[ModelInterface]] = [key for key, values in CLASS_MAP.items() if model in values]
154
+
155
+ # For each parameter combination
156
+ for i, params in enumerate(param_combinations):
157
+ if grid_search:
158
+ progress(f"Grid search {i+1}/{len(param_combinations)}, Training with parameters:\n{params}")
159
+ initialization_args["override_params"] = params
160
+
161
+ # Launch all class routines
162
+ for class_to_process in classes:
163
+ model_instance: ModelInterface = class_to_process(**initialization_args)
164
+ trained_model: ModelInterface = model_instance.routine_full(dataset, verbose)
165
+ info(trained_model)
166
+ del trained_model
167
+ return
168
+
@@ -0,0 +1,285 @@
1
+ """
2
+ This module contains the Utils class, which provides static methods for common operations.
3
+
4
+ This class contains static methods for:
5
+
6
+ - Safe division (with 0 as denominator or None)
7
+ - Safe multiplication (with None)
8
+ - Converting between one-hot encoding and class indices
9
+ - Calculating ROC curves and AUC scores
10
+ """
11
+ # pyright: reportUnknownMemberType=false
12
+ # pyright: reportUnknownVariableType=false
13
+
14
+ # Imports
15
+ from typing import Any
16
+
17
+ import numpy as np
18
+ from numpy.typing import NDArray
19
+
20
+ from ..ctx import Muffle
21
+ from ..decorators import handle_error
22
+ from .config.get import DataScienceConfig
23
+
24
+
25
+ # Class
26
+ class Utils:
27
+ """ Utility class providing common operations. """
28
+
29
+ @staticmethod
30
+ def safe_divide_float(a: float, b: float) -> float:
31
+ """ Safe division of two numbers, return 0 if denominator is 0.
32
+
33
+ Args:
34
+ a (float): First number
35
+ b (float): Second number
36
+ Returns:
37
+ float: Result of the division
38
+
39
+ Examples:
40
+ >>> Utils.safe_divide_float(10, 2)
41
+ 5.0
42
+ >>> Utils.safe_divide_float(0, 5)
43
+ 0.0
44
+ >>> Utils.safe_divide_float(10, 0)
45
+ 0
46
+ >>> Utils.safe_divide_float(-10, 2)
47
+ -5.0
48
+ """
49
+ return a / b if b > 0 else 0
50
+
51
+ @staticmethod
52
+ def safe_divide_none(a: float | None, b: float | None) -> float | None:
53
+ """ Safe division of two numbers, return None if either number is None or denominator is 0.
54
+
55
+ Args:
56
+ a (float | None): First number
57
+ b (float | None): Second number
58
+ Returns:
59
+ float | None: Result of the division or None if denominator is None
60
+
61
+ Examples:
62
+ >>> None == Utils.safe_divide_none(None, 2)
63
+ True
64
+ >>> None == Utils.safe_divide_none(10, None)
65
+ True
66
+ >>> None == Utils.safe_divide_none(10, 0)
67
+ True
68
+ >>> Utils.safe_divide_none(10, 2)
69
+ 5.0
70
+ """
71
+ return a / b if a is not None and b is not None and b > 0 else None
72
+
73
+ @staticmethod
74
+ def safe_multiply_none(a: float | None, b: float | None) -> float | None:
75
+ """ Safe multiplication of two numbers, return None if either number is None.
76
+
77
+ Args:
78
+ a (float | None): First number
79
+ b (float | None): Second number
80
+ Returns:
81
+ float | None: Result of the multiplication or None if either number is None
82
+
83
+ Examples:
84
+ >>> None == Utils.safe_multiply_none(None, 2)
85
+ True
86
+ >>> None == Utils.safe_multiply_none(10, None)
87
+ True
88
+ >>> Utils.safe_multiply_none(10, 2)
89
+ 20
90
+ >>> Utils.safe_multiply_none(-10, 2)
91
+ -20
92
+ """
93
+ return a * b if a is not None and b is not None else None
94
+
95
+ @staticmethod
96
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
97
+ def convert_to_class_indices(y: NDArray[np.intc | np.single] | list[NDArray[np.intc | np.single]]) -> NDArray[Any]:
98
+ """ Convert array from one-hot encoded format to class indices.
99
+ If the input is already class indices, it returns the same array.
100
+
101
+ Args:
102
+ y (NDArray[intc | single] | list[NDArray[intc | single]]): Input array (either one-hot encoded or class indices)
103
+ Returns:
104
+ NDArray[Any]: Array of class indices: [[0, 0, 1, 0], [1, 0, 0, 0]] -> [2, 0]
105
+
106
+ Examples:
107
+ >>> Utils.convert_to_class_indices(np.array([[0, 0, 1, 0], [1, 0, 0, 0]])).tolist()
108
+ [2, 0]
109
+ >>> Utils.convert_to_class_indices(np.array([2, 0, 1])).tolist()
110
+ [2, 0, 1]
111
+ >>> Utils.convert_to_class_indices(np.array([[1], [0]])).tolist()
112
+ [[1], [0]]
113
+ >>> Utils.convert_to_class_indices(np.array([])).tolist()
114
+ []
115
+ """
116
+ y = np.array(y)
117
+ if y.ndim > 1 and y.shape[1] > 1:
118
+ return np.argmax(y, axis=1)
119
+ return y
120
+
121
+ @staticmethod
122
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
123
+ def convert_to_one_hot(
124
+ y: NDArray[np.intc | np.single] | list[NDArray[np.intc | np.single]], num_classes: int
125
+ ) -> NDArray[Any]:
126
+ """ Convert array from class indices to one-hot encoded format.
127
+ If the input is already one-hot encoded, it returns the same array.
128
+
129
+ Args:
130
+ y (NDArray[intc|single] | list[NDArray[intc|single]]): Input array (either class indices or one-hot encoded)
131
+ num_classes (int): Total number of classes
132
+ Returns:
133
+ NDArray[Any]: One-hot encoded array: [2, 0] -> [[0, 0, 1, 0], [1, 0, 0, 0]]
134
+
135
+ Examples:
136
+ >>> Utils.convert_to_one_hot(np.array([2, 0]), 4).tolist()
137
+ [[0.0, 0.0, 1.0, 0.0], [1.0, 0.0, 0.0, 0.0]]
138
+ >>> Utils.convert_to_one_hot(np.array([[0, 0, 1, 0], [1, 0, 0, 0]]), 4).tolist()
139
+ [[0, 0, 1, 0], [1, 0, 0, 0]]
140
+ >>> Utils.convert_to_one_hot(np.array([0, 1, 2]), 3).shape
141
+ (3, 3)
142
+ >>> Utils.convert_to_one_hot(np.array([]), 3)
143
+ array([], shape=(0, 3), dtype=float32)
144
+
145
+ >>> array = np.array([[0.1, 0.9], [0.2, 0.8]])
146
+ >>> array = Utils.convert_to_class_indices(array)
147
+ >>> array = Utils.convert_to_one_hot(array, 2)
148
+ >>> array.tolist()
149
+ [[0.0, 1.0], [0.0, 1.0]]
150
+ """
151
+ y = np.array(y)
152
+ if y.ndim == 1 or y.shape[1] != num_classes:
153
+
154
+ # Get the number of samples and create a one-hot encoded array
155
+ n_samples: int = len(y)
156
+ one_hot: NDArray[np.float32] = np.zeros((n_samples, num_classes), dtype=np.float32)
157
+ if n_samples > 0:
158
+ # Create a one-hot encoding by setting specific positions to 1.0:
159
+ # - np.arange(n_samples) creates an array [0, 1, 2, ..., n_samples-1] for row indices
160
+ # - y.astype(int) contains the class indices that determine which column gets the 1.0
161
+ # - Together they form coordinate pairs (row_idx, class_idx) where we set values to 1.0
162
+ row_indices: NDArray[np.intc] = np.arange(n_samples)
163
+ one_hot[row_indices, y.astype(int)] = 1.0
164
+ return one_hot
165
+ return y
166
+
167
+ @staticmethod
168
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
169
+ def get_roc_curve_and_auc(
170
+ y_true: NDArray[np.intc | np.single],
171
+ y_pred: NDArray[np.single]
172
+ ) -> tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
173
+ """ Calculate ROC curve and AUC score.
174
+
175
+ Args:
176
+ y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
177
+ y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
178
+ Returns:
179
+ tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
180
+ Tuple containing AUC score, False Positive Rate, True Positive Rate, and Thresholds
181
+
182
+ Examples:
183
+ >>> # Binary classification example
184
+ >>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
185
+ >>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
186
+ >>> auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(y_true, y_pred)
187
+ >>> round(auc_value, 2)
188
+ 0.92
189
+ >>> [round(x, 2) for x in fpr.tolist()]
190
+ [0.0, 0.0, 0.33, 0.67, 1.0]
191
+ >>> [round(x, 2) for x in tpr.tolist()]
192
+ [0.0, 0.5, 1.0, 1.0, 1.0]
193
+ >>> [round(x, 2) for x in thresholds.tolist()]
194
+ [inf, 0.9, 0.8, 0.3, 0.2]
195
+ """
196
+ # For predictions, assert they are probabilities (one-hot encoded)
197
+ assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
198
+ pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
199
+
200
+ # Calculate ROC curve and AUC score using probabilities
201
+ with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
202
+
203
+ # Import functions
204
+ try:
205
+ from sklearn.metrics import roc_auc_score, roc_curve
206
+ except ImportError as e:
207
+ raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
208
+
209
+ # Convert y_true to class indices for both functions
210
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
211
+
212
+ # Calculate AUC score directly using roc_auc_score
213
+ auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
214
+
215
+ # Calculate ROC curve points
216
+ results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
217
+ fpr: NDArray[np.single] = results[0]
218
+ tpr: NDArray[np.single] = results[1]
219
+ thresholds: NDArray[np.single] = results[2]
220
+
221
+ return auc_value, fpr, tpr, thresholds
222
+
223
+ @staticmethod
224
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
225
+ def get_pr_curve_and_auc(
226
+ y_true: NDArray[np.intc | np.single],
227
+ y_pred: NDArray[np.single],
228
+ negative: bool = False
229
+ ) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
230
+ """ Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
231
+
232
+ Args:
233
+ y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
234
+ y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
235
+ negative (bool): Whether to calculate the negative Precision-Recall Curve
236
+ Returns:
237
+ tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
238
+ Tuple containing either:
239
+ - AUC score, Average Precision, Precision, Recall, and Thresholds
240
+ - AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
241
+
242
+ Examples:
243
+ >>> # Binary classification example
244
+ >>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
245
+ >>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
246
+ >>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
247
+ >>> round(auc_value, 2)
248
+ 0.92
249
+ >>> round(average_precision, 2)
250
+ 0.83
251
+ >>> [round(x, 2) for x in precision.tolist()]
252
+ [0.4, 0.5, 0.67, 1.0, 1.0]
253
+ >>> [round(x, 2) for x in recall.tolist()]
254
+ [1.0, 1.0, 1.0, 0.5, 0.0]
255
+ >>> [round(x, 2) for x in thresholds.tolist()]
256
+ [0.2, 0.3, 0.8, 0.9]
257
+ """
258
+ # For predictions, assert they are probabilities (one-hot encoded)
259
+ assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
260
+ pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
261
+
262
+ # Calculate Precision-Recall Curve and AUC score using probabilities
263
+ with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
264
+
265
+ # Import functions
266
+ try:
267
+ from sklearn.metrics import auc, average_precision_score, precision_recall_curve
268
+ except ImportError as e:
269
+ raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
270
+
271
+ # Convert y_true to class indices for both functions
272
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
273
+
274
+ results: tuple[Any, Any, Any] = precision_recall_curve(
275
+ y_true_indices,
276
+ pred_probs,
277
+ pos_label=1 if not negative else 0
278
+ )
279
+ precision: NDArray[np.single] = results[0]
280
+ recall: NDArray[np.single] = results[1]
281
+ thresholds: NDArray[np.single] = results[2]
282
+ auc_value: float = float(auc(recall, precision))
283
+ average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
284
+ return auc_value, average_precision, precision, recall, thresholds
285
+