dragon-ml-toolbox 19.14.0__py3-none-any.whl → 20.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/METADATA +29 -46
- dragon_ml_toolbox-20.0.0.dist-info/RECORD +178 -0
- ml_tools/{ETL_cleaning.py → ETL_cleaning/__init__.py} +13 -5
- ml_tools/ETL_cleaning/_basic_clean.py +351 -0
- ml_tools/ETL_cleaning/_clean_tools.py +128 -0
- ml_tools/ETL_cleaning/_dragon_cleaner.py +245 -0
- ml_tools/ETL_cleaning/_imprimir.py +13 -0
- ml_tools/{ETL_engineering.py → ETL_engineering/__init__.py} +8 -4
- ml_tools/ETL_engineering/_dragon_engineering.py +261 -0
- ml_tools/ETL_engineering/_imprimir.py +24 -0
- ml_tools/{_core/_ETL_engineering.py → ETL_engineering/_transforms.py} +14 -267
- ml_tools/{_core → GUI_tools}/_GUI_tools.py +37 -40
- ml_tools/{GUI_tools.py → GUI_tools/__init__.py} +7 -5
- ml_tools/GUI_tools/_imprimir.py +12 -0
- ml_tools/IO_tools/_IO_loggers.py +235 -0
- ml_tools/IO_tools/_IO_save_load.py +151 -0
- ml_tools/IO_tools/_IO_utils.py +140 -0
- ml_tools/{IO_tools.py → IO_tools/__init__.py} +13 -5
- ml_tools/IO_tools/_imprimir.py +14 -0
- ml_tools/MICE/_MICE_imputation.py +132 -0
- ml_tools/{MICE_imputation.py → MICE/__init__.py} +6 -7
- ml_tools/{_core/_MICE_imputation.py → MICE/_dragon_mice.py} +243 -322
- ml_tools/MICE/_imprimir.py +11 -0
- ml_tools/{ML_callbacks.py → ML_callbacks/__init__.py} +12 -4
- ml_tools/ML_callbacks/_base.py +101 -0
- ml_tools/ML_callbacks/_checkpoint.py +232 -0
- ml_tools/ML_callbacks/_early_stop.py +208 -0
- ml_tools/ML_callbacks/_imprimir.py +12 -0
- ml_tools/ML_callbacks/_scheduler.py +197 -0
- ml_tools/{ML_chaining_utilities.py → ML_chain/__init__.py} +8 -3
- ml_tools/{_core/_ML_chaining_utilities.py → ML_chain/_chaining_tools.py} +5 -129
- ml_tools/ML_chain/_dragon_chain.py +140 -0
- ml_tools/ML_chain/_imprimir.py +11 -0
- ml_tools/ML_configuration/__init__.py +90 -0
- ml_tools/ML_configuration/_base_model_config.py +69 -0
- ml_tools/ML_configuration/_finalize.py +366 -0
- ml_tools/ML_configuration/_imprimir.py +47 -0
- ml_tools/ML_configuration/_metrics.py +593 -0
- ml_tools/ML_configuration/_models.py +206 -0
- ml_tools/ML_configuration/_training.py +124 -0
- ml_tools/ML_datasetmaster/__init__.py +28 -0
- ml_tools/ML_datasetmaster/_base_datasetmaster.py +337 -0
- ml_tools/{_core/_ML_datasetmaster.py → ML_datasetmaster/_datasetmaster.py} +9 -329
- ml_tools/ML_datasetmaster/_imprimir.py +15 -0
- ml_tools/{_core/_ML_sequence_datasetmaster.py → ML_datasetmaster/_sequence_datasetmaster.py} +13 -15
- ml_tools/{_core/_ML_vision_datasetmaster.py → ML_datasetmaster/_vision_datasetmaster.py} +63 -65
- ml_tools/ML_evaluation/__init__.py +53 -0
- ml_tools/ML_evaluation/_classification.py +629 -0
- ml_tools/ML_evaluation/_feature_importance.py +409 -0
- ml_tools/ML_evaluation/_imprimir.py +25 -0
- ml_tools/ML_evaluation/_loss.py +92 -0
- ml_tools/ML_evaluation/_regression.py +273 -0
- ml_tools/{_core/_ML_sequence_evaluation.py → ML_evaluation/_sequence.py} +8 -11
- ml_tools/{_core/_ML_vision_evaluation.py → ML_evaluation/_vision.py} +12 -17
- ml_tools/{_core → ML_evaluation_captum}/_ML_evaluation_captum.py +11 -38
- ml_tools/{ML_evaluation_captum.py → ML_evaluation_captum/__init__.py} +6 -4
- ml_tools/ML_evaluation_captum/_imprimir.py +10 -0
- ml_tools/{_core → ML_finalize_handler}/_ML_finalize_handler.py +3 -7
- ml_tools/ML_finalize_handler/__init__.py +10 -0
- ml_tools/ML_finalize_handler/_imprimir.py +8 -0
- ml_tools/ML_inference/__init__.py +22 -0
- ml_tools/ML_inference/_base_inference.py +166 -0
- ml_tools/{_core/_ML_chaining_inference.py → ML_inference/_chain_inference.py} +14 -17
- ml_tools/ML_inference/_dragon_inference.py +332 -0
- ml_tools/ML_inference/_imprimir.py +11 -0
- ml_tools/ML_inference/_multi_inference.py +180 -0
- ml_tools/ML_inference_sequence/__init__.py +10 -0
- ml_tools/ML_inference_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_inference.py → ML_inference_sequence/_sequence_inference.py} +11 -15
- ml_tools/ML_inference_vision/__init__.py +10 -0
- ml_tools/ML_inference_vision/_imprimir.py +8 -0
- ml_tools/{_core/_ML_vision_inference.py → ML_inference_vision/_vision_inference.py} +15 -19
- ml_tools/ML_models/__init__.py +32 -0
- ml_tools/{_core/_ML_models_advanced.py → ML_models/_advanced_models.py} +22 -18
- ml_tools/ML_models/_base_mlp_attention.py +198 -0
- ml_tools/{_core/_models_advanced_base.py → ML_models/_base_save_load.py} +73 -49
- ml_tools/ML_models/_dragon_tabular.py +248 -0
- ml_tools/ML_models/_imprimir.py +18 -0
- ml_tools/ML_models/_mlp_attention.py +134 -0
- ml_tools/{_core → ML_models}/_models_advanced_helpers.py +13 -13
- ml_tools/ML_models_sequence/__init__.py +10 -0
- ml_tools/ML_models_sequence/_imprimir.py +8 -0
- ml_tools/{_core/_ML_sequence_models.py → ML_models_sequence/_sequence_models.py} +5 -8
- ml_tools/ML_models_vision/__init__.py +29 -0
- ml_tools/ML_models_vision/_base_wrapper.py +254 -0
- ml_tools/ML_models_vision/_image_classification.py +182 -0
- ml_tools/ML_models_vision/_image_segmentation.py +108 -0
- ml_tools/ML_models_vision/_imprimir.py +16 -0
- ml_tools/ML_models_vision/_object_detection.py +135 -0
- ml_tools/ML_optimization/__init__.py +21 -0
- ml_tools/ML_optimization/_imprimir.py +13 -0
- ml_tools/{_core/_ML_optimization_pareto.py → ML_optimization/_multi_dragon.py} +18 -24
- ml_tools/ML_optimization/_single_dragon.py +203 -0
- ml_tools/{_core/_ML_optimization.py → ML_optimization/_single_manual.py} +75 -213
- ml_tools/{_core → ML_scaler}/_ML_scaler.py +8 -11
- ml_tools/ML_scaler/__init__.py +10 -0
- ml_tools/ML_scaler/_imprimir.py +8 -0
- ml_tools/ML_trainer/__init__.py +20 -0
- ml_tools/ML_trainer/_base_trainer.py +297 -0
- ml_tools/ML_trainer/_dragon_detection_trainer.py +402 -0
- ml_tools/ML_trainer/_dragon_sequence_trainer.py +540 -0
- ml_tools/ML_trainer/_dragon_trainer.py +1160 -0
- ml_tools/ML_trainer/_imprimir.py +10 -0
- ml_tools/{ML_utilities.py → ML_utilities/__init__.py} +14 -6
- ml_tools/ML_utilities/_artifact_finder.py +382 -0
- ml_tools/ML_utilities/_imprimir.py +16 -0
- ml_tools/ML_utilities/_inspection.py +325 -0
- ml_tools/ML_utilities/_train_tools.py +205 -0
- ml_tools/{ML_vision_transformers.py → ML_vision_transformers/__init__.py} +9 -6
- ml_tools/{_core/_ML_vision_transformers.py → ML_vision_transformers/_core_transforms.py} +11 -155
- ml_tools/ML_vision_transformers/_imprimir.py +14 -0
- ml_tools/ML_vision_transformers/_offline_augmentation.py +159 -0
- ml_tools/{_core/_PSO_optimization.py → PSO_optimization/_PSO.py} +58 -15
- ml_tools/{PSO_optimization.py → PSO_optimization/__init__.py} +5 -3
- ml_tools/PSO_optimization/_imprimir.py +10 -0
- ml_tools/SQL/__init__.py +7 -0
- ml_tools/{_core/_SQL.py → SQL/_dragon_SQL.py} +7 -11
- ml_tools/SQL/_imprimir.py +8 -0
- ml_tools/{_core → VIF}/_VIF_factor.py +5 -8
- ml_tools/{VIF_factor.py → VIF/__init__.py} +4 -2
- ml_tools/VIF/_imprimir.py +10 -0
- ml_tools/_core/__init__.py +7 -1
- ml_tools/_core/_logger.py +8 -18
- ml_tools/_core/_schema_load_ops.py +43 -0
- ml_tools/_core/_script_info.py +2 -2
- ml_tools/{data_exploration.py → data_exploration/__init__.py} +32 -16
- ml_tools/data_exploration/_analysis.py +214 -0
- ml_tools/data_exploration/_cleaning.py +566 -0
- ml_tools/data_exploration/_features.py +583 -0
- ml_tools/data_exploration/_imprimir.py +32 -0
- ml_tools/data_exploration/_plotting.py +487 -0
- ml_tools/data_exploration/_schema_ops.py +176 -0
- ml_tools/{ensemble_evaluation.py → ensemble_evaluation/__init__.py} +6 -4
- ml_tools/{_core → ensemble_evaluation}/_ensemble_evaluation.py +3 -7
- ml_tools/ensemble_evaluation/_imprimir.py +14 -0
- ml_tools/{ensemble_inference.py → ensemble_inference/__init__.py} +5 -3
- ml_tools/{_core → ensemble_inference}/_ensemble_inference.py +15 -18
- ml_tools/ensemble_inference/_imprimir.py +9 -0
- ml_tools/{ensemble_learning.py → ensemble_learning/__init__.py} +4 -6
- ml_tools/{_core → ensemble_learning}/_ensemble_learning.py +7 -10
- ml_tools/ensemble_learning/_imprimir.py +10 -0
- ml_tools/{excel_handler.py → excel_handler/__init__.py} +5 -3
- ml_tools/{_core → excel_handler}/_excel_handler.py +6 -10
- ml_tools/excel_handler/_imprimir.py +13 -0
- ml_tools/{keys.py → keys/__init__.py} +4 -1
- ml_tools/keys/_imprimir.py +11 -0
- ml_tools/{_core → keys}/_keys.py +2 -0
- ml_tools/{math_utilities.py → math_utilities/__init__.py} +5 -2
- ml_tools/math_utilities/_imprimir.py +11 -0
- ml_tools/{_core → math_utilities}/_math_utilities.py +1 -5
- ml_tools/{optimization_tools.py → optimization_tools/__init__.py} +9 -4
- ml_tools/optimization_tools/_imprimir.py +13 -0
- ml_tools/optimization_tools/_optimization_bounds.py +236 -0
- ml_tools/optimization_tools/_optimization_plots.py +218 -0
- ml_tools/{path_manager.py → path_manager/__init__.py} +6 -3
- ml_tools/{_core/_path_manager.py → path_manager/_dragonmanager.py} +11 -347
- ml_tools/path_manager/_imprimir.py +15 -0
- ml_tools/path_manager/_path_tools.py +346 -0
- ml_tools/plot_fonts/__init__.py +8 -0
- ml_tools/plot_fonts/_imprimir.py +8 -0
- ml_tools/{_core → plot_fonts}/_plot_fonts.py +2 -5
- ml_tools/schema/__init__.py +15 -0
- ml_tools/schema/_feature_schema.py +223 -0
- ml_tools/schema/_gui_schema.py +191 -0
- ml_tools/schema/_imprimir.py +10 -0
- ml_tools/{serde.py → serde/__init__.py} +4 -2
- ml_tools/serde/_imprimir.py +10 -0
- ml_tools/{_core → serde}/_serde.py +3 -8
- ml_tools/{utilities.py → utilities/__init__.py} +11 -6
- ml_tools/utilities/_imprimir.py +18 -0
- ml_tools/{_core/_utilities.py → utilities/_utility_save_load.py} +13 -190
- ml_tools/utilities/_utility_tools.py +192 -0
- dragon_ml_toolbox-19.14.0.dist-info/RECORD +0 -111
- ml_tools/ML_chaining_inference.py +0 -8
- ml_tools/ML_configuration.py +0 -86
- ml_tools/ML_configuration_pytab.py +0 -14
- ml_tools/ML_datasetmaster.py +0 -10
- ml_tools/ML_evaluation.py +0 -16
- ml_tools/ML_evaluation_multi.py +0 -12
- ml_tools/ML_finalize_handler.py +0 -8
- ml_tools/ML_inference.py +0 -12
- ml_tools/ML_models.py +0 -14
- ml_tools/ML_models_advanced.py +0 -14
- ml_tools/ML_models_pytab.py +0 -14
- ml_tools/ML_optimization.py +0 -14
- ml_tools/ML_optimization_pareto.py +0 -8
- ml_tools/ML_scaler.py +0 -8
- ml_tools/ML_sequence_datasetmaster.py +0 -8
- ml_tools/ML_sequence_evaluation.py +0 -10
- ml_tools/ML_sequence_inference.py +0 -8
- ml_tools/ML_sequence_models.py +0 -8
- ml_tools/ML_trainer.py +0 -12
- ml_tools/ML_vision_datasetmaster.py +0 -12
- ml_tools/ML_vision_evaluation.py +0 -10
- ml_tools/ML_vision_inference.py +0 -8
- ml_tools/ML_vision_models.py +0 -18
- ml_tools/SQL.py +0 -8
- ml_tools/_core/_ETL_cleaning.py +0 -694
- ml_tools/_core/_IO_tools.py +0 -498
- ml_tools/_core/_ML_callbacks.py +0 -702
- ml_tools/_core/_ML_configuration.py +0 -1332
- ml_tools/_core/_ML_configuration_pytab.py +0 -102
- ml_tools/_core/_ML_evaluation.py +0 -867
- ml_tools/_core/_ML_evaluation_multi.py +0 -544
- ml_tools/_core/_ML_inference.py +0 -646
- ml_tools/_core/_ML_models.py +0 -668
- ml_tools/_core/_ML_models_pytab.py +0 -693
- ml_tools/_core/_ML_trainer.py +0 -2323
- ml_tools/_core/_ML_utilities.py +0 -886
- ml_tools/_core/_ML_vision_models.py +0 -644
- ml_tools/_core/_data_exploration.py +0 -1909
- ml_tools/_core/_optimization_tools.py +0 -493
- ml_tools/_core/_schema.py +0 -359
- ml_tools/plot_fonts.py +0 -8
- ml_tools/schema.py +0 -12
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-19.14.0.dist-info → dragon_ml_toolbox-20.0.0.dist-info}/top_level.txt +0 -0
|
@@ -1,544 +0,0 @@
|
|
|
1
|
-
import numpy as np
|
|
2
|
-
import pandas as pd
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import seaborn as sns
|
|
5
|
-
import torch
|
|
6
|
-
import shap
|
|
7
|
-
from sklearn.metrics import (
|
|
8
|
-
classification_report,
|
|
9
|
-
ConfusionMatrixDisplay,
|
|
10
|
-
roc_curve,
|
|
11
|
-
roc_auc_score,
|
|
12
|
-
precision_recall_curve,
|
|
13
|
-
average_precision_score,
|
|
14
|
-
mean_squared_error,
|
|
15
|
-
mean_absolute_error,
|
|
16
|
-
r2_score,
|
|
17
|
-
median_absolute_error,
|
|
18
|
-
hamming_loss,
|
|
19
|
-
jaccard_score
|
|
20
|
-
)
|
|
21
|
-
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal, Optional
|
|
23
|
-
import warnings
|
|
24
|
-
|
|
25
|
-
from ._path_manager import make_fullpath, sanitize_filename
|
|
26
|
-
from ._logger import get_logger
|
|
27
|
-
from ._script_info import _script_info
|
|
28
|
-
from ._keys import SHAPKeys, _EvaluationConfig
|
|
29
|
-
from ._ML_configuration import (MultiTargetRegressionMetricsFormat,
|
|
30
|
-
_BaseRegressionFormat,
|
|
31
|
-
MultiLabelBinaryClassificationMetricsFormat,
|
|
32
|
-
_BaseMultiLabelFormat)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
_LOGGER = get_logger("Evaluation Multi")
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
__all__ = [
|
|
39
|
-
"multi_target_regression_metrics",
|
|
40
|
-
"multi_label_classification_metrics",
|
|
41
|
-
"multi_target_shap_summary_plot",
|
|
42
|
-
]
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
DPI_value = _EvaluationConfig.DPI
|
|
46
|
-
REGRESSION_PLOT_SIZE = _EvaluationConfig.REGRESSION_PLOT_SIZE
|
|
47
|
-
CLASSIFICATION_PLOT_SIZE = _EvaluationConfig.CLASSIFICATION_PLOT_SIZE
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
def multi_target_regression_metrics(
|
|
51
|
-
y_true: np.ndarray,
|
|
52
|
-
y_pred: np.ndarray,
|
|
53
|
-
target_names: List[str],
|
|
54
|
-
save_dir: Union[str, Path],
|
|
55
|
-
config: Optional[MultiTargetRegressionMetricsFormat] = None
|
|
56
|
-
):
|
|
57
|
-
"""
|
|
58
|
-
Calculates and saves regression metrics for each target individually.
|
|
59
|
-
|
|
60
|
-
For each target, this function saves a residual plot and a true vs. predicted plot.
|
|
61
|
-
It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
|
|
62
|
-
for all targets.
|
|
63
|
-
|
|
64
|
-
Args:
|
|
65
|
-
y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
|
|
66
|
-
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
67
|
-
target_names (List[str]): A list of names for the target variables.
|
|
68
|
-
save_dir (str | Path): Directory to save plots and the report.
|
|
69
|
-
config (object): Formatting configuration object.
|
|
70
|
-
"""
|
|
71
|
-
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
72
|
-
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
73
|
-
raise ValueError()
|
|
74
|
-
if y_true.shape != y_pred.shape:
|
|
75
|
-
_LOGGER.error("Shapes of y_true and y_pred must match.")
|
|
76
|
-
raise ValueError()
|
|
77
|
-
if y_true.shape[1] != len(target_names):
|
|
78
|
-
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
79
|
-
raise ValueError()
|
|
80
|
-
|
|
81
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
82
|
-
metrics_summary = []
|
|
83
|
-
|
|
84
|
-
# --- Parse Config or use defaults ---
|
|
85
|
-
if config is None:
|
|
86
|
-
# Create a default config if one wasn't provided
|
|
87
|
-
format_config = _BaseRegressionFormat()
|
|
88
|
-
else:
|
|
89
|
-
format_config = config
|
|
90
|
-
|
|
91
|
-
# --- Set Matplotlib font size ---
|
|
92
|
-
# original_rc_params = plt.rcParams.copy()
|
|
93
|
-
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
94
|
-
|
|
95
|
-
# ticks font sizes
|
|
96
|
-
xtick_size = format_config.xtick_size
|
|
97
|
-
ytick_size = format_config.ytick_size
|
|
98
|
-
base_font_size = format_config.font_size
|
|
99
|
-
|
|
100
|
-
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
101
|
-
|
|
102
|
-
for i, name in enumerate(target_names):
|
|
103
|
-
# print(f" -> Evaluating target: '{name}'")
|
|
104
|
-
true_i = y_true[:, i]
|
|
105
|
-
pred_i = y_pred[:, i]
|
|
106
|
-
sanitized_name = sanitize_filename(name)
|
|
107
|
-
|
|
108
|
-
# --- Calculate Metrics ---
|
|
109
|
-
rmse = np.sqrt(mean_squared_error(true_i, pred_i))
|
|
110
|
-
mae = mean_absolute_error(true_i, pred_i)
|
|
111
|
-
r2 = r2_score(true_i, pred_i)
|
|
112
|
-
medae = median_absolute_error(true_i, pred_i)
|
|
113
|
-
metrics_summary.append({
|
|
114
|
-
'Target': name,
|
|
115
|
-
'RMSE': rmse,
|
|
116
|
-
'MAE': mae,
|
|
117
|
-
'MedAE': medae,
|
|
118
|
-
'R2-score': r2,
|
|
119
|
-
})
|
|
120
|
-
|
|
121
|
-
# --- Save Residual Plot ---
|
|
122
|
-
residuals = true_i - pred_i
|
|
123
|
-
fig_res, ax_res = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
|
|
124
|
-
ax_res.scatter(pred_i, residuals,
|
|
125
|
-
alpha=format_config.scatter_alpha,
|
|
126
|
-
edgecolors='k',
|
|
127
|
-
s=50,
|
|
128
|
-
color=format_config.scatter_color) # Use config color
|
|
129
|
-
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
|
|
130
|
-
ax_res.set_xlabel("Predicted Values", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
131
|
-
ax_res.set_ylabel("Residuals", labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
132
|
-
ax_res.set_title(f"Residual Plot for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
133
|
-
|
|
134
|
-
# Apply Ticks
|
|
135
|
-
ax_res.tick_params(axis='x', labelsize=xtick_size)
|
|
136
|
-
ax_res.tick_params(axis='y', labelsize=ytick_size)
|
|
137
|
-
|
|
138
|
-
ax_res.grid(True, linestyle='--', alpha=0.6)
|
|
139
|
-
plt.tight_layout()
|
|
140
|
-
res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
|
|
141
|
-
plt.savefig(res_path)
|
|
142
|
-
plt.close(fig_res)
|
|
143
|
-
|
|
144
|
-
# --- Save True vs. Predicted Plot ---
|
|
145
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=REGRESSION_PLOT_SIZE, dpi=DPI_value)
|
|
146
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
147
|
-
alpha=format_config.scatter_alpha,
|
|
148
|
-
edgecolors='k',
|
|
149
|
-
s=50,
|
|
150
|
-
color=format_config.scatter_color) # Use config color
|
|
151
|
-
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
152
|
-
linestyle='--',
|
|
153
|
-
lw=2,
|
|
154
|
-
color=format_config.ideal_line_color) # Use config color
|
|
155
|
-
ax_tvp.set_xlabel('True Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
156
|
-
ax_tvp.set_ylabel('Predicted Values', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
157
|
-
ax_tvp.set_title(f"True vs. Predicted for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
158
|
-
|
|
159
|
-
# Apply Ticks
|
|
160
|
-
ax_tvp.tick_params(axis='x', labelsize=xtick_size)
|
|
161
|
-
ax_tvp.tick_params(axis='y', labelsize=ytick_size)
|
|
162
|
-
|
|
163
|
-
ax_tvp.grid(True, linestyle='--', alpha=0.6)
|
|
164
|
-
plt.tight_layout()
|
|
165
|
-
tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
|
|
166
|
-
plt.savefig(tvp_path)
|
|
167
|
-
plt.close(fig_tvp)
|
|
168
|
-
|
|
169
|
-
# --- Save Summary Report ---
|
|
170
|
-
summary_df = pd.DataFrame(metrics_summary)
|
|
171
|
-
report_path = save_dir_path / "regression_report_multi.csv"
|
|
172
|
-
summary_df.to_csv(report_path, index=False)
|
|
173
|
-
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
174
|
-
|
|
175
|
-
# --- Restore RC params ---
|
|
176
|
-
# plt.rcParams.update(original_rc_params)
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
def multi_label_classification_metrics(
|
|
180
|
-
y_true: np.ndarray,
|
|
181
|
-
y_pred: np.ndarray,
|
|
182
|
-
y_prob: np.ndarray,
|
|
183
|
-
target_names: List[str],
|
|
184
|
-
save_dir: Union[str, Path],
|
|
185
|
-
config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
|
|
186
|
-
):
|
|
187
|
-
"""
|
|
188
|
-
Calculates and saves classification metrics for each label individually.
|
|
189
|
-
|
|
190
|
-
This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
|
|
191
|
-
and then iterates through each label to generate and save individual reports,
|
|
192
|
-
confusion matrices, ROC curves, and Precision-Recall curves.
|
|
193
|
-
|
|
194
|
-
Args:
|
|
195
|
-
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
196
|
-
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
197
|
-
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
198
|
-
target_names (List[str]): A list of names for the labels.
|
|
199
|
-
save_dir (str | Path): Directory to save plots and reports.
|
|
200
|
-
config (object): Formatting configuration object.
|
|
201
|
-
"""
|
|
202
|
-
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
203
|
-
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
204
|
-
raise ValueError()
|
|
205
|
-
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
206
|
-
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
207
|
-
raise ValueError()
|
|
208
|
-
if y_true.shape[1] != len(target_names):
|
|
209
|
-
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
210
|
-
raise ValueError()
|
|
211
|
-
|
|
212
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
213
|
-
|
|
214
|
-
# --- Parse Config or use defaults ---
|
|
215
|
-
if config is None:
|
|
216
|
-
# Create a default config if one wasn't provided
|
|
217
|
-
format_config = _BaseMultiLabelFormat()
|
|
218
|
-
else:
|
|
219
|
-
format_config = config
|
|
220
|
-
|
|
221
|
-
# y_pred is now passed in directly, no threshold needed.
|
|
222
|
-
|
|
223
|
-
# --- Save current RC params and update font size ---
|
|
224
|
-
# original_rc_params = plt.rcParams.copy()
|
|
225
|
-
# plt.rcParams.update({'font.size': format_config.font_size})
|
|
226
|
-
|
|
227
|
-
# ticks and legend font sizes
|
|
228
|
-
xtick_size = format_config.xtick_size
|
|
229
|
-
ytick_size = format_config.ytick_size
|
|
230
|
-
legend_size = format_config.legend_size
|
|
231
|
-
base_font_size = format_config.font_size
|
|
232
|
-
|
|
233
|
-
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
234
|
-
h_loss = hamming_loss(y_true, y_pred)
|
|
235
|
-
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
236
|
-
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
237
|
-
|
|
238
|
-
overall_report = (
|
|
239
|
-
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
240
|
-
f"--------------------------------------------------\n"
|
|
241
|
-
f"Hamming Loss: {h_loss:.4f}\n"
|
|
242
|
-
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
243
|
-
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
244
|
-
f"--------------------------------------------------\n"
|
|
245
|
-
)
|
|
246
|
-
# print(overall_report)
|
|
247
|
-
overall_report_path = save_dir_path / "classification_report.txt"
|
|
248
|
-
overall_report_path.write_text(overall_report)
|
|
249
|
-
|
|
250
|
-
# --- Per-Label Metrics and Plots ---
|
|
251
|
-
for i, name in enumerate(target_names):
|
|
252
|
-
print(f" -> Evaluating label: '{name}'")
|
|
253
|
-
true_i = y_true[:, i]
|
|
254
|
-
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
255
|
-
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
256
|
-
sanitized_name = sanitize_filename(name)
|
|
257
|
-
|
|
258
|
-
# --- Save Classification Report for the label (uses y_pred) ---
|
|
259
|
-
report_text = classification_report(true_i, pred_i)
|
|
260
|
-
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
261
|
-
report_path.write_text(report_text) # type: ignore
|
|
262
|
-
|
|
263
|
-
# --- Save Confusion Matrix (uses y_pred) ---
|
|
264
|
-
fig_cm, ax_cm = plt.subplots(figsize=_EvaluationConfig.CM_SIZE, dpi=_EvaluationConfig.DPI)
|
|
265
|
-
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
266
|
-
pred_i,
|
|
267
|
-
cmap=format_config.cmap, # Use config cmap
|
|
268
|
-
ax=ax_cm,
|
|
269
|
-
normalize='true',
|
|
270
|
-
labels=[0, 1],
|
|
271
|
-
display_labels=["Negative", "Positive"],
|
|
272
|
-
colorbar=False)
|
|
273
|
-
|
|
274
|
-
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
275
|
-
|
|
276
|
-
# Turn off gridlines
|
|
277
|
-
ax_cm.grid(False)
|
|
278
|
-
|
|
279
|
-
# Manually update font size of cell texts
|
|
280
|
-
for text in ax_cm.texts:
|
|
281
|
-
text.set_fontsize(base_font_size + 2) # Use config font_size
|
|
282
|
-
|
|
283
|
-
# Apply ticks
|
|
284
|
-
ax_cm.tick_params(axis='x', labelsize=xtick_size)
|
|
285
|
-
ax_cm.tick_params(axis='y', labelsize=ytick_size)
|
|
286
|
-
|
|
287
|
-
# Set titles and labels with padding
|
|
288
|
-
ax_cm.set_title(f"Confusion Matrix for '{name}'", pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
289
|
-
ax_cm.set_xlabel(ax_cm.get_xlabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
290
|
-
ax_cm.set_ylabel(ax_cm.get_ylabel(), labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
291
|
-
|
|
292
|
-
# --- ADJUST COLORBAR FONT & SIZE---
|
|
293
|
-
# Manually add the colorbar with the 'shrink' parameter
|
|
294
|
-
cbar = fig_cm.colorbar(disp_.im_, ax=ax_cm, shrink=0.8)
|
|
295
|
-
|
|
296
|
-
# Update the tick size on the new cbar object
|
|
297
|
-
cbar.ax.tick_params(labelsize=ytick_size) # type: ignore
|
|
298
|
-
|
|
299
|
-
plt.tight_layout()
|
|
300
|
-
|
|
301
|
-
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
302
|
-
plt.savefig(cm_path)
|
|
303
|
-
plt.close(fig_cm)
|
|
304
|
-
|
|
305
|
-
# --- Save ROC Curve (uses y_prob) ---
|
|
306
|
-
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
307
|
-
|
|
308
|
-
try:
|
|
309
|
-
# Calculate Youden's J statistic (tpr - fpr)
|
|
310
|
-
J = tpr - fpr
|
|
311
|
-
# Find the index of the best threshold
|
|
312
|
-
best_index = np.argmax(J)
|
|
313
|
-
optimal_threshold = thresholds[best_index]
|
|
314
|
-
best_tpr = tpr[best_index]
|
|
315
|
-
best_fpr = fpr[best_index]
|
|
316
|
-
|
|
317
|
-
# Define the filename
|
|
318
|
-
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
319
|
-
threshold_path = save_dir_path / threshold_filename
|
|
320
|
-
|
|
321
|
-
# The class name is the target_name for this label
|
|
322
|
-
class_name = name
|
|
323
|
-
|
|
324
|
-
# Create content for the file
|
|
325
|
-
file_content = (
|
|
326
|
-
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
327
|
-
f"Class/Label: {class_name}\n"
|
|
328
|
-
f"--------------------------------------------------\n"
|
|
329
|
-
f"Threshold: {optimal_threshold:.6f}\n"
|
|
330
|
-
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
331
|
-
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
332
|
-
)
|
|
333
|
-
|
|
334
|
-
threshold_path.write_text(file_content, encoding="utf-8")
|
|
335
|
-
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
336
|
-
|
|
337
|
-
except Exception as e:
|
|
338
|
-
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
339
|
-
|
|
340
|
-
auc = roc_auc_score(true_i, prob_i)
|
|
341
|
-
fig_roc, ax_roc = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
342
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
343
|
-
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
344
|
-
|
|
345
|
-
ax_roc.set_title(f'ROC Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
346
|
-
ax_roc.set_xlabel('False Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
347
|
-
ax_roc.set_ylabel('True Positive Rate', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
348
|
-
|
|
349
|
-
# Apply ticks and legend font size
|
|
350
|
-
ax_roc.tick_params(axis='x', labelsize=xtick_size)
|
|
351
|
-
ax_roc.tick_params(axis='y', labelsize=ytick_size)
|
|
352
|
-
ax_roc.legend(loc='lower right', fontsize=legend_size)
|
|
353
|
-
|
|
354
|
-
ax_roc.grid(True, linestyle='--', alpha=0.6)
|
|
355
|
-
|
|
356
|
-
plt.tight_layout()
|
|
357
|
-
|
|
358
|
-
roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
|
|
359
|
-
plt.savefig(roc_path)
|
|
360
|
-
plt.close(fig_roc)
|
|
361
|
-
|
|
362
|
-
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
363
|
-
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
364
|
-
ap_score = average_precision_score(true_i, prob_i)
|
|
365
|
-
fig_pr, ax_pr = plt.subplots(figsize=CLASSIFICATION_PLOT_SIZE, dpi=DPI_value)
|
|
366
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
367
|
-
ax_pr.set_title(f'Precision-Recall Curve for "{name}"', pad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size + 2)
|
|
368
|
-
ax_pr.set_xlabel('Recall', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
369
|
-
ax_pr.set_ylabel('Precision', labelpad=_EvaluationConfig.LABEL_PADDING, fontsize=base_font_size)
|
|
370
|
-
|
|
371
|
-
# Apply ticks and legend font size
|
|
372
|
-
ax_pr.tick_params(axis='x', labelsize=xtick_size)
|
|
373
|
-
ax_pr.tick_params(axis='y', labelsize=ytick_size)
|
|
374
|
-
ax_pr.legend(loc='lower left', fontsize=legend_size)
|
|
375
|
-
|
|
376
|
-
ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
377
|
-
|
|
378
|
-
fig_pr.tight_layout()
|
|
379
|
-
|
|
380
|
-
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
381
|
-
plt.savefig(pr_path)
|
|
382
|
-
plt.close(fig_pr)
|
|
383
|
-
|
|
384
|
-
# restore RC params
|
|
385
|
-
# plt.rcParams.update(original_rc_params)
|
|
386
|
-
|
|
387
|
-
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
def multi_target_shap_summary_plot(
|
|
391
|
-
model: torch.nn.Module,
|
|
392
|
-
background_data: Union[torch.Tensor, np.ndarray],
|
|
393
|
-
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
394
|
-
feature_names: List[str],
|
|
395
|
-
target_names: List[str],
|
|
396
|
-
save_dir: Union[str, Path],
|
|
397
|
-
device: torch.device = torch.device('cpu'),
|
|
398
|
-
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
399
|
-
):
|
|
400
|
-
"""
|
|
401
|
-
DEPRECATED
|
|
402
|
-
|
|
403
|
-
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
404
|
-
|
|
405
|
-
Args:
|
|
406
|
-
model (torch.nn.Module): The trained PyTorch model.
|
|
407
|
-
background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
|
|
408
|
-
instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
|
|
409
|
-
feature_names (List[str]): Names of the features for plot labeling.
|
|
410
|
-
target_names (List[str]): Names of the output targets.
|
|
411
|
-
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
412
|
-
device (torch.device): The torch device for SHAP calculations.
|
|
413
|
-
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
414
|
-
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
415
|
-
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
416
|
-
"""
|
|
417
|
-
_LOGGER.warning("This function is deprecated and may be removed in future versions. Use Captum module instead.")
|
|
418
|
-
|
|
419
|
-
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
420
|
-
model.eval()
|
|
421
|
-
# model.cpu()
|
|
422
|
-
|
|
423
|
-
shap_values_list = None
|
|
424
|
-
instances_to_explain_np = None
|
|
425
|
-
|
|
426
|
-
if explainer_type == 'deep':
|
|
427
|
-
# --- 1. Use DeepExplainer ---
|
|
428
|
-
|
|
429
|
-
# Ensure data is torch.Tensor
|
|
430
|
-
if isinstance(background_data, np.ndarray):
|
|
431
|
-
background_data = torch.from_numpy(background_data).float()
|
|
432
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
433
|
-
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
434
|
-
|
|
435
|
-
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
436
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
437
|
-
return
|
|
438
|
-
|
|
439
|
-
background_data = background_data.to(device)
|
|
440
|
-
instances_to_explain = instances_to_explain.to(device)
|
|
441
|
-
|
|
442
|
-
with warnings.catch_warnings():
|
|
443
|
-
warnings.simplefilter("ignore", category=UserWarning)
|
|
444
|
-
explainer = shap.DeepExplainer(model, background_data)
|
|
445
|
-
|
|
446
|
-
# print("Calculating SHAP values with DeepExplainer...")
|
|
447
|
-
# DeepExplainer returns a list of arrays for multi-output models
|
|
448
|
-
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
449
|
-
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
450
|
-
|
|
451
|
-
elif explainer_type == 'kernel':
|
|
452
|
-
# --- 2. Use KernelExplainer ---
|
|
453
|
-
_LOGGER.warning(
|
|
454
|
-
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
455
|
-
)
|
|
456
|
-
|
|
457
|
-
# Convert all data to numpy
|
|
458
|
-
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
459
|
-
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
460
|
-
|
|
461
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
462
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
463
|
-
return
|
|
464
|
-
|
|
465
|
-
background_summary = shap.kmeans(background_data_np, 30)
|
|
466
|
-
|
|
467
|
-
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
468
|
-
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
469
|
-
with torch.no_grad():
|
|
470
|
-
output = model(x_torch)
|
|
471
|
-
return output.cpu().numpy() # Return full multi-output array
|
|
472
|
-
|
|
473
|
-
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
474
|
-
# print("Calculating SHAP values with KernelExplainer...")
|
|
475
|
-
# KernelExplainer also returns a list of arrays for multi-output models
|
|
476
|
-
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
477
|
-
# instances_to_explain_np is already set
|
|
478
|
-
|
|
479
|
-
else:
|
|
480
|
-
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
481
|
-
raise ValueError("Invalid explainer_type")
|
|
482
|
-
|
|
483
|
-
# --- 3. Plotting and Saving (Common Logic) ---
|
|
484
|
-
|
|
485
|
-
if shap_values_list is None or instances_to_explain_np is None:
|
|
486
|
-
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
487
|
-
return
|
|
488
|
-
|
|
489
|
-
# Ensure number of SHAP value arrays matches number of target names
|
|
490
|
-
if len(shap_values_list) != len(target_names):
|
|
491
|
-
_LOGGER.error(
|
|
492
|
-
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
493
|
-
f"outputs, but {len(target_names)} target_names were provided."
|
|
494
|
-
)
|
|
495
|
-
return
|
|
496
|
-
|
|
497
|
-
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
498
|
-
plt.ioff()
|
|
499
|
-
|
|
500
|
-
# Iterate through each target's SHAP values and generate plots.
|
|
501
|
-
for i, target_name in enumerate(target_names):
|
|
502
|
-
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
503
|
-
shap_values_for_target = shap_values_list[i]
|
|
504
|
-
sanitized_target_name = sanitize_filename(target_name)
|
|
505
|
-
|
|
506
|
-
# Save Bar Plot for the target
|
|
507
|
-
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
508
|
-
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
509
|
-
plt.tight_layout()
|
|
510
|
-
bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
|
|
511
|
-
plt.savefig(bar_path)
|
|
512
|
-
plt.close()
|
|
513
|
-
|
|
514
|
-
# Save Dot Plot for the target
|
|
515
|
-
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
516
|
-
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
517
|
-
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
518
|
-
cb = plt.gcf().axes[-1]
|
|
519
|
-
cb.set_ylabel("", size=1)
|
|
520
|
-
plt.tight_layout()
|
|
521
|
-
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
522
|
-
plt.savefig(dot_path)
|
|
523
|
-
plt.close()
|
|
524
|
-
|
|
525
|
-
# --- Save Summary Data to CSV for this target ---
|
|
526
|
-
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
527
|
-
summary_path = save_dir_path / shap_summary_filename
|
|
528
|
-
|
|
529
|
-
# For a specific target, shap_values_for_target is just a 2D array
|
|
530
|
-
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
531
|
-
|
|
532
|
-
summary_df = pd.DataFrame({
|
|
533
|
-
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
534
|
-
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
535
|
-
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
536
|
-
|
|
537
|
-
summary_df.to_csv(summary_path, index=False)
|
|
538
|
-
|
|
539
|
-
plt.ion()
|
|
540
|
-
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
def info():
|
|
544
|
-
_script_info(__all__)
|