dragon-ml-toolbox 14.3.0__py3-none-any.whl → 14.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 14.3.0
3
+ Version: 14.8.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -141,6 +141,7 @@ ETL_cleaning
141
141
  ETL_engineering
142
142
  math_utilities
143
143
  ML_callbacks
144
+ ML_configuration
144
145
  ML_datasetmaster
145
146
  ML_evaluation_multi
146
147
  ML_evaluation
@@ -1,25 +1,26 @@
1
- dragon_ml_toolbox-14.3.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-14.3.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
1
+ dragon_ml_toolbox-14.8.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-14.8.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=gkOdNDbKYpIJezwSo2CEnISkLeYfYHv9t8b5K2-P69A,2687
3
3
  ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
4
  ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
5
  ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
6
  ml_tools/MICE_imputation.py,sha256=KLJXGQLKJ6AuWWttAG-LCCaxpS-ygM4dXPiguHDaL6Y,20815
7
7
  ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
8
- ml_tools/ML_datasetmaster.py,sha256=rsJgZEGBJmfeKF6cR8CQZzfEx4T7Y-p1wUnR15_nNw0,28400
9
- ml_tools/ML_evaluation.py,sha256=4GU86rUWMIGbkXrvN6PyjfGwKtWvXKE7pMlWpWeBq14,18988
10
- ml_tools/ML_evaluation_multi.py,sha256=rJKdgtq-9I7oaI7PRzq7aIZ84XdNV0xzlVePZW4nj0k,16095
8
+ ml_tools/ML_configuration.py,sha256=tXkm2q57bl2kK0Iqpx1G7s1pEURBL_UMmqD8mlsGPs4,4689
9
+ ml_tools/ML_datasetmaster.py,sha256=Zi5jBnBI_U6tD8mpCVL5bQcsqsGEMAzMsCVI_wFD2QU,30175
10
+ ml_tools/ML_evaluation.py,sha256=EvlgFeMQeZ1RSEMtNd-nv7W0d0SVcR4n6cwW5UG16DU,25358
11
+ ml_tools/ML_evaluation_multi.py,sha256=bQZ2gJY-dBzKQxvtd-B6wVaGBdFpQGVBr7tQZFokp5E,17166
11
12
  ml_tools/ML_inference.py,sha256=YJ953bhNWsdlPRtJQh3h2ACfMIgp8dQ9KtL9Azar-5s,23489
12
13
  ml_tools/ML_models.py,sha256=PqOcNlws7vCJMbiVCKqlPuktxvskZVUHG3VfU-Yshf8,31415
13
14
  ml_tools/ML_models_advanced.py,sha256=vk3PZBSu3DVso2S1rKTxxdS43XG8Q5FnasIL3-rMajc,12410
14
15
  ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
15
16
  ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
16
- ml_tools/ML_trainer.py,sha256=ZWI4MbUcLeBxyfoUTL96l5tjHHMp9I64h4SdXnjYmBE,49795
17
- ml_tools/ML_utilities.py,sha256=z6LbpbZwhn8F__fWlKi-g-cAJQXSxwg1NHfC5FBoAyc,21139
18
- ml_tools/ML_vision_datasetmaster.py,sha256=feFNUBjybzVJJrdyqToQ_mLV1uDJXHkNL0tmn_zofSY,56034
17
+ ml_tools/ML_trainer.py,sha256=salZxfv3RWRCiinp5S9xeUsHysMbMQ52EecR8GyEbaM,51461
18
+ ml_tools/ML_utilities.py,sha256=eYe2N-65FTzaOHF5gmiJl-HmicyzhqcdvlDiIivr5_g,22993
19
+ ml_tools/ML_vision_datasetmaster.py,sha256=VHZo0gzgrXrfGcHA34WKD3gGfhlxMrOXbNdYhXb6p6M,64462
19
20
  ml_tools/ML_vision_evaluation.py,sha256=t12R7i1RkOCt9zu1_lxSBr8OH6A6Get0k8ftDLctn6I,10486
20
21
  ml_tools/ML_vision_inference.py,sha256=He3KV3VJAm8PwO-fOq4b9VO8UXFr-GmpuCnoHXf4VZI,20588
21
- ml_tools/ML_vision_models.py,sha256=G3S4jB9AE9wMpU9ZygOgOx9q1K6t6LAXBYcJ-U2XQ1M,25600
22
- ml_tools/ML_vision_transformers.py,sha256=95e0aBkHY5VDGE8i5xy57COU7NvSNIgFknnhBubwE40,1832
22
+ ml_tools/ML_vision_models.py,sha256=WqiRN9JAjv--BcwkDrooXAs4Qo26JHPCHh3JSPm4kMI,26226
23
+ ml_tools/ML_vision_transformers.py,sha256=h332O9BjDMgxrBc0I-bJwJODWlcp7nJHbX1QS2etwBk,7738
23
24
  ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
24
25
  ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
25
26
  ml_tools/SQL.py,sha256=vXLPGfVVg8bfkbBE3HVfyEclVbdJy0TBhuQONtMwSCQ,11234
@@ -32,17 +33,17 @@ ml_tools/_script_info.py,sha256=21r83LV3RubsNZ_RTEUON6RbDf7Mh4_udweNcvdF_Fk,212
32
33
  ml_tools/constants.py,sha256=3br5Rk9cL2IUo638eJuMOGdbGQaWssaUecYEvSeRBLM,3322
33
34
  ml_tools/custom_logger.py,sha256=TGc0Ww2Xlqj2XE3q4bP43hV7T3qnb5ci9f0pYHXF5TY,11226
34
35
  ml_tools/data_exploration.py,sha256=bwHzFJ-IAo5GN3T53F-1J_pXUg8VHS91sG_90utAsfg,69911
35
- ml_tools/ensemble_evaluation.py,sha256=FGHSe8LBI8_w8LjNeJWOcYQ1UK_mc6fVah8gmSvNVGg,26853
36
+ ml_tools/ensemble_evaluation.py,sha256=2sJ3jD6yBNPRNwSokyaLKqKHi0QhF13ChoFe5yd4zwg,28368
36
37
  ml_tools/ensemble_inference.py,sha256=0yLmLNj45RVVoSCLH1ZYJG9IoAhTkWUqEZmLOQTFGTY,9348
37
38
  ml_tools/ensemble_learning.py,sha256=vsIED7nlheYI4w2SBzP6SC1AnNeMfn-2A1Gqw5EfxsM,21964
38
39
  ml_tools/handle_excel.py,sha256=pfdAPb9ywegFkM9T54bRssDOsX-K7rSeV0RaMz7lEAo,14006
39
- ml_tools/keys.py,sha256=wZOBuEnnHc54vlOZiimnrxfk-sZh6f6suPppJW8rbPQ,3326
40
+ ml_tools/keys.py,sha256=-OiL9G0RIOKQk6BwETKIP3LWz2s5-x6lZW2YitJa4mY,3330
40
41
  ml_tools/math_utilities.py,sha256=xeKq1quR_3DYLgowcp4Uam_4s3JltUyOnqMOGuAiYWU,8802
41
42
  ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJRs,12763
42
43
  ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
43
44
  ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
44
45
  ml_tools/utilities.py,sha256=aWqvYzmxlD74PD5Yqu1VuTekDJeYLQrmPIU_VeVyRp0,22526
45
- dragon_ml_toolbox-14.3.0.dist-info/METADATA,sha256=TeVrfmCt4AVSweSN4Ai0yyZCJMQtSD1MHsUoEQHXLg4,6475
46
- dragon_ml_toolbox-14.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
47
- dragon_ml_toolbox-14.3.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
48
- dragon_ml_toolbox-14.3.0.dist-info/RECORD,,
46
+ dragon_ml_toolbox-14.8.0.dist-info/METADATA,sha256=9OndkhzBGS_XzlCPuHH88wIgndT2jhWN4fydXTGJg-8,6492
47
+ dragon_ml_toolbox-14.8.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
48
+ dragon_ml_toolbox-14.8.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
49
+ dragon_ml_toolbox-14.8.0.dist-info/RECORD,,
@@ -0,0 +1,116 @@
1
+ from typing import Optional
2
+ from ._script_info import _script_info
3
+
4
+
5
+ __all__ = [
6
+ "ClassificationMetricsFormat",
7
+ "MultiClassificationMetricsFormat"
8
+ ]
9
+
10
+
11
+ class ClassificationMetricsFormat:
12
+ """
13
+ Optional configuration for classification tasks, use in the '.evaluate()' method of the MLTrainer.
14
+ """
15
+ def __init__(self,
16
+ cmap: str="Blues",
17
+ class_map: Optional[dict[str,int]]=None,
18
+ ROC_PR_line: str='darkorange',
19
+ calibration_bins: int=15,
20
+ font_size: int=16) -> None:
21
+ """
22
+ Initializes the formatting configuration for single-label classification metrics.
23
+
24
+ Args:
25
+ cmap (str): The matplotlib colormap name for the confusion matrix
26
+ and report heatmap. Defaults to "Blues".
27
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
28
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
29
+
30
+ class_map (dict[str,int] | None): A dictionary mapping
31
+ class string names to their integer indices (e.g., {'cat': 0, 'dog': 1}).
32
+ This is used to label the axes of the confusion matrix and classification
33
+ report correctly. Defaults to None.
34
+
35
+ ROC_PR_line (str): The color name or hex code for the line plotted
36
+ on the ROC and Precision-Recall curves. Defaults to 'darkorange'.
37
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
38
+ - Hex codes: '#FF6347', '#4682B4'
39
+
40
+ calibration_bins (int): The number of bins to use when
41
+ creating the calibration (reliability) plot. Defaults to 15.
42
+
43
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
44
+
45
+ <br>
46
+
47
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
48
+ """
49
+ self.cmap = cmap
50
+ self.class_map = class_map
51
+ self.ROC_PR_line = ROC_PR_line
52
+ self.calibration_bins = calibration_bins
53
+ self.font_size = font_size
54
+
55
+ def __repr__(self) -> str:
56
+ parts = [
57
+ f"cmap='{self.cmap}'",
58
+ f"class_map={self.class_map}",
59
+ f"ROC_PR_line='{self.ROC_PR_line}'",
60
+ f"calibration_bins={self.calibration_bins}",
61
+ f"font_size={self.font_size}"
62
+ ]
63
+ return f"ClassificationMetricsFormat({', '.join(parts)})"
64
+
65
+
66
+ class MultiClassificationMetricsFormat:
67
+ """
68
+ Optional configuration for multi-label classification tasks, use in the '.evaluate()' method of the MLTrainer.
69
+ """
70
+ def __init__(self,
71
+ threshold: float=0.5,
72
+ ROC_PR_line: str='darkorange',
73
+ cmap: str = "Blues",
74
+ font_size: int = 16) -> None:
75
+ """
76
+ Initializes the formatting configuration for multi-label classification metrics.
77
+
78
+ Args:
79
+ threshold (float): The probability threshold (0.0 to 1.0) used
80
+ to convert sigmoid outputs into binary (0 or 1) predictions for
81
+ calculating the confusion matrix and overall metrics. Defaults to 0.5.
82
+
83
+ ROC_PR_line (str): The color name or hex code for the line plotted
84
+ on the ROC and Precision-Recall curves (one for each label).
85
+ Defaults to 'darkorange'.
86
+ - Common color names: 'darkorange', 'cornflowerblue', 'crimson', 'forestgreen'
87
+ - Hex codes: '#FF6347', '#4682B4'
88
+
89
+ cmap (str): The matplotlib colormap name for the per-label
90
+ confusion matrices. Defaults to "Blues".
91
+ - Sequential options: 'Blues', 'Greens', 'Reds', 'Oranges', 'Purples'
92
+ - Diverging options: 'coolwarm', 'viridis', 'plasma', 'inferno'
93
+
94
+ font_size (int): The base font size to apply to the plots. Defaults to 16.
95
+
96
+ <br>
97
+
98
+ ## [Matplotlib Colormaps](https://matplotlib.org/stable/users/explain/colors/colormaps.html)
99
+ """
100
+ self.threshold = threshold
101
+ self.cmap = cmap
102
+ self.ROC_PR_line = ROC_PR_line
103
+ self.font_size = font_size
104
+
105
+ def __repr__(self) -> str:
106
+ parts = [
107
+ f"threshold={self.threshold}",
108
+ f"ROC_PR_line='{self.ROC_PR_line}'",
109
+ f"cmap='{self.cmap}'",
110
+ f"font_size={self.font_size}"
111
+ ]
112
+ return f"MultiClassificationMetricsFormat({', '.join(parts)})"
113
+
114
+
115
+ def info():
116
+ _script_info(__all__)
@@ -333,7 +333,20 @@ class DatasetMaker(_BaseDatasetMaker):
333
333
  # --- 5. Create Datasets ---
334
334
  self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
335
335
  self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
336
+
337
+ def __repr__(self) -> str:
338
+ s = f"<{self.__class__.__name__} (ID: '{self.id}')>\n"
339
+ s += f" Target: {self.target_names[0]}\n"
340
+ s += f" Features: {self.number_of_features}\n"
341
+ s += f" Scaler: {'Fitted' if self.scaler else 'None'}\n"
336
342
 
343
+ if self._train_ds:
344
+ s += f" Train Samples: {len(self._train_ds)}\n" # type: ignore
345
+ if self._test_ds:
346
+ s += f" Test Samples: {len(self._test_ds)}\n" # type: ignore
347
+
348
+ return s
349
+
337
350
 
338
351
  # --- Multi-Target Class ---
339
352
  class DatasetMakerMulti(_BaseDatasetMaker):
@@ -448,6 +461,19 @@ class DatasetMakerMulti(_BaseDatasetMaker):
448
461
  self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
449
462
  self._test_ds = _PytorchDataset(X_test_final, y_test, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
450
463
 
464
+ def __repr__(self) -> str:
465
+ s = f"<{self.__class__.__name__} (ID: '{self.id}')>\n"
466
+ s += f" Targets: {self.number_of_targets}\n"
467
+ s += f" Features: {self.number_of_features}\n"
468
+ s += f" Scaler: {'Fitted' if self.scaler else 'None'}\n"
469
+
470
+ if self._train_ds:
471
+ s += f" Train Samples: {len(self._train_ds)}\n" # type: ignore
472
+ if self._test_ds:
473
+ s += f" Test Samples: {len(self._test_ds)}\n" # type: ignore
474
+
475
+ return s
476
+
451
477
 
452
478
  # --- Private Base Class ---
453
479
  class _BaseMaker(ABC):
@@ -654,6 +680,22 @@ class SequenceMaker(_BaseMaker):
654
680
  _LOGGER.error("Windows have not been generated. Call .generate_windows() first.")
655
681
  raise RuntimeError()
656
682
  return self._train_dataset, self._test_dataset
683
+
684
+ def __repr__(self) -> str:
685
+ s = f"<{self.__class__.__name__}>:\n"
686
+ s += f" Sequence Length (Window): {self.sequence_length}\n"
687
+ s += f" Total Data Points: {len(self.sequence)}\n"
688
+ s += " --- Status ---\n"
689
+ s += f" Split: {self._is_split}\n"
690
+ s += f" Normalized: {self._is_normalized}\n"
691
+ s += f" Windows Generated: {self._are_windows_generated}\n"
692
+
693
+ if self._are_windows_generated:
694
+ train_len = len(self._train_dataset) if self._train_dataset else 0 # type: ignore
695
+ test_len = len(self._test_dataset) if self._test_dataset else 0 # type: ignore
696
+ s += f" Datasets (Train/Test): {train_len} / {test_len} windows\n"
697
+
698
+ return s
657
699
 
658
700
 
659
701
  def info():
ml_tools/ML_evaluation.py CHANGED
@@ -21,7 +21,7 @@ from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
22
  import warnings
23
23
 
24
- from .path_manager import make_fullpath
24
+ from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
27
  from .keys import SHAPKeys, PyTorchLogKeys
@@ -35,6 +35,8 @@ __all__ = [
35
35
  "plot_attention_importance"
36
36
  ]
37
37
 
38
+ DPI_value = 250
39
+
38
40
 
39
41
  def plot_losses(history: dict, save_dir: Union[str, Path]):
40
42
  """
@@ -48,10 +50,10 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
48
50
  val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
49
51
 
50
52
  if not train_loss and not val_loss:
51
- print("Warning: Loss history is empty or incomplete. Cannot plot.")
53
+ _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
52
54
  return
53
55
 
54
- fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
56
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
55
57
 
56
58
  # Plot training loss only if data for it exists
57
59
  if train_loss:
@@ -78,8 +80,15 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
78
80
  plt.close(fig)
79
81
 
80
82
 
81
- def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
82
- cmap: str = "Blues"):
83
+ def classification_metrics(save_dir: Union[str, Path],
84
+ y_true: np.ndarray,
85
+ y_pred: np.ndarray,
86
+ y_prob: Optional[np.ndarray] = None,
87
+ cmap: str = "Blues",
88
+ class_map: Optional[dict[str,int]]=None,
89
+ ROC_PR_line: str='darkorange',
90
+ calibration_bins: int=15,
91
+ font_size: int=16):
83
92
  """
84
93
  Saves classification metrics and plots.
85
94
 
@@ -89,12 +98,31 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
89
98
  y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
90
99
  cmap (str): Colormap for the confusion matrix.
91
100
  save_dir (str | Path): Directory to save plots.
101
+ class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
92
102
  """
93
- print("--- Classification Report ---")
103
+ original_rc_params = plt.rcParams.copy()
104
+ plt.rcParams.update({'font.size': font_size})
105
+
106
+ # print("--- Classification Report ---")
107
+
108
+ # --- Parse class_map ---
109
+ map_labels = None
110
+ map_display_labels = None
111
+ if class_map:
112
+ # Sort the map by its values (the indices) to ensure correct order
113
+ try:
114
+ sorted_items = sorted(class_map.items(), key=lambda item: item[1])
115
+ map_labels = [item[1] for item in sorted_items]
116
+ map_display_labels = [item[0] for item in sorted_items]
117
+ except Exception as e:
118
+ _LOGGER.warning(f"Could not parse 'class_map': {e}")
119
+ map_labels = None
120
+ map_display_labels = None
121
+
94
122
  # Generate report as both text and dictionary
95
- report_text: str = classification_report(y_true, y_pred) # type: ignore
96
- report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
97
- print(report_text)
123
+ report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
124
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
125
+ # print(report_text)
98
126
 
99
127
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
100
128
  # Save text report
@@ -104,8 +132,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
104
132
 
105
133
  # --- Save Classification Report Heatmap ---
106
134
  try:
107
- plt.figure(figsize=(8, 6), dpi=100)
108
- sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
135
+ plt.figure(figsize=(8, 6), dpi=DPI_value)
136
+ sns.set_theme(font_scale=1.2) # Scale seaborn font
137
+ sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
138
+ annot=True,
139
+ cmap=cmap,
140
+ fmt='.2f',
141
+ vmin=0.0,
142
+ vmax=1.0)
143
+ sns.set_theme(font_scale=1.0) # Reset seaborn scale
109
144
  plt.title("Classification Report")
110
145
  plt.tight_layout()
111
146
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
@@ -114,69 +149,179 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
114
149
  plt.close()
115
150
  except Exception as e:
116
151
  _LOGGER.error(f"Could not generate classification report heatmap: {e}")
117
-
152
+
153
+ # --- labels for Confusion Matrix ---
154
+ plot_labels = map_labels
155
+ plot_display_labels = map_display_labels
156
+
118
157
  # Save Confusion Matrix
119
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
120
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
158
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
159
+ disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
160
+ y_pred,
161
+ cmap=cmap,
162
+ ax=ax_cm,
163
+ normalize='true',
164
+ labels=plot_labels,
165
+ display_labels=plot_display_labels)
166
+
167
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
168
+
169
+ # Turn off gridlines
170
+ ax_cm.grid(False)
171
+
172
+ # Manually update font size of cell texts
173
+ for text in ax_cm.texts:
174
+ text.set_fontsize(font_size)
175
+
176
+ fig_cm.tight_layout()
177
+
121
178
  ax_cm.set_title("Confusion Matrix")
122
179
  cm_path = save_dir_path / "confusion_matrix.svg"
123
180
  plt.savefig(cm_path)
124
181
  _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
125
182
  plt.close(fig_cm)
126
183
 
127
- # Plotting logic for ROC and PR Curves
128
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
129
- # Use probabilities of the positive class
130
- y_score = y_prob[:, 1]
184
+
185
+ # Plotting logic for ROC, PR, and Calibration Curves
186
+ if y_prob is not None and y_prob.ndim == 2:
187
+ num_classes = y_prob.shape[1]
188
+
189
+ # --- Determine which classes to loop over ---
190
+ class_indices_to_plot = []
191
+ plot_titles = []
192
+ save_suffixes = []
193
+
194
+ if num_classes == 2:
195
+ # Binary case: Only plot for the positive class (index 1)
196
+ class_indices_to_plot = [1]
197
+ plot_titles = [""] # No extra title
198
+ save_suffixes = [""] # No extra suffix
199
+ _LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
131
200
 
132
- # --- Save ROC Curve ---
133
- fpr, tpr, _ = roc_curve(y_true, y_score)
134
- auc = roc_auc_score(y_true, y_score)
135
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
136
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
137
- ax_roc.plot([0, 1], [0, 1], 'k--')
138
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
139
- ax_roc.set_xlabel('False Positive Rate')
140
- ax_roc.set_ylabel('True Positive Rate')
141
- ax_roc.legend(loc='lower right')
142
- ax_roc.grid(True)
143
- roc_path = save_dir_path / "roc_curve.svg"
144
- plt.savefig(roc_path)
145
- _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
146
- plt.close(fig_roc)
147
-
148
- # --- Save Precision-Recall Curve ---
149
- precision, recall, _ = precision_recall_curve(y_true, y_score)
150
- ap_score = average_precision_score(y_true, y_score)
151
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
152
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
153
- ax_pr.set_title('Precision-Recall Curve')
154
- ax_pr.set_xlabel('Recall')
155
- ax_pr.set_ylabel('Precision')
156
- ax_pr.legend(loc='lower left')
157
- ax_pr.grid(True)
158
- pr_path = save_dir_path / "pr_curve.svg"
159
- plt.savefig(pr_path)
160
- _LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
161
- plt.close(fig_pr)
201
+ elif num_classes > 2:
202
+ _LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
203
+ # Multiclass case: Plot for every class (One-vs-Rest)
204
+ class_indices_to_plot = list(range(num_classes))
205
+
206
+ # --- Use class_map names if available ---
207
+ use_generic_names = True
208
+ if map_display_labels and len(map_display_labels) == num_classes:
209
+ try:
210
+ # Ensure labels are safe for filenames
211
+ safe_names = [sanitize_filename(name) for name in map_display_labels]
212
+ plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
213
+ save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
214
+ use_generic_names = False
215
+ except Exception as e:
216
+ _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
217
+ use_generic_names = True
218
+
219
+ if use_generic_names:
220
+ plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
221
+ save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
162
222
 
163
- # --- Save Calibration Plot ---
164
- if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
165
- y_score = y_prob[:, 1] # Use probabilities of the positive class
223
+ else:
224
+ # Should not happen, but good to check
225
+ _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
226
+
227
+ # --- Loop and generate plots ---
228
+ for i, class_index in enumerate(class_indices_to_plot):
229
+ plot_title = plot_titles[i]
230
+ save_suffix = save_suffixes[i]
231
+
232
+ # Get scores for the current class
233
+ y_score = y_prob[:, class_index]
234
+
235
+ # Binarize y_true for the current class
236
+ y_true_binary = (y_true == class_index).astype(int)
237
+
238
+ # --- Save ROC Curve ---
239
+ fpr, tpr, _ = roc_curve(y_true_binary, y_score)
166
240
 
167
- fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
168
- CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
241
+ # Calculate AUC.
242
+ # Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
243
+ # Here we calculate the specific AUC for the binarized problem.
244
+ auc = roc_auc_score(y_true_binary, y_score)
169
245
 
170
- ax_cal.set_title('Reliability Curve')
246
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
247
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
248
+ ax_roc.plot([0, 1], [0, 1], 'k--')
249
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
250
+ ax_roc.set_xlabel('False Positive Rate')
251
+ ax_roc.set_ylabel('True Positive Rate')
252
+ ax_roc.legend(loc='lower right')
253
+ ax_roc.grid(True)
254
+ roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
255
+ plt.savefig(roc_path)
256
+ plt.close(fig_roc)
257
+
258
+ # --- Save Precision-Recall Curve ---
259
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
260
+ ap_score = average_precision_score(y_true_binary, y_score)
261
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
262
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
263
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
264
+ ax_pr.set_xlabel('Recall')
265
+ ax_pr.set_ylabel('Precision')
266
+ ax_pr.legend(loc='lower left')
267
+ ax_pr.grid(True)
268
+ pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
269
+ plt.savefig(pr_path)
270
+ plt.close(fig_pr)
271
+
272
+ # --- Save Calibration Plot ---
273
+ fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
274
+
275
+ # --- Step 1: Get binned data *without* plotting ---
276
+ with plt.ioff(): # Suppress showing the temporary plot
277
+ fig_temp, ax_temp = plt.subplots()
278
+ cal_display_temp = CalibrationDisplay.from_predictions(
279
+ y_true_binary, # Use binarized labels
280
+ y_score,
281
+ n_bins=calibration_bins,
282
+ ax=ax_temp,
283
+ name="temp" # Add a name to suppress potential warnings
284
+ )
285
+ # Get the x, y coordinates of the binned data
286
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
287
+ plt.close(fig_temp) # Close the temporary plot
288
+
289
+ # --- Step 2: Build the plot from scratch ---
290
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
291
+
292
+ sns.regplot(
293
+ x=line_x,
294
+ y=line_y,
295
+ ax=ax_cal,
296
+ scatter=False,
297
+ label=f"Calibration Curve ({calibration_bins} bins)",
298
+ line_kws={
299
+ 'color': ROC_PR_line,
300
+ 'linestyle': '--',
301
+ 'linewidth': 2,
302
+ }
303
+ )
304
+
305
+ ax_cal.set_title(f'Reliability Curve{plot_title}')
171
306
  ax_cal.set_xlabel('Mean Predicted Probability')
172
307
  ax_cal.set_ylabel('Fraction of Positives')
308
+
309
+ # --- Step 3: Set final limits *after* plotting ---
310
+ ax_cal.set_ylim(0.0, 1.0)
311
+ ax_cal.set_xlim(0.0, 1.0)
312
+
313
+ ax_cal.legend(loc='lower right')
173
314
  ax_cal.grid(True)
174
315
  plt.tight_layout()
175
316
 
176
- cal_path = save_dir_path / "calibration_plot.svg"
317
+ cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
177
318
  plt.savefig(cal_path)
178
- _LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
179
319
  plt.close(fig_cal)
320
+
321
+ _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
322
+
323
+ # restore RC params
324
+ plt.rcParams.update(original_rc_params)
180
325
 
181
326
 
182
327
  def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
@@ -211,7 +356,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
211
356
 
212
357
  # Save residual plot
213
358
  residuals = y_true - y_pred
214
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
359
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
215
360
  ax_res.scatter(y_pred, residuals, alpha=0.6)
216
361
  ax_res.axhline(0, color='red', linestyle='--')
217
362
  ax_res.set_xlabel("Predicted Values")
@@ -225,7 +370,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
225
370
  plt.close(fig_res)
226
371
 
227
372
  # Save true vs predicted plot
228
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
373
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
229
374
  ax_tvp.scatter(y_true, y_pred, alpha=0.6)
230
375
  ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
231
376
  ax_tvp.set_xlabel('True Values')
@@ -239,7 +384,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
239
384
  plt.close(fig_tvp)
240
385
 
241
386
  # Save Histogram of Residuals
242
- fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
387
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
243
388
  sns.histplot(residuals, kde=True, ax=ax_hist)
244
389
  ax_hist.set_xlabel("Residual Value")
245
390
  ax_hist.set_ylabel("Frequency")
@@ -276,7 +421,7 @@ def shap_summary_plot(model,
276
421
  slow and memory-intensive.
277
422
  """
278
423
 
279
- print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
424
+ _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
280
425
 
281
426
  model.eval()
282
427
  # model.cpu() # Run explanations on CPU
@@ -348,9 +493,9 @@ def shap_summary_plot(model,
348
493
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
349
494
  raise ValueError()
350
495
 
351
- if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
496
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
352
497
  # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
- shap_values = shap_values.squeeze(-1)
498
+ shap_values = shap_values.squeeze(-1) # type: ignore
354
499
 
355
500
  # --- 3. Plotting and Saving ---
356
501
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
@@ -455,7 +600,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
455
600
  # --- Step 3: Create and save the plot for top N features ---
456
601
  plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
457
602
 
458
- plt.figure(figsize=(10, 8), dpi=100)
603
+ plt.figure(figsize=(10, 8), dpi=DPI_value)
459
604
 
460
605
  # Create horizontal bar plot with error bars
461
606
  plt.barh(