dragon-ml-toolbox 10.1.1__py3-none-any.whl → 14.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/METADATA +38 -63
- dragon_ml_toolbox-14.2.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE +1 -1
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +11 -0
- ml_tools/ETL_cleaning.py +175 -59
- ml_tools/ETL_engineering.py +506 -70
- ml_tools/GUI_tools.py +2 -1
- ml_tools/MICE_imputation.py +212 -7
- ml_tools/ML_callbacks.py +73 -40
- ml_tools/ML_datasetmaster.py +267 -284
- ml_tools/ML_evaluation.py +119 -58
- ml_tools/ML_evaluation_multi.py +107 -32
- ml_tools/ML_inference.py +15 -5
- ml_tools/ML_models.py +234 -170
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +321 -97
- ml_tools/ML_scaler.py +10 -5
- ml_tools/ML_trainer.py +585 -40
- ml_tools/ML_utilities.py +528 -0
- ml_tools/ML_vision_datasetmaster.py +1315 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/PSO_optimization.py +10 -7
- ml_tools/RNN_forecast.py +2 -0
- ml_tools/SQL.py +22 -9
- ml_tools/VIF_factor.py +4 -3
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_logger.py +0 -2
- ml_tools/_schema.py +96 -0
- ml_tools/constants.py +79 -0
- ml_tools/custom_logger.py +164 -16
- ml_tools/data_exploration.py +1092 -109
- ml_tools/ensemble_evaluation.py +48 -1
- ml_tools/ensemble_inference.py +6 -7
- ml_tools/ensemble_learning.py +4 -3
- ml_tools/handle_excel.py +1 -0
- ml_tools/keys.py +80 -0
- ml_tools/math_utilities.py +259 -0
- ml_tools/optimization_tools.py +198 -24
- ml_tools/path_manager.py +144 -45
- ml_tools/serde.py +192 -0
- ml_tools/utilities.py +287 -227
- dragon_ml_toolbox-10.1.1.dist-info/RECORD +0 -36
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-10.1.1.dist-info → dragon_ml_toolbox-14.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -18,10 +18,13 @@ from sklearn.metrics import (
|
|
|
18
18
|
import torch
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
|
+
from typing import Union, Optional, List, Literal
|
|
22
|
+
import warnings
|
|
23
|
+
|
|
21
24
|
from .path_manager import make_fullpath
|
|
22
25
|
from ._logger import _LOGGER
|
|
23
|
-
from typing import Union, Optional, List
|
|
24
26
|
from ._script_info import _script_info
|
|
27
|
+
from .keys import SHAPKeys, PyTorchLogKeys
|
|
25
28
|
|
|
26
29
|
|
|
27
30
|
__all__ = [
|
|
@@ -41,8 +44,8 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
41
44
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
42
45
|
save_dir (str | Path): Directory to save the plot image.
|
|
43
46
|
"""
|
|
44
|
-
train_loss = history.get(
|
|
45
|
-
val_loss = history.get(
|
|
47
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
48
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
46
49
|
|
|
47
50
|
if not train_loss and not val_loss:
|
|
48
51
|
print("Warning: Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -247,13 +250,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
247
250
|
plt.savefig(hist_path)
|
|
248
251
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
249
252
|
plt.close(fig_hist)
|
|
250
|
-
|
|
253
|
+
|
|
251
254
|
|
|
252
255
|
def shap_summary_plot(model,
|
|
253
256
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
254
257
|
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
255
258
|
feature_names: Optional[list[str]],
|
|
256
|
-
save_dir: Union[str, Path]
|
|
259
|
+
save_dir: Union[str, Path],
|
|
260
|
+
device: torch.device = torch.device('cpu'),
|
|
261
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
257
262
|
"""
|
|
258
263
|
Calculates SHAP values and saves summary plots and data.
|
|
259
264
|
|
|
@@ -263,54 +268,105 @@ def shap_summary_plot(model,
|
|
|
263
268
|
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
264
269
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
265
270
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
271
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
272
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
273
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
274
|
+
PyTorch models.
|
|
275
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
276
|
+
slow and memory-intensive.
|
|
266
277
|
"""
|
|
267
|
-
# everything to numpy
|
|
268
|
-
if isinstance(background_data, np.ndarray):
|
|
269
|
-
background_data_np = background_data
|
|
270
|
-
else:
|
|
271
|
-
background_data_np = background_data.numpy()
|
|
272
|
-
|
|
273
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
274
|
-
instances_to_explain_np = instances_to_explain
|
|
275
|
-
else:
|
|
276
|
-
instances_to_explain_np = instances_to_explain.numpy()
|
|
277
|
-
|
|
278
|
-
# --- Data Validation Step ---
|
|
279
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
280
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
281
|
-
return
|
|
282
278
|
|
|
283
|
-
print("\n--- SHAP Value Explanation ---")
|
|
279
|
+
print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
|
|
284
280
|
|
|
285
281
|
model.eval()
|
|
286
|
-
model.cpu()
|
|
282
|
+
# model.cpu() # Run explanations on CPU
|
|
287
283
|
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
#
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
284
|
+
shap_values = None
|
|
285
|
+
instances_to_explain_np = None
|
|
286
|
+
|
|
287
|
+
if explainer_type == 'deep':
|
|
288
|
+
# --- 1. Use DeepExplainer ---
|
|
289
|
+
|
|
290
|
+
# Ensure data is torch.Tensor
|
|
291
|
+
if isinstance(background_data, np.ndarray):
|
|
292
|
+
background_data = torch.from_numpy(background_data).float()
|
|
293
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
294
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
295
|
+
|
|
296
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
297
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
background_data = background_data.to(device)
|
|
301
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
302
|
+
|
|
303
|
+
with warnings.catch_warnings():
|
|
304
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
305
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
306
|
+
|
|
307
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
308
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
309
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
310
|
+
|
|
311
|
+
elif explainer_type == 'kernel':
|
|
312
|
+
# --- 2. Use KernelExplainer ---
|
|
313
|
+
_LOGGER.warning(
|
|
314
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
# Ensure data is np.ndarray
|
|
318
|
+
if isinstance(background_data, torch.Tensor):
|
|
319
|
+
background_data_np = background_data.cpu().numpy()
|
|
320
|
+
else:
|
|
321
|
+
background_data_np = background_data
|
|
322
|
+
|
|
323
|
+
if isinstance(instances_to_explain, torch.Tensor):
|
|
324
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
325
|
+
else:
|
|
326
|
+
instances_to_explain_np = instances_to_explain
|
|
327
|
+
|
|
328
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
329
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
# Summarize background data
|
|
333
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
334
|
+
|
|
335
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
336
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
337
|
+
with torch.no_grad():
|
|
338
|
+
output = model(x_torch)
|
|
339
|
+
# Return as numpy array
|
|
340
|
+
return output.cpu().numpy()
|
|
341
|
+
|
|
342
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
343
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
344
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
345
|
+
# instances_to_explain_np is already set
|
|
304
346
|
|
|
305
|
-
|
|
306
|
-
|
|
347
|
+
else:
|
|
348
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
349
|
+
raise ValueError()
|
|
307
350
|
|
|
351
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
|
|
352
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
353
|
+
shap_values = shap_values.squeeze(-1)
|
|
354
|
+
|
|
355
|
+
# --- 3. Plotting and Saving ---
|
|
308
356
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
309
357
|
plt.ioff()
|
|
310
358
|
|
|
359
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
360
|
+
if feature_names is None:
|
|
361
|
+
# Create generic names if none were provided
|
|
362
|
+
num_features = instances_to_explain_np.shape[1]
|
|
363
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
364
|
+
|
|
365
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
366
|
+
|
|
311
367
|
# Save Bar Plot
|
|
312
368
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
313
|
-
shap.summary_plot(shap_values,
|
|
369
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
314
370
|
ax = plt.gca()
|
|
315
371
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
316
372
|
plt.title("SHAP Feature Importance")
|
|
@@ -321,11 +377,12 @@ def shap_summary_plot(model,
|
|
|
321
377
|
|
|
322
378
|
# Save Dot Plot
|
|
323
379
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
324
|
-
shap.summary_plot(shap_values,
|
|
380
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
325
381
|
ax = plt.gca()
|
|
326
382
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
327
|
-
|
|
328
|
-
|
|
383
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
384
|
+
cb = plt.gcf().axes[-1]
|
|
385
|
+
cb.set_ylabel("", size=1)
|
|
329
386
|
plt.title("SHAP Feature Importance")
|
|
330
387
|
plt.tight_layout()
|
|
331
388
|
plt.savefig(dot_path)
|
|
@@ -333,17 +390,21 @@ def shap_summary_plot(model,
|
|
|
333
390
|
plt.close()
|
|
334
391
|
|
|
335
392
|
# Save Summary Data to CSV
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
|
|
393
|
+
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
394
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
339
395
|
|
|
340
|
-
|
|
341
|
-
|
|
396
|
+
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
397
|
+
if isinstance(shap_values, list):
|
|
398
|
+
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
399
|
+
else:
|
|
400
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
401
|
+
|
|
402
|
+
mean_abs_shap = mean_abs_shap.flatten()
|
|
342
403
|
|
|
343
404
|
summary_df = pd.DataFrame({
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
}).sort_values(
|
|
405
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
406
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
407
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
347
408
|
|
|
348
409
|
summary_df.to_csv(summary_path, index=False)
|
|
349
410
|
|
|
@@ -351,7 +412,7 @@ def shap_summary_plot(model,
|
|
|
351
412
|
plt.ion()
|
|
352
413
|
|
|
353
414
|
|
|
354
|
-
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
|
|
415
|
+
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
355
416
|
"""
|
|
356
417
|
Aggregates attention weights and plots global feature importance.
|
|
357
418
|
|
|
@@ -362,6 +423,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
362
423
|
weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
|
|
363
424
|
feature_names (List[str] | None): Names of the features for plot labeling.
|
|
364
425
|
save_dir (str | Path): Directory to save the plot and summary CSV.
|
|
426
|
+
top_n (int): The number of top features to display in the plot.
|
|
365
427
|
"""
|
|
366
428
|
if not weights:
|
|
367
429
|
_LOGGER.error("Attention weights list is empty. Skipping importance plot.")
|
|
@@ -390,11 +452,10 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
390
452
|
summary_df.to_csv(summary_path, index=False)
|
|
391
453
|
_LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
|
|
392
454
|
|
|
393
|
-
# --- Step 3: Create and save the plot ---
|
|
394
|
-
|
|
455
|
+
# --- Step 3: Create and save the plot for top N features ---
|
|
456
|
+
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
395
457
|
|
|
396
|
-
|
|
397
|
-
plot_df = summary_df.sort_values('mean_attention', ascending=True)
|
|
458
|
+
plt.figure(figsize=(10, 8), dpi=100)
|
|
398
459
|
|
|
399
460
|
# Create horizontal bar plot with error bars
|
|
400
461
|
plt.barh(
|
|
@@ -408,7 +469,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
408
469
|
color='cornflowerblue'
|
|
409
470
|
)
|
|
410
471
|
|
|
411
|
-
plt.title('
|
|
472
|
+
plt.title('Top Features by Attention')
|
|
412
473
|
plt.xlabel('Average Attention Weight')
|
|
413
474
|
plt.ylabel('Feature')
|
|
414
475
|
plt.grid(axis='x', linestyle='--', alpha=0.6)
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,11 +19,14 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List
|
|
22
|
+
from typing import Union, List, Literal
|
|
23
|
+
import warnings
|
|
23
24
|
|
|
24
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
26
|
from ._logger import _LOGGER
|
|
26
27
|
from ._script_info import _script_info
|
|
28
|
+
from .keys import SHAPKeys
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
__all__ = [
|
|
29
32
|
"multi_target_regression_metrics",
|
|
@@ -230,10 +233,12 @@ def multi_target_shap_summary_plot(
|
|
|
230
233
|
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
231
234
|
feature_names: List[str],
|
|
232
235
|
target_names: List[str],
|
|
233
|
-
save_dir: Union[str, Path]
|
|
236
|
+
save_dir: Union[str, Path],
|
|
237
|
+
device: torch.device = torch.device('cpu'),
|
|
238
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
234
239
|
):
|
|
235
240
|
"""
|
|
236
|
-
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
241
|
+
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
237
242
|
|
|
238
243
|
Args:
|
|
239
244
|
model (torch.nn.Module): The trained PyTorch model.
|
|
@@ -242,40 +247,93 @@ def multi_target_shap_summary_plot(
|
|
|
242
247
|
feature_names (List[str]): Names of the features for plot labeling.
|
|
243
248
|
target_names (List[str]): Names of the output targets.
|
|
244
249
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
250
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
251
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
252
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
253
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
245
254
|
"""
|
|
246
|
-
|
|
247
|
-
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
248
|
-
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
249
|
-
|
|
250
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
251
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
252
|
-
return
|
|
253
|
-
|
|
254
|
-
_LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
|
|
255
|
+
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
255
256
|
model.eval()
|
|
256
|
-
model.cpu()
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
257
|
+
# model.cpu()
|
|
258
|
+
|
|
259
|
+
shap_values_list = None
|
|
260
|
+
instances_to_explain_np = None
|
|
261
|
+
|
|
262
|
+
if explainer_type == 'deep':
|
|
263
|
+
# --- 1. Use DeepExplainer ---
|
|
264
|
+
|
|
265
|
+
# Ensure data is torch.Tensor
|
|
266
|
+
if isinstance(background_data, np.ndarray):
|
|
267
|
+
background_data = torch.from_numpy(background_data).float()
|
|
268
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
269
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
270
|
+
|
|
271
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
272
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
background_data = background_data.to(device)
|
|
276
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
277
|
+
|
|
278
|
+
with warnings.catch_warnings():
|
|
279
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
280
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
281
|
+
|
|
282
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
283
|
+
# DeepExplainer returns a list of arrays for multi-output models
|
|
284
|
+
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
285
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
286
|
+
|
|
287
|
+
elif explainer_type == 'kernel':
|
|
288
|
+
# --- 2. Use KernelExplainer ---
|
|
289
|
+
_LOGGER.warning(
|
|
290
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# Convert all data to numpy
|
|
294
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
295
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
296
|
+
|
|
297
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
298
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
302
|
+
|
|
303
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
304
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
305
|
+
with torch.no_grad():
|
|
306
|
+
output = model(x_torch)
|
|
307
|
+
return output.cpu().numpy() # Return full multi-output array
|
|
308
|
+
|
|
309
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
310
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
311
|
+
# KernelExplainer also returns a list of arrays for multi-output models
|
|
312
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
313
|
+
# instances_to_explain_np is already set
|
|
314
|
+
|
|
315
|
+
else:
|
|
316
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
317
|
+
raise ValueError("Invalid explainer_type")
|
|
318
|
+
|
|
319
|
+
# --- 3. Plotting and Saving (Common Logic) ---
|
|
320
|
+
|
|
321
|
+
if shap_values_list is None or instances_to_explain_np is None:
|
|
322
|
+
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
# Ensure number of SHAP value arrays matches number of target names
|
|
326
|
+
if len(shap_values_list) != len(target_names):
|
|
327
|
+
_LOGGER.error(
|
|
328
|
+
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
329
|
+
f"outputs, but {len(target_names)} target_names were provided."
|
|
330
|
+
)
|
|
331
|
+
return
|
|
274
332
|
|
|
275
333
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
276
334
|
plt.ioff()
|
|
277
335
|
|
|
278
|
-
#
|
|
336
|
+
# Iterate through each target's SHAP values and generate plots.
|
|
279
337
|
for i, target_name in enumerate(target_names):
|
|
280
338
|
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
281
339
|
shap_values_for_target = shap_values_list[i]
|
|
@@ -292,11 +350,28 @@ def multi_target_shap_summary_plot(
|
|
|
292
350
|
# Save Dot Plot for the target
|
|
293
351
|
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
294
352
|
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
353
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
354
|
+
cb = plt.gcf().axes[-1]
|
|
355
|
+
cb.set_ylabel("", size=1)
|
|
295
356
|
plt.tight_layout()
|
|
296
357
|
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
297
358
|
plt.savefig(dot_path)
|
|
298
359
|
plt.close()
|
|
299
|
-
|
|
360
|
+
|
|
361
|
+
# --- Save Summary Data to CSV for this target ---
|
|
362
|
+
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
363
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
364
|
+
|
|
365
|
+
# For a specific target, shap_values_for_target is just a 2D array
|
|
366
|
+
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
367
|
+
|
|
368
|
+
summary_df = pd.DataFrame({
|
|
369
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
370
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
371
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
372
|
+
|
|
373
|
+
summary_df.to_csv(summary_path, index=False)
|
|
374
|
+
|
|
300
375
|
plt.ion()
|
|
301
376
|
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
302
377
|
|
ml_tools/ML_inference.py
CHANGED
|
@@ -9,7 +9,8 @@ from .ML_scaler import PytorchScaler
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .keys import PyTorchInferenceKeys
|
|
12
|
+
from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
|
|
13
|
+
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
16
|
"PyTorchInferenceHandler",
|
|
@@ -55,11 +56,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
55
56
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
56
57
|
|
|
57
58
|
try:
|
|
58
|
-
# Load
|
|
59
|
-
|
|
59
|
+
# Load whatever is in the file
|
|
60
|
+
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
+
|
|
62
|
+
# Check if it's the new checkpoint dictionary or an old weights-only file
|
|
63
|
+
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
64
|
+
# It's a new training checkpoint, extract the weights
|
|
65
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
66
|
+
else:
|
|
67
|
+
# It's an old-style file (or just a state_dict), load it directly
|
|
68
|
+
self.model.load_state_dict(loaded_data)
|
|
69
|
+
|
|
70
|
+
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
71
|
+
|
|
60
72
|
self.model.to(self.device)
|
|
61
73
|
self.model.eval() # Set the model to evaluation mode
|
|
62
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
|
|
63
74
|
except Exception as e:
|
|
64
75
|
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
65
76
|
raise
|
|
@@ -71,7 +82,6 @@ class _BaseInferenceHandler(ABC):
|
|
|
71
82
|
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
72
83
|
device_lower = "cpu"
|
|
73
84
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
74
|
-
# Your M-series Mac will appreciate this check!
|
|
75
85
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
76
86
|
device_lower = "cpu"
|
|
77
87
|
return torch.device(device_lower)
|