stouputils 1.3.6__tar.gz → 1.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stouputils-1.3.6 → stouputils-1.3.8}/PKG-INFO +1 -1
- {stouputils-1.3.6 → stouputils-1.3.8}/pyproject.toml +1 -1
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/metric_dictionnary.py +13 -1
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/metric_utils.py +170 -15
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/model_interface.py +5 -14
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/utils.py +73 -6
- {stouputils-1.3.6 → stouputils-1.3.8}/.gitignore +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/LICENSE +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/README.md +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/all_doctests.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/automatic_docs.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/config.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/image.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/video.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/archive.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/backup.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/collections.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/cd_utils.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/github.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/pypi.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/pyproject.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/ctx.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/config/get.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/config/set.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/auto_contrast.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/axis_flip.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/bias_field_correction.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/binary_threshold.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/blur.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/brightness.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/canny.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/clahe.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/common.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/contrast.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/curvature_flow_filter.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/denoise.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/histogram_equalization.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/invert.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/laplacian.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/median_blur.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/noise.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/normalize.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/random_erase.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/resize.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/rotation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/salt_pepper.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/sharpening.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/shearing.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/threshold.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/translation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/zoom.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_augmentation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_preprocess.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/prosthesis_detection.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/technique.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/dataset_loader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/grouping_strategy.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/image_loader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/xy_tuple.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/mlflow_utils.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/abstract_model.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/all.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/base_keras.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/all.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/convnext.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/densenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/efficientnet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/mobilenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/resnet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/squeezenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/vgg.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/xception.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/losses/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/visualizations.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/sandbox.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/range_tuple.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/augment_dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/exhaustive_process.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/preprocess_dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/routine.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/decorators.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/dont_look/zip_file_override.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/image.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/common.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/downloader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/linux.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/main.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/windows.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/io.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/parallel.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/print.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: stouputils
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.8
|
|
4
4
|
Summary: Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more.
|
|
5
5
|
Project-URL: Homepage, https://github.com/Stoupy51/stouputils
|
|
6
6
|
Project-URL: Issues, https://github.com/Stoupy51/stouputils/issues
|
|
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "stouputils"
|
|
8
|
-
version = "1.3.
|
|
8
|
+
version = "1.3.8"
|
|
9
9
|
description = "Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more."
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
requires-python = ">=3.10"
|
|
@@ -8,12 +8,16 @@ This class contains the following metrics:
|
|
|
8
8
|
1. Main metrics:
|
|
9
9
|
|
|
10
10
|
- Area Under the Curve (AUC)
|
|
11
|
+
- Area Under the Precision-Recall Curve (AUPRC)
|
|
12
|
+
- Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
|
|
11
13
|
- Specificity (True Negative Rate)
|
|
12
14
|
- Recall/Sensitivity (True Positive Rate)
|
|
13
15
|
- Precision (Positive Predictive Value)
|
|
14
16
|
- Negative Predictive Value (NPV)
|
|
15
17
|
- Accuracy
|
|
16
18
|
- F1 Score
|
|
19
|
+
- Precision-Recall Average
|
|
20
|
+
- Precision-Recall Average for Negative Class
|
|
17
21
|
|
|
18
22
|
2. Confusion matrix metrics:
|
|
19
23
|
|
|
@@ -39,6 +43,8 @@ This class contains the following metrics:
|
|
|
39
43
|
|
|
40
44
|
- Youden's J statistic
|
|
41
45
|
- Cost-based threshold
|
|
46
|
+
- F1 Score threshold
|
|
47
|
+
- F1 Score threshold for the negative class
|
|
42
48
|
|
|
43
49
|
6. Average metrics across folds:
|
|
44
50
|
|
|
@@ -52,7 +58,9 @@ This class contains the following metrics:
|
|
|
52
58
|
class MetricDictionnary:
|
|
53
59
|
|
|
54
60
|
# Main metrics (starting with '1:')
|
|
55
|
-
AUC: str = "1: Area Under the Curve: AUC"
|
|
61
|
+
AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
|
|
62
|
+
AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
|
|
63
|
+
NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
|
|
56
64
|
SPECIFICITY: str = "1: Specificity: True Negative Rate"
|
|
57
65
|
RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
|
|
58
66
|
PRECISION: str = "1: Precision: Positive Predictive Value"
|
|
@@ -61,6 +69,8 @@ class MetricDictionnary:
|
|
|
61
69
|
BALANCED_ACCURACY: str = "1: Balanced Accuracy"
|
|
62
70
|
F1_SCORE: str = "1: F1 Score"
|
|
63
71
|
F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
|
|
72
|
+
PR_AVERAGE: str = "1: Precision-Recall Average"
|
|
73
|
+
PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
|
|
64
74
|
|
|
65
75
|
# Confusion matrix metrics (starting with '2:')
|
|
66
76
|
CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
|
|
@@ -82,6 +92,8 @@ class MetricDictionnary:
|
|
|
82
92
|
# Optimal thresholds (starting with '5:')
|
|
83
93
|
OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
|
|
84
94
|
OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
|
|
95
|
+
OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
|
|
96
|
+
OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
|
|
85
97
|
|
|
86
98
|
# Average metrics across folds (starting with '6:')
|
|
87
99
|
AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
|
|
@@ -117,10 +117,10 @@ class MetricUtils:
|
|
|
117
117
|
if mcc_metric:
|
|
118
118
|
metrics.update(mcc_metric)
|
|
119
119
|
|
|
120
|
-
# Calculate and plot ROC/AUC
|
|
121
|
-
|
|
122
|
-
if
|
|
123
|
-
metrics.update(
|
|
120
|
+
# Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
|
|
121
|
+
curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
|
|
122
|
+
if curves_metrics:
|
|
123
|
+
metrics.update(curves_metrics)
|
|
124
124
|
|
|
125
125
|
# Multiclass classification
|
|
126
126
|
elif mode == "multiclass":
|
|
@@ -287,22 +287,24 @@ class MetricUtils:
|
|
|
287
287
|
|
|
288
288
|
@staticmethod
|
|
289
289
|
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
290
|
-
def
|
|
290
|
+
def roc_curve_and_auc(
|
|
291
291
|
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
292
292
|
pred_probs: NDArray[np.single],
|
|
293
293
|
fold_number: int = -1,
|
|
294
|
-
run_name: str = ""
|
|
294
|
+
run_name: str = "",
|
|
295
|
+
plot_if_minimum: int = 5
|
|
295
296
|
) -> dict[str, float]:
|
|
296
297
|
""" Calculate ROC curve and AUC score.
|
|
297
298
|
|
|
298
299
|
Args:
|
|
299
|
-
true_classes
|
|
300
|
-
pred_probs
|
|
301
|
-
fold_number
|
|
300
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
301
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
302
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
302
303
|
-1 for final model with test set,
|
|
303
304
|
0 for final model with validation set,
|
|
304
305
|
>0 for other folds with their validation set
|
|
305
|
-
run_name
|
|
306
|
+
run_name (str): Name for saving the plot
|
|
307
|
+
plot_if_minimum (int): Minimum number of samples required in true_classes to plot the ROC curve
|
|
306
308
|
Returns:
|
|
307
309
|
dict[str, float]: Dictionary containing AUC score and optimal thresholds
|
|
308
310
|
|
|
@@ -311,7 +313,7 @@ class MetricUtils:
|
|
|
311
313
|
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
312
314
|
>>> from stouputils.ctx import Muffle
|
|
313
315
|
>>> with Muffle():
|
|
314
|
-
... metrics = MetricUtils.
|
|
316
|
+
... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
315
317
|
|
|
316
318
|
>>> # Check metrics
|
|
317
319
|
>>> round(float(metrics[MetricDictionnary.AUC]), 2)
|
|
@@ -321,7 +323,6 @@ class MetricUtils:
|
|
|
321
323
|
>>> float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST])
|
|
322
324
|
inf
|
|
323
325
|
"""
|
|
324
|
-
true_classes = Utils.convert_to_class_indices(true_classes)
|
|
325
326
|
auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(true_classes, pred_probs)
|
|
326
327
|
metrics: dict[str, float] = {MetricDictionnary.AUC: auc_value}
|
|
327
328
|
|
|
@@ -339,8 +340,8 @@ class MetricUtils:
|
|
|
339
340
|
optimal_threshold_cost: float = thresholds[np.argmin(total_cost)]
|
|
340
341
|
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST] = optimal_threshold_cost
|
|
341
342
|
|
|
342
|
-
# Plot ROC curve if
|
|
343
|
-
if run_name and
|
|
343
|
+
# Plot ROC curve if run_name and minimum number of samples is reached
|
|
344
|
+
if run_name and len(true_classes) >= plot_if_minimum:
|
|
344
345
|
plt.figure(figsize=(12, 6))
|
|
345
346
|
plt.plot(fpr, tpr, "b", label=f"ROC curve (AUC = {auc_value:.2f})")
|
|
346
347
|
plt.plot([0, 1], [0, 1], "r--")
|
|
@@ -377,6 +378,160 @@ class MetricUtils:
|
|
|
377
378
|
|
|
378
379
|
return metrics
|
|
379
380
|
|
|
381
|
+
@staticmethod
|
|
382
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
383
|
+
def pr_curve_and_auc(
|
|
384
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
385
|
+
pred_probs: NDArray[np.single],
|
|
386
|
+
fold_number: int = -1,
|
|
387
|
+
run_name: str = "",
|
|
388
|
+
plot_if_minimum: int = 5
|
|
389
|
+
) -> dict[str, float]:
|
|
390
|
+
""" Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
|
|
391
|
+
|
|
392
|
+
Args:
|
|
393
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
394
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
395
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
396
|
+
-1 for final model with test set,
|
|
397
|
+
0 for final model with validation set,
|
|
398
|
+
>0 for other folds with their validation set
|
|
399
|
+
run_name (str): Name for saving the plot
|
|
400
|
+
plot_if_minimum (int): Minimum number of samples required in true_classes to plot the PR curves
|
|
401
|
+
Returns:
|
|
402
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds
|
|
403
|
+
|
|
404
|
+
Examples:
|
|
405
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
406
|
+
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
407
|
+
>>> from stouputils.ctx import Muffle
|
|
408
|
+
>>> with Muffle():
|
|
409
|
+
... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
410
|
+
|
|
411
|
+
>>> # Check metrics
|
|
412
|
+
>>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
|
|
413
|
+
0.75
|
|
414
|
+
>>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
|
|
415
|
+
0.92
|
|
416
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
|
|
417
|
+
0.5
|
|
418
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
|
|
419
|
+
0.33
|
|
420
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
|
|
421
|
+
0.9
|
|
422
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
|
|
423
|
+
0.1
|
|
424
|
+
"""
|
|
425
|
+
auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
|
|
426
|
+
neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
|
|
427
|
+
Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
|
|
428
|
+
)
|
|
429
|
+
|
|
430
|
+
# Calculate metrics
|
|
431
|
+
metrics: dict[str, float] = {
|
|
432
|
+
MetricDictionnary.AUPRC: auc_value,
|
|
433
|
+
MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
|
|
434
|
+
MetricDictionnary.PR_AVERAGE: average_precision,
|
|
435
|
+
MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
# Calculate optimal thresholds for both PR curves
|
|
439
|
+
for is_negative in (False, True):
|
|
440
|
+
|
|
441
|
+
# Get the right values based on positive/negative case
|
|
442
|
+
if not is_negative:
|
|
443
|
+
curr_precision = precision
|
|
444
|
+
curr_recall = recall
|
|
445
|
+
curr_thresholds = thresholds
|
|
446
|
+
curr_auc = auc_value
|
|
447
|
+
curr_ap = average_precision
|
|
448
|
+
else:
|
|
449
|
+
curr_precision = npv
|
|
450
|
+
curr_recall = specificity
|
|
451
|
+
curr_thresholds = neg_thresholds
|
|
452
|
+
curr_auc = neg_auc_value
|
|
453
|
+
curr_ap = average_precision_neg
|
|
454
|
+
|
|
455
|
+
# Calculate F-score for each threshold
|
|
456
|
+
fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
|
|
457
|
+
fscore = fscore[~np.isnan(fscore)]
|
|
458
|
+
|
|
459
|
+
# Get optimal threshold (maximum F-score)
|
|
460
|
+
if len(fscore) > 0:
|
|
461
|
+
optimal_idx: int = int(np.argmax(fscore))
|
|
462
|
+
optimal_threshold: float = curr_thresholds[optimal_idx]
|
|
463
|
+
else:
|
|
464
|
+
optimal_idx: int = 0
|
|
465
|
+
optimal_threshold = float('inf')
|
|
466
|
+
|
|
467
|
+
# Store in metrics dictionary
|
|
468
|
+
if not is_negative:
|
|
469
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
|
|
470
|
+
else:
|
|
471
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
|
|
472
|
+
|
|
473
|
+
# Plot PR curve if run_name and minimum number of samples is reached
|
|
474
|
+
if run_name and len(true_classes) >= plot_if_minimum:
|
|
475
|
+
label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
|
|
476
|
+
plt.figure(figsize=(12, 6))
|
|
477
|
+
plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
|
|
478
|
+
|
|
479
|
+
# Prepare the path
|
|
480
|
+
fold_name: str = ""
|
|
481
|
+
if fold_number > 0:
|
|
482
|
+
fold_name = f"_fold_{fold_number}_val"
|
|
483
|
+
elif fold_number == 0:
|
|
484
|
+
fold_name = "_val"
|
|
485
|
+
elif fold_number == -1:
|
|
486
|
+
fold_name = "_test"
|
|
487
|
+
elif fold_number == -2:
|
|
488
|
+
fold_name = "_train"
|
|
489
|
+
pr: str = "pr" if not is_negative else "negative_pr"
|
|
490
|
+
curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
|
|
491
|
+
|
|
492
|
+
plt.plot(
|
|
493
|
+
curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
plt.xlim([-0.01, 1.01])
|
|
497
|
+
plt.ylim([-0.01, 1.01])
|
|
498
|
+
plt.xlabel("Recall" if not is_negative else "Specificity")
|
|
499
|
+
plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
|
|
500
|
+
plt.title(f"{label} Curve")
|
|
501
|
+
plt.legend(loc="lower right")
|
|
502
|
+
plt.savefig(curve_path)
|
|
503
|
+
mlflow.log_artifact(curve_path)
|
|
504
|
+
os.remove(curve_path)
|
|
505
|
+
plt.close()
|
|
506
|
+
|
|
507
|
+
return metrics
|
|
508
|
+
|
|
509
|
+
@staticmethod
|
|
510
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
511
|
+
def all_curves(
|
|
512
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
513
|
+
pred_probs: NDArray[np.single],
|
|
514
|
+
fold_number: int = -1,
|
|
515
|
+
run_name: str = ""
|
|
516
|
+
) -> dict[str, float]:
|
|
517
|
+
""" Run all X_curve_and_auc functions and return a dictionary of metrics.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
521
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
522
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
523
|
+
-1 for final model with test set,
|
|
524
|
+
0 for final model with validation set,
|
|
525
|
+
>0 for other folds with their validation set
|
|
526
|
+
run_name (str): Name for saving the plot
|
|
527
|
+
Returns:
|
|
528
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
|
|
529
|
+
"""
|
|
530
|
+
metrics: dict[str, float] = {}
|
|
531
|
+
metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
532
|
+
metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
533
|
+
return metrics
|
|
534
|
+
|
|
380
535
|
|
|
381
536
|
@staticmethod
|
|
382
537
|
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
@@ -609,7 +764,7 @@ class MetricUtils:
|
|
|
609
764
|
|
|
610
765
|
# Find all local minima
|
|
611
766
|
from scipy.signal import argrelextrema
|
|
612
|
-
local_minima_idx: NDArray[np.intp] = argrelextrema(y_array, np.less)[0]
|
|
767
|
+
local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
|
|
613
768
|
distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
|
|
614
769
|
else:
|
|
615
770
|
assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
|
|
@@ -851,22 +851,13 @@ class ModelInterface(AbstractModel):
|
|
|
851
851
|
self.history.append(history)
|
|
852
852
|
self.evaluation_results.append(eval_results)
|
|
853
853
|
|
|
854
|
-
# Generate and save ROC
|
|
855
|
-
MetricUtils.
|
|
856
|
-
true_classes=true_classes,
|
|
857
|
-
pred_probs=predictions,
|
|
858
|
-
fold_number=fold_number,
|
|
859
|
-
run_name=self.run_name
|
|
860
|
-
)
|
|
854
|
+
# Generate and save ROC Curve and PR Curve for this fold
|
|
855
|
+
MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
|
|
861
856
|
|
|
862
|
-
# If final model, also log the ROC curve for the train set
|
|
857
|
+
# If final model, also log the ROC curve and PR curve for the train set
|
|
863
858
|
if fold_number == 0:
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
pred_probs=training_predictions,
|
|
867
|
-
fold_number=-2,
|
|
868
|
-
run_name=self.run_name
|
|
869
|
-
)
|
|
859
|
+
fold_number = -2 # -2 is the train set
|
|
860
|
+
MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
|
|
870
861
|
|
|
871
862
|
# Log visualization artifacts if they were generated
|
|
872
863
|
if temp_dir is not None:
|
|
@@ -197,22 +197,89 @@ class Utils:
|
|
|
197
197
|
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
198
198
|
pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
|
|
199
199
|
|
|
200
|
-
# Convert true labels to class indices if they're one-hot encoded
|
|
201
|
-
true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
202
|
-
|
|
203
200
|
# Calculate ROC curve and AUC score using probabilities
|
|
204
201
|
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
205
202
|
|
|
206
203
|
# Import functions
|
|
207
204
|
try:
|
|
208
|
-
from sklearn.metrics import
|
|
205
|
+
from sklearn.metrics import roc_auc_score, roc_curve
|
|
209
206
|
except ImportError as e:
|
|
210
207
|
raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
|
|
211
208
|
|
|
212
|
-
|
|
209
|
+
# Convert y_true to class indices for both functions
|
|
210
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
211
|
+
|
|
212
|
+
# Calculate AUC score directly using roc_auc_score
|
|
213
|
+
auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
|
|
214
|
+
|
|
215
|
+
# Calculate ROC curve points
|
|
216
|
+
results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
|
|
213
217
|
fpr: NDArray[np.single] = results[0]
|
|
214
218
|
tpr: NDArray[np.single] = results[1]
|
|
215
219
|
thresholds: NDArray[np.single] = results[2]
|
|
216
|
-
|
|
220
|
+
|
|
217
221
|
return auc_value, fpr, tpr, thresholds
|
|
218
222
|
|
|
223
|
+
@staticmethod
|
|
224
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
225
|
+
def get_pr_curve_and_auc(
|
|
226
|
+
y_true: NDArray[np.intc | np.single],
|
|
227
|
+
y_pred: NDArray[np.single],
|
|
228
|
+
negative: bool = False
|
|
229
|
+
) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
230
|
+
""" Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
|
|
234
|
+
y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
|
|
235
|
+
negative (bool): Whether to calculate the negative Precision-Recall Curve
|
|
236
|
+
Returns:
|
|
237
|
+
tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
238
|
+
Tuple containing either:
|
|
239
|
+
- AUC score, Average Precision, Precision, Recall, and Thresholds
|
|
240
|
+
- AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
|
|
241
|
+
|
|
242
|
+
Examples:
|
|
243
|
+
>>> # Binary classification example
|
|
244
|
+
>>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
|
|
245
|
+
>>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
|
|
246
|
+
>>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
|
|
247
|
+
>>> round(auc_value, 2)
|
|
248
|
+
0.92
|
|
249
|
+
>>> round(average_precision, 2)
|
|
250
|
+
0.83
|
|
251
|
+
>>> [round(x, 2) for x in precision.tolist()]
|
|
252
|
+
[0.4, 0.5, 0.67, 1.0, 1.0]
|
|
253
|
+
>>> [round(x, 2) for x in recall.tolist()]
|
|
254
|
+
[1.0, 1.0, 1.0, 0.5, 0.0]
|
|
255
|
+
>>> [round(x, 2) for x in thresholds.tolist()]
|
|
256
|
+
[0.2, 0.3, 0.8, 0.9]
|
|
257
|
+
"""
|
|
258
|
+
# For predictions, assert they are probabilities (one-hot encoded)
|
|
259
|
+
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
260
|
+
pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
|
|
261
|
+
|
|
262
|
+
# Calculate Precision-Recall Curve and AUC score using probabilities
|
|
263
|
+
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
264
|
+
|
|
265
|
+
# Import functions
|
|
266
|
+
try:
|
|
267
|
+
from sklearn.metrics import auc, average_precision_score, precision_recall_curve
|
|
268
|
+
except ImportError as e:
|
|
269
|
+
raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
|
|
270
|
+
|
|
271
|
+
# Convert y_true to class indices for both functions
|
|
272
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
273
|
+
|
|
274
|
+
results: tuple[Any, Any, Any] = precision_recall_curve(
|
|
275
|
+
y_true=y_true_indices,
|
|
276
|
+
probas_pred=pred_probs,
|
|
277
|
+
pos_label=1 if not negative else 0
|
|
278
|
+
)
|
|
279
|
+
precision: NDArray[np.single] = results[0]
|
|
280
|
+
recall: NDArray[np.single] = results[1]
|
|
281
|
+
thresholds: NDArray[np.single] = results[2]
|
|
282
|
+
auc_value: float = float(auc(recall, precision))
|
|
283
|
+
average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
|
|
284
|
+
return auc_value, average_precision, precision, recall, thresholds
|
|
285
|
+
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/__init__.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/auto_contrast.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/axis_flip.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/brightness.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/canny.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/clahe.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/common.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/contrast.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/denoise.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/invert.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/laplacian.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/median_blur.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/noise.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/normalize.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/random_erase.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/resize.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/rotation.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/salt_pepper.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/sharpening.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/shearing.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/threshold.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/translation.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_augmentation.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_preprocess.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/losses/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/visualizations.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|