dragon-ml-toolbox 6.4.1__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_evaluation.py CHANGED
@@ -20,7 +20,7 @@ import shap
20
20
  from pathlib import Path
21
21
  from .path_manager import make_fullpath
22
22
  from ._logger import _LOGGER
23
- from typing import Union, Optional
23
+ from typing import Union, Optional, List
24
24
  from ._script_info import _script_info
25
25
 
26
26
 
@@ -28,7 +28,8 @@ __all__ = [
28
28
  "plot_losses",
29
29
  "classification_metrics",
30
30
  "regression_metrics",
31
- "shap_summary_plot"
31
+ "shap_summary_plot",
32
+ "plot_attention_importance"
32
33
  ]
33
34
 
34
35
 
@@ -248,8 +249,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
248
249
  plt.close(fig_hist)
249
250
 
250
251
 
251
- def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], instances_to_explain: Union[torch.Tensor,np.ndarray],
252
- feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
252
+ def shap_summary_plot(model,
253
+ background_data: Union[torch.Tensor,np.ndarray],
254
+ instances_to_explain: Union[torch.Tensor,np.ndarray],
255
+ feature_names: Optional[list[str]],
256
+ save_dir: Union[str, Path]):
253
257
  """
254
258
  Calculates SHAP values and saves summary plots and data.
255
259
 
@@ -258,7 +262,7 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
258
262
  background_data (torch.Tensor): A sample of data for the explainer background.
259
263
  instances_to_explain (torch.Tensor): The specific data instances to explain.
260
264
  feature_names (list of str | None): Names of the features for plot labeling.
261
- save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
265
+ save_dir (str | Path): Directory to save SHAP artifacts.
262
266
  """
263
267
  # everything to numpy
264
268
  if isinstance(background_data, np.ndarray):
@@ -301,55 +305,119 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
301
305
  print("Calculating SHAP values with KernelExplainer...")
302
306
  shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
303
307
 
304
- if save_dir:
305
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
- plt.ioff()
307
-
308
- # Save Bar Plot
309
- bar_path = save_dir_path / "shap_bar_plot.svg"
310
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
311
- ax = plt.gca()
312
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
313
- plt.title("SHAP Feature Importance")
314
- plt.tight_layout()
315
- plt.savefig(bar_path)
316
- _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
317
- plt.close()
308
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
309
+ plt.ioff()
310
+
311
+ # Save Bar Plot
312
+ bar_path = save_dir_path / "shap_bar_plot.svg"
313
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
314
+ ax = plt.gca()
315
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
316
+ plt.title("SHAP Feature Importance")
317
+ plt.tight_layout()
318
+ plt.savefig(bar_path)
319
+ _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
320
+ plt.close()
318
321
 
319
- # Save Dot Plot
320
- dot_path = save_dir_path / "shap_dot_plot.svg"
321
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
322
- ax = plt.gca()
323
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
324
- cb = plt.gcf().axes[-1]
325
- cb.set_ylabel("", size=1)
326
- plt.title("SHAP Feature Importance")
327
- plt.tight_layout()
328
- plt.savefig(dot_path)
329
- _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
330
- plt.close()
322
+ # Save Dot Plot
323
+ dot_path = save_dir_path / "shap_dot_plot.svg"
324
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
325
+ ax = plt.gca()
326
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
327
+ cb = plt.gcf().axes[-1]
328
+ cb.set_ylabel("", size=1)
329
+ plt.title("SHAP Feature Importance")
330
+ plt.tight_layout()
331
+ plt.savefig(dot_path)
332
+ _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
333
+ plt.close()
331
334
 
332
- # Save Summary Data to CSV
333
- summary_path = save_dir_path / "shap_summary.csv"
334
- # Ensure the array is 1D before creating the DataFrame
335
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
336
-
337
- if feature_names is None:
338
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
339
-
340
- summary_df = pd.DataFrame({
341
- 'feature': feature_names,
342
- 'mean_abs_shap_value': mean_abs_shap
343
- }).sort_values('mean_abs_shap_value', ascending=False)
344
-
345
- summary_df.to_csv(summary_path, index=False)
346
-
347
- _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
348
- plt.ion()
335
+ # Save Summary Data to CSV
336
+ summary_path = save_dir_path / "shap_summary.csv"
337
+ # Ensure the array is 1D before creating the DataFrame
338
+ mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
339
+
340
+ if feature_names is None:
341
+ feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
349
342
 
350
- else:
351
- _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
352
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot")
343
+ summary_df = pd.DataFrame({
344
+ 'feature': feature_names,
345
+ 'mean_abs_shap_value': mean_abs_shap
346
+ }).sort_values('mean_abs_shap_value', ascending=False)
347
+
348
+ summary_df.to_csv(summary_path, index=False)
349
+
350
+ _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
351
+ plt.ion()
352
+
353
+
354
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
355
+ """
356
+ Aggregates attention weights and plots global feature importance.
357
+
358
+ The plot shows the mean attention for each feature as a bar, with the
359
+ standard deviation represented by error bars.
360
+
361
+ Args:
362
+ weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
363
+ feature_names (List[str] | None): Names of the features for plot labeling.
364
+ save_dir (str | Path): Directory to save the plot and summary CSV.
365
+ """
366
+ if not weights:
367
+ _LOGGER.warning("⚠️ Attention weights list is empty. Skipping importance plot.")
368
+ return
369
+
370
+ # --- Step 1: Aggregate data ---
371
+ # Concatenate the list of tensors into a single large tensor
372
+ full_weights_tensor = torch.cat(weights, dim=0)
373
+
374
+ # Calculate mean and std dev across the batch dimension (dim=0)
375
+ mean_weights = full_weights_tensor.mean(dim=0)
376
+ std_weights = full_weights_tensor.std(dim=0)
377
+
378
+ # --- Step 2: Create and save summary DataFrame ---
379
+ if feature_names is None:
380
+ feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
381
+
382
+ summary_df = pd.DataFrame({
383
+ 'feature': feature_names,
384
+ 'mean_attention': mean_weights.numpy(),
385
+ 'std_attention': std_weights.numpy()
386
+ }).sort_values('mean_attention', ascending=False)
387
+
388
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
389
+ summary_path = save_dir_path / "attention_summary.csv"
390
+ summary_df.to_csv(summary_path, index=False)
391
+ _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
392
+
393
+ # --- Step 3: Create and save the plot ---
394
+ plt.figure(figsize=(10, 8), dpi=100)
395
+
396
+ # Sort for plotting
397
+ plot_df = summary_df.sort_values('mean_attention', ascending=True)
398
+
399
+ # Create horizontal bar plot with error bars
400
+ plt.barh(
401
+ y=plot_df['feature'],
402
+ width=plot_df['mean_attention'],
403
+ xerr=plot_df['std_attention'],
404
+ align='center',
405
+ alpha=0.7,
406
+ ecolor='grey',
407
+ capsize=3,
408
+ color='cornflowerblue'
409
+ )
410
+
411
+ plt.title('Global Feature Importance')
412
+ plt.xlabel('Average Attention Weight')
413
+ plt.ylabel('Feature')
414
+ plt.grid(axis='x', linestyle='--', alpha=0.6)
415
+ plt.tight_layout()
416
+
417
+ plot_path = save_dir_path / "attention_importance.svg"
418
+ plt.savefig(plot_path)
419
+ _LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
420
+ plt.close()
353
421
 
354
422
 
355
423
  def info():
@@ -0,0 +1,296 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import torch
6
+ import shap
7
+ from sklearn.metrics import (
8
+ classification_report,
9
+ ConfusionMatrixDisplay,
10
+ roc_curve,
11
+ roc_auc_score,
12
+ precision_recall_curve,
13
+ average_precision_score,
14
+ mean_squared_error,
15
+ mean_absolute_error,
16
+ r2_score,
17
+ median_absolute_error,
18
+ hamming_loss,
19
+ jaccard_score
20
+ )
21
+ from pathlib import Path
22
+ from typing import Union, List, Optional
23
+
24
+ from .path_manager import make_fullpath, sanitize_filename
25
+ from ._logger import _LOGGER
26
+ from ._script_info import _script_info
27
+
28
+ __all__ = [
29
+ "multi_target_regression_metrics",
30
+ "multi_label_classification_metrics",
31
+ "multi_target_shap_summary_plot",
32
+ ]
33
+
34
+
35
+ def multi_target_regression_metrics(
36
+ y_true: np.ndarray,
37
+ y_pred: np.ndarray,
38
+ target_names: List[str],
39
+ save_dir: Union[str, Path]
40
+ ):
41
+ """
42
+ Calculates and saves regression metrics for each target individually.
43
+
44
+ For each target, this function saves a residual plot and a true vs. predicted plot.
45
+ It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
46
+ for all targets.
47
+
48
+ Args:
49
+ y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
50
+ y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
51
+ target_names (List[str]): A list of names for the target variables.
52
+ save_dir (str | Path): Directory to save plots and the report.
53
+ """
54
+ if y_true.ndim != 2 or y_pred.ndim != 2:
55
+ raise ValueError("y_true and y_pred must be 2D arrays for multi-target regression.")
56
+ if y_true.shape != y_pred.shape:
57
+ raise ValueError("Shapes of y_true and y_pred must match.")
58
+ if y_true.shape[1] != len(target_names):
59
+ raise ValueError("Number of target names must match the number of columns in y_true.")
60
+
61
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
62
+ metrics_summary = []
63
+
64
+ _LOGGER.info("--- Multi-Target Regression Evaluation ---")
65
+
66
+ for i, name in enumerate(target_names):
67
+ _LOGGER.info(f" -> Evaluating target: '{name}'")
68
+ true_i = y_true[:, i]
69
+ pred_i = y_pred[:, i]
70
+ sanitized_name = sanitize_filename(name)
71
+
72
+ # --- Calculate Metrics ---
73
+ rmse = np.sqrt(mean_squared_error(true_i, pred_i))
74
+ mae = mean_absolute_error(true_i, pred_i)
75
+ r2 = r2_score(true_i, pred_i)
76
+ medae = median_absolute_error(true_i, pred_i)
77
+ metrics_summary.append({
78
+ 'target': name,
79
+ 'rmse': rmse,
80
+ 'mae': mae,
81
+ 'r2_score': r2,
82
+ 'median_abs_error': medae
83
+ })
84
+
85
+ # --- Save Residual Plot ---
86
+ residuals = true_i - pred_i
87
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
88
+ ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
89
+ ax_res.axhline(0, color='red', linestyle='--')
90
+ ax_res.set_xlabel("Predicted Values")
91
+ ax_res.set_ylabel("Residuals (True - Predicted)")
92
+ ax_res.set_title(f"Residual Plot for '{name}'")
93
+ ax_res.grid(True, linestyle='--', alpha=0.6)
94
+ plt.tight_layout()
95
+ res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
96
+ plt.savefig(res_path)
97
+ plt.close(fig_res)
98
+
99
+ # --- Save True vs. Predicted Plot ---
100
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
101
+ ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
102
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
103
+ ax_tvp.set_xlabel('True Values')
104
+ ax_tvp.set_ylabel('Predicted Values')
105
+ ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
106
+ ax_tvp.grid(True, linestyle='--', alpha=0.6)
107
+ plt.tight_layout()
108
+ tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
109
+ plt.savefig(tvp_path)
110
+ plt.close(fig_tvp)
111
+
112
+ # --- Save Summary Report ---
113
+ summary_df = pd.DataFrame(metrics_summary)
114
+ report_path = save_dir_path / "regression_report_multi.csv"
115
+ summary_df.to_csv(report_path, index=False)
116
+ _LOGGER.info(f"✅ Full regression report saved to '{report_path.name}'")
117
+
118
+
119
+ def multi_label_classification_metrics(
120
+ y_true: np.ndarray,
121
+ y_prob: np.ndarray,
122
+ target_names: List[str],
123
+ save_dir: Union[str, Path],
124
+ threshold: float = 0.5
125
+ ):
126
+ """
127
+ Calculates and saves classification metrics for each label individually.
128
+
129
+ This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
130
+ and then iterates through each label to generate and save individual reports,
131
+ confusion matrices, ROC curves, and Precision-Recall curves.
132
+
133
+ Args:
134
+ y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
135
+ y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
136
+ target_names (List[str]): A list of names for the labels.
137
+ save_dir (str | Path): Directory to save plots and reports.
138
+ threshold (float): The probability threshold to convert probabilities into
139
+ binary predictions for metrics like the confusion matrix.
140
+ """
141
+ if y_true.ndim != 2 or y_prob.ndim != 2:
142
+ raise ValueError("y_true and y_prob must be 2D arrays for multi-label classification.")
143
+ if y_true.shape != y_prob.shape:
144
+ raise ValueError("Shapes of y_true and y_prob must match.")
145
+ if y_true.shape[1] != len(target_names):
146
+ raise ValueError("Number of target names must match the number of columns in y_true.")
147
+
148
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
149
+
150
+ # Generate binary predictions from probabilities
151
+ y_pred = (y_prob >= threshold).astype(int)
152
+
153
+ _LOGGER.info("--- Multi-Label Classification Evaluation ---")
154
+
155
+ # --- Calculate and Save Overall Metrics ---
156
+ h_loss = hamming_loss(y_true, y_pred)
157
+ j_score_micro = jaccard_score(y_true, y_pred, average='micro')
158
+ j_score_macro = jaccard_score(y_true, y_pred, average='macro')
159
+
160
+ overall_report = (
161
+ f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
162
+ f"--------------------------------------------------\n"
163
+ f"Hamming Loss: {h_loss:.4f}\n"
164
+ f"Jaccard Score (micro): {j_score_micro:.4f}\n"
165
+ f"Jaccard Score (macro): {j_score_macro:.4f}\n"
166
+ f"--------------------------------------------------\n"
167
+ )
168
+ _LOGGER.info(overall_report)
169
+ overall_report_path = save_dir_path / "classification_report_overall.txt"
170
+ overall_report_path.write_text(overall_report)
171
+
172
+ # --- Per-Label Metrics and Plots ---
173
+ for i, name in enumerate(target_names):
174
+ _LOGGER.info(f" -> Evaluating label: '{name}'")
175
+ true_i = y_true[:, i]
176
+ pred_i = y_pred[:, i]
177
+ prob_i = y_prob[:, i]
178
+ sanitized_name = sanitize_filename(name)
179
+
180
+ # --- Save Classification Report for the label ---
181
+ report_text = classification_report(true_i, pred_i)
182
+ report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
183
+ report_path.write_text(report_text) # type: ignore
184
+
185
+ # --- Save Confusion Matrix ---
186
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
187
+ ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
188
+ ax_cm.set_title(f"Confusion Matrix for '{name}'")
189
+ cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
190
+ plt.savefig(cm_path)
191
+ plt.close(fig_cm)
192
+
193
+ # --- Save ROC Curve ---
194
+ fpr, tpr, _ = roc_curve(true_i, prob_i)
195
+ auc = roc_auc_score(true_i, prob_i)
196
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
197
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
198
+ ax_roc.plot([0, 1], [0, 1], 'k--')
199
+ ax_roc.set_title(f'ROC Curve for "{name}"')
200
+ ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
201
+ ax_roc.legend(loc='lower right'); ax_roc.grid(True, linestyle='--', alpha=0.6)
202
+ roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
203
+ plt.savefig(roc_path)
204
+ plt.close(fig_roc)
205
+
206
+ # --- Save Precision-Recall Curve ---
207
+ precision, recall, _ = precision_recall_curve(true_i, prob_i)
208
+ ap_score = average_precision_score(true_i, prob_i)
209
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
210
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
211
+ ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
212
+ ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
213
+ ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
214
+ pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
215
+ plt.savefig(pr_path)
216
+ plt.close(fig_pr)
217
+
218
+ _LOGGER.info(f"✅ All individual label reports and plots saved to '{save_dir_path.name}'")
219
+
220
+
221
+ def multi_target_shap_summary_plot(
222
+ model: torch.nn.Module,
223
+ background_data: Union[torch.Tensor, np.ndarray],
224
+ instances_to_explain: Union[torch.Tensor, np.ndarray],
225
+ feature_names: List[str],
226
+ target_names: List[str],
227
+ save_dir: Union[str, Path]
228
+ ):
229
+ """
230
+ Calculates SHAP values for a multi-target model and saves summary plots for each target.
231
+
232
+ Args:
233
+ model (torch.nn.Module): The trained PyTorch model.
234
+ background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
235
+ instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
236
+ feature_names (List[str]): Names of the features for plot labeling.
237
+ target_names (List[str]): Names of the output targets.
238
+ save_dir (str | Path): Directory to save SHAP artifacts.
239
+ """
240
+ # Convert all data to numpy
241
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
242
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
243
+
244
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
245
+ _LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
246
+ return
247
+
248
+ _LOGGER.info("\n--- Multi-Target SHAP Value Explanation ---")
249
+ model.eval()
250
+ model.cpu()
251
+
252
+ # 1. Summarize the background data.
253
+ background_summary = shap.kmeans(background_data_np, 30)
254
+
255
+ # 2. Define a prediction function wrapper for the multi-target model.
256
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
257
+ x_torch = torch.from_numpy(x_np).float()
258
+ with torch.no_grad():
259
+ output = model(x_torch)
260
+ return output.cpu().numpy()
261
+
262
+ # 3. Create the KernelExplainer.
263
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
264
+
265
+ _LOGGER.info("Calculating SHAP values with KernelExplainer...")
266
+ # For multi-output models, shap_values is a list of arrays.
267
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
268
+
269
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
270
+ plt.ioff()
271
+
272
+ # 4. Iterate through each target's SHAP values and generate plots.
273
+ for i, target_name in enumerate(target_names):
274
+ _LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
275
+ shap_values_for_target = shap_values_list[i]
276
+ sanitized_target_name = sanitize_filename(target_name)
277
+
278
+ # Save Bar Plot for the target
279
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
280
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
281
+ plt.tight_layout()
282
+ bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
283
+ plt.savefig(bar_path)
284
+ plt.close()
285
+
286
+ # Save Dot Plot for the target
287
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
288
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
289
+ plt.tight_layout()
290
+ dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
291
+ plt.savefig(dot_path)
292
+ plt.close()
293
+
294
+ plt.ion()
295
+ _LOGGER.info(f"✅ All SHAP plots saved to '{save_dir_path.name}'")
296
+