dragon-ml-toolbox 12.12.0__tar.gz → 12.13.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.12.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-12.13.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_evaluation.py +90 -44
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_evaluation_multi.py +103 -32
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_trainer.py +15 -4
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/LICENSE +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/README.md +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_callbacks.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_datasetmaster.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_inference.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_models.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_scaler.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_simple_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ML_utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/constants.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/custom_logger.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/keys.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/math_utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/serde.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/setup.cfg +0 -0
|
@@ -18,7 +18,7 @@ from sklearn.metrics import (
|
|
|
18
18
|
import torch
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Union, Optional, List
|
|
21
|
+
from typing import Union, Optional, List, Literal
|
|
22
22
|
|
|
23
23
|
from .path_manager import make_fullpath
|
|
24
24
|
from ._logger import _LOGGER
|
|
@@ -249,13 +249,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
249
249
|
plt.savefig(hist_path)
|
|
250
250
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
251
251
|
plt.close(fig_hist)
|
|
252
|
-
|
|
252
|
+
|
|
253
253
|
|
|
254
254
|
def shap_summary_plot(model,
|
|
255
255
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
256
256
|
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
257
257
|
feature_names: Optional[list[str]],
|
|
258
|
-
save_dir: Union[str, Path]
|
|
258
|
+
save_dir: Union[str, Path],
|
|
259
|
+
device: torch.device = torch.device('cpu'),
|
|
260
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
259
261
|
"""
|
|
260
262
|
Calculates SHAP values and saves summary plots and data.
|
|
261
263
|
|
|
@@ -265,48 +267,85 @@ def shap_summary_plot(model,
|
|
|
265
267
|
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
266
268
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
267
269
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
270
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
271
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
272
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
|
|
273
|
+
PyTorch models.
|
|
274
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
275
|
+
slow and memory-intensive.
|
|
268
276
|
"""
|
|
269
|
-
# everything to numpy
|
|
270
|
-
if isinstance(background_data, np.ndarray):
|
|
271
|
-
background_data_np = background_data
|
|
272
|
-
else:
|
|
273
|
-
background_data_np = background_data.numpy()
|
|
274
|
-
|
|
275
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
276
|
-
instances_to_explain_np = instances_to_explain
|
|
277
|
-
else:
|
|
278
|
-
instances_to_explain_np = instances_to_explain.numpy()
|
|
279
277
|
|
|
280
|
-
|
|
281
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
282
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
283
|
-
return
|
|
284
|
-
|
|
285
|
-
print("\n--- SHAP Value Explanation ---")
|
|
278
|
+
print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
|
|
286
279
|
|
|
287
280
|
model.eval()
|
|
288
|
-
model.cpu()
|
|
289
|
-
|
|
290
|
-
# 1. Summarize the background data.
|
|
291
|
-
# Summarize the background data using k-means. 10-50 clusters is a good starting point.
|
|
292
|
-
background_summary = shap.kmeans(background_data_np, 30)
|
|
293
|
-
|
|
294
|
-
# 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
|
|
295
|
-
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
296
|
-
# Convert numpy data to torch tensor
|
|
297
|
-
x_torch = torch.from_numpy(x_np).float()
|
|
298
|
-
with torch.no_grad():
|
|
299
|
-
# Get model output
|
|
300
|
-
output = model(x_torch)
|
|
301
|
-
# Return as numpy array
|
|
302
|
-
return output.cpu().numpy().flatten()
|
|
303
|
-
|
|
304
|
-
# 3. Create the KernelExplainer
|
|
305
|
-
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
281
|
+
# model.cpu() # Run explanations on CPU
|
|
306
282
|
|
|
307
|
-
|
|
308
|
-
|
|
283
|
+
shap_values = None
|
|
284
|
+
instances_to_explain_np = None
|
|
285
|
+
|
|
286
|
+
if explainer_type == 'deep':
|
|
287
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
288
|
+
|
|
289
|
+
# Ensure data is torch.Tensor
|
|
290
|
+
if isinstance(background_data, np.ndarray):
|
|
291
|
+
background_data = torch.from_numpy(background_data).float()
|
|
292
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
293
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
294
|
+
|
|
295
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
296
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
297
|
+
return
|
|
298
|
+
|
|
299
|
+
background_data = background_data.to(device)
|
|
300
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
301
|
+
|
|
302
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
303
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
304
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
305
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
306
|
+
|
|
307
|
+
elif explainer_type == 'kernel':
|
|
308
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
309
|
+
_LOGGER.warning(
|
|
310
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
311
|
+
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
# Ensure data is np.ndarray
|
|
315
|
+
if isinstance(background_data, torch.Tensor):
|
|
316
|
+
background_data_np = background_data.cpu().numpy()
|
|
317
|
+
else:
|
|
318
|
+
background_data_np = background_data
|
|
319
|
+
|
|
320
|
+
if isinstance(instances_to_explain, torch.Tensor):
|
|
321
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
322
|
+
else:
|
|
323
|
+
instances_to_explain_np = instances_to_explain
|
|
324
|
+
|
|
325
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
326
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
327
|
+
return
|
|
328
|
+
|
|
329
|
+
# Summarize background data
|
|
330
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
331
|
+
|
|
332
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
333
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
334
|
+
with torch.no_grad():
|
|
335
|
+
output = model(x_torch)
|
|
336
|
+
# Return as numpy array
|
|
337
|
+
return output.cpu().numpy()
|
|
338
|
+
|
|
339
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
340
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
341
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
342
|
+
# instances_to_explain_np is already set
|
|
309
343
|
|
|
344
|
+
else:
|
|
345
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
346
|
+
raise ValueError()
|
|
347
|
+
|
|
348
|
+
# --- 3. Plotting and Saving ---
|
|
310
349
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
311
350
|
plt.ioff()
|
|
312
351
|
|
|
@@ -326,8 +365,9 @@ def shap_summary_plot(model,
|
|
|
326
365
|
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
327
366
|
ax = plt.gca()
|
|
328
367
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
329
|
-
|
|
330
|
-
|
|
368
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
369
|
+
cb = plt.gcf().axes[-1]
|
|
370
|
+
cb.set_ylabel("", size=1)
|
|
331
371
|
plt.title("SHAP Feature Importance")
|
|
332
372
|
plt.tight_layout()
|
|
333
373
|
plt.savefig(dot_path)
|
|
@@ -337,8 +377,14 @@ def shap_summary_plot(model,
|
|
|
337
377
|
# Save Summary Data to CSV
|
|
338
378
|
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
339
379
|
summary_path = save_dir_path / shap_summary_filename
|
|
340
|
-
|
|
341
|
-
|
|
380
|
+
|
|
381
|
+
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
382
|
+
if isinstance(shap_values, list):
|
|
383
|
+
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
384
|
+
else:
|
|
385
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
386
|
+
|
|
387
|
+
mean_abs_shap = mean_abs_shap.flatten()
|
|
342
388
|
|
|
343
389
|
if feature_names is None:
|
|
344
390
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
@@ -351,7 +397,7 @@ def shap_summary_plot(model,
|
|
|
351
397
|
summary_df.to_csv(summary_path, index=False)
|
|
352
398
|
|
|
353
399
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
354
|
-
plt.ion()
|
|
400
|
+
plt.ion()
|
|
355
401
|
|
|
356
402
|
|
|
357
403
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
@@ -19,11 +19,12 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List
|
|
22
|
+
from typing import Union, List, Literal
|
|
23
23
|
|
|
24
24
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
+
from .keys import SHAPKeys
|
|
27
28
|
|
|
28
29
|
|
|
29
30
|
__all__ = [
|
|
@@ -231,10 +232,12 @@ def multi_target_shap_summary_plot(
|
|
|
231
232
|
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
232
233
|
feature_names: List[str],
|
|
233
234
|
target_names: List[str],
|
|
234
|
-
save_dir: Union[str, Path]
|
|
235
|
+
save_dir: Union[str, Path],
|
|
236
|
+
device: torch.device = torch.device('cpu'),
|
|
237
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'
|
|
235
238
|
):
|
|
236
239
|
"""
|
|
237
|
-
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
240
|
+
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
238
241
|
|
|
239
242
|
Args:
|
|
240
243
|
model (torch.nn.Module): The trained PyTorch model.
|
|
@@ -243,40 +246,91 @@ def multi_target_shap_summary_plot(
|
|
|
243
246
|
feature_names (List[str]): Names of the features for plot labeling.
|
|
244
247
|
target_names (List[str]): Names of the output targets.
|
|
245
248
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
249
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
250
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
251
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
|
|
252
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
246
253
|
"""
|
|
247
|
-
|
|
248
|
-
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
249
|
-
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
250
|
-
|
|
251
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
252
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
253
|
-
return
|
|
254
|
-
|
|
255
|
-
_LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
|
|
254
|
+
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
256
255
|
model.eval()
|
|
257
|
-
model.cpu()
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
256
|
+
# model.cpu()
|
|
257
|
+
|
|
258
|
+
shap_values_list = None
|
|
259
|
+
instances_to_explain_np = None
|
|
260
|
+
|
|
261
|
+
if explainer_type == 'deep':
|
|
262
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
263
|
+
|
|
264
|
+
# Ensure data is torch.Tensor
|
|
265
|
+
if isinstance(background_data, np.ndarray):
|
|
266
|
+
background_data = torch.from_numpy(background_data).float()
|
|
267
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
268
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
269
|
+
|
|
270
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
271
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
272
|
+
return
|
|
273
|
+
|
|
274
|
+
background_data = background_data.to(device)
|
|
275
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
276
|
+
|
|
277
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
278
|
+
print("Calculating SHAP values with DeepExplainer...")
|
|
279
|
+
# DeepExplainer returns a list of arrays for multi-output models
|
|
280
|
+
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
281
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
282
|
+
|
|
283
|
+
elif explainer_type == 'kernel':
|
|
284
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
285
|
+
_LOGGER.warning(
|
|
286
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
287
|
+
"Consider reducing 'n_samples' if the process terminates."
|
|
288
|
+
)
|
|
289
|
+
|
|
290
|
+
# Convert all data to numpy
|
|
291
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
292
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
293
|
+
|
|
294
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
295
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
296
|
+
return
|
|
297
|
+
|
|
298
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
299
|
+
|
|
300
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
301
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
302
|
+
with torch.no_grad():
|
|
303
|
+
output = model(x_torch)
|
|
304
|
+
return output.cpu().numpy() # Return full multi-output array
|
|
305
|
+
|
|
306
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
307
|
+
print("Calculating SHAP values with KernelExplainer...")
|
|
308
|
+
# KernelExplainer also returns a list of arrays for multi-output models
|
|
309
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
310
|
+
# instances_to_explain_np is already set
|
|
311
|
+
|
|
312
|
+
else:
|
|
313
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
314
|
+
raise ValueError("Invalid explainer_type")
|
|
315
|
+
|
|
316
|
+
# --- 3. Plotting and Saving (Common Logic) ---
|
|
317
|
+
|
|
318
|
+
if shap_values_list is None or instances_to_explain_np is None:
|
|
319
|
+
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
320
|
+
return
|
|
321
|
+
|
|
322
|
+
# Ensure number of SHAP value arrays matches number of target names
|
|
323
|
+
if len(shap_values_list) != len(target_names):
|
|
324
|
+
_LOGGER.error(
|
|
325
|
+
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
326
|
+
f"outputs, but {len(target_names)} target_names were provided."
|
|
327
|
+
)
|
|
328
|
+
return
|
|
275
329
|
|
|
276
330
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
277
331
|
plt.ioff()
|
|
278
332
|
|
|
279
|
-
#
|
|
333
|
+
# Iterate through each target's SHAP values and generate plots.
|
|
280
334
|
for i, target_name in enumerate(target_names):
|
|
281
335
|
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
282
336
|
shap_values_for_target = shap_values_list[i]
|
|
@@ -293,11 +347,28 @@ def multi_target_shap_summary_plot(
|
|
|
293
347
|
# Save Dot Plot for the target
|
|
294
348
|
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
295
349
|
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
350
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
351
|
+
cb = plt.gcf().axes[-1]
|
|
352
|
+
cb.set_ylabel("", size=1)
|
|
296
353
|
plt.tight_layout()
|
|
297
354
|
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
298
355
|
plt.savefig(dot_path)
|
|
299
356
|
plt.close()
|
|
300
|
-
|
|
357
|
+
|
|
358
|
+
# --- Save Summary Data to CSV for this target ---
|
|
359
|
+
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
360
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
361
|
+
|
|
362
|
+
# For a specific target, shap_values_for_target is just a 2D array
|
|
363
|
+
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
364
|
+
|
|
365
|
+
summary_df = pd.DataFrame({
|
|
366
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
367
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
368
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
369
|
+
|
|
370
|
+
summary_df.to_csv(summary_path, index=False)
|
|
371
|
+
|
|
301
372
|
plt.ion()
|
|
302
373
|
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
303
374
|
|
|
@@ -340,9 +340,10 @@ class MLTrainer:
|
|
|
340
340
|
def explain(self,
|
|
341
341
|
save_dir: Union[str,Path],
|
|
342
342
|
explain_dataset: Optional[Dataset] = None,
|
|
343
|
-
n_samples: int =
|
|
343
|
+
n_samples: int = 300,
|
|
344
344
|
feature_names: Optional[List[str]] = None,
|
|
345
|
-
target_names: Optional[List[str]] = None
|
|
345
|
+
target_names: Optional[List[str]] = None,
|
|
346
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
346
347
|
"""
|
|
347
348
|
Explains model predictions using SHAP and saves all artifacts.
|
|
348
349
|
|
|
@@ -359,6 +360,9 @@ class MLTrainer:
|
|
|
359
360
|
feature_names (list[str] | None): Feature names.
|
|
360
361
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
361
362
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
363
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
364
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
365
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
362
366
|
"""
|
|
363
367
|
# Internal helper to create a dataloader and get a random sample
|
|
364
368
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
@@ -410,6 +414,9 @@ class MLTrainer:
|
|
|
410
414
|
else:
|
|
411
415
|
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
412
416
|
raise ValueError()
|
|
417
|
+
|
|
418
|
+
# move model to device
|
|
419
|
+
self.model.to(self.device)
|
|
413
420
|
|
|
414
421
|
# 3. Call the plotting function
|
|
415
422
|
if self.kind in ["regression", "classification"]:
|
|
@@ -418,7 +425,9 @@ class MLTrainer:
|
|
|
418
425
|
background_data=background_data,
|
|
419
426
|
instances_to_explain=instances_to_explain,
|
|
420
427
|
feature_names=feature_names,
|
|
421
|
-
save_dir=save_dir
|
|
428
|
+
save_dir=save_dir,
|
|
429
|
+
explainer_type=explainer_type,
|
|
430
|
+
device=self.device
|
|
422
431
|
)
|
|
423
432
|
elif self.kind in ["multi_target_regression", "multi_label_classification"]:
|
|
424
433
|
# try to get target names
|
|
@@ -442,7 +451,9 @@ class MLTrainer:
|
|
|
442
451
|
instances_to_explain=instances_to_explain,
|
|
443
452
|
feature_names=feature_names, # type: ignore
|
|
444
453
|
target_names=target_names, # type: ignore
|
|
445
|
-
save_dir=save_dir
|
|
454
|
+
save_dir=save_dir,
|
|
455
|
+
explainer_type=explainer_type,
|
|
456
|
+
device=self.device
|
|
446
457
|
)
|
|
447
458
|
|
|
448
459
|
def _attention_helper(self, dataloader: DataLoader):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/requires.txt
RENAMED
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-12.13.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|