dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import torch
|
|
6
|
+
import shap
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Union, Optional, Literal
|
|
9
|
+
import warnings
|
|
10
|
+
|
|
11
|
+
from ..path_manager import make_fullpath, sanitize_filename
|
|
12
|
+
from .._core import get_logger
|
|
13
|
+
from ..keys._keys import SHAPKeys, _EvaluationConfig
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
_LOGGER = get_logger("Feature Importance")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
"shap_summary_plot",
|
|
21
|
+
"plot_attention_importance",
|
|
22
|
+
"multi_target_shap_summary_plot",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
DPI_value = _EvaluationConfig.DPI
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def shap_summary_plot(model,
|
|
30
|
+
background_data: Union[torch.Tensor,np.ndarray],
|
|
31
|
+
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
32
|
+
feature_names: Optional[list[str]],
|
|
33
|
+
save_dir: Union[str, Path],
|
|
34
|
+
device: torch.device = torch.device('cpu'),
|
|
35
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
36
|
+
"""
|
|
37
|
+
Calculates SHAP values and saves summary plots and data.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
model (nn.Module): The trained PyTorch model.
|
|
41
|
+
background_data (torch.Tensor): A sample of data for the explainer background.
|
|
42
|
+
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
43
|
+
feature_names (list of str | None): Names of the features for plot labeling.
|
|
44
|
+
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
45
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
46
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
47
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
48
|
+
PyTorch models.
|
|
49
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
50
|
+
slow and memory-intensive.
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
54
|
+
|
|
55
|
+
model.eval()
|
|
56
|
+
# model.cpu() # Run explanations on CPU
|
|
57
|
+
|
|
58
|
+
shap_values = None
|
|
59
|
+
instances_to_explain_np = None
|
|
60
|
+
|
|
61
|
+
if explainer_type == 'deep':
|
|
62
|
+
# --- 1. Use DeepExplainer ---
|
|
63
|
+
|
|
64
|
+
# Ensure data is torch.Tensor
|
|
65
|
+
if isinstance(background_data, np.ndarray):
|
|
66
|
+
background_data = torch.from_numpy(background_data).float()
|
|
67
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
68
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
69
|
+
|
|
70
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
71
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
72
|
+
return
|
|
73
|
+
|
|
74
|
+
background_data = background_data.to(device)
|
|
75
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
76
|
+
|
|
77
|
+
with warnings.catch_warnings():
|
|
78
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
79
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
80
|
+
|
|
81
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
82
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
83
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
84
|
+
|
|
85
|
+
elif explainer_type == 'kernel':
|
|
86
|
+
# --- 2. Use KernelExplainer ---
|
|
87
|
+
_LOGGER.warning(
|
|
88
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Ensure data is np.ndarray
|
|
92
|
+
if isinstance(background_data, torch.Tensor):
|
|
93
|
+
background_data_np = background_data.cpu().numpy()
|
|
94
|
+
else:
|
|
95
|
+
background_data_np = background_data
|
|
96
|
+
|
|
97
|
+
if isinstance(instances_to_explain, torch.Tensor):
|
|
98
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
99
|
+
else:
|
|
100
|
+
instances_to_explain_np = instances_to_explain
|
|
101
|
+
|
|
102
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
103
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
# Summarize background data
|
|
107
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
108
|
+
|
|
109
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
110
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
111
|
+
with torch.no_grad():
|
|
112
|
+
output = model(x_torch)
|
|
113
|
+
# Return as numpy array
|
|
114
|
+
return output.cpu().numpy()
|
|
115
|
+
|
|
116
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
117
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
118
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
119
|
+
# instances_to_explain_np is already set
|
|
120
|
+
|
|
121
|
+
else:
|
|
122
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
123
|
+
raise ValueError()
|
|
124
|
+
|
|
125
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
126
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
127
|
+
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
128
|
+
|
|
129
|
+
# --- 3. Plotting and Saving ---
|
|
130
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
131
|
+
plt.ioff()
|
|
132
|
+
|
|
133
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
134
|
+
if feature_names is None:
|
|
135
|
+
# Create generic names if none were provided
|
|
136
|
+
num_features = instances_to_explain_np.shape[1]
|
|
137
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
138
|
+
|
|
139
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
140
|
+
|
|
141
|
+
# Save Bar Plot
|
|
142
|
+
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
143
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
144
|
+
ax = plt.gca()
|
|
145
|
+
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
146
|
+
plt.title("SHAP Feature Importance")
|
|
147
|
+
plt.tight_layout()
|
|
148
|
+
plt.savefig(bar_path)
|
|
149
|
+
_LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
|
|
150
|
+
plt.close()
|
|
151
|
+
|
|
152
|
+
# Save Dot Plot
|
|
153
|
+
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
154
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
155
|
+
ax = plt.gca()
|
|
156
|
+
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
157
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
158
|
+
cb = plt.gcf().axes[-1]
|
|
159
|
+
cb.set_ylabel("", size=1)
|
|
160
|
+
plt.title("SHAP Feature Importance")
|
|
161
|
+
plt.tight_layout()
|
|
162
|
+
plt.savefig(dot_path)
|
|
163
|
+
_LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
|
|
164
|
+
plt.close()
|
|
165
|
+
|
|
166
|
+
# Save Summary Data to CSV
|
|
167
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
168
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
169
|
+
|
|
170
|
+
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
171
|
+
if isinstance(shap_values, list):
|
|
172
|
+
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
173
|
+
else:
|
|
174
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
175
|
+
|
|
176
|
+
mean_abs_shap = mean_abs_shap.flatten()
|
|
177
|
+
|
|
178
|
+
summary_df = pd.DataFrame({
|
|
179
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
180
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
181
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
182
|
+
|
|
183
|
+
summary_df.to_csv(summary_path, index=False)
|
|
184
|
+
|
|
185
|
+
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
186
|
+
plt.ion()
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def plot_attention_importance(weights: list[torch.Tensor], feature_names: Optional[list[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
190
|
+
"""
|
|
191
|
+
Aggregates attention weights and plots global feature importance.
|
|
192
|
+
|
|
193
|
+
The plot shows the mean attention for each feature as a bar, with the
|
|
194
|
+
standard deviation represented by error bars.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
198
|
+
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
199
|
+
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
200
|
+
top_n (int): The number of top features to display in the plot.
|
|
201
|
+
"""
|
|
202
|
+
if not weights:
|
|
203
|
+
_LOGGER.error("Attention weights list is empty. Skipping importance plot.")
|
|
204
|
+
return
|
|
205
|
+
|
|
206
|
+
# --- Step 1: Aggregate data ---
|
|
207
|
+
# Concatenate the list of tensors into a single large tensor
|
|
208
|
+
full_weights_tensor = torch.cat(weights, dim=0)
|
|
209
|
+
|
|
210
|
+
# Calculate mean and std dev across the batch dimension (dim=0)
|
|
211
|
+
mean_weights = full_weights_tensor.mean(dim=0)
|
|
212
|
+
std_weights = full_weights_tensor.std(dim=0)
|
|
213
|
+
|
|
214
|
+
# --- Step 2: Create and save summary DataFrame ---
|
|
215
|
+
if feature_names is None:
|
|
216
|
+
feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
|
|
217
|
+
|
|
218
|
+
summary_df = pd.DataFrame({
|
|
219
|
+
'feature': feature_names,
|
|
220
|
+
'mean_attention': mean_weights.numpy(),
|
|
221
|
+
'std_attention': std_weights.numpy()
|
|
222
|
+
}).sort_values('mean_attention', ascending=False)
|
|
223
|
+
|
|
224
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
225
|
+
summary_path = save_dir_path / "attention_summary.csv"
|
|
226
|
+
summary_df.to_csv(summary_path, index=False)
|
|
227
|
+
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
228
|
+
|
|
229
|
+
# --- Step 3: Create and save the plot for top N features ---
|
|
230
|
+
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
231
|
+
|
|
232
|
+
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
233
|
+
|
|
234
|
+
# Create horizontal bar plot with error bars
|
|
235
|
+
plt.barh(
|
|
236
|
+
y=plot_df['feature'],
|
|
237
|
+
width=plot_df['mean_attention'],
|
|
238
|
+
xerr=plot_df['std_attention'],
|
|
239
|
+
align='center',
|
|
240
|
+
alpha=0.7,
|
|
241
|
+
ecolor='grey',
|
|
242
|
+
capsize=3,
|
|
243
|
+
color='cornflowerblue'
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
plt.title('Top Features by Attention')
|
|
247
|
+
plt.xlabel('Average Attention Weight')
|
|
248
|
+
plt.ylabel('Feature')
|
|
249
|
+
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
|
250
|
+
plt.tight_layout()
|
|
251
|
+
|
|
252
|
+
plot_path = save_dir_path / "attention_importance.svg"
|
|
253
|
+
plt.savefig(plot_path)
|
|
254
|
+
_LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
|
|
255
|
+
plt.close()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def multi_target_shap_summary_plot(
|
|
259
|
+
model: torch.nn.Module,
|
|
260
|
+
background_data: Union[torch.Tensor, np.ndarray],
|
|
261
|
+
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
262
|
+
feature_names: list[str],
|
|
263
|
+
target_names: list[str],
|
|
264
|
+
save_dir: Union[str, Path],
|
|
265
|
+
device: torch.device = torch.device('cpu'),
|
|
266
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
267
|
+
):
|
|
268
|
+
"""
|
|
269
|
+
DEPRECATED
|
|
270
|
+
|
|
271
|
+
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
model (torch.nn.Module): The trained PyTorch model.
|
|
275
|
+
background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
|
|
276
|
+
instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
|
|
277
|
+
feature_names (List[str]): Names of the features for plot labeling.
|
|
278
|
+
target_names (List[str]): Names of the output targets.
|
|
279
|
+
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
280
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
281
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
282
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
283
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
284
|
+
"""
|
|
285
|
+
_LOGGER.warning("This function is deprecated and may be removed in future versions. Use Captum module instead.")
|
|
286
|
+
|
|
287
|
+
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
288
|
+
model.eval()
|
|
289
|
+
# model.cpu()
|
|
290
|
+
|
|
291
|
+
shap_values_list = None
|
|
292
|
+
instances_to_explain_np = None
|
|
293
|
+
|
|
294
|
+
if explainer_type == 'deep':
|
|
295
|
+
# --- 1. Use DeepExplainer ---
|
|
296
|
+
|
|
297
|
+
# Ensure data is torch.Tensor
|
|
298
|
+
if isinstance(background_data, np.ndarray):
|
|
299
|
+
background_data = torch.from_numpy(background_data).float()
|
|
300
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
301
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
302
|
+
|
|
303
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
304
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
305
|
+
return
|
|
306
|
+
|
|
307
|
+
background_data = background_data.to(device)
|
|
308
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
309
|
+
|
|
310
|
+
with warnings.catch_warnings():
|
|
311
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
312
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
313
|
+
|
|
314
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
315
|
+
# DeepExplainer returns a list of arrays for multi-output models
|
|
316
|
+
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
317
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
318
|
+
|
|
319
|
+
elif explainer_type == 'kernel':
|
|
320
|
+
# --- 2. Use KernelExplainer ---
|
|
321
|
+
_LOGGER.warning(
|
|
322
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Convert all data to numpy
|
|
326
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
327
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
328
|
+
|
|
329
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
330
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
331
|
+
return
|
|
332
|
+
|
|
333
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
334
|
+
|
|
335
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
336
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
337
|
+
with torch.no_grad():
|
|
338
|
+
output = model(x_torch)
|
|
339
|
+
return output.cpu().numpy() # Return full multi-output array
|
|
340
|
+
|
|
341
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
342
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
343
|
+
# KernelExplainer also returns a list of arrays for multi-output models
|
|
344
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
345
|
+
# instances_to_explain_np is already set
|
|
346
|
+
|
|
347
|
+
else:
|
|
348
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
349
|
+
raise ValueError("Invalid explainer_type")
|
|
350
|
+
|
|
351
|
+
# --- 3. Plotting and Saving (Common Logic) ---
|
|
352
|
+
|
|
353
|
+
if shap_values_list is None or instances_to_explain_np is None:
|
|
354
|
+
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
355
|
+
return
|
|
356
|
+
|
|
357
|
+
# Ensure number of SHAP value arrays matches number of target names
|
|
358
|
+
if len(shap_values_list) != len(target_names):
|
|
359
|
+
_LOGGER.error(
|
|
360
|
+
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
361
|
+
f"outputs, but {len(target_names)} target_names were provided."
|
|
362
|
+
)
|
|
363
|
+
return
|
|
364
|
+
|
|
365
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
366
|
+
plt.ioff()
|
|
367
|
+
|
|
368
|
+
# Iterate through each target's SHAP values and generate plots.
|
|
369
|
+
for i, target_name in enumerate(target_names):
|
|
370
|
+
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
371
|
+
shap_values_for_target = shap_values_list[i]
|
|
372
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
373
|
+
|
|
374
|
+
# Save Bar Plot for the target
|
|
375
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
376
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
377
|
+
plt.tight_layout()
|
|
378
|
+
bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
|
|
379
|
+
plt.savefig(bar_path)
|
|
380
|
+
plt.close()
|
|
381
|
+
|
|
382
|
+
# Save Dot Plot for the target
|
|
383
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
384
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
385
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
386
|
+
cb = plt.gcf().axes[-1]
|
|
387
|
+
cb.set_ylabel("", size=1)
|
|
388
|
+
plt.tight_layout()
|
|
389
|
+
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
390
|
+
plt.savefig(dot_path)
|
|
391
|
+
plt.close()
|
|
392
|
+
|
|
393
|
+
# --- Save Summary Data to CSV for this target ---
|
|
394
|
+
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
395
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
396
|
+
|
|
397
|
+
# For a specific target, shap_values_for_target is just a 2D array
|
|
398
|
+
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
399
|
+
|
|
400
|
+
summary_df = pd.DataFrame({
|
|
401
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
402
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
403
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
404
|
+
|
|
405
|
+
summary_df.to_csv(summary_path, index=False)
|
|
406
|
+
|
|
407
|
+
plt.ion()
|
|
408
|
+
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
409
|
+
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .._core import _imprimir_disponibles
|
|
2
|
+
|
|
3
|
+
_GRUPOS = [
|
|
4
|
+
# regression
|
|
5
|
+
"regression_metrics",
|
|
6
|
+
"multi_target_regression_metrics",
|
|
7
|
+
# classification
|
|
8
|
+
"classification_metrics",
|
|
9
|
+
"multi_label_classification_metrics",
|
|
10
|
+
# loss
|
|
11
|
+
"plot_losses",
|
|
12
|
+
# feature importance
|
|
13
|
+
"shap_summary_plot",
|
|
14
|
+
"multi_target_shap_summary_plot",
|
|
15
|
+
"plot_attention_importance",
|
|
16
|
+
# sequence
|
|
17
|
+
"sequence_to_value_metrics",
|
|
18
|
+
"sequence_to_sequence_metrics",
|
|
19
|
+
# vision
|
|
20
|
+
"segmentation_metrics",
|
|
21
|
+
"object_detection_metrics",
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
def info():
|
|
25
|
+
_imprimir_disponibles(_GRUPOS)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
import seaborn as sns
|
|
3
|
+
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union
|
|
6
|
+
|
|
7
|
+
from ..path_manager import make_fullpath
|
|
8
|
+
from .._core import get_logger
|
|
9
|
+
from ..keys._keys import PyTorchLogKeys, _EvaluationConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_LOGGER = get_logger("Loss Plot")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"plot_losses",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DPI_value = _EvaluationConfig.DPI
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
24
|
+
"""
|
|
25
|
+
Plots training & validation loss curves from a history object.
|
|
26
|
+
Also plots the learning rate if available in the history.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
30
|
+
save_dir (str | Path): Directory to save the plot image.
|
|
31
|
+
"""
|
|
32
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
33
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
34
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
35
|
+
|
|
36
|
+
if not train_loss and not val_loss:
|
|
37
|
+
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
fig, ax = plt.subplots(figsize=_EvaluationConfig.LOSS_PLOT_SIZE, dpi=DPI_value)
|
|
41
|
+
|
|
42
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
43
|
+
line_handles = [] # To store line objects for the legend
|
|
44
|
+
|
|
45
|
+
# Plot training loss only if data for it exists
|
|
46
|
+
if train_loss:
|
|
47
|
+
epochs = range(1, len(train_loss) + 1)
|
|
48
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
49
|
+
line_handles.append(line1)
|
|
50
|
+
|
|
51
|
+
# Plot validation loss only if data for it exists
|
|
52
|
+
if val_loss:
|
|
53
|
+
epochs = range(1, len(val_loss) + 1)
|
|
54
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
55
|
+
line_handles.append(line2)
|
|
56
|
+
|
|
57
|
+
ax.set_title('Training and Validation Loss', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE + 2, pad=_EvaluationConfig.LABEL_PADDING)
|
|
58
|
+
ax.set_xlabel('Epochs', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
59
|
+
ax.set_ylabel('Loss', color='tab:blue', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
60
|
+
ax.tick_params(axis='y', labelcolor='tab:blue', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
61
|
+
ax.tick_params(axis='x', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
62
|
+
ax.grid(True, linestyle='--')
|
|
63
|
+
|
|
64
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
65
|
+
if lr_history:
|
|
66
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
67
|
+
epochs = range(1, len(lr_history) + 1)
|
|
68
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
69
|
+
line_handles.append(line3)
|
|
70
|
+
|
|
71
|
+
ax2.set_ylabel('Learning Rate', color='g', fontsize=_EvaluationConfig.LOSS_PLOT_LABEL_SIZE, labelpad=_EvaluationConfig.LABEL_PADDING)
|
|
72
|
+
ax2.tick_params(axis='y', labelcolor='g', labelsize=_EvaluationConfig.LOSS_PLOT_TICK_SIZE)
|
|
73
|
+
# Use scientific notation if the LR is very small
|
|
74
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
75
|
+
# increase the size of the scientific notation
|
|
76
|
+
ax2.yaxis.get_offset_text().set_fontsize(_EvaluationConfig.LOSS_PLOT_TICK_SIZE - 2)
|
|
77
|
+
# remove grid from second y-axis
|
|
78
|
+
ax2.grid(False)
|
|
79
|
+
|
|
80
|
+
# Combine legends from both axes
|
|
81
|
+
ax.legend(handles=line_handles, loc='best', fontsize=_EvaluationConfig.LOSS_PLOT_LEGEND_SIZE)
|
|
82
|
+
|
|
83
|
+
# ax.grid(True)
|
|
84
|
+
plt.tight_layout()
|
|
85
|
+
|
|
86
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
87
|
+
save_path = save_dir_path / "loss_plot.svg"
|
|
88
|
+
plt.savefig(save_path)
|
|
89
|
+
_LOGGER.info(f"📉 Loss plot saved as '{save_path.name}'")
|
|
90
|
+
|
|
91
|
+
plt.close(fig)
|
|
92
|
+
|