dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
  2. dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
  3. ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
  4. ml_tools/ETL_cleaning/_basic_clean.py +351 -0
  5. ml_tools/ETL_cleaning/_clean_tools.py +128 -0
  6. ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
  7. ml_tools/ETL_cleaning/_imprimir.py +13 -0
  8. ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
  9. ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
  10. ml_tools/ETL_engineering/_imprimir.py +24 -0
  11. ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
  12. ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
  13. ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
  14. ml_tools/GUI_tools/_imprimir.py +12 -0
  15. ml_tools/IO_tools/_IO_loggers.py +235 -0
  16. ml_tools/IO_tools/_IO_save_load.py +151 -0
  17. ml_tools/IO_tools/_IO_utils.py +140 -0
  18. ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
  19. ml_tools/IO_tools/_imprimir.py +14 -0
  20. ml_tools/MICE/_MICE_imputation.py +132 -0
  21. ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
  22. ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
  23. ml_tools/MICE/_imprimir.py +11 -0
  24. ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
  25. ml_tools/ML_callbacks/_base.py +101 -0
  26. ml_tools/ML_callbacks/_checkpoint.py +232 -0
  27. ml_tools/ML_callbacks/_early_stop.py +208 -0
  28. ml_tools/ML_callbacks/_imprimir.py +12 -0
  29. ml_tools/ML_callbacks/_scheduler.py +197 -0
  30. ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
  31. ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
  32. ml_tools/ML_chain/_dragon_chain.py +140 -0
  33. ml_tools/ML_chain/_imprimir.py +11 -0
  34. ml_tools/ML_configuration/__init__.py +90 -0
  35. ml_tools/ML_configuration/_base_model_config.py +69 -0
  36. ml_tools/ML_configuration/_finalize.py +366 -0
  37. ml_tools/ML_configuration/_imprimir.py +47 -0
  38. ml_tools/ML_configuration/_metrics.py +593 -0
  39. ml_tools/ML_configuration/_models.py +206 -0
  40. ml_tools/ML_configuration/_training.py +124 -0
  41. ml_tools/ML_datasetmaster/__init__.py +28 -0
  42. ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
  43. ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
  44. ml_tools/ML_datasetmaster/_imprimir.py +15 -0
  45. ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
  46. ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
  47. ml_tools/ML_evaluation/__init__.py +53 -0
  48. ml_tools/ML_evaluation/_classification.py +629 -0
  49. ml_tools/ML_evaluation/_feature_importance.py +409 -0
  50. ml_tools/ML_evaluation/_imprimir.py +25 -0
  51. ml_tools/ML_evaluation/_loss.py +92 -0
  52. ml_tools/ML_evaluation/_regression.py +273 -0
  53. ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
  54. ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
  55. ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
  56. ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
  57. ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
  58. ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
  59. ml_tools/ML_finalize_handler/__init__.py +10 -0
  60. ml_tools/ML_finalize_handler/_imprimir.py +8 -0
  61. ml_tools/ML_inference/__init__.py +22 -0
  62. ml_tools/ML_inference/_base_inference.py +166 -0
  63. ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
  64. ml_tools/ML_inference/_dragon_inference.py +332 -0
  65. ml_tools/ML_inference/_imprimir.py +11 -0
  66. ml_tools/ML_inference/_multi_inference.py +180 -0
  67. ml_tools/ML_inference_sequence/__init__.py +10 -0
  68. ml_tools/ML_inference_sequence/_imprimir.py +8 -0
  69. ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
  70. ml_tools/ML_inference_vision/__init__.py +10 -0
  71. ml_tools/ML_inference_vision/_imprimir.py +8 -0
  72. ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
  73. ml_tools/ML_models/__init__.py +32 -0
  74. ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
  75. ml_tools/ML_models/_base_mlp_attention.py +198 -0
  76. ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
  77. ml_tools/ML_models/_dragon_tabular.py +248 -0
  78. ml_tools/ML_models/_imprimir.py +18 -0
  79. ml_tools/ML_models/_mlp_attention.py +134 -0
  80. ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
  81. ml_tools/ML_models_sequence/__init__.py +10 -0
  82. ml_tools/ML_models_sequence/_imprimir.py +8 -0
  83. ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
  84. ml_tools/ML_models_vision/__init__.py +29 -0
  85. ml_tools/ML_models_vision/_base_wrapper.py +254 -0
  86. ml_tools/ML_models_vision/_image_classification.py +182 -0
  87. ml_tools/ML_models_vision/_image_segmentation.py +108 -0
  88. ml_tools/ML_models_vision/_imprimir.py +16 -0
  89. ml_tools/ML_models_vision/_object_detection.py +135 -0
  90. ml_tools/ML_optimization/__init__.py +21 -0
  91. ml_tools/ML_optimization/_imprimir.py +13 -0
  92. ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
  93. ml_tools/ML_optimization/_single_dragon.py +203 -0
  94. ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
  95. ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
  96. ml_tools/ML_scaler/__init__.py +10 -0
  97. ml_tools/ML_scaler/_imprimir.py +8 -0
  98. ml_tools/ML_trainer/__init__.py +20 -0
  99. ml_tools/ML_trainer/_base_trainer.py +297 -0
  100. ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
  101. ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
  102. ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
  103. ml_tools/ML_trainer/_imprimir.py +10 -0
  104. ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
  105. ml_tools/ML_utilities/_artifact_finder.py +382 -0
  106. ml_tools/ML_utilities/_imprimir.py +16 -0
  107. ml_tools/ML_utilities/_inspection.py +325 -0
  108. ml_tools/ML_utilities/_train_tools.py +205 -0
  109. ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
  110. ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
  111. ml_tools/ML_vision_transformers/_imprimir.py +14 -0
  112. ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
  113. ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
  114. ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
  115. ml_tools/PSO_optimization/_imprimir.py +10 -0
  116. ml_tools/SQL/__init__.py +7 -0
  117. ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
  118. ml_tools/SQL/_imprimir.py +8 -0
  119. ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
  120. ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
  121. ml_tools/VIF/_imprimir.py +10 -0
  122. ml_tools/_core/__init__.py +7 -1
  123. ml_tools/_core/_logger.py +8 -18
  124. ml_tools/_core/_schema_load_ops.py +43 -0
  125. ml_tools/_core/_script_info.py +2 -2
  126. ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
  127. ml_tools/data_exploration/_analysis.py +214 -0
  128. ml_tools/data_exploration/_cleaning.py +566 -0
  129. ml_tools/data_exploration/_features.py +583 -0
  130. ml_tools/data_exploration/_imprimir.py +32 -0
  131. ml_tools/data_exploration/_plotting.py +487 -0
  132. ml_tools/data_exploration/_schema_ops.py +176 -0
  133. ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
  134. ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
  135. ml_tools/ensemble_evaluation/_imprimir.py +14 -0
  136. ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
  137. ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
  138. ml_tools/ensemble_inference/_imprimir.py +9 -0
  139. ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
  140. ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
  141. ml_tools/ensemble_learning/_imprimir.py +10 -0
  142. ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
  143. ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
  144. ml_tools/excel_handler/_imprimir.py +13 -0
  145. ml_tools/{keys.py → keys/__init__.py} +4 -1
  146. ml_tools/keys/_imprimir.py +11 -0
  147. ml_tools/{_core → keys}/_keys.py +2 -0
  148. ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
  149. ml_tools/math_utilities/_imprimir.py +11 -0
  150. ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
  151. ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
  152. ml_tools/optimization_tools/_imprimir.py +13 -0
  153. ml_tools/optimization_tools/_optimization_bounds.py +236 -0
  154. ml_tools/optimization_tools/_optimization_plots.py +218 -0
  155. ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
  156. ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
  157. ml_tools/path_manager/_imprimir.py +15 -0
  158. ml_tools/path_manager/_path_tools.py +346 -0
  159. ml_tools/plot_fonts/__init__.py +8 -0
  160. ml_tools/plot_fonts/_imprimir.py +8 -0
  161. ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
  162. ml_tools/schema/__init__.py +15 -0
  163. ml_tools/schema/_feature_schema.py +223 -0
  164. ml_tools/schema/_gui_schema.py +191 -0
  165. ml_tools/schema/_imprimir.py +10 -0
  166. ml_tools/{serde.py → serde/__init__.py} +4 -2
  167. ml_tools/serde/_imprimir.py +10 -0
  168. ml_tools/{_core → serde}/_serde.py +3 -8
  169. ml_tools/{utilities.py → utilities/__init__.py} +11 -6
  170. ml_tools/utilities/_imprimir.py +18 -0
  171. ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
  172. ml_tools/utilities/_utility_tools.py +192 -0
  173. dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
  174. ml_tools/ML_chaining_inference.py +0 -8
  175. ml_tools/ML_configuration.py +0 -86
  176. ml_tools/ML_configuration_pytab.py +0 -14
  177. ml_tools/ML_datasetmaster.py +0 -10
  178. ml_tools/ML_evaluation.py +0 -16
  179. ml_tools/ML_evaluation_multi.py +0 -12
  180. ml_tools/ML_finalize_handler.py +0 -8
  181. ml_tools/ML_inference.py +0 -12
  182. ml_tools/ML_models.py +0 -14
  183. ml_tools/ML_models_advanced.py +0 -14
  184. ml_tools/ML_models_pytab.py +0 -14
  185. ml_tools/ML_optimization.py +0 -14
  186. ml_tools/ML_optimization_pareto.py +0 -8
  187. ml_tools/ML_scaler.py +0 -8
  188. ml_tools/ML_sequence_datasetmaster.py +0 -8
  189. ml_tools/ML_sequence_evaluation.py +0 -10
  190. ml_tools/ML_sequence_inference.py +0 -8
  191. ml_tools/ML_sequence_models.py +0 -8
  192. ml_tools/ML_trainer.py +0 -12
  193. ml_tools/ML_vision_datasetmaster.py +0 -12
  194. ml_tools/ML_vision_evaluation.py +0 -10
  195. ml_tools/ML_vision_inference.py +0 -8
  196. ml_tools/ML_vision_models.py +0 -18
  197. ml_tools/SQL.py +0 -8
  198. ml_tools/_core/_ETL_cleaning.py +0 -694
  199. ml_tools/_core/_IO_tools.py +0 -498
  200. ml_tools/_core/_ML_callbacks.py +0 -702
  201. ml_tools/_core/_ML_configuration.py +0 -1332
  202. ml_tools/_core/_ML_configuration_pytab.py +0 -102
  203. ml_tools/_core/_ML_evaluation.py +0 -867
  204. ml_tools/_core/_ML_evaluation_multi.py +0 -544
  205. ml_tools/_core/_ML_inference.py +0 -646
  206. ml_tools/_core/_ML_models.py +0 -668
  207. ml_tools/_core/_ML_models_pytab.py +0 -693
  208. ml_tools/_core/_ML_trainer.py +0 -2323
  209. ml_tools/_core/_ML_utilities.py +0 -886
  210. ml_tools/_core/_ML_vision_models.py +0 -644
  211. ml_tools/_core/_data_exploration.py +0 -1909
  212. ml_tools/_core/_optimization_tools.py +0 -493
  213. ml_tools/_core/_schema.py +0 -359
  214. ml_tools/plot_fonts.py +0 -8
  215. ml_tools/schema.py +0 -12
  216. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
  217. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
  218. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  219. {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,409 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import torch
6
+ import shap
7
+ from pathlib import Path
8
+ from typing import Union, Optional, Literal
9
+ import warnings
10
+
11
+ from ..path_manager import make_fullpath, sanitize_filename
12
+ from .._core import get_logger
13
+ from ..keys._keys import SHAPKeys, _EvaluationConfig
14
+
15
+
16
+ _LOGGER = get_logger("Feature Importance")
17
+
18
+
19
+ __all__ = [
20
+ "shap_summary_plot",
21
+ "plot_attention_importance",
22
+ "multi_target_shap_summary_plot",
23
+ ]
24
+
25
+
26
+ DPI_value = _EvaluationConfig.DPI
27
+
28
+
29
+ def shap_summary_plot(model,
30
+ background_data: Union[torch.Tensor,np.ndarray],
31
+ instances_to_explain: Union[torch.Tensor,np.ndarray],
32
+ feature_names: Optional[list[str]],
33
+ save_dir: Union[str, Path],
34
+ device: torch.device = torch.device('cpu'),
35
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
36
+ """
37
+ Calculates SHAP values and saves summary plots and data.
38
+
39
+ Args:
40
+ model (nn.Module): The trained PyTorch model.
41
+ background_data (torch.Tensor): A sample of data for the explainer background.
42
+ instances_to_explain (torch.Tensor): The specific data instances to explain.
43
+ feature_names (list of str | None): Names of the features for plot labeling.
44
+ save_dir (str | Path): Directory to save SHAP artifacts.
45
+ device (torch.device): The torch device for SHAP calculations.
46
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
47
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
48
+ PyTorch models.
49
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
50
+ slow and memory-intensive.
51
+ """
52
+
53
+ _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
54
+
55
+ model.eval()
56
+ # model.cpu() # Run explanations on CPU
57
+
58
+ shap_values = None
59
+ instances_to_explain_np = None
60
+
61
+ if explainer_type == 'deep':
62
+ # --- 1. Use DeepExplainer ---
63
+
64
+ # Ensure data is torch.Tensor
65
+ if isinstance(background_data, np.ndarray):
66
+ background_data = torch.from_numpy(background_data).float()
67
+ if isinstance(instances_to_explain, np.ndarray):
68
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
69
+
70
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
71
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
72
+ return
73
+
74
+ background_data = background_data.to(device)
75
+ instances_to_explain = instances_to_explain.to(device)
76
+
77
+ with warnings.catch_warnings():
78
+ warnings.simplefilter("ignore", category=UserWarning)
79
+ explainer = shap.DeepExplainer(model, background_data)
80
+
81
+ # print("Calculating SHAP values with DeepExplainer...")
82
+ shap_values = explainer.shap_values(instances_to_explain)
83
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
84
+
85
+ elif explainer_type == 'kernel':
86
+ # --- 2. Use KernelExplainer ---
87
+ _LOGGER.warning(
88
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
89
+ )
90
+
91
+ # Ensure data is np.ndarray
92
+ if isinstance(background_data, torch.Tensor):
93
+ background_data_np = background_data.cpu().numpy()
94
+ else:
95
+ background_data_np = background_data
96
+
97
+ if isinstance(instances_to_explain, torch.Tensor):
98
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
99
+ else:
100
+ instances_to_explain_np = instances_to_explain
101
+
102
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
103
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
104
+ return
105
+
106
+ # Summarize background data
107
+ background_summary = shap.kmeans(background_data_np, 30)
108
+
109
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
110
+ x_torch = torch.from_numpy(x_np).float().to(device)
111
+ with torch.no_grad():
112
+ output = model(x_torch)
113
+ # Return as numpy array
114
+ return output.cpu().numpy()
115
+
116
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
117
+ # print("Calculating SHAP values with KernelExplainer...")
118
+ shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
119
+ # instances_to_explain_np is already set
120
+
121
+ else:
122
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
123
+ raise ValueError()
124
+
125
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
126
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
127
+ shap_values = shap_values.squeeze(-1) # type: ignore
128
+
129
+ # --- 3. Plotting and Saving ---
130
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
131
+ plt.ioff()
132
+
133
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
134
+ if feature_names is None:
135
+ # Create generic names if none were provided
136
+ num_features = instances_to_explain_np.shape[1]
137
+ feature_names = [f'feature_{i}' for i in range(num_features)]
138
+
139
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
140
+
141
+ # Save Bar Plot
142
+ bar_path = save_dir_path / "shap_bar_plot.svg"
143
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
144
+ ax = plt.gca()
145
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
146
+ plt.title("SHAP Feature Importance")
147
+ plt.tight_layout()
148
+ plt.savefig(bar_path)
149
+ _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
150
+ plt.close()
151
+
152
+ # Save Dot Plot
153
+ dot_path = save_dir_path / "shap_dot_plot.svg"
154
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
155
+ ax = plt.gca()
156
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
157
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
158
+ cb = plt.gcf().axes[-1]
159
+ cb.set_ylabel("", size=1)
160
+ plt.title("SHAP Feature Importance")
161
+ plt.tight_layout()
162
+ plt.savefig(dot_path)
163
+ _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
164
+ plt.close()
165
+
166
+ # Save Summary Data to CSV
167
+ shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
168
+ summary_path = save_dir_path / shap_summary_filename
169
+
170
+ # Handle multi-class (list of arrays) vs. regression (single array)
171
+ if isinstance(shap_values, list):
172
+ mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
173
+ else:
174
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
175
+
176
+ mean_abs_shap = mean_abs_shap.flatten()
177
+
178
+ summary_df = pd.DataFrame({
179
+ SHAPKeys.FEATURE_COLUMN: feature_names,
180
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
181
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
182
+
183
+ summary_df.to_csv(summary_path, index=False)
184
+
185
+ _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
186
+ plt.ion()
187
+
188
+
189
+ def plot_attention_importance(weights: list[torch.Tensor], feature_names: Optional[list[str]], save_dir: Union[str, Path], top_n: int = 10):
190
+ """
191
+ Aggregates attention weights and plots global feature importance.
192
+
193
+ The plot shows the mean attention for each feature as a bar, with the
194
+ standard deviation represented by error bars.
195
+
196
+ Args:
197
+ weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
198
+ feature_names (List[str] | None): Names of the features for plot labeling.
199
+ save_dir (str | Path): Directory to save the plot and summary CSV.
200
+ top_n (int): The number of top features to display in the plot.
201
+ """
202
+ if not weights:
203
+ _LOGGER.error("Attention weights list is empty. Skipping importance plot.")
204
+ return
205
+
206
+ # --- Step 1: Aggregate data ---
207
+ # Concatenate the list of tensors into a single large tensor
208
+ full_weights_tensor = torch.cat(weights, dim=0)
209
+
210
+ # Calculate mean and std dev across the batch dimension (dim=0)
211
+ mean_weights = full_weights_tensor.mean(dim=0)
212
+ std_weights = full_weights_tensor.std(dim=0)
213
+
214
+ # --- Step 2: Create and save summary DataFrame ---
215
+ if feature_names is None:
216
+ feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
217
+
218
+ summary_df = pd.DataFrame({
219
+ 'feature': feature_names,
220
+ 'mean_attention': mean_weights.numpy(),
221
+ 'std_attention': std_weights.numpy()
222
+ }).sort_values('mean_attention', ascending=False)
223
+
224
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
225
+ summary_path = save_dir_path / "attention_summary.csv"
226
+ summary_df.to_csv(summary_path, index=False)
227
+ _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
228
+
229
+ # --- Step 3: Create and save the plot for top N features ---
230
+ plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
231
+
232
+ plt.figure(figsize=(10, 8), dpi=DPI_value)
233
+
234
+ # Create horizontal bar plot with error bars
235
+ plt.barh(
236
+ y=plot_df['feature'],
237
+ width=plot_df['mean_attention'],
238
+ xerr=plot_df['std_attention'],
239
+ align='center',
240
+ alpha=0.7,
241
+ ecolor='grey',
242
+ capsize=3,
243
+ color='cornflowerblue'
244
+ )
245
+
246
+ plt.title('Top Features by Attention')
247
+ plt.xlabel('Average Attention Weight')
248
+ plt.ylabel('Feature')
249
+ plt.grid(axis='x', linestyle='--', alpha=0.6)
250
+ plt.tight_layout()
251
+
252
+ plot_path = save_dir_path / "attention_importance.svg"
253
+ plt.savefig(plot_path)
254
+ _LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
255
+ plt.close()
256
+
257
+
258
+ def multi_target_shap_summary_plot(
259
+ model: torch.nn.Module,
260
+ background_data: Union[torch.Tensor, np.ndarray],
261
+ instances_to_explain: Union[torch.Tensor, np.ndarray],
262
+ feature_names: list[str],
263
+ target_names: list[str],
264
+ save_dir: Union[str, Path],
265
+ device: torch.device = torch.device('cpu'),
266
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
267
+ ):
268
+ """
269
+ DEPRECATED
270
+
271
+ Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
272
+
273
+ Args:
274
+ model (torch.nn.Module): The trained PyTorch model.
275
+ background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
276
+ instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
277
+ feature_names (List[str]): Names of the features for plot labeling.
278
+ target_names (List[str]): Names of the output targets.
279
+ save_dir (str | Path): Directory to save SHAP artifacts.
280
+ device (torch.device): The torch device for SHAP calculations.
281
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
282
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
283
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
284
+ """
285
+ _LOGGER.warning("This function is deprecated and may be removed in future versions. Use Captum module instead.")
286
+
287
+ _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
288
+ model.eval()
289
+ # model.cpu()
290
+
291
+ shap_values_list = None
292
+ instances_to_explain_np = None
293
+
294
+ if explainer_type == 'deep':
295
+ # --- 1. Use DeepExplainer ---
296
+
297
+ # Ensure data is torch.Tensor
298
+ if isinstance(background_data, np.ndarray):
299
+ background_data = torch.from_numpy(background_data).float()
300
+ if isinstance(instances_to_explain, np.ndarray):
301
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
302
+
303
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
304
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
305
+ return
306
+
307
+ background_data = background_data.to(device)
308
+ instances_to_explain = instances_to_explain.to(device)
309
+
310
+ with warnings.catch_warnings():
311
+ warnings.simplefilter("ignore", category=UserWarning)
312
+ explainer = shap.DeepExplainer(model, background_data)
313
+
314
+ # print("Calculating SHAP values with DeepExplainer...")
315
+ # DeepExplainer returns a list of arrays for multi-output models
316
+ shap_values_list = explainer.shap_values(instances_to_explain)
317
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
318
+
319
+ elif explainer_type == 'kernel':
320
+ # --- 2. Use KernelExplainer ---
321
+ _LOGGER.warning(
322
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
323
+ )
324
+
325
+ # Convert all data to numpy
326
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
327
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
328
+
329
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
330
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
331
+ return
332
+
333
+ background_summary = shap.kmeans(background_data_np, 30)
334
+
335
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
336
+ x_torch = torch.from_numpy(x_np).float().to(device)
337
+ with torch.no_grad():
338
+ output = model(x_torch)
339
+ return output.cpu().numpy() # Return full multi-output array
340
+
341
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
342
+ # print("Calculating SHAP values with KernelExplainer...")
343
+ # KernelExplainer also returns a list of arrays for multi-output models
344
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
345
+ # instances_to_explain_np is already set
346
+
347
+ else:
348
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
349
+ raise ValueError("Invalid explainer_type")
350
+
351
+ # --- 3. Plotting and Saving (Common Logic) ---
352
+
353
+ if shap_values_list is None or instances_to_explain_np is None:
354
+ _LOGGER.error("SHAP value calculation failed. Aborting plotting.")
355
+ return
356
+
357
+ # Ensure number of SHAP value arrays matches number of target names
358
+ if len(shap_values_list) != len(target_names):
359
+ _LOGGER.error(
360
+ f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
361
+ f"outputs, but {len(target_names)} target_names were provided."
362
+ )
363
+ return
364
+
365
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
366
+ plt.ioff()
367
+
368
+ # Iterate through each target's SHAP values and generate plots.
369
+ for i, target_name in enumerate(target_names):
370
+ print(f" -> Generating SHAP plots for target: '{target_name}'")
371
+ shap_values_for_target = shap_values_list[i]
372
+ sanitized_target_name = sanitize_filename(target_name)
373
+
374
+ # Save Bar Plot for the target
375
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
376
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
377
+ plt.tight_layout()
378
+ bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
379
+ plt.savefig(bar_path)
380
+ plt.close()
381
+
382
+ # Save Dot Plot for the target
383
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
384
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
385
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
386
+ cb = plt.gcf().axes[-1]
387
+ cb.set_ylabel("", size=1)
388
+ plt.tight_layout()
389
+ dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
390
+ plt.savefig(dot_path)
391
+ plt.close()
392
+
393
+ # --- Save Summary Data to CSV for this target ---
394
+ shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
395
+ summary_path = save_dir_path / shap_summary_filename
396
+
397
+ # For a specific target, shap_values_for_target is just a 2D array
398
+ mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
399
+
400
+ summary_df = pd.DataFrame({
401
+ SHAPKeys.FEATURE_COLUMN: feature_names,
402
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
403
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
404
+
405
+ summary_df.to_csv(summary_path, index=False)
406
+
407
+ plt.ion()
408
+ _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
409
+
@@ -0,0 +1,25 @@
1
+ from .._core import _imprimir_disponibles
2
+
3
+ _GRUPOS = [
4
+ # regression
5
+ "regression_metrics",
6
+ "multi_target_regression_metrics",
7
+ # classification
8
+ "classification_metrics",
9
+ "multi_label_classification_metrics",
10
+ # loss
11
+ "plot_losses",
12
+ # feature importance
13
+ "shap_summary_plot",
14
+ "multi_target_shap_summary_plot",
15
+ "plot_attention_importance",
16
+ # sequence
17
+ "sequence_to_value_metrics",
18
+ "sequence_to_sequence_metrics",
19
+ # vision
20
+ "segmentation_metrics",
21
+ "object_detection_metrics",
22
+ ]
23
+
24
+ def info():
25
+ _imprimir_disponibles(_GRUPOS)
@@ -0,0 +1,92 @@
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ from ..path_manager import make_fullpath
8
+ from .._core import get_logger
9
+ from ..keys._keys import PyTorchLogKeys, _EvaluationConfig
10
+
11
+
12
+ _LOGGER = get_logger("Loss Plot")
13
+
14
+
15
+ __all__ = [
16
+ "plot_losses",
17
+ ]
18
+
19
+
20
+ DPI_value = _EvaluationConfig.DPI
21
+
22
+
23
+ def plot_losses(history: dict, save_dir: Union[str, Path]):
24
+ """
25
+ Plots training & validation loss curves from a history object.
26
+ Also plots the learning rate if available in the history.
27
+
28
+ Args:
29
+ history (dict): A dictionary containing 'train_loss' and 'val_loss'.
30
+ save_dir (str | Path): Directory to save the plot image.
31
+ """
32
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
33
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
34
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
35
+
36
+ if not train_loss and not val_loss:
37
+ _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
38
+ return
39
+
40
+ fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
41
+
42
+ # --- Plot Losses (Left Y-axis) ---
43
+ line_handles = [] # To store line objects for the legend
44
+
45
+ # Plot training loss only if data for it exists
46
+ if train_loss:
47
+ epochs = range(1, len(train_loss) + 1)
48
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
49
+ line_handles.append(line1)
50
+
51
+ # Plot validation loss only if data for it exists
52
+ if val_loss:
53
+ epochs = range(1, len(val_loss) + 1)
54
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
55
+ line_handles.append(line2)
56
+
57
+ ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
58
+ ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
59
+ ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
60
+ ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
61
+ ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
62
+ ax.grid(True, linestyle='--')
63
+
64
+ # --- Plot Learning Rate (Right Y-axis) ---
65
+ if lr_history:
66
+ ax2 = ax.twinx() # Create a second y-axis
67
+ epochs = range(1, len(lr_history) + 1)
68
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
69
+ line_handles.append(line3)
70
+
71
+ ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
72
+ ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
73
+ # Use scientific notation if the LR is very small
74
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
75
+ # increase the size of the scientific notation
76
+ ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
77
+ # remove grid from second y-axis
78
+ ax2.grid(False)
79
+
80
+ # Combine legends from both axes
81
+ ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
82
+
83
+ # ax.grid(True)
84
+ plt.tight_layout()
85
+
86
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
87
+ save_path = save_dir_path / "loss_plot.svg"
88
+ plt.savefig(save_path)
89
+ _LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
90
+
91
+ plt.close(fig)
92
+