dragon-ml-toolbox 6.4.0__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_evaluation.py CHANGED
@@ -20,7 +20,7 @@ import shap
20
20
  from pathlib import Path
21
21
  from .path_manager import make_fullpath
22
22
  from ._logger import _LOGGER
23
- from typing import Union, Optional
23
+ from typing import Union, Optional, List
24
24
  from ._script_info import _script_info
25
25
 
26
26
 
@@ -28,7 +28,8 @@ __all__ = [
28
28
  "plot_losses",
29
29
  "classification_metrics",
30
30
  "regression_metrics",
31
- "shap_summary_plot"
31
+ "shap_summary_plot",
32
+ "plot_attention_importance"
32
33
  ]
33
34
 
34
35
 
@@ -249,7 +250,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
249
250
 
250
251
 
251
252
  def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], instances_to_explain: Union[torch.Tensor,np.ndarray],
252
- feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
253
+ feature_names: Optional[list[str]], save_dir: Union[str, Path]):
253
254
  """
254
255
  Calculates SHAP values and saves summary plots and data.
255
256
 
@@ -258,7 +259,7 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
258
259
  background_data (torch.Tensor): A sample of data for the explainer background.
259
260
  instances_to_explain (torch.Tensor): The specific data instances to explain.
260
261
  feature_names (list of str | None): Names of the features for plot labeling.
261
- save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
262
+ save_dir (str | Path): Directory to save SHAP artifacts.
262
263
  """
263
264
  # everything to numpy
264
265
  if isinstance(background_data, np.ndarray):
@@ -301,55 +302,119 @@ def shap_summary_plot(model, background_data: Union[torch.Tensor,np.ndarray], in
301
302
  print("Calculating SHAP values with KernelExplainer...")
302
303
  shap_values = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
303
304
 
304
- if save_dir:
305
- save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
- plt.ioff()
307
-
308
- # Save Bar Plot
309
- bar_path = save_dir_path / "shap_bar_plot.svg"
310
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
311
- ax = plt.gca()
312
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
313
- plt.title("SHAP Feature Importance")
314
- plt.tight_layout()
315
- plt.savefig(bar_path)
316
- _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
317
- plt.close()
305
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
306
+ plt.ioff()
307
+
308
+ # Save Bar Plot
309
+ bar_path = save_dir_path / "shap_bar_plot.svg"
310
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
311
+ ax = plt.gca()
312
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
313
+ plt.title("SHAP Feature Importance")
314
+ plt.tight_layout()
315
+ plt.savefig(bar_path)
316
+ _LOGGER.info(f"📊 SHAP bar plot saved as '{bar_path.name}'")
317
+ plt.close()
318
318
 
319
- # Save Dot Plot
320
- dot_path = save_dir_path / "shap_dot_plot.svg"
321
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
322
- ax = plt.gca()
323
- ax.set_xlabel("SHAP Value Impact", labelpad=10)
324
- cb = plt.gcf().axes[-1]
325
- cb.set_ylabel("", size=1)
326
- plt.title("SHAP Feature Importance")
327
- plt.tight_layout()
328
- plt.savefig(dot_path)
329
- _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
330
- plt.close()
319
+ # Save Dot Plot
320
+ dot_path = save_dir_path / "shap_dot_plot.svg"
321
+ shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
322
+ ax = plt.gca()
323
+ ax.set_xlabel("SHAP Value Impact", labelpad=10)
324
+ cb = plt.gcf().axes[-1]
325
+ cb.set_ylabel("", size=1)
326
+ plt.title("SHAP Feature Importance")
327
+ plt.tight_layout()
328
+ plt.savefig(dot_path)
329
+ _LOGGER.info(f"📊 SHAP dot plot saved as '{dot_path.name}'")
330
+ plt.close()
331
331
 
332
- # Save Summary Data to CSV
333
- summary_path = save_dir_path / "shap_summary.csv"
334
- # Ensure the array is 1D before creating the DataFrame
335
- mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
336
-
337
- if feature_names is None:
338
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
339
-
340
- summary_df = pd.DataFrame({
341
- 'feature': feature_names,
342
- 'mean_abs_shap_value': mean_abs_shap
343
- }).sort_values('mean_abs_shap_value', ascending=False)
344
-
345
- summary_df.to_csv(summary_path, index=False)
346
-
347
- _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
348
- plt.ion()
332
+ # Save Summary Data to CSV
333
+ summary_path = save_dir_path / "shap_summary.csv"
334
+ # Ensure the array is 1D before creating the DataFrame
335
+ mean_abs_shap = np.abs(shap_values).mean(axis=0).flatten()
336
+
337
+ if feature_names is None:
338
+ feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
349
339
 
350
- else:
351
- _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
352
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot")
340
+ summary_df = pd.DataFrame({
341
+ 'feature': feature_names,
342
+ 'mean_abs_shap_value': mean_abs_shap
343
+ }).sort_values('mean_abs_shap_value', ascending=False)
344
+
345
+ summary_df.to_csv(summary_path, index=False)
346
+
347
+ _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
348
+ plt.ion()
349
+
350
+
351
+ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path]):
352
+ """
353
+ Aggregates attention weights and plots global feature importance.
354
+
355
+ The plot shows the mean attention for each feature as a bar, with the
356
+ standard deviation represented by error bars.
357
+
358
+ Args:
359
+ weights (List[torch.Tensor]): A list of attention weight tensors from each batch.
360
+ feature_names (List[str] | None): Names of the features for plot labeling.
361
+ save_dir (str | Path): Directory to save the plot and summary CSV.
362
+ """
363
+ if not weights:
364
+ _LOGGER.warning("⚠️ Attention weights list is empty. Skipping importance plot.")
365
+ return
366
+
367
+ # --- Step 1: Aggregate data ---
368
+ # Concatenate the list of tensors into a single large tensor
369
+ full_weights_tensor = torch.cat(weights, dim=0)
370
+
371
+ # Calculate mean and std dev across the batch dimension (dim=0)
372
+ mean_weights = full_weights_tensor.mean(dim=0)
373
+ std_weights = full_weights_tensor.std(dim=0)
374
+
375
+ # --- Step 2: Create and save summary DataFrame ---
376
+ if feature_names is None:
377
+ feature_names = [f'feature_{i}' for i in range(len(mean_weights))]
378
+
379
+ summary_df = pd.DataFrame({
380
+ 'feature': feature_names,
381
+ 'mean_attention': mean_weights.numpy(),
382
+ 'std_attention': std_weights.numpy()
383
+ }).sort_values('mean_attention', ascending=False)
384
+
385
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
386
+ summary_path = save_dir_path / "attention_summary.csv"
387
+ summary_df.to_csv(summary_path, index=False)
388
+ _LOGGER.info(f"📝 Attention summary data saved as '{summary_path.name}'")
389
+
390
+ # --- Step 3: Create and save the plot ---
391
+ plt.figure(figsize=(10, 8), dpi=100)
392
+
393
+ # Sort for plotting
394
+ plot_df = summary_df.sort_values('mean_attention', ascending=True)
395
+
396
+ # Create horizontal bar plot with error bars
397
+ plt.barh(
398
+ y=plot_df['feature'],
399
+ width=plot_df['mean_attention'],
400
+ xerr=plot_df['std_attention'],
401
+ align='center',
402
+ alpha=0.7,
403
+ ecolor='grey',
404
+ capsize=3,
405
+ color='cornflowerblue'
406
+ )
407
+
408
+ plt.title('Global Feature Importance')
409
+ plt.xlabel('Average Attention Weight')
410
+ plt.ylabel('Feature')
411
+ plt.grid(axis='x', linestyle='--', alpha=0.6)
412
+ plt.tight_layout()
413
+
414
+ plot_path = save_dir_path / "attention_importance.svg"
415
+ plt.savefig(plot_path)
416
+ _LOGGER.info(f"📊 Attention importance plot saved as '{plot_path.name}'")
417
+ plt.close()
353
418
 
354
419
 
355
420
  def info():
ml_tools/ML_inference.py CHANGED
@@ -4,6 +4,7 @@ import numpy as np
4
4
  from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
6
 
7
+ from .ML_scaler import PytorchScaler
7
8
  from ._script_info import _script_info
8
9
  from ._logger import _LOGGER
9
10
  from .path_manager import make_fullpath
@@ -25,7 +26,8 @@ class PyTorchInferenceHandler:
25
26
  state_dict: Union[str, Path],
26
27
  task: Literal["classification", "regression"],
27
28
  device: str = 'cpu',
28
- target_id: Optional[str]=None):
29
+ target_id: Optional[str]=None,
30
+ scaler: Optional[Union[PytorchScaler, str, Path]] = None):
29
31
  """
30
32
  Initializes the handler by loading a model's state_dict.
31
33
 
@@ -35,12 +37,22 @@ class PyTorchInferenceHandler:
35
37
  task (str): The type of task, 'regression' or 'classification'.
36
38
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
37
39
  target_id (str | None): Target name as used in the training set.
40
+ scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
38
41
  """
39
42
  self.model = model
40
43
  self.task = task
41
44
  self.device = self._validate_device(device)
42
45
  self.target_id = target_id
43
-
46
+
47
+ # Load the scaler if a path is provided
48
+ if scaler is not None:
49
+ if isinstance(scaler, (str, Path)):
50
+ self.scaler = PytorchScaler.load(scaler)
51
+ else:
52
+ self.scaler = scaler
53
+ else:
54
+ self.scaler = None
55
+
44
56
  model_p = make_fullpath(state_dict, enforce="file")
45
57
 
46
58
  try:
@@ -65,12 +77,22 @@ class PyTorchInferenceHandler:
65
77
  return torch.device(device_lower)
66
78
 
67
79
  def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
68
- """Converts input to a torch.Tensor and moves it to the correct device."""
80
+ """
81
+ Converts input to a torch.Tensor, applies scaling if a scaler is
82
+ present, and moves it to the correct device.
83
+ """
69
84
  if isinstance(features, np.ndarray):
70
- features = torch.from_numpy(features).float()
85
+ features_tensor = torch.from_numpy(features).float()
86
+ else:
87
+ # Ensure it's a float tensor for the model
88
+ features_tensor = features.float()
89
+
90
+ # Apply the scaler transformation if the scaler is available
91
+ if self.scaler:
92
+ features_tensor = self.scaler.transform(features_tensor)
71
93
 
72
94
  # Ensure tensor is on the correct device
73
- return features.to(self.device)
95
+ return features_tensor.to(self.device)
74
96
 
75
97
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
76
98
  """
@@ -190,18 +212,27 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
190
212
  f"Invalid task type: The handler for target_id '{handler.target_id}' "
191
213
  f"is for '{handler.task}', but only 'regression' tasks are supported."
192
214
  )
215
+
193
216
  # inference
194
217
  if output == "numpy":
195
- result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
196
- else: # torch
197
- result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
198
-
199
- # Unpack single results and update result dictionary
200
- # If the original input was 1D, extract the single prediction from the array.
201
- if is_single_sample:
202
- results[handler.target_id] = result[0]
203
- else:
204
- results[handler.target_id] = result
218
+ # This path returns NumPy arrays or standard Python scalars
219
+ numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
220
+ if is_single_sample:
221
+ # For a single sample, convert the 1-element array to a Python scalar
222
+ results[handler.target_id] = numpy_result.item()
223
+ else:
224
+ # For a batch, return the full NumPy array of predictions
225
+ results[handler.target_id] = numpy_result
226
+
227
+ else: # output == "torch"
228
+ # This path returns PyTorch tensors on the model's device
229
+ torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
230
+ if is_single_sample:
231
+ # For a single sample, return the 0-dim tensor
232
+ results[handler.target_id] = torch_result[0]
233
+ else:
234
+ # For a batch, return the full tensor of predictions
235
+ results[handler.target_id] = torch_result
205
236
 
206
237
  return results
207
238
 
@@ -263,18 +294,26 @@ def multi_inference_classification(
263
294
  f"is for '{handler.task}', but this function only supports 'classification'."
264
295
  )
265
296
 
266
- # Always use the batch method to get both labels and probabilities
297
+ # Inference
267
298
  if output == "numpy":
299
+ # predict_batch_numpy returns a dict of NumPy arrays
268
300
  result = handler.predict_batch_numpy(feature_vector)
269
301
  else: # torch
302
+ # predict_batch returns a dict of Torch tensors
270
303
  result = handler.predict_batch(feature_vector)
271
304
 
272
305
  labels = result[PyTorchInferenceKeys.LABELS]
273
306
  probabilities = result[PyTorchInferenceKeys.PROBABILITIES]
274
307
 
275
- # If the original input was 1D, unpack the single result from the batch array
276
308
  if is_single_sample:
277
- labels_results[handler.target_id] = labels[0]
309
+ # For "numpy", convert the single label to a Python int scalar.
310
+ # For "torch", get the 0-dim tensor label.
311
+ if output == "numpy":
312
+ labels_results[handler.target_id] = labels.item()
313
+ else: # torch
314
+ labels_results[handler.target_id] = labels[0]
315
+
316
+ # The probabilities are an array/tensor of values
278
317
  probs_results[handler.target_id] = probabilities[0]
279
318
  else:
280
319
  labels_results[handler.target_id] = labels