dragon-ml-toolbox 12.12.0__py3-none-any.whl → 12.13.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.12.0
3
+ Version: 12.13.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,19 +1,19 @@
1
- dragon_ml_toolbox-12.12.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-12.12.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
1
+ dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-12.13.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
3
3
  ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
4
  ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
5
  ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
6
  ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
7
7
  ml_tools/ML_callbacks.py,sha256=2ZazJjlbClP-ALc8q0ru2oalkugbhO3TFwPg4RFZpck,14056
8
8
  ml_tools/ML_datasetmaster.py,sha256=kedCGneR3S2zui0_JFZN6TBL5e69XWkdpkE_QohyqSM,31433
9
- ml_tools/ML_evaluation.py,sha256=tLswOPgH4G1KExSMn0876YtNkbxPh-W3J4MYOjomMWA,16208
10
- ml_tools/ML_evaluation_multi.py,sha256=6OZyQ4SM9ALh38mOABmiHgIQDWcovsD_iOo7Bg9YZCE,12516
9
+ ml_tools/ML_evaluation.py,sha256=h7fAtk0lS4gTqQ46fiVjucTvFlX4rsufKnEtate6Nu0,18381
10
+ ml_tools/ML_evaluation_multi.py,sha256=Kn9n5lfxo7A0TvgIDMx8UHZCvzTqv1ViezzwJBF-ypM,15970
11
11
  ml_tools/ML_inference.py,sha256=ymFvncFsU10PExq87xnEj541DKV5ck0nMuK8ToJHzVQ,23067
12
12
  ml_tools/ML_models.py,sha256=G64NPhYZfYvHTIUwkIrMrNLgfDTKJwqdc8jwesPqB9E,28090
13
13
  ml_tools/ML_optimization.py,sha256=es3TlQbY7RYgJMZnznkjYGbUxFnAqzZxE_g3_qLK9Q8,22960
14
14
  ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
15
15
  ml_tools/ML_simple_optimization.py,sha256=W2mce1XFCuiOHTOjOsCNbETISHn5MwYlYsTIXH5hMMo,18177
16
- ml_tools/ML_trainer.py,sha256=_g48w5Ak-wQr5fGHdJqlcpnzv3gWyL1ghkOhy9VOZbo,23930
16
+ ml_tools/ML_trainer.py,sha256=UmCuKr_GzQGYqhEZ-kaRv9Buj44DsNyuOzmOM7Fw8N0,24569
17
17
  ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
18
18
  ml_tools/PSO_optimization.py,sha256=fVHeemqilBS0zrGV25E5yKwDlGdd2ZKa18d8CZ6Q6Fk,22961
19
19
  ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
@@ -35,7 +35,7 @@ ml_tools/optimization_tools.py,sha256=P074YCuZzkqkONnAsM-Zb9DTX_i8cRkkJLpwAWz6CR
35
35
  ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
36
36
  ml_tools/serde.py,sha256=ll2mVC0sO2jIEdG3K6xMcgEN13N4YSb8VjviGvw_ers,4949
37
37
  ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
38
- dragon_ml_toolbox-12.12.0.dist-info/METADATA,sha256=PKf7t2ojMJs9-6STvqebRBxS_1rWPv58ff0BqPk2d_A,6167
39
- dragon_ml_toolbox-12.12.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- dragon_ml_toolbox-12.12.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
- dragon_ml_toolbox-12.12.0.dist-info/RECORD,,
38
+ dragon_ml_toolbox-12.13.0.dist-info/METADATA,sha256=p3-oOSqq1hhJj13KjIXeFnwBu3UTfBJu5mTDL9MCpdU,6167
39
+ dragon_ml_toolbox-12.13.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ dragon_ml_toolbox-12.13.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
+ dragon_ml_toolbox-12.13.0.dist-info/RECORD,,
ml_tools/ML_evaluation.py CHANGED
@@ -18,7 +18,7 @@ from sklearn.metrics import (
18
18
  import torch
19
19
  import shap
20
20
  from pathlib import Path
21
- from typing import Union, Optional, List
21
+ from typing import Union, Optional, List, Literal
22
22
 
23
23
  from .path_manager import make_fullpath
24
24
  from ._logger import _LOGGER
@@ -249,13 +249,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
249
249
  plt.savefig(hist_path)
250
250
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
251
251
  plt.close(fig_hist)
252
-
252
+
253
253
 
254
254
  def shap_summary_plot(model,
255
255
  background_data: Union[torch.Tensor,np.ndarray],
256
256
  instances_to_explain: Union[torch.Tensor,np.ndarray],
257
257
  feature_names: Optional[list[str]],
258
- save_dir: Union[str, Path]):
258
+ save_dir: Union[str, Path],
259
+ device: torch.device = torch.device('cpu'),
260
+ explainer_type: Literal['deep', 'kernel'] = 'deep'):
259
261
  """
260
262
  Calculates SHAP values and saves summary plots and data.
261
263
 
@@ -265,48 +267,85 @@ def shap_summary_plot(model,
265
267
  instances_to_explain (torch.Tensor): The specific data instances to explain.
266
268
  feature_names (list of str | None): Names of the features for plot labeling.
267
269
  save_dir (str | Path): Directory to save SHAP artifacts.
270
+ device (torch.device): The torch device for SHAP calculations.
271
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
272
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
273
+ PyTorch models.
274
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
275
+ slow and memory-intensive.
268
276
  """
269
- # everything to numpy
270
- if isinstance(background_data, np.ndarray):
271
- background_data_np = background_data
272
- else:
273
- background_data_np = background_data.numpy()
274
-
275
- if isinstance(instances_to_explain, np.ndarray):
276
- instances_to_explain_np = instances_to_explain
277
- else:
278
- instances_to_explain_np = instances_to_explain.numpy()
279
277
 
280
- # --- Data Validation Step ---
281
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
282
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
283
- return
284
-
285
- print("\n--- SHAP Value Explanation ---")
278
+ print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
286
279
 
287
280
  model.eval()
288
- model.cpu()
289
-
290
- # 1. Summarize the background data.
291
- # Summarize the background data using k-means. 10-50 clusters is a good starting point.
292
- background_summary = shap.kmeans(background_data_np, 30)
293
-
294
- # 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
295
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
296
- # Convert numpy data to torch tensor
297
- x_torch = torch.from_numpy(x_np).float()
298
- with torch.no_grad():
299
- # Get model output
300
- output = model(x_torch)
301
- # Return as numpy array
302
- return output.cpu().numpy().flatten()
303
-
304
- # 3. Create the KernelExplainer
305
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
281
+ # model.cpu() # Run explanations on CPU
306
282
 
307
- print("Calculating SHAP values with KernelExplainer...")
308
- shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
283
+ shap_values = None
284
+ instances_to_explain_np = None
285
+
286
+ if explainer_type == 'deep':
287
+ # --- 1. Use DeepExplainer (Preferred) ---
288
+
289
+ # Ensure data is torch.Tensor
290
+ if isinstance(background_data, np.ndarray):
291
+ background_data = torch.from_numpy(background_data).float()
292
+ if isinstance(instances_to_explain, np.ndarray):
293
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
294
+
295
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
296
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
297
+ return
298
+
299
+ background_data = background_data.to(device)
300
+ instances_to_explain = instances_to_explain.to(device)
301
+
302
+ explainer = shap.DeepExplainer(model, background_data)
303
+ # print("Calculating SHAP values with DeepExplainer...")
304
+ shap_values = explainer.shap_values(instances_to_explain)
305
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
306
+
307
+ elif explainer_type == 'kernel':
308
+ # --- 2. Use KernelExplainer (Slow Fallback) ---
309
+ _LOGGER.warning(
310
+ "Using KernelExplainer. This is memory-intensive and slow. "
311
+ "Consider reducing 'n_samples' if the process terminates unexpectedly."
312
+ )
313
+
314
+ # Ensure data is np.ndarray
315
+ if isinstance(background_data, torch.Tensor):
316
+ background_data_np = background_data.cpu().numpy()
317
+ else:
318
+ background_data_np = background_data
319
+
320
+ if isinstance(instances_to_explain, torch.Tensor):
321
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
322
+ else:
323
+ instances_to_explain_np = instances_to_explain
324
+
325
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
326
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
327
+ return
328
+
329
+ # Summarize background data
330
+ background_summary = shap.kmeans(background_data_np, 30)
331
+
332
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
333
+ x_torch = torch.from_numpy(x_np).float().to(device)
334
+ with torch.no_grad():
335
+ output = model(x_torch)
336
+ # Return as numpy array
337
+ return output.cpu().numpy()
338
+
339
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
340
+ # print("Calculating SHAP values with KernelExplainer...")
341
+ shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
342
+ # instances_to_explain_np is already set
309
343
 
344
+ else:
345
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
346
+ raise ValueError()
347
+
348
+ # --- 3. Plotting and Saving ---
310
349
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
311
350
  plt.ioff()
312
351
 
@@ -326,8 +365,9 @@ def shap_summary_plot(model,
326
365
  shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
327
366
  ax = plt.gca()
328
367
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
329
- cb = plt.gcf().axes[-1]
330
- cb.set_ylabel("", size=1)
368
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
369
+ cb = plt.gcf().axes[-1]
370
+ cb.set_ylabel("", size=1)
331
371
  plt.title("SHAP Feature Importance")
332
372
  plt.tight_layout()
333
373
  plt.savefig(dot_path)
@@ -337,8 +377,14 @@ def shap_summary_plot(model,
337
377
  # Save Summary Data to CSV
338
378
  shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
339
379
  summary_path = save_dir_path / shap_summary_filename
340
- # Ensure the array is 1D before creating the DataFrame
341
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
380
+
381
+ # Handle multi-class (list of arrays) vs. regression (single array)
382
+ if isinstance(shap_values, list):
383
+ mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
384
+ else:
385
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
386
+
387
+ mean_abs_shap = mean_abs_shap.flatten()
342
388
 
343
389
  if feature_names is None:
344
390
  feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
@@ -351,7 +397,7 @@ def shap_summary_plot(model,
351
397
  summary_df.to_csv(summary_path, index=False)
352
398
 
353
399
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
354
- plt.ion()
400
+ plt.ion()
355
401
 
356
402
 
357
403
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -19,11 +19,12 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List
22
+ from typing import Union, List, Literal
23
23
 
24
24
  from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
+ from .keys import SHAPKeys
27
28
 
28
29
 
29
30
  __all__ = [
@@ -231,10 +232,12 @@ def multi_target_shap_summary_plot(
231
232
  instances_to_explain: Union[torch.Tensor, np.ndarray],
232
233
  feature_names: List[str],
233
234
  target_names: List[str],
234
- save_dir: Union[str, Path]
235
+ save_dir: Union[str, Path],
236
+ device: torch.device = torch.device('cpu'),
237
+ explainer_type: Literal['deep', 'kernel'] = 'deep'
235
238
  ):
236
239
  """
237
- Calculates SHAP values for a multi-target model and saves summary plots for each target.
240
+ Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
238
241
 
239
242
  Args:
240
243
  model (torch.nn.Module): The trained PyTorch model.
@@ -243,40 +246,91 @@ def multi_target_shap_summary_plot(
243
246
  feature_names (List[str]): Names of the features for plot labeling.
244
247
  target_names (List[str]): Names of the output targets.
245
248
  save_dir (str | Path): Directory to save SHAP artifacts.
249
+ device (torch.device): The torch device for SHAP calculations.
250
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
251
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
252
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
246
253
  """
247
- # Convert all data to numpy
248
- background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
249
- instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
250
-
251
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
252
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
253
- return
254
-
255
- _LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
254
+ _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
256
255
  model.eval()
257
- model.cpu()
258
-
259
- # 1. Summarize the background data.
260
- background_summary = shap.kmeans(background_data_np, 30)
261
-
262
- # 2. Define a prediction function wrapper for the multi-target model.
263
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
264
- x_torch = torch.from_numpy(x_np).float()
265
- with torch.no_grad():
266
- output = model(x_torch)
267
- return output.cpu().numpy()
268
-
269
- # 3. Create the KernelExplainer.
270
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
271
-
272
- print("Calculating SHAP values with KernelExplainer...")
273
- # For multi-output models, shap_values is a list of arrays.
274
- shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
256
+ # model.cpu()
257
+
258
+ shap_values_list = None
259
+ instances_to_explain_np = None
260
+
261
+ if explainer_type == 'deep':
262
+ # --- 1. Use DeepExplainer (Preferred) ---
263
+
264
+ # Ensure data is torch.Tensor
265
+ if isinstance(background_data, np.ndarray):
266
+ background_data = torch.from_numpy(background_data).float()
267
+ if isinstance(instances_to_explain, np.ndarray):
268
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
269
+
270
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
271
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
272
+ return
273
+
274
+ background_data = background_data.to(device)
275
+ instances_to_explain = instances_to_explain.to(device)
276
+
277
+ explainer = shap.DeepExplainer(model, background_data)
278
+ print("Calculating SHAP values with DeepExplainer...")
279
+ # DeepExplainer returns a list of arrays for multi-output models
280
+ shap_values_list = explainer.shap_values(instances_to_explain)
281
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
282
+
283
+ elif explainer_type == 'kernel':
284
+ # --- 2. Use KernelExplainer (Slow Fallback) ---
285
+ _LOGGER.warning(
286
+ "Using KernelExplainer. This is memory-intensive and slow. "
287
+ "Consider reducing 'n_samples' if the process terminates."
288
+ )
289
+
290
+ # Convert all data to numpy
291
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
292
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
293
+
294
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
295
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
296
+ return
297
+
298
+ background_summary = shap.kmeans(background_data_np, 30)
299
+
300
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
301
+ x_torch = torch.from_numpy(x_np).float().to(device)
302
+ with torch.no_grad():
303
+ output = model(x_torch)
304
+ return output.cpu().numpy() # Return full multi-output array
305
+
306
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
307
+ print("Calculating SHAP values with KernelExplainer...")
308
+ # KernelExplainer also returns a list of arrays for multi-output models
309
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
310
+ # instances_to_explain_np is already set
311
+
312
+ else:
313
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
314
+ raise ValueError("Invalid explainer_type")
315
+
316
+ # --- 3. Plotting and Saving (Common Logic) ---
317
+
318
+ if shap_values_list is None or instances_to_explain_np is None:
319
+ _LOGGER.error("SHAP value calculation failed. Aborting plotting.")
320
+ return
321
+
322
+ # Ensure number of SHAP value arrays matches number of target names
323
+ if len(shap_values_list) != len(target_names):
324
+ _LOGGER.error(
325
+ f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
326
+ f"outputs, but {len(target_names)} target_names were provided."
327
+ )
328
+ return
275
329
 
276
330
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
277
331
  plt.ioff()
278
332
 
279
- # 4. Iterate through each target's SHAP values and generate plots.
333
+ # Iterate through each target's SHAP values and generate plots.
280
334
  for i, target_name in enumerate(target_names):
281
335
  print(f" -> Generating SHAP plots for target: '{target_name}'")
282
336
  shap_values_for_target = shap_values_list[i]
@@ -293,11 +347,28 @@ def multi_target_shap_summary_plot(
293
347
  # Save Dot Plot for the target
294
348
  shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
295
349
  plt.title(f"SHAP Feature Importance for '{target_name}'")
350
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
351
+ cb = plt.gcf().axes[-1]
352
+ cb.set_ylabel("", size=1)
296
353
  plt.tight_layout()
297
354
  dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
298
355
  plt.savefig(dot_path)
299
356
  plt.close()
300
-
357
+
358
+ # --- Save Summary Data to CSV for this target ---
359
+ shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
360
+ summary_path = save_dir_path / shap_summary_filename
361
+
362
+ # For a specific target, shap_values_for_target is just a 2D array
363
+ mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
364
+
365
+ summary_df = pd.DataFrame({
366
+ SHAPKeys.FEATURE_COLUMN: feature_names,
367
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
368
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
369
+
370
+ summary_df.to_csv(summary_path, index=False)
371
+
301
372
  plt.ion()
302
373
  _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
303
374
 
ml_tools/ML_trainer.py CHANGED
@@ -340,9 +340,10 @@ class MLTrainer:
340
340
  def explain(self,
341
341
  save_dir: Union[str,Path],
342
342
  explain_dataset: Optional[Dataset] = None,
343
- n_samples: int = 1000,
343
+ n_samples: int = 300,
344
344
  feature_names: Optional[List[str]] = None,
345
- target_names: Optional[List[str]] = None):
345
+ target_names: Optional[List[str]] = None,
346
+ explainer_type: Literal['deep', 'kernel'] = 'deep'):
346
347
  """
347
348
  Explains model predictions using SHAP and saves all artifacts.
348
349
 
@@ -359,6 +360,9 @@ class MLTrainer:
359
360
  feature_names (list[str] | None): Feature names.
360
361
  target_names (list[str] | None): Target names for multi-target tasks.
361
362
  save_dir (str | Path): Directory to save all SHAP artifacts.
363
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
364
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
365
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
362
366
  """
363
367
  # Internal helper to create a dataloader and get a random sample
364
368
  def _get_random_sample(dataset: Dataset, num_samples: int):
@@ -410,6 +414,9 @@ class MLTrainer:
410
414
  else:
411
415
  _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
412
416
  raise ValueError()
417
+
418
+ # move model to device
419
+ self.model.to(self.device)
413
420
 
414
421
  # 3. Call the plotting function
415
422
  if self.kind in ["regression", "classification"]:
@@ -418,7 +425,9 @@ class MLTrainer:
418
425
  background_data=background_data,
419
426
  instances_to_explain=instances_to_explain,
420
427
  feature_names=feature_names,
421
- save_dir=save_dir
428
+ save_dir=save_dir,
429
+ explainer_type=explainer_type,
430
+ device=self.device
422
431
  )
423
432
  elif self.kind in ["multi_target_regression", "multi_label_classification"]:
424
433
  # try to get target names
@@ -442,7 +451,9 @@ class MLTrainer:
442
451
  instances_to_explain=instances_to_explain,
443
452
  feature_names=feature_names, # type: ignore
444
453
  target_names=target_names, # type: ignore
445
- save_dir=save_dir
454
+ save_dir=save_dir,
455
+ explainer_type=explainer_type,
456
+ device=self.device
446
457
  )
447
458
 
448
459
  def _attention_helper(self, dataloader: DataLoader):