dragon-ml-toolbox 19.13.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.13.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1901
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.13.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -1,867 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- import matplotlib.pyplot as plt
4
- import seaborn as sns
5
- from sklearn.calibration import CalibrationDisplay
6
- from sklearn.metrics import (
7
- classification_report,
8
- ConfusionMatrixDisplay,
9
- roc_curve,
10
- roc_auc_score,
11
- mean_squared_error,
12
- mean_absolute_error,
13
- r2_score,
14
- median_absolute_error,
15
- precision_recall_curve,
16
- average_precision_score
17
- )
18
- import torch
19
- import shap
20
- from pathlib import Path
21
- from typing import Union, Optional, List, Literal
22
- import warnings
23
-
24
- from ._path_manager import make_fullpath, sanitize_filename
25
- from ._logger import get_logger
26
- from ._script_info import _script_info
27
- from ._keys import SHAPKeys, PyTorchLogKeys, _EvaluationConfig
28
- from ._ML_configuration import (RegressionMetricsFormat,
29
- BinaryClassificationMetricsFormat,
30
- MultiClassClassificationMetricsFormat,
31
- BinaryImageClassificationMetricsFormat,
32
- MultiClassImageClassificationMetricsFormat,
33
- _BaseClassificationFormat,
34
- _BaseRegressionFormat)
35
-
36
-
37
- _LOGGER = get_logger("Evaluation")
38
-
39
-
40
- __all__ = [
41
- "plot_losses",
42
- "classification_metrics",
43
- "regression_metrics",
44
- "shap_summary_plot",
45
- "plot_attention_importance"
46
- ]
47
-
48
-
49
- DPI_value = _EvaluationConfig.DPI
50
- REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
51
- CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
52
-
53
-
54
- def plot_losses(history: dict, save_dir: Union[str, Path]):
55
- """
56
- Plots training & validation loss curves from a history object.
57
- Also plots the learning rate if available in the history.
58
-
59
- Args:
60
- history (dict): A dictionary containing 'train_loss' and 'val_loss'.
61
- save_dir (str | Path): Directory to save the plot image.
62
- """
63
- train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
64
- val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
65
- lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
66
-
67
- if not train_loss and not val_loss:
68
- _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
69
- return
70
-
71
- fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
72
-
73
- # --- Plot Losses (Left Y-axis) ---
74
- line_handles = [] # To store line objects for the legend
75
-
76
- # Plot training loss only if data for it exists
77
- if train_loss:
78
- epochs = range(1, len(train_loss) + 1)
79
- line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
80
- line_handles.append(line1)
81
-
82
- # Plot validation loss only if data for it exists
83
- if val_loss:
84
- epochs = range(1, len(val_loss) + 1)
85
- line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
86
- line_handles.append(line2)
87
-
88
- ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
89
- ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
90
- ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
91
- ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
92
- ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
93
- ax.grid(True, linestyle='--')
94
-
95
- # --- Plot Learning Rate (Right Y-axis) ---
96
- if lr_history:
97
- ax2 = ax.twinx() # Create a second y-axis
98
- epochs = range(1, len(lr_history) + 1)
99
- line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
100
- line_handles.append(line3)
101
-
102
- ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
103
- ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
104
- # Use scientific notation if the LR is very small
105
- ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
106
- # increase the size of the scientific notation
107
- ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
108
- # remove grid from second y-axis
109
- ax2.grid(False)
110
-
111
- # Combine legends from both axes
112
- ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
113
-
114
- # ax.grid(True)
115
- plt.tight_layout()
116
-
117
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
118
- save_path = save_dir_path / "loss_plot.svg"
119
- plt.savefig(save_path)
120
- _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
121
-
122
- plt.close(fig)
123
-
124
-
125
- def classification_metrics(save_dir: Union[str, Path],
126
- y_true: np.ndarray,
127
- y_pred: np.ndarray,
128
- y_prob: Optional[np.ndarray] = None,
129
- class_map: Optional[dict[str,int]] = None,
130
- config: Optional[Union[BinaryClassificationMetricsFormat,
131
- MultiClassClassificationMetricsFormat,
132
- BinaryImageClassificationMetricsFormat,
133
- MultiClassImageClassificationMetricsFormat]] = None):
134
- """
135
- Saves classification metrics and plots.
136
-
137
- Args:
138
- y_true (np.ndarray): Ground truth labels.
139
- y_pred (np.ndarray): Predicted labels.
140
- y_prob (np.ndarray): Predicted probabilities for ROC curve.
141
- config (object): Formatting configuration object.
142
- save_dir (str | Path): Directory to save plots.
143
- """
144
- # --- Parse Config or use defaults ---
145
- if config is None:
146
- # Create a default config if one wasn't provided
147
- format_config = _BaseClassificationFormat()
148
- else:
149
- format_config = config
150
-
151
- # original_rc_params = plt.rcParams.copy()
152
- # plt.rcParams.update({'font.size': format_config.font_size})
153
-
154
- # --- Set Font Sizes ---
155
- xtick_size = format_config.xtick_size
156
- ytick_size = format_config.ytick_size
157
- legend_size = format_config.legend_size
158
-
159
- # config font size for heatmap
160
- cm_font_size = format_config.cm_font_size
161
- cm_tick_size = cm_font_size - 4
162
-
163
- # --- Parse class_map ---
164
- map_labels = None
165
- map_display_labels = None
166
- if class_map:
167
- # Sort the map by its values (the indices) to ensure correct order
168
- try:
169
- sorted_items = sorted(class_map.items(), key=lambda item: item[1])
170
- map_labels = [item[1] for item in sorted_items]
171
- map_display_labels = [item[0] for item in sorted_items]
172
- except Exception as e:
173
- _LOGGER.warning(f"Could not parse 'class_map': {e}")
174
- map_labels = None
175
- map_display_labels = None
176
-
177
- # Generate report as both text and dictionary
178
- report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
179
- report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
180
- # print(report_text)
181
-
182
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
183
- # Save text report
184
- report_path = save_dir_path / "classification_report.txt"
185
- report_path.write_text(report_text, encoding="utf-8")
186
- _LOGGER.info(f"📝 Classification report saved as '{report_path.name}'")
187
-
188
- # --- Save Classification Report Heatmap ---
189
- try:
190
- # Create DataFrame from report
191
- report_df = pd.DataFrame(report_dict)
192
-
193
- # 1. Robust Cleanup: Drop by name, not position
194
- # Remove 'accuracy' column if it exists (handles the scalar value issue)
195
- report_df = report_df.drop(columns=['accuracy'], errors='ignore')
196
-
197
- # Remove 'support' row explicitly (safer than iloc[:-1])
198
- if 'support' in report_df.index:
199
- report_df = report_df.drop(index='support')
200
-
201
- # 2. Transpose: Rows = Classes, Cols = Metrics
202
- plot_df = report_df.T
203
-
204
- # 3. Dynamic Height Calculation
205
- # (Base height of 4 + 0.5 inches per class row)
206
- fig_height = max(5.0, len(plot_df.index) * 0.5 + 4.0)
207
- fig_width = 8.0 # Set a fixed width
208
-
209
- # --- Use calculated dimensions, not the config constant ---
210
- fig_heat, ax_heat = plt.subplots(figsize=(fig_width, fig_height), dpi=_EvaluationConfig.DPI)
211
-
212
- # sns.set_theme(font_scale=1.4)
213
- sns.heatmap(plot_df,
214
- annot=True,
215
- cmap=format_config.cmap,
216
- fmt='.2f',
217
- vmin=0.0,
218
- vmax=1.0,
219
- cbar_kws={'shrink': 0.9}) # Shrink colorbar slightly to fit better
220
-
221
- # sns.set_theme(font_scale=1.0)
222
-
223
- ax_heat.set_title("Classification Report Heatmap", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
224
-
225
- # manually increase the font size of the elements
226
- for text in ax_heat.texts:
227
- text.set_fontsize(cm_tick_size)
228
-
229
- # manually increase the size of the colorbar ticks
230
- cbar = ax_heat.collections[0].colorbar
231
- cbar.ax.tick_params(labelsize=cm_tick_size - 4) # type: ignore
232
-
233
- # Update Ticks
234
- ax_heat.tick_params(axis='x', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING)
235
- ax_heat.tick_params(axis='y', labelsize=cm_tick_size, pad=_EvaluationConfig.LABEL_PADDING, rotation=0) # Ensure Y labels are horizontal
236
-
237
- plt.tight_layout()
238
-
239
- heatmap_path = save_dir_path / "classification_report_heatmap.svg"
240
- plt.savefig(heatmap_path)
241
- _LOGGER.info(f"📊 Report heatmap saved as '{heatmap_path.name}'")
242
- plt.close(fig_heat)
243
-
244
- except Exception as e:
245
- _LOGGER.error(f"Could not generate classification report heatmap: {e}")
246
-
247
- # --- labels for Confusion Matrix ---
248
- plot_labels = map_labels
249
- plot_display_labels = map_display_labels
250
-
251
- # 1. DYNAMIC SIZE CALCULATION
252
- # Calculate figure size based on number of classes.
253
- n_classes = len(plot_labels) if plot_labels is not None else len(np.unique(y_true))
254
- # Ensure a minimum size so very small matrices aren't tiny
255
- fig_w = max(9, n_classes * 0.8 + 3)
256
- fig_h = max(8, n_classes * 0.8 + 2)
257
-
258
- # Use the calculated size instead of CLASSIFICATION_PLOT_SIZE
259
- fig_cm, ax_cm = plt.subplots(figsize=(fig_w, fig_h), dpi=DPI_value)
260
- disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
261
- y_pred,
262
- cmap=format_config.cmap,
263
- ax=ax_cm,
264
- normalize='true',
265
- labels=plot_labels,
266
- display_labels=plot_display_labels,
267
- colorbar=False)
268
-
269
- disp_.im_.set_clim(vmin=0.0, vmax=1.0)
270
-
271
- # Turn off gridlines
272
- ax_cm.grid(False)
273
-
274
- # 2. CHECK FOR FONT CLASH
275
- # If matrix is huge, force text smaller. If small, allow user config.
276
- final_font_size = cm_font_size + 2
277
- if n_classes > 2:
278
- final_font_size = cm_font_size - n_classes # Decrease font size for larger matrices
279
-
280
- for text in ax_cm.texts:
281
- text.set_fontsize(final_font_size)
282
-
283
- # Update Ticks for Confusion Matrix
284
- ax_cm.tick_params(axis='x', labelsize=cm_tick_size)
285
- ax_cm.tick_params(axis='y', labelsize=cm_tick_size)
286
-
287
- #if more than 3 classes, rotate x ticks
288
- if n_classes > 3:
289
- plt.setp(ax_cm.get_xticklabels(), rotation=45, ha='right', rotation_mode="anchor")
290
-
291
- # Set titles and labels with padding
292
- ax_cm.set_title("Confusion Matrix", pad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size + 2)
293
- ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
294
- ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=cm_font_size)
295
-
296
- # --- ADJUST COLORBAR FONT & SIZE---
297
- # Manually add the colorbar with the 'shrink' parameter
298
- cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
299
-
300
- # Update the tick size on the new cbar object
301
- cbar.ax.tick_params(labelsize=cm_tick_size)
302
-
303
- # (Optional) add a label to the bar itself (e.g. "Probability")
304
- # cbar.set_label('Probability', fontsize=12)
305
-
306
- fig_cm.tight_layout()
307
-
308
- cm_path = save_dir_path / "confusion_matrix.svg"
309
- plt.savefig(cm_path)
310
- _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
311
- plt.close(fig_cm)
312
-
313
-
314
- # Plotting logic for ROC, PR, and Calibration Curves
315
- if y_prob is not None and y_prob.ndim == 2:
316
- num_classes = y_prob.shape[1]
317
-
318
- # --- Determine which classes to loop over ---
319
- class_indices_to_plot = []
320
- plot_titles = []
321
- save_suffixes = []
322
-
323
- if num_classes == 2:
324
- # Binary case: Only plot for the positive class (index 1)
325
- class_indices_to_plot = [1]
326
- plot_titles = [""] # No extra title
327
- save_suffixes = [""] # No extra suffix
328
- _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
329
-
330
- elif num_classes > 2:
331
- _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
332
- # Multiclass case: Plot for every class (One-vs-Rest)
333
- class_indices_to_plot = list(range(num_classes))
334
-
335
- # --- Use class_map names if available ---
336
- use_generic_names = True
337
- if map_display_labels and len(map_display_labels) == num_classes:
338
- try:
339
- # Ensure labels are safe for filenames
340
- safe_names = [sanitize_filename(name) for name in map_display_labels]
341
- plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
342
- save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
343
- use_generic_names = False
344
- except Exception as e:
345
- _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
346
- use_generic_names = True
347
-
348
- if use_generic_names:
349
- plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
350
- save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
351
-
352
- else:
353
- # Should not happen, but good to check
354
- _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
355
-
356
- # --- Loop and generate plots ---
357
- for i, class_index in enumerate(class_indices_to_plot):
358
- plot_title = plot_titles[i]
359
- save_suffix = save_suffixes[i]
360
-
361
- # Get scores for the current class
362
- y_score = y_prob[:, class_index]
363
-
364
- # Binarize y_true for the current class
365
- y_true_binary = (y_true == class_index).astype(int)
366
-
367
- # --- Save ROC Curve ---
368
- fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
369
-
370
- try:
371
- # Calculate Youden's J statistic (tpr - fpr)
372
- J = tpr - fpr
373
- # Find the index of the best threshold
374
- best_index = np.argmax(J)
375
- optimal_threshold = thresholds[best_index]
376
-
377
- # Define the filename
378
- threshold_filename = f"best_threshold{save_suffix}.txt"
379
- threshold_path = save_dir_path / threshold_filename
380
-
381
- # Get the class name for the report
382
- class_name = ""
383
- # Check if we have display labels and the current index is valid
384
- if map_display_labels and class_index < len(map_display_labels):
385
- class_name = map_display_labels[class_index]
386
- if num_classes > 2:
387
- # Add 'vs. Rest' for multiclass one-vs-rest plots
388
- class_name += " (vs. Rest)"
389
- else:
390
- # Fallback to the generic title or default binary name
391
- class_name = plot_title.strip() or "Binary Positive Class"
392
-
393
- # Create content for the file
394
- file_content = (
395
- f"Optimal Classification Threshold (Youden's J Statistic)\n"
396
- f"Class: {class_name}\n"
397
- f"--------------------------------------------------\n"
398
- f"Threshold: {optimal_threshold:.6f}\n"
399
- f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
400
- f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
401
- )
402
-
403
- threshold_path.write_text(file_content, encoding="utf-8")
404
- _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
405
-
406
- except Exception as e:
407
- _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
408
-
409
- # Calculate AUC.
410
- auc = roc_auc_score(y_true_binary, y_score)
411
-
412
- fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
413
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
414
- ax_roc.plot([0, 1], [0, 1], 'k--')
415
- ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
416
- ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
417
- ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
418
-
419
- # Apply Ticks and Legend sizing
420
- ax_roc.tick_params(axis='x', labelsize=xtick_size)
421
- ax_roc.tick_params(axis='y', labelsize=ytick_size)
422
- ax_roc.legend(loc='lower right', fontsize=legend_size)
423
-
424
- ax_roc.grid(True)
425
- roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
426
-
427
- plt.tight_layout()
428
-
429
- plt.savefig(roc_path)
430
- plt.close(fig_roc)
431
-
432
- # --- Save Precision-Recall Curve ---
433
- precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
434
- ap_score = average_precision_score(y_true_binary, y_score)
435
- fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
436
- ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
437
- ax_pr.set_title(f'Precision-Recall Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
438
- ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
439
- ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
440
-
441
- # Apply Ticks and Legend sizing
442
- ax_pr.tick_params(axis='x', labelsize=xtick_size)
443
- ax_pr.tick_params(axis='y', labelsize=ytick_size)
444
- ax_pr.legend(loc='lower left', fontsize=legend_size)
445
-
446
- ax_pr.grid(True)
447
- pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
448
-
449
- plt.tight_layout()
450
-
451
- plt.savefig(pr_path)
452
- plt.close(fig_pr)
453
-
454
- # --- Save Calibration Plot ---
455
- fig_cal, ax_cal = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
456
-
457
- # --- Step 1: Get binned data *without* plotting ---
458
- with plt.ioff(): # Suppress showing the temporary plot
459
- fig_temp, ax_temp = plt.subplots()
460
- cal_display_temp = CalibrationDisplay.from_predictions(
461
- y_true_binary, # Use binarized labels
462
- y_score,
463
- n_bins=format_config.calibration_bins,
464
- ax=ax_temp,
465
- name="temp" # Add a name to suppress potential warnings
466
- )
467
- # Get the x, y coordinates of the binned data
468
- line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
469
- plt.close(fig_temp) # Close the temporary plot
470
-
471
- # --- Step 2: Build the plot from scratch ---
472
- ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
473
-
474
- sns.regplot(
475
- x=line_x,
476
- y=line_y,
477
- ax=ax_cal,
478
- scatter=False,
479
- label=f"Model calibration",
480
- line_kws={
481
- 'color': format_config.ROC_PR_line,
482
- 'linestyle': '--',
483
- 'linewidth': 2,
484
- }
485
- )
486
-
487
- ax_cal.set_title(f'Reliability Curve{plot_title}', pad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size + 2)
488
- ax_cal.set_xlabel('Mean Predicted Probability', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
489
- ax_cal.set_ylabel('Fraction of Positives', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=format_config.font_size)
490
-
491
- # --- Step 3: Set final limits *after* plotting ---
492
- ax_cal.set_ylim(0.0, 1.0)
493
- ax_cal.set_xlim(0.0, 1.0)
494
-
495
- # Apply Ticks and Legend sizing
496
- ax_cal.tick_params(axis='x', labelsize=xtick_size)
497
- ax_cal.tick_params(axis='y', labelsize=ytick_size)
498
- ax_cal.legend(loc='lower right', fontsize=legend_size)
499
-
500
- ax_cal.grid(True)
501
- plt.tight_layout()
502
-
503
- cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
504
- plt.savefig(cal_path)
505
- plt.close(fig_cal)
506
-
507
- _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
508
-
509
- # restore RC params
510
- # plt.rcParams.update(original_rc_params)
511
-
512
-
513
- def regression_metrics(
514
- y_true: np.ndarray,
515
- y_pred: np.ndarray,
516
- save_dir: Union[str, Path],
517
- config: Optional[RegressionMetricsFormat] = None
518
- ):
519
- """
520
- Saves regression metrics and plots.
521
-
522
- Args:
523
- y_true (np.ndarray): Ground truth values.
524
- y_pred (np.ndarray): Predicted values.
525
- save_dir (str | Path): Directory to save plots and report.
526
- config (RegressionMetricsFormat, optional): Formatting configuration object.
527
- """
528
-
529
- # --- Parse Config or use defaults ---
530
- if config is None:
531
- # Create a default config if one wasn't provided
532
- format_config = _BaseRegressionFormat()
533
- else:
534
- format_config = config
535
-
536
- # --- Set Matplotlib font size ---
537
- # original_rc_params = plt.rcParams.copy()
538
- # plt.rcParams.update({'font.size': format_config.font_size})
539
-
540
- # --- Resolve Font Sizes ---
541
- xtick_size = format_config.xtick_size
542
- ytick_size = format_config.ytick_size
543
- base_font_size = format_config.font_size
544
-
545
- # --- Calculate Metrics ---
546
- rmse = np.sqrt(mean_squared_error(y_true, y_pred))
547
- mae = mean_absolute_error(y_true, y_pred)
548
- r2 = r2_score(y_true, y_pred)
549
- medae = median_absolute_error(y_true, y_pred)
550
-
551
- report_lines = [
552
- "--- Regression Report ---",
553
- f" Root Mean Squared Error (RMSE): {rmse:.4f}",
554
- f" Mean Absolute Error (MAE): {mae:.4f}",
555
- f" Median Absolute Error (MedAE): {medae:.4f}",
556
- f" Coefficient of Determination (R²): {r2:.4f}"
557
- ]
558
- report_string = "\n".join(report_lines)
559
- # print(report_string)
560
-
561
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
562
- # Save text report
563
- report_path = save_dir_path / "regression_report.txt"
564
- report_path.write_text(report_string)
565
- _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
566
-
567
- # --- Save residual plot ---
568
- residuals = y_true - y_pred
569
- fig_res, ax_res = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
570
- ax_res.scatter(y_pred, residuals,
571
- alpha=format_config.scatter_alpha,
572
- color=format_config.scatter_color)
573
- ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
574
- ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
575
- ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
576
- ax_res.set_title("Residual Plot", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
577
-
578
- # Apply Ticks
579
- ax_res.tick_params(axis='x', labelsize=xtick_size)
580
- ax_res.tick_params(axis='y', labelsize=ytick_size)
581
-
582
- ax_res.grid(True)
583
- plt.tight_layout()
584
- res_path = save_dir_path / "residual_plot.svg"
585
- plt.savefig(res_path)
586
- _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
587
- plt.close(fig_res)
588
-
589
- # --- Save true vs predicted plot ---
590
- fig_tvp, ax_tvp = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
591
- ax_tvp.scatter(y_true, y_pred,
592
- alpha=format_config.scatter_alpha,
593
- color=format_config.scatter_color)
594
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
595
- linestyle='--',
596
- lw=2,
597
- color=format_config.ideal_line_color)
598
- ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
599
- ax_tvp.set_ylabel('Predictions', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
600
- ax_tvp.set_title('True vs. Predicted Values', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
601
-
602
- # Apply Ticks
603
- ax_tvp.tick_params(axis='x', labelsize=xtick_size)
604
- ax_tvp.tick_params(axis='y', labelsize=ytick_size)
605
-
606
- ax_tvp.grid(True)
607
- plt.tight_layout()
608
- tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
609
- plt.savefig(tvp_path)
610
- _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
611
- plt.close(fig_tvp)
612
-
613
- # --- Save Histogram of Residuals ---
614
- fig_hist, ax_hist = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
615
- sns.histplot(residuals, kde=True, ax=ax_hist,
616
- bins=format_config.hist_bins,
617
- color=format_config.scatter_color)
618
- ax_hist.set_xlabel("Residual Value", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
619
- ax_hist.set_ylabel("Frequency", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
620
- ax_hist.set_title("Distribution of Residuals", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
621
-
622
- # Apply Ticks
623
- ax_hist.tick_params(axis='x', labelsize=xtick_size)
624
- ax_hist.tick_params(axis='y', labelsize=ytick_size)
625
-
626
- ax_hist.grid(True)
627
- plt.tight_layout()
628
- hist_path = save_dir_path / "residuals_histogram.svg"
629
- plt.savefig(hist_path)
630
- _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
631
- plt.close(fig_hist)
632
-
633
- # --- Restore RC params ---
634
- # plt.rcParams.update(original_rc_params)
635
-
636
-
637
- def shap_summary_plot(model,
638
- background_data: Union[torch.Tensor,np.ndarray],
639
- instances_to_explain: Union[torch.Tensor,np.ndarray],
640
- feature_names: Optional[list[str]],
641
- save_dir: Union[str, Path],
642
- device: torch.device = torch.device('cpu'),
643
- explainer_type: Literal['deep', 'kernel'] = 'kernel'):
644
- """
645
- Calculates SHAP values and saves summary plots and data.
646
-
647
- Args:
648
- model (nn.Module): The trained PyTorch model.
649
- background_data (torch.Tensor): A sample of data for the explainer background.
650
- instances_to_explain (torch.Tensor): The specific data instances to explain.
651
- feature_names (list of str | None): Names of the features for plot labeling.
652
- save_dir (str | Path): Directory to save SHAP artifacts.
653
- device (torch.device): The torch device for SHAP calculations.
654
- explainer_type (Literal['deep', 'kernel']): The explainer to use.
655
- - 'deep': Uses shap.DeepExplainer. Fast and efficient for
656
- PyTorch models.
657
- - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
658
- slow and memory-intensive.
659
- """
660
-
661
- _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
662
-
663
- model.eval()
664
- # model.cpu() # Run explanations on CPU
665
-
666
- shap_values = None
667
- instances_to_explain_np = None
668
-
669
- if explainer_type == 'deep':
670
- # --- 1. Use DeepExplainer ---
671
-
672
- # Ensure data is torch.Tensor
673
- if isinstance(background_data, np.ndarray):
674
- background_data = torch.from_numpy(background_data).float()
675
- if isinstance(instances_to_explain, np.ndarray):
676
- instances_to_explain = torch.from_numpy(instances_to_explain).float()
677
-
678
- if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
679
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
680
- return
681
-
682
- background_data = background_data.to(device)
683
- instances_to_explain = instances_to_explain.to(device)
684
-
685
- with warnings.catch_warnings():
686
- warnings.simplefilter("ignore", category=UserWarning)
687
- explainer = shap.DeepExplainer(model, background_data)
688
-
689
- # print("Calculating SHAP values with DeepExplainer...")
690
- shap_values = explainer.shap_values(instances_to_explain)
691
- instances_to_explain_np = instances_to_explain.cpu().numpy()
692
-
693
- elif explainer_type == 'kernel':
694
- # --- 2. Use KernelExplainer ---
695
- _LOGGER.warning(
696
- "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
697
- )
698
-
699
- # Ensure data is np.ndarray
700
- if isinstance(background_data, torch.Tensor):
701
- background_data_np = background_data.cpu().numpy()
702
- else:
703
- background_data_np = background_data
704
-
705
- if isinstance(instances_to_explain, torch.Tensor):
706
- instances_to_explain_np = instances_to_explain.cpu().numpy()
707
- else:
708
- instances_to_explain_np = instances_to_explain
709
-
710
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
711
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
712
- return
713
-
714
- # Summarize background data
715
- background_summary = shap.kmeans(background_data_np, 30)
716
-
717
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
718
- x_torch = torch.from_numpy(x_np).float().to(device)
719
- with torch.no_grad():
720
- output = model(x_torch)
721
- # Return as numpy array
722
- return output.cpu().numpy()
723
-
724
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
725
- # print("Calculating SHAP values with KernelExplainer...")
726
- shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
727
- # instances_to_explain_np is already set
728
-
729
- else:
730
- _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
731
- raise ValueError()
732
-
733
- if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
734
- # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
735
- shap_values = shap_values.squeeze(-1) # type: ignore
736
-
737
- # --- 3. Plotting and Saving ---
738
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
739
- plt.ioff()
740
-
741
- # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
742
- if feature_names is None:
743
- # Create generic names if none were provided
744
- num_features = instances_to_explain_np.shape[1]
745
- feature_names = [f'feature_{i}' for i in range(num_features)]
746
-
747
- instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
748
-
749
- # Save Bar Plot
750
- bar_path = save_dir_path / "shap_bar_plot.svg"
751
- shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
752
- ax = plt.gca()
753
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
754
- plt.title("SHAP Feature Importance")
755
- plt.tight_layout()
756
- plt.savefig(bar_path)
757
- _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
758
- plt.close()
759
-
760
- # Save Dot Plot
761
- dot_path = save_dir_path / "shap_dot_plot.svg"
762
- shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
763
- ax = plt.gca()
764
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
765
- if plt.gcf().axes and len(plt.gcf().axes) > 1:
766
- cb = plt.gcf().axes[-1]
767
- cb.set_ylabel("", size=1)
768
- plt.title("SHAP Feature Importance")
769
- plt.tight_layout()
770
- plt.savefig(dot_path)
771
- _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
772
- plt.close()
773
-
774
- # Save Summary Data to CSV
775
- shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
776
- summary_path = save_dir_path / shap_summary_filename
777
-
778
- # Handle multi-class (list of arrays) vs. regression (single array)
779
- if isinstance(shap_values, list):
780
- mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
781
- else:
782
- mean_abs_shap = np.abs(shap_values).mean(axis=0)
783
-
784
- mean_abs_shap = mean_abs_shap.flatten()
785
-
786
- summary_df = pd.DataFrame({
787
- SHAPKeys.FEATURE_COLUMN: feature_names,
788
- SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
789
- }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
790
-
791
- summary_df.to_csv(summary_path, index=False)
792
-
793
- _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
794
- plt.ion()
795
-
796
-
797
- def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
798
- """
799
- Aggregates attention weights and plots global feature importance.
800
-
801
- The plot shows the mean attention for each feature as a bar, with the
802
- standard deviation represented by error bars.
803
-
804
- Args:
805
- weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
806
- feature_names (List[str] | None): Names of the features for plot labeling.
807
- save_dir (str | Path): Directory to save the plot and summary CSV.
808
- top_n (int): The number of top features to display in the plot.
809
- """
810
- if not weights:
811
- _LOGGER.error("Attention weights list is empty. Skipping importance plot.")
812
- return
813
-
814
- # --- Step 1: Aggregate data ---
815
- # Concatenate the list of tensors into a single large tensor
816
- full_weights_tensor = torch.cat(weights, dim=0)
817
-
818
- # Calculate mean and std dev across the batch dimension (dim=0)
819
- mean_weights = full_weights_tensor.mean(dim=0)
820
- std_weights = full_weights_tensor.std(dim=0)
821
-
822
- # --- Step 2: Create and save summary DataFrame ---
823
- if feature_names is None:
824
- feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
825
-
826
- summary_df = pd.DataFrame({
827
- 'feature': feature_names,
828
- 'mean_attention': mean_weights.numpy(),
829
- 'std_attention': std_weights.numpy()
830
- }).sort_values('mean_attention', ascending=False)
831
-
832
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
833
- summary_path = save_dir_path / "attention_summary.csv"
834
- summary_df.to_csv(summary_path, index=False)
835
- _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
836
-
837
- # --- Step 3: Create and save the plot for top N features ---
838
- plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
839
-
840
- plt.figure(figsize=(10, 8), dpi=DPI_value)
841
-
842
- # Create horizontal bar plot with error bars
843
- plt.barh(
844
- y=plot_df['feature'],
845
- width=plot_df['mean_attention'],
846
- xerr=plot_df['std_attention'],
847
- align='center',
848
- alpha=0.7,
849
- ecolor='grey',
850
- capsize=3,
851
- color='cornflowerblue'
852
- )
853
-
854
- plt.title('Top Features by Attention')
855
- plt.xlabel('Average Attention Weight')
856
- plt.ylabel('Feature')
857
- plt.grid(axis='x', linestyle='--', alpha=0.6)
858
- plt.tight_layout()
859
-
860
- plot_path = save_dir_path / "attention_importance.svg"
861
- plt.savefig(plot_path)
862
- _LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
863
- plt.close()
864
-
865
-
866
- def info():
867
- _script_info(__all__)