stouputils 1.3.6__tar.gz → 1.3.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {stouputils-1.3.6 → stouputils-1.3.8}/PKG-INFO +1 -1
  2. {stouputils-1.3.6 → stouputils-1.3.8}/pyproject.toml +1 -1
  3. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/metric_dictionnary.py +13 -1
  4. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/metric_utils.py +170 -15
  5. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/model_interface.py +5 -14
  6. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/utils.py +73 -6
  7. {stouputils-1.3.6 → stouputils-1.3.8}/.gitignore +0 -0
  8. {stouputils-1.3.6 → stouputils-1.3.8}/LICENSE +0 -0
  9. {stouputils-1.3.6 → stouputils-1.3.8}/README.md +0 -0
  10. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/__init__.py +0 -0
  11. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/all_doctests.py +0 -0
  12. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/__init__.py +0 -0
  13. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/automatic_docs.py +0 -0
  14. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/__init__.py +0 -0
  15. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/config.py +0 -0
  16. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/image.py +0 -0
  17. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/applications/upscaler/video.py +0 -0
  18. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/archive.py +0 -0
  19. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/backup.py +0 -0
  20. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/collections.py +0 -0
  21. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/__init__.py +0 -0
  22. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/cd_utils.py +0 -0
  23. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/github.py +0 -0
  24. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/pypi.py +0 -0
  25. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/continuous_delivery/pyproject.py +0 -0
  26. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/ctx.py +0 -0
  27. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/config/get.py +0 -0
  28. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/config/set.py +0 -0
  29. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/__init__.py +0 -0
  30. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/auto_contrast.py +0 -0
  31. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/axis_flip.py +0 -0
  32. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/bias_field_correction.py +0 -0
  33. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/binary_threshold.py +0 -0
  34. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/blur.py +0 -0
  35. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/brightness.py +0 -0
  36. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/canny.py +0 -0
  37. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/clahe.py +0 -0
  38. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/common.py +0 -0
  39. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/contrast.py +0 -0
  40. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/curvature_flow_filter.py +0 -0
  41. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/denoise.py +0 -0
  42. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/histogram_equalization.py +0 -0
  43. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/invert.py +0 -0
  44. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/laplacian.py +0 -0
  45. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/median_blur.py +0 -0
  46. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/noise.py +0 -0
  47. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/normalize.py +0 -0
  48. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/random_erase.py +0 -0
  49. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/resize.py +0 -0
  50. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/rotation.py +0 -0
  51. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/salt_pepper.py +0 -0
  52. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/sharpening.py +0 -0
  53. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/shearing.py +0 -0
  54. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/threshold.py +0 -0
  55. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/translation.py +0 -0
  56. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image/zoom.py +0 -0
  57. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_augmentation.py +0 -0
  58. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/image_preprocess.py +0 -0
  59. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/prosthesis_detection.py +0 -0
  60. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/data_processing/technique.py +0 -0
  61. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/__init__.py +0 -0
  62. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/dataset.py +0 -0
  63. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/dataset_loader.py +0 -0
  64. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/grouping_strategy.py +0 -0
  65. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/image_loader.py +0 -0
  66. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/dataset/xy_tuple.py +0 -0
  67. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/mlflow_utils.py +0 -0
  68. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/abstract_model.py +0 -0
  69. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/all.py +0 -0
  70. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/base_keras.py +0 -0
  71. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/all.py +0 -0
  72. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/convnext.py +0 -0
  73. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/densenet.py +0 -0
  74. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/efficientnet.py +0 -0
  75. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/mobilenet.py +0 -0
  76. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/resnet.py +0 -0
  77. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/squeezenet.py +0 -0
  78. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/vgg.py +0 -0
  79. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras/xception.py +0 -0
  80. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/__init__.py +0 -0
  81. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +0 -0
  82. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +0 -0
  83. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +0 -0
  84. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +0 -0
  85. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +0 -0
  86. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/losses/__init__.py +0 -0
  87. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +0 -0
  88. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/keras_utils/visualizations.py +0 -0
  89. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/models/sandbox.py +0 -0
  90. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/range_tuple.py +0 -0
  91. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/augment_dataset.py +0 -0
  92. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/exhaustive_process.py +0 -0
  93. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/preprocess_dataset.py +0 -0
  94. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/data_science/scripts/routine.py +0 -0
  95. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/decorators.py +0 -0
  96. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/dont_look/zip_file_override.py +0 -0
  97. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/image.py +0 -0
  98. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/__init__.py +0 -0
  99. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/common.py +0 -0
  100. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/downloader.py +0 -0
  101. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/linux.py +0 -0
  102. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/main.py +0 -0
  103. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/installer/windows.py +0 -0
  104. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/io.py +0 -0
  105. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/parallel.py +0 -0
  106. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/print.py +0 -0
  107. {stouputils-1.3.6 → stouputils-1.3.8}/stouputils/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: stouputils
3
- Version: 1.3.6
3
+ Version: 1.3.8
4
4
  Summary: Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more.
5
5
  Project-URL: Homepage, https://github.com/Stoupy51/stouputils
6
6
  Project-URL: Issues, https://github.com/Stoupy51/stouputils/issues
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
5
5
 
6
6
  [project]
7
7
  name = "stouputils"
8
- version = "1.3.6"
8
+ version = "1.3.8"
9
9
  description = "Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more."
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.10"
@@ -8,12 +8,16 @@ This class contains the following metrics:
8
8
  1. Main metrics:
9
9
 
10
10
  - Area Under the Curve (AUC)
11
+ - Area Under the Precision-Recall Curve (AUPRC)
12
+ - Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
11
13
  - Specificity (True Negative Rate)
12
14
  - Recall/Sensitivity (True Positive Rate)
13
15
  - Precision (Positive Predictive Value)
14
16
  - Negative Predictive Value (NPV)
15
17
  - Accuracy
16
18
  - F1 Score
19
+ - Precision-Recall Average
20
+ - Precision-Recall Average for Negative Class
17
21
 
18
22
  2. Confusion matrix metrics:
19
23
 
@@ -39,6 +43,8 @@ This class contains the following metrics:
39
43
 
40
44
  - Youden's J statistic
41
45
  - Cost-based threshold
46
+ - F1 Score threshold
47
+ - F1 Score threshold for the negative class
42
48
 
43
49
  6. Average metrics across folds:
44
50
 
@@ -52,7 +58,9 @@ This class contains the following metrics:
52
58
  class MetricDictionnary:
53
59
 
54
60
  # Main metrics (starting with '1:')
55
- AUC: str = "1: Area Under the Curve: AUC"
61
+ AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
62
+ AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
63
+ NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
56
64
  SPECIFICITY: str = "1: Specificity: True Negative Rate"
57
65
  RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
58
66
  PRECISION: str = "1: Precision: Positive Predictive Value"
@@ -61,6 +69,8 @@ class MetricDictionnary:
61
69
  BALANCED_ACCURACY: str = "1: Balanced Accuracy"
62
70
  F1_SCORE: str = "1: F1 Score"
63
71
  F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
72
+ PR_AVERAGE: str = "1: Precision-Recall Average"
73
+ PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
64
74
 
65
75
  # Confusion matrix metrics (starting with '2:')
66
76
  CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
@@ -82,6 +92,8 @@ class MetricDictionnary:
82
92
  # Optimal thresholds (starting with '5:')
83
93
  OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
84
94
  OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
95
+ OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
96
+ OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
85
97
 
86
98
  # Average metrics across folds (starting with '6:')
87
99
  AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
@@ -117,10 +117,10 @@ class MetricUtils:
117
117
  if mcc_metric:
118
118
  metrics.update(mcc_metric)
119
119
 
120
- # Calculate and plot ROC/AUC
121
- roc_metrics: dict[str, float] = MetricUtils.roc_and_auc(true_classes, y_pred, fold_number=-1, run_name=run_name)
122
- if roc_metrics:
123
- metrics.update(roc_metrics)
120
+ # Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
121
+ curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
122
+ if curves_metrics:
123
+ metrics.update(curves_metrics)
124
124
 
125
125
  # Multiclass classification
126
126
  elif mode == "multiclass":
@@ -287,22 +287,24 @@ class MetricUtils:
287
287
 
288
288
  @staticmethod
289
289
  @handle_error(error_log=DataScienceConfig.ERROR_LOG)
290
- def roc_and_auc(
290
+ def roc_curve_and_auc(
291
291
  true_classes: NDArray[np.intc] | NDArray[np.single],
292
292
  pred_probs: NDArray[np.single],
293
293
  fold_number: int = -1,
294
- run_name: str = ""
294
+ run_name: str = "",
295
+ plot_if_minimum: int = 5
295
296
  ) -> dict[str, float]:
296
297
  """ Calculate ROC curve and AUC score.
297
298
 
298
299
  Args:
299
- true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
300
- pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
301
- fold_number (int): Fold number, used for naming the plot file, usually
300
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
301
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
302
+ fold_number (int): Fold number, used for naming the plot file, usually
302
303
  -1 for final model with test set,
303
304
  0 for final model with validation set,
304
305
  >0 for other folds with their validation set
305
- run_name (str): Name for saving the plot
306
+ run_name (str): Name for saving the plot
307
+ plot_if_minimum (int): Minimum number of samples required in true_classes to plot the ROC curve
306
308
  Returns:
307
309
  dict[str, float]: Dictionary containing AUC score and optimal thresholds
308
310
 
@@ -311,7 +313,7 @@ class MetricUtils:
311
313
  >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
312
314
  >>> from stouputils.ctx import Muffle
313
315
  >>> with Muffle():
314
- ... metrics = MetricUtils.roc_and_auc(true_classes, pred_probs, run_name="")
316
+ ... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
315
317
 
316
318
  >>> # Check metrics
317
319
  >>> round(float(metrics[MetricDictionnary.AUC]), 2)
@@ -321,7 +323,6 @@ class MetricUtils:
321
323
  >>> float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST])
322
324
  inf
323
325
  """
324
- true_classes = Utils.convert_to_class_indices(true_classes)
325
326
  auc_value, fpr, tpr, thresholds = Utils.get_roc_curve_and_auc(true_classes, pred_probs)
326
327
  metrics: dict[str, float] = {MetricDictionnary.AUC: auc_value}
327
328
 
@@ -339,8 +340,8 @@ class MetricUtils:
339
340
  optimal_threshold_cost: float = thresholds[np.argmin(total_cost)]
340
341
  metrics[MetricDictionnary.OPTIMAL_THRESHOLD_COST] = optimal_threshold_cost
341
342
 
342
- # Plot ROC curve if not nan
343
- if run_name and not np.isnan(auc_value):
343
+ # Plot ROC curve if run_name and minimum number of samples is reached
344
+ if run_name and len(true_classes) >= plot_if_minimum:
344
345
  plt.figure(figsize=(12, 6))
345
346
  plt.plot(fpr, tpr, "b", label=f"ROC curve (AUC = {auc_value:.2f})")
346
347
  plt.plot([0, 1], [0, 1], "r--")
@@ -377,6 +378,160 @@ class MetricUtils:
377
378
 
378
379
  return metrics
379
380
 
381
+ @staticmethod
382
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
383
+ def pr_curve_and_auc(
384
+ true_classes: NDArray[np.intc] | NDArray[np.single],
385
+ pred_probs: NDArray[np.single],
386
+ fold_number: int = -1,
387
+ run_name: str = "",
388
+ plot_if_minimum: int = 5
389
+ ) -> dict[str, float]:
390
+ """ Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
391
+
392
+ Args:
393
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
394
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
395
+ fold_number (int): Fold number, used for naming the plot file, usually
396
+ -1 for final model with test set,
397
+ 0 for final model with validation set,
398
+ >0 for other folds with their validation set
399
+ run_name (str): Name for saving the plot
400
+ plot_if_minimum (int): Minimum number of samples required in true_classes to plot the PR curves
401
+ Returns:
402
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds
403
+
404
+ Examples:
405
+ >>> true_classes = np.array([0, 1, 0])
406
+ >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
407
+ >>> from stouputils.ctx import Muffle
408
+ >>> with Muffle():
409
+ ... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
410
+
411
+ >>> # Check metrics
412
+ >>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
413
+ 0.75
414
+ >>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
415
+ 0.92
416
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
417
+ 0.5
418
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
419
+ 0.33
420
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
421
+ 0.9
422
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
423
+ 0.1
424
+ """
425
+ auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
426
+ neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
427
+ Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
428
+ )
429
+
430
+ # Calculate metrics
431
+ metrics: dict[str, float] = {
432
+ MetricDictionnary.AUPRC: auc_value,
433
+ MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
434
+ MetricDictionnary.PR_AVERAGE: average_precision,
435
+ MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
436
+ }
437
+
438
+ # Calculate optimal thresholds for both PR curves
439
+ for is_negative in (False, True):
440
+
441
+ # Get the right values based on positive/negative case
442
+ if not is_negative:
443
+ curr_precision = precision
444
+ curr_recall = recall
445
+ curr_thresholds = thresholds
446
+ curr_auc = auc_value
447
+ curr_ap = average_precision
448
+ else:
449
+ curr_precision = npv
450
+ curr_recall = specificity
451
+ curr_thresholds = neg_thresholds
452
+ curr_auc = neg_auc_value
453
+ curr_ap = average_precision_neg
454
+
455
+ # Calculate F-score for each threshold
456
+ fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
457
+ fscore = fscore[~np.isnan(fscore)]
458
+
459
+ # Get optimal threshold (maximum F-score)
460
+ if len(fscore) > 0:
461
+ optimal_idx: int = int(np.argmax(fscore))
462
+ optimal_threshold: float = curr_thresholds[optimal_idx]
463
+ else:
464
+ optimal_idx: int = 0
465
+ optimal_threshold = float('inf')
466
+
467
+ # Store in metrics dictionary
468
+ if not is_negative:
469
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
470
+ else:
471
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
472
+
473
+ # Plot PR curve if run_name and minimum number of samples is reached
474
+ if run_name and len(true_classes) >= plot_if_minimum:
475
+ label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
476
+ plt.figure(figsize=(12, 6))
477
+ plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
478
+
479
+ # Prepare the path
480
+ fold_name: str = ""
481
+ if fold_number > 0:
482
+ fold_name = f"_fold_{fold_number}_val"
483
+ elif fold_number == 0:
484
+ fold_name = "_val"
485
+ elif fold_number == -1:
486
+ fold_name = "_test"
487
+ elif fold_number == -2:
488
+ fold_name = "_train"
489
+ pr: str = "pr" if not is_negative else "negative_pr"
490
+ curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
491
+
492
+ plt.plot(
493
+ curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
494
+ )
495
+
496
+ plt.xlim([-0.01, 1.01])
497
+ plt.ylim([-0.01, 1.01])
498
+ plt.xlabel("Recall" if not is_negative else "Specificity")
499
+ plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
500
+ plt.title(f"{label} Curve")
501
+ plt.legend(loc="lower right")
502
+ plt.savefig(curve_path)
503
+ mlflow.log_artifact(curve_path)
504
+ os.remove(curve_path)
505
+ plt.close()
506
+
507
+ return metrics
508
+
509
+ @staticmethod
510
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
511
+ def all_curves(
512
+ true_classes: NDArray[np.intc] | NDArray[np.single],
513
+ pred_probs: NDArray[np.single],
514
+ fold_number: int = -1,
515
+ run_name: str = ""
516
+ ) -> dict[str, float]:
517
+ """ Run all X_curve_and_auc functions and return a dictionary of metrics.
518
+
519
+ Args:
520
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
521
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
522
+ fold_number (int): Fold number, used for naming the plot file, usually
523
+ -1 for final model with test set,
524
+ 0 for final model with validation set,
525
+ >0 for other folds with their validation set
526
+ run_name (str): Name for saving the plot
527
+ Returns:
528
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
529
+ """
530
+ metrics: dict[str, float] = {}
531
+ metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
532
+ metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
533
+ return metrics
534
+
380
535
 
381
536
  @staticmethod
382
537
  @handle_error(error_log=DataScienceConfig.ERROR_LOG)
@@ -609,7 +764,7 @@ class MetricUtils:
609
764
 
610
765
  # Find all local minima
611
766
  from scipy.signal import argrelextrema
612
- local_minima_idx: NDArray[np.intp] = argrelextrema(y_array, np.less)[0]
767
+ local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
613
768
  distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
614
769
  else:
615
770
  assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
@@ -851,22 +851,13 @@ class ModelInterface(AbstractModel):
851
851
  self.history.append(history)
852
852
  self.evaluation_results.append(eval_results)
853
853
 
854
- # Generate and save ROC curve for this fold
855
- MetricUtils.roc_and_auc(
856
- true_classes=true_classes,
857
- pred_probs=predictions,
858
- fold_number=fold_number,
859
- run_name=self.run_name
860
- )
854
+ # Generate and save ROC Curve and PR Curve for this fold
855
+ MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
861
856
 
862
- # If final model, also log the ROC curve for the train set
857
+ # If final model, also log the ROC curve and PR curve for the train set
863
858
  if fold_number == 0:
864
- MetricUtils.roc_and_auc(
865
- true_classes=training_true_classes,
866
- pred_probs=training_predictions,
867
- fold_number=-2,
868
- run_name=self.run_name
869
- )
859
+ fold_number = -2 # -2 is the train set
860
+ MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
870
861
 
871
862
  # Log visualization artifacts if they were generated
872
863
  if temp_dir is not None:
@@ -197,22 +197,89 @@ class Utils:
197
197
  assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
198
198
  pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
199
199
 
200
- # Convert true labels to class indices if they're one-hot encoded
201
- true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
202
-
203
200
  # Calculate ROC curve and AUC score using probabilities
204
201
  with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
205
202
 
206
203
  # Import functions
207
204
  try:
208
- from sklearn.metrics import auc, roc_curve
205
+ from sklearn.metrics import roc_auc_score, roc_curve
209
206
  except ImportError as e:
210
207
  raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
211
208
 
212
- results: tuple[Any, Any, Any] = roc_curve(true_classes, pred_probs, drop_intermediate=False)
209
+ # Convert y_true to class indices for both functions
210
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
211
+
212
+ # Calculate AUC score directly using roc_auc_score
213
+ auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
214
+
215
+ # Calculate ROC curve points
216
+ results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
213
217
  fpr: NDArray[np.single] = results[0]
214
218
  tpr: NDArray[np.single] = results[1]
215
219
  thresholds: NDArray[np.single] = results[2]
216
- auc_value: float = float(auc(fpr, tpr))
220
+
217
221
  return auc_value, fpr, tpr, thresholds
218
222
 
223
+ @staticmethod
224
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
225
+ def get_pr_curve_and_auc(
226
+ y_true: NDArray[np.intc | np.single],
227
+ y_pred: NDArray[np.single],
228
+ negative: bool = False
229
+ ) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
230
+ """ Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
231
+
232
+ Args:
233
+ y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
234
+ y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
235
+ negative (bool): Whether to calculate the negative Precision-Recall Curve
236
+ Returns:
237
+ tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
238
+ Tuple containing either:
239
+ - AUC score, Average Precision, Precision, Recall, and Thresholds
240
+ - AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
241
+
242
+ Examples:
243
+ >>> # Binary classification example
244
+ >>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
245
+ >>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
246
+ >>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
247
+ >>> round(auc_value, 2)
248
+ 0.92
249
+ >>> round(average_precision, 2)
250
+ 0.83
251
+ >>> [round(x, 2) for x in precision.tolist()]
252
+ [0.4, 0.5, 0.67, 1.0, 1.0]
253
+ >>> [round(x, 2) for x in recall.tolist()]
254
+ [1.0, 1.0, 1.0, 0.5, 0.0]
255
+ >>> [round(x, 2) for x in thresholds.tolist()]
256
+ [0.2, 0.3, 0.8, 0.9]
257
+ """
258
+ # For predictions, assert they are probabilities (one-hot encoded)
259
+ assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
260
+ pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
261
+
262
+ # Calculate Precision-Recall Curve and AUC score using probabilities
263
+ with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
264
+
265
+ # Import functions
266
+ try:
267
+ from sklearn.metrics import auc, average_precision_score, precision_recall_curve
268
+ except ImportError as e:
269
+ raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
270
+
271
+ # Convert y_true to class indices for both functions
272
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
273
+
274
+ results: tuple[Any, Any, Any] = precision_recall_curve(
275
+ y_true=y_true_indices,
276
+ probas_pred=pred_probs,
277
+ pos_label=1 if not negative else 0
278
+ )
279
+ precision: NDArray[np.single] = results[0]
280
+ recall: NDArray[np.single] = results[1]
281
+ thresholds: NDArray[np.single] = results[2]
282
+ auc_value: float = float(auc(recall, precision))
283
+ average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
284
+ return auc_value, average_precision, precision, recall, thresholds
285
+
File without changes
File without changes
File without changes
File without changes
File without changes