stouputils 1.14.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- stouputils/__init__.py +40 -0
- stouputils/__main__.py +86 -0
- stouputils/_deprecated.py +37 -0
- stouputils/all_doctests.py +160 -0
- stouputils/applications/__init__.py +22 -0
- stouputils/applications/automatic_docs.py +634 -0
- stouputils/applications/upscaler/__init__.py +39 -0
- stouputils/applications/upscaler/config.py +128 -0
- stouputils/applications/upscaler/image.py +247 -0
- stouputils/applications/upscaler/video.py +287 -0
- stouputils/archive.py +344 -0
- stouputils/backup.py +488 -0
- stouputils/collections.py +244 -0
- stouputils/continuous_delivery/__init__.py +27 -0
- stouputils/continuous_delivery/cd_utils.py +243 -0
- stouputils/continuous_delivery/github.py +522 -0
- stouputils/continuous_delivery/pypi.py +130 -0
- stouputils/continuous_delivery/pyproject.py +147 -0
- stouputils/continuous_delivery/stubs.py +86 -0
- stouputils/ctx.py +408 -0
- stouputils/data_science/config/get.py +51 -0
- stouputils/data_science/config/set.py +125 -0
- stouputils/data_science/data_processing/image/__init__.py +66 -0
- stouputils/data_science/data_processing/image/auto_contrast.py +79 -0
- stouputils/data_science/data_processing/image/axis_flip.py +58 -0
- stouputils/data_science/data_processing/image/bias_field_correction.py +74 -0
- stouputils/data_science/data_processing/image/binary_threshold.py +73 -0
- stouputils/data_science/data_processing/image/blur.py +59 -0
- stouputils/data_science/data_processing/image/brightness.py +54 -0
- stouputils/data_science/data_processing/image/canny.py +110 -0
- stouputils/data_science/data_processing/image/clahe.py +92 -0
- stouputils/data_science/data_processing/image/common.py +30 -0
- stouputils/data_science/data_processing/image/contrast.py +53 -0
- stouputils/data_science/data_processing/image/curvature_flow_filter.py +74 -0
- stouputils/data_science/data_processing/image/denoise.py +378 -0
- stouputils/data_science/data_processing/image/histogram_equalization.py +123 -0
- stouputils/data_science/data_processing/image/invert.py +64 -0
- stouputils/data_science/data_processing/image/laplacian.py +60 -0
- stouputils/data_science/data_processing/image/median_blur.py +52 -0
- stouputils/data_science/data_processing/image/noise.py +59 -0
- stouputils/data_science/data_processing/image/normalize.py +65 -0
- stouputils/data_science/data_processing/image/random_erase.py +66 -0
- stouputils/data_science/data_processing/image/resize.py +69 -0
- stouputils/data_science/data_processing/image/rotation.py +80 -0
- stouputils/data_science/data_processing/image/salt_pepper.py +68 -0
- stouputils/data_science/data_processing/image/sharpening.py +55 -0
- stouputils/data_science/data_processing/image/shearing.py +64 -0
- stouputils/data_science/data_processing/image/threshold.py +64 -0
- stouputils/data_science/data_processing/image/translation.py +71 -0
- stouputils/data_science/data_processing/image/zoom.py +83 -0
- stouputils/data_science/data_processing/image_augmentation.py +118 -0
- stouputils/data_science/data_processing/image_preprocess.py +183 -0
- stouputils/data_science/data_processing/prosthesis_detection.py +359 -0
- stouputils/data_science/data_processing/technique.py +481 -0
- stouputils/data_science/dataset/__init__.py +45 -0
- stouputils/data_science/dataset/dataset.py +292 -0
- stouputils/data_science/dataset/dataset_loader.py +135 -0
- stouputils/data_science/dataset/grouping_strategy.py +296 -0
- stouputils/data_science/dataset/image_loader.py +100 -0
- stouputils/data_science/dataset/xy_tuple.py +696 -0
- stouputils/data_science/metric_dictionnary.py +106 -0
- stouputils/data_science/metric_utils.py +847 -0
- stouputils/data_science/mlflow_utils.py +206 -0
- stouputils/data_science/models/abstract_model.py +149 -0
- stouputils/data_science/models/all.py +85 -0
- stouputils/data_science/models/base_keras.py +765 -0
- stouputils/data_science/models/keras/all.py +38 -0
- stouputils/data_science/models/keras/convnext.py +62 -0
- stouputils/data_science/models/keras/densenet.py +50 -0
- stouputils/data_science/models/keras/efficientnet.py +60 -0
- stouputils/data_science/models/keras/mobilenet.py +56 -0
- stouputils/data_science/models/keras/resnet.py +52 -0
- stouputils/data_science/models/keras/squeezenet.py +233 -0
- stouputils/data_science/models/keras/vgg.py +42 -0
- stouputils/data_science/models/keras/xception.py +38 -0
- stouputils/data_science/models/keras_utils/callbacks/__init__.py +20 -0
- stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +219 -0
- stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +148 -0
- stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +31 -0
- stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +249 -0
- stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +66 -0
- stouputils/data_science/models/keras_utils/losses/__init__.py +12 -0
- stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +56 -0
- stouputils/data_science/models/keras_utils/visualizations.py +416 -0
- stouputils/data_science/models/model_interface.py +939 -0
- stouputils/data_science/models/sandbox.py +116 -0
- stouputils/data_science/range_tuple.py +234 -0
- stouputils/data_science/scripts/augment_dataset.py +77 -0
- stouputils/data_science/scripts/exhaustive_process.py +133 -0
- stouputils/data_science/scripts/preprocess_dataset.py +70 -0
- stouputils/data_science/scripts/routine.py +168 -0
- stouputils/data_science/utils.py +285 -0
- stouputils/decorators.py +605 -0
- stouputils/image.py +441 -0
- stouputils/installer/__init__.py +18 -0
- stouputils/installer/common.py +67 -0
- stouputils/installer/downloader.py +101 -0
- stouputils/installer/linux.py +144 -0
- stouputils/installer/main.py +223 -0
- stouputils/installer/windows.py +136 -0
- stouputils/io.py +486 -0
- stouputils/parallel.py +483 -0
- stouputils/print.py +482 -0
- stouputils/py.typed +1 -0
- stouputils/stouputils/__init__.pyi +15 -0
- stouputils/stouputils/_deprecated.pyi +12 -0
- stouputils/stouputils/all_doctests.pyi +46 -0
- stouputils/stouputils/applications/__init__.pyi +2 -0
- stouputils/stouputils/applications/automatic_docs.pyi +106 -0
- stouputils/stouputils/applications/upscaler/__init__.pyi +3 -0
- stouputils/stouputils/applications/upscaler/config.pyi +18 -0
- stouputils/stouputils/applications/upscaler/image.pyi +109 -0
- stouputils/stouputils/applications/upscaler/video.pyi +60 -0
- stouputils/stouputils/archive.pyi +67 -0
- stouputils/stouputils/backup.pyi +109 -0
- stouputils/stouputils/collections.pyi +86 -0
- stouputils/stouputils/continuous_delivery/__init__.pyi +5 -0
- stouputils/stouputils/continuous_delivery/cd_utils.pyi +129 -0
- stouputils/stouputils/continuous_delivery/github.pyi +162 -0
- stouputils/stouputils/continuous_delivery/pypi.pyi +53 -0
- stouputils/stouputils/continuous_delivery/pyproject.pyi +67 -0
- stouputils/stouputils/continuous_delivery/stubs.pyi +39 -0
- stouputils/stouputils/ctx.pyi +211 -0
- stouputils/stouputils/decorators.pyi +252 -0
- stouputils/stouputils/image.pyi +172 -0
- stouputils/stouputils/installer/__init__.pyi +5 -0
- stouputils/stouputils/installer/common.pyi +39 -0
- stouputils/stouputils/installer/downloader.pyi +24 -0
- stouputils/stouputils/installer/linux.pyi +39 -0
- stouputils/stouputils/installer/main.pyi +57 -0
- stouputils/stouputils/installer/windows.pyi +31 -0
- stouputils/stouputils/io.pyi +213 -0
- stouputils/stouputils/parallel.pyi +216 -0
- stouputils/stouputils/print.pyi +136 -0
- stouputils/stouputils/version_pkg.pyi +15 -0
- stouputils/version_pkg.py +189 -0
- stouputils-1.14.0.dist-info/METADATA +178 -0
- stouputils-1.14.0.dist-info/RECORD +140 -0
- stouputils-1.14.0.dist-info/WHEEL +4 -0
- stouputils-1.14.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,847 @@
|
|
|
1
|
+
"""
|
|
2
|
+
This module contains the MetricUtils class, which provides static methods for
|
|
3
|
+
calculating various metrics for machine learning tasks.
|
|
4
|
+
|
|
5
|
+
This class contains static methods for:
|
|
6
|
+
|
|
7
|
+
- Calculating various metrics (accuracy, precision, recall, etc.)
|
|
8
|
+
- Computing confusion matrix and related metrics
|
|
9
|
+
- Generating ROC curves and finding optimal thresholds
|
|
10
|
+
- Calculating F-beta scores
|
|
11
|
+
|
|
12
|
+
The metrics are calculated based on the predictions made by a model and the true labels from a dataset.
|
|
13
|
+
The class supports both binary and multiclass classification tasks.
|
|
14
|
+
"""
|
|
15
|
+
# pyright: reportUnknownMemberType=false
|
|
16
|
+
# pyright: reportUnknownVariableType=false
|
|
17
|
+
# pyright: reportMissingTypeStubs=false
|
|
18
|
+
|
|
19
|
+
# Imports
|
|
20
|
+
import os
|
|
21
|
+
from collections.abc import Iterable
|
|
22
|
+
from typing import Any, Literal
|
|
23
|
+
|
|
24
|
+
import mlflow
|
|
25
|
+
import numpy as np
|
|
26
|
+
from ..decorators import handle_error, measure_time
|
|
27
|
+
from ..print import info, warning
|
|
28
|
+
from matplotlib import pyplot as plt
|
|
29
|
+
from numpy.typing import NDArray
|
|
30
|
+
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix, matthews_corrcoef
|
|
31
|
+
|
|
32
|
+
from .config.get import DataScienceConfig
|
|
33
|
+
from .dataset import Dataset
|
|
34
|
+
from .metric_dictionnary import MetricDictionnary
|
|
35
|
+
from .utils import Utils
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
# Class
|
|
39
|
+
class MetricUtils:
|
|
40
|
+
""" Class containing static methods for calculating metrics. """
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
@measure_time(printer=info, message="Execution time of MetricUtils.metrics")
|
|
44
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
45
|
+
def metrics(
|
|
46
|
+
dataset: Dataset,
|
|
47
|
+
predictions: Iterable[Any],
|
|
48
|
+
run_name: str,
|
|
49
|
+
mode: Literal["binary", "multiclass", "none"] = "binary"
|
|
50
|
+
) -> dict[str, float]:
|
|
51
|
+
""" Method to calculate as many metrics as possible for the given dataset and predictions.
|
|
52
|
+
|
|
53
|
+
Args:
|
|
54
|
+
dataset (Dataset): Dataset containing the true labels
|
|
55
|
+
predictions (Iterable): Predictions made by the model
|
|
56
|
+
run_name (str): Name of the run, used to save the ROC curve
|
|
57
|
+
mode (Literal): Mode of the classification, defaults to "binary"
|
|
58
|
+
Returns:
|
|
59
|
+
dict[str, float]: Dictionary containing the calculated metrics
|
|
60
|
+
|
|
61
|
+
Examples:
|
|
62
|
+
>>> # Prepare a test dataset
|
|
63
|
+
>>> from .dataset import XyTuple
|
|
64
|
+
>>> test_data = XyTuple(X=np.array([[1], [2], [3]]), y=np.array([0, 1, 0]))
|
|
65
|
+
>>> dataset = Dataset(training_data=test_data, test_data=test_data, name="osef")
|
|
66
|
+
|
|
67
|
+
>>> # Prepare predictions
|
|
68
|
+
>>> predictions = np.array([[0.9, 0.1], [0.2, 0.8], [0.2, 0.8]])
|
|
69
|
+
|
|
70
|
+
>>> # Calculate metrics
|
|
71
|
+
>>> from stouputils.ctx import Muffle
|
|
72
|
+
>>> with Muffle():
|
|
73
|
+
... metrics = MetricUtils.metrics(dataset, predictions, run_name="")
|
|
74
|
+
|
|
75
|
+
>>> # Check metrics
|
|
76
|
+
>>> round(float(metrics[MetricDictionnary.ACCURACY]), 2)
|
|
77
|
+
0.67
|
|
78
|
+
>>> round(float(metrics[MetricDictionnary.PRECISION]), 2)
|
|
79
|
+
0.5
|
|
80
|
+
>>> round(float(metrics[MetricDictionnary.RECALL]), 2)
|
|
81
|
+
1.0
|
|
82
|
+
>>> round(float(metrics[MetricDictionnary.F1_SCORE]), 2)
|
|
83
|
+
0.67
|
|
84
|
+
>>> round(float(metrics[MetricDictionnary.AUC]), 2)
|
|
85
|
+
0.75
|
|
86
|
+
>>> round(float(metrics[MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT]), 2)
|
|
87
|
+
0.5
|
|
88
|
+
"""
|
|
89
|
+
# Initialize metrics
|
|
90
|
+
metrics: dict[str, float] = {}
|
|
91
|
+
y_true: NDArray[np.single] = dataset.test_data.ungrouped_array()[1]
|
|
92
|
+
y_pred: NDArray[np.single] = np.array(predictions)
|
|
93
|
+
|
|
94
|
+
# Binary classification
|
|
95
|
+
if mode == "binary":
|
|
96
|
+
true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
97
|
+
pred_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_pred)
|
|
98
|
+
|
|
99
|
+
# Get confusion matrix metrics
|
|
100
|
+
conf_metrics: dict[str, float] = MetricUtils.confusion_matrix(
|
|
101
|
+
true_classes=true_classes,
|
|
102
|
+
pred_classes=pred_classes,
|
|
103
|
+
labels=dataset.labels,
|
|
104
|
+
run_name=run_name
|
|
105
|
+
)
|
|
106
|
+
metrics.update(conf_metrics)
|
|
107
|
+
|
|
108
|
+
# Calculate F-beta scores
|
|
109
|
+
precision: float = conf_metrics.get(MetricDictionnary.PRECISION, 0)
|
|
110
|
+
recall: float = conf_metrics.get(MetricDictionnary.RECALL, 0)
|
|
111
|
+
f_metrics: dict[str, float] = MetricUtils.f_scores(precision, recall)
|
|
112
|
+
if f_metrics:
|
|
113
|
+
metrics.update(f_metrics)
|
|
114
|
+
|
|
115
|
+
# Calculate Matthews Correlation Coefficient
|
|
116
|
+
mcc_metric: dict[str, float] = MetricUtils.matthews_correlation(true_classes, pred_classes)
|
|
117
|
+
if mcc_metric:
|
|
118
|
+
metrics.update(mcc_metric)
|
|
119
|
+
|
|
120
|
+
# Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
|
|
121
|
+
curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
|
|
122
|
+
if curves_metrics:
|
|
123
|
+
metrics.update(curves_metrics)
|
|
124
|
+
|
|
125
|
+
# Multiclass classification
|
|
126
|
+
elif mode == "multiclass":
|
|
127
|
+
pass
|
|
128
|
+
|
|
129
|
+
return metrics
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
133
|
+
def confusion_matrix(
|
|
134
|
+
true_classes: NDArray[np.intc],
|
|
135
|
+
pred_classes: NDArray[np.intc],
|
|
136
|
+
labels: tuple[str, ...],
|
|
137
|
+
run_name: str = ""
|
|
138
|
+
) -> dict[str, float]:
|
|
139
|
+
""" Calculate metrics based on confusion matrix.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
true_classes (NDArray[np.intc]): True class labels
|
|
143
|
+
pred_classes (NDArray[np.intc]): Predicted class labels
|
|
144
|
+
labels (tuple[str, ...]): List of class labels (strings)
|
|
145
|
+
run_name (str): Name for saving the plot
|
|
146
|
+
Returns:
|
|
147
|
+
dict[str, float]: Dictionary of confusion matrix based metrics
|
|
148
|
+
|
|
149
|
+
Examples:
|
|
150
|
+
>>> # Prepare data
|
|
151
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
152
|
+
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
153
|
+
>>> pred_classes = Utils.convert_to_class_indices(pred_probs) # [0, 1, 1]
|
|
154
|
+
>>> labels = ["class_0", "class_1"]
|
|
155
|
+
|
|
156
|
+
>>> # Calculate metrics
|
|
157
|
+
>>> from stouputils.ctx import Muffle
|
|
158
|
+
>>> with Muffle():
|
|
159
|
+
... metrics = MetricUtils.confusion_matrix(true_classes, pred_classes, labels, run_name="")
|
|
160
|
+
|
|
161
|
+
>>> # Check metrics
|
|
162
|
+
>>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_TN])
|
|
163
|
+
1
|
|
164
|
+
>>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_FP])
|
|
165
|
+
1
|
|
166
|
+
>>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_FN])
|
|
167
|
+
0
|
|
168
|
+
>>> int(metrics[MetricDictionnary.CONFUSION_MATRIX_TP])
|
|
169
|
+
1
|
|
170
|
+
>>> round(float(metrics[MetricDictionnary.FALSE_POSITIVE_RATE]), 2)
|
|
171
|
+
0.5
|
|
172
|
+
"""
|
|
173
|
+
metrics: dict[str, float] = {}
|
|
174
|
+
|
|
175
|
+
# Get basic confusion matrix values
|
|
176
|
+
conf_matrix: NDArray[np.intc] = confusion_matrix(true_classes, pred_classes)
|
|
177
|
+
TN: int = conf_matrix[0, 0] # True Negatives
|
|
178
|
+
FP: int = conf_matrix[0, 1] # False Positives
|
|
179
|
+
FN: int = conf_matrix[1, 0] # False Negatives
|
|
180
|
+
TP: int = conf_matrix[1, 1] # True Positives
|
|
181
|
+
|
|
182
|
+
# Calculate totals for each category
|
|
183
|
+
total_samples: int = TN + FP + FN + TP
|
|
184
|
+
total_actual_negatives: int = TN + FP
|
|
185
|
+
total_actual_positives: int = TP + FN
|
|
186
|
+
total_predicted_negatives: int = TN + FN
|
|
187
|
+
total_predicted_positives: int = TP + FP
|
|
188
|
+
|
|
189
|
+
# Calculate core metrics
|
|
190
|
+
specificity: float = Utils.safe_divide_float(TN, total_actual_negatives)
|
|
191
|
+
recall: float = Utils.safe_divide_float(TP, total_actual_positives)
|
|
192
|
+
precision: float = Utils.safe_divide_float(TP, total_predicted_positives)
|
|
193
|
+
npv: float = Utils.safe_divide_float(TN, total_predicted_negatives)
|
|
194
|
+
accuracy: float = Utils.safe_divide_float(TN + TP, total_samples)
|
|
195
|
+
balanced_accuracy: float = (specificity + recall) / 2
|
|
196
|
+
f1_score: float = Utils.safe_divide_float(2 * (precision * recall), precision + recall)
|
|
197
|
+
f1_score_negative: float = Utils.safe_divide_float(2 * (specificity * npv), specificity + npv)
|
|
198
|
+
|
|
199
|
+
# Store main metrics using MetricDictionnary
|
|
200
|
+
metrics[MetricDictionnary.SPECIFICITY] = specificity
|
|
201
|
+
metrics[MetricDictionnary.RECALL] = recall
|
|
202
|
+
metrics[MetricDictionnary.PRECISION] = precision
|
|
203
|
+
metrics[MetricDictionnary.NPV] = npv
|
|
204
|
+
metrics[MetricDictionnary.ACCURACY] = accuracy
|
|
205
|
+
metrics[MetricDictionnary.BALANCED_ACCURACY] = balanced_accuracy
|
|
206
|
+
metrics[MetricDictionnary.F1_SCORE] = f1_score
|
|
207
|
+
metrics[MetricDictionnary.F1_SCORE_NEGATIVE] = f1_score_negative
|
|
208
|
+
|
|
209
|
+
# Store confusion matrix values and derived metrics
|
|
210
|
+
metrics[MetricDictionnary.CONFUSION_MATRIX_TN] = TN
|
|
211
|
+
metrics[MetricDictionnary.CONFUSION_MATRIX_FP] = FP
|
|
212
|
+
metrics[MetricDictionnary.CONFUSION_MATRIX_FN] = FN
|
|
213
|
+
metrics[MetricDictionnary.CONFUSION_MATRIX_TP] = TP
|
|
214
|
+
metrics[MetricDictionnary.FALSE_POSITIVE_RATE] = Utils.safe_divide_float(FP, total_actual_negatives)
|
|
215
|
+
metrics[MetricDictionnary.FALSE_NEGATIVE_RATE] = Utils.safe_divide_float(FN, total_actual_positives)
|
|
216
|
+
metrics[MetricDictionnary.FALSE_DISCOVERY_RATE] = Utils.safe_divide_float(FP, total_predicted_positives)
|
|
217
|
+
metrics[MetricDictionnary.FALSE_OMISSION_RATE] = Utils.safe_divide_float(FN, total_predicted_negatives)
|
|
218
|
+
metrics[MetricDictionnary.CRITICAL_SUCCESS_INDEX] = Utils.safe_divide_float(TP, total_actual_positives + FP)
|
|
219
|
+
|
|
220
|
+
# Plot confusion matrix
|
|
221
|
+
if run_name:
|
|
222
|
+
confusion_matrix_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_confusion_matrix.png"
|
|
223
|
+
ConfusionMatrixDisplay.from_predictions(true_classes, pred_classes, display_labels=labels)
|
|
224
|
+
plt.savefig(confusion_matrix_path)
|
|
225
|
+
mlflow.log_artifact(confusion_matrix_path)
|
|
226
|
+
os.remove(confusion_matrix_path)
|
|
227
|
+
plt.close()
|
|
228
|
+
|
|
229
|
+
return metrics
|
|
230
|
+
|
|
231
|
+
@staticmethod
|
|
232
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
233
|
+
def f_scores(precision: float, recall: float) -> dict[str, float]:
|
|
234
|
+
""" Calculate F-beta scores for different beta values.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
precision (float): Precision value
|
|
238
|
+
recall (float): Recall value
|
|
239
|
+
Returns:
|
|
240
|
+
dict[str, float]: Dictionary of F-beta scores
|
|
241
|
+
|
|
242
|
+
Examples:
|
|
243
|
+
>>> from stouputils.ctx import Muffle
|
|
244
|
+
>>> with Muffle():
|
|
245
|
+
... metrics = MetricUtils.f_scores(precision=0.5, recall=1.0)
|
|
246
|
+
>>> [round(float(x), 2) for x in metrics.values()]
|
|
247
|
+
[0.5, 0.51, 0.54, 0.58, 0.62, 0.67, 0.71, 0.75, 0.78, 0.81, 0.83]
|
|
248
|
+
|
|
249
|
+
"""
|
|
250
|
+
# Assertions
|
|
251
|
+
assert precision > 0, "Precision cannot be 0"
|
|
252
|
+
assert recall > 0, "Recall cannot be 0"
|
|
253
|
+
|
|
254
|
+
# Calculate F-beta scores
|
|
255
|
+
metrics: dict[str, float] = {}
|
|
256
|
+
betas: Iterable[float] = np.linspace(0, 2, 11)
|
|
257
|
+
for beta in betas:
|
|
258
|
+
divider: float = (beta**2 * precision) + recall
|
|
259
|
+
score: float = Utils.safe_divide_float((1 + beta**2) * precision * recall, divider)
|
|
260
|
+
metrics[MetricDictionnary.F_SCORE_X.replace("X", f"{beta:.1f}")] = score
|
|
261
|
+
if score == 0:
|
|
262
|
+
warning(f"F-score is 0 for beta={beta:.1f}")
|
|
263
|
+
return metrics
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
267
|
+
def matthews_correlation(true_classes: NDArray[np.intc], pred_classes: NDArray[np.intc]) -> dict[str, float]:
|
|
268
|
+
""" Calculate Matthews Correlation Coefficient.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
true_classes (NDArray[np.intc]): True class labels
|
|
272
|
+
pred_classes (NDArray[np.intc]): Predicted class labels
|
|
273
|
+
Returns:
|
|
274
|
+
dict[str, float]: Dictionary containing MCC
|
|
275
|
+
|
|
276
|
+
Examples:
|
|
277
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
278
|
+
>>> pred_classes = np.array([0, 1, 1])
|
|
279
|
+
>>> from stouputils.ctx import Muffle
|
|
280
|
+
>>> with Muffle():
|
|
281
|
+
... metrics = MetricUtils.matthews_correlation(true_classes, pred_classes)
|
|
282
|
+
>>> float(metrics[MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT])
|
|
283
|
+
0.5
|
|
284
|
+
"""
|
|
285
|
+
return {MetricDictionnary.MATTHEWS_CORRELATION_COEFFICIENT: matthews_corrcoef(true_classes, pred_classes)}
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
@staticmethod
|
|
289
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
290
|
+
def roc_curve_and_auc(
|
|
291
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
292
|
+
pred_probs: NDArray[np.single],
|
|
293
|
+
fold_number: int = -1,
|
|
294
|
+
run_name: str = "",
|
|
295
|
+
plot_if_minimum: int = 5
|
|
296
|
+
) -> dict[str, float]:
|
|
297
|
+
""" Calculate ROC curve and AUC score.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
301
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
302
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
303
|
+
-1 for final model with test set,
|
|
304
|
+
0 for final model with validation set,
|
|
305
|
+
>0 for other folds with their validation set
|
|
306
|
+
run_name (str): Name for saving the plot
|
|
307
|
+
plot_if_minimum (int): Minimum number of samples required in true_classes to plot the ROC curve
|
|
308
|
+
Returns:
|
|
309
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds
|
|
310
|
+
|
|
311
|
+
Examples:
|
|
312
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
313
|
+
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
314
|
+
>>> from stouputils.ctx import Muffle
|
|
315
|
+
>>> with Muffle():
|
|
316
|
+
... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
317
|
+
|
|
318
|
+
>>> # Check metrics
|
|
319
|
+
>>> round(float(metrics[MetricDictionnary.AUC]), 2)
|
|
320
|
+
0.75
|
|
321
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_YOUDEN]), 2)
|
|
322
|
+
0.9
|
|
323
|
+
>>> float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST])
|
|
324
|
+
inf
|
|
325
|
+
"""
|
|
326
|
+
auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(true_classes, pred_probs)
|
|
327
|
+
metrics: dict[str, float] = {MetricDictionnary.AUC: auc_value}
|
|
328
|
+
|
|
329
|
+
# Find optimal threshold using different methods
|
|
330
|
+
# 1. Youden's method
|
|
331
|
+
youden_index: NDArray[np.single] = tpr - fpr
|
|
332
|
+
optimal_threshold_youden: float = thresholds[np.argmax(youden_index)]
|
|
333
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_YOUDEN] = optimal_threshold_youden
|
|
334
|
+
|
|
335
|
+
# 2. Cost-based method
|
|
336
|
+
# Assuming false positives cost twice as much as false negatives
|
|
337
|
+
cost_fp: float = 2
|
|
338
|
+
cost_fn: float = 1
|
|
339
|
+
total_cost: NDArray[np.single] = cost_fp * fpr + cost_fn * (1 - tpr)
|
|
340
|
+
optimal_threshold_cost: float = thresholds[np.argmin(total_cost)]
|
|
341
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST] = optimal_threshold_cost
|
|
342
|
+
|
|
343
|
+
# Plot ROC curve if run_name and minimum number of samples is reached
|
|
344
|
+
if run_name and len(true_classes) >= plot_if_minimum:
|
|
345
|
+
plt.figure(figsize=(12, 6))
|
|
346
|
+
plt.plot(fpr, tpr, "b", label=f"ROC curve (AUC = {auc_value:.2f})")
|
|
347
|
+
plt.plot([0, 1], [0, 1], "r--")
|
|
348
|
+
|
|
349
|
+
# Add optimal threshold points
|
|
350
|
+
youden_idx: int = int(np.argmax(youden_index))
|
|
351
|
+
cost_idx: int = int(np.argmin(total_cost))
|
|
352
|
+
|
|
353
|
+
# Prepare the path
|
|
354
|
+
fold_name: str = ""
|
|
355
|
+
if fold_number > 0:
|
|
356
|
+
fold_name = f"_fold_{fold_number}_val"
|
|
357
|
+
elif fold_number == 0:
|
|
358
|
+
fold_name = "_val"
|
|
359
|
+
elif fold_number == -1:
|
|
360
|
+
fold_name = "_test"
|
|
361
|
+
elif fold_number == -2:
|
|
362
|
+
fold_name = "_train"
|
|
363
|
+
roc_curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_roc_curve{fold_name}.png"
|
|
364
|
+
|
|
365
|
+
plt.plot(fpr[youden_idx], tpr[youden_idx], 'go', label=f'Youden (t={optimal_threshold_youden:.2f})')
|
|
366
|
+
plt.plot(fpr[cost_idx], tpr[cost_idx], 'mo', label=f'Cost (t={optimal_threshold_cost:.2f})')
|
|
367
|
+
|
|
368
|
+
plt.xlim([-0.01, 1.01])
|
|
369
|
+
plt.ylim([-0.01, 1.01])
|
|
370
|
+
plt.xlabel("False Positive Rate")
|
|
371
|
+
plt.ylabel("True Positive Rate")
|
|
372
|
+
plt.title("Receiver Operating Characteristic (ROC)")
|
|
373
|
+
plt.legend(loc="lower right")
|
|
374
|
+
plt.savefig(roc_curve_path)
|
|
375
|
+
mlflow.log_artifact(roc_curve_path)
|
|
376
|
+
os.remove(roc_curve_path)
|
|
377
|
+
plt.close()
|
|
378
|
+
|
|
379
|
+
return metrics
|
|
380
|
+
|
|
381
|
+
@staticmethod
|
|
382
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
383
|
+
def pr_curve_and_auc(
|
|
384
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
385
|
+
pred_probs: NDArray[np.single],
|
|
386
|
+
fold_number: int = -1,
|
|
387
|
+
run_name: str = "",
|
|
388
|
+
plot_if_minimum: int = 5
|
|
389
|
+
) -> dict[str, float]:
|
|
390
|
+
""" Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
394
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
395
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
396
|
+
-1 for final model with test set,
|
|
397
|
+
0 for final model with validation set,
|
|
398
|
+
>0 for other folds with their validation set
|
|
399
|
+
run_name (str): Name for saving the plot
|
|
400
|
+
plot_if_minimum (int): Minimum number of samples required in true_classes to plot the PR curves
|
|
401
|
+
Returns:
|
|
402
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds
|
|
403
|
+
|
|
404
|
+
Examples:
|
|
405
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
406
|
+
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
407
|
+
>>> from stouputils.ctx import Muffle
|
|
408
|
+
>>> with Muffle():
|
|
409
|
+
... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
410
|
+
|
|
411
|
+
>>> # Check metrics
|
|
412
|
+
>>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
|
|
413
|
+
0.75
|
|
414
|
+
>>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
|
|
415
|
+
0.92
|
|
416
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
|
|
417
|
+
0.5
|
|
418
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
|
|
419
|
+
0.33
|
|
420
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
|
|
421
|
+
0.9
|
|
422
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
|
|
423
|
+
0.1
|
|
424
|
+
"""
|
|
425
|
+
auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
|
|
426
|
+
neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
|
|
427
|
+
Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Calculate metrics
|
|
431
|
+
metrics: dict[str, float] = {
|
|
432
|
+
MetricDictionnary.AUPRC: auc_value,
|
|
433
|
+
MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
|
|
434
|
+
MetricDictionnary.PR_AVERAGE: average_precision,
|
|
435
|
+
MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
# Calculate optimal thresholds for both PR curves
|
|
439
|
+
for is_negative in (False, True):
|
|
440
|
+
|
|
441
|
+
# Get the right values based on positive/negative case
|
|
442
|
+
if not is_negative:
|
|
443
|
+
curr_precision = precision
|
|
444
|
+
curr_recall = recall
|
|
445
|
+
curr_thresholds = thresholds
|
|
446
|
+
curr_auc = auc_value
|
|
447
|
+
curr_ap = average_precision
|
|
448
|
+
else:
|
|
449
|
+
curr_precision = npv
|
|
450
|
+
curr_recall = specificity
|
|
451
|
+
curr_thresholds = neg_thresholds
|
|
452
|
+
curr_auc = neg_auc_value
|
|
453
|
+
curr_ap = average_precision_neg
|
|
454
|
+
|
|
455
|
+
# Calculate F-score for each threshold
|
|
456
|
+
fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
|
|
457
|
+
fscore = fscore[~np.isnan(fscore)]
|
|
458
|
+
|
|
459
|
+
# Get optimal threshold (maximum F-score)
|
|
460
|
+
if len(fscore) > 0:
|
|
461
|
+
optimal_idx: int = int(np.argmax(fscore))
|
|
462
|
+
optimal_threshold: float = curr_thresholds[optimal_idx]
|
|
463
|
+
else:
|
|
464
|
+
optimal_idx: int = 0
|
|
465
|
+
optimal_threshold = float('inf')
|
|
466
|
+
|
|
467
|
+
# Store in metrics dictionary
|
|
468
|
+
if not is_negative:
|
|
469
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
|
|
470
|
+
else:
|
|
471
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
|
|
472
|
+
|
|
473
|
+
# Plot PR curve if run_name and minimum number of samples is reached
|
|
474
|
+
if run_name and len(true_classes) >= plot_if_minimum:
|
|
475
|
+
label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
|
|
476
|
+
plt.figure(figsize=(12, 6))
|
|
477
|
+
plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
|
|
478
|
+
|
|
479
|
+
# Prepare the path
|
|
480
|
+
fold_name: str = ""
|
|
481
|
+
if fold_number > 0:
|
|
482
|
+
fold_name = f"_fold_{fold_number}_val"
|
|
483
|
+
elif fold_number == 0:
|
|
484
|
+
fold_name = "_val"
|
|
485
|
+
elif fold_number == -1:
|
|
486
|
+
fold_name = "_test"
|
|
487
|
+
elif fold_number == -2:
|
|
488
|
+
fold_name = "_train"
|
|
489
|
+
pr: str = "pr" if not is_negative else "negative_pr"
|
|
490
|
+
curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
|
|
491
|
+
|
|
492
|
+
plt.plot(
|
|
493
|
+
curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
plt.xlim([-0.01, 1.01])
|
|
497
|
+
plt.ylim([-0.01, 1.01])
|
|
498
|
+
plt.xlabel("Recall" if not is_negative else "Specificity")
|
|
499
|
+
plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
|
|
500
|
+
plt.title(f"{label} Curve")
|
|
501
|
+
plt.legend(loc="lower right")
|
|
502
|
+
plt.savefig(curve_path)
|
|
503
|
+
mlflow.log_artifact(curve_path)
|
|
504
|
+
os.remove(curve_path)
|
|
505
|
+
plt.close()
|
|
506
|
+
|
|
507
|
+
return metrics
|
|
508
|
+
|
|
509
|
+
@staticmethod
|
|
510
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
511
|
+
def all_curves(
|
|
512
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
513
|
+
pred_probs: NDArray[np.single],
|
|
514
|
+
fold_number: int = -1,
|
|
515
|
+
run_name: str = ""
|
|
516
|
+
) -> dict[str, float]:
|
|
517
|
+
""" Run all X_curve_and_auc functions and return a dictionary of metrics.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
521
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
522
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
523
|
+
-1 for final model with test set,
|
|
524
|
+
0 for final model with validation set,
|
|
525
|
+
>0 for other folds with their validation set
|
|
526
|
+
run_name (str): Name for saving the plot
|
|
527
|
+
Returns:
|
|
528
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
|
|
529
|
+
"""
|
|
530
|
+
metrics: dict[str, float] = {}
|
|
531
|
+
metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
532
|
+
metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
533
|
+
return metrics
|
|
534
|
+
|
|
535
|
+
|
|
536
|
+
@staticmethod
|
|
537
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
538
|
+
def plot_metric_curves(
|
|
539
|
+
all_history: list[dict[str, list[float]]],
|
|
540
|
+
metric_name: str,
|
|
541
|
+
run_name: str = ""
|
|
542
|
+
) -> None:
|
|
543
|
+
""" Plot training and validation curves for a specific metric.
|
|
544
|
+
|
|
545
|
+
Generates two plots for the given metric:
|
|
546
|
+
1. A combined plot with both training and validation curves
|
|
547
|
+
2. A validation-only plot
|
|
548
|
+
|
|
549
|
+
The plots show the metric's progression across training epochs for each fold.
|
|
550
|
+
Special formatting distinguishes between folds and curve types:
|
|
551
|
+
- Fold 0 (final model) uses thicker lines (2.0 width vs 1.0)
|
|
552
|
+
- Training curves use solid lines, validation uses dashed
|
|
553
|
+
- Each curve is clearly labeled in the legend
|
|
554
|
+
|
|
555
|
+
The plots are saved to the temp folder and logged to MLflow before cleanup.
|
|
556
|
+
|
|
557
|
+
Args:
|
|
558
|
+
all_history (list[dict[str, list[float]]]): List of history dictionaries for each fold
|
|
559
|
+
metric_name (str): Name of the metric to plot (e.g. "accuracy", "loss")
|
|
560
|
+
run_name (str): Name of the run
|
|
561
|
+
|
|
562
|
+
Examples:
|
|
563
|
+
>>> # Prepare data with 2 folds for instance
|
|
564
|
+
>>> all_history = [
|
|
565
|
+
... {'loss': [0.1, 0.09, 0.08, 0.07, 0.06], 'val_loss': [0.11, 0.1, 0.09, 0.08, 0.07]},
|
|
566
|
+
... {'loss': [0.12, 0.11, 0.1, 0.09, 0.08], 'val_loss': [0.13, 0.12, 0.11, 0.1, 0.09]}
|
|
567
|
+
... ]
|
|
568
|
+
>>> MetricUtils.plot_metric_curves(metric_name="loss", all_history=all_history, run_name="")
|
|
569
|
+
"""
|
|
570
|
+
for only_validation in (False, True):
|
|
571
|
+
plt.figure(figsize=(12, 6))
|
|
572
|
+
|
|
573
|
+
# Track max value for y-limit calculation
|
|
574
|
+
max_value: float = 0.0
|
|
575
|
+
|
|
576
|
+
for fold, history in enumerate(all_history):
|
|
577
|
+
# Get validation metrics for this fold
|
|
578
|
+
val_metric: list[float] = history[f"val_{metric_name}"]
|
|
579
|
+
epochs: list[int] = list(range(1, len(val_metric) + 1))
|
|
580
|
+
|
|
581
|
+
# Update max value
|
|
582
|
+
max_value = max(max_value, max(val_metric))
|
|
583
|
+
|
|
584
|
+
# Use thicker line for final model (fold 0)
|
|
585
|
+
alpha: float = 1.0 if fold == 0 else 0.5
|
|
586
|
+
linewidth: float = 2.0 if fold == 0 else 1.0
|
|
587
|
+
label: str = "Final Model" if fold == 0 else f"Fold {fold + 1}"
|
|
588
|
+
val_label: str = f"Validation {metric_name} ({label})"
|
|
589
|
+
plt.plot(epochs, val_metric, linestyle='--', linewidth=linewidth, alpha=alpha, label=val_label)
|
|
590
|
+
|
|
591
|
+
# Add training metrics if showing both curves
|
|
592
|
+
if not only_validation:
|
|
593
|
+
train_metric: list[float] = history[metric_name]
|
|
594
|
+
max_value = max(max_value, max(train_metric))
|
|
595
|
+
train_label: str = f"Training {metric_name} ({label})"
|
|
596
|
+
plt.plot(epochs, train_metric, linestyle='-', linewidth=linewidth, alpha=alpha, label=train_label)
|
|
597
|
+
|
|
598
|
+
# Configure plot formatting
|
|
599
|
+
plt.title(("Training and " if not only_validation else "") + f"Validation {metric_name} Across All Folds")
|
|
600
|
+
plt.xlabel("Epochs")
|
|
601
|
+
plt.ylabel(metric_name)
|
|
602
|
+
|
|
603
|
+
# Set y-limit for loss metric, to avoid seeing non-sense curves
|
|
604
|
+
if metric_name == "loss" and not only_validation:
|
|
605
|
+
plt.ylim(0, min(2.0, max_value * 1.1))
|
|
606
|
+
|
|
607
|
+
# Add legend
|
|
608
|
+
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
609
|
+
plt.tight_layout()
|
|
610
|
+
|
|
611
|
+
# Save plot and log to MLflow
|
|
612
|
+
if run_name:
|
|
613
|
+
path: str = ("training_" if not only_validation else "") + f"validation_{metric_name}_curves.png"
|
|
614
|
+
full_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{path}"
|
|
615
|
+
plt.savefig(full_path, bbox_inches='tight')
|
|
616
|
+
mlflow.log_artifact(full_path)
|
|
617
|
+
os.remove(full_path)
|
|
618
|
+
plt.close()
|
|
619
|
+
|
|
620
|
+
@staticmethod
|
|
621
|
+
def plot_every_metric_curves(
|
|
622
|
+
all_history: list[dict[str, list[float]]],
|
|
623
|
+
metrics_names: tuple[str, ...] = (),
|
|
624
|
+
run_name: str = ""
|
|
625
|
+
) -> None:
|
|
626
|
+
""" Plot and save training and validation curves for each metric.
|
|
627
|
+
|
|
628
|
+
Args:
|
|
629
|
+
all_history (list[dict[str, list[float]]]): List of history dictionaries for each fold
|
|
630
|
+
metrics_names (tuple[str, ...]): List of metric names to plot, defaults to ("loss",)
|
|
631
|
+
run_name (str): Name of the run
|
|
632
|
+
|
|
633
|
+
Examples:
|
|
634
|
+
>>> # Prepare data with 2 folds for instance
|
|
635
|
+
>>> all_history = [
|
|
636
|
+
... {'loss': [0.1, 0.09], 'val_loss': [0.11, 0.1], "accuracy": [0.9, 0.8], "val_accuracy": [0.8, 0.7]},
|
|
637
|
+
... {'loss': [0.12, 0.11], 'val_loss': [0.13, 0.12], "accuracy": [0.8, 0.7], "val_accuracy": [0.7, 0.6]}
|
|
638
|
+
... ]
|
|
639
|
+
>>> MetricUtils.plot_every_metric_curves(all_history, metrics_names=["loss", "accuracy"], run_name="")
|
|
640
|
+
"""
|
|
641
|
+
# Set default metrics names to loss
|
|
642
|
+
if not metrics_names:
|
|
643
|
+
metrics_names = ("loss",)
|
|
644
|
+
|
|
645
|
+
# Plot each metric
|
|
646
|
+
for metric_name in metrics_names:
|
|
647
|
+
MetricUtils.plot_metric_curves(all_history, metric_name, run_name)
|
|
648
|
+
|
|
649
|
+
@staticmethod
|
|
650
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
651
|
+
def find_best_x_and_plot(
|
|
652
|
+
x_values: list[float],
|
|
653
|
+
y_values: list[float],
|
|
654
|
+
best_idx: int | None = None,
|
|
655
|
+
smoothen: bool = True,
|
|
656
|
+
use_steep: bool = True,
|
|
657
|
+
run_name: str = "",
|
|
658
|
+
x_label: str = "Learning Rate",
|
|
659
|
+
y_label: str = "Loss",
|
|
660
|
+
plot_title: str = "Learning Rate Finder",
|
|
661
|
+
log_x: bool = True,
|
|
662
|
+
y_limits: tuple[float, ...] | None = None
|
|
663
|
+
) -> float:
|
|
664
|
+
""" Find the best x value (where y is minimized) and plot the curve.
|
|
665
|
+
|
|
666
|
+
Args:
|
|
667
|
+
x_values (list[float]): List of x values (e.g. learning rates)
|
|
668
|
+
y_values (list[float]): List of corresponding y values (e.g. losses)
|
|
669
|
+
best_idx (int | None): Index of the best x value (if None, a robust approach is used)
|
|
670
|
+
smoothen (bool): Whether to apply smoothing to the y values
|
|
671
|
+
use_steep (bool): Whether to use steepest slope strategy to determine best index
|
|
672
|
+
run_name (str): Name of the run for saving the plot
|
|
673
|
+
x_label (str): Label for the x-axis
|
|
674
|
+
y_label (str): Label for the y-axis
|
|
675
|
+
plot_title (str): Title for the plot
|
|
676
|
+
log_x (bool): Whether to use a logarithmic x-axis (e.g. learning rate)
|
|
677
|
+
y_limits (tuple[float, ...] | None): Limit for the y-axis, defaults to None (no limit)
|
|
678
|
+
|
|
679
|
+
Returns:
|
|
680
|
+
float: The best x value found (where y is minimized)
|
|
681
|
+
|
|
682
|
+
This function creates a plot showing the relationship between x and y values
|
|
683
|
+
to help identify the optimal x (where y is minimized). The plot can use a logarithmic
|
|
684
|
+
x-axis for better visualization if desired.
|
|
685
|
+
|
|
686
|
+
The ideal x is typically found where y is still decreasing but before it starts to increase dramatically.
|
|
687
|
+
|
|
688
|
+
Examples:
|
|
689
|
+
>>> x_values = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
|
|
690
|
+
>>> y_values = [0.1, 0.09, 0.07, 0.06, 0.09]
|
|
691
|
+
>>> best_x = MetricUtils.find_best_x_and_plot(x_values, y_values, use_steep=True)
|
|
692
|
+
>>> print(f"Best x: {best_x:.0e}")
|
|
693
|
+
Best x: 1e-03
|
|
694
|
+
>>> best_x = MetricUtils.find_best_x_and_plot(x_values, y_values, use_steep=False)
|
|
695
|
+
>>> print(f"Best x: {best_x:.0e}")
|
|
696
|
+
Best x: 1e-02
|
|
697
|
+
"""
|
|
698
|
+
# Validate input data
|
|
699
|
+
assert x_values, "No x data to plot"
|
|
700
|
+
assert y_values, "No y data to plot"
|
|
701
|
+
|
|
702
|
+
# Convert lists to numpy arrays for easier manipulation
|
|
703
|
+
y_array: NDArray[np.single] = np.array(y_values)
|
|
704
|
+
x_array: NDArray[np.single] = np.array(x_values)
|
|
705
|
+
|
|
706
|
+
# Apply smoothing to the y values if requested and if we have enough data points
|
|
707
|
+
if smoothen and len(y_values) > 2:
|
|
708
|
+
|
|
709
|
+
# Calculate appropriate window size based on data length
|
|
710
|
+
window_size: int = min(10, len(y_values) // 3)
|
|
711
|
+
if window_size > 1:
|
|
712
|
+
|
|
713
|
+
# Apply moving average smoothing using convolution
|
|
714
|
+
valid_convolution: NDArray[np.single] = np.convolve(y_array, np.ones(window_size)/window_size, mode="valid")
|
|
715
|
+
y_array = np.copy(y_array)
|
|
716
|
+
|
|
717
|
+
# Calculate start and end indices for replacing values with smoothed ones
|
|
718
|
+
start_idx: int = window_size // 2
|
|
719
|
+
end_idx: int = start_idx + len(valid_convolution)
|
|
720
|
+
y_array[start_idx:end_idx] = valid_convolution
|
|
721
|
+
|
|
722
|
+
# Replace first and last values with original values (to avoid weird effects)
|
|
723
|
+
y_array[0] = y_values[0]
|
|
724
|
+
y_array[-1] = y_values[-1]
|
|
725
|
+
|
|
726
|
+
# 1. Global minimum index between 10% and 90% (excluding borders)
|
|
727
|
+
window_start: int = int(0.1 * len(y_array))
|
|
728
|
+
window_end: int = int(0.9 * len(y_array))
|
|
729
|
+
global_window_min_idx: int = int(np.argmin(y_array[window_start:window_end]))
|
|
730
|
+
global_min_idx: int = global_window_min_idx + window_start
|
|
731
|
+
|
|
732
|
+
# Determine best index
|
|
733
|
+
if best_idx is None:
|
|
734
|
+
if use_steep:
|
|
735
|
+
|
|
736
|
+
# 2. Compute slope in loss vs log(x) for LR sensitivity
|
|
737
|
+
log_x_array: NDArray[np.single] = np.log(x_array)
|
|
738
|
+
slopes: NDArray[np.single] = np.gradient(y_array, log_x_array)
|
|
739
|
+
|
|
740
|
+
# 3. Define proximity window to the left of global minimum
|
|
741
|
+
proximity: int = max(1, len(y_array) // 10)
|
|
742
|
+
window_start = max(0, global_min_idx - proximity)
|
|
743
|
+
|
|
744
|
+
# 4. Find steepest slope within window
|
|
745
|
+
if window_start < global_min_idx:
|
|
746
|
+
local_slopes: NDArray[np.single] = slopes[window_start:global_min_idx]
|
|
747
|
+
relative_idx: int = int(np.argmin(local_slopes))
|
|
748
|
+
steep_idx: int = window_start + relative_idx
|
|
749
|
+
best_idx = steep_idx
|
|
750
|
+
else:
|
|
751
|
+
best_idx = global_min_idx
|
|
752
|
+
|
|
753
|
+
# 5. Top-7 most negative slopes as candidates
|
|
754
|
+
neg_idx: NDArray[np.intp] = np.where(slopes < 0)[0]
|
|
755
|
+
sorted_neg: NDArray[np.intp] = neg_idx[np.argsort(slopes[neg_idx])]
|
|
756
|
+
top7_fave: NDArray[np.intp] = sorted_neg[:7]
|
|
757
|
+
|
|
758
|
+
# Include best_idx and global_min_idx
|
|
759
|
+
candidates: set[int] = set(top7_fave.tolist())
|
|
760
|
+
candidates.add(best_idx)
|
|
761
|
+
distinct_candidates = np.array(sorted(candidates, key=int))
|
|
762
|
+
else:
|
|
763
|
+
best_idx = global_min_idx
|
|
764
|
+
|
|
765
|
+
# Find all local minima
|
|
766
|
+
from scipy.signal import argrelextrema
|
|
767
|
+
local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
|
|
768
|
+
distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
|
|
769
|
+
else:
|
|
770
|
+
assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
|
|
771
|
+
distinct_candidates = np.array([best_idx])
|
|
772
|
+
|
|
773
|
+
# Get the best x value and corresponding y value
|
|
774
|
+
best_x: float = x_array[best_idx]
|
|
775
|
+
min_y: float = y_array[best_idx]
|
|
776
|
+
|
|
777
|
+
# Create and save the plot if a run name is provided
|
|
778
|
+
if run_name:
|
|
779
|
+
|
|
780
|
+
# Log metrics to mlflow (e.g. 'learning_rate_finder_learning_rate', 'learning_rate_finder_loss')
|
|
781
|
+
log_title: str = MetricDictionnary.PARAMETER_FINDER.replace("TITLE", plot_title)
|
|
782
|
+
log_x_label: str = log_title.replace("PARAMETER_NAME", x_label)
|
|
783
|
+
log_y_label: str = log_title.replace("PARAMETER_NAME", y_label)
|
|
784
|
+
for i in range(len(x_values)):
|
|
785
|
+
mlflow.log_metric(log_x_label, x_values[i], step=i)
|
|
786
|
+
mlflow.log_metric(log_y_label, y_values[i], step=i)
|
|
787
|
+
|
|
788
|
+
# Prepare the plot
|
|
789
|
+
plt.figure(figsize=(12, 6))
|
|
790
|
+
plt.plot(x_array, y_array, label="Smoothed Curve", linewidth=2)
|
|
791
|
+
plt.plot(x_values, y_values, "-", markersize=3, alpha=0.5, label="Original Curve", color="gray")
|
|
792
|
+
|
|
793
|
+
# Use logarithmic scale for x-axis if requested
|
|
794
|
+
if log_x:
|
|
795
|
+
plt.xscale("log")
|
|
796
|
+
|
|
797
|
+
# Set labels and title
|
|
798
|
+
plt.xlabel(x_label)
|
|
799
|
+
plt.ylabel(y_label)
|
|
800
|
+
plt.title(plot_title)
|
|
801
|
+
plt.grid(True, which="both", ls="--")
|
|
802
|
+
|
|
803
|
+
# Limit y-axis to avoid extreme values
|
|
804
|
+
if y_limits is not None and len(y_limits) == 2:
|
|
805
|
+
min_y_limit: float = max(y_limits[0], min(y_values) * 0.9)
|
|
806
|
+
max_y_limit: float = min(y_limits[1], max(y_values) * 1.1)
|
|
807
|
+
plt.ylim(min_y_limit, max_y_limit)
|
|
808
|
+
plt.legend()
|
|
809
|
+
|
|
810
|
+
# Highlight local minima if any
|
|
811
|
+
if len(distinct_candidates) > 0:
|
|
812
|
+
candidate_xs = [x_array[idx] for idx in distinct_candidates]
|
|
813
|
+
candidate_ys = [y_array[idx] for idx in distinct_candidates]
|
|
814
|
+
candidates_label = "Possible Candidates" if use_steep else "Local Minima"
|
|
815
|
+
plt.scatter(candidate_xs, candidate_ys, color="orange", s=25, zorder=4, label=candidates_label)
|
|
816
|
+
|
|
817
|
+
# Highlight the best point
|
|
818
|
+
plt.scatter([x_array[global_min_idx]], [y_array[global_min_idx]], color="red", s=50, zorder=5, label="Global Minimum")
|
|
819
|
+
|
|
820
|
+
# Format the best x value for display
|
|
821
|
+
best_x_str: str = f"{best_x:.2e}" if best_x < 1e-3 else f"{best_x:.2f}"
|
|
822
|
+
|
|
823
|
+
# Add annotation pointing to the best point
|
|
824
|
+
plt.annotate(
|
|
825
|
+
f"Supposed best {x_label}: {best_x_str}",
|
|
826
|
+
xy=(best_x, min_y),
|
|
827
|
+
xytext=(best_x * 1.5, min_y * 1.1),
|
|
828
|
+
arrowprops={"facecolor":"black", "shrink":0.05, "width":1.2}
|
|
829
|
+
)
|
|
830
|
+
plt.legend()
|
|
831
|
+
plt.tight_layout()
|
|
832
|
+
|
|
833
|
+
# Save the plot to a file and log it to MLflow
|
|
834
|
+
flat_x_label: str = x_label.lower().replace(" ", "_")
|
|
835
|
+
path: str = f"{flat_x_label}_finder.png"
|
|
836
|
+
os.makedirs(DataScienceConfig.TEMP_FOLDER, exist_ok=True)
|
|
837
|
+
full_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{path}"
|
|
838
|
+
plt.savefig(full_path, bbox_inches="tight")
|
|
839
|
+
mlflow.log_artifact(full_path)
|
|
840
|
+
info(f"Saved best x plot to {full_path}")
|
|
841
|
+
|
|
842
|
+
# Clean up the temporary file
|
|
843
|
+
os.remove(full_path)
|
|
844
|
+
plt.close()
|
|
845
|
+
|
|
846
|
+
return best_x
|
|
847
|
+
|