dragon-ml-toolbox 13.3.2__tar.gz → 13.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (46) hide show
  1. {dragon_ml_toolbox-13.3.2/dragon_ml_toolbox.egg-info → dragon_ml_toolbox-13.5.0}/PKG-INFO +1 -1
  2. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0/dragon_ml_toolbox.egg-info}/PKG-INFO +1 -1
  3. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_datasetmaster.py +61 -20
  4. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_evaluation.py +20 -12
  5. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_evaluation_multi.py +5 -6
  6. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_trainer.py +17 -9
  7. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/pyproject.toml +1 -1
  8. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/LICENSE +0 -0
  9. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/LICENSE-THIRD-PARTY.md +0 -0
  10. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/README.md +0 -0
  11. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/dragon_ml_toolbox.egg-info/SOURCES.txt +0 -0
  12. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/dragon_ml_toolbox.egg-info/dependency_links.txt +0 -0
  13. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/dragon_ml_toolbox.egg-info/requires.txt +0 -0
  14. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/dragon_ml_toolbox.egg-info/top_level.txt +0 -0
  15. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ETL_cleaning.py +0 -0
  16. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ETL_engineering.py +0 -0
  17. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/GUI_tools.py +0 -0
  18. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/MICE_imputation.py +0 -0
  19. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_callbacks.py +0 -0
  20. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_inference.py +0 -0
  21. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_models.py +0 -0
  22. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_optimization.py +0 -0
  23. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_scaler.py +0 -0
  24. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ML_utilities.py +0 -0
  25. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/PSO_optimization.py +0 -0
  26. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/RNN_forecast.py +0 -0
  27. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/SQL.py +0 -0
  28. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/VIF_factor.py +0 -0
  29. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/__init__.py +0 -0
  30. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/_logger.py +0 -0
  31. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/_schema.py +0 -0
  32. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/_script_info.py +0 -0
  33. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/constants.py +0 -0
  34. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/custom_logger.py +0 -0
  35. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/data_exploration.py +0 -0
  36. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ensemble_evaluation.py +0 -0
  37. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ensemble_inference.py +0 -0
  38. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/ensemble_learning.py +0 -0
  39. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/handle_excel.py +0 -0
  40. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/keys.py +0 -0
  41. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/math_utilities.py +0 -0
  42. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/optimization_tools.py +0 -0
  43. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/path_manager.py +0 -0
  44. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/serde.py +0 -0
  45. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/ml_tools/utilities.py +0 -0
  46. {dragon_ml_toolbox-13.3.2 → dragon_ml_toolbox-13.5.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 13.3.2
3
+ Version: 13.5.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dragon-ml-toolbox
3
- Version: 13.3.2
3
+ Version: 13.5.0
4
4
  Summary: A collection of tools for data science and machine learning projects.
5
5
  Author-email: "Karl L. Loza Vidaurre" <luigiloza@gmail.com>
6
6
  License-Expression: MIT
@@ -126,8 +126,8 @@ class _BaseDatasetMaker(ABC):
126
126
  else:
127
127
  _LOGGER.info("No continuous features listed in schema. Scaler will not be fitted.")
128
128
 
129
- X_train_values = X_train.values
130
- X_test_values = X_test.values
129
+ X_train_values = X_train.to_numpy()
130
+ X_test_values = X_test.to_numpy()
131
131
 
132
132
  # continuous_feature_indices is derived
133
133
  if self.scaler is None and continuous_feature_indices:
@@ -253,26 +253,42 @@ class DatasetMaker(_BaseDatasetMaker):
253
253
  pandas_df: pandas.DataFrame,
254
254
  schema: FeatureSchema,
255
255
  kind: Literal["regression", "classification"],
256
+ scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
256
257
  test_size: float = 0.2,
257
- random_state: int = 42,
258
- scaler: Optional[PytorchScaler] = None):
258
+ random_state: int = 42):
259
259
  """
260
260
  Args:
261
261
  pandas_df (pandas.DataFrame):
262
262
  The pre-processed input DataFrame containing all columns. (features and single target).
263
263
  schema (FeatureSchema):
264
264
  The definitive schema object from data_exploration.
265
- kind (Literal["regression", "classification"]):
265
+ kind ("regression" | "classification"):
266
266
  The type of ML task. This determines the data type of the labels.
267
+ scaler ("fit" | "none" | PytorchScaler):
268
+ Strategy for data scaling:
269
+ - "fit": Fit a new PytorchScaler on continuous features.
270
+ - "none": Do not scale data (e.g., for TabularTransformer).
271
+ - PytorchScaler instance: Use a pre-fitted scaler to transform data.
267
272
  test_size (float):
268
273
  The proportion of the dataset to allocate to the test split.
269
274
  random_state (int):
270
275
  The seed for the random number of generator for reproducibility.
271
- scaler (PytorchScaler | None):
272
- A pre-fitted PytorchScaler instance, if None a new scaler will be created.
276
+
273
277
  """
274
278
  super().__init__()
275
- self.scaler = scaler
279
+
280
+ _apply_scaling: bool = False
281
+ if scaler == "fit":
282
+ self.scaler = None # To be created
283
+ _apply_scaling = True
284
+ elif scaler == "none":
285
+ self.scaler = None
286
+ elif isinstance(scaler, PytorchScaler):
287
+ self.scaler = scaler # Use the provided one
288
+ _apply_scaling = True
289
+ else:
290
+ _LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
291
+ raise ValueError()
276
292
 
277
293
  # --- 1. Identify features (from schema) ---
278
294
  self._feature_names = list(schema.feature_names)
@@ -310,9 +326,14 @@ class DatasetMaker(_BaseDatasetMaker):
310
326
  label_dtype = torch.float32 if kind == "regression" else torch.int64
311
327
 
312
328
  # --- 4. Scale (using the schema) ---
313
- X_train_final, X_test_final = self._prepare_scaler(
314
- X_train, y_train, X_test, label_dtype, schema
315
- )
329
+ if _apply_scaling:
330
+ X_train_final, X_test_final = self._prepare_scaler(
331
+ X_train, y_train, X_test, label_dtype, schema
332
+ )
333
+ else:
334
+ _LOGGER.info("Features have not been scaled as specified.")
335
+ X_train_final = X_train.to_numpy()
336
+ X_test_final = X_test.to_numpy()
316
337
 
317
338
  # --- 5. Create Datasets ---
318
339
  self._train_ds = _PytorchDataset(X_train_final, y_train, labels_dtype=label_dtype, feature_names=self._feature_names, target_names=self._target_names)
@@ -336,9 +357,9 @@ class DatasetMakerMulti(_BaseDatasetMaker):
336
357
  pandas_df: pandas.DataFrame,
337
358
  target_columns: List[str],
338
359
  schema: FeatureSchema,
360
+ scaler: Union[Literal["fit"], Literal["none"], PytorchScaler],
339
361
  test_size: float = 0.2,
340
- random_state: int = 42,
341
- scaler: Optional[PytorchScaler] = None):
362
+ random_state: int = 42):
342
363
  """
343
364
  Args:
344
365
  pandas_df (pandas.DataFrame):
@@ -348,20 +369,35 @@ class DatasetMakerMulti(_BaseDatasetMaker):
348
369
  List of target column names.
349
370
  schema (FeatureSchema):
350
371
  The definitive schema object from data_exploration.
372
+ scaler ("fit" | "none" | PytorchScaler):
373
+ Strategy for data scaling:
374
+ - "fit": Fit a new PytorchScaler on continuous features.
375
+ - "none": Do not scale data (e.g., for TabularTransformer).
376
+ - PytorchScaler instance: Use a pre-fitted scaler to transform data.
351
377
  test_size (float):
352
378
  The proportion of the dataset to allocate to the test split.
353
379
  random_state (int):
354
380
  The seed for the random number generator for reproducibility.
355
- scaler (PytorchScaler | None):
356
- A pre-fitted PytorchScaler instance.
357
381
 
358
382
  ## Note:
359
383
  For multi-binary classification, the most common PyTorch loss function is nn.BCEWithLogitsLoss.
360
384
  This loss function requires the labels to be torch.float32 which is the same type required for regression (multi-regression) tasks.
361
385
  """
362
386
  super().__init__()
363
- self.scaler = scaler
364
-
387
+
388
+ _apply_scaling: bool = False
389
+ if scaler == "fit":
390
+ self.scaler = None
391
+ _apply_scaling = True
392
+ elif scaler == "none":
393
+ self.scaler = None
394
+ elif isinstance(scaler, PytorchScaler):
395
+ self.scaler = scaler # Use the provided one
396
+ _apply_scaling = True
397
+ else:
398
+ _LOGGER.error(f"Invalid 'scaler' argument. Must be 'fit', 'none', or a PytorchScaler instance.")
399
+ raise ValueError()
400
+
365
401
  # --- 1. Get features and targets from schema/args ---
366
402
  self._feature_names = list(schema.feature_names)
367
403
  self._target_names = target_columns
@@ -403,9 +439,14 @@ class DatasetMakerMulti(_BaseDatasetMaker):
403
439
  label_dtype = torch.float32
404
440
 
405
441
  # --- 4. Scale (using the schema) ---
406
- X_train_final, X_test_final = self._prepare_scaler(
407
- X_train, y_train, X_test, label_dtype, schema
408
- )
442
+ if _apply_scaling:
443
+ X_train_final, X_test_final = self._prepare_scaler(
444
+ X_train, y_train, X_test, label_dtype, schema
445
+ )
446
+ else:
447
+ _LOGGER.info("Features have not been scaled as specified.")
448
+ X_train_final = X_train.to_numpy()
449
+ X_test_final = X_test.to_numpy()
409
450
 
410
451
  # --- 5. Create Datasets ---
411
452
  # _PytorchDataset now correctly handles y_train (a DataFrame)
@@ -258,7 +258,7 @@ def shap_summary_plot(model,
258
258
  feature_names: Optional[list[str]],
259
259
  save_dir: Union[str, Path],
260
260
  device: torch.device = torch.device('cpu'),
261
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
261
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
262
262
  """
263
263
  Calculates SHAP values and saves summary plots and data.
264
264
 
@@ -270,7 +270,7 @@ def shap_summary_plot(model,
270
270
  save_dir (str | Path): Directory to save SHAP artifacts.
271
271
  device (torch.device): The torch device for SHAP calculations.
272
272
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
273
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
274
  PyTorch models.
275
275
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
276
  slow and memory-intensive.
@@ -285,7 +285,7 @@ def shap_summary_plot(model,
285
285
  instances_to_explain_np = None
286
286
 
287
287
  if explainer_type == 'deep':
288
- # --- 1. Use DeepExplainer (Preferred) ---
288
+ # --- 1. Use DeepExplainer ---
289
289
 
290
290
  # Ensure data is torch.Tensor
291
291
  if isinstance(background_data, np.ndarray):
@@ -309,10 +309,9 @@ def shap_summary_plot(model,
309
309
  instances_to_explain_np = instances_to_explain.cpu().numpy()
310
310
 
311
311
  elif explainer_type == 'kernel':
312
- # --- 2. Use KernelExplainer (Slow Fallback) ---
312
+ # --- 2. Use KernelExplainer ---
313
313
  _LOGGER.warning(
314
- "Using KernelExplainer. This is memory-intensive and slow. "
315
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
314
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
316
315
  )
317
316
 
318
317
  # Ensure data is np.ndarray
@@ -348,14 +347,26 @@ def shap_summary_plot(model,
348
347
  else:
349
348
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
349
  raise ValueError()
350
+
351
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
352
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
+ shap_values = shap_values.squeeze(-1)
351
354
 
352
355
  # --- 3. Plotting and Saving ---
353
356
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
354
357
  plt.ioff()
355
358
 
359
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
360
+ if feature_names is None:
361
+ # Create generic names if none were provided
362
+ num_features = instances_to_explain_np.shape[1]
363
+ feature_names = [f'feature_{i}' for i in range(num_features)]
364
+
365
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
366
+
356
367
  # Save Bar Plot
357
368
  bar_path = save_dir_path / "shap_bar_plot.svg"
358
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
369
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
359
370
  ax = plt.gca()
360
371
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
361
372
  plt.title("SHAP Feature Importance")
@@ -366,7 +377,7 @@ def shap_summary_plot(model,
366
377
 
367
378
  # Save Dot Plot
368
379
  dot_path = save_dir_path / "shap_dot_plot.svg"
369
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
380
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
370
381
  ax = plt.gca()
371
382
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
372
383
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -389,9 +400,6 @@ def shap_summary_plot(model,
389
400
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
401
 
391
402
  mean_abs_shap = mean_abs_shap.flatten()
392
-
393
- if feature_names is None:
394
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
395
403
 
396
404
  summary_df = pd.DataFrame({
397
405
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -401,7 +409,7 @@ def shap_summary_plot(model,
401
409
  summary_df.to_csv(summary_path, index=False)
402
410
 
403
411
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
404
- plt.ion()
412
+ plt.ion()
405
413
 
406
414
 
407
415
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -235,7 +235,7 @@ def multi_target_shap_summary_plot(
235
235
  target_names: List[str],
236
236
  save_dir: Union[str, Path],
237
237
  device: torch.device = torch.device('cpu'),
238
- explainer_type: Literal['deep', 'kernel'] = 'deep'
238
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
239
239
  ):
240
240
  """
241
241
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -249,7 +249,7 @@ def multi_target_shap_summary_plot(
249
249
  save_dir (str | Path): Directory to save SHAP artifacts.
250
250
  device (torch.device): The torch device for SHAP calculations.
251
251
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
252
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
253
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
254
254
  """
255
255
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -260,7 +260,7 @@ def multi_target_shap_summary_plot(
260
260
  instances_to_explain_np = None
261
261
 
262
262
  if explainer_type == 'deep':
263
- # --- 1. Use DeepExplainer (Preferred) ---
263
+ # --- 1. Use DeepExplainer ---
264
264
 
265
265
  # Ensure data is torch.Tensor
266
266
  if isinstance(background_data, np.ndarray):
@@ -285,10 +285,9 @@ def multi_target_shap_summary_plot(
285
285
  instances_to_explain_np = instances_to_explain.cpu().numpy()
286
286
 
287
287
  elif explainer_type == 'kernel':
288
- # --- 2. Use KernelExplainer (Slow Fallback) ---
288
+ # --- 2. Use KernelExplainer ---
289
289
  _LOGGER.warning(
290
- "Using KernelExplainer. This is memory-intensive and slow. "
291
- "Consider reducing 'n_samples' if the process terminates."
290
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
292
291
  )
293
292
 
294
293
  # Convert all data to numpy
@@ -9,7 +9,7 @@ from .ML_callbacks import Callback, History, TqdmProgressBar, ModelCheckpoint
9
9
  from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
11
11
  from ._script_info import _script_info
12
- from .keys import PyTorchLogKeys, PyTorchCheckpointKeys
12
+ from .keys import PyTorchLogKeys, PyTorchCheckpointKeys, DatasetKeys
13
13
  from ._logger import _LOGGER
14
14
  from .path_manager import make_fullpath
15
15
 
@@ -408,7 +408,7 @@ class MLTrainer:
408
408
  n_samples: int = 300,
409
409
  feature_names: Optional[List[str]] = None,
410
410
  target_names: Optional[List[str]] = None,
411
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
411
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
412
412
  """
413
413
  Explains model predictions using SHAP and saves all artifacts.
414
414
 
@@ -422,11 +422,11 @@ class MLTrainer:
422
422
  explain_dataset (Dataset | None): A specific dataset to explain.
423
423
  If None, the trainer's test dataset is used.
424
424
  n_samples (int): The number of samples to use for both background and explanation.
425
- feature_names (list[str] | None): Feature names.
425
+ feature_names (list[str] | None): Feature names. If None, the names will be extracted from the Dataset and raise an error on failure.
426
426
  target_names (list[str] | None): Target names for multi-target tasks.
427
427
  save_dir (str | Path): Directory to save all SHAP artifacts.
428
428
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
429
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
429
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for PyTorch models.
430
430
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY slow and memory-intensive. Use with a very low 'n_samples'< 100.
431
431
  """
432
432
  # Internal helper to create a dataloader and get a random sample
@@ -474,10 +474,10 @@ class MLTrainer:
474
474
  # attempt to get feature names
475
475
  if feature_names is None:
476
476
  # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
477
- if hasattr(target_dataset, "feature_names"):
477
+ if hasattr(target_dataset, DatasetKeys.FEATURE_NAMES):
478
478
  feature_names = target_dataset.feature_names # type: ignore
479
479
  else:
480
- _LOGGER.error("Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a `feature_names` attribute.")
480
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
481
481
  raise ValueError()
482
482
 
483
483
  # move model to device
@@ -498,7 +498,7 @@ class MLTrainer:
498
498
  # try to get target names
499
499
  if target_names is None:
500
500
  target_names = []
501
- if hasattr(target_dataset, 'target_names'):
501
+ if hasattr(target_dataset, DatasetKeys.TARGET_NAMES):
502
502
  target_names = target_dataset.target_names # type: ignore
503
503
  else:
504
504
  # Infer number of targets from the model's output layer
@@ -549,7 +549,7 @@ class MLTrainer:
549
549
  yield attention_weights
550
550
 
551
551
  def explain_attention(self, save_dir: Union[str, Path],
552
- feature_names: Optional[List[str]],
552
+ feature_names: Optional[List[str]] = None,
553
553
  explain_dataset: Optional[Dataset] = None,
554
554
  plot_n_features: int = 10):
555
555
  """
@@ -559,7 +559,7 @@ class MLTrainer:
559
559
 
560
560
  Args:
561
561
  save_dir (str | Path): Directory to save the plot and summary data.
562
- feature_names (List[str] | None): Names for the features for plot labeling. If not given, generic names will be used.
562
+ feature_names (List[str] | None): Names for the features for plot labeling. If None, the names will be extracted from the Dataset and raise an error on failure.
563
563
  explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
564
564
  plot_n_features (int): Number of top features to plot.
565
565
  """
@@ -580,6 +580,14 @@ class MLTrainer:
580
580
  _LOGGER.error("The explanation dataset is empty or invalid. Skipping attention analysis.")
581
581
  return
582
582
 
583
+ # Get feature names
584
+ if feature_names is None:
585
+ if hasattr(dataset_to_use, DatasetKeys.FEATURE_NAMES):
586
+ feature_names = dataset_to_use.feature_names # type: ignore
587
+ else:
588
+ _LOGGER.error(f"Could not extract `feature_names` from the dataset for attention plot. It must be provided if the dataset object does not have a '{DatasetKeys.FEATURE_NAMES}' attribute.")
589
+ raise ValueError()
590
+
583
591
  explain_loader = DataLoader(
584
592
  dataset=dataset_to_use, batch_size=32, shuffle=False,
585
593
  num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "dragon-ml-toolbox"
3
- version = "13.3.2"
3
+ version = "13.5.0"
4
4
  description = "A collection of tools for data science and machine learning projects."
5
5
  authors = [
6
6
  { name = "Karl L. Loza Vidaurre", email = "luigiloza@gmail.com" }