dragon-ml-toolbox 12.12.0__tar.gz → 13.0.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (46) hide show
  1. {dragon_ml_toolbox-12.12.0/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-13.0.0}/PKG-INFO +1 -1
  2. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
  3. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_callbacks.py +40 -8
  4. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation.py +94 -44
  5. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_evaluation_multi.py +107 -32
  6. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_inference.py +14 -4
  7. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_trainer.py +113 -15
  8. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/keys.py +9 -0
  9. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/pyproject.toml +1 -1
  10. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/LICENSE +0 -0
  11. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/LICENSE-THIRD-PARTY.md +0 -0
  12. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/README.md +0 -0
  13. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
  14. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
  15. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
  16. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
  17. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_cleaning.py +0 -0
  18. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ETL_engineering.py +0 -0
  19. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/GUI_tools.py +0 -0
  20. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/MICE_imputation.py +0 -0
  21. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_datasetmaster.py +0 -0
  22. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_models.py +0 -0
  23. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_optimization.py +0 -0
  24. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_scaler.py +0 -0
  25. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_simple_optimization.py +0 -0
  26. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ML_utilities.py +0 -0
  27. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/PSO_optimization.py +0 -0
  28. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/RNN_forecast.py +0 -0
  29. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/SQL.py +0 -0
  30. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/VIF_factor.py +0 -0
  31. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/__init__.py +0 -0
  32. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_logger.py +0 -0
  33. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/_script_info.py +0 -0
  34. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/constants.py +0 -0
  35. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/custom_logger.py +0 -0
  36. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/data_exploration.py +0 -0
  37. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_evaluation.py +0 -0
  38. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_inference.py +0 -0
  39. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/ensemble_learning.py +0 -0
  40. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/handle_excel.py +0 -0
  41. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/math_utilities.py +0 -0
  42. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/optimization_tools.py +0 -0
  43. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/path_manager.py +0 -0
  44. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/serde.py +0 -0
  45. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/ml_tools/utilities.py +0 -0
  46. {dragon_ml_toolbox-12.12.0 → dragon_ml_toolbox-13.0.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.12.0
3
+ Version: 13.0.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 12.12.0
3
+ Version: 13.0.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -5,7 +5,7 @@ from typing import Union, Literal, Optional
5
5
  from pathlib import Path
6
6
 
7
7
  from .path_manager import make_fullpath, sanitize_filename
8
- from .keys import PyTorchLogKeys
8
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
9
9
  from ._logger import _LOGGER
10
10
  from ._script_info import _script_info
11
11
 
@@ -189,7 +189,7 @@ class EarlyStopping(Callback):
189
189
 
190
190
  class ModelCheckpoint(Callback):
191
191
  """
192
- Saves the model weights to a directory with automated filename generation and rotation.
192
+ Saves the model weights, optimizer state, LR scheduler state (if any), and epoch number to a directory with automated filename generation and rotation.
193
193
  """
194
194
  def __init__(self, save_dir: Union[str,Path], checkpoint_name: Optional[str]=None, monitor: str = PyTorchLogKeys.VAL_LOSS,
195
195
  save_best_only: bool = True, mode: Literal['auto', 'min', 'max']= 'auto', verbose: int = 0):
@@ -200,7 +200,7 @@ class ModelCheckpoint(Callback):
200
200
  Args:
201
201
  save_dir (str): Directory where checkpoint files will be saved.
202
202
  checkpoint_name (str| None): If None, the filename will include the epoch and score.
203
- monitor (str): Metric to monitor for `save_best_only=True`.
203
+ monitor (str): Metric to monitor.
204
204
  save_best_only (bool): If true, save only the best model.
205
205
  mode (str): One of {'auto', 'min', 'max'}.
206
206
  verbose (int): Verbosity mode.
@@ -270,15 +270,29 @@ class ModelCheckpoint(Callback):
270
270
  if self.verbose > 0:
271
271
  _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
272
272
 
273
+ # Update best score *before* saving
274
+ self.best = current
275
+
276
+ # Create a comprehensive checkpoint dictionary
277
+ checkpoint_data = {
278
+ PyTorchCheckpointKeys.EPOCH: epoch,
279
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
280
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
281
+ PyTorchCheckpointKeys.BEST_SCORE: self.best,
282
+ }
283
+
284
+ # Check for scheduler
285
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
286
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
287
+
273
288
  # Save the new best model
274
- torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
289
+ torch.save(checkpoint_data, new_filepath)
275
290
 
276
291
  # Delete the old best model file
277
292
  if self.last_best_filepath and self.last_best_filepath.exists():
278
293
  self.last_best_filepath.unlink()
279
294
 
280
295
  # Update state
281
- self.best = current
282
296
  self.last_best_filepath = new_filepath
283
297
 
284
298
  def _save_rolling_checkpoints(self, epoch, logs):
@@ -292,7 +306,19 @@ class ModelCheckpoint(Callback):
292
306
 
293
307
  if self.verbose > 0:
294
308
  _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
295
- torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
309
+
310
+ # Create a comprehensive checkpoint dictionary
311
+ checkpoint_data = {
312
+ PyTorchCheckpointKeys.EPOCH: epoch,
313
+ PyTorchCheckpointKeys.MODEL_STATE: self.trainer.model.state_dict(), # type: ignore
314
+ PyTorchCheckpointKeys.OPTIMIZER_STATE: self.trainer.optimizer.state_dict(), # type: ignore
315
+ PyTorchCheckpointKeys.BEST_SCORE: self.best, # Save the current best score
316
+ }
317
+
318
+ if hasattr(self.trainer, 'scheduler') and self.trainer.scheduler is not None: # type: ignore
319
+ checkpoint_data[PyTorchCheckpointKeys.SCHEDULER_STATE] = self.trainer.scheduler.state_dict() # type: ignore
320
+
321
+ torch.save(checkpoint_data, filepath)
296
322
 
297
323
  self.saved_checkpoints.append(filepath)
298
324
 
@@ -309,19 +335,25 @@ class LRScheduler(Callback):
309
335
  """
310
336
  Callback to manage a PyTorch learning rate scheduler.
311
337
  """
312
- def __init__(self, scheduler, monitor: Optional[str] = None):
338
+ def __init__(self, scheduler, monitor: Optional[str] = PyTorchLogKeys.VAL_LOSS):
313
339
  """
314
340
  This callback automatically calls the scheduler's `step()` method at the
315
341
  end of each epoch. It also logs a message when the learning rate changes.
316
342
 
317
343
  Args:
318
344
  scheduler: An initialized PyTorch learning rate scheduler.
319
- monitor (str, optional): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
345
+ monitor (str): The metric to monitor for schedulers that require it, like `ReduceLROnPlateau`. Should match a key in the logs (e.g., 'val_loss').
320
346
  """
321
347
  super().__init__()
322
348
  self.scheduler = scheduler
323
349
  self.monitor = monitor
324
350
  self.previous_lr = None
351
+
352
+ def set_trainer(self, trainer):
353
+ """This is called by the Trainer to associate itself with the callback."""
354
+ super().set_trainer(trainer)
355
+ # Register the scheduler with the trainer so it can be added to the checkpoint
356
+ self.trainer.scheduler = self.scheduler # type: ignore
325
357
 
326
358
  def on_train_begin(self, logs=None):
327
359
  """Store the initial learning rate."""
@@ -18,7 +18,8 @@ from sklearn.metrics import (
18
18
  import torch
19
19
  import shap
20
20
  from pathlib import Path
21
- from typing import Union, Optional, List
21
+ from typing import Union, Optional, List, Literal
22
+ import warnings
22
23
 
23
24
  from .path_manager import make_fullpath
24
25
  from ._logger import _LOGGER
@@ -249,13 +250,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
249
250
  plt.savefig(hist_path)
250
251
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
251
252
  plt.close(fig_hist)
252
-
253
+
253
254
 
254
255
  def shap_summary_plot(model,
255
256
  background_data: Union[torch.Tensor,np.ndarray],
256
257
  instances_to_explain: Union[torch.Tensor,np.ndarray],
257
258
  feature_names: Optional[list[str]],
258
- save_dir: Union[str, Path]):
259
+ save_dir: Union[str, Path],
260
+ device: torch.device = torch.device('cpu'),
261
+ explainer_type: Literal['deep', 'kernel'] = 'deep'):
259
262
  """
260
263
  Calculates SHAP values and saves summary plots and data.
261
264
 
@@ -265,48 +268,88 @@ def shap_summary_plot(model,
265
268
  instances_to_explain (torch.Tensor): The specific data instances to explain.
266
269
  feature_names (list of str | None): Names of the features for plot labeling.
267
270
  save_dir (str | Path): Directory to save SHAP artifacts.
271
+ device (torch.device): The torch device for SHAP calculations.
272
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
274
+ PyTorch models.
275
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
+ slow and memory-intensive.
268
277
  """
269
- # everything to numpy
270
- if isinstance(background_data, np.ndarray):
271
- background_data_np = background_data
272
- else:
273
- background_data_np = background_data.numpy()
274
-
275
- if isinstance(instances_to_explain, np.ndarray):
276
- instances_to_explain_np = instances_to_explain
277
- else:
278
- instances_to_explain_np = instances_to_explain.numpy()
279
-
280
- # --- Data Validation Step ---
281
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
282
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
283
- return
284
278
 
285
- print("\n--- SHAP Value Explanation ---")
279
+ print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
286
280
 
287
281
  model.eval()
288
- model.cpu()
289
-
290
- # 1. Summarize the background data.
291
- # Summarize the background data using k-means. 10-50 clusters is a good starting point.
292
- background_summary = shap.kmeans(background_data_np, 30)
293
-
294
- # 2. Define a prediction function wrapper that SHAP can use. It must take a numpy array and return a numpy array.
295
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
296
- # Convert numpy data to torch tensor
297
- x_torch = torch.from_numpy(x_np).float()
298
- with torch.no_grad():
299
- # Get model output
300
- output = model(x_torch)
301
- # Return as numpy array
302
- return output.cpu().numpy().flatten()
303
-
304
- # 3. Create the KernelExplainer
305
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
282
+ # model.cpu() # Run explanations on CPU
306
283
 
307
- print("Calculating SHAP values with KernelExplainer...")
308
- shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
284
+ shap_values = None
285
+ instances_to_explain_np = None
286
+
287
+ if explainer_type == 'deep':
288
+ # --- 1. Use DeepExplainer (Preferred) ---
289
+
290
+ # Ensure data is torch.Tensor
291
+ if isinstance(background_data, np.ndarray):
292
+ background_data = torch.from_numpy(background_data).float()
293
+ if isinstance(instances_to_explain, np.ndarray):
294
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
295
+
296
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
297
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
298
+ return
299
+
300
+ background_data = background_data.to(device)
301
+ instances_to_explain = instances_to_explain.to(device)
302
+
303
+ with warnings.catch_warnings():
304
+ warnings.simplefilter("ignore", category=UserWarning)
305
+ explainer = shap.DeepExplainer(model, background_data)
306
+
307
+ # print("Calculating SHAP values with DeepExplainer...")
308
+ shap_values = explainer.shap_values(instances_to_explain)
309
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
310
+
311
+ elif explainer_type == 'kernel':
312
+ # --- 2. Use KernelExplainer (Slow Fallback) ---
313
+ _LOGGER.warning(
314
+ "Using KernelExplainer. This is memory-intensive and slow. "
315
+ "Consider reducing 'n_samples' if the process terminates unexpectedly."
316
+ )
317
+
318
+ # Ensure data is np.ndarray
319
+ if isinstance(background_data, torch.Tensor):
320
+ background_data_np = background_data.cpu().numpy()
321
+ else:
322
+ background_data_np = background_data
323
+
324
+ if isinstance(instances_to_explain, torch.Tensor):
325
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
326
+ else:
327
+ instances_to_explain_np = instances_to_explain
328
+
329
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
330
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
331
+ return
332
+
333
+ # Summarize background data
334
+ background_summary = shap.kmeans(background_data_np, 30)
335
+
336
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
337
+ x_torch = torch.from_numpy(x_np).float().to(device)
338
+ with torch.no_grad():
339
+ output = model(x_torch)
340
+ # Return as numpy array
341
+ return output.cpu().numpy()
342
+
343
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
344
+ # print("Calculating SHAP values with KernelExplainer...")
345
+ shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
346
+ # instances_to_explain_np is already set
309
347
 
348
+ else:
349
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
+ raise ValueError()
351
+
352
+ # --- 3. Plotting and Saving ---
310
353
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
311
354
  plt.ioff()
312
355
 
@@ -326,8 +369,9 @@ def shap_summary_plot(model,
326
369
  shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
327
370
  ax = plt.gca()
328
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
329
- cb = plt.gcf().axes[-1]
330
- cb.set_ylabel("", size=1)
372
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
373
+ cb = plt.gcf().axes[-1]
374
+ cb.set_ylabel("", size=1)
331
375
  plt.title("SHAP Feature Importance")
332
376
  plt.tight_layout()
333
377
  plt.savefig(dot_path)
@@ -337,8 +381,14 @@ def shap_summary_plot(model,
337
381
  # Save Summary Data to CSV
338
382
  shap_summary_filename = SHAPKeys.SAVENAME + ".csv"
339
383
  summary_path = save_dir_path / shap_summary_filename
340
- # Ensure the array is 1D before creating the DataFrame
341
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
384
+
385
+ # Handle multi-class (list of arrays) vs. regression (single array)
386
+ if isinstance(shap_values, list):
387
+ mean_abs_shap = np.abs(np.stack(shap_values)).mean(axis=0).mean(axis=0)
388
+ else:
389
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
+
391
+ mean_abs_shap = mean_abs_shap.flatten()
342
392
 
343
393
  if feature_names is None:
344
394
  feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
@@ -351,7 +401,7 @@ def shap_summary_plot(model,
351
401
  summary_df.to_csv(summary_path, index=False)
352
402
 
353
403
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
354
- plt.ion()
404
+ plt.ion()
355
405
 
356
406
 
357
407
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -19,11 +19,13 @@ from sklearn.metrics import (
19
19
  jaccard_score
20
20
  )
21
21
  from pathlib import Path
22
- from typing import Union, List
22
+ from typing import Union, List, Literal
23
+ import warnings
23
24
 
24
25
  from .path_manager import make_fullpath, sanitize_filename
25
26
  from ._logger import _LOGGER
26
27
  from ._script_info import _script_info
28
+ from .keys import SHAPKeys
27
29
 
28
30
 
29
31
  __all__ = [
@@ -231,10 +233,12 @@ def multi_target_shap_summary_plot(
231
233
  instances_to_explain: Union[torch.Tensor, np.ndarray],
232
234
  feature_names: List[str],
233
235
  target_names: List[str],
234
- save_dir: Union[str, Path]
236
+ save_dir: Union[str, Path],
237
+ device: torch.device = torch.device('cpu'),
238
+ explainer_type: Literal['deep', 'kernel'] = 'deep'
235
239
  ):
236
240
  """
237
- Calculates SHAP values for a multi-target model and saves summary plots for each target.
241
+ Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
238
242
 
239
243
  Args:
240
244
  model (torch.nn.Module): The trained PyTorch model.
@@ -243,40 +247,94 @@ def multi_target_shap_summary_plot(
243
247
  feature_names (List[str]): Names of the features for plot labeling.
244
248
  target_names (List[str]): Names of the output targets.
245
249
  save_dir (str | Path): Directory to save SHAP artifacts.
250
+ device (torch.device): The torch device for SHAP calculations.
251
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
253
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
246
254
  """
247
- # Convert all data to numpy
248
- background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
249
- instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
250
-
251
- if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
252
- _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
253
- return
254
-
255
- _LOGGER.info("--- Multi-Target SHAP Value Explanation ---")
255
+ _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
256
256
  model.eval()
257
- model.cpu()
258
-
259
- # 1. Summarize the background data.
260
- background_summary = shap.kmeans(background_data_np, 30)
261
-
262
- # 2. Define a prediction function wrapper for the multi-target model.
263
- def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
264
- x_torch = torch.from_numpy(x_np).float()
265
- with torch.no_grad():
266
- output = model(x_torch)
267
- return output.cpu().numpy()
268
-
269
- # 3. Create the KernelExplainer.
270
- explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
271
-
272
- print("Calculating SHAP values with KernelExplainer...")
273
- # For multi-output models, shap_values is a list of arrays.
274
- shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
257
+ # model.cpu()
258
+
259
+ shap_values_list = None
260
+ instances_to_explain_np = None
261
+
262
+ if explainer_type == 'deep':
263
+ # --- 1. Use DeepExplainer (Preferred) ---
264
+
265
+ # Ensure data is torch.Tensor
266
+ if isinstance(background_data, np.ndarray):
267
+ background_data = torch.from_numpy(background_data).float()
268
+ if isinstance(instances_to_explain, np.ndarray):
269
+ instances_to_explain = torch.from_numpy(instances_to_explain).float()
270
+
271
+ if torch.isnan(background_data).any() or torch.isnan(instances_to_explain).any():
272
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
273
+ return
274
+
275
+ background_data = background_data.to(device)
276
+ instances_to_explain = instances_to_explain.to(device)
277
+
278
+ with warnings.catch_warnings():
279
+ warnings.simplefilter("ignore", category=UserWarning)
280
+ explainer = shap.DeepExplainer(model, background_data)
281
+
282
+ # print("Calculating SHAP values with DeepExplainer...")
283
+ # DeepExplainer returns a list of arrays for multi-output models
284
+ shap_values_list = explainer.shap_values(instances_to_explain)
285
+ instances_to_explain_np = instances_to_explain.cpu().numpy()
286
+
287
+ elif explainer_type == 'kernel':
288
+ # --- 2. Use KernelExplainer (Slow Fallback) ---
289
+ _LOGGER.warning(
290
+ "Using KernelExplainer. This is memory-intensive and slow. "
291
+ "Consider reducing 'n_samples' if the process terminates."
292
+ )
293
+
294
+ # Convert all data to numpy
295
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
296
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
297
+
298
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
299
+ _LOGGER.error("Input data for SHAP contains NaN values. Aborting explanation.")
300
+ return
301
+
302
+ background_summary = shap.kmeans(background_data_np, 30)
303
+
304
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
305
+ x_torch = torch.from_numpy(x_np).float().to(device)
306
+ with torch.no_grad():
307
+ output = model(x_torch)
308
+ return output.cpu().numpy() # Return full multi-output array
309
+
310
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
311
+ # print("Calculating SHAP values with KernelExplainer...")
312
+ # KernelExplainer also returns a list of arrays for multi-output models
313
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
314
+ # instances_to_explain_np is already set
315
+
316
+ else:
317
+ _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
318
+ raise ValueError("Invalid explainer_type")
319
+
320
+ # --- 3. Plotting and Saving (Common Logic) ---
321
+
322
+ if shap_values_list is None or instances_to_explain_np is None:
323
+ _LOGGER.error("SHAP value calculation failed. Aborting plotting.")
324
+ return
325
+
326
+ # Ensure number of SHAP value arrays matches number of target names
327
+ if len(shap_values_list) != len(target_names):
328
+ _LOGGER.error(
329
+ f"SHAP explanation mismatch: Model produced {len(shap_values_list)} "
330
+ f"outputs, but {len(target_names)} target_names were provided."
331
+ )
332
+ return
275
333
 
276
334
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
277
335
  plt.ioff()
278
336
 
279
- # 4. Iterate through each target's SHAP values and generate plots.
337
+ # Iterate through each target's SHAP values and generate plots.
280
338
  for i, target_name in enumerate(target_names):
281
339
  print(f" -> Generating SHAP plots for target: '{target_name}'")
282
340
  shap_values_for_target = shap_values_list[i]
@@ -293,11 +351,28 @@ def multi_target_shap_summary_plot(
293
351
  # Save Dot Plot for the target
294
352
  shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
295
353
  plt.title(f"SHAP Feature Importance for '{target_name}'")
354
+ if plt.gcf().axes and len(plt.gcf().axes) > 1:
355
+ cb = plt.gcf().axes[-1]
356
+ cb.set_ylabel("", size=1)
296
357
  plt.tight_layout()
297
358
  dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
298
359
  plt.savefig(dot_path)
299
360
  plt.close()
300
-
361
+
362
+ # --- Save Summary Data to CSV for this target ---
363
+ shap_summary_filename = f"{SHAPKeys.SAVENAME}_{sanitized_target_name}.csv"
364
+ summary_path = save_dir_path / shap_summary_filename
365
+
366
+ # For a specific target, shap_values_for_target is just a 2D array
367
+ mean_abs_shap = np.abs(shap_values_for_target).mean(axis=0).flatten()
368
+
369
+ summary_df = pd.DataFrame({
370
+ SHAPKeys.FEATURE_COLUMN: feature_names,
371
+ SHAPKeys.SHAP_VALUE_COLUMN: mean_abs_shap
372
+ }).sort_values(SHAPKeys.SHAP_VALUE_COLUMN, ascending=False)
373
+
374
+ summary_df.to_csv(summary_path, index=False)
375
+
301
376
  plt.ion()
302
377
  _LOGGER.info(f"All SHAP plots saved to '{save_dir_path.name}'")
303
378
 
@@ -9,7 +9,7 @@ from .ML_scaler import PytorchScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys
12
+ from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
13
 
14
14
 
15
15
  __all__ = [
@@ -56,11 +56,21 @@ class _BaseInferenceHandler(ABC):
56
56
  model_p = make_fullpath(state_dict, enforce="file")
57
57
 
58
58
  try:
59
- # Load the state dictionary and apply it to the model structure
60
- self.model.load_state_dict(torch.load(model_p, map_location=self.device))
59
+ # Load whatever is in the file
60
+ loaded_data = torch.load(model_p, map_location=self.device)
61
+
62
+ # Check if it's the new checkpoint dictionary or an old weights-only file
63
+ if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
+ # It's a new training checkpoint, extract the weights
65
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+ else:
67
+ # It's an old-style file (or just a state_dict), load it directly
68
+ self.model.load_state_dict(loaded_data)
69
+
70
+ _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
71
+
61
72
  self.model.to(self.device)
62
73
  self.model.eval() # Set the model to evaluation mode
63
- _LOGGER.info(f"Model state loaded from '{model_p.name}' and set to evaluation mode.")
64
74
  except Exception as e:
65
75
  _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
66
76
  raise
@@ -5,12 +5,13 @@ import torch
5
5
  from torch import nn
6
6
  import numpy as np
7
7
 
8
- from .ML_callbacks import Callback, History, TqdmProgressBar
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
13
13
  from ._logger import _LOGGER
14
+ from .path_manager import make_fullpath
14
15
 
15
16
 
16
17
  __all__ = [
@@ -55,6 +56,7 @@ class MLTrainer:
55
56
  self.kind = kind
56
57
  self.criterion = criterion
57
58
  self.optimizer = optimizer
59
+ self.scheduler = None
58
60
  self.device = self._validate_device(device)
59
61
  self.dataloader_workers = dataloader_workers
60
62
 
@@ -70,6 +72,7 @@ class MLTrainer:
70
72
  self.history = {}
71
73
  self.epoch = 0
72
74
  self.epochs = 0 # Total epochs for the fit run
75
+ self.start_epoch = 1
73
76
  self.stop_training = False
74
77
 
75
78
  def _validate_device(self, device: str) -> torch.device:
@@ -109,8 +112,66 @@ class MLTrainer:
109
112
  num_workers=loader_workers,
110
113
  pin_memory=("cuda" in self.device.type)
111
114
  )
115
+
116
+ def _load_checkpoint(self, path: Union[str, Path]):
117
+ """Loads a training checkpoint to resume training."""
118
+ p = make_fullpath(path, enforce="file")
119
+ _LOGGER.info(f"Loading checkpoint from '{p.name}' to resume training...")
120
+
121
+ try:
122
+ checkpoint = torch.load(p, map_location=self.device)
123
+
124
+ if PyTorchCheckpointKeys.MODEL_STATE not in checkpoint or PyTorchCheckpointKeys.OPTIMIZER_STATE not in checkpoint:
125
+ _LOGGER.error(f"Checkpoint file '{p.name}' is invalid. Missing 'model_state_dict' or 'optimizer_state_dict'.")
126
+ raise KeyError()
112
127
 
113
- def fit(self, epochs: int = 10, batch_size: int = 10, shuffle: bool = True):
128
+ self.model.load_state_dict(checkpoint[PyTorchCheckpointKeys.MODEL_STATE])
129
+ self.optimizer.load_state_dict(checkpoint[PyTorchCheckpointKeys.OPTIMIZER_STATE])
130
+ self.start_epoch = checkpoint.get(PyTorchCheckpointKeys.EPOCH, 0) + 1 # Resume on the *next* epoch
131
+
132
+ # --- Scheduler State Loading Logic ---
133
+ scheduler_state_exists = PyTorchCheckpointKeys.SCHEDULER_STATE in checkpoint
134
+ scheduler_object_exists = self.scheduler is not None
135
+
136
+ if scheduler_object_exists and scheduler_state_exists:
137
+ # Case 1: Both exist. Attempt to load.
138
+ try:
139
+ self.scheduler.load_state_dict(checkpoint[PyTorchCheckpointKeys.SCHEDULER_STATE]) # type: ignore
140
+ scheduler_name = self.scheduler.__class__.__name__
141
+ _LOGGER.info(f"Restored LR scheduler state for: {scheduler_name}")
142
+ except Exception as e:
143
+ # Loading failed, likely a mismatch
144
+ scheduler_name = self.scheduler.__class__.__name__
145
+ _LOGGER.error(f"Failed to load scheduler state for '{scheduler_name}'. A different scheduler type might have been used.")
146
+ raise e
147
+
148
+ elif scheduler_object_exists and not scheduler_state_exists:
149
+ # Case 2: Scheduler provided, but no state in checkpoint.
150
+ scheduler_name = self.scheduler.__class__.__name__
151
+ _LOGGER.warning(f"'{scheduler_name}' was provided, but no scheduler state was found in the checkpoint. The scheduler will start from its initial state.")
152
+
153
+ elif not scheduler_object_exists and scheduler_state_exists:
154
+ # Case 3: State in checkpoint, but no scheduler provided.
155
+ _LOGGER.error("Checkpoint contains an LR scheduler state, but no LRScheduler callback was provided.")
156
+ raise ValueError()
157
+
158
+ # Restore callback states
159
+ for cb in self.callbacks:
160
+ if isinstance(cb, ModelCheckpoint) and PyTorchCheckpointKeys.BEST_SCORE in checkpoint:
161
+ cb.best = checkpoint[PyTorchCheckpointKeys.BEST_SCORE]
162
+ _LOGGER.info(f"Restored {cb.__class__.__name__} 'best' score to: {cb.best:.4f}")
163
+
164
+ _LOGGER.info(f"Checkpoint loaded. Resuming training from epoch {self.start_epoch}.")
165
+
166
+ except Exception as e:
167
+ _LOGGER.error(f"Failed to load checkpoint from '{p}': {e}")
168
+ raise
169
+
170
+ def fit(self,
171
+ epochs: int = 10,
172
+ batch_size: int = 10,
173
+ shuffle: bool = True,
174
+ resume_from_checkpoint: Optional[Union[str, Path]] = None):
114
175
  """
115
176
  Starts the training-validation process of the model.
116
177
 
@@ -120,6 +181,7 @@ class MLTrainer:
120
181
  epochs (int): The total number of epochs to train for.
121
182
  batch_size (int): The number of samples per batch.
122
183
  shuffle (bool): Whether to shuffle the training data at each epoch.
184
+ resume_from_checkpoint (str | Path | None): Optional path to a checkpoint to resume training.
123
185
 
124
186
  Note:
125
187
  For regression tasks using `nn.MSELoss` or `nn.L1Loss`, the trainer
@@ -132,15 +194,18 @@ class MLTrainer:
132
194
  self._create_dataloaders(batch_size, shuffle)
133
195
  self.model.to(self.device)
134
196
 
197
+ if resume_from_checkpoint:
198
+ self._load_checkpoint(resume_from_checkpoint)
199
+
135
200
  # Reset stop_training flag on the trainer
136
201
  self.stop_training = False
137
202
 
138
- self.callbacks_hook('on_train_begin')
203
+ self._callbacks_hook('on_train_begin')
139
204
 
140
- for epoch in range(1, self.epochs + 1):
205
+ for epoch in range(self.start_epoch, self.epochs + 1):
141
206
  self.epoch = epoch
142
207
  epoch_logs = {}
143
- self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
208
+ self._callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
144
209
 
145
210
  train_logs = self._train_step()
146
211
  epoch_logs.update(train_logs)
@@ -148,13 +213,13 @@ class MLTrainer:
148
213
  val_logs = self._validation_step()
149
214
  epoch_logs.update(val_logs)
150
215
 
151
- self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
216
+ self._callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
152
217
 
153
218
  # Check the early stopping flag
154
219
  if self.stop_training:
155
220
  break
156
221
 
157
- self.callbacks_hook('on_train_end')
222
+ self._callbacks_hook('on_train_end')
158
223
  return self.history
159
224
 
160
225
  def _train_step(self):
@@ -166,7 +231,7 @@ class MLTrainer:
166
231
  PyTorchLogKeys.BATCH_INDEX: batch_idx,
167
232
  PyTorchLogKeys.BATCH_SIZE: features.size(0)
168
233
  }
169
- self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
234
+ self._callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
170
235
 
171
236
  features, target = features.to(self.device), target.to(self.device)
172
237
  self.optimizer.zero_grad()
@@ -188,7 +253,7 @@ class MLTrainer:
188
253
 
189
254
  # Add the batch loss to the logs and call the end-of-batch hook
190
255
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
191
- self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
256
+ self._callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
192
257
 
193
258
  return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
194
259
 
@@ -340,9 +405,10 @@ class MLTrainer:
340
405
  def explain(self,
341
406
  save_dir: Union[str,Path],
342
407
  explain_dataset: Optional[Dataset] = None,
343
- n_samples: int = 1000,
408
+ n_samples: int = 300,
344
409
  feature_names: Optional[List[str]] = None,
345
- target_names: Optional[List[str]] = None):
410
+ target_names: Optional[List[str]] = None,
411
+ explainer_type: Literal['deep', 'kernel'] = 'deep'):
346
412
  """
347
413
  Explains model predictions using SHAP and saves all artifacts.
348
414
 
@@ -359,6 +425,9 @@ class MLTrainer:
359
425
  feature_names (list[str] | None): Feature names.
360
426
  target_names (list[str] | None): Target names for multi-target tasks.
361
427
  save_dir (str | Path): Directory to save all SHAP artifacts.
428
+ explainer_type (Literal['deep', 'kernel']): The explainer to use.
429
+ - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
430
+ - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
362
431
  """
363
432
  # Internal helper to create a dataloader and get a random sample
364
433
  def _get_random_sample(dataset: Dataset, num_samples: int):
@@ -410,6 +479,9 @@ class MLTrainer:
410
479
  else:
411
480
  _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
412
481
  raise ValueError()
482
+
483
+ # move model to device
484
+ self.model.to(self.device)
413
485
 
414
486
  # 3. Call the plotting function
415
487
  if self.kind in ["regression", "classification"]:
@@ -418,7 +490,9 @@ class MLTrainer:
418
490
  background_data=background_data,
419
491
  instances_to_explain=instances_to_explain,
420
492
  feature_names=feature_names,
421
- save_dir=save_dir
493
+ save_dir=save_dir,
494
+ explainer_type=explainer_type,
495
+ device=self.device
422
496
  )
423
497
  elif self.kind in ["multi_target_regression", "multi_label_classification"]:
424
498
  # try to get target names
@@ -442,7 +516,9 @@ class MLTrainer:
442
516
  instances_to_explain=instances_to_explain,
443
517
  feature_names=feature_names, # type: ignore
444
518
  target_names=target_names, # type: ignore
445
- save_dir=save_dir
519
+ save_dir=save_dir,
520
+ explainer_type=explainer_type,
521
+ device=self.device
446
522
  )
447
523
 
448
524
  def _attention_helper(self, dataloader: DataLoader):
@@ -527,11 +603,33 @@ class MLTrainer:
527
603
  else:
528
604
  _LOGGER.error("No attention weights were collected from the model.")
529
605
 
530
- def callbacks_hook(self, method_name: str, *args, **kwargs):
606
+ def _callbacks_hook(self, method_name: str, *args, **kwargs):
531
607
  """Calls the specified method on all callbacks."""
532
608
  for callback in self.callbacks:
533
609
  method = getattr(callback, method_name)
534
610
  method(*args, **kwargs)
611
+
612
+ def to_cpu(self):
613
+ """
614
+ Moves the model to the CPU and updates the trainer's device setting.
615
+
616
+ This is useful for running operations that require the CPU.
617
+ """
618
+ self.device = torch.device('cpu')
619
+ self.model.to(self.device)
620
+ _LOGGER.info("Trainer and model moved to CPU.")
621
+
622
+ def to_device(self, device: str):
623
+ """
624
+ Moves the model to the specified device and updates the trainer's device setting.
625
+
626
+ Args:
627
+ device (str): The target device (e.g., 'cuda', 'mps', 'cpu').
628
+ """
629
+ self.device = self._validate_device(device)
630
+ self.model.to(self.device)
631
+ _LOGGER.info(f"Trainer and model moved to {self.device}.")
632
+
535
633
 
536
634
  def info():
537
635
  _script_info(__all__)
@@ -68,6 +68,15 @@ class SHAPKeys:
68
68
  SAVENAME = "shap_summary"
69
69
 
70
70
 
71
+ class PyTorchCheckpointKeys:
72
+ """Keys for saving/loading a training checkpoint dictionary."""
73
+ MODEL_STATE = "model_state_dict"
74
+ OPTIMIZER_STATE = "optimizer_state_dict"
75
+ SCHEDULER_STATE = "scheduler_state_dict"
76
+ EPOCH = "epoch"
77
+ BEST_SCORE = "best_score"
78
+
79
+
71
80
  class _OneHotOtherPlaceholder:
72
81
  """Used internally by GUI_tools."""
73
82
  OTHER_GUI = "OTHER"
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dragon-ml-toolbox"
3
- version = "12.12.0"
3
+ version = "13.0.0"
4
4
  description = "A collection of tools for data science and machine learning projects."
5
5
  authors = [
6
6
  { name = "Karl L. Loza Vidaurre", email = "luigiloza@gmail.com" }