dragon-ml-toolbox 12.12.0__tar.gz → 13.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-12.12.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-13.0.0}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_callbacks.py +40 -8
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation.py +94 -44
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation_multi.py +107 -32
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_inference.py +14 -4
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_trainer.py +113 -15
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/keys.py +9 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/pyproject.toml +1 -1
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/LICENSE +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/README.md +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_cleaning.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_engineering.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/GUI_tools.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/MICE_imputation.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_datasetmaster.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_models.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_scaler.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_simple_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/PSO_optimization.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/RNN_forecast.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/SQL.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/VIF_factor.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/__init__.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_logger.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_script_info.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/constants.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/custom_logger.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/data_exploration.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_evaluation.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_inference.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_learning.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/handle_excel.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/math_utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/optimization_tools.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/path_manager.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/serde.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/utilities.py +0 -0
- {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/setup.cfg +0 -0
|
@@ -5,7 +5,7 @@ from typing import Union, Literal, Optional
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
|
|
7
7
|
from .path_manager import make_fullpath, sanitize_filename
|
|
8
|
-
from .keys import PyTorchLogKeys
|
|
8
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
9
9
|
from ._logger import _LOGGER
|
|
10
10
|
from ._script_info import _script_info
|
|
11
11
|
|
|
@@ -189,7 +189,7 @@ class EarlyStopping(Callback):
|
|
|
189
189
|
|
|
190
190
|
class ModelCheckpoint(Callback):
|
|
191
191
|
"""
|
|
192
|
-
Saves the model weights to a directory with automated filename generation and rotation.
|
|
192
|
+
Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
|
|
193
193
|
"""
|
|
194
194
|
def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
|
|
195
195
|
save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
|
|
@@ -200,7 +200,7 @@ class ModelCheckpoint(Callback):
|
|
|
200
200
|
Args:
|
|
201
201
|
save_dir (str): Directory where checkpoint files will be saved.
|
|
202
202
|
checkpoint_name (str| None): If None, the filename will include the epoch and score.
|
|
203
|
-
monitor (str): Metric to monitor
|
|
203
|
+
monitor (str): Metric to monitor.
|
|
204
204
|
save_best_only (bool): If true, save only the best model.
|
|
205
205
|
mode (str): One of {'auto', 'min', 'max'}.
|
|
206
206
|
verbose (int): Verbosity mode.
|
|
@@ -270,15 +270,29 @@ class ModelCheckpoint(Callback):
|
|
|
270
270
|
if self.verbose > 0:
|
|
271
271
|
_LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
|
|
272
272
|
|
|
273
|
+
# Update best score *before* saving
|
|
274
|
+
self.best = current
|
|
275
|
+
|
|
276
|
+
# Create a comprehensive checkpoint dictionary
|
|
277
|
+
checkpoint_data = {
|
|
278
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
279
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
280
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
281
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# Check for scheduler
|
|
285
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
286
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
287
|
+
|
|
273
288
|
# Save the new best model
|
|
274
|
-
torch.save(
|
|
289
|
+
torch.save(checkpoint_data, new_filepath)
|
|
275
290
|
|
|
276
291
|
# Delete the old best model file
|
|
277
292
|
if self.last_best_filepath and self.last_best_filepath.exists():
|
|
278
293
|
self.last_best_filepath.unlink()
|
|
279
294
|
|
|
280
295
|
# Update state
|
|
281
|
-
self.best = current
|
|
282
296
|
self.last_best_filepath = new_filepath
|
|
283
297
|
|
|
284
298
|
def _save_rolling_checkpoints(self, epoch, logs):
|
|
@@ -292,7 +306,19 @@ class ModelCheckpoint(Callback):
|
|
|
292
306
|
|
|
293
307
|
if self.verbose > 0:
|
|
294
308
|
_LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
|
|
295
|
-
|
|
309
|
+
|
|
310
|
+
# Create a comprehensive checkpoint dictionary
|
|
311
|
+
checkpoint_data = {
|
|
312
|
+
PyTorchCheckpointKeys.EPOCH: epoch,
|
|
313
|
+
PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
|
|
314
|
+
PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
|
|
315
|
+
PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
|
|
319
|
+
checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
|
|
320
|
+
|
|
321
|
+
torch.save(checkpoint_data, filepath)
|
|
296
322
|
|
|
297
323
|
self.saved_checkpoints.append(filepath)
|
|
298
324
|
|
|
@@ -309,19 +335,25 @@ class LRScheduler(Callback):
|
|
|
309
335
|
"""
|
|
310
336
|
Callback to manage a PyTorch learning rate scheduler.
|
|
311
337
|
"""
|
|
312
|
-
def __init__(self, scheduler, monitor: Optional[str] =
|
|
338
|
+
def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
|
|
313
339
|
"""
|
|
314
340
|
This callback automatically calls the scheduler's `step()` method at the
|
|
315
341
|
end of each epoch. It also logs a message when the learning rate changes.
|
|
316
342
|
|
|
317
343
|
Args:
|
|
318
344
|
scheduler: An initialized PyTorch learning rate scheduler.
|
|
319
|
-
monitor (str
|
|
345
|
+
monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
|
|
320
346
|
"""
|
|
321
347
|
super().__init__()
|
|
322
348
|
self.scheduler = scheduler
|
|
323
349
|
self.monitor = monitor
|
|
324
350
|
self.previous_lr = None
|
|
351
|
+
|
|
352
|
+
def set_trainer(self, trainer):
|
|
353
|
+
"""This is called by the Trainer to associate itself with the callback."""
|
|
354
|
+
super().set_trainer(trainer)
|
|
355
|
+
# Register the scheduler with the trainer so it can be added to the checkpoint
|
|
356
|
+
self.trainer.scheduler = self.scheduler # type: ignore
|
|
325
357
|
|
|
326
358
|
def on_train_begin(self, logs=None):
|
|
327
359
|
"""Store the initial learning rate."""
|
|
@@ -18,7 +18,8 @@ from sklearn.metrics import (
|
|
|
18
18
|
import torch
|
|
19
19
|
import shap
|
|
20
20
|
from pathlib import Path
|
|
21
|
-
from typing import Union, Optional, List
|
|
21
|
+
from typing import Union, Optional, List, Literal
|
|
22
|
+
import warnings
|
|
22
23
|
|
|
23
24
|
from .path_manager import make_fullpath
|
|
24
25
|
from ._logger import _LOGGER
|
|
@@ -249,13 +250,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
249
250
|
plt.savefig(hist_path)
|
|
250
251
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
251
252
|
plt.close(fig_hist)
|
|
252
|
-
|
|
253
|
+
|
|
253
254
|
|
|
254
255
|
def shap_summary_plot(model,
|
|
255
256
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
256
257
|
instances_to_explain: Union[torch.Tensor,np.ndarray],
|
|
257
258
|
feature_names: Optional[list[str]],
|
|
258
|
-
save_dir: Union[str, Path]
|
|
259
|
+
save_dir: Union[str, Path],
|
|
260
|
+
device: torch.device = torch.device('cpu'),
|
|
261
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
259
262
|
"""
|
|
260
263
|
Calculates SHAP values and saves summary plots and data.
|
|
261
264
|
|
|
@@ -265,48 +268,88 @@ def shap_summary_plot(model,
|
|
|
265
268
|
instances_to_explain (torch.Tensor): The specific data instances to explain.
|
|
266
269
|
feature_names (list of str | None): Names of the features for plot labeling.
|
|
267
270
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
271
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
272
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
273
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
|
|
274
|
+
PyTorch models.
|
|
275
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
276
|
+
slow and memory-intensive.
|
|
268
277
|
"""
|
|
269
|
-
# everything to numpy
|
|
270
|
-
if isinstance(background_data, np.ndarray):
|
|
271
|
-
background_data_np = background_data
|
|
272
|
-
else:
|
|
273
|
-
background_data_np = background_data.numpy()
|
|
274
|
-
|
|
275
|
-
if isinstance(instances_to_explain, np.ndarray):
|
|
276
|
-
instances_to_explain_np = instances_to_explain
|
|
277
|
-
else:
|
|
278
|
-
instances_to_explain_np = instances_to_explain.numpy()
|
|
279
|
-
|
|
280
|
-
# --- Data Validation Step ---
|
|
281
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
282
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
283
|
-
return
|
|
284
278
|
|
|
285
|
-
print("\n--- SHAP Value Explanation ---")
|
|
279
|
+
print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
|
|
286
280
|
|
|
287
281
|
model.eval()
|
|
288
|
-
model.cpu()
|
|
289
|
-
|
|
290
|
-
# 1. Summarize the background data.
|
|
291
|
-
# Summarize the background data using k-means. 10-50 clusters is a good starting point.
|
|
292
|
-
background_summary = shap.kmeans(background_data_np, 30)
|
|
293
|
-
|
|
294
|
-
# 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
|
|
295
|
-
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
296
|
-
# Convert numpy data to torch tensor
|
|
297
|
-
x_torch = torch.from_numpy(x_np).float()
|
|
298
|
-
with torch.no_grad():
|
|
299
|
-
# Get model output
|
|
300
|
-
output = model(x_torch)
|
|
301
|
-
# Return as numpy array
|
|
302
|
-
return output.cpu().numpy().flatten()
|
|
303
|
-
|
|
304
|
-
# 3. Create the KernelExplainer
|
|
305
|
-
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
282
|
+
# model.cpu() # Run explanations on CPU
|
|
306
283
|
|
|
307
|
-
|
|
308
|
-
|
|
284
|
+
shap_values = None
|
|
285
|
+
instances_to_explain_np = None
|
|
286
|
+
|
|
287
|
+
if explainer_type == 'deep':
|
|
288
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
289
|
+
|
|
290
|
+
# Ensure data is torch.Tensor
|
|
291
|
+
if isinstance(background_data, np.ndarray):
|
|
292
|
+
background_data = torch.from_numpy(background_data).float()
|
|
293
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
294
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
295
|
+
|
|
296
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
297
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
298
|
+
return
|
|
299
|
+
|
|
300
|
+
background_data = background_data.to(device)
|
|
301
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
302
|
+
|
|
303
|
+
with warnings.catch_warnings():
|
|
304
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
305
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
306
|
+
|
|
307
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
308
|
+
shap_values = explainer.shap_values(instances_to_explain)
|
|
309
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
310
|
+
|
|
311
|
+
elif explainer_type == 'kernel':
|
|
312
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
313
|
+
_LOGGER.warning(
|
|
314
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
315
|
+
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
# Ensure data is np.ndarray
|
|
319
|
+
if isinstance(background_data, torch.Tensor):
|
|
320
|
+
background_data_np = background_data.cpu().numpy()
|
|
321
|
+
else:
|
|
322
|
+
background_data_np = background_data
|
|
323
|
+
|
|
324
|
+
if isinstance(instances_to_explain, torch.Tensor):
|
|
325
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
326
|
+
else:
|
|
327
|
+
instances_to_explain_np = instances_to_explain
|
|
328
|
+
|
|
329
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
330
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
331
|
+
return
|
|
332
|
+
|
|
333
|
+
# Summarize background data
|
|
334
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
335
|
+
|
|
336
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
337
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
338
|
+
with torch.no_grad():
|
|
339
|
+
output = model(x_torch)
|
|
340
|
+
# Return as numpy array
|
|
341
|
+
return output.cpu().numpy()
|
|
342
|
+
|
|
343
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
344
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
345
|
+
shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
346
|
+
# instances_to_explain_np is already set
|
|
309
347
|
|
|
348
|
+
else:
|
|
349
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
350
|
+
raise ValueError()
|
|
351
|
+
|
|
352
|
+
# --- 3. Plotting and Saving ---
|
|
310
353
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
311
354
|
plt.ioff()
|
|
312
355
|
|
|
@@ -326,8 +369,9 @@ def shap_summary_plot(model,
|
|
|
326
369
|
shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
327
370
|
ax = plt.gca()
|
|
328
371
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
329
|
-
|
|
330
|
-
|
|
372
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
373
|
+
cb = plt.gcf().axes[-1]
|
|
374
|
+
cb.set_ylabel("", size=1)
|
|
331
375
|
plt.title("SHAP Feature Importance")
|
|
332
376
|
plt.tight_layout()
|
|
333
377
|
plt.savefig(dot_path)
|
|
@@ -337,8 +381,14 @@ def shap_summary_plot(model,
|
|
|
337
381
|
# Save Summary Data to CSV
|
|
338
382
|
shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
|
|
339
383
|
summary_path = save_dir_path / shap_summary_filename
|
|
340
|
-
|
|
341
|
-
|
|
384
|
+
|
|
385
|
+
# Handle multi-class (list of arrays) vs. regression (single array)
|
|
386
|
+
if isinstance(shap_values, list):
|
|
387
|
+
mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
|
|
388
|
+
else:
|
|
389
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
390
|
+
|
|
391
|
+
mean_abs_shap = mean_abs_shap.flatten()
|
|
342
392
|
|
|
343
393
|
if feature_names is None:
|
|
344
394
|
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
@@ -351,7 +401,7 @@ def shap_summary_plot(model,
|
|
|
351
401
|
summary_df.to_csv(summary_path, index=False)
|
|
352
402
|
|
|
353
403
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
354
|
-
plt.ion()
|
|
404
|
+
plt.ion()
|
|
355
405
|
|
|
356
406
|
|
|
357
407
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
@@ -19,11 +19,13 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List
|
|
22
|
+
from typing import Union, List, Literal
|
|
23
|
+
import warnings
|
|
23
24
|
|
|
24
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
26
|
from ._logger import _LOGGER
|
|
26
27
|
from ._script_info import _script_info
|
|
28
|
+
from .keys import SHAPKeys
|
|
27
29
|
|
|
28
30
|
|
|
29
31
|
__all__ = [
|
|
@@ -231,10 +233,12 @@ def multi_target_shap_summary_plot(
|
|
|
231
233
|
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
232
234
|
feature_names: List[str],
|
|
233
235
|
target_names: List[str],
|
|
234
|
-
save_dir: Union[str, Path]
|
|
236
|
+
save_dir: Union[str, Path],
|
|
237
|
+
device: torch.device = torch.device('cpu'),
|
|
238
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'
|
|
235
239
|
):
|
|
236
240
|
"""
|
|
237
|
-
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
241
|
+
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
238
242
|
|
|
239
243
|
Args:
|
|
240
244
|
model (torch.nn.Module): The trained PyTorch model.
|
|
@@ -243,40 +247,94 @@ def multi_target_shap_summary_plot(
|
|
|
243
247
|
feature_names (List[str]): Names of the features for plot labeling.
|
|
244
248
|
target_names (List[str]): Names of the output targets.
|
|
245
249
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
250
|
+
device (torch.device): The torch device for SHAP calculations.
|
|
251
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
252
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
|
|
253
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
246
254
|
"""
|
|
247
|
-
|
|
248
|
-
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
249
|
-
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
250
|
-
|
|
251
|
-
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
252
|
-
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
253
|
-
return
|
|
254
|
-
|
|
255
|
-
_LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
|
|
255
|
+
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
256
256
|
model.eval()
|
|
257
|
-
model.cpu()
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
257
|
+
# model.cpu()
|
|
258
|
+
|
|
259
|
+
shap_values_list = None
|
|
260
|
+
instances_to_explain_np = None
|
|
261
|
+
|
|
262
|
+
if explainer_type == 'deep':
|
|
263
|
+
# --- 1. Use DeepExplainer (Preferred) ---
|
|
264
|
+
|
|
265
|
+
# Ensure data is torch.Tensor
|
|
266
|
+
if isinstance(background_data, np.ndarray):
|
|
267
|
+
background_data = torch.from_numpy(background_data).float()
|
|
268
|
+
if isinstance(instances_to_explain, np.ndarray):
|
|
269
|
+
instances_to_explain = torch.from_numpy(instances_to_explain).float()
|
|
270
|
+
|
|
271
|
+
if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
|
|
272
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
273
|
+
return
|
|
274
|
+
|
|
275
|
+
background_data = background_data.to(device)
|
|
276
|
+
instances_to_explain = instances_to_explain.to(device)
|
|
277
|
+
|
|
278
|
+
with warnings.catch_warnings():
|
|
279
|
+
warnings.simplefilter("ignore", category=UserWarning)
|
|
280
|
+
explainer = shap.DeepExplainer(model, background_data)
|
|
281
|
+
|
|
282
|
+
# print("Calculating SHAP values with DeepExplainer...")
|
|
283
|
+
# DeepExplainer returns a list of arrays for multi-output models
|
|
284
|
+
shap_values_list = explainer.shap_values(instances_to_explain)
|
|
285
|
+
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
286
|
+
|
|
287
|
+
elif explainer_type == 'kernel':
|
|
288
|
+
# --- 2. Use KernelExplainer (Slow Fallback) ---
|
|
289
|
+
_LOGGER.warning(
|
|
290
|
+
"Using KernelExplainer. This is memory-intensive and slow. "
|
|
291
|
+
"Consider reducing 'n_samples' if the process terminates."
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
# Convert all data to numpy
|
|
295
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
296
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
297
|
+
|
|
298
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
299
|
+
_LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
|
|
300
|
+
return
|
|
301
|
+
|
|
302
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
303
|
+
|
|
304
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
305
|
+
x_torch = torch.from_numpy(x_np).float().to(device)
|
|
306
|
+
with torch.no_grad():
|
|
307
|
+
output = model(x_torch)
|
|
308
|
+
return output.cpu().numpy() # Return full multi-output array
|
|
309
|
+
|
|
310
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
311
|
+
# print("Calculating SHAP values with KernelExplainer...")
|
|
312
|
+
# KernelExplainer also returns a list of arrays for multi-output models
|
|
313
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
314
|
+
# instances_to_explain_np is already set
|
|
315
|
+
|
|
316
|
+
else:
|
|
317
|
+
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
318
|
+
raise ValueError("Invalid explainer_type")
|
|
319
|
+
|
|
320
|
+
# --- 3. Plotting and Saving (Common Logic) ---
|
|
321
|
+
|
|
322
|
+
if shap_values_list is None or instances_to_explain_np is None:
|
|
323
|
+
_LOGGER.error("SHAP value calculation failed. Aborting plotting.")
|
|
324
|
+
return
|
|
325
|
+
|
|
326
|
+
# Ensure number of SHAP value arrays matches number of target names
|
|
327
|
+
if len(shap_values_list) != len(target_names):
|
|
328
|
+
_LOGGER.error(
|
|
329
|
+
f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
|
|
330
|
+
f"outputs, but {len(target_names)} target_names were provided."
|
|
331
|
+
)
|
|
332
|
+
return
|
|
275
333
|
|
|
276
334
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
277
335
|
plt.ioff()
|
|
278
336
|
|
|
279
|
-
#
|
|
337
|
+
# Iterate through each target's SHAP values and generate plots.
|
|
280
338
|
for i, target_name in enumerate(target_names):
|
|
281
339
|
print(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
282
340
|
shap_values_for_target = shap_values_list[i]
|
|
@@ -293,11 +351,28 @@ def multi_target_shap_summary_plot(
|
|
|
293
351
|
# Save Dot Plot for the target
|
|
294
352
|
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
295
353
|
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
354
|
+
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
355
|
+
cb = plt.gcf().axes[-1]
|
|
356
|
+
cb.set_ylabel("", size=1)
|
|
296
357
|
plt.tight_layout()
|
|
297
358
|
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
298
359
|
plt.savefig(dot_path)
|
|
299
360
|
plt.close()
|
|
300
|
-
|
|
361
|
+
|
|
362
|
+
# --- Save Summary Data to CSV for this target ---
|
|
363
|
+
shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
|
|
364
|
+
summary_path = save_dir_path / shap_summary_filename
|
|
365
|
+
|
|
366
|
+
# For a specific target, shap_values_for_target is just a 2D array
|
|
367
|
+
mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
|
|
368
|
+
|
|
369
|
+
summary_df = pd.DataFrame({
|
|
370
|
+
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
371
|
+
SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
|
|
372
|
+
}).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
|
|
373
|
+
|
|
374
|
+
summary_df.to_csv(summary_path, index=False)
|
|
375
|
+
|
|
301
376
|
plt.ion()
|
|
302
377
|
_LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
|
|
303
378
|
|
|
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .keys import PyTorchInferenceKeys
|
|
12
|
+
from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
56
56
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
57
57
|
|
|
58
58
|
try:
|
|
59
|
-
# Load
|
|
60
|
-
|
|
59
|
+
# Load whatever is in the file
|
|
60
|
+
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
+
|
|
62
|
+
# Check if it's the new checkpoint dictionary or an old weights-only file
|
|
63
|
+
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
64
|
+
# It's a new training checkpoint, extract the weights
|
|
65
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
66
|
+
else:
|
|
67
|
+
# It's an old-style file (or just a state_dict), load it directly
|
|
68
|
+
self.model.load_state_dict(loaded_data)
|
|
69
|
+
|
|
70
|
+
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
71
|
+
|
|
61
72
|
self.model.to(self.device)
|
|
62
73
|
self.model.eval() # Set the model to evaluation mode
|
|
63
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
|
|
64
74
|
except Exception as e:
|
|
65
75
|
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
66
76
|
raise
|
|
@@ -5,12 +5,13 @@ import torch
|
|
|
5
5
|
from torch import nn
|
|
6
6
|
import numpy as np
|
|
7
7
|
|
|
8
|
-
from .ML_callbacks import Callback, History, TqdmProgressBar
|
|
8
|
+
from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
|
|
9
9
|
from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
|
|
10
10
|
from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
|
|
11
11
|
from ._script_info import _script_info
|
|
12
|
-
from .keys import PyTorchLogKeys
|
|
12
|
+
from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
|
|
13
13
|
from ._logger import _LOGGER
|
|
14
|
+
from .path_manager import make_fullpath
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
__all__ = [
|
|
@@ -55,6 +56,7 @@ class MLTrainer:
|
|
|
55
56
|
self.kind = kind
|
|
56
57
|
self.criterion = criterion
|
|
57
58
|
self.optimizer = optimizer
|
|
59
|
+
self.scheduler = None
|
|
58
60
|
self.device = self._validate_device(device)
|
|
59
61
|
self.dataloader_workers = dataloader_workers
|
|
60
62
|
|
|
@@ -70,6 +72,7 @@ class MLTrainer:
|
|
|
70
72
|
self.history = {}
|
|
71
73
|
self.epoch = 0
|
|
72
74
|
self.epochs = 0 # Total epochs for the fit run
|
|
75
|
+
self.start_epoch = 1
|
|
73
76
|
self.stop_training = False
|
|
74
77
|
|
|
75
78
|
def _validate_device(self, device: str) -> torch.device:
|
|
@@ -109,8 +112,66 @@ class MLTrainer:
|
|
|
109
112
|
num_workers=loader_workers,
|
|
110
113
|
pin_memory=("cuda" in self.device.type)
|
|
111
114
|
)
|
|
115
|
+
|
|
116
|
+
def _load_checkpoint(self, path: Union[str, Path]):
|
|
117
|
+
"""Loads a training checkpoint to resume training."""
|
|
118
|
+
p = make_fullpath(path, enforce="file")
|
|
119
|
+
_LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
|
|
120
|
+
|
|
121
|
+
try:
|
|
122
|
+
checkpoint = torch.load(p, map_location=self.device)
|
|
123
|
+
|
|
124
|
+
if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
|
|
125
|
+
_LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
|
|
126
|
+
raise KeyError()
|
|
112
127
|
|
|
113
|
-
|
|
128
|
+
self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
|
|
129
|
+
self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
|
|
130
|
+
self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
|
|
131
|
+
|
|
132
|
+
# --- Scheduler State Loading Logic ---
|
|
133
|
+
scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
|
|
134
|
+
scheduler_object_exists = self.scheduler is not None
|
|
135
|
+
|
|
136
|
+
if scheduler_object_exists and scheduler_state_exists:
|
|
137
|
+
# Case 1: Both exist. Attempt to load.
|
|
138
|
+
try:
|
|
139
|
+
self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
|
|
140
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
141
|
+
_LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
|
|
142
|
+
except Exception as e:
|
|
143
|
+
# Loading failed, likely a mismatch
|
|
144
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
145
|
+
_LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
|
|
146
|
+
raise e
|
|
147
|
+
|
|
148
|
+
elif scheduler_object_exists and not scheduler_state_exists:
|
|
149
|
+
# Case 2: Scheduler provided, but no state in checkpoint.
|
|
150
|
+
scheduler_name = self.scheduler.__class__.__name__
|
|
151
|
+
_LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
|
|
152
|
+
|
|
153
|
+
elif not scheduler_object_exists and scheduler_state_exists:
|
|
154
|
+
# Case 3: State in checkpoint, but no scheduler provided.
|
|
155
|
+
_LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
|
|
156
|
+
raise ValueError()
|
|
157
|
+
|
|
158
|
+
# Restore callback states
|
|
159
|
+
for cb in self.callbacks:
|
|
160
|
+
if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
|
|
161
|
+
cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
|
|
162
|
+
_LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
|
|
163
|
+
|
|
164
|
+
_LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
|
|
165
|
+
|
|
166
|
+
except Exception as e:
|
|
167
|
+
_LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
|
|
168
|
+
raise
|
|
169
|
+
|
|
170
|
+
def fit(self,
|
|
171
|
+
epochs: int = 10,
|
|
172
|
+
batch_size: int = 10,
|
|
173
|
+
shuffle: bool = True,
|
|
174
|
+
resume_from_checkpoint: Optional[Union[str, Path]] = None):
|
|
114
175
|
"""
|
|
115
176
|
Starts the training-validation process of the model.
|
|
116
177
|
|
|
@@ -120,6 +181,7 @@ class MLTrainer:
|
|
|
120
181
|
epochs (int): The total number of epochs to train for.
|
|
121
182
|
batch_size (int): The number of samples per batch.
|
|
122
183
|
shuffle (bool): Whether to shuffle the training data at each epoch.
|
|
184
|
+
resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
|
|
123
185
|
|
|
124
186
|
Note:
|
|
125
187
|
For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
|
|
@@ -132,15 +194,18 @@ class MLTrainer:
|
|
|
132
194
|
self._create_dataloaders(batch_size, shuffle)
|
|
133
195
|
self.model.to(self.device)
|
|
134
196
|
|
|
197
|
+
if resume_from_checkpoint:
|
|
198
|
+
self._load_checkpoint(resume_from_checkpoint)
|
|
199
|
+
|
|
135
200
|
# Reset stop_training flag on the trainer
|
|
136
201
|
self.stop_training = False
|
|
137
202
|
|
|
138
|
-
self.
|
|
203
|
+
self._callbacks_hook('on_train_begin')
|
|
139
204
|
|
|
140
|
-
for epoch in range(
|
|
205
|
+
for epoch in range(self.start_epoch, self.epochs + 1):
|
|
141
206
|
self.epoch = epoch
|
|
142
207
|
epoch_logs = {}
|
|
143
|
-
self.
|
|
208
|
+
self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
|
|
144
209
|
|
|
145
210
|
train_logs = self._train_step()
|
|
146
211
|
epoch_logs.update(train_logs)
|
|
@@ -148,13 +213,13 @@ class MLTrainer:
|
|
|
148
213
|
val_logs = self._validation_step()
|
|
149
214
|
epoch_logs.update(val_logs)
|
|
150
215
|
|
|
151
|
-
self.
|
|
216
|
+
self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
|
|
152
217
|
|
|
153
218
|
# Check the early stopping flag
|
|
154
219
|
if self.stop_training:
|
|
155
220
|
break
|
|
156
221
|
|
|
157
|
-
self.
|
|
222
|
+
self._callbacks_hook('on_train_end')
|
|
158
223
|
return self.history
|
|
159
224
|
|
|
160
225
|
def _train_step(self):
|
|
@@ -166,7 +231,7 @@ class MLTrainer:
|
|
|
166
231
|
PyTorchLogKeys.BATCH_INDEX: batch_idx,
|
|
167
232
|
PyTorchLogKeys.BATCH_SIZE: features.size(0)
|
|
168
233
|
}
|
|
169
|
-
self.
|
|
234
|
+
self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
|
|
170
235
|
|
|
171
236
|
features, target = features.to(self.device), target.to(self.device)
|
|
172
237
|
self.optimizer.zero_grad()
|
|
@@ -188,7 +253,7 @@ class MLTrainer:
|
|
|
188
253
|
|
|
189
254
|
# Add the batch loss to the logs and call the end-of-batch hook
|
|
190
255
|
batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
|
|
191
|
-
self.
|
|
256
|
+
self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
|
|
192
257
|
|
|
193
258
|
return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
|
|
194
259
|
|
|
@@ -340,9 +405,10 @@ class MLTrainer:
|
|
|
340
405
|
def explain(self,
|
|
341
406
|
save_dir: Union[str,Path],
|
|
342
407
|
explain_dataset: Optional[Dataset] = None,
|
|
343
|
-
n_samples: int =
|
|
408
|
+
n_samples: int = 300,
|
|
344
409
|
feature_names: Optional[List[str]] = None,
|
|
345
|
-
target_names: Optional[List[str]] = None
|
|
410
|
+
target_names: Optional[List[str]] = None,
|
|
411
|
+
explainer_type: Literal['deep', 'kernel'] = 'deep'):
|
|
346
412
|
"""
|
|
347
413
|
Explains model predictions using SHAP and saves all artifacts.
|
|
348
414
|
|
|
@@ -359,6 +425,9 @@ class MLTrainer:
|
|
|
359
425
|
feature_names (list[str] | None): Feature names.
|
|
360
426
|
target_names (list[str] | None): Target names for multi-target tasks.
|
|
361
427
|
save_dir (str | Path): Directory to save all SHAP artifacts.
|
|
428
|
+
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
429
|
+
- 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
|
|
430
|
+
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
|
|
362
431
|
"""
|
|
363
432
|
# Internal helper to create a dataloader and get a random sample
|
|
364
433
|
def _get_random_sample(dataset: Dataset, num_samples: int):
|
|
@@ -410,6 +479,9 @@ class MLTrainer:
|
|
|
410
479
|
else:
|
|
411
480
|
_LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
|
|
412
481
|
raise ValueError()
|
|
482
|
+
|
|
483
|
+
# move model to device
|
|
484
|
+
self.model.to(self.device)
|
|
413
485
|
|
|
414
486
|
# 3. Call the plotting function
|
|
415
487
|
if self.kind in ["regression", "classification"]:
|
|
@@ -418,7 +490,9 @@ class MLTrainer:
|
|
|
418
490
|
background_data=background_data,
|
|
419
491
|
instances_to_explain=instances_to_explain,
|
|
420
492
|
feature_names=feature_names,
|
|
421
|
-
save_dir=save_dir
|
|
493
|
+
save_dir=save_dir,
|
|
494
|
+
explainer_type=explainer_type,
|
|
495
|
+
device=self.device
|
|
422
496
|
)
|
|
423
497
|
elif self.kind in ["multi_target_regression", "multi_label_classification"]:
|
|
424
498
|
# try to get target names
|
|
@@ -442,7 +516,9 @@ class MLTrainer:
|
|
|
442
516
|
instances_to_explain=instances_to_explain,
|
|
443
517
|
feature_names=feature_names, # type: ignore
|
|
444
518
|
target_names=target_names, # type: ignore
|
|
445
|
-
save_dir=save_dir
|
|
519
|
+
save_dir=save_dir,
|
|
520
|
+
explainer_type=explainer_type,
|
|
521
|
+
device=self.device
|
|
446
522
|
)
|
|
447
523
|
|
|
448
524
|
def _attention_helper(self, dataloader: DataLoader):
|
|
@@ -527,11 +603,33 @@ class MLTrainer:
|
|
|
527
603
|
else:
|
|
528
604
|
_LOGGER.error("No attention weights were collected from the model.")
|
|
529
605
|
|
|
530
|
-
def
|
|
606
|
+
def _callbacks_hook(self, method_name: str, *args, **kwargs):
|
|
531
607
|
"""Calls the specified method on all callbacks."""
|
|
532
608
|
for callback in self.callbacks:
|
|
533
609
|
method = getattr(callback, method_name)
|
|
534
610
|
method(*args, **kwargs)
|
|
611
|
+
|
|
612
|
+
def to_cpu(self):
|
|
613
|
+
"""
|
|
614
|
+
Moves the model to the CPU and updates the trainer's device setting.
|
|
615
|
+
|
|
616
|
+
This is useful for running operations that require the CPU.
|
|
617
|
+
"""
|
|
618
|
+
self.device = torch.device('cpu')
|
|
619
|
+
self.model.to(self.device)
|
|
620
|
+
_LOGGER.info("Trainer and model moved to CPU.")
|
|
621
|
+
|
|
622
|
+
def to_device(self, device: str):
|
|
623
|
+
"""
|
|
624
|
+
Moves the model to the specified device and updates the trainer's device setting.
|
|
625
|
+
|
|
626
|
+
Args:
|
|
627
|
+
device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
|
|
628
|
+
"""
|
|
629
|
+
self.device = self._validate_device(device)
|
|
630
|
+
self.model.to(self.device)
|
|
631
|
+
_LOGGER.info(f"Trainer and model moved to {self.device}.")
|
|
632
|
+
|
|
535
633
|
|
|
536
634
|
def info():
|
|
537
635
|
_script_info(__all__)
|
|
@@ -68,6 +68,15 @@ class SHAPKeys:
|
|
|
68
68
|
SAVENAME = "shap_summary"
|
|
69
69
|
|
|
70
70
|
|
|
71
|
+
class PyTorchCheckpointKeys:
|
|
72
|
+
"""Keys for saving/loading a training checkpoint dictionary."""
|
|
73
|
+
MODEL_STATE = "model_state_dict"
|
|
74
|
+
OPTIMIZER_STATE = "optimizer_state_dict"
|
|
75
|
+
SCHEDULER_STATE = "scheduler_state_dict"
|
|
76
|
+
EPOCH = "epoch"
|
|
77
|
+
BEST_SCORE = "best_score"
|
|
78
|
+
|
|
79
|
+
|
|
71
80
|
class _OneHotOtherPlaceholder:
|
|
72
81
|
"""Used internally by GUI_tools."""
|
|
73
82
|
OTHER_GUI = "OTHER"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt
RENAMED
|
File without changes
|
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/requires.txt
RENAMED
|
File without changes
|
{dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/top_level.txt
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|