dragon-ml-toolbox 6.4.1__py3-none-any.whl → 7.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_scaler.py ADDED
@@ -0,0 +1,197 @@
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from pathlib import Path
4
+ from typing import Union, List, Optional
5
+ from ._logger import _LOGGER
6
+ from ._script_info import _script_info
7
+ from .path_manager import make_fullpath
8
+
9
+ __all__ = [
10
+ "PytorchScaler"
11
+ ]
12
+
13
+ class PytorchScaler:
14
+ """
15
+ Standardizes continuous features in a PyTorch dataset by subtracting the
16
+ mean and dividing by the standard deviation.
17
+
18
+ The scaler is fitted on a training dataset and can then be saved and
19
+ loaded for consistent transformation during inference.
20
+ """
21
+ def __init__(self,
22
+ mean: Optional[torch.Tensor] = None,
23
+ std: Optional[torch.Tensor] = None,
24
+ continuous_feature_indices: Optional[List[int]] = None):
25
+ """
26
+ Initializes the scaler.
27
+
28
+ Args:
29
+ mean (torch.Tensor, optional): The mean of the features to scale.
30
+ std (torch.Tensor, optional): The standard deviation of the features.
31
+ continuous_feature_indices (List[int], optional): The column indices of the features to standardize.
32
+ """
33
+ self.mean_ = mean
34
+ self.std_ = std
35
+ self.continuous_feature_indices = continuous_feature_indices
36
+
37
+ @classmethod
38
+ def fit(cls, dataset: Dataset, continuous_feature_indices: List[int], batch_size: int = 64) -> 'PytorchScaler':
39
+ """
40
+ Fits the scaler by computing the mean and std dev from a dataset using a
41
+ fast, single-pass, vectorized algorithm.
42
+
43
+ Args:
44
+ dataset (Dataset): The PyTorch Dataset to fit on.
45
+ continuous_feature_indices (List[int]): The column indices of the
46
+ features to standardize.
47
+ batch_size (int): The batch size for iterating through the dataset.
48
+
49
+ Returns:
50
+ PytorchScaler: A new, fitted instance of the scaler.
51
+ """
52
+ if not continuous_feature_indices:
53
+ _LOGGER.warning("⚠️ No continuous feature indices provided. Scaler will not be fitted.")
54
+ return cls()
55
+
56
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
57
+
58
+ running_sum, running_sum_sq = None, None
59
+ count = 0
60
+ num_continuous_features = len(continuous_feature_indices)
61
+
62
+ for features, _ in loader:
63
+ if running_sum is None:
64
+ device = features.device
65
+ running_sum = torch.zeros(num_continuous_features, device=device)
66
+ running_sum_sq = torch.zeros(num_continuous_features, device=device)
67
+
68
+ continuous_features = features[:, continuous_feature_indices].to(device)
69
+
70
+ running_sum += torch.sum(continuous_features, dim=0)
71
+ running_sum_sq += torch.sum(continuous_features**2, dim=0) # type: ignore
72
+ count += continuous_features.size(0)
73
+
74
+ if count == 0:
75
+ _LOGGER.warning("⚠️ Dataset is empty. Scaler cannot be fitted.")
76
+ return cls(continuous_feature_indices=continuous_feature_indices)
77
+
78
+ # Calculate mean
79
+ mean = running_sum / count
80
+
81
+ # Calculate standard deviation
82
+ if count < 2:
83
+ _LOGGER.warning(f"⚠️ Only one sample found. Standard deviation cannot be calculated and is set to 1.")
84
+ std = torch.ones_like(mean)
85
+ else:
86
+ # var = E[X^2] - (E[X])^2
87
+ var = (running_sum_sq / count) - mean**2
88
+ std = torch.sqrt(torch.clamp(var, min=1e-8)) # Clamp for numerical stability
89
+
90
+ _LOGGER.info(f"Scaler fitted on {count} samples for {num_continuous_features} continuous features.")
91
+ return cls(mean=mean, std=std, continuous_feature_indices=continuous_feature_indices)
92
+
93
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ Applies standardization to the specified continuous features.
96
+
97
+ Args:
98
+ data (torch.Tensor): The input data tensor.
99
+
100
+ Returns:
101
+ torch.Tensor: The transformed data tensor.
102
+ """
103
+ if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
104
+ _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
105
+ return data
106
+
107
+ data_clone = data.clone()
108
+
109
+ # Ensure mean and std are on the same device as the data
110
+ mean = self.mean_.to(data.device)
111
+ std = self.std_.to(data.device)
112
+
113
+ # Extract the columns to be scaled
114
+ features_to_scale = data_clone[:, self.continuous_feature_indices]
115
+
116
+ # Apply scaling, adding epsilon to std to prevent division by zero
117
+ scaled_features = (features_to_scale - mean) / (std + 1e-8)
118
+
119
+ # Place the scaled features back into the cloned tensor
120
+ data_clone[:, self.continuous_feature_indices] = scaled_features
121
+
122
+ return data_clone
123
+
124
+ def inverse_transform(self, data: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Applies the inverse of the standardization transformation.
127
+
128
+ Args:
129
+ data (torch.Tensor): The scaled data tensor.
130
+
131
+ Returns:
132
+ torch.Tensor: The original-scale data tensor.
133
+ """
134
+ if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
135
+ _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
136
+ return data
137
+
138
+ data_clone = data.clone()
139
+
140
+ mean = self.mean_.to(data.device)
141
+ std = self.std_.to(data.device)
142
+
143
+ features_to_inverse = data_clone[:, self.continuous_feature_indices]
144
+
145
+ # Apply inverse scaling
146
+ original_scale_features = (features_to_inverse * (std + 1e-8)) + mean
147
+
148
+ data_clone[:, self.continuous_feature_indices] = original_scale_features
149
+
150
+ return data_clone
151
+
152
+ def save(self, filepath: Union[str, Path]):
153
+ """
154
+ Saves the scaler's state (mean, std, indices) to a .pth file.
155
+
156
+ Args:
157
+ filepath (str | Path): The path to save the file.
158
+ """
159
+ path_obj = make_fullpath(filepath)
160
+ state = {
161
+ 'mean': self.mean_,
162
+ 'std': self.std_,
163
+ 'continuous_feature_indices': self.continuous_feature_indices
164
+ }
165
+ torch.save(state, path_obj)
166
+ _LOGGER.info(f"✅ PytorchScaler state saved to '{path_obj.name}'.")
167
+
168
+ @staticmethod
169
+ def load(filepath: Union[str, Path]) -> 'PytorchScaler':
170
+ """
171
+ Loads a scaler's state from a .pth file.
172
+
173
+ Args:
174
+ filepath (str | Path): The path to the saved scaler file.
175
+
176
+ Returns:
177
+ PytorchScaler: An instance of the scaler with the loaded state.
178
+ """
179
+ path_obj = make_fullpath(filepath, enforce="file")
180
+ state = torch.load(path_obj)
181
+ _LOGGER.info(f"✅ PytorchScaler state loaded from '{path_obj.name}'.")
182
+ return PytorchScaler(
183
+ mean=state['mean'],
184
+ std=state['std'],
185
+ continuous_feature_indices=state['continuous_feature_indices']
186
+ )
187
+
188
+ def __repr__(self) -> str:
189
+ """Returns the developer-friendly string representation of the scaler."""
190
+ if self.continuous_feature_indices:
191
+ num_features = len(self.continuous_feature_indices)
192
+ return f"PytorchScaler(fitted for {num_features} features)"
193
+ return "PytorchScaler(not fitted)"
194
+
195
+
196
+ def info():
197
+ _script_info(__all__)
ml_tools/ML_trainer.py CHANGED
@@ -6,7 +6,7 @@ from torch import nn
6
6
  import numpy as np
7
7
 
8
8
  from .ML_callbacks import Callback, History, TqdmProgressBar
9
- from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
9
+ from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
10
  from ._script_info import _script_info
11
11
  from .keys import PyTorchLogKeys
12
12
  from ._logger import _LOGGER
@@ -282,8 +282,11 @@ class MLTrainer:
282
282
  print("\n--- Training History ---")
283
283
  plot_losses(self.history, save_dir=save_dir)
284
284
 
285
- def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 1000,
286
- feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
285
+ def explain(self,
286
+ feature_names: Optional[List[str]],
287
+ save_dir: Union[str,Path],
288
+ explain_dataset: Optional[Dataset] = None,
289
+ n_samples: int = 1000):
287
290
  """
288
291
  Explains model predictions using SHAP and saves all artifacts.
289
292
 
@@ -328,14 +331,14 @@ class MLTrainer:
328
331
  # 1. Get background data from the trainer's train_dataset
329
332
  background_data = _get_random_sample(self.train_dataset, n_samples)
330
333
  if background_data is None:
331
- print("Warning: Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
334
+ _LOGGER.error(" Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
332
335
  return
333
336
 
334
337
  # 2. Determine target dataset and get explanation instances
335
338
  target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
336
339
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
337
340
  if instances_to_explain is None:
338
- print("Warning: Explanation dataset is empty or invalid. Skipping SHAP analysis.")
341
+ _LOGGER.error(" Explanation dataset is empty or invalid. Skipping SHAP analysis.")
339
342
  return
340
343
 
341
344
  # 3. Call the plotting function
@@ -347,7 +350,80 @@ class MLTrainer:
347
350
  save_dir=save_dir
348
351
  )
349
352
 
353
+ def _attention_helper(self, dataloader: DataLoader):
354
+ """
355
+ Private method to yield model attention weights batch by batch for evaluation.
350
356
 
357
+ Args:
358
+ dataloader (DataLoader): The dataloader to predict on.
359
+
360
+ Yields:
361
+ (torch.Tensor): Attention weights
362
+ """
363
+ self.model.eval()
364
+ self.model.to(self.device)
365
+
366
+ with torch.no_grad():
367
+ for features, target in dataloader:
368
+ features = features.to(self.device)
369
+ attention_weights = None
370
+
371
+ # Get model output
372
+ # Unpack logits and weights from the special forward method
373
+ _output, attention_weights = self.model.forward_attention(features) # type: ignore
374
+
375
+ if attention_weights is not None:
376
+ attention_weights = attention_weights.cpu()
377
+
378
+ yield attention_weights
379
+
380
+ def explain_attention(self, save_dir: Union[str, Path], feature_names: Optional[List[str]], explain_dataset: Optional[Dataset] = None):
381
+ """
382
+ Generates and saves a feature importance plot based on attention weights.
383
+
384
+ This method only works for models with a `forward_attention` method.
385
+
386
+ Args:
387
+ save_dir (str | Path): Directory to save the plot and summary data.
388
+ feature_names (List[str] | None): Names for the features for plot labeling.
389
+ explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
390
+ """
391
+
392
+ print("\n--- Attention Analysis ---")
393
+
394
+ # --- Step 1: Check if the model supports this explanation ---
395
+ if not hasattr(self.model, 'forward_attention'):
396
+ _LOGGER.error("❌ Model does not have a `forward_attention` method. Skipping attention explanation.")
397
+ return
398
+
399
+ # --- Step 2: Set up the dataloader ---
400
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
401
+ if not isinstance(dataset_to_use, Dataset):
402
+ _LOGGER.error("❌ The explanation dataset is empty or invalid. Skipping attention analysis.")
403
+ return
404
+
405
+ explain_loader = DataLoader(
406
+ dataset=dataset_to_use, batch_size=32, shuffle=False,
407
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
408
+ pin_memory=("cuda" in self.device.type)
409
+ )
410
+
411
+ # --- Step 3: Collect weights ---
412
+ all_weights = []
413
+ for att_weights_b in self._attention_helper(explain_loader):
414
+ if att_weights_b is not None:
415
+ all_weights.append(att_weights_b)
416
+
417
+ # --- Step 4: Call the plotting function ---
418
+ if all_weights:
419
+ plot_attention_importance(
420
+ weights=all_weights,
421
+ feature_names=feature_names,
422
+ save_dir=save_dir
423
+ )
424
+ else:
425
+ _LOGGER.error("❌ No attention weights were collected from the model.")
426
+
351
427
  def callbacks_hook(self, method_name: str, *args, **kwargs):
352
428
  """Calls the specified method on all callbacks."""
353
429
  for callback in self.callbacks: