stouputils 1.3.5__tar.gz → 1.3.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. {stouputils-1.3.5 → stouputils-1.3.7}/PKG-INFO +1 -1
  2. {stouputils-1.3.5 → stouputils-1.3.7}/pyproject.toml +1 -1
  3. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/metric_dictionnary.py +13 -1
  4. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/metric_utils.py +159 -7
  5. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/model_interface.py +5 -14
  6. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/utils.py +73 -6
  7. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/decorators.py +17 -0
  8. {stouputils-1.3.5 → stouputils-1.3.7}/.gitignore +0 -0
  9. {stouputils-1.3.5 → stouputils-1.3.7}/LICENSE +0 -0
  10. {stouputils-1.3.5 → stouputils-1.3.7}/README.md +0 -0
  11. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/__init__.py +0 -0
  12. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/all_doctests.py +0 -0
  13. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/__init__.py +0 -0
  14. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/automatic_docs.py +0 -0
  15. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/upscaler/__init__.py +0 -0
  16. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/upscaler/config.py +0 -0
  17. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/upscaler/image.py +0 -0
  18. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/applications/upscaler/video.py +0 -0
  19. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/archive.py +0 -0
  20. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/backup.py +0 -0
  21. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/collections.py +0 -0
  22. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/continuous_delivery/__init__.py +0 -0
  23. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/continuous_delivery/cd_utils.py +0 -0
  24. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/continuous_delivery/github.py +0 -0
  25. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/continuous_delivery/pypi.py +0 -0
  26. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/continuous_delivery/pyproject.py +0 -0
  27. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/ctx.py +0 -0
  28. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/config/get.py +0 -0
  29. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/config/set.py +0 -0
  30. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/__init__.py +0 -0
  31. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/auto_contrast.py +0 -0
  32. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/axis_flip.py +0 -0
  33. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/bias_field_correction.py +0 -0
  34. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/binary_threshold.py +0 -0
  35. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/blur.py +0 -0
  36. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/brightness.py +0 -0
  37. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/canny.py +0 -0
  38. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/clahe.py +0 -0
  39. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/common.py +0 -0
  40. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/contrast.py +0 -0
  41. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/curvature_flow_filter.py +0 -0
  42. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/denoise.py +0 -0
  43. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/histogram_equalization.py +0 -0
  44. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/invert.py +0 -0
  45. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/laplacian.py +0 -0
  46. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/median_blur.py +0 -0
  47. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/noise.py +0 -0
  48. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/normalize.py +0 -0
  49. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/random_erase.py +0 -0
  50. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/resize.py +0 -0
  51. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/rotation.py +0 -0
  52. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/salt_pepper.py +0 -0
  53. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/sharpening.py +0 -0
  54. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/shearing.py +0 -0
  55. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/threshold.py +0 -0
  56. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/translation.py +0 -0
  57. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image/zoom.py +0 -0
  58. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_augmentation.py +0 -0
  59. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/image_preprocess.py +0 -0
  60. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/prosthesis_detection.py +0 -0
  61. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/data_processing/technique.py +0 -0
  62. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/__init__.py +0 -0
  63. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/dataset.py +0 -0
  64. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/dataset_loader.py +0 -0
  65. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/grouping_strategy.py +0 -0
  66. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/image_loader.py +0 -0
  67. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/dataset/xy_tuple.py +0 -0
  68. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/mlflow_utils.py +0 -0
  69. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/abstract_model.py +0 -0
  70. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/all.py +0 -0
  71. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/base_keras.py +0 -0
  72. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/all.py +0 -0
  73. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/convnext.py +0 -0
  74. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/densenet.py +0 -0
  75. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/efficientnet.py +0 -0
  76. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/mobilenet.py +0 -0
  77. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/resnet.py +0 -0
  78. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/squeezenet.py +0 -0
  79. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/vgg.py +0 -0
  80. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras/xception.py +0 -0
  81. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/__init__.py +0 -0
  82. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/colored_progress_bar.py +0 -0
  83. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/learning_rate_finder.py +0 -0
  84. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/model_checkpoint_v2.py +0 -0
  85. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/progressive_unfreezing.py +0 -0
  86. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/callbacks/warmup_scheduler.py +0 -0
  87. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/losses/__init__.py +0 -0
  88. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/losses/next_generation_loss.py +0 -0
  89. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/keras_utils/visualizations.py +0 -0
  90. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/models/sandbox.py +0 -0
  91. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/range_tuple.py +0 -0
  92. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/scripts/augment_dataset.py +0 -0
  93. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/scripts/exhaustive_process.py +0 -0
  94. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/scripts/preprocess_dataset.py +0 -0
  95. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/data_science/scripts/routine.py +0 -0
  96. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/dont_look/zip_file_override.py +0 -0
  97. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/image.py +0 -0
  98. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/__init__.py +0 -0
  99. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/common.py +0 -0
  100. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/downloader.py +0 -0
  101. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/linux.py +0 -0
  102. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/main.py +0 -0
  103. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/installer/windows.py +0 -0
  104. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/io.py +0 -0
  105. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/parallel.py +0 -0
  106. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/print.py +0 -0
  107. {stouputils-1.3.5 → stouputils-1.3.7}/stouputils/py.typed +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: stouputils
3
- Version: 1.3.5
3
+ Version: 1.3.7
4
4
  Summary: Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more.
5
5
  Project-URL: Homepage, https://github.com/Stoupy51/stouputils
6
6
  Project-URL: Issues, https://github.com/Stoupy51/stouputils/issues
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
5
5
 
6
6
  [project]
7
7
  name = "stouputils"
8
- version = "1.3.5"
8
+ version = "1.3.7"
9
9
  description = "Stouputils is a collection of utility modules designed to simplify and enhance the development process. It includes a range of tools for tasks such as execution of doctests, display utilities, decorators, as well as context managers, and many more."
10
10
  readme = "README.md"
11
11
  requires-python = ">=3.10"
@@ -8,12 +8,16 @@ This class contains the following metrics:
8
8
  1. Main metrics:
9
9
 
10
10
  - Area Under the Curve (AUC)
11
+ - Area Under the Precision-Recall Curve (AUPRC)
12
+ - Area Under the NPV-Specificity Curve (NEGATIVE_AUPRC)
11
13
  - Specificity (True Negative Rate)
12
14
  - Recall/Sensitivity (True Positive Rate)
13
15
  - Precision (Positive Predictive Value)
14
16
  - Negative Predictive Value (NPV)
15
17
  - Accuracy
16
18
  - F1 Score
19
+ - Precision-Recall Average
20
+ - Precision-Recall Average for Negative Class
17
21
 
18
22
  2. Confusion matrix metrics:
19
23
 
@@ -39,6 +43,8 @@ This class contains the following metrics:
39
43
 
40
44
  - Youden's J statistic
41
45
  - Cost-based threshold
46
+ - F1 Score threshold
47
+ - F1 Score threshold for the negative class
42
48
 
43
49
  6. Average metrics across folds:
44
50
 
@@ -52,7 +58,9 @@ This class contains the following metrics:
52
58
  class MetricDictionnary:
53
59
 
54
60
  # Main metrics (starting with '1:')
55
- AUC: str = "1: Area Under the Curve: AUC"
61
+ AUC: str = "1: Area Under the ROC Curve: AUC / AUROC"
62
+ AUPRC: str = "1: Area Under the Precision-Recall Curve: AUPRC / PR AUC"
63
+ NEGATIVE_AUPRC: str = "1: Area Under the NPV-Specificity Curve: AUNPRC / NPR AUC"
56
64
  SPECIFICITY: str = "1: Specificity: True Negative Rate"
57
65
  RECALL: str = "1: Recall/Sensitivity: True Positive Rate"
58
66
  PRECISION: str = "1: Precision: Positive Predictive Value"
@@ -61,6 +69,8 @@ class MetricDictionnary:
61
69
  BALANCED_ACCURACY: str = "1: Balanced Accuracy"
62
70
  F1_SCORE: str = "1: F1 Score"
63
71
  F1_SCORE_NEGATIVE: str = "1: F1 Score for Negative Class"
72
+ PR_AVERAGE: str = "1: Precision-Recall Average"
73
+ PR_AVERAGE_NEGATIVE: str = "1: Precision-Recall Average for Negative Class"
64
74
 
65
75
  # Confusion matrix metrics (starting with '2:')
66
76
  CONFUSION_MATRIX_TN: str = "2: Confusion Matrix: TN"
@@ -82,6 +92,8 @@ class MetricDictionnary:
82
92
  # Optimal thresholds (starting with '5:')
83
93
  OPTIMAL_THRESHOLD_YOUDEN: str = "5: Optimal Threshold: Youden"
84
94
  OPTIMAL_THRESHOLD_COST: str = "5: Optimal Threshold: Cost"
95
+ OPTIMAL_THRESHOLD_F1: str = "5: Optimal Threshold: F1"
96
+ OPTIMAL_THRESHOLD_F1_NEGATIVE: str = "5: Optimal Threshold: F1 for Negative Class"
85
97
 
86
98
  # Average metrics across folds (starting with '6:')
87
99
  AVERAGE_METRIC: str = "6: Average METRIC_NAME across folds"
@@ -117,10 +117,10 @@ class MetricUtils:
117
117
  if mcc_metric:
118
118
  metrics.update(mcc_metric)
119
119
 
120
- # Calculate and plot ROC/AUC
121
- roc_metrics: dict[str, float] = MetricUtils.roc_and_auc(true_classes, y_pred, fold_number=-1, run_name=run_name)
122
- if roc_metrics:
123
- metrics.update(roc_metrics)
120
+ # Calculate and plot (ROC Curve / AUC) and (PR Curve / AUC, and negative one)
121
+ curves_metrics: dict[str, float] = MetricUtils.all_curves(true_classes, y_pred, fold_number=-1, run_name=run_name)
122
+ if curves_metrics:
123
+ metrics.update(curves_metrics)
124
124
 
125
125
  # Multiclass classification
126
126
  elif mode == "multiclass":
@@ -287,7 +287,7 @@ class MetricUtils:
287
287
 
288
288
  @staticmethod
289
289
  @handle_error(error_log=DataScienceConfig.ERROR_LOG)
290
- def roc_and_auc(
290
+ def roc_curve_and_auc(
291
291
  true_classes: NDArray[np.intc] | NDArray[np.single],
292
292
  pred_probs: NDArray[np.single],
293
293
  fold_number: int = -1,
@@ -311,7 +311,7 @@ class MetricUtils:
311
311
  >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
312
312
  >>> from stouputils.ctx import Muffle
313
313
  >>> with Muffle():
314
- ... metrics = MetricUtils.roc_and_auc(true_classes, pred_probs, run_name="")
314
+ ... metrics = MetricUtils.roc_curve_and_auc(true_classes, pred_probs, run_name="")
315
315
 
316
316
  >>> # Check metrics
317
317
  >>> round(float(metrics[MetricDictionnary.AUC]), 2)
@@ -377,6 +377,158 @@ class MetricUtils:
377
377
 
378
378
  return metrics
379
379
 
380
+ @staticmethod
381
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
382
+ def pr_curve_and_auc(
383
+ true_classes: NDArray[np.intc] | NDArray[np.single],
384
+ pred_probs: NDArray[np.single],
385
+ fold_number: int = -1,
386
+ run_name: str = ""
387
+ ) -> dict[str, float]:
388
+ """ Calculate Precision-Recall curve and AUC score. (and NPV-Specificity curve and AUC)
389
+
390
+ Args:
391
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
392
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
393
+ fold_number (int): Fold number, used for naming the plot file, usually
394
+ -1 for final model with test set,
395
+ 0 for final model with validation set,
396
+ >0 for other folds with their validation set
397
+ run_name (str): Name for saving the plot
398
+ Returns:
399
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds
400
+
401
+ Examples:
402
+ >>> true_classes = np.array([0, 1, 0])
403
+ >>> pred_probs = np.array([[0.9, 0.1], [0.1, 0.9], [0.1, 0.9]])
404
+ >>> from stouputils.ctx import Muffle
405
+ >>> with Muffle():
406
+ ... metrics = MetricUtils.pr_curve_and_auc(true_classes, pred_probs, run_name="")
407
+
408
+ >>> # Check metrics
409
+ >>> round(float(metrics[MetricDictionnary.AUPRC]), 2)
410
+ 0.75
411
+ >>> round(float(metrics[MetricDictionnary.NEGATIVE_AUPRC]), 2)
412
+ 0.92
413
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE]), 2)
414
+ 0.5
415
+ >>> round(float(metrics[MetricDictionnary.PR_AVERAGE_NEGATIVE]), 2)
416
+ 0.33
417
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1]), 2)
418
+ 0.9
419
+ >>> round(float(metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE]), 2)
420
+ 0.1
421
+ """
422
+ auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(true_classes, pred_probs)
423
+ neg_auc_value, average_precision_neg, npv, specificity, neg_thresholds = (
424
+ Utils.get_pr_curve_and_auc(true_classes, pred_probs, negative=True)
425
+ )
426
+
427
+ # Calculate metrics
428
+ metrics: dict[str, float] = {
429
+ MetricDictionnary.AUPRC: auc_value,
430
+ MetricDictionnary.NEGATIVE_AUPRC: neg_auc_value,
431
+ MetricDictionnary.PR_AVERAGE: average_precision,
432
+ MetricDictionnary.PR_AVERAGE_NEGATIVE: average_precision_neg
433
+ }
434
+
435
+ # Calculate optimal thresholds for both PR curves
436
+ for is_negative in (False, True):
437
+
438
+ # Get the right values based on positive/negative case
439
+ if not is_negative:
440
+ curr_precision = precision
441
+ curr_recall = recall
442
+ curr_thresholds = thresholds
443
+ curr_auc = auc_value
444
+ curr_ap = average_precision
445
+ else:
446
+ curr_precision = npv
447
+ curr_recall = specificity
448
+ curr_thresholds = neg_thresholds
449
+ curr_auc = neg_auc_value
450
+ curr_ap = average_precision_neg
451
+
452
+ # Calculate F-score for each threshold
453
+ fscore: NDArray[np.single] = (2 * curr_precision * curr_recall) / (curr_precision + curr_recall)
454
+ fscore = fscore[~np.isnan(fscore)]
455
+
456
+ # Get optimal threshold (maximum F-score)
457
+ if len(fscore) > 0:
458
+ optimal_idx: int = int(np.argmax(fscore))
459
+ optimal_threshold: float = curr_thresholds[optimal_idx]
460
+ else:
461
+ optimal_idx: int = 0
462
+ optimal_threshold = float('inf')
463
+
464
+ # Store in metrics dictionary
465
+ if not is_negative:
466
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1] = optimal_threshold
467
+ else:
468
+ metrics[MetricDictionnary.OPTIMAL_THRESHOLD_F1_NEGATIVE] = optimal_threshold
469
+
470
+ # Plot ROC curve if not nan
471
+ if run_name:
472
+ label: str = "Precision - Recall" if not is_negative else "Negative Predictive Value - Specificity"
473
+ plt.figure(figsize=(12, 6))
474
+ plt.plot(curr_recall, curr_precision, "b", label=f"{label} curve (AUC = {curr_auc:.2f}, AP = {curr_ap:.2f})")
475
+
476
+ # Prepare the path
477
+ fold_name: str = ""
478
+ if fold_number > 0:
479
+ fold_name = f"_fold_{fold_number}_val"
480
+ elif fold_number == 0:
481
+ fold_name = "_val"
482
+ elif fold_number == -1:
483
+ fold_name = "_test"
484
+ elif fold_number == -2:
485
+ fold_name = "_train"
486
+ pr: str = "pr" if not is_negative else "negative_pr"
487
+ curve_path: str = f"{DataScienceConfig.TEMP_FOLDER}/{run_name}_{pr}_curve{fold_name}.png"
488
+
489
+ plt.plot(
490
+ curr_recall[optimal_idx], curr_precision[optimal_idx], 'go', label=f"Optimal threshold (t={optimal_threshold:.2f})"
491
+ )
492
+
493
+ plt.xlim([-0.01, 1.01])
494
+ plt.ylim([-0.01, 1.01])
495
+ plt.xlabel("Recall" if not is_negative else "Specificity")
496
+ plt.ylabel("Precision" if not is_negative else "Negative Predictive Value")
497
+ plt.title(f"{label} Curve")
498
+ plt.legend(loc="lower right")
499
+ plt.savefig(curve_path)
500
+ mlflow.log_artifact(curve_path)
501
+ os.remove(curve_path)
502
+ plt.close()
503
+
504
+ return metrics
505
+
506
+ @staticmethod
507
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
508
+ def all_curves(
509
+ true_classes: NDArray[np.intc] | NDArray[np.single],
510
+ pred_probs: NDArray[np.single],
511
+ fold_number: int = -1,
512
+ run_name: str = ""
513
+ ) -> dict[str, float]:
514
+ """ Run all X_curve_and_auc functions and return a dictionary of metrics.
515
+
516
+ Args:
517
+ true_classes (NDArray[np.intc | np.single]): True class labels (one-hot encoded or class indices)
518
+ pred_probs (NDArray[np.single]): Predicted probabilities (must be probability scores, not class indices)
519
+ fold_number (int): Fold number, used for naming the plot file, usually
520
+ -1 for final model with test set,
521
+ 0 for final model with validation set,
522
+ >0 for other folds with their validation set
523
+ run_name (str): Name for saving the plot
524
+ Returns:
525
+ dict[str, float]: Dictionary containing AUC score and optimal thresholds for ROC and PR curves
526
+ """
527
+ metrics: dict[str, float] = {}
528
+ metrics.update(MetricUtils.roc_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
529
+ metrics.update(MetricUtils.pr_curve_and_auc(true_classes, pred_probs, fold_number, run_name))
530
+ return metrics
531
+
380
532
 
381
533
  @staticmethod
382
534
  @handle_error(error_log=DataScienceConfig.ERROR_LOG)
@@ -609,7 +761,7 @@ class MetricUtils:
609
761
 
610
762
  # Find all local minima
611
763
  from scipy.signal import argrelextrema
612
- local_minima_idx: NDArray[np.intp] = argrelextrema(y_array, np.less)[0]
764
+ local_minima_idx: NDArray[np.intp] = np.array(argrelextrema(y_array, np.less)[0], dtype=np.intp)
613
765
  distinct_candidates = np.unique(np.append(local_minima_idx, best_idx))
614
766
  else:
615
767
  assert 0 <= best_idx < len(x_array), "Best x index is out of bounds"
@@ -851,22 +851,13 @@ class ModelInterface(AbstractModel):
851
851
  self.history.append(history)
852
852
  self.evaluation_results.append(eval_results)
853
853
 
854
- # Generate and save ROC curve for this fold
855
- MetricUtils.roc_and_auc(
856
- true_classes=true_classes,
857
- pred_probs=predictions,
858
- fold_number=fold_number,
859
- run_name=self.run_name
860
- )
854
+ # Generate and save ROC Curve and PR Curve for this fold
855
+ MetricUtils.all_curves(true_classes, predictions, fold_number, run_name=self.run_name)
861
856
 
862
- # If final model, also log the ROC curve for the train set
857
+ # If final model, also log the ROC curve and PR curve for the train set
863
858
  if fold_number == 0:
864
- MetricUtils.roc_and_auc(
865
- true_classes=training_true_classes,
866
- pred_probs=training_predictions,
867
- fold_number=-2,
868
- run_name=self.run_name
869
- )
859
+ fold_number = -2 # -2 is the train set
860
+ MetricUtils.all_curves(training_true_classes, training_predictions, fold_number, run_name=self.run_name)
870
861
 
871
862
  # Log visualization artifacts if they were generated
872
863
  if temp_dir is not None:
@@ -197,22 +197,89 @@ class Utils:
197
197
  assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
198
198
  pred_probs: NDArray[np.single] = y_pred[:, 1] # Take probability of positive class only
199
199
 
200
- # Convert true labels to class indices if they're one-hot encoded
201
- true_classes: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
202
-
203
200
  # Calculate ROC curve and AUC score using probabilities
204
201
  with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
205
202
 
206
203
  # Import functions
207
204
  try:
208
- from sklearn.metrics import auc, roc_curve
205
+ from sklearn.metrics import roc_auc_score, roc_curve
209
206
  except ImportError as e:
210
207
  raise ImportError("scikit-learn is required for ROC curve calculation. Install with 'pip install scikit-learn'") from e
211
208
 
212
- results: tuple[Any, Any, Any] = roc_curve(true_classes, pred_probs, drop_intermediate=False)
209
+ # Convert y_true to class indices for both functions
210
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
211
+
212
+ # Calculate AUC score directly using roc_auc_score
213
+ auc_value: float = float(roc_auc_score(y_true_indices, pred_probs))
214
+
215
+ # Calculate ROC curve points
216
+ results: tuple[Any, Any, Any] = roc_curve(y_true_indices, pred_probs, drop_intermediate=False)
213
217
  fpr: NDArray[np.single] = results[0]
214
218
  tpr: NDArray[np.single] = results[1]
215
219
  thresholds: NDArray[np.single] = results[2]
216
- auc_value: float = float(auc(fpr, tpr))
220
+
217
221
  return auc_value, fpr, tpr, thresholds
218
222
 
223
+ @staticmethod
224
+ @handle_error(error_log=DataScienceConfig.ERROR_LOG)
225
+ def get_pr_curve_and_auc(
226
+ y_true: NDArray[np.intc | np.single],
227
+ y_pred: NDArray[np.single],
228
+ negative: bool = False
229
+ ) -> tuple[float, float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
230
+ """ Calculate Precision-Recall Curve (or Negative Precision-Recall Curve) and AUC score.
231
+
232
+ Args:
233
+ y_true (NDArray[intc | single]): True class labels (either one-hot encoded or class indices)
234
+ y_pred (NDArray[single]): Predicted probabilities (must be probability scores, not class indices)
235
+ negative (bool): Whether to calculate the negative Precision-Recall Curve
236
+ Returns:
237
+ tuple[float, NDArray[np.single], NDArray[np.single], NDArray[np.single]]:
238
+ Tuple containing either:
239
+ - AUC score, Average Precision, Precision, Recall, and Thresholds
240
+ - AUC score, Average Precision, Negative Predictive Value, Specificity, and Thresholds for the negative class
241
+
242
+ Examples:
243
+ >>> # Binary classification example
244
+ >>> y_true = np.array([0.0, 1.0, 0.0, 1.0, 0.0])
245
+ >>> y_pred = np.array([[0.2, 0.8], [0.1, 0.9], [0.8, 0.2], [0.2, 0.8], [0.7, 0.3]])
246
+ >>> auc_value, average_precision, precision, recall, thresholds = Utils.get_pr_curve_and_auc(y_true, y_pred)
247
+ >>> round(auc_value, 2)
248
+ 0.92
249
+ >>> round(average_precision, 2)
250
+ 0.83
251
+ >>> [round(x, 2) for x in precision.tolist()]
252
+ [0.4, 0.5, 0.67, 1.0, 1.0]
253
+ >>> [round(x, 2) for x in recall.tolist()]
254
+ [1.0, 1.0, 1.0, 0.5, 0.0]
255
+ >>> [round(x, 2) for x in thresholds.tolist()]
256
+ [0.2, 0.3, 0.8, 0.9]
257
+ """
258
+ # For predictions, assert they are probabilities (one-hot encoded)
259
+ assert y_pred.ndim > 1 and y_pred.shape[1] > 1, "Predictions must be probability scores in one-hot format"
260
+ pred_probs: NDArray[np.single] = y_pred[:, 1] if not negative else y_pred[:, 0]
261
+
262
+ # Calculate Precision-Recall Curve and AUC score using probabilities
263
+ with Muffle(mute_stderr=True): # Suppress "UndefinedMetricWarning: No positive samples in y_true [...]"
264
+
265
+ # Import functions
266
+ try:
267
+ from sklearn.metrics import auc, average_precision_score, precision_recall_curve
268
+ except ImportError as e:
269
+ raise ImportError("scikit-learn is required for PR Curve calculation. Install with 'pip install scikit-learn'") from e
270
+
271
+ # Convert y_true to class indices for both functions
272
+ y_true_indices: NDArray[np.intc] = Utils.convert_to_class_indices(y_true)
273
+
274
+ results: tuple[Any, Any, Any] = precision_recall_curve(
275
+ y_true=y_true_indices,
276
+ probas_pred=pred_probs,
277
+ pos_label=1 if not negative else 0
278
+ )
279
+ precision: NDArray[np.single] = results[0]
280
+ recall: NDArray[np.single] = results[1]
281
+ thresholds: NDArray[np.single] = results[2]
282
+ auc_value: float = float(auc(recall, precision))
283
+ average_precision: float = float(average_precision_score(y_true_indices, pred_probs))
284
+ return auc_value, average_precision, precision, recall, thresholds
285
+
@@ -33,6 +33,17 @@ def get_func_name(func: Callable[..., Any]) -> str:
33
33
  except AttributeError:
34
34
  return "<unknown>"
35
35
 
36
+ def get_wrapper_name(decorator_name: str, func: Callable[..., Any]) -> str:
37
+ """ Get a descriptive name for a wrapper function.
38
+
39
+ Args:
40
+ decorator_name (str): Name of the decorator
41
+ func (Callable[..., Any]): Function being decorated
42
+ Returns:
43
+ str: Combined name for the wrapper function
44
+ """
45
+ return f"{decorator_name}_{get_func_name(func)}"
46
+
36
47
  # Decorator that make a function silent (disable stdout)
37
48
  def silent(
38
49
  func: Callable[..., Any] | None = None,
@@ -66,6 +77,7 @@ def silent(
66
77
  # Use Muffle context manager to silence output
67
78
  with Muffle(mute_stderr=mute_stderr):
68
79
  return func(*args, **kwargs)
80
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.silent", func)
69
81
  return wrapper
70
82
 
71
83
  # Handle both @silent and @silent(mute_stderr=...)
@@ -108,6 +120,7 @@ def measure_time(
108
120
  def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any:
109
121
  with MeasureTime(print_func=print_func, message=new_msg, perf_counter=perf_counter):
110
122
  return func(*args, **kwargs)
123
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.measure_time", func)
111
124
  return wrapper
112
125
  return decorator
113
126
 
@@ -189,6 +202,7 @@ def handle_error(
189
202
  # Sleep for the specified time, only if the error_log is not ERROR_TRACEBACK (because it's blocking)
190
203
  if sleep_time > 0.0 and error_log != LogLevels.ERROR_TRACEBACK:
191
204
  time.sleep(sleep_time)
205
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.handle_error", func)
192
206
  return wrapper
193
207
 
194
208
  # Handle both @handle_error and @handle_error(exceptions=..., message=..., error_log=...)
@@ -255,6 +269,7 @@ def simple_cache(
255
269
  return result
256
270
 
257
271
  # Return the wrapper
272
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.simple_cache", func)
258
273
  return wrapper
259
274
 
260
275
  # Handle both @simple_cache and @simple_cache(method=...)
@@ -312,6 +327,7 @@ def deprecated(
312
327
 
313
328
  # Call the original function
314
329
  return func(*args, **kwargs)
330
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.deprecated", func)
315
331
  return wrapper
316
332
 
317
333
  # Handle both @deprecated and @deprecated(message=..., error_log=...)
@@ -362,6 +378,7 @@ def abstract(
362
378
  @handle_error(exceptions=NotImplementedError, error_log=error_log)
363
379
  def wrapper(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any:
364
380
  raise NotImplementedError(message)
381
+ wrapper.__name__ = get_wrapper_name("stouputils.decorators.abstract", func)
365
382
  return wrapper
366
383
 
367
384
  # Handle both @abstract and @abstract(error_log=...)
File without changes
File without changes
File without changes
File without changes
File without changes