dragon-ml-toolbox 6.4.1__py3-none-any.whl → 8.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/METADATA +4 -1
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/RECORD +14 -11
- ml_tools/ML_datasetmaster.py +285 -438
- ml_tools/ML_evaluation.py +119 -51
- ml_tools/ML_evaluation_multi.py +296 -0
- ml_tools/ML_inference.py +251 -31
- ml_tools/ML_models.py +468 -47
- ml_tools/ML_scaler.py +197 -0
- ml_tools/ML_trainer.py +246 -73
- ml_tools/_ML_optimization_multi.py +231 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-6.4.1.dist-info → dragon_ml_toolbox-8.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -20,7 +20,7 @@ import shap
|
|
|
20
20
|
from pathlib import Path
|
|
21
21
|
from .path_manager import make_fullpath
|
|
22
22
|
from ._logger import _LOGGER
|
|
23
|
-
from typing import Union, Optional
|
|
23
|
+
from typing import Union, Optional, List
|
|
24
24
|
from ._script_info import _script_info
|
|
25
25
|
|
|
26
26
|
|
|
@@ -28,7 +28,8 @@ __all__ = [
|
|
|
28
28
|
"plot_losses",
|
|
29
29
|
"classification_metrics",
|
|
30
30
|
"regression_metrics",
|
|
31
|
-
"shap_summary_plot"
|
|
31
|
+
"shap_summary_plot",
|
|
32
|
+
"plot_attention_importance"
|
|
32
33
|
]
|
|
33
34
|
|
|
34
35
|
|
|
@@ -248,8 +249,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
248
249
|
plt.close(fig_hist)
|
|
249
250
|
|
|
250
251
|
|
|
251
|
-
def shap_summary_plot(model,
|
|
252
|
-
|
|
252
|
+
def shap_summary_plot(model,
|
|
253
|
+
background_data: Union[torch.Tensor,np.ndarray],
|
|
254
|
+
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
255
|
+
feature_names: Optional[list[str]],
|
|
256
|
+
save_dir: Union[str, Path]):
|
|
253
257
|
"""
|
|
254
258
|
Calculates SHAP values and saves summary plots and data.
|
|
255
259
|
|
|
@@ -258,7 +262,7 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
|
|
|
258
262
|
background_data (torch.Tensor): A sample of data for the explainer background.
|
|
259
263
|
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
260
264
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
261
|
-
save_dir (str | Path
|
|
265
|
+
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
262
266
|
"""
|
|
263
267
|
# everything to numpy
|
|
264
268
|
if isinstance(background_data, np.ndarray):
|
|
@@ -301,55 +305,119 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
|
|
|
301
305
|
print("Calculating SHAP values with KernelExplainer...")
|
|
302
306
|
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
303
307
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
plt.close()
|
|
308
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
309
|
+
plt.ioff()
|
|
310
|
+
|
|
311
|
+
# Save Bar Plot
|
|
312
|
+
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
313
|
+
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
314
|
+
ax = plt.gca()
|
|
315
|
+
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
316
|
+
plt.title("SHAP Feature Importance")
|
|
317
|
+
plt.tight_layout()
|
|
318
|
+
plt.savefig(bar_path)
|
|
319
|
+
_LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
|
|
320
|
+
plt.close()
|
|
318
321
|
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
322
|
+
# Save Dot Plot
|
|
323
|
+
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
324
|
+
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
325
|
+
ax = plt.gca()
|
|
326
|
+
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
327
|
+
cb = plt.gcf().axes[-1]
|
|
328
|
+
cb.set_ylabel("", size=1)
|
|
329
|
+
plt.title("SHAP Feature Importance")
|
|
330
|
+
plt.tight_layout()
|
|
331
|
+
plt.savefig(dot_path)
|
|
332
|
+
_LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
|
|
333
|
+
plt.close()
|
|
331
334
|
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
summary_df = pd.DataFrame({
|
|
341
|
-
'feature': feature_names,
|
|
342
|
-
'mean_abs_shap_value': mean_abs_shap
|
|
343
|
-
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
344
|
-
|
|
345
|
-
summary_df.to_csv(summary_path, index=False)
|
|
346
|
-
|
|
347
|
-
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
348
|
-
plt.ion()
|
|
335
|
+
# Save Summary Data to CSV
|
|
336
|
+
summary_path = save_dir_path / "shap_summary.csv"
|
|
337
|
+
# Ensure the array is 1D before creating the DataFrame
|
|
338
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
339
|
+
|
|
340
|
+
if feature_names is None:
|
|
341
|
+
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
349
342
|
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
343
|
+
summary_df = pd.DataFrame({
|
|
344
|
+
'feature': feature_names,
|
|
345
|
+
'mean_abs_shap_value': mean_abs_shap
|
|
346
|
+
}).sort_values('mean_abs_shap_value', ascending=False)
|
|
347
|
+
|
|
348
|
+
summary_df.to_csv(summary_path, index=False)
|
|
349
|
+
|
|
350
|
+
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
351
|
+
plt.ion()
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
|
|
355
|
+
"""
|
|
356
|
+
Aggregates attention weights and plots global feature importance.
|
|
357
|
+
|
|
358
|
+
The plot shows the mean attention for each feature as a bar, with the
|
|
359
|
+
standard deviation represented by error bars.
|
|
360
|
+
|
|
361
|
+
Args:
|
|
362
|
+
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
363
|
+
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
364
|
+
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
365
|
+
"""
|
|
366
|
+
if not weights:
|
|
367
|
+
_LOGGER.warning("⚠️ Attention weights list is empty. Skipping importance plot.")
|
|
368
|
+
return
|
|
369
|
+
|
|
370
|
+
# --- Step 1: Aggregate data ---
|
|
371
|
+
# Concatenate the list of tensors into a single large tensor
|
|
372
|
+
full_weights_tensor = torch.cat(weights, dim=0)
|
|
373
|
+
|
|
374
|
+
# Calculate mean and std dev across the batch dimension (dim=0)
|
|
375
|
+
mean_weights = full_weights_tensor.mean(dim=0)
|
|
376
|
+
std_weights = full_weights_tensor.std(dim=0)
|
|
377
|
+
|
|
378
|
+
# --- Step 2: Create and save summary DataFrame ---
|
|
379
|
+
if feature_names is None:
|
|
380
|
+
feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
|
|
381
|
+
|
|
382
|
+
summary_df = pd.DataFrame({
|
|
383
|
+
'feature': feature_names,
|
|
384
|
+
'mean_attention': mean_weights.numpy(),
|
|
385
|
+
'std_attention': std_weights.numpy()
|
|
386
|
+
}).sort_values('mean_attention', ascending=False)
|
|
387
|
+
|
|
388
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
389
|
+
summary_path = save_dir_path / "attention_summary.csv"
|
|
390
|
+
summary_df.to_csv(summary_path, index=False)
|
|
391
|
+
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
392
|
+
|
|
393
|
+
# --- Step 3: Create and save the plot ---
|
|
394
|
+
plt.figure(figsize=(10, 8), dpi=100)
|
|
395
|
+
|
|
396
|
+
# Sort for plotting
|
|
397
|
+
plot_df = summary_df.sort_values('mean_attention', ascending=True)
|
|
398
|
+
|
|
399
|
+
# Create horizontal bar plot with error bars
|
|
400
|
+
plt.barh(
|
|
401
|
+
y=plot_df['feature'],
|
|
402
|
+
width=plot_df['mean_attention'],
|
|
403
|
+
xerr=plot_df['std_attention'],
|
|
404
|
+
align='center',
|
|
405
|
+
alpha=0.7,
|
|
406
|
+
ecolor='grey',
|
|
407
|
+
capsize=3,
|
|
408
|
+
color='cornflowerblue'
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
plt.title('Global Feature Importance')
|
|
412
|
+
plt.xlabel('Average Attention Weight')
|
|
413
|
+
plt.ylabel('Feature')
|
|
414
|
+
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
|
415
|
+
plt.tight_layout()
|
|
416
|
+
|
|
417
|
+
plot_path = save_dir_path / "attention_importance.svg"
|
|
418
|
+
plt.savefig(plot_path)
|
|
419
|
+
_LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
|
|
420
|
+
plt.close()
|
|
353
421
|
|
|
354
422
|
|
|
355
423
|
def info():
|
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import torch
|
|
6
|
+
import shap
|
|
7
|
+
from sklearn.metrics import (
|
|
8
|
+
classification_report,
|
|
9
|
+
ConfusionMatrixDisplay,
|
|
10
|
+
roc_curve,
|
|
11
|
+
roc_auc_score,
|
|
12
|
+
precision_recall_curve,
|
|
13
|
+
average_precision_score,
|
|
14
|
+
mean_squared_error,
|
|
15
|
+
mean_absolute_error,
|
|
16
|
+
r2_score,
|
|
17
|
+
median_absolute_error,
|
|
18
|
+
hamming_loss,
|
|
19
|
+
jaccard_score
|
|
20
|
+
)
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Union, List, Optional
|
|
23
|
+
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
|
+
from ._logger import _LOGGER
|
|
26
|
+
from ._script_info import _script_info
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"multi_target_regression_metrics",
|
|
30
|
+
"multi_label_classification_metrics",
|
|
31
|
+
"multi_target_shap_summary_plot",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def multi_target_regression_metrics(
|
|
36
|
+
y_true: np.ndarray,
|
|
37
|
+
y_pred: np.ndarray,
|
|
38
|
+
target_names: List[str],
|
|
39
|
+
save_dir: Union[str, Path]
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Calculates and saves regression metrics for each target individually.
|
|
43
|
+
|
|
44
|
+
For each target, this function saves a residual plot and a true vs. predicted plot.
|
|
45
|
+
It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
|
|
46
|
+
for all targets.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
|
|
50
|
+
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
51
|
+
target_names (List[str]): A list of names for the target variables.
|
|
52
|
+
save_dir (str | Path): Directory to save plots and the report.
|
|
53
|
+
"""
|
|
54
|
+
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
55
|
+
raise ValueError("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
56
|
+
if y_true.shape != y_pred.shape:
|
|
57
|
+
raise ValueError("Shapes of y_true and y_pred must match.")
|
|
58
|
+
if y_true.shape[1] != len(target_names):
|
|
59
|
+
raise ValueError("Number of target names must match the number of columns in y_true.")
|
|
60
|
+
|
|
61
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
62
|
+
metrics_summary = []
|
|
63
|
+
|
|
64
|
+
_LOGGER.info("--- Multi-Target Regression Evaluation ---")
|
|
65
|
+
|
|
66
|
+
for i, name in enumerate(target_names):
|
|
67
|
+
_LOGGER.info(f" -> Evaluating target: '{name}'")
|
|
68
|
+
true_i = y_true[:, i]
|
|
69
|
+
pred_i = y_pred[:, i]
|
|
70
|
+
sanitized_name = sanitize_filename(name)
|
|
71
|
+
|
|
72
|
+
# --- Calculate Metrics ---
|
|
73
|
+
rmse = np.sqrt(mean_squared_error(true_i, pred_i))
|
|
74
|
+
mae = mean_absolute_error(true_i, pred_i)
|
|
75
|
+
r2 = r2_score(true_i, pred_i)
|
|
76
|
+
medae = median_absolute_error(true_i, pred_i)
|
|
77
|
+
metrics_summary.append({
|
|
78
|
+
'target': name,
|
|
79
|
+
'rmse': rmse,
|
|
80
|
+
'mae': mae,
|
|
81
|
+
'r2_score': r2,
|
|
82
|
+
'median_abs_error': medae
|
|
83
|
+
})
|
|
84
|
+
|
|
85
|
+
# --- Save Residual Plot ---
|
|
86
|
+
residuals = true_i - pred_i
|
|
87
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
|
|
88
|
+
ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
|
|
89
|
+
ax_res.axhline(0, color='red', linestyle='--')
|
|
90
|
+
ax_res.set_xlabel("Predicted Values")
|
|
91
|
+
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
92
|
+
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
93
|
+
ax_res.grid(True, linestyle='--', alpha=0.6)
|
|
94
|
+
plt.tight_layout()
|
|
95
|
+
res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
|
|
96
|
+
plt.savefig(res_path)
|
|
97
|
+
plt.close(fig_res)
|
|
98
|
+
|
|
99
|
+
# --- Save True vs. Predicted Plot ---
|
|
100
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
|
|
101
|
+
ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
|
|
102
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
|
|
103
|
+
ax_tvp.set_xlabel('True Values')
|
|
104
|
+
ax_tvp.set_ylabel('Predicted Values')
|
|
105
|
+
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
106
|
+
ax_tvp.grid(True, linestyle='--', alpha=0.6)
|
|
107
|
+
plt.tight_layout()
|
|
108
|
+
tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
|
|
109
|
+
plt.savefig(tvp_path)
|
|
110
|
+
plt.close(fig_tvp)
|
|
111
|
+
|
|
112
|
+
# --- Save Summary Report ---
|
|
113
|
+
summary_df = pd.DataFrame(metrics_summary)
|
|
114
|
+
report_path = save_dir_path / "regression_report_multi.csv"
|
|
115
|
+
summary_df.to_csv(report_path, index=False)
|
|
116
|
+
_LOGGER.info(f"✅ Full regression report saved to '{report_path.name}'")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def multi_label_classification_metrics(
|
|
120
|
+
y_true: np.ndarray,
|
|
121
|
+
y_prob: np.ndarray,
|
|
122
|
+
target_names: List[str],
|
|
123
|
+
save_dir: Union[str, Path],
|
|
124
|
+
threshold: float = 0.5
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Calculates and saves classification metrics for each label individually.
|
|
128
|
+
|
|
129
|
+
This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
|
|
130
|
+
and then iterates through each label to generate and save individual reports,
|
|
131
|
+
confusion matrices, ROC curves, and Precision-Recall curves.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
135
|
+
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
136
|
+
target_names (List[str]): A list of names for the labels.
|
|
137
|
+
save_dir (str | Path): Directory to save plots and reports.
|
|
138
|
+
threshold (float): The probability threshold to convert probabilities into
|
|
139
|
+
binary predictions for metrics like the confusion matrix.
|
|
140
|
+
"""
|
|
141
|
+
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
142
|
+
raise ValueError("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
143
|
+
if y_true.shape != y_prob.shape:
|
|
144
|
+
raise ValueError("Shapes of y_true and y_prob must match.")
|
|
145
|
+
if y_true.shape[1] != len(target_names):
|
|
146
|
+
raise ValueError("Number of target names must match the number of columns in y_true.")
|
|
147
|
+
|
|
148
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
149
|
+
|
|
150
|
+
# Generate binary predictions from probabilities
|
|
151
|
+
y_pred = (y_prob >= threshold).astype(int)
|
|
152
|
+
|
|
153
|
+
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
154
|
+
|
|
155
|
+
# --- Calculate and Save Overall Metrics ---
|
|
156
|
+
h_loss = hamming_loss(y_true, y_pred)
|
|
157
|
+
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
158
|
+
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
159
|
+
|
|
160
|
+
overall_report = (
|
|
161
|
+
f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
|
|
162
|
+
f"--------------------------------------------------\n"
|
|
163
|
+
f"Hamming Loss: {h_loss:.4f}\n"
|
|
164
|
+
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
165
|
+
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
166
|
+
f"--------------------------------------------------\n"
|
|
167
|
+
)
|
|
168
|
+
_LOGGER.info(overall_report)
|
|
169
|
+
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
170
|
+
overall_report_path.write_text(overall_report)
|
|
171
|
+
|
|
172
|
+
# --- Per-Label Metrics and Plots ---
|
|
173
|
+
for i, name in enumerate(target_names):
|
|
174
|
+
_LOGGER.info(f" -> Evaluating label: '{name}'")
|
|
175
|
+
true_i = y_true[:, i]
|
|
176
|
+
pred_i = y_pred[:, i]
|
|
177
|
+
prob_i = y_prob[:, i]
|
|
178
|
+
sanitized_name = sanitize_filename(name)
|
|
179
|
+
|
|
180
|
+
# --- Save Classification Report for the label ---
|
|
181
|
+
report_text = classification_report(true_i, pred_i)
|
|
182
|
+
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
183
|
+
report_path.write_text(report_text) # type: ignore
|
|
184
|
+
|
|
185
|
+
# --- Save Confusion Matrix ---
|
|
186
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
187
|
+
ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
|
|
188
|
+
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
189
|
+
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
190
|
+
plt.savefig(cm_path)
|
|
191
|
+
plt.close(fig_cm)
|
|
192
|
+
|
|
193
|
+
# --- Save ROC Curve ---
|
|
194
|
+
fpr, tpr, _ = roc_curve(true_i, prob_i)
|
|
195
|
+
auc = roc_auc_score(true_i, prob_i)
|
|
196
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
|
|
197
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
198
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
199
|
+
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
200
|
+
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
201
|
+
ax_roc.legend(loc='lower right'); ax_roc.grid(True, linestyle='--', alpha=0.6)
|
|
202
|
+
roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
|
|
203
|
+
plt.savefig(roc_path)
|
|
204
|
+
plt.close(fig_roc)
|
|
205
|
+
|
|
206
|
+
# --- Save Precision-Recall Curve ---
|
|
207
|
+
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
208
|
+
ap_score = average_precision_score(true_i, prob_i)
|
|
209
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
|
|
210
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
211
|
+
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
212
|
+
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
213
|
+
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
214
|
+
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
215
|
+
plt.savefig(pr_path)
|
|
216
|
+
plt.close(fig_pr)
|
|
217
|
+
|
|
218
|
+
_LOGGER.info(f"✅ All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def multi_target_shap_summary_plot(
|
|
222
|
+
model: torch.nn.Module,
|
|
223
|
+
background_data: Union[torch.Tensor, np.ndarray],
|
|
224
|
+
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
225
|
+
feature_names: List[str],
|
|
226
|
+
target_names: List[str],
|
|
227
|
+
save_dir: Union[str, Path]
|
|
228
|
+
):
|
|
229
|
+
"""
|
|
230
|
+
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
model (torch.nn.Module): The trained PyTorch model.
|
|
234
|
+
background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
|
|
235
|
+
instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
|
|
236
|
+
feature_names (List[str]): Names of the features for plot labeling.
|
|
237
|
+
target_names (List[str]): Names of the output targets.
|
|
238
|
+
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
239
|
+
"""
|
|
240
|
+
# Convert all data to numpy
|
|
241
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
242
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
243
|
+
|
|
244
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
245
|
+
_LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
_LOGGER.info("\n--- Multi-Target SHAP Value Explanation ---")
|
|
249
|
+
model.eval()
|
|
250
|
+
model.cpu()
|
|
251
|
+
|
|
252
|
+
# 1. Summarize the background data.
|
|
253
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
254
|
+
|
|
255
|
+
# 2. Define a prediction function wrapper for the multi-target model.
|
|
256
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
257
|
+
x_torch = torch.from_numpy(x_np).float()
|
|
258
|
+
with torch.no_grad():
|
|
259
|
+
output = model(x_torch)
|
|
260
|
+
return output.cpu().numpy()
|
|
261
|
+
|
|
262
|
+
# 3. Create the KernelExplainer.
|
|
263
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
264
|
+
|
|
265
|
+
_LOGGER.info("Calculating SHAP values with KernelExplainer...")
|
|
266
|
+
# For multi-output models, shap_values is a list of arrays.
|
|
267
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
268
|
+
|
|
269
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
270
|
+
plt.ioff()
|
|
271
|
+
|
|
272
|
+
# 4. Iterate through each target's SHAP values and generate plots.
|
|
273
|
+
for i, target_name in enumerate(target_names):
|
|
274
|
+
_LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
275
|
+
shap_values_for_target = shap_values_list[i]
|
|
276
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
277
|
+
|
|
278
|
+
# Save Bar Plot for the target
|
|
279
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
280
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
281
|
+
plt.tight_layout()
|
|
282
|
+
bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
|
|
283
|
+
plt.savefig(bar_path)
|
|
284
|
+
plt.close()
|
|
285
|
+
|
|
286
|
+
# Save Dot Plot for the target
|
|
287
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
288
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
289
|
+
plt.tight_layout()
|
|
290
|
+
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
291
|
+
plt.savefig(dot_path)
|
|
292
|
+
plt.close()
|
|
293
|
+
|
|
294
|
+
plt.ion()
|
|
295
|
+
_LOGGER.info(f"✅ All SHAP plots saved to '{save_dir_path.name}'")
|
|
296
|
+
|