stouputils 1.3.6__tar.gz → 1.3.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {stouputils-1.3.6 → stouputils-1.3.7}/PKG-INFO +1 -1
- {stouputils-1.3.6 → stouputils-1.3.7}/pyproject.toml +1 -1
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/metric_dictionnary.py +13 -1
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/metric_utils.py +159 -7
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/model_interface.py +5 -14
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/utils.py +73 -6
- {stouputils-1.3.6 → stouputils-1.3.7}/.gitignore +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/LICENSE +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/README.md +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/all_doctests.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/automatic_docs.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/upscaler/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/upscaler/config.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/upscaler/image.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/applications/upscaler/video.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/archive.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/backup.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/collections.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/continuous_delivery/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/continuous_delivery/cd_utils.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/continuous_delivery/github.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/continuous_delivery/pypi.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/continuous_delivery/pyproject.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/ctx.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/config/get.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/config/set.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/auto_contrast.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/axis_flip.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/bias_field_correction.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/binary_threshold.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/blur.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/brightness.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/canny.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/clahe.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/common.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/contrast.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/curvature_flow_filter.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/denoise.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/histogram_equalization.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/invert.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/laplacian.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/median_blur.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/noise.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/normalize.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/random_erase.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/resize.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/rotation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/salt_pepper.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/sharpening.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/shearing.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/threshold.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/translation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/zoom.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_augmentation.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_preprocess.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/prosthesis_detection.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/technique.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/dataset_loader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/grouping_strategy.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/image_loader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/dataset/xy_tuple.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/mlflow_utils.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/abstract_model.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/all.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/base_keras.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/all.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/convnext.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/densenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/efficientnet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/mobilenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/resnet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/squeezenet.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/vgg.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras/xception.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/losses/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/visualizations.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/sandbox.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/range_tuple.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/scripts/augment_dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/scripts/exhaustive_process.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/scripts/preprocess_dataset.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/scripts/routine.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/decorators.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/dont_look/zip_file_override.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/image.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/__init__.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/common.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/downloader.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/linux.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/main.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/installer/windows.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/io.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/parallel.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/print.py +0 -0
- {stouputils-1.3.6 → stouputils-1.3.7}/stouputils/py.typed +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: stouputils
|
|
3
|
-
Version: 1.3.
|
|
3
|
+
Version: 1.3.7
|
|
4
4
|
Summary: Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more.
|
|
5
5
|
Project-URL: Homepage, https://github.com/Stoupy51/stouputils
|
|
6
6
|
Project-URL: Issues, https://github.com/Stoupy51/stouputils/issues
|
|
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
|
|
|
5
5
|
|
|
6
6
|
[project]
|
|
7
7
|
name = "stouputils"
|
|
8
|
-
version = "1.3.
|
|
8
|
+
version = "1.3.7"
|
|
9
9
|
description = "Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more."
|
|
10
10
|
readme = "README.md"
|
|
11
11
|
requires-python = ">=3.10"
|
|
@@ -8,12 +8,16 @@ This class contains the following metrics:
|
|
|
8
8
|
1. Main metrics:
|
|
9
9
|
|
|
10
10
|
- Area Under the Curve (AUC)
|
|
11
|
+
- Area Under the Precision-Recall Curve (AUPRC)
|
|
12
|
+
- Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
|
|
11
13
|
- Specificity (True Negative Rate)
|
|
12
14
|
- Recall/Sensitivity (True Positive Rate)
|
|
13
15
|
- Precision (Positive Predictive Value)
|
|
14
16
|
- Negative Predictive Value (NPV)
|
|
15
17
|
- Accuracy
|
|
16
18
|
- F1 Score
|
|
19
|
+
- Precision-Recall Average
|
|
20
|
+
- Precision-Recall Average for Negative Class
|
|
17
21
|
|
|
18
22
|
2. Confusion matrix metrics:
|
|
19
23
|
|
|
@@ -39,6 +43,8 @@ This class contains the following metrics:
|
|
|
39
43
|
|
|
40
44
|
- Youden's J statistic
|
|
41
45
|
- Cost-based threshold
|
|
46
|
+
- F1 Score threshold
|
|
47
|
+
- F1 Score threshold for the negative class
|
|
42
48
|
|
|
43
49
|
6. Average metrics across folds:
|
|
44
50
|
|
|
@@ -52,7 +58,9 @@ This class contains the following metrics:
|
|
|
52
58
|
class MetricDictionnary:
|
|
53
59
|
|
|
54
60
|
# Main metrics (starting with '1:')
|
|
55
|
-
AUC: str = "1: Area Under the Curve: AUC"
|
|
61
|
+
AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
|
|
62
|
+
AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
|
|
63
|
+
NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
|
|
56
64
|
SPECIFICITY: str = "1: Specificity: True Negative Rate"
|
|
57
65
|
RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
|
|
58
66
|
PRECISION: str = "1: Precision: Positive Predictive Value"
|
|
@@ -61,6 +69,8 @@ class MetricDictionnary:
|
|
|
61
69
|
BALANCED_ACCURACY: str = "1: Balanced Accuracy"
|
|
62
70
|
F1_SCORE: str = "1: F1 Score"
|
|
63
71
|
F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
|
|
72
|
+
PR_AVERAGE: str = "1: Precision-Recall Average"
|
|
73
|
+
PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
|
|
64
74
|
|
|
65
75
|
# Confusion matrix metrics (starting with '2:')
|
|
66
76
|
CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
|
|
@@ -82,6 +92,8 @@ class MetricDictionnary:
|
|
|
82
92
|
# Optimal thresholds (starting with '5:')
|
|
83
93
|
OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
|
|
84
94
|
OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
|
|
95
|
+
OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
|
|
96
|
+
OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
|
|
85
97
|
|
|
86
98
|
# Average metrics across folds (starting with '6:')
|
|
87
99
|
AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
|
|
@@ -117,10 +117,10 @@ class MetricUtils:
|
|
|
117
117
|
if mcc_metric:
|
|
118
118
|
metrics.update(mcc_metric)
|
|
119
119
|
|
|
120
|
-
# Calculate and plot ROC/AUC
|
|
121
|
-
|
|
122
|
-
if
|
|
123
|
-
metrics.update(
|
|
120
|
+
# Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
|
|
121
|
+
curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
|
|
122
|
+
if curves_metrics:
|
|
123
|
+
metrics.update(curves_metrics)
|
|
124
124
|
|
|
125
125
|
# Multiclass classification
|
|
126
126
|
elif mode == "multiclass":
|
|
@@ -287,7 +287,7 @@ class MetricUtils:
|
|
|
287
287
|
|
|
288
288
|
@staticmethod
|
|
289
289
|
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
290
|
-
def
|
|
290
|
+
def roc_curve_and_auc(
|
|
291
291
|
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
292
292
|
pred_probs: NDArray[np.single],
|
|
293
293
|
fold_number: int = -1,
|
|
@@ -311,7 +311,7 @@ class MetricUtils:
|
|
|
311
311
|
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
312
312
|
>>> from stouputils.ctx import Muffle
|
|
313
313
|
>>> with Muffle():
|
|
314
|
-
... metrics = MetricUtils.
|
|
314
|
+
... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
315
315
|
|
|
316
316
|
>>> # Check metrics
|
|
317
317
|
>>> round(float(metrics[MetricDictionnary.AUC]), 2)
|
|
@@ -377,6 +377,158 @@ class MetricUtils:
|
|
|
377
377
|
|
|
378
378
|
return metrics
|
|
379
379
|
|
|
380
|
+
@staticmethod
|
|
381
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
382
|
+
def pr_curve_and_auc(
|
|
383
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
384
|
+
pred_probs: NDArray[np.single],
|
|
385
|
+
fold_number: int = -1,
|
|
386
|
+
run_name: str = ""
|
|
387
|
+
) -> dict[str, float]:
|
|
388
|
+
""" Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
392
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
393
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
394
|
+
-1 for final model with test set,
|
|
395
|
+
0 for final model with validation set,
|
|
396
|
+
>0 for other folds with their validation set
|
|
397
|
+
run_name (str): Name for saving the plot
|
|
398
|
+
Returns:
|
|
399
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds
|
|
400
|
+
|
|
401
|
+
Examples:
|
|
402
|
+
>>> true_classes = np.array([0, 1, 0])
|
|
403
|
+
>>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
|
|
404
|
+
>>> from stouputils.ctx import Muffle
|
|
405
|
+
>>> with Muffle():
|
|
406
|
+
... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
|
|
407
|
+
|
|
408
|
+
>>> # Check metrics
|
|
409
|
+
>>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
|
|
410
|
+
0.75
|
|
411
|
+
>>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
|
|
412
|
+
0.92
|
|
413
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
|
|
414
|
+
0.5
|
|
415
|
+
>>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
|
|
416
|
+
0.33
|
|
417
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
|
|
418
|
+
0.9
|
|
419
|
+
>>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
|
|
420
|
+
0.1
|
|
421
|
+
"""
|
|
422
|
+
auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
|
|
423
|
+
neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
|
|
424
|
+
Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# Calculate metrics
|
|
428
|
+
metrics: dict[str, float] = {
|
|
429
|
+
MetricDictionnary.AUPRC: auc_value,
|
|
430
|
+
MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
|
|
431
|
+
MetricDictionnary.PR_AVERAGE: average_precision,
|
|
432
|
+
MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
|
|
433
|
+
}
|
|
434
|
+
|
|
435
|
+
# Calculate optimal thresholds for both PR curves
|
|
436
|
+
for is_negative in (False, True):
|
|
437
|
+
|
|
438
|
+
# Get the right values based on positive/negative case
|
|
439
|
+
if not is_negative:
|
|
440
|
+
curr_precision = precision
|
|
441
|
+
curr_recall = recall
|
|
442
|
+
curr_thresholds = thresholds
|
|
443
|
+
curr_auc = auc_value
|
|
444
|
+
curr_ap = average_precision
|
|
445
|
+
else:
|
|
446
|
+
curr_precision = npv
|
|
447
|
+
curr_recall = specificity
|
|
448
|
+
curr_thresholds = neg_thresholds
|
|
449
|
+
curr_auc = neg_auc_value
|
|
450
|
+
curr_ap = average_precision_neg
|
|
451
|
+
|
|
452
|
+
# Calculate F-score for each threshold
|
|
453
|
+
fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
|
|
454
|
+
fscore = fscore[~np.isnan(fscore)]
|
|
455
|
+
|
|
456
|
+
# Get optimal threshold (maximum F-score)
|
|
457
|
+
if len(fscore) > 0:
|
|
458
|
+
optimal_idx: int = int(np.argmax(fscore))
|
|
459
|
+
optimal_threshold: float = curr_thresholds[optimal_idx]
|
|
460
|
+
else:
|
|
461
|
+
optimal_idx: int = 0
|
|
462
|
+
optimal_threshold = float('inf')
|
|
463
|
+
|
|
464
|
+
# Store in metrics dictionary
|
|
465
|
+
if not is_negative:
|
|
466
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
|
|
467
|
+
else:
|
|
468
|
+
metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
|
|
469
|
+
|
|
470
|
+
# Plot ROC curve if not nan
|
|
471
|
+
if run_name:
|
|
472
|
+
label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
|
|
473
|
+
plt.figure(figsize=(12, 6))
|
|
474
|
+
plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
|
|
475
|
+
|
|
476
|
+
# Prepare the path
|
|
477
|
+
fold_name: str = ""
|
|
478
|
+
if fold_number > 0:
|
|
479
|
+
fold_name = f"_fold_{fold_number}_val"
|
|
480
|
+
elif fold_number == 0:
|
|
481
|
+
fold_name = "_val"
|
|
482
|
+
elif fold_number == -1:
|
|
483
|
+
fold_name = "_test"
|
|
484
|
+
elif fold_number == -2:
|
|
485
|
+
fold_name = "_train"
|
|
486
|
+
pr: str = "pr" if not is_negative else "negative_pr"
|
|
487
|
+
curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
|
|
488
|
+
|
|
489
|
+
plt.plot(
|
|
490
|
+
curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
plt.xlim([-0.01, 1.01])
|
|
494
|
+
plt.ylim([-0.01, 1.01])
|
|
495
|
+
plt.xlabel("Recall" if not is_negative else "Specificity")
|
|
496
|
+
plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
|
|
497
|
+
plt.title(f"{label} Curve")
|
|
498
|
+
plt.legend(loc="lower right")
|
|
499
|
+
plt.savefig(curve_path)
|
|
500
|
+
mlflow.log_artifact(curve_path)
|
|
501
|
+
os.remove(curve_path)
|
|
502
|
+
plt.close()
|
|
503
|
+
|
|
504
|
+
return metrics
|
|
505
|
+
|
|
506
|
+
@staticmethod
|
|
507
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
508
|
+
def all_curves(
|
|
509
|
+
true_classes: NDArray[np.intc] | NDArray[np.single],
|
|
510
|
+
pred_probs: NDArray[np.single],
|
|
511
|
+
fold_number: int = -1,
|
|
512
|
+
run_name: str = ""
|
|
513
|
+
) -> dict[str, float]:
|
|
514
|
+
""" Run all X_curve_and_auc functions and return a dictionary of metrics.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
|
|
518
|
+
pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
|
|
519
|
+
fold_number (int): Fold number, used for naming the plot file, usually
|
|
520
|
+
-1 for final model with test set,
|
|
521
|
+
0 for final model with validation set,
|
|
522
|
+
>0 for other folds with their validation set
|
|
523
|
+
run_name (str): Name for saving the plot
|
|
524
|
+
Returns:
|
|
525
|
+
dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
|
|
526
|
+
"""
|
|
527
|
+
metrics: dict[str, float] = {}
|
|
528
|
+
metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
529
|
+
metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
|
|
530
|
+
return metrics
|
|
531
|
+
|
|
380
532
|
|
|
381
533
|
@staticmethod
|
|
382
534
|
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
@@ -609,7 +761,7 @@ class MetricUtils:
|
|
|
609
761
|
|
|
610
762
|
# Find all local minima
|
|
611
763
|
from scipy.signal import argrelextrema
|
|
612
|
-
local_minima_idx: NDArray[np.intp] = argrelextrema(y_array, np.less)[0]
|
|
764
|
+
local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
|
|
613
765
|
distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
|
|
614
766
|
else:
|
|
615
767
|
assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
|
|
@@ -851,22 +851,13 @@ class ModelInterface(AbstractModel):
|
|
|
851
851
|
self.history.append(history)
|
|
852
852
|
self.evaluation_results.append(eval_results)
|
|
853
853
|
|
|
854
|
-
# Generate and save ROC
|
|
855
|
-
MetricUtils.
|
|
856
|
-
true_classes=true_classes,
|
|
857
|
-
pred_probs=predictions,
|
|
858
|
-
fold_number=fold_number,
|
|
859
|
-
run_name=self.run_name
|
|
860
|
-
)
|
|
854
|
+
# Generate and save ROC Curve and PR Curve for this fold
|
|
855
|
+
MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
|
|
861
856
|
|
|
862
|
-
# If final model, also log the ROC curve for the train set
|
|
857
|
+
# If final model, also log the ROC curve and PR curve for the train set
|
|
863
858
|
if fold_number == 0:
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
pred_probs=training_predictions,
|
|
867
|
-
fold_number=-2,
|
|
868
|
-
run_name=self.run_name
|
|
869
|
-
)
|
|
859
|
+
fold_number = -2 # -2 is the train set
|
|
860
|
+
MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
|
|
870
861
|
|
|
871
862
|
# Log visualization artifacts if they were generated
|
|
872
863
|
if temp_dir is not None:
|
|
@@ -197,22 +197,89 @@ class Utils:
|
|
|
197
197
|
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
198
198
|
pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
|
|
199
199
|
|
|
200
|
-
# Convert true labels to class indices if they're one-hot encoded
|
|
201
|
-
true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
202
|
-
|
|
203
200
|
# Calculate ROC curve and AUC score using probabilities
|
|
204
201
|
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
205
202
|
|
|
206
203
|
# Import functions
|
|
207
204
|
try:
|
|
208
|
-
from sklearn.metrics import
|
|
205
|
+
from sklearn.metrics import roc_auc_score, roc_curve
|
|
209
206
|
except ImportError as e:
|
|
210
207
|
raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
|
|
211
208
|
|
|
212
|
-
|
|
209
|
+
# Convert y_true to class indices for both functions
|
|
210
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
211
|
+
|
|
212
|
+
# Calculate AUC score directly using roc_auc_score
|
|
213
|
+
auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
|
|
214
|
+
|
|
215
|
+
# Calculate ROC curve points
|
|
216
|
+
results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
|
|
213
217
|
fpr: NDArray[np.single] = results[0]
|
|
214
218
|
tpr: NDArray[np.single] = results[1]
|
|
215
219
|
thresholds: NDArray[np.single] = results[2]
|
|
216
|
-
|
|
220
|
+
|
|
217
221
|
return auc_value, fpr, tpr, thresholds
|
|
218
222
|
|
|
223
|
+
@staticmethod
|
|
224
|
+
@handle_error(error_log=DataScienceConfig.ERROR_LOG)
|
|
225
|
+
def get_pr_curve_and_auc(
|
|
226
|
+
y_true: NDArray[np.intc | np.single],
|
|
227
|
+
y_pred: NDArray[np.single],
|
|
228
|
+
negative: bool = False
|
|
229
|
+
) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
230
|
+
""" Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
|
|
234
|
+
y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
|
|
235
|
+
negative (bool): Whether to calculate the negative Precision-Recall Curve
|
|
236
|
+
Returns:
|
|
237
|
+
tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
|
|
238
|
+
Tuple containing either:
|
|
239
|
+
- AUC score, Average Precision, Precision, Recall, and Thresholds
|
|
240
|
+
- AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
|
|
241
|
+
|
|
242
|
+
Examples:
|
|
243
|
+
>>> # Binary classification example
|
|
244
|
+
>>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
|
|
245
|
+
>>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
|
|
246
|
+
>>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
|
|
247
|
+
>>> round(auc_value, 2)
|
|
248
|
+
0.92
|
|
249
|
+
>>> round(average_precision, 2)
|
|
250
|
+
0.83
|
|
251
|
+
>>> [round(x, 2) for x in precision.tolist()]
|
|
252
|
+
[0.4, 0.5, 0.67, 1.0, 1.0]
|
|
253
|
+
>>> [round(x, 2) for x in recall.tolist()]
|
|
254
|
+
[1.0, 1.0, 1.0, 0.5, 0.0]
|
|
255
|
+
>>> [round(x, 2) for x in thresholds.tolist()]
|
|
256
|
+
[0.2, 0.3, 0.8, 0.9]
|
|
257
|
+
"""
|
|
258
|
+
# For predictions, assert they are probabilities (one-hot encoded)
|
|
259
|
+
assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
|
|
260
|
+
pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
|
|
261
|
+
|
|
262
|
+
# Calculate Precision-Recall Curve and AUC score using probabilities
|
|
263
|
+
with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
|
|
264
|
+
|
|
265
|
+
# Import functions
|
|
266
|
+
try:
|
|
267
|
+
from sklearn.metrics import auc, average_precision_score, precision_recall_curve
|
|
268
|
+
except ImportError as e:
|
|
269
|
+
raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
|
|
270
|
+
|
|
271
|
+
# Convert y_true to class indices for both functions
|
|
272
|
+
y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
|
|
273
|
+
|
|
274
|
+
results: tuple[Any, Any, Any] = precision_recall_curve(
|
|
275
|
+
y_true=y_true_indices,
|
|
276
|
+
probas_pred=pred_probs,
|
|
277
|
+
pos_label=1 if not negative else 0
|
|
278
|
+
)
|
|
279
|
+
precision: NDArray[np.single] = results[0]
|
|
280
|
+
recall: NDArray[np.single] = results[1]
|
|
281
|
+
thresholds: NDArray[np.single] = results[2]
|
|
282
|
+
auc_value: float = float(auc(recall, precision))
|
|
283
|
+
average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
|
|
284
|
+
return auc_value, average_precision, precision, recall, thresholds
|
|
285
|
+
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/__init__.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/auto_contrast.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/axis_flip.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/brightness.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/canny.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/clahe.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/common.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/contrast.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/denoise.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/invert.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/laplacian.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/median_blur.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/noise.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/normalize.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/random_erase.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/resize.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/rotation.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/salt_pepper.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/sharpening.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/shearing.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/threshold.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/translation.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_augmentation.py
RENAMED
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_preprocess.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/losses/__init__.py
RENAMED
|
File without changes
|
|
File without changes
|
{stouputils-1.3.6 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/visualizations.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|