stouputils 1.14.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. stouputils/__init__.py +40 -0
  2. stouputils/__main__.py +86 -0
  3. stouputils/_deprecated.py +37 -0
  4. stouputils/all_doctests.py +160 -0
  5. stouputils/applications/__init__.py +22 -0
  6. stouputils/applications/automatic_docs.py +634 -0
  7. stouputils/applications/upscaler/__init__.py +39 -0
  8. stouputils/applications/upscaler/config.py +128 -0
  9. stouputils/applications/upscaler/image.py +247 -0
  10. stouputils/applications/upscaler/video.py +287 -0
  11. stouputils/archive.py +344 -0
  12. stouputils/backup.py +488 -0
  13. stouputils/collections.py +244 -0
  14. stouputils/continuous_delivery/__init__.py +27 -0
  15. stouputils/continuous_delivery/cd_utils.py +243 -0
  16. stouputils/continuous_delivery/github.py +522 -0
  17. stouputils/continuous_delivery/pypi.py +130 -0
  18. stouputils/continuous_delivery/pyproject.py +147 -0
  19. stouputils/continuous_delivery/stubs.py +86 -0
  20. stouputils/ctx.py +408 -0
  21. stouputils/data_science/config/get.py +51 -0
  22. stouputils/data_science/config/set.py +125 -0
  23. stouputils/data_science/data_processing/image/__init__.py +66 -0
  24. stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
  25. stouputils/data_science/data_processing/image/axis_flip.py +58 -0
  26. stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
  27. stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
  28. stouputils/data_science/data_processing/image/blur.py +59 -0
  29. stouputils/data_science/data_processing/image/brightness.py +54 -0
  30. stouputils/data_science/data_processing/image/canny.py +110 -0
  31. stouputils/data_science/data_processing/image/clahe.py +92 -0
  32. stouputils/data_science/data_processing/image/common.py +30 -0
  33. stouputils/data_science/data_processing/image/contrast.py +53 -0
  34. stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
  35. stouputils/data_science/data_processing/image/denoise.py +378 -0
  36. stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
  37. stouputils/data_science/data_processing/image/invert.py +64 -0
  38. stouputils/data_science/data_processing/image/laplacian.py +60 -0
  39. stouputils/data_science/data_processing/image/median_blur.py +52 -0
  40. stouputils/data_science/data_processing/image/noise.py +59 -0
  41. stouputils/data_science/data_processing/image/normalize.py +65 -0
  42. stouputils/data_science/data_processing/image/random_erase.py +66 -0
  43. stouputils/data_science/data_processing/image/resize.py +69 -0
  44. stouputils/data_science/data_processing/image/rotation.py +80 -0
  45. stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
  46. stouputils/data_science/data_processing/image/sharpening.py +55 -0
  47. stouputils/data_science/data_processing/image/shearing.py +64 -0
  48. stouputils/data_science/data_processing/image/threshold.py +64 -0
  49. stouputils/data_science/data_processing/image/translation.py +71 -0
  50. stouputils/data_science/data_processing/image/zoom.py +83 -0
  51. stouputils/data_science/data_processing/image_augmentation.py +118 -0
  52. stouputils/data_science/data_processing/image_preprocess.py +183 -0
  53. stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
  54. stouputils/data_science/data_processing/technique.py +481 -0
  55. stouputils/data_science/dataset/__init__.py +45 -0
  56. stouputils/data_science/dataset/dataset.py +292 -0
  57. stouputils/data_science/dataset/dataset_loader.py +135 -0
  58. stouputils/data_science/dataset/grouping_strategy.py +296 -0
  59. stouputils/data_science/dataset/image_loader.py +100 -0
  60. stouputils/data_science/dataset/xy_tuple.py +696 -0
  61. stouputils/data_science/metric_dictionnary.py +106 -0
  62. stouputils/data_science/metric_utils.py +847 -0
  63. stouputils/data_science/mlflow_utils.py +206 -0
  64. stouputils/data_science/models/abstract_model.py +149 -0
  65. stouputils/data_science/models/all.py +85 -0
  66. stouputils/data_science/models/base_keras.py +765 -0
  67. stouputils/data_science/models/keras/all.py +38 -0
  68. stouputils/data_science/models/keras/convnext.py +62 -0
  69. stouputils/data_science/models/keras/densenet.py +50 -0
  70. stouputils/data_science/models/keras/efficientnet.py +60 -0
  71. stouputils/data_science/models/keras/mobilenet.py +56 -0
  72. stouputils/data_science/models/keras/resnet.py +52 -0
  73. stouputils/data_science/models/keras/squeezenet.py +233 -0
  74. stouputils/data_science/models/keras/vgg.py +42 -0
  75. stouputils/data_science/models/keras/xception.py +38 -0
  76. stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
  77. stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
  78. stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
  79. stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
  80. stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
  81. stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
  82. stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
  83. stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
  84. stouputils/data_science/models/keras_utils/visualizations.py +416 -0
  85. stouputils/data_science/models/model_interface.py +939 -0
  86. stouputils/data_science/models/sandbox.py +116 -0
  87. stouputils/data_science/range_tuple.py +234 -0
  88. stouputils/data_science/scripts/augment_dataset.py +77 -0
  89. stouputils/data_science/scripts/exhaustive_process.py +133 -0
  90. stouputils/data_science/scripts/preprocess_dataset.py +70 -0
  91. stouputils/data_science/scripts/routine.py +168 -0
  92. stouputils/data_science/utils.py +285 -0
  93. stouputils/decorators.py +605 -0
  94. stouputils/image.py +441 -0
  95. stouputils/installer/__init__.py +18 -0
  96. stouputils/installer/common.py +67 -0
  97. stouputils/installer/downloader.py +101 -0
  98. stouputils/installer/linux.py +144 -0
  99. stouputils/installer/main.py +223 -0
  100. stouputils/installer/windows.py +136 -0
  101. stouputils/io.py +486 -0
  102. stouputils/parallel.py +483 -0
  103. stouputils/print.py +482 -0
  104. stouputils/py.typed +1 -0
  105. stouputils/stouputils/__init__.pyi +15 -0
  106. stouputils/stouputils/_deprecated.pyi +12 -0
  107. stouputils/stouputils/all_doctests.pyi +46 -0
  108. stouputils/stouputils/applications/__init__.pyi +2 -0
  109. stouputils/stouputils/applications/automatic_docs.pyi +106 -0
  110. stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
  111. stouputils/stouputils/applications/upscaler/config.pyi +18 -0
  112. stouputils/stouputils/applications/upscaler/image.pyi +109 -0
  113. stouputils/stouputils/applications/upscaler/video.pyi +60 -0
  114. stouputils/stouputils/archive.pyi +67 -0
  115. stouputils/stouputils/backup.pyi +109 -0
  116. stouputils/stouputils/collections.pyi +86 -0
  117. stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
  118. stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
  119. stouputils/stouputils/continuous_delivery/github.pyi +162 -0
  120. stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
  121. stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
  122. stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
  123. stouputils/stouputils/ctx.pyi +211 -0
  124. stouputils/stouputils/decorators.pyi +252 -0
  125. stouputils/stouputils/image.pyi +172 -0
  126. stouputils/stouputils/installer/__init__.pyi +5 -0
  127. stouputils/stouputils/installer/common.pyi +39 -0
  128. stouputils/stouputils/installer/downloader.pyi +24 -0
  129. stouputils/stouputils/installer/linux.pyi +39 -0
  130. stouputils/stouputils/installer/main.pyi +57 -0
  131. stouputils/stouputils/installer/windows.pyi +31 -0
  132. stouputils/stouputils/io.pyi +213 -0
  133. stouputils/stouputils/parallel.pyi +216 -0
  134. stouputils/stouputils/print.pyi +136 -0
  135. stouputils/stouputils/version_pkg.pyi +15 -0
  136. stouputils/version_pkg.py +189 -0
  137. stouputils-1.14.0.dist-info/METADATA +178 -0
  138. stouputils-1.14.0.dist-info/RECORD +140 -0
  139. stouputils-1.14.0.dist-info/WHEEL +4 -0
  140. stouputils-1.14.0.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,847 @@
1
+ """
2
+ This module contains the MetricUtils class, which provides static methods for
3
+ calculating various metrics for machine learning tasks.
4
+
5
+ This class contains static methods for:
6
+
7
+ - Calculating various metrics (accuracy, precision, recall, etc.)
8
+ - Computing confusion matrix and related metrics
9
+ - Generating ROC curves and finding optimal thresholds
10
+ - Calculating F-beta scores
11
+
12
+ The metrics are calculated based on the predictions made by a model and the true labels from a dataset.
13
+ The class supports both binary and multiclass classification tasks.
14
+ """
15
+ # pyright: reportUnknownMemberType=false
16
+ # pyright: reportUnknownVariableType=false
17
+ # pyright: reportMissingTypeStubs=false
18
+
19
+ # Imports
20
+ import os
21
+ from collections.abc import Iterable
22
+ from typing import Any, Literal
23
+
24
+ import mlflow
25
+ import numpy as np
26
+ from ..decorators import handle_error, measure_time
27
+ from ..print import info, warning
28
+ from matplotlib import pyplot as plt
29
+ from numpy.typing import NDArray
30
+ from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, matthews_corrcoef
31
+
32
+ from .config.get import DataScienceConfig
33
+ from .dataset import Dataset
34
+ from .metric_dictionnary import MetricDictionnary
35
+ from .utils import Utils
36
+
37
+
38
+ # Class
39
+ class MetricUtils:
40
+ """ Class containing static methods for calculating metrics. """
41
+
42
+ @staticmethod
43
+ @measure_time(printer=info, message="Execution time of MetricUtils.metrics")
44
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
45
+ def metrics(
46
+ dataset: Dataset,
47
+ predictions: Iterable[Any],
48
+ run_name: str,
49
+ mode: Literal["binary", "multiclass", "none"] = "binary"
50
+ ) -> dict[str, float]:
51
+ """ Method to calculate as many metrics as possible for the given dataset and predictions.
52
+
53
+ Args:
54
+ dataset (Dataset): Dataset containing the true labels
55
+ predictions (Iterable): Predictions made by the model
56
+ run_name (str): Name of the run, used to save the ROC curve
57
+ mode (Literal): Mode of the classification, defaults to "binary"
58
+ Returns:
59
+ dict[str, float]: Dictionary containing the calculated metrics
60
+
61
+ Examples:
62
+ >>> # Prepare a test dataset
63
+ >>> from .dataset import XyTuple
64
+ >>> test_data = XyTuple(X=np.array([[1], [2], [3]]), y=np.array([0, 1, 0]))
65
+ >>> dataset = Dataset(training_data=test_data, test_data=test_data, name="osef")
66
+
67
+ >>> # Prepare predictions
68
+ >>> predictions = np.array([[0.9, 0.1], [0.2, 0.8], [0.2, 0.8]])
69
+
70
+ >>> # Calculate metrics
71
+ >>> from stouputils.ctx import Muffle
72
+ >>> with Muffle():
73
+ ... metrics = MetricUtils.metrics(dataset, predictions, run_name="")
74
+
75
+ >>> # Check metrics
76
+ >>> round(float(metrics[MetricDictionnary.ACCURACY]), 2)
77
+ 0.67
78
+ >>> round(float(metrics[MetricDictionnary.PRECISION]), 2)
79
+ 0.5
80
+ >>> round(float(metrics[MetricDictionnary.RECALL]), 2)
81
+ 1.0
82
+ >>> round(float(metrics[MetricDictionnary.F1_SCORE]), 2)
83
+ 0.67
84
+ >>> round(float(metrics[MetricDictionnary.AUC]), 2)
85
+ 0.75
86
+ >>> round(float(metrics[MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT]), 2)
87
+ 0.5
88
+ """
89
+ # Initialize metrics
90
+ metrics: dict[str, float] = {}
91
+ y_true: NDArray[np.single] = dataset.test_data.ungrouped_array()[1]
92
+ y_pred: NDArray[np.single] = np.array(predictions)
93
+
94
+ # Binary classification
95
+ if mode == "binary":
96
+ true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
97
+ pred_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_pred)
98
+
99
+ # Get confusion matrix metrics
100
+ conf_metrics: dict[str, float] = MetricUtils.confusion_matrix(
101
+ true_classes=true_classes,
102
+ pred_classes=pred_classes,
103
+ labels=dataset.labels,
104
+ run_name=run_name
105
+ )
106
+ metrics.update(conf_metrics)
107
+
108
+ # Calculate F-beta scores
109
+ precision: float = conf_metrics.get(MetricDictionnary.PRECISION, 0)
110
+ recall: float = conf_metrics.get(MetricDictionnary.RECALL, 0)
111
+ f_metrics: dict[str, float] = MetricUtils.f_scores(precision, recall)
112
+ if f_metrics:
113
+ metrics.update(f_metrics)
114
+
115
+ # Calculate Matthews Correlation Coefficient
116
+ mcc_metric: dict[str, float] = MetricUtils.matthews_correlation(true_classes, pred_classes)
117
+ if mcc_metric:
118
+ metrics.update(mcc_metric)
119
+
120
+ # Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
121
+ curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
122
+ if curves_metrics:
123
+ metrics.update(curves_metrics)
124
+
125
+ # Multiclass classification
126
+ elif mode == "multiclass":
127
+ pass
128
+
129
+ return metrics
130
+
131
+ @staticmethod
132
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
133
+ def confusion_matrix(
134
+ true_classes: NDArray[np.intc],
135
+ pred_classes: NDArray[np.intc],
136
+ labels: tuple[str, ...],
137
+ run_name: str = ""
138
+ ) -> dict[str, float]:
139
+ """ Calculate metrics based on confusion matrix.
140
+
141
+ Args:
142
+ true_classes (NDArray[np.intc]): True class labels
143
+ pred_classes (NDArray[np.intc]): Predicted class labels
144
+ labels (tuple[str, ...]): List of class labels (strings)
145
+ run_name (str): Name for saving the plot
146
+ Returns:
147
+ dict[str, float]: Dictionary of confusion matrix based metrics
148
+
149
+ Examples:
150
+ >>> # Prepare data
151
+ >>> true_classes = np.array([0, 1, 0])
152
+ >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
153
+ >>> pred_classes = Utils.convert_to_class_indices(pred_probs) # [0, 1, 1]
154
+ >>> labels = ["class_0", "class_1"]
155
+
156
+ >>> # Calculate metrics
157
+ >>> from stouputils.ctx import Muffle
158
+ >>> with Muffle():
159
+ ... metrics = MetricUtils.confusion_matrix(true_classes, pred_classes, labels, run_name="")
160
+
161
+ >>> # Check metrics
162
+ >>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_TN])
163
+ 1
164
+ >>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_FP])
165
+ 1
166
+ >>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_FN])
167
+ 0
168
+ >>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_TP])
169
+ 1
170
+ >>> round(float(metrics[MetricDictionnary.FALSE_POSITIVE_RATE]), 2)
171
+ 0.5
172
+ """
173
+ metrics: dict[str, float] = {}
174
+
175
+ # Get basic confusion matrix values
176
+ conf_matrix: NDArray[np.intc] = confusion_matrix(true_classes, pred_classes)
177
+ TN: int = conf_matrix[0, 0] # True Negatives
178
+ FP: int = conf_matrix[0, 1] # False Positives
179
+ FN: int = conf_matrix[1, 0] # False Negatives
180
+ TP: int = conf_matrix[1, 1] # True Positives
181
+
182
+ # Calculate totals for each category
183
+ total_samples: int = TN + FP + FN + TP
184
+ total_actual_negatives: int = TN + FP
185
+ total_actual_positives: int = TP + FN
186
+ total_predicted_negatives: int = TN + FN
187
+ total_predicted_positives: int = TP + FP
188
+
189
+ # Calculate core metrics
190
+ specificity: float = Utils.safe_divide_float(TN, total_actual_negatives)
191
+ recall: float = Utils.safe_divide_float(TP, total_actual_positives)
192
+ precision: float = Utils.safe_divide_float(TP, total_predicted_positives)
193
+ npv: float = Utils.safe_divide_float(TN, total_predicted_negatives)
194
+ accuracy: float = Utils.safe_divide_float(TN + TP, total_samples)
195
+ balanced_accuracy: float = (specificity + recall) / 2
196
+ f1_score: float = Utils.safe_divide_float(2 * (precision * recall), precision + recall)
197
+ f1_score_negative: float = Utils.safe_divide_float(2 * (specificity * npv), specificity + npv)
198
+
199
+ # Store main metrics using MetricDictionnary
200
+ metrics[MetricDictionnary.SPECIFICITY] = specificity
201
+ metrics[MetricDictionnary.RECALL] = recall
202
+ metrics[MetricDictionnary.PRECISION] = precision
203
+ metrics[MetricDictionnary.NPV] = npv
204
+ metrics[MetricDictionnary.ACCURACY] = accuracy
205
+ metrics[MetricDictionnary.BALANCED_ACCURACY] = balanced_accuracy
206
+ metrics[MetricDictionnary.F1_SCORE] = f1_score
207
+ metrics[MetricDictionnary.F1_SCORE_NEGATIVE] = f1_score_negative
208
+
209
+ # Store confusion matrix values and derived metrics
210
+ metrics[MetricDictionnary.CONFUSION_MATRIX_TN] = TN
211
+ metrics[MetricDictionnary.CONFUSION_MATRIX_FP] = FP
212
+ metrics[MetricDictionnary.CONFUSION_MATRIX_FN] = FN
213
+ metrics[MetricDictionnary.CONFUSION_MATRIX_TP] = TP
214
+ metrics[MetricDictionnary.FALSE_POSITIVE_RATE] = Utils.safe_divide_float(FP, total_actual_negatives)
215
+ metrics[MetricDictionnary.FALSE_NEGATIVE_RATE] = Utils.safe_divide_float(FN, total_actual_positives)
216
+ metrics[MetricDictionnary.FALSE_DISCOVERY_RATE] = Utils.safe_divide_float(FP, total_predicted_positives)
217
+ metrics[MetricDictionnary.FALSE_OMISSION_RATE] = Utils.safe_divide_float(FN, total_predicted_negatives)
218
+ metrics[MetricDictionnary.CRITICAL_SUCCESS_INDEX] = Utils.safe_divide_float(TP, total_actual_positives + FP)
219
+
220
+ # Plot confusion matrix
221
+ if run_name:
222
+ confusion_matrix_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_confusion_matrix.png"
223
+ ConfusionMatrixDisplay.from_predictions(true_classes, pred_classes, display_labels=labels)
224
+ plt.savefig(confusion_matrix_path)
225
+ mlflow.log_artifact(confusion_matrix_path)
226
+ os.remove(confusion_matrix_path)
227
+ plt.close()
228
+
229
+ return metrics
230
+
231
+ @staticmethod
232
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
233
+ def f_scores(precision: float, recall: float) -> dict[str, float]:
234
+ """ Calculate F-beta scores for different beta values.
235
+
236
+ Args:
237
+ precision (float): Precision value
238
+ recall (float): Recall value
239
+ Returns:
240
+ dict[str, float]: Dictionary of F-beta scores
241
+
242
+ Examples:
243
+ >>> from stouputils.ctx import Muffle
244
+ >>> with Muffle():
245
+ ... metrics = MetricUtils.f_scores(precision=0.5, recall=1.0)
246
+ >>> [round(float(x), 2) for x in metrics.values()]
247
+ [0.5, 0.51, 0.54, 0.58, 0.62, 0.67, 0.71, 0.75, 0.78, 0.81, 0.83]
248
+
249
+ """
250
+ # Assertions
251
+ assert precision > 0, "Precision cannot be 0"
252
+ assert recall > 0, "Recall cannot be 0"
253
+
254
+ # Calculate F-beta scores
255
+ metrics: dict[str, float] = {}
256
+ betas: Iterable[float] = np.linspace(0, 2, 11)
257
+ for beta in betas:
258
+ divider: float = (beta**2 * precision) + recall
259
+ score: float = Utils.safe_divide_float((1 + beta**2) * precision * recall, divider)
260
+ metrics[MetricDictionnary.F_SCORE_X.replace("X", f"{beta:.1f}")] = score
261
+ if score == 0:
262
+ warning(f"F-score is 0 for beta={beta:.1f}")
263
+ return metrics
264
+
265
+ @staticmethod
266
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
267
+ def matthews_correlation(true_classes: NDArray[np.intc], pred_classes: NDArray[np.intc]) -> dict[str, float]:
268
+ """ Calculate Matthews Correlation Coefficient.
269
+
270
+ Args:
271
+ true_classes (NDArray[np.intc]): True class labels
272
+ pred_classes (NDArray[np.intc]): Predicted class labels
273
+ Returns:
274
+ dict[str, float]: Dictionary containing MCC
275
+
276
+ Examples:
277
+ >>> true_classes = np.array([0, 1, 0])
278
+ >>> pred_classes = np.array([0, 1, 1])
279
+ >>> from stouputils.ctx import Muffle
280
+ >>> with Muffle():
281
+ ... metrics = MetricUtils.matthews_correlation(true_classes, pred_classes)
282
+ >>> float(metrics[MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT])
283
+ 0.5
284
+ """
285
+ return {MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT: matthews_corrcoef(true_classes, pred_classes)}
286
+
287
+
288
+ @staticmethod
289
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
290
+ def roc_curve_and_auc(
291
+ true_classes: NDArray[np.intc] | NDArray[np.single],
292
+ pred_probs: NDArray[np.single],
293
+ fold_number: int = -1,
294
+ run_name: str = "",
295
+ plot_if_minimum: int = 5
296
+ ) -> dict[str, float]:
297
+ """ Calculate ROC curve and AUC score.
298
+
299
+ Args:
300
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
301
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
302
+ fold_number (int): Fold number, used for naming the plot file, usually
303
+ -1 for final model with test set,
304
+ 0 for final model with validation set,
305
+ >0 for other folds with their validation set
306
+ run_name (str): Name for saving the plot
307
+ plot_if_minimum (int): Minimum number of samples required in true_classes to plot the ROC curve
308
+ Returns:
309
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds
310
+
311
+ Examples:
312
+ >>> true_classes = np.array([0, 1, 0])
313
+ >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
314
+ >>> from stouputils.ctx import Muffle
315
+ >>> with Muffle():
316
+ ... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
317
+
318
+ >>> # Check metrics
319
+ >>> round(float(metrics[MetricDictionnary.AUC]), 2)
320
+ 0.75
321
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_YOUDEN]), 2)
322
+ 0.9
323
+ >>> float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST])
324
+ inf
325
+ """
326
+ auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(true_classes, pred_probs)
327
+ metrics: dict[str, float] = {MetricDictionnary.AUC: auc_value}
328
+
329
+ # Find optimal threshold using different methods
330
+ # 1. Youden's method
331
+ youden_index: NDArray[np.single] = tpr - fpr
332
+ optimal_threshold_youden: float = thresholds[np.argmax(youden_index)]
333
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_YOUDEN] = optimal_threshold_youden
334
+
335
+ # 2. Cost-based method
336
+ # Assuming false positives cost twice as much as false negatives
337
+ cost_fp: float = 2
338
+ cost_fn: float = 1
339
+ total_cost: NDArray[np.single] = cost_fp * fpr + cost_fn * (1 - tpr)
340
+ optimal_threshold_cost: float = thresholds[np.argmin(total_cost)]
341
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST] = optimal_threshold_cost
342
+
343
+ # Plot ROC curve if run_name and minimum number of samples is reached
344
+ if run_name and len(true_classes) >= plot_if_minimum:
345
+ plt.figure(figsize=(12, 6))
346
+ plt.plot(fpr, tpr, "b", label=f"ROC curve (AUC = {auc_value:.2f})")
347
+ plt.plot([0, 1], [0, 1], "r--")
348
+
349
+ # Add optimal threshold points
350
+ youden_idx: int = int(np.argmax(youden_index))
351
+ cost_idx: int = int(np.argmin(total_cost))
352
+
353
+ # Prepare the path
354
+ fold_name: str = ""
355
+ if fold_number > 0:
356
+ fold_name = f"_fold_{fold_number}_val"
357
+ elif fold_number == 0:
358
+ fold_name = "_val"
359
+ elif fold_number == -1:
360
+ fold_name = "_test"
361
+ elif fold_number == -2:
362
+ fold_name = "_train"
363
+ roc_curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_roc_curve{fold_name}.png"
364
+
365
+ plt.plot(fpr[youden_idx], tpr[youden_idx], 'go', label=f'Youden (t={optimal_threshold_youden:.2f})')
366
+ plt.plot(fpr[cost_idx], tpr[cost_idx], 'mo', label=f'Cost (t={optimal_threshold_cost:.2f})')
367
+
368
+ plt.xlim([-0.01, 1.01])
369
+ plt.ylim([-0.01, 1.01])
370
+ plt.xlabel("False Positive Rate")
371
+ plt.ylabel("True Positive Rate")
372
+ plt.title("Receiver Operating Characteristic (ROC)")
373
+ plt.legend(loc="lower right")
374
+ plt.savefig(roc_curve_path)
375
+ mlflow.log_artifact(roc_curve_path)
376
+ os.remove(roc_curve_path)
377
+ plt.close()
378
+
379
+ return metrics
380
+
381
+ @staticmethod
382
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
383
+ def pr_curve_and_auc(
384
+ true_classes: NDArray[np.intc] | NDArray[np.single],
385
+ pred_probs: NDArray[np.single],
386
+ fold_number: int = -1,
387
+ run_name: str = "",
388
+ plot_if_minimum: int = 5
389
+ ) -> dict[str, float]:
390
+ """ Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
391
+
392
+ Args:
393
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
394
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
395
+ fold_number (int): Fold number, used for naming the plot file, usually
396
+ -1 for final model with test set,
397
+ 0 for final model with validation set,
398
+ >0 for other folds with their validation set
399
+ run_name (str): Name for saving the plot
400
+ plot_if_minimum (int): Minimum number of samples required in true_classes to plot the PR curves
401
+ Returns:
402
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds
403
+
404
+ Examples:
405
+ >>> true_classes = np.array([0, 1, 0])
406
+ >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
407
+ >>> from stouputils.ctx import Muffle
408
+ >>> with Muffle():
409
+ ... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
410
+
411
+ >>> # Check metrics
412
+ >>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
413
+ 0.75
414
+ >>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
415
+ 0.92
416
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
417
+ 0.5
418
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
419
+ 0.33
420
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
421
+ 0.9
422
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
423
+ 0.1
424
+ """
425
+ auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
426
+ neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
427
+ Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
428
+ )
429
+
430
+ # Calculate metrics
431
+ metrics: dict[str, float] = {
432
+ MetricDictionnary.AUPRC: auc_value,
433
+ MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
434
+ MetricDictionnary.PR_AVERAGE: average_precision,
435
+ MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
436
+ }
437
+
438
+ # Calculate optimal thresholds for both PR curves
439
+ for is_negative in (False, True):
440
+
441
+ # Get the right values based on positive/negative case
442
+ if not is_negative:
443
+ curr_precision = precision
444
+ curr_recall = recall
445
+ curr_thresholds = thresholds
446
+ curr_auc = auc_value
447
+ curr_ap = average_precision
448
+ else:
449
+ curr_precision = npv
450
+ curr_recall = specificity
451
+ curr_thresholds = neg_thresholds
452
+ curr_auc = neg_auc_value
453
+ curr_ap = average_precision_neg
454
+
455
+ # Calculate F-score for each threshold
456
+ fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
457
+ fscore = fscore[~np.isnan(fscore)]
458
+
459
+ # Get optimal threshold (maximum F-score)
460
+ if len(fscore) > 0:
461
+ optimal_idx: int = int(np.argmax(fscore))
462
+ optimal_threshold: float = curr_thresholds[optimal_idx]
463
+ else:
464
+ optimal_idx: int = 0
465
+ optimal_threshold = float('inf')
466
+
467
+ # Store in metrics dictionary
468
+ if not is_negative:
469
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
470
+ else:
471
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
472
+
473
+ # Plot PR curve if run_name and minimum number of samples is reached
474
+ if run_name and len(true_classes) >= plot_if_minimum:
475
+ label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
476
+ plt.figure(figsize=(12, 6))
477
+ plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
478
+
479
+ # Prepare the path
480
+ fold_name: str = ""
481
+ if fold_number > 0:
482
+ fold_name = f"_fold_{fold_number}_val"
483
+ elif fold_number == 0:
484
+ fold_name = "_val"
485
+ elif fold_number == -1:
486
+ fold_name = "_test"
487
+ elif fold_number == -2:
488
+ fold_name = "_train"
489
+ pr: str = "pr" if not is_negative else "negative_pr"
490
+ curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
491
+
492
+ plt.plot(
493
+ curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
494
+ )
495
+
496
+ plt.xlim([-0.01, 1.01])
497
+ plt.ylim([-0.01, 1.01])
498
+ plt.xlabel("Recall" if not is_negative else "Specificity")
499
+ plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
500
+ plt.title(f"{label} Curve")
501
+ plt.legend(loc="lower right")
502
+ plt.savefig(curve_path)
503
+ mlflow.log_artifact(curve_path)
504
+ os.remove(curve_path)
505
+ plt.close()
506
+
507
+ return metrics
508
+
509
+ @staticmethod
510
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
511
+ def all_curves(
512
+ true_classes: NDArray[np.intc] | NDArray[np.single],
513
+ pred_probs: NDArray[np.single],
514
+ fold_number: int = -1,
515
+ run_name: str = ""
516
+ ) -> dict[str, float]:
517
+ """ Run all X_curve_and_auc functions and return a dictionary of metrics.
518
+
519
+ Args:
520
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
521
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
522
+ fold_number (int): Fold number, used for naming the plot file, usually
523
+ -1 for final model with test set,
524
+ 0 for final model with validation set,
525
+ >0 for other folds with their validation set
526
+ run_name (str): Name for saving the plot
527
+ Returns:
528
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
529
+ """
530
+ metrics: dict[str, float] = {}
531
+ metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
532
+ metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
533
+ return metrics
534
+
535
+
536
+ @staticmethod
537
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
538
+ def plot_metric_curves(
539
+ all_history: list[dict[str, list[float]]],
540
+ metric_name: str,
541
+ run_name: str = ""
542
+ ) -> None:
543
+ """ Plot training and validation curves for a specific metric.
544
+
545
+ Generates two plots for the given metric:
546
+ 1. A combined plot with both training and validation curves
547
+ 2. A validation-only plot
548
+
549
+ The plots show the metric's progression across training epochs for each fold.
550
+ Special formatting distinguishes between folds and curve types:
551
+ - Fold 0 (final model) uses thicker lines (2.0 width vs 1.0)
552
+ - Training curves use solid lines, validation uses dashed
553
+ - Each curve is clearly labeled in the legend
554
+
555
+ The plots are saved to the temp folder and logged to MLflow before cleanup.
556
+
557
+ Args:
558
+ all_history (list[dict[str, list[float]]]): List of history dictionaries for each fold
559
+ metric_name (str): Name of the metric to plot (e.g. "accuracy", "loss")
560
+ run_name (str): Name of the run
561
+
562
+ Examples:
563
+ >>> # Prepare data with 2 folds for instance
564
+ >>> all_history = [
565
+ ... {'loss': [0.1, 0.09, 0.08, 0.07, 0.06], 'val_loss': [0.11, 0.1, 0.09, 0.08, 0.07]},
566
+ ... {'loss': [0.12, 0.11, 0.1, 0.09, 0.08], 'val_loss': [0.13, 0.12, 0.11, 0.1, 0.09]}
567
+ ... ]
568
+ >>> MetricUtils.plot_metric_curves(metric_name="loss", all_history=all_history, run_name="")
569
+ """
570
+ for only_validation in (False, True):
571
+ plt.figure(figsize=(12, 6))
572
+
573
+ # Track max value for y-limit calculation
574
+ max_value: float = 0.0
575
+
576
+ for fold, history in enumerate(all_history):
577
+ # Get validation metrics for this fold
578
+ val_metric: list[float] = history[f"val_{metric_name}"]
579
+ epochs: list[int] = list(range(1, len(val_metric) + 1))
580
+
581
+ # Update max value
582
+ max_value = max(max_value, max(val_metric))
583
+
584
+ # Use thicker line for final model (fold 0)
585
+ alpha: float = 1.0 if fold == 0 else 0.5
586
+ linewidth: float = 2.0 if fold == 0 else 1.0
587
+ label: str = "Final Model" if fold == 0 else f"Fold {fold + 1}"
588
+ val_label: str = f"Validation {metric_name} ({label})"
589
+ plt.plot(epochs, val_metric, linestyle='--', linewidth=linewidth, alpha=alpha, label=val_label)
590
+
591
+ # Add training metrics if showing both curves
592
+ if not only_validation:
593
+ train_metric: list[float] = history[metric_name]
594
+ max_value = max(max_value, max(train_metric))
595
+ train_label: str = f"Training {metric_name} ({label})"
596
+ plt.plot(epochs, train_metric, linestyle='-', linewidth=linewidth, alpha=alpha, label=train_label)
597
+
598
+ # Configure plot formatting
599
+ plt.title(("Training and " if not only_validation else "") + f"Validation {metric_name} Across All Folds")
600
+ plt.xlabel("Epochs")
601
+ plt.ylabel(metric_name)
602
+
603
+ # Set y-limit for loss metric, to avoid seeing non-sense curves
604
+ if metric_name == "loss" and not only_validation:
605
+ plt.ylim(0, min(2.0, max_value * 1.1))
606
+
607
+ # Add legend
608
+ plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
609
+ plt.tight_layout()
610
+
611
+ # Save plot and log to MLflow
612
+ if run_name:
613
+ path: str = ("training_" if not only_validation else "") + f"validation_{metric_name}_curves.png"
614
+ full_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{path}"
615
+ plt.savefig(full_path, bbox_inches='tight')
616
+ mlflow.log_artifact(full_path)
617
+ os.remove(full_path)
618
+ plt.close()
619
+
620
+ @staticmethod
621
+ def plot_every_metric_curves(
622
+ all_history: list[dict[str, list[float]]],
623
+ metrics_names: tuple[str, ...] = (),
624
+ run_name: str = ""
625
+ ) -> None:
626
+ """ Plot and save training and validation curves for each metric.
627
+
628
+ Args:
629
+ all_history (list[dict[str, list[float]]]): List of history dictionaries for each fold
630
+ metrics_names (tuple[str, ...]): List of metric names to plot, defaults to ("loss",)
631
+ run_name (str): Name of the run
632
+
633
+ Examples:
634
+ >>> # Prepare data with 2 folds for instance
635
+ >>> all_history = [
636
+ ... {'loss': [0.1, 0.09], 'val_loss': [0.11, 0.1], "accuracy": [0.9, 0.8], "val_accuracy": [0.8, 0.7]},
637
+ ... {'loss': [0.12, 0.11], 'val_loss': [0.13, 0.12], "accuracy": [0.8, 0.7], "val_accuracy": [0.7, 0.6]}
638
+ ... ]
639
+ >>> MetricUtils.plot_every_metric_curves(all_history, metrics_names=["loss", "accuracy"], run_name="")
640
+ """
641
+ # Set default metrics names to loss
642
+ if not metrics_names:
643
+ metrics_names = ("loss",)
644
+
645
+ # Plot each metric
646
+ for metric_name in metrics_names:
647
+ MetricUtils.plot_metric_curves(all_history, metric_name, run_name)
648
+
649
+ @staticmethod
650
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
651
+ def find_best_x_and_plot(
652
+ x_values: list[float],
653
+ y_values: list[float],
654
+ best_idx: int | None = None,
655
+ smoothen: bool = True,
656
+ use_steep: bool = True,
657
+ run_name: str = "",
658
+ x_label: str = "Learning Rate",
659
+ y_label: str = "Loss",
660
+ plot_title: str = "Learning Rate Finder",
661
+ log_x: bool = True,
662
+ y_limits: tuple[float, ...] | None = None
663
+ ) -> float:
664
+ """ Find the best x value (where y is minimized) and plot the curve.
665
+
666
+ Args:
667
+ x_values (list[float]): List of x values (e.g. learning rates)
668
+ y_values (list[float]): List of corresponding y values (e.g. losses)
669
+ best_idx (int | None): Index of the best x value (if None, a robust approach is used)
670
+ smoothen (bool): Whether to apply smoothing to the y values
671
+ use_steep (bool): Whether to use steepest slope strategy to determine best index
672
+ run_name (str): Name of the run for saving the plot
673
+ x_label (str): Label for the x-axis
674
+ y_label (str): Label for the y-axis
675
+ plot_title (str): Title for the plot
676
+ log_x (bool): Whether to use a logarithmic x-axis (e.g. learning rate)
677
+ y_limits (tuple[float, ...] | None): Limit for the y-axis, defaults to None (no limit)
678
+
679
+ Returns:
680
+ float: The best x value found (where y is minimized)
681
+
682
+ This function creates a plot showing the relationship between x and y values
683
+ to help identify the optimal x (where y is minimized). The plot can use a logarithmic
684
+ x-axis for better visualization if desired.
685
+
686
+ The ideal x is typically found where y is still decreasing but before it starts to increase dramatically.
687
+
688
+ Examples:
689
+ >>> x_values = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
690
+ >>> y_values = [0.1, 0.09, 0.07, 0.06, 0.09]
691
+ >>> best_x = MetricUtils.find_best_x_and_plot(x_values, y_values, use_steep=True)
692
+ >>> print(f"Best x: {best_x:.0e}")
693
+ Best x: 1e-03
694
+ >>> best_x = MetricUtils.find_best_x_and_plot(x_values, y_values, use_steep=False)
695
+ >>> print(f"Best x: {best_x:.0e}")
696
+ Best x: 1e-02
697
+ """
698
+ # Validate input data
699
+ assert x_values, "No x data to plot"
700
+ assert y_values, "No y data to plot"
701
+
702
+ # Convert lists to numpy arrays for easier manipulation
703
+ y_array: NDArray[np.single] = np.array(y_values)
704
+ x_array: NDArray[np.single] = np.array(x_values)
705
+
706
+ # Apply smoothing to the y values if requested and if we have enough data points
707
+ if smoothen and len(y_values) > 2:
708
+
709
+ # Calculate appropriate window size based on data length
710
+ window_size: int = min(10, len(y_values) // 3)
711
+ if window_size > 1:
712
+
713
+ # Apply moving average smoothing using convolution
714
+ valid_convolution: NDArray[np.single] = np.convolve(y_array, np.ones(window_size)/window_size, mode="valid")
715
+ y_array = np.copy(y_array)
716
+
717
+ # Calculate start and end indices for replacing values with smoothed ones
718
+ start_idx: int = window_size // 2
719
+ end_idx: int = start_idx + len(valid_convolution)
720
+ y_array[start_idx:end_idx] = valid_convolution
721
+
722
+ # Replace first and last values with original values (to avoid weird effects)
723
+ y_array[0] = y_values[0]
724
+ y_array[-1] = y_values[-1]
725
+
726
+ # 1. Global minimum index between 10% and 90% (excluding borders)
727
+ window_start: int = int(0.1 * len(y_array))
728
+ window_end: int = int(0.9 * len(y_array))
729
+ global_window_min_idx: int = int(np.argmin(y_array[window_start:window_end]))
730
+ global_min_idx: int = global_window_min_idx + window_start
731
+
732
+ # Determine best index
733
+ if best_idx is None:
734
+ if use_steep:
735
+
736
+ # 2. Compute slope in loss vs log(x) for LR sensitivity
737
+ log_x_array: NDArray[np.single] = np.log(x_array)
738
+ slopes: NDArray[np.single] = np.gradient(y_array, log_x_array)
739
+
740
+ # 3. Define proximity window to the left of global minimum
741
+ proximity: int = max(1, len(y_array) // 10)
742
+ window_start = max(0, global_min_idx - proximity)
743
+
744
+ # 4. Find steepest slope within window
745
+ if window_start < global_min_idx:
746
+ local_slopes: NDArray[np.single] = slopes[window_start:global_min_idx]
747
+ relative_idx: int = int(np.argmin(local_slopes))
748
+ steep_idx: int = window_start + relative_idx
749
+ best_idx = steep_idx
750
+ else:
751
+ best_idx = global_min_idx
752
+
753
+ # 5. Top-7 most negative slopes as candidates
754
+ neg_idx: NDArray[np.intp] = np.where(slopes < 0)[0]
755
+ sorted_neg: NDArray[np.intp] = neg_idx[np.argsort(slopes[neg_idx])]
756
+ top7_fave: NDArray[np.intp] = sorted_neg[:7]
757
+
758
+ # Include best_idx and global_min_idx
759
+ candidates: set[int] = set(top7_fave.tolist())
760
+ candidates.add(best_idx)
761
+ distinct_candidates = np.array(sorted(candidates, key=int))
762
+ else:
763
+ best_idx = global_min_idx
764
+
765
+ # Find all local minima
766
+ from scipy.signal import argrelextrema
767
+ local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
768
+ distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
769
+ else:
770
+ assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
771
+ distinct_candidates = np.array([best_idx])
772
+
773
+ # Get the best x value and corresponding y value
774
+ best_x: float = x_array[best_idx]
775
+ min_y: float = y_array[best_idx]
776
+
777
+ # Create and save the plot if a run name is provided
778
+ if run_name:
779
+
780
+ # Log metrics to mlflow (e.g. 'learning_rate_finder_learning_rate', 'learning_rate_finder_loss')
781
+ log_title: str = MetricDictionnary.PARAMETER_FINDER.replace("TITLE", plot_title)
782
+ log_x_label: str = log_title.replace("PARAMETER_NAME", x_label)
783
+ log_y_label: str = log_title.replace("PARAMETER_NAME", y_label)
784
+ for i in range(len(x_values)):
785
+ mlflow.log_metric(log_x_label, x_values[i], step=i)
786
+ mlflow.log_metric(log_y_label, y_values[i], step=i)
787
+
788
+ # Prepare the plot
789
+ plt.figure(figsize=(12, 6))
790
+ plt.plot(x_array, y_array, label="Smoothed Curve", linewidth=2)
791
+ plt.plot(x_values, y_values, "-", markersize=3, alpha=0.5, label="Original Curve", color="gray")
792
+
793
+ # Use logarithmic scale for x-axis if requested
794
+ if log_x:
795
+ plt.xscale("log")
796
+
797
+ # Set labels and title
798
+ plt.xlabel(x_label)
799
+ plt.ylabel(y_label)
800
+ plt.title(plot_title)
801
+ plt.grid(True, which="both", ls="--")
802
+
803
+ # Limit y-axis to avoid extreme values
804
+ if y_limits is not None and len(y_limits) == 2:
805
+ min_y_limit: float = max(y_limits[0], min(y_values) * 0.9)
806
+ max_y_limit: float = min(y_limits[1], max(y_values) * 1.1)
807
+ plt.ylim(min_y_limit, max_y_limit)
808
+ plt.legend()
809
+
810
+ # Highlight local minima if any
811
+ if len(distinct_candidates) > 0:
812
+ candidate_xs = [x_array[idx] for idx in distinct_candidates]
813
+ candidate_ys = [y_array[idx] for idx in distinct_candidates]
814
+ candidates_label = "Possible Candidates" if use_steep else "Local Minima"
815
+ plt.scatter(candidate_xs, candidate_ys, color="orange", s=25, zorder=4, label=candidates_label)
816
+
817
+ # Highlight the best point
818
+ plt.scatter([x_array[global_min_idx]], [y_array[global_min_idx]], color="red", s=50, zorder=5, label="Global Minimum")
819
+
820
+ # Format the best x value for display
821
+ best_x_str: str = f"{best_x:.2e}" if best_x < 1e-3 else f"{best_x:.2f}"
822
+
823
+ # Add annotation pointing to the best point
824
+ plt.annotate(
825
+ f"Supposed best {x_label}: {best_x_str}",
826
+ xy=(best_x, min_y),
827
+ xytext=(best_x * 1.5, min_y * 1.1),
828
+ arrowprops={"facecolor":"black", "shrink":0.05, "width":1.2}
829
+ )
830
+ plt.legend()
831
+ plt.tight_layout()
832
+
833
+ # Save the plot to a file and log it to MLflow
834
+ flat_x_label: str = x_label.lower().replace(" ", "_")
835
+ path: str = f"{flat_x_label}_finder.png"
836
+ os.makedirs(DataScienceConfig.TEMP_FOLDER, exist_ok=True)
837
+ full_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{path}"
838
+ plt.savefig(full_path, bbox_inches="tight")
839
+ mlflow.log_artifact(full_path)
840
+ info(f"Saved best x plot to {full_path}")
841
+
842
+ # Clean up the temporary file
843
+ os.remove(full_path)
844
+ plt.close()
845
+
846
+ return best_x
847
+