dragon-ml-toolbox 10.2.0__py3-none-any.whl → 14.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (48) hide show
  1. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
  2. dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
  3. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
  4. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
  5. ml_tools/ETL_cleaning.py +72 -34
  6. ml_tools/ETL_engineering.py +506 -70
  7. ml_tools/GUI_tools.py +2 -1
  8. ml_tools/MICE_imputation.py +212 -7
  9. ml_tools/ML_callbacks.py +73 -40
  10. ml_tools/ML_datasetmaster.py +267 -284
  11. ml_tools/ML_evaluation.py +119 -58
  12. ml_tools/ML_evaluation_multi.py +107 -32
  13. ml_tools/ML_inference.py +15 -5
  14. ml_tools/ML_models.py +234 -170
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +321 -97
  17. ml_tools/ML_scaler.py +10 -5
  18. ml_tools/ML_trainer.py +585 -40
  19. ml_tools/ML_utilities.py +528 -0
  20. ml_tools/ML_vision_datasetmaster.py +1315 -0
  21. ml_tools/ML_vision_evaluation.py +260 -0
  22. ml_tools/ML_vision_inference.py +428 -0
  23. ml_tools/ML_vision_models.py +627 -0
  24. ml_tools/ML_vision_transformers.py +58 -0
  25. ml_tools/PSO_optimization.py +10 -7
  26. ml_tools/RNN_forecast.py +2 -0
  27. ml_tools/SQL.py +22 -9
  28. ml_tools/VIF_factor.py +4 -3
  29. ml_tools/_ML_vision_recipe.py +88 -0
  30. ml_tools/__init__.py +1 -0
  31. ml_tools/_logger.py +0 -2
  32. ml_tools/_schema.py +96 -0
  33. ml_tools/constants.py +79 -0
  34. ml_tools/custom_logger.py +164 -16
  35. ml_tools/data_exploration.py +1092 -109
  36. ml_tools/ensemble_evaluation.py +48 -1
  37. ml_tools/ensemble_inference.py +6 -7
  38. ml_tools/ensemble_learning.py +4 -3
  39. ml_tools/handle_excel.py +1 -0
  40. ml_tools/keys.py +80 -0
  41. ml_tools/math_utilities.py +259 -0
  42. ml_tools/optimization_tools.py +198 -24
  43. ml_tools/path_manager.py +144 -45
  44. ml_tools/serde.py +192 -0
  45. ml_tools/utilities.py +287 -227
  46. dragon_ml_toolbox-10.2.0.dist-info/RECORD +0 -36
  47. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
  48. {dragon_ml_toolbox-10.2.0.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -18,10 +18,13 @@ from sklearn.metrics import (
18
18
  import torch
19
19
  import shap
20
20
  from pathlib import Path
21
+ from typing import Union, Optional, List, Literal
22
+ import warnings
23
+
21
24
  from .path_manager import make_fullpath
22
25
  from ._logger import _LOGGER
23
- from typing import Union, Optional, List
24
26
  from ._script_info import _script_info
27
+ from .keys import SHAPKeys, PyTorchLogKeys
25
28
 
26
29
 
27
30
  __all__ = [
@@ -41,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
41
44
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
42
45
  save_dir (str | Path): Directory to save the plot image.
43
46
  """
44
- train_loss = history.get('train_loss', [])
45
- val_loss = history.get('val_loss', [])
47
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
48
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
46
49
 
47
50
  if not train_loss and not val_loss:
48
51
  print("Warning: Loss history is empty or incomplete. Cannot plot.")
@@ -247,13 +250,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
247
250
  plt.savefig(hist_path)
248
251
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
249
252
  plt.close(fig_hist)
250
-
253
+
251
254
 
252
255
  def shap_summary_plot(model,
253
256
  background_data: Union[torch.Tensor,np.ndarray],
254
257
  instances_to_explain: Union[torch.Tensor,np.ndarray],
255
258
  feature_names: Optional[list[str]],
256
- save_dir: Union[str, Path]):
259
+ save_dir: Union[str, Path],
260
+ device: torch.device = torch.device('cpu'),
261
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
257
262
  """
258
263
  Calculates SHAP values and saves summary plots and data.
259
264
 
@@ -263,54 +268,105 @@ def shap_summary_plot(model,
263
268
  instances_to_explain (torch.Tensor): The specific data instances to explain.
264
269
  feature_names (list of str | None): Names of the features for plot labeling.
265
270
  save_dir (str | Path): Directory to save SHAP artifacts.
271
+ device (torch.device): The torch device for SHAP calculations.
272
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
+ PyTorch models.
275
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
+ slow and memory-intensive.
266
277
  """
267
- # everything to numpy
268
- if isinstance(background_data, np.ndarray):
269
- background_data_np = background_data
270
- else:
271
- background_data_np = background_data.numpy()
272
-
273
- if isinstance(instances_to_explain, np.ndarray):
274
- instances_to_explain_np = instances_to_explain
275
- else:
276
- instances_to_explain_np = instances_to_explain.numpy()
277
-
278
- # --- Data Validation Step ---
279
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
280
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
281
- return
282
278
 
283
- print("\n--- SHAP Value Explanation ---")
279
+ print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
284
280
 
285
281
  model.eval()
286
- model.cpu()
282
+ # model.cpu() # Run explanations on CPU
287
283
 
288
- # 1. Summarize the background data.
289
- # Summarize the background data using k-means. 10-50 clusters is a good starting point.
290
- background_summary = shap.kmeans(background_data_np, 30)
291
-
292
- # 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
293
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
294
- # Convert numpy data to torch tensor
295
- x_torch = torch.from_numpy(x_np).float()
296
- with torch.no_grad():
297
- # Get model output
298
- output = model(x_torch)
299
- # Return as numpy array
300
- return output.cpu().numpy().flatten()
301
-
302
- # 3. Create the KernelExplainer
303
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
284
+ shap_values = None
285
+ instances_to_explain_np = None
286
+
287
+ if explainer_type == 'deep':
288
+ # --- 1. Use DeepExplainer ---
289
+
290
+ # Ensure data is torch.Tensor
291
+ if isinstance(background_data, np.ndarray):
292
+ background_data = torch.from_numpy(background_data).float()
293
+ if isinstance(instances_to_explain, np.ndarray):
294
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
295
+
296
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
297
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
298
+ return
299
+
300
+ background_data = background_data.to(device)
301
+ instances_to_explain = instances_to_explain.to(device)
302
+
303
+ with warnings.catch_warnings():
304
+ warnings.simplefilter("ignore", category=UserWarning)
305
+ explainer = shap.DeepExplainer(model, background_data)
306
+
307
+ # print("Calculating SHAP values with DeepExplainer...")
308
+ shap_values = explainer.shap_values(instances_to_explain)
309
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
310
+
311
+ elif explainer_type == 'kernel':
312
+ # --- 2. Use KernelExplainer ---
313
+ _LOGGER.warning(
314
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
315
+ )
316
+
317
+ # Ensure data is np.ndarray
318
+ if isinstance(background_data, torch.Tensor):
319
+ background_data_np = background_data.cpu().numpy()
320
+ else:
321
+ background_data_np = background_data
322
+
323
+ if isinstance(instances_to_explain, torch.Tensor):
324
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
325
+ else:
326
+ instances_to_explain_np = instances_to_explain
327
+
328
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
329
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
330
+ return
331
+
332
+ # Summarize background data
333
+ background_summary = shap.kmeans(background_data_np, 30)
334
+
335
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
336
+ x_torch = torch.from_numpy(x_np).float().to(device)
337
+ with torch.no_grad():
338
+ output = model(x_torch)
339
+ # Return as numpy array
340
+ return output.cpu().numpy()
341
+
342
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
343
+ # print("Calculating SHAP values with KernelExplainer...")
344
+ shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
345
+ # instances_to_explain_np is already set
304
346
 
305
- print("Calculating SHAP values with KernelExplainer...")
306
- shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
347
+ else:
348
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
349
+ raise ValueError()
307
350
 
351
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
352
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
+ shap_values = shap_values.squeeze(-1)
354
+
355
+ # --- 3. Plotting and Saving ---
308
356
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
309
357
  plt.ioff()
310
358
 
359
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
360
+ if feature_names is None:
361
+ # Create generic names if none were provided
362
+ num_features = instances_to_explain_np.shape[1]
363
+ feature_names = [f'feature_{i}' for i in range(num_features)]
364
+
365
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
366
+
311
367
  # Save Bar Plot
312
368
  bar_path = save_dir_path / "shap_bar_plot.svg"
313
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
369
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
314
370
  ax = plt.gca()
315
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
316
372
  plt.title("SHAP Feature Importance")
@@ -321,11 +377,12 @@ def shap_summary_plot(model,
321
377
 
322
378
  # Save Dot Plot
323
379
  dot_path = save_dir_path / "shap_dot_plot.svg"
324
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
380
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
325
381
  ax = plt.gca()
326
382
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
327
- cb = plt.gcf().axes[-1]
328
- cb.set_ylabel("", size=1)
383
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
384
+ cb = plt.gcf().axes[-1]
385
+ cb.set_ylabel("", size=1)
329
386
  plt.title("SHAP Feature Importance")
330
387
  plt.tight_layout()
331
388
  plt.savefig(dot_path)
@@ -333,17 +390,21 @@ def shap_summary_plot(model,
333
390
  plt.close()
334
391
 
335
392
  # Save Summary Data to CSV
336
- summary_path = save_dir_path / "shap_summary.csv"
337
- # Ensure the array is 1D before creating the DataFrame
338
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
393
+ shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
394
+ summary_path = save_dir_path / shap_summary_filename
339
395
 
340
- if feature_names is None:
341
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
396
+ # Handle multi-class (list of arrays) vs. regression (single array)
397
+ if isinstance(shap_values, list):
398
+ mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
399
+ else:
400
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
401
+
402
+ mean_abs_shap = mean_abs_shap.flatten()
342
403
 
343
404
  summary_df = pd.DataFrame({
344
- 'feature': feature_names,
345
- 'mean_abs_shap_value': mean_abs_shap
346
- }).sort_values('mean_abs_shap_value', ascending=False)
405
+ SHAPKeys.FEATURE_COLUMN: feature_names,
406
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
407
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
347
408
 
348
409
  summary_df.to_csv(summary_path, index=False)
349
410
 
@@ -351,7 +412,7 @@ def shap_summary_plot(model,
351
412
  plt.ion()
352
413
 
353
414
 
354
- def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
415
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
355
416
  """
356
417
  Aggregates attention weights and plots global feature importance.
357
418
 
@@ -362,6 +423,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
362
423
  weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
363
424
  feature_names (List[str] | None): Names of the features for plot labeling.
364
425
  save_dir (str | Path): Directory to save the plot and summary CSV.
426
+ top_n (int): The number of top features to display in the plot.
365
427
  """
366
428
  if not weights:
367
429
  _LOGGER.error("Attention weights list is empty. Skipping importance plot.")
@@ -390,11 +452,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
390
452
  summary_df.to_csv(summary_path, index=False)
391
453
  _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
392
454
 
393
- # --- Step 3: Create and save the plot ---
394
- plt.figure(figsize=(10, 8), dpi=100)
455
+ # --- Step 3: Create and save the plot for top N features ---
456
+ plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
395
457
 
396
- # Sort for plotting
397
- plot_df = summary_df.sort_values('mean_attention', ascending=True)
458
+ plt.figure(figsize=(10, 8), dpi=100)
398
459
 
399
460
  # Create horizontal bar plot with error bars
400
461
  plt.barh(
@@ -408,7 +469,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
408
469
  color='cornflowerblue'
409
470
  )
410
471
 
411
- plt.title('Global Feature Importance')
472
+ plt.title('Top Features by Attention')
412
473
  plt.xlabel('Average Attention Weight')
413
474
  plt.ylabel('Feature')
414
475
  plt.grid(axis='x', linestyle='--', alpha=0.6)
@@ -19,11 +19,14 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List
22
+ from typing import Union, List, Literal
23
+ import warnings
23
24
 
24
25
  from .path_manager import make_fullpath, sanitize_filename
25
26
  from ._logger import _LOGGER
26
27
  from ._script_info import _script_info
28
+ from .keys import SHAPKeys
29
+
27
30
 
28
31
  __all__ = [
29
32
  "multi_target_regression_metrics",
@@ -230,10 +233,12 @@ def multi_target_shap_summary_plot(
230
233
  instances_to_explain: Union[torch.Tensor, np.ndarray],
231
234
  feature_names: List[str],
232
235
  target_names: List[str],
233
- save_dir: Union[str, Path]
236
+ save_dir: Union[str, Path],
237
+ device: torch.device = torch.device('cpu'),
238
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
234
239
  ):
235
240
  """
236
- Calculates SHAP values for a multi-target model and saves summary plots for each target.
241
+ Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
237
242
 
238
243
  Args:
239
244
  model (torch.nn.Module): The trained PyTorch model.
@@ -242,40 +247,93 @@ def multi_target_shap_summary_plot(
242
247
  feature_names (List[str]): Names of the features for plot labeling.
243
248
  target_names (List[str]): Names of the output targets.
244
249
  save_dir (str | Path): Directory to save SHAP artifacts.
250
+ device (torch.device): The torch device for SHAP calculations.
251
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
245
254
  """
246
- # Convert all data to numpy
247
- background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
248
- instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
249
-
250
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
251
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
252
- return
253
-
254
- _LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
255
+ _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
255
256
  model.eval()
256
- model.cpu()
257
-
258
- # 1. Summarize the background data.
259
- background_summary = shap.kmeans(background_data_np, 30)
260
-
261
- # 2. Define a prediction function wrapper for the multi-target model.
262
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
263
- x_torch = torch.from_numpy(x_np).float()
264
- with torch.no_grad():
265
- output = model(x_torch)
266
- return output.cpu().numpy()
267
-
268
- # 3. Create the KernelExplainer.
269
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
270
-
271
- print("Calculating SHAP values with KernelExplainer...")
272
- # For multi-output models, shap_values is a list of arrays.
273
- shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
257
+ # model.cpu()
258
+
259
+ shap_values_list = None
260
+ instances_to_explain_np = None
261
+
262
+ if explainer_type == 'deep':
263
+ # --- 1. Use DeepExplainer ---
264
+
265
+ # Ensure data is torch.Tensor
266
+ if isinstance(background_data, np.ndarray):
267
+ background_data = torch.from_numpy(background_data).float()
268
+ if isinstance(instances_to_explain, np.ndarray):
269
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
270
+
271
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
272
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
273
+ return
274
+
275
+ background_data = background_data.to(device)
276
+ instances_to_explain = instances_to_explain.to(device)
277
+
278
+ with warnings.catch_warnings():
279
+ warnings.simplefilter("ignore", category=UserWarning)
280
+ explainer = shap.DeepExplainer(model, background_data)
281
+
282
+ # print("Calculating SHAP values with DeepExplainer...")
283
+ # DeepExplainer returns a list of arrays for multi-output models
284
+ shap_values_list = explainer.shap_values(instances_to_explain)
285
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
286
+
287
+ elif explainer_type == 'kernel':
288
+ # --- 2. Use KernelExplainer ---
289
+ _LOGGER.warning(
290
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
291
+ )
292
+
293
+ # Convert all data to numpy
294
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
295
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
296
+
297
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
298
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
299
+ return
300
+
301
+ background_summary = shap.kmeans(background_data_np, 30)
302
+
303
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
304
+ x_torch = torch.from_numpy(x_np).float().to(device)
305
+ with torch.no_grad():
306
+ output = model(x_torch)
307
+ return output.cpu().numpy() # Return full multi-output array
308
+
309
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
310
+ # print("Calculating SHAP values with KernelExplainer...")
311
+ # KernelExplainer also returns a list of arrays for multi-output models
312
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
313
+ # instances_to_explain_np is already set
314
+
315
+ else:
316
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
317
+ raise ValueError("Invalid explainer_type")
318
+
319
+ # --- 3. Plotting and Saving (Common Logic) ---
320
+
321
+ if shap_values_list is None or instances_to_explain_np is None:
322
+ _LOGGER.error("SHAP value calculation failed. Aborting plotting.")
323
+ return
324
+
325
+ # Ensure number of SHAP value arrays matches number of target names
326
+ if len(shap_values_list) != len(target_names):
327
+ _LOGGER.error(
328
+ f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
329
+ f"outputs, but {len(target_names)} target_names were provided."
330
+ )
331
+ return
274
332
 
275
333
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
276
334
  plt.ioff()
277
335
 
278
- # 4. Iterate through each target's SHAP values and generate plots.
336
+ # Iterate through each target's SHAP values and generate plots.
279
337
  for i, target_name in enumerate(target_names):
280
338
  print(f" -> Generating SHAP plots for target: '{target_name}'")
281
339
  shap_values_for_target = shap_values_list[i]
@@ -292,11 +350,28 @@ def multi_target_shap_summary_plot(
292
350
  # Save Dot Plot for the target
293
351
  shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
294
352
  plt.title(f"SHAP Feature Importance for '{target_name}'")
353
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
354
+ cb = plt.gcf().axes[-1]
355
+ cb.set_ylabel("", size=1)
295
356
  plt.tight_layout()
296
357
  dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
297
358
  plt.savefig(dot_path)
298
359
  plt.close()
299
-
360
+
361
+ # --- Save Summary Data to CSV for this target ---
362
+ shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
363
+ summary_path = save_dir_path / shap_summary_filename
364
+
365
+ # For a specific target, shap_values_for_target is just a 2D array
366
+ mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
367
+
368
+ summary_df = pd.DataFrame({
369
+ SHAPKeys.FEATURE_COLUMN: feature_names,
370
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
371
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
372
+
373
+ summary_df.to_csv(summary_path, index=False)
374
+
300
375
  plt.ion()
301
376
  _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
302
377
 
ml_tools/ML_inference.py CHANGED
@@ -9,7 +9,8 @@ from .ML_scaler import PytorchScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys
12
+ from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
+
13
14
 
14
15
  __all__ = [
15
16
  "PyTorchInferenceHandler",
@@ -55,11 +56,21 @@ class _BaseInferenceHandler(ABC):
55
56
  model_p = make_fullpath(state_dict, enforce="file")
56
57
 
57
58
  try:
58
- # Load the state dictionary and apply it to the model structure
59
- self.model.load_state_dict(torch.load(model_p, map_location=self.device))
59
+ # Load whatever is in the file
60
+ loaded_data = torch.load(model_p, map_location=self.device)
61
+
62
+ # Check if it's the new checkpoint dictionary or an old weights-only file
63
+ if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
+ # It's a new training checkpoint, extract the weights
65
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+ else:
67
+ # It's an old-style file (or just a state_dict), load it directly
68
+ self.model.load_state_dict(loaded_data)
69
+
70
+ _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
71
+
60
72
  self.model.to(self.device)
61
73
  self.model.eval() # Set the model to evaluation mode
62
- _LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
63
74
  except Exception as e:
64
75
  _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
65
76
  raise
@@ -71,7 +82,6 @@ class _BaseInferenceHandler(ABC):
71
82
  _LOGGER.warning("CUDA not available, switching to CPU.")
72
83
  device_lower = "cpu"
73
84
  elif device_lower == "mps" and not torch.backends.mps.is_available():
74
- # Your M-series Mac will appreciate this check!
75
85
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
76
86
  device_lower = "cpu"
77
87
  return torch.device(device_lower)