dragon-ml-toolbox 13.4.0__py3-none-any.whl → 13.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 13.4.0
3
+ Version: 13.5.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,18 +1,18 @@
1
- dragon_ml_toolbox-13.4.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
- dragon_ml_toolbox-13.4.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
1
+ dragon_ml_toolbox-13.5.0.dist-info/licenses/LICENSE,sha256=L35WDmmLZNTlJvxF6Vy7Uy4SYNi6rCfWUqlTHpoRMoU,1081
2
+ dragon_ml_toolbox-13.5.0.dist-info/licenses/LICENSE-THIRD-PARTY.md,sha256=iy2r_R7wjzsCbz_Q_jMsp_jfZ6oP8XW9QhwzRBH0mGY,1904
3
3
  ml_tools/ETL_cleaning.py,sha256=2VBRllV8F-ZiPylPp8Az2gwn5ztgazN0BH5OKnRUhV0,20402
4
4
  ml_tools/ETL_engineering.py,sha256=KfYqgsxupAx6e_TxwO1LZXeu5mFkIhVXJrNjP3CzIZc,54927
5
5
  ml_tools/GUI_tools.py,sha256=Va6ig-dHULPVRwQYYtH3fvY5XPIoqRcJpRW8oXC55Hw,45413
6
6
  ml_tools/MICE_imputation.py,sha256=X273Qlgoqqg7KTmoKd75YDyAPB0UIbTzGP3xsCmRh3E,11717
7
7
  ml_tools/ML_callbacks.py,sha256=elD2Yr030sv_6gX_m9GVd6HTyrbmt34nFS8lrgS4HtM,15808
8
8
  ml_tools/ML_datasetmaster.py,sha256=6caWbq6eu1RE9V51gmceD71PtMctJRjFuLvkkK5ChiY,36271
9
- ml_tools/ML_evaluation.py,sha256=3u5dOhS77gn3kAshKr2GwSa5xZBF0YM77ZkFevqNPvA,18528
10
- ml_tools/ML_evaluation_multi.py,sha256=L6Ub_uObXsI7ToVCF6DtmAFekHRcga5wWMOnRYRR-BY,16121
9
+ ml_tools/ML_evaluation.py,sha256=li77AuP53pCzgrj6p-jTCNtPFgS9Y9XnMWIZn1ulTBM,18946
10
+ ml_tools/ML_evaluation_multi.py,sha256=rJKdgtq-9I7oaI7PRzq7aIZ84XdNV0xzlVePZW4nj0k,16095
11
11
  ml_tools/ML_inference.py,sha256=yq2gdN6s_OUYC5ZLQrIJC5BA5H33q8UKODXwb-_0M2c,23549
12
12
  ml_tools/ML_models.py,sha256=UVWJHPLVIvFno_csCHH1FwBfTwQ5nX0V8F1TbOByZ4I,31388
13
13
  ml_tools/ML_optimization.py,sha256=P0zkhKAwTpkorIBtR0AOIDcyexo5ngmvFUzo3DfNO-E,22692
14
14
  ml_tools/ML_scaler.py,sha256=tw6onj9o8_kk3FQYb930HUzvv1zsFZe2YZJdF3LtHkU,7538
15
- ml_tools/ML_trainer.py,sha256=9BP6JFClqGfe7GL-FGG3n5e-no9ssjEOLol7P6baGrI,29019
15
+ ml_tools/ML_trainer.py,sha256=ZxeOagXW5adFhYIH-oMTlcrLU6VHe4R1EROI7yypNwQ,29665
16
16
  ml_tools/ML_utilities.py,sha256=EnKpPTnJ2qjZmz7kvows4Uu5CfSA7ByRmI1v2-KarKw,9337
17
17
  ml_tools/PSO_optimization.py,sha256=T-HWHMRJUnPvPwixdU5jif3_rnnI36TzcL8u3oSCwuA,22960
18
18
  ml_tools/RNN_forecast.py,sha256=Qa2KoZfdAvSjZ4yE78N4BFXtr3tTr0Gx7tQJZPotsh0,1967
@@ -35,7 +35,7 @@ ml_tools/optimization_tools.py,sha256=TYFQ2nSnp7xxs-VyoZISWgnGJghFbsWasHjruegyJR
35
35
  ml_tools/path_manager.py,sha256=CyDU16pOKmC82jPubqJPT6EBt-u-3rGVbxyPIZCvDDY,18432
36
36
  ml_tools/serde.py,sha256=c8uDYjYry_VrLvoG4ixqDj5pij88lVn6Tu4NHcPkwDU,6943
37
37
  ml_tools/utilities.py,sha256=OcAyV1tEcYAfOWlGjRgopsjDLxU3DcI5EynzvWV4q3A,15754
38
- dragon_ml_toolbox-13.4.0.dist-info/METADATA,sha256=Ixk5If3BJhjyJy9_mirNJ2QckMELXFQiJa9_8RWfreI,6166
39
- dragon_ml_toolbox-13.4.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
- dragon_ml_toolbox-13.4.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
- dragon_ml_toolbox-13.4.0.dist-info/RECORD,,
38
+ dragon_ml_toolbox-13.5.0.dist-info/METADATA,sha256=EwOjL8T9Vnk1cg7vsDY4JaK9ovZtIkeIN2LcAiN-nvg,6166
39
+ dragon_ml_toolbox-13.5.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
40
+ dragon_ml_toolbox-13.5.0.dist-info/top_level.txt,sha256=wm-oxax3ciyez6VoO4zsFd-gSok2VipYXnbg3TH9PtU,9
41
+ dragon_ml_toolbox-13.5.0.dist-info/RECORD,,
ml_tools/ML_evaluation.py CHANGED
@@ -258,7 +258,7 @@ def shap_summary_plot(model,
258
258
  feature_names: Optional[list[str]],
259
259
  save_dir: Union[str, Path],
260
260
  device: torch.device = torch.device('cpu'),
261
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
261
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
262
262
  """
263
263
  Calculates SHAP values and saves summary plots and data.
264
264
 
@@ -270,7 +270,7 @@ def shap_summary_plot(model,
270
270
  save_dir (str | Path): Directory to save SHAP artifacts.
271
271
  device (torch.device): The torch device for SHAP calculations.
272
272
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
273
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
274
  PyTorch models.
275
275
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
276
  slow and memory-intensive.
@@ -285,7 +285,7 @@ def shap_summary_plot(model,
285
285
  instances_to_explain_np = None
286
286
 
287
287
  if explainer_type == 'deep':
288
- # --- 1. Use DeepExplainer (Preferred) ---
288
+ # --- 1. Use DeepExplainer ---
289
289
 
290
290
  # Ensure data is torch.Tensor
291
291
  if isinstance(background_data, np.ndarray):
@@ -309,10 +309,9 @@ def shap_summary_plot(model,
309
309
  instances_to_explain_np = instances_to_explain.cpu().numpy()
310
310
 
311
311
  elif explainer_type == 'kernel':
312
- # --- 2. Use KernelExplainer (Slow Fallback) ---
312
+ # --- 2. Use KernelExplainer ---
313
313
  _LOGGER.warning(
314
- "Using KernelExplainer. This is memory-intensive and slow. "
315
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
314
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
316
315
  )
317
316
 
318
317
  # Ensure data is np.ndarray
@@ -348,14 +347,26 @@ def shap_summary_plot(model,
348
347
  else:
349
348
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
349
  raise ValueError()
350
+
351
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
352
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
+ shap_values = shap_values.squeeze(-1)
351
354
 
352
355
  # --- 3. Plotting and Saving ---
353
356
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
354
357
  plt.ioff()
355
358
 
359
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
360
+ if feature_names is None:
361
+ # Create generic names if none were provided
362
+ num_features = instances_to_explain_np.shape[1]
363
+ feature_names = [f'feature_{i}' for i in range(num_features)]
364
+
365
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
366
+
356
367
  # Save Bar Plot
357
368
  bar_path = save_dir_path / "shap_bar_plot.svg"
358
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
369
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
359
370
  ax = plt.gca()
360
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
361
372
  plt.title("SHAP Feature Importance")
@@ -366,7 +377,7 @@ def shap_summary_plot(model,
366
377
 
367
378
  # Save Dot Plot
368
379
  dot_path = save_dir_path / "shap_dot_plot.svg"
369
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
380
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
370
381
  ax = plt.gca()
371
382
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
372
383
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -389,9 +400,6 @@ def shap_summary_plot(model,
389
400
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
401
 
391
402
  mean_abs_shap = mean_abs_shap.flatten()
392
-
393
- if feature_names is None:
394
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
395
403
 
396
404
  summary_df = pd.DataFrame({
397
405
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -401,7 +409,7 @@ def shap_summary_plot(model,
401
409
  summary_df.to_csv(summary_path, index=False)
402
410
 
403
411
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
404
- plt.ion()
412
+ plt.ion()
405
413
 
406
414
 
407
415
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -235,7 +235,7 @@ def multi_target_shap_summary_plot(
235
235
  target_names: List[str],
236
236
  save_dir: Union[str, Path],
237
237
  device: torch.device = torch.device('cpu'),
238
- explainer_type: Literal['deep', 'kernel'] = 'deep'
238
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
239
239
  ):
240
240
  """
241
241
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -249,7 +249,7 @@ def multi_target_shap_summary_plot(
249
249
  save_dir (str | Path): Directory to save SHAP artifacts.
250
250
  device (torch.device): The torch device for SHAP calculations.
251
251
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
252
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
253
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
254
254
  """
255
255
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -260,7 +260,7 @@ def multi_target_shap_summary_plot(
260
260
  instances_to_explain_np = None
261
261
 
262
262
  if explainer_type == 'deep':
263
- # --- 1. Use DeepExplainer (Preferred) ---
263
+ # --- 1. Use DeepExplainer ---
264
264
 
265
265
  # Ensure data is torch.Tensor
266
266
  if isinstance(background_data, np.ndarray):
@@ -285,10 +285,9 @@ def multi_target_shap_summary_plot(
285
285
  instances_to_explain_np = instances_to_explain.cpu().numpy()
286
286
 
287
287
  elif explainer_type == 'kernel':
288
- # --- 2. Use KernelExplainer (Slow Fallback) ---
288
+ # --- 2. Use KernelExplainer ---
289
289
  _LOGGER.warning(
290
- "Using KernelExplainer. This is memory-intensive and slow. "
291
- "Consider reducing 'n_samples' if the process terminates."
290
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
292
291
  )
293
292
 
294
293
  # Convert all data to numpy
ml_tools/ML_trainer.py CHANGED
@@ -9,7 +9,7 @@ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
13
13
  from ._logger import _LOGGER
14
14
  from .path_manager import make_fullpath
15
15
 
@@ -408,7 +408,7 @@ class MLTrainer:
408
408
  n_samples: int = 300,
409
409
  feature_names: Optional[List[str]] = None,
410
410
  target_names: Optional[List[str]] = None,
411
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
411
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
412
412
  """
413
413
  Explains model predictions using SHAP and saves all artifacts.
414
414
 
@@ -422,11 +422,11 @@ class MLTrainer:
422
422
  explain_dataset (Dataset | None): A specific dataset to explain.
423
423
  If None, the trainer's test dataset is used.
424
424
  n_samples (int): The number of samples to use for both background and explanation.
425
- feature_names (list[str] | None): Feature names.
425
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
426
426
  target_names (list[str] | None): Target names for multi-target tasks.
427
427
  save_dir (str | Path): Directory to save all SHAP artifacts.
428
428
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
429
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
429
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
430
430
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
431
431
  """
432
432
  # Internal helper to create a dataloader and get a random sample
@@ -474,10 +474,10 @@ class MLTrainer:
474
474
  # attempt to get feature names
475
475
  if feature_names is None:
476
476
  # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
477
- if hasattr(target_dataset, "feature_names"):
477
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
478
478
  feature_names = target_dataset.feature_names # type: ignore
479
479
  else:
480
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
480
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
481
481
  raise ValueError()
482
482
 
483
483
  # move model to device
@@ -498,7 +498,7 @@ class MLTrainer:
498
498
  # try to get target names
499
499
  if target_names is None:
500
500
  target_names = []
501
- if hasattr(target_dataset, 'target_names'):
501
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
502
502
  target_names = target_dataset.target_names # type: ignore
503
503
  else:
504
504
  # Infer number of targets from the model's output layer
@@ -549,7 +549,7 @@ class MLTrainer:
549
549
  yield attention_weights
550
550
 
551
551
  def explain_attention(self, save_dir: Union[str, Path],
552
- feature_names: Optional[List[str]],
552
+ feature_names: Optional[List[str]] = None,
553
553
  explain_dataset: Optional[Dataset] = None,
554
554
  plot_n_features: int = 10):
555
555
  """
@@ -559,7 +559,7 @@ class MLTrainer:
559
559
 
560
560
  Args:
561
561
  save_dir (str | Path): Directory to save the plot and summary data.
562
- feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
562
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
563
563
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
564
564
  plot_n_features (int): Number of top features to plot.
565
565
  """
@@ -580,6 +580,14 @@ class MLTrainer:
580
580
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
581
581
  return
582
582
 
583
+ # Get feature names
584
+ if feature_names is None:
585
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
586
+ feature_names = dataset_to_use.feature_names # type: ignore
587
+ else:
588
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
589
+ raise ValueError()
590
+
583
591
  explain_loader = DataLoader(
584
592
  dataset=dataset_to_use, batch_size=32, shuffle=False,
585
593
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,