dragon-ml-toolbox 6.4.1__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_evaluation.py CHANGED
@@ -20,7 +20,7 @@ import shap
20
20
  from pathlib import Path
21
21
  from .path_manager import make_fullpath
22
22
  from ._logger import _LOGGER
23
- from typing import Union, Optional
23
+ from typing import Union, Optional, List
24
24
  from ._script_info import _script_info
25
25
 
26
26
 
@@ -28,7 +28,8 @@ __all__ = [
28
28
  "plot_losses",
29
29
  "classification_metrics",
30
30
  "regression_metrics",
31
- "shap_summary_plot"
31
+ "shap_summary_plot",
32
+ "plot_attention_importance"
32
33
  ]
33
34
 
34
35
 
@@ -249,7 +250,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
249
250
 
250
251
 
251
252
  def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], instances_to_explain: Union[torch.Tensor,np.ndarray],
252
- feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
253
+ feature_names: Optional[list[str]], save_dir: Union[str, Path]):
253
254
  """
254
255
  Calculates SHAP values and saves summary plots and data.
255
256
 
@@ -258,7 +259,7 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
258
259
  background_data (torch.Tensor): A sample of data for the explainer background.
259
260
  instances_to_explain (torch.Tensor): The specific data instances to explain.
260
261
  feature_names (list of str | None): Names of the features for plot labeling.
261
- save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
262
+ save_dir (str | Path): Directory to save SHAP artifacts.
262
263
  """
263
264
  # everything to numpy
264
265
  if isinstance(background_data, np.ndarray):
@@ -301,55 +302,119 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
301
302
  print("Calculating SHAP values with KernelExplainer...")
302
303
  shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
303
304
 
304
- if save_dir:
305
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
- plt.ioff()
307
-
308
- # Save Bar Plot
309
- bar_path = save_dir_path / "shap_bar_plot.svg"
310
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
311
- ax = plt.gca()
312
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
313
- plt.title("SHAP Feature Importance")
314
- plt.tight_layout()
315
- plt.savefig(bar_path)
316
- _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
317
- plt.close()
305
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
+ plt.ioff()
307
+
308
+ # Save Bar Plot
309
+ bar_path = save_dir_path / "shap_bar_plot.svg"
310
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
311
+ ax = plt.gca()
312
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
313
+ plt.title("SHAP Feature Importance")
314
+ plt.tight_layout()
315
+ plt.savefig(bar_path)
316
+ _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
317
+ plt.close()
318
318
 
319
- # Save Dot Plot
320
- dot_path = save_dir_path / "shap_dot_plot.svg"
321
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
322
- ax = plt.gca()
323
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
324
- cb = plt.gcf().axes[-1]
325
- cb.set_ylabel("", size=1)
326
- plt.title("SHAP Feature Importance")
327
- plt.tight_layout()
328
- plt.savefig(dot_path)
329
- _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
330
- plt.close()
319
+ # Save Dot Plot
320
+ dot_path = save_dir_path / "shap_dot_plot.svg"
321
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
322
+ ax = plt.gca()
323
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
324
+ cb = plt.gcf().axes[-1]
325
+ cb.set_ylabel("", size=1)
326
+ plt.title("SHAP Feature Importance")
327
+ plt.tight_layout()
328
+ plt.savefig(dot_path)
329
+ _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
330
+ plt.close()
331
331
 
332
- # Save Summary Data to CSV
333
- summary_path = save_dir_path / "shap_summary.csv"
334
- # Ensure the array is 1D before creating the DataFrame
335
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
336
-
337
- if feature_names is None:
338
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
339
-
340
- summary_df = pd.DataFrame({
341
- 'feature': feature_names,
342
- 'mean_abs_shap_value': mean_abs_shap
343
- }).sort_values('mean_abs_shap_value', ascending=False)
344
-
345
- summary_df.to_csv(summary_path, index=False)
346
-
347
- _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
348
- plt.ion()
332
+ # Save Summary Data to CSV
333
+ summary_path = save_dir_path / "shap_summary.csv"
334
+ # Ensure the array is 1D before creating the DataFrame
335
+ mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
336
+
337
+ if feature_names is None:
338
+ feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
349
339
 
350
- else:
351
- _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
352
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot")
340
+ summary_df = pd.DataFrame({
341
+ 'feature': feature_names,
342
+ 'mean_abs_shap_value': mean_abs_shap
343
+ }).sort_values('mean_abs_shap_value', ascending=False)
344
+
345
+ summary_df.to_csv(summary_path, index=False)
346
+
347
+ _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
348
+ plt.ion()
349
+
350
+
351
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
352
+ """
353
+ Aggregates attention weights and plots global feature importance.
354
+
355
+ The plot shows the mean attention for each feature as a bar, with the
356
+ standard deviation represented by error bars.
357
+
358
+ Args:
359
+ weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
360
+ feature_names (List[str] | None): Names of the features for plot labeling.
361
+ save_dir (str | Path): Directory to save the plot and summary CSV.
362
+ """
363
+ if not weights:
364
+ _LOGGER.warning("⚠️ Attention weights list is empty. Skipping importance plot.")
365
+ return
366
+
367
+ # --- Step 1: Aggregate data ---
368
+ # Concatenate the list of tensors into a single large tensor
369
+ full_weights_tensor = torch.cat(weights, dim=0)
370
+
371
+ # Calculate mean and std dev across the batch dimension (dim=0)
372
+ mean_weights = full_weights_tensor.mean(dim=0)
373
+ std_weights = full_weights_tensor.std(dim=0)
374
+
375
+ # --- Step 2: Create and save summary DataFrame ---
376
+ if feature_names is None:
377
+ feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
378
+
379
+ summary_df = pd.DataFrame({
380
+ 'feature': feature_names,
381
+ 'mean_attention': mean_weights.numpy(),
382
+ 'std_attention': std_weights.numpy()
383
+ }).sort_values('mean_attention', ascending=False)
384
+
385
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
386
+ summary_path = save_dir_path / "attention_summary.csv"
387
+ summary_df.to_csv(summary_path, index=False)
388
+ _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
389
+
390
+ # --- Step 3: Create and save the plot ---
391
+ plt.figure(figsize=(10, 8), dpi=100)
392
+
393
+ # Sort for plotting
394
+ plot_df = summary_df.sort_values('mean_attention', ascending=True)
395
+
396
+ # Create horizontal bar plot with error bars
397
+ plt.barh(
398
+ y=plot_df['feature'],
399
+ width=plot_df['mean_attention'],
400
+ xerr=plot_df['std_attention'],
401
+ align='center',
402
+ alpha=0.7,
403
+ ecolor='grey',
404
+ capsize=3,
405
+ color='cornflowerblue'
406
+ )
407
+
408
+ plt.title('Global Feature Importance')
409
+ plt.xlabel('Average Attention Weight')
410
+ plt.ylabel('Feature')
411
+ plt.grid(axis='x', linestyle='--', alpha=0.6)
412
+ plt.tight_layout()
413
+
414
+ plot_path = save_dir_path / "attention_importance.svg"
415
+ plt.savefig(plot_path)
416
+ _LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
417
+ plt.close()
353
418
 
354
419
 
355
420
  def info():
ml_tools/ML_inference.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
4
  from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
6
 
7
+ from .ML_scaler import PytorchScaler
7
8
  from ._script_info import _script_info
8
9
  from ._logger import _LOGGER
9
10
  from .path_manager import make_fullpath
@@ -25,7 +26,8 @@ class PyTorchInferenceHandler:
25
26
  state_dict: Union[str, Path],
26
27
  task: Literal["classification", "regression"],
27
28
  device: str = 'cpu',
28
- target_id: Optional[str]=None):
29
+ target_id: Optional[str]=None,
30
+ scaler: Optional[Union[PytorchScaler, str, Path]] = None):
29
31
  """
30
32
  Initializes the handler by loading a model's state_dict.
31
33
 
@@ -35,12 +37,22 @@ class PyTorchInferenceHandler:
35
37
  task (str): The type of task, 'regression' or 'classification'.
36
38
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
37
39
  target_id (str | None): Target name as used in the training set.
40
+ scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
38
41
  """
39
42
  self.model = model
40
43
  self.task = task
41
44
  self.device = self._validate_device(device)
42
45
  self.target_id = target_id
43
-
46
+
47
+ # Load the scaler if a path is provided
48
+ if scaler is not None:
49
+ if isinstance(scaler, (str, Path)):
50
+ self.scaler = PytorchScaler.load(scaler)
51
+ else:
52
+ self.scaler = scaler
53
+ else:
54
+ self.scaler = None
55
+
44
56
  model_p = make_fullpath(state_dict, enforce="file")
45
57
 
46
58
  try:
@@ -65,12 +77,22 @@ class PyTorchInferenceHandler:
65
77
  return torch.device(device_lower)
66
78
 
67
79
  def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
68
- """Converts input to a torch.Tensor and moves it to the correct device."""
80
+ """
81
+ Converts input to a torch.Tensor, applies scaling if a scaler is
82
+ present, and moves it to the correct device.
83
+ """
69
84
  if isinstance(features, np.ndarray):
70
- features = torch.from_numpy(features).float()
85
+ features_tensor = torch.from_numpy(features).float()
86
+ else:
87
+ # Ensure it's a float tensor for the model
88
+ features_tensor = features.float()
89
+
90
+ # Apply the scaler transformation if the scaler is available
91
+ if self.scaler:
92
+ features_tensor = self.scaler.transform(features_tensor)
71
93
 
72
94
  # Ensure tensor is on the correct device
73
- return features.to(self.device)
95
+ return features_tensor.to(self.device)
74
96
 
75
97
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
76
98
  """