dragon-ml-toolbox 6.4.1__py3-none-any.whl → 8.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

ml_tools/ML_scaler.py ADDED
@@ -0,0 +1,197 @@
1
+ import torch
2
+ from torch.utils.data import Dataset, DataLoader
3
+ from pathlib import Path
4
+ from typing import Union, List, Optional
5
+ from ._logger import _LOGGER
6
+ from ._script_info import _script_info
7
+ from .path_manager import make_fullpath
8
+
9
+ __all__ = [
10
+ "PytorchScaler"
11
+ ]
12
+
13
+ class PytorchScaler:
14
+ """
15
+ Standardizes continuous features in a PyTorch dataset by subtracting the
16
+ mean and dividing by the standard deviation.
17
+
18
+ The scaler is fitted on a training dataset and can then be saved and
19
+ loaded for consistent transformation during inference.
20
+ """
21
+ def __init__(self,
22
+ mean: Optional[torch.Tensor] = None,
23
+ std: Optional[torch.Tensor] = None,
24
+ continuous_feature_indices: Optional[List[int]] = None):
25
+ """
26
+ Initializes the scaler.
27
+
28
+ Args:
29
+ mean (torch.Tensor, optional): The mean of the features to scale.
30
+ std (torch.Tensor, optional): The standard deviation of the features.
31
+ continuous_feature_indices (List[int], optional): The column indices of the features to standardize.
32
+ """
33
+ self.mean_ = mean
34
+ self.std_ = std
35
+ self.continuous_feature_indices = continuous_feature_indices
36
+
37
+ @classmethod
38
+ def fit(cls, dataset: Dataset, continuous_feature_indices: List[int], batch_size: int = 64) -> 'PytorchScaler':
39
+ """
40
+ Fits the scaler by computing the mean and std dev from a dataset using a
41
+ fast, single-pass, vectorized algorithm.
42
+
43
+ Args:
44
+ dataset (Dataset): The PyTorch Dataset to fit on.
45
+ continuous_feature_indices (List[int]): The column indices of the
46
+ features to standardize.
47
+ batch_size (int): The batch size for iterating through the dataset.
48
+
49
+ Returns:
50
+ PytorchScaler: A new, fitted instance of the scaler.
51
+ """
52
+ if not continuous_feature_indices:
53
+ _LOGGER.warning("⚠️ No continuous feature indices provided. Scaler will not be fitted.")
54
+ return cls()
55
+
56
+ loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
57
+
58
+ running_sum, running_sum_sq = None, None
59
+ count = 0
60
+ num_continuous_features = len(continuous_feature_indices)
61
+
62
+ for features, _ in loader:
63
+ if running_sum is None:
64
+ device = features.device
65
+ running_sum = torch.zeros(num_continuous_features, device=device)
66
+ running_sum_sq = torch.zeros(num_continuous_features, device=device)
67
+
68
+ continuous_features = features[:, continuous_feature_indices].to(device)
69
+
70
+ running_sum += torch.sum(continuous_features, dim=0)
71
+ running_sum_sq += torch.sum(continuous_features**2, dim=0) # type: ignore
72
+ count += continuous_features.size(0)
73
+
74
+ if count == 0:
75
+ _LOGGER.warning("⚠️ Dataset is empty. Scaler cannot be fitted.")
76
+ return cls(continuous_feature_indices=continuous_feature_indices)
77
+
78
+ # Calculate mean
79
+ mean = running_sum / count
80
+
81
+ # Calculate standard deviation
82
+ if count < 2:
83
+ _LOGGER.warning(f"⚠️ Only one sample found. Standard deviation cannot be calculated and is set to 1.")
84
+ std = torch.ones_like(mean)
85
+ else:
86
+ # var = E[X^2] - (E[X])^2
87
+ var = (running_sum_sq / count) - mean**2
88
+ std = torch.sqrt(torch.clamp(var, min=1e-8)) # Clamp for numerical stability
89
+
90
+ _LOGGER.info(f"Scaler fitted on {count} samples for {num_continuous_features} continuous features.")
91
+ return cls(mean=mean, std=std, continuous_feature_indices=continuous_feature_indices)
92
+
93
+ def transform(self, data: torch.Tensor) -> torch.Tensor:
94
+ """
95
+ Applies standardization to the specified continuous features.
96
+
97
+ Args:
98
+ data (torch.Tensor): The input data tensor.
99
+
100
+ Returns:
101
+ torch.Tensor: The transformed data tensor.
102
+ """
103
+ if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
104
+ _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
105
+ return data
106
+
107
+ data_clone = data.clone()
108
+
109
+ # Ensure mean and std are on the same device as the data
110
+ mean = self.mean_.to(data.device)
111
+ std = self.std_.to(data.device)
112
+
113
+ # Extract the columns to be scaled
114
+ features_to_scale = data_clone[:, self.continuous_feature_indices]
115
+
116
+ # Apply scaling, adding epsilon to std to prevent division by zero
117
+ scaled_features = (features_to_scale - mean) / (std + 1e-8)
118
+
119
+ # Place the scaled features back into the cloned tensor
120
+ data_clone[:, self.continuous_feature_indices] = scaled_features
121
+
122
+ return data_clone
123
+
124
+ def inverse_transform(self, data: torch.Tensor) -> torch.Tensor:
125
+ """
126
+ Applies the inverse of the standardization transformation.
127
+
128
+ Args:
129
+ data (torch.Tensor): The scaled data tensor.
130
+
131
+ Returns:
132
+ torch.Tensor: The original-scale data tensor.
133
+ """
134
+ if self.mean_ is None or self.std_ is None or self.continuous_feature_indices is None:
135
+ _LOGGER.warning("⚠️ Scaler has not been fitted. Returning original data.")
136
+ return data
137
+
138
+ data_clone = data.clone()
139
+
140
+ mean = self.mean_.to(data.device)
141
+ std = self.std_.to(data.device)
142
+
143
+ features_to_inverse = data_clone[:, self.continuous_feature_indices]
144
+
145
+ # Apply inverse scaling
146
+ original_scale_features = (features_to_inverse * (std + 1e-8)) + mean
147
+
148
+ data_clone[:, self.continuous_feature_indices] = original_scale_features
149
+
150
+ return data_clone
151
+
152
+ def save(self, filepath: Union[str, Path]):
153
+ """
154
+ Saves the scaler's state (mean, std, indices) to a .pth file.
155
+
156
+ Args:
157
+ filepath (str | Path): The path to save the file.
158
+ """
159
+ path_obj = make_fullpath(filepath)
160
+ state = {
161
+ 'mean': self.mean_,
162
+ 'std': self.std_,
163
+ 'continuous_feature_indices': self.continuous_feature_indices
164
+ }
165
+ torch.save(state, path_obj)
166
+ _LOGGER.info(f"✅ PytorchScaler state saved to '{path_obj.name}'.")
167
+
168
+ @staticmethod
169
+ def load(filepath: Union[str, Path]) -> 'PytorchScaler':
170
+ """
171
+ Loads a scaler's state from a .pth file.
172
+
173
+ Args:
174
+ filepath (str | Path): The path to the saved scaler file.
175
+
176
+ Returns:
177
+ PytorchScaler: An instance of the scaler with the loaded state.
178
+ """
179
+ path_obj = make_fullpath(filepath, enforce="file")
180
+ state = torch.load(path_obj)
181
+ _LOGGER.info(f"✅ PytorchScaler state loaded from '{path_obj.name}'.")
182
+ return PytorchScaler(
183
+ mean=state['mean'],
184
+ std=state['std'],
185
+ continuous_feature_indices=state['continuous_feature_indices']
186
+ )
187
+
188
+ def __repr__(self) -> str:
189
+ """Returns the developer-friendly string representation of the scaler."""
190
+ if self.continuous_feature_indices:
191
+ num_features = len(self.continuous_feature_indices)
192
+ return f"PytorchScaler(fitted for {num_features} features)"
193
+ return "PytorchScaler(not fitted)"
194
+
195
+
196
+ def info():
197
+ _script_info(__all__)
ml_tools/ML_trainer.py CHANGED
@@ -6,7 +6,8 @@ from torch import nn
6
6
  import numpy as np
7
7
 
8
8
  from .ML_callbacks import Callback, History, TqdmProgressBar
9
- from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
9
+ from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot, plot_attention_importance
10
+ from .ML_evaluation_multi import multi_target_regression_metrics, multi_label_classification_metrics, multi_target_shap_summary_plot
10
11
  from ._script_info import _script_info
11
12
  from .keys import PyTorchLogKeys
12
13
  from ._logger import _LOGGER
@@ -19,7 +20,7 @@ __all__ = [
19
20
 
20
21
  class MLTrainer:
21
22
  def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
22
- kind: Literal["regression", "classification"],
23
+ kind: Literal["regression", "classification", "multi_target_regression", "multi_label_classification"],
23
24
  criterion: nn.Module, optimizer: torch.optim.Optimizer,
24
25
  device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
25
26
  """
@@ -31,20 +32,22 @@ class MLTrainer:
31
32
  model (nn.Module): The PyTorch model to train.
32
33
  train_dataset (Dataset): The training dataset.
33
34
  test_dataset (Dataset): The testing/validation dataset.
34
- kind (str): The type of task, 'regression' or 'classification'.
35
+ kind (str): Can be 'regression', 'classification', 'multi_target_regression', or 'multi_label_classification'.
35
36
  criterion (nn.Module): The loss function.
36
37
  optimizer (torch.optim.Optimizer): The optimizer.
37
38
  device (str): The device to run training on ('cpu', 'cuda', 'mps').
38
- dataloader_workers (int): Subprocesses for data loading. Defaults to 2.
39
+ dataloader_workers (int): Subprocesses for data loading.
39
40
  callbacks (List[Callback] | None): A list of callbacks to use during training.
40
41
 
41
42
  Note:
42
- For **regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
43
-
44
- For **classification** tasks, `nn.CrossEntropyLoss` (multi-class) or `nn.BCEWithLogitsLoss` (binary) are common choices.
43
+ - For **regression** and **multi_target_regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
44
+
45
+ - For **single-label, multi-class classification** tasks, `nn.CrossEntropyLoss` is the standard choice.
46
+
47
+ - For **multi-label, binary classification** tasks (where each label is a 0 or 1), `nn.BCEWithLogitsLoss` is the correct choice as it treats each output as an independent binary problem.
45
48
  """
46
- if kind not in ["regression", "classification"]:
47
- raise TypeError("Kind must be 'regression' or 'classification'.")
49
+ if kind not in ["regression", "classification", "multi_target_regression", "multi_label_classification"]:
50
+ raise ValueError(f"'{kind}' is not a valid task type.")
48
51
 
49
52
  self.model = model
50
53
  self.train_dataset = train_dataset
@@ -157,7 +160,6 @@ class MLTrainer:
157
160
  def _train_step(self):
158
161
  self.model.train()
159
162
  running_loss = 0.0
160
- # Enumerate to get batch index
161
163
  for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
162
164
  # Create a log dictionary for the batch
163
165
  batch_logs = {
@@ -168,22 +170,26 @@ class MLTrainer:
168
170
 
169
171
  features, target = features.to(self.device), target.to(self.device)
170
172
  self.optimizer.zero_grad()
173
+
171
174
  output = self.model(features)
172
- if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
175
+
176
+ # Apply shape correction only for single-target regression
177
+ if self.kind == "regression":
173
178
  output = output.view_as(target)
179
+
174
180
  loss = self.criterion(output, target)
181
+
175
182
  loss.backward()
176
183
  self.optimizer.step()
177
184
 
178
185
  # Calculate batch loss and update running loss for the epoch
179
186
  batch_loss = loss.item()
180
187
  running_loss += batch_loss * features.size(0)
181
-
188
+
182
189
  # Add the batch loss to the logs and call the end-of-batch hook
183
190
  batch_logs[PyTorchLogKeys.BATCH_LOSS] = batch_loss
184
191
  self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
185
192
 
186
- # Return the average loss for the entire epoch
187
193
  return {PyTorchLogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
188
194
 
189
195
  def _validation_step(self):
@@ -192,25 +198,27 @@ class MLTrainer:
192
198
  with torch.no_grad():
193
199
  for features, target in self.test_loader: # type: ignore
194
200
  features, target = features.to(self.device), target.to(self.device)
201
+
195
202
  output = self.model(features)
196
- if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
203
+ # Apply shape correction only for single-target regression
204
+ if self.kind == "regression":
197
205
  output = output.view_as(target)
206
+
198
207
  loss = self.criterion(output, target)
208
+
199
209
  running_loss += loss.item() * features.size(0)
210
+
200
211
  logs = {PyTorchLogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
201
212
  return logs
202
213
 
203
- def _predict_for_eval(self, dataloader: DataLoader):
214
+ def _predict_for_eval(self, dataloader: DataLoader, classification_threshold: float = 0.5):
204
215
  """
205
216
  Private method to yield model predictions batch by batch for evaluation.
206
- This is used internally by the `evaluate` method.
207
-
208
- Args:
209
- dataloader (DataLoader): The dataloader to predict on.
210
-
217
+
211
218
  Yields:
212
219
  tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
213
- y_prob_batch is None for regression tasks.
220
+
221
+ - y_prob_batch is None for regression tasks.
214
222
  """
215
223
  self.model.eval()
216
224
  self.model.to(self.device)
@@ -220,81 +228,135 @@ class MLTrainer:
220
228
  output = self.model(features).cpu()
221
229
  y_true_batch = target.numpy()
222
230
 
223
- if self.kind == "classification":
224
- probs = nn.functional.softmax(output, dim=1)
231
+ y_pred_batch = None
232
+ y_prob_batch = None
233
+
234
+ if self.kind in ["regression", "multi_target_regression"]:
235
+ y_pred_batch = output.numpy()
236
+
237
+ elif self.kind == "classification":
238
+ probs = torch.softmax(output, dim=1)
225
239
  preds = torch.argmax(probs, dim=1)
226
240
  y_pred_batch = preds.numpy()
227
241
  y_prob_batch = probs.numpy()
228
- # regression
229
- else:
230
- y_pred_batch = output.numpy()
231
- y_prob_batch = None
232
-
242
+
243
+ elif self.kind == "multi_label_classification":
244
+ probs = torch.sigmoid(output)
245
+ preds = (probs >= classification_threshold).int()
246
+ y_pred_batch = preds.numpy()
247
+ y_prob_batch = probs.numpy()
248
+
233
249
  yield y_pred_batch, y_prob_batch, y_true_batch
234
-
235
- def evaluate(self, save_dir: Union[str,Path], data: Optional[Union[DataLoader, Dataset]] = None):
250
+
251
+ def evaluate(self, save_dir: Union[str, Path], data: Optional[Union[DataLoader, Dataset]] = None, classification_threshold: float = 0.5):
236
252
  """
237
- Evaluates the model on the given data.
253
+ Evaluates the model, routing to the correct evaluation function based on task `kind`.
238
254
 
239
255
  Args:
240
- data (DataLoader | Dataset | None ): The data to evaluate on.
241
- Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
242
256
  save_dir (str | Path): Directory to save all reports and plots.
257
+ data (DataLoader | Dataset | None): The data to evaluate on. If None, defaults to the trainer's internal test_dataset.
258
+ classification_threshold (float): Probability threshold for multi-label tasks.
243
259
  """
260
+ dataset_for_names = None
244
261
  eval_loader = None
262
+
245
263
  if isinstance(data, DataLoader):
246
264
  eval_loader = data
247
- else:
248
- # Determine which dataset to use (the one passed in, or the default test_dataset)
249
- dataset_to_use = data if data is not None else self.test_dataset
250
- if not isinstance(dataset_to_use, Dataset):
251
- raise ValueError("Cannot evaluate. No valid DataLoader or Dataset was provided, "
252
- "and no test_dataset is available in the trainer.")
253
-
254
- # Create a new DataLoader from the dataset
255
- eval_loader = DataLoader(
256
- dataset=dataset_to_use,
257
- batch_size=32, # A sensible default for evaluation
258
- shuffle=False,
259
- num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
260
- pin_memory=(self.device.type == "cuda")
261
- )
262
-
265
+ # Try to get the dataset from the loader for fetching target names
266
+ if hasattr(data, 'dataset'):
267
+ dataset_for_names = data.dataset
268
+ elif isinstance(data, Dataset):
269
+ # Create a new loader from the provided dataset
270
+ eval_loader = DataLoader(data,
271
+ batch_size=32,
272
+ shuffle=False,
273
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
274
+ pin_memory=(self.device.type == "cuda"))
275
+ dataset_for_names = data
276
+ else: # data is None, use the trainer's default test dataset
277
+ if self.test_dataset is None:
278
+ raise ValueError("Cannot evaluate. No data provided and no test_dataset available in the trainer.")
279
+ # Create a fresh DataLoader from the test_dataset
280
+ eval_loader = DataLoader(self.test_dataset,
281
+ batch_size=32,
282
+ shuffle=False,
283
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
284
+ pin_memory=(self.device.type == "cuda"))
285
+ dataset_for_names = self.test_dataset
286
+
287
+ if eval_loader is None:
288
+ raise ValueError("Cannot evaluate. No valid data was provided or found.")
289
+
263
290
  print("\n--- Model Evaluation ---")
264
291
 
265
- # Collect results from the predict generator
266
292
  all_preds, all_probs, all_true = [], [], []
267
- for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader):
268
- all_preds.append(y_pred_b)
269
- if y_prob_b is not None:
270
- all_probs.append(y_prob_b)
271
- all_true.append(y_true_b)
293
+ for y_pred_b, y_prob_b, y_true_b in self._predict_for_eval(eval_loader, classification_threshold):
294
+ if y_pred_b is not None: all_preds.append(y_pred_b)
295
+ if y_prob_b is not None: all_probs.append(y_prob_b)
296
+ if y_true_b is not None: all_true.append(y_true_b)
297
+
298
+ if not all_true:
299
+ _LOGGER.error("❌ Evaluation failed: No data was processed.")
300
+ return
272
301
 
273
302
  y_pred = np.concatenate(all_preds)
274
303
  y_true = np.concatenate(all_true)
275
- y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
304
+ y_prob = np.concatenate(all_probs) if all_probs else None
276
305
 
277
- if self.kind == "classification":
278
- classification_metrics(save_dir, y_true, y_pred, y_prob)
279
- else:
306
+ # --- Routing Logic ---
307
+ if self.kind == "regression":
280
308
  regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir)
281
309
 
310
+ elif self.kind == "classification":
311
+ classification_metrics(save_dir, y_true, y_pred, y_prob)
312
+
313
+ elif self.kind == "multi_target_regression":
314
+ try:
315
+ target_names = dataset_for_names.target_names # type: ignore
316
+ except AttributeError:
317
+ num_targets = y_true.shape[1]
318
+ target_names = [f"target_{i}" for i in range(num_targets)]
319
+ _LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
320
+ multi_target_regression_metrics(y_true, y_pred, target_names, save_dir)
321
+
322
+ elif self.kind == "multi_label_classification":
323
+ try:
324
+ target_names = dataset_for_names.target_names # type: ignore
325
+ except AttributeError:
326
+ num_targets = y_true.shape[1]
327
+ target_names = [f"label_{i}" for i in range(num_targets)]
328
+ _LOGGER.warning(f"⚠️ Dataset has no 'target_names' attribute. Using generic names.")
329
+
330
+ if y_prob is None:
331
+ _LOGGER.error("❌ Evaluation for multi_label_classification requires probabilities (y_prob).")
332
+ return
333
+ multi_label_classification_metrics(y_true, y_prob, target_names, save_dir, classification_threshold)
334
+
282
335
  print("\n--- Training History ---")
283
336
  plot_losses(self.history, save_dir=save_dir)
284
337
 
285
- def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 1000,
286
- feature_names: Optional[List[str]] = None, save_dir: Optional[Union[str,Path]] = None):
338
+ def explain(self,
339
+ save_dir: Union[str,Path],
340
+ explain_dataset: Optional[Dataset] = None,
341
+ n_samples: int = 1000,
342
+ feature_names: Optional[List[str]] = None,
343
+ target_names: Optional[List[str]] = None):
287
344
  """
288
345
  Explains model predictions using SHAP and saves all artifacts.
289
346
 
290
347
  The background data is automatically sampled from the trainer's training dataset.
348
+
349
+ This method automatically routes to the appropriate SHAP summary plot
350
+ function based on the task. If `feature_names` or `target_names` (multi-target) are not provided,
351
+ it will attempt to extract them from the dataset.
291
352
 
292
353
  Args:
293
- explain_dataset (Dataset, optional): A specific dataset to explain.
354
+ explain_dataset (Dataset | None): A specific dataset to explain.
294
355
  If None, the trainer's test dataset is used.
295
356
  n_samples (int): The number of samples to use for both background and explanation.
296
- feature_names (List[str], optional): Names for the features.
297
- save_dir (str, optional): Directory to save all SHAP artifacts.
357
+ feature_names (list[str] | None): Feature names.
358
+ target_names (list[str] | None): Target names
359
+ save_dir (str | Path): Directory to save all SHAP artifacts.
298
360
  """
299
361
  # Internal helper to create a dataloader and get a random sample
300
362
  def _get_random_sample(dataset: Dataset, num_samples: int):
@@ -328,26 +390,137 @@ class MLTrainer:
328
390
  # 1. Get background data from the trainer's train_dataset
329
391
  background_data = _get_random_sample(self.train_dataset, n_samples)
330
392
  if background_data is None:
331
- print("Warning: Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
393
+ _LOGGER.error(" Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
332
394
  return
333
395
 
334
396
  # 2. Determine target dataset and get explanation instances
335
397
  target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
336
398
  instances_to_explain = _get_random_sample(target_dataset, n_samples)
337
399
  if instances_to_explain is None:
338
- print("Warning: Explanation dataset is empty or invalid. Skipping SHAP analysis.")
400
+ _LOGGER.error(" Explanation dataset is empty or invalid. Skipping SHAP analysis.")
339
401
  return
402
+
403
+ # attempt to get feature names
404
+ if feature_names is None:
405
+ # _LOGGER.info("`feature_names` not provided. Attempting to extract from dataset...")
406
+ if hasattr(target_dataset, "feature_names"):
407
+ feature_names = target_dataset.feature_names # type: ignore
408
+ else:
409
+ try:
410
+ # Handle PyTorch Subset
411
+ feature_names = target_dataset.dataset.feature_names # type: ignore
412
+ except AttributeError:
413
+ _LOGGER.error("❌ Could not extract `feature_names` from the dataset.")
414
+ raise ValueError("`feature_names` must be provided if the dataset object does not have a `feature_names` attribute.")
340
415
 
341
416
  # 3. Call the plotting function
342
- shap_summary_plot(
343
- model=self.model,
344
- background_data=background_data,
345
- instances_to_explain=instances_to_explain,
346
- feature_names=feature_names,
347
- save_dir=save_dir
348
- )
417
+ if self.kind in ["regression", "classification"]:
418
+ shap_summary_plot(
419
+ model=self.model,
420
+ background_data=background_data,
421
+ instances_to_explain=instances_to_explain,
422
+ feature_names=feature_names,
423
+ save_dir=save_dir
424
+ )
425
+ elif self.kind in ["multi_target_regression", "multi_label_classification"]:
426
+ # try to get target names
427
+ if target_names is None:
428
+ target_names = []
429
+ if hasattr(target_dataset, 'target_names'):
430
+ target_names = target_dataset.target_names # type: ignore
431
+ else:
432
+ # Infer number of targets from the model's output layer
433
+ try:
434
+ num_targets = self.model.output_layer.out_features # type: ignore
435
+ target_names = [f"target_{i}" for i in range(num_targets)] # type: ignore
436
+ _LOGGER.warning("Dataset has no 'target_names' attribute. Using generic names.")
437
+ except AttributeError:
438
+ _LOGGER.error("Cannot determine target names for multi-target SHAP plot. Skipping.")
439
+ return
440
+
441
+ multi_target_shap_summary_plot(
442
+ model=self.model,
443
+ background_data=background_data,
444
+ instances_to_explain=instances_to_explain,
445
+ feature_names=feature_names, # type: ignore
446
+ target_names=target_names, # type: ignore
447
+ save_dir=save_dir
448
+ )
449
+
450
+ def _attention_helper(self, dataloader: DataLoader):
451
+ """
452
+ Private method to yield model attention weights batch by batch for evaluation.
453
+
454
+ Args:
455
+ dataloader (DataLoader): The dataloader to predict on.
456
+
457
+ Yields:
458
+ (torch.Tensor): Attention weights
459
+ """
460
+ self.model.eval()
461
+ self.model.to(self.device)
462
+
463
+ with torch.no_grad():
464
+ for features, target in dataloader:
465
+ features = features.to(self.device)
466
+ attention_weights = None
467
+
468
+ # Get model output
469
+ # Unpack logits and weights from the special forward method
470
+ _output, attention_weights = self.model.forward_attention(features) # type: ignore
471
+
472
+ if attention_weights is not None:
473
+ attention_weights = attention_weights.cpu()
474
+
475
+ yield attention_weights
349
476
 
477
+ def explain_attention(self, save_dir: Union[str, Path], feature_names: Optional[List[str]], explain_dataset: Optional[Dataset] = None):
478
+ """
479
+ Generates and saves a feature importance plot based on attention weights.
480
+
481
+ This method only works for models with a `forward_attention` method.
482
+
483
+ Args:
484
+ save_dir (str | Path): Directory to save the plot and summary data.
485
+ feature_names (List[str] | None): Names for the features for plot labeling.
486
+ explain_dataset (Dataset, optional): A specific dataset to explain. If None, the trainer's test dataset is used.
487
+ """
488
+
489
+ print("\n--- Attention Analysis ---")
490
+
491
+ # --- Step 1: Check if the model supports this explanation ---
492
+ if not hasattr(self.model, 'forward_attention'):
493
+ _LOGGER.error("❌ Model does not have a `forward_attention` method. Skipping attention explanation.")
494
+ return
350
495
 
496
+ # --- Step 2: Set up the dataloader ---
497
+ dataset_to_use = explain_dataset if explain_dataset is not None else self.test_dataset
498
+ if not isinstance(dataset_to_use, Dataset):
499
+ _LOGGER.error("❌ The explanation dataset is empty or invalid. Skipping attention analysis.")
500
+ return
501
+
502
+ explain_loader = DataLoader(
503
+ dataset=dataset_to_use, batch_size=32, shuffle=False,
504
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
505
+ pin_memory=("cuda" in self.device.type)
506
+ )
507
+
508
+ # --- Step 3: Collect weights ---
509
+ all_weights = []
510
+ for att_weights_b in self._attention_helper(explain_loader):
511
+ if att_weights_b is not None:
512
+ all_weights.append(att_weights_b)
513
+
514
+ # --- Step 4: Call the plotting function ---
515
+ if all_weights:
516
+ plot_attention_importance(
517
+ weights=all_weights,
518
+ feature_names=feature_names,
519
+ save_dir=save_dir
520
+ )
521
+ else:
522
+ _LOGGER.error("❌ No attention weights were collected from the model.")
523
+
351
524
  def callbacks_hook(self, method_name: str, *args, **kwargs):
352
525
  """Calls the specified method on all callbacks."""
353
526
  for callback in self.callbacks: