dragon-ml-toolbox 2.4.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,255 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.metrics import (
5
+ classification_report,
6
+ ConfusionMatrixDisplay,
7
+ roc_curve,
8
+ roc_auc_score,
9
+ mean_squared_error,
10
+ mean_absolute_error,
11
+ r2_score,
12
+ median_absolute_error
13
+ )
14
+ import torch
15
+ import shap
16
+ from pathlib import Path
17
+ from .utilities import make_fullpath
18
+ from .logger import _LOGGER
19
+ from typing import Union, Optional
20
+
21
+
22
+ __all__ = [
23
+ "plot_losses",
24
+ "classification_metrics",
25
+ "regression_metrics",
26
+ "shap_summary_plot"
27
+ ]
28
+
29
+
30
+ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
31
+ """
32
+ Plots training & validation loss curves from a history object.
33
+
34
+ Args:
35
+ history (dict): A dictionary containing 'train_loss' and 'val_loss'.
36
+ save_dir (str | Path | None): Directory to save the plot image.
37
+ """
38
+ train_loss = history.get('train_loss', [])
39
+ val_loss = history.get('val_loss', [])
40
+
41
+ if not train_loss and not val_loss:
42
+ print("Warning: Loss history is empty or incomplete. Cannot plot.")
43
+ return
44
+
45
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
46
+
47
+ # Plot training loss only if data for it exists
48
+ if train_loss:
49
+ epochs = range(1, len(train_loss) + 1)
50
+ ax.plot(epochs, train_loss, 'o-', label='Training Loss')
51
+
52
+ # Plot validation loss only if data for it exists
53
+ if val_loss:
54
+ epochs = range(1, len(val_loss) + 1)
55
+ ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
56
+
57
+ ax.set_title('Training and Validation Loss')
58
+ ax.set_xlabel('Epochs')
59
+ ax.set_ylabel('Loss')
60
+ ax.legend()
61
+ ax.grid(True)
62
+ plt.tight_layout()
63
+
64
+ if save_dir:
65
+ save_dir_path = make_fullpath(save_dir, make=True)
66
+ save_path = save_dir_path / "loss_plot.svg"
67
+ plt.savefig(save_path)
68
+ _LOGGER.info(f"Loss plot saved as '{save_path.name}'")
69
+ else:
70
+ plt.show()
71
+ plt.close(fig)
72
+
73
+
74
+ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
75
+ cmap: str = "Blues", save_dir: Optional[Union[str, Path]] = None):
76
+ """
77
+ Displays and optionally saves classification metrics and plots.
78
+
79
+ Args:
80
+ y_true (np.ndarray): Ground truth labels.
81
+ y_pred (np.ndarray): Predicted labels.
82
+ y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
83
+ cmap (str): Colormap for the confusion matrix.
84
+ save_dir (str | Path | None): Directory to save plots. If None, plots are shown not saved.
85
+ """
86
+ print("--- Classification Report ---")
87
+ report: str = classification_report(y_true, y_pred) # type: ignore
88
+ print(report)
89
+
90
+ if save_dir:
91
+ save_dir_path = make_fullpath(save_dir, make=True)
92
+ # Save text report
93
+ report_path = save_dir_path / "classification_report.txt"
94
+ report_path.write_text(report, encoding="utf-8")
95
+ _LOGGER.info(f"Classification report saved as '{report_path.name}'")
96
+
97
+ # Save Confusion Matrix
98
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
99
+ ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
100
+ ax_cm.set_title("Confusion Matrix")
101
+ cm_path = save_dir_path / "confusion_matrix.svg"
102
+ plt.savefig(cm_path)
103
+ _LOGGER.info(f"Confusion matrix saved as '{cm_path.name}'")
104
+ plt.close(fig_cm)
105
+
106
+ # Save ROC Curve
107
+ if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
108
+ fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
109
+ auc = roc_auc_score(y_true, y_prob[:, 1])
110
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
111
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
112
+ ax_roc.plot([0, 1], [0, 1], 'k--')
113
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
114
+ ax_roc.set_xlabel('False Positive Rate')
115
+ ax_roc.set_ylabel('True Positive Rate')
116
+ ax_roc.legend(loc='lower right')
117
+ ax_roc.grid(True)
118
+ roc_path = save_dir_path / "roc_curve.svg"
119
+ plt.savefig(roc_path)
120
+ _LOGGER.info(f"ROC curve saved as '{roc_path.name}'")
121
+ plt.close(fig_roc)
122
+ else:
123
+ # Show plots if not saving
124
+ ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap)
125
+ plt.show()
126
+ if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
127
+ fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
128
+ auc = roc_auc_score(y_true, y_prob[:, 1])
129
+ plt.figure()
130
+ plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
131
+ plt.plot([0, 1], [0, 1], 'k--')
132
+ plt.title('ROC Curve')
133
+ plt.show()
134
+
135
+
136
+ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optional[Union[str, Path]] = None):
137
+ """
138
+ Displays regression metrics and optionally saves plots and report.
139
+
140
+ Args:
141
+ y_true (np.ndarray): Ground truth values.
142
+ y_pred (np.ndarray): Predicted values.
143
+ save_dir (str | None): Directory to save plots and report.
144
+ """
145
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred))
146
+ mae = mean_absolute_error(y_true, y_pred)
147
+ r2 = r2_score(y_true, y_pred)
148
+ medae = median_absolute_error(y_true, y_pred)
149
+
150
+ report_lines = [
151
+ "--- Regression Report ---",
152
+ f" Root Mean Squared Error (RMSE): {rmse:.4f}",
153
+ f" Mean Absolute Error (MAE): {mae:.4f}",
154
+ f" Median Absolute Error (MedAE): {medae:.4f}",
155
+ f" Coefficient of Determination (R²): {r2:.4f}"
156
+ ]
157
+ report_string = "\n".join(report_lines)
158
+ print(report_string)
159
+
160
+ if save_dir:
161
+ save_dir_path = make_fullpath(save_dir, make=True)
162
+ # Save text report
163
+ report_path = save_dir_path / "regression_report.txt"
164
+ report_path.write_text(report_string)
165
+ _LOGGER.info(f"Regression report saved as '{report_path.name}'")
166
+
167
+ # Save residual plot
168
+ residuals = y_true - y_pred
169
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
170
+ ax_res.scatter(y_pred, residuals, alpha=0.6)
171
+ ax_res.axhline(0, color='red', linestyle='--')
172
+ ax_res.set_xlabel("Predicted Values")
173
+ ax_res.set_ylabel("Residuals")
174
+ ax_res.set_title("Residual Plot")
175
+ ax_res.grid(True)
176
+ plt.tight_layout()
177
+ res_path = save_dir_path / "residual_plot.svg"
178
+ plt.savefig(res_path)
179
+ _LOGGER.info(f"Residual plot saved as '{res_path.name}'")
180
+ plt.close(fig_res)
181
+
182
+ # Save true vs predicted plot
183
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
184
+ ax_tvp.scatter(y_true, y_pred, alpha=0.6)
185
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
186
+ ax_tvp.set_xlabel('True Values')
187
+ ax_tvp.set_ylabel('Predictions')
188
+ ax_tvp.set_title('True vs. Predicted Values')
189
+ ax_tvp.grid(True)
190
+ plt.tight_layout()
191
+ tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
192
+ plt.savefig(tvp_path)
193
+ _LOGGER.info(f"True vs. Predicted plot saved as '{tvp_path.name}'")
194
+ plt.close(fig_tvp)
195
+
196
+
197
+ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
198
+ feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
199
+ """
200
+ Calculates SHAP values and saves summary plots and data.
201
+
202
+ Args:
203
+ model (nn.Module): The trained PyTorch model.
204
+ background_data (torch.Tensor): A sample of data for the explainer background.
205
+ instances_to_explain (torch.Tensor): The specific data instances to explain.
206
+ feature_names (list of str | None): Names of the features for plot labeling.
207
+ save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
208
+ """
209
+ print("\n--- SHAP Value Explanation ---")
210
+ print("Calculating SHAP values... ")
211
+
212
+ model.eval()
213
+ model.cpu()
214
+
215
+ explainer = shap.DeepExplainer(model, background_data)
216
+ shap_values = explainer.shap_values(instances_to_explain)
217
+
218
+ shap_values_for_plot = shap_values[1] if isinstance(shap_values, list) else shap_values
219
+ if isinstance(shap_values, list):
220
+ _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
221
+
222
+ if save_dir:
223
+ save_dir_path = make_fullpath(save_dir, make=True)
224
+ # Save Bar Plot
225
+ bar_path = save_dir_path / "shap_bar_plot.svg"
226
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
227
+ plt.title("SHAP Feature Importance")
228
+ plt.tight_layout()
229
+ plt.savefig(bar_path)
230
+ _LOGGER.info(f"SHAP bar plot saved as '{bar_path.name}'")
231
+ plt.close()
232
+
233
+ # Save Dot Plot
234
+ dot_path = save_dir_path / "shap_dot_plot.svg"
235
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot", show=False)
236
+ plt.title("SHAP Feature Importance")
237
+ plt.tight_layout()
238
+ plt.savefig(dot_path)
239
+ _LOGGER.info(f"SHAP dot plot saved as '{dot_path.name}'")
240
+ plt.close()
241
+
242
+ # Save Summary Data to CSV
243
+ summary_path = save_dir_path / "shap_summary.csv"
244
+ mean_abs_shap = np.abs(shap_values_for_plot).mean(axis=0)
245
+ if feature_names is None:
246
+ feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
247
+ summary_df = pd.DataFrame({
248
+ 'feature': feature_names,
249
+ 'mean_abs_shap_value': mean_abs_shap
250
+ }).sort_values('mean_abs_shap_value', ascending=False)
251
+ summary_df.to_csv(summary_path, index=False)
252
+ _LOGGER.info(f"SHAP summary data saved as '{summary_path.name}'")
253
+ else:
254
+ _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
255
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")
ml_tools/ML_trainer.py ADDED
@@ -0,0 +1,344 @@
1
+ from typing import List, Literal, Union, Optional
2
+ from pathlib import Path
3
+ from torch.utils.data import DataLoader, Dataset
4
+ import torch
5
+ from torch import nn
6
+ import numpy as np
7
+
8
+ from .ML_callbacks import Callback, History, TqdmProgressBar
9
+ from .ML_evaluation import classification_metrics, regression_metrics, plot_losses, shap_summary_plot
10
+ from .utilities import _script_info, LogKeys
11
+ from .logger import _LOGGER
12
+
13
+
14
+ __all__ = [
15
+ "MyTrainer"
16
+ ]
17
+
18
+
19
+ class MyTrainer:
20
+ def __init__(self, model: nn.Module, train_dataset: Dataset, test_dataset: Dataset,
21
+ kind: Literal["regression", "classification"],
22
+ criterion: nn.Module, optimizer: torch.optim.Optimizer,
23
+ device: Union[Literal['cuda', 'mps', 'cpu'],str], dataloader_workers: int = 2, callbacks: Optional[List[Callback]] = None):
24
+ """
25
+ Automates the training process of a PyTorch Model.
26
+
27
+ Built-in Callbacks: `History`, `TqdmProgressBar`
28
+
29
+ Args:
30
+ model (nn.Module): The PyTorch model to train.
31
+ train_dataset (Dataset): The training dataset.
32
+ test_dataset (Dataset): The testing/validation dataset.
33
+ kind (str): The type of task, 'regression' or 'classification'.
34
+ criterion (nn.Module): The loss function.
35
+ optimizer (torch.optim.Optimizer): The optimizer.
36
+ device (str): The device to run training on ('cpu', 'cuda', 'mps').
37
+ dataloader_workers (int): Subprocesses for data loading. Defaults to 2.
38
+ callbacks (List[Callback] | None): A list of callbacks to use during training.
39
+
40
+ Note:
41
+ For **regression** tasks, suggested criterions include `nn.MSELoss` or `nn.L1Loss`.
42
+
43
+ For **classification** tasks, `nn.CrossEntropyLoss` (multi-class) or `nn.BCEWithLogitsLoss` (binary) are common choices.
44
+ """
45
+ if kind not in ["regression", "classification"]:
46
+ raise TypeError("Kind must be 'regression' or 'classification'.")
47
+
48
+ self.model = model
49
+ self.train_dataset = train_dataset
50
+ self.test_dataset = test_dataset
51
+ self.kind = kind
52
+ self.criterion = criterion
53
+ self.optimizer = optimizer
54
+ self.device = self._validate_device(device)
55
+ self.dataloader_workers = dataloader_workers
56
+
57
+ # Callback handler - History and TqdmProgressBar are added by default
58
+ default_callbacks = [History(), TqdmProgressBar()]
59
+ user_callbacks = callbacks if callbacks is not None else []
60
+ self.callbacks = default_callbacks + user_callbacks
61
+ self._set_trainer_on_callbacks()
62
+
63
+ # Internal state
64
+ self.train_loader = None
65
+ self.test_loader = None
66
+ self.history = {}
67
+ self.epoch = 0
68
+ self.epochs = 0 # Total epochs for the fit run
69
+ self.stop_training = False
70
+
71
+ def _validate_device(self, device: str) -> torch.device:
72
+ """Validates the selected device and returns a torch.device object."""
73
+ device_lower = device.lower()
74
+ if "cuda" in device_lower and not torch.cuda.is_available():
75
+ _LOGGER.warning("CUDA not available, switching to CPU.")
76
+ device = "cpu"
77
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
78
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
79
+ device = "cpu"
80
+ return torch.device(device)
81
+
82
+ def _set_trainer_on_callbacks(self):
83
+ """Gives each callback a reference to this trainer instance."""
84
+ for callback in self.callbacks:
85
+ callback.set_trainer(self)
86
+
87
+ def _create_dataloaders(self, batch_size: int, shuffle: bool):
88
+ """Initializes the DataLoaders."""
89
+ # Ensure stability on MPS devices by setting num_workers to 0
90
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
91
+
92
+ self.train_loader = DataLoader(
93
+ dataset=self.train_dataset,
94
+ batch_size=batch_size,
95
+ shuffle=shuffle,
96
+ num_workers=loader_workers,
97
+ pin_memory=(self.device.type == "cuda")
98
+ )
99
+ self.test_loader = DataLoader(
100
+ dataset=self.test_dataset,
101
+ batch_size=batch_size,
102
+ shuffle=False,
103
+ num_workers=loader_workers,
104
+ pin_memory=(self.device.type == "cuda")
105
+ )
106
+
107
+ def fit(self, epochs: int = 10, batch_size: int = 32, shuffle: bool = True):
108
+ """
109
+ Starts the training-validation process of the model.
110
+
111
+ Args:
112
+ epochs (int): The total number of epochs to train for.
113
+ batch_size (int): The number of samples per batch.
114
+ shuffle (bool): Whether to shuffle the training data at each epoch.
115
+ """
116
+ self.epochs = epochs
117
+ self._create_dataloaders(batch_size, shuffle)
118
+ self.model.to(self.device)
119
+
120
+ # Reset stop_training flag on the trainer
121
+ self.stop_training = False
122
+
123
+ self.callbacks_hook('on_train_begin')
124
+
125
+ for epoch in range(1, self.epochs + 1):
126
+ self.epoch = epoch
127
+ epoch_logs = {}
128
+ self.callbacks_hook('on_epoch_begin', epoch, logs=epoch_logs)
129
+
130
+ train_logs = self._train_step()
131
+ epoch_logs.update(train_logs)
132
+
133
+ val_logs = self._validation_step()
134
+ epoch_logs.update(val_logs)
135
+
136
+ self.callbacks_hook('on_epoch_end', epoch, logs=epoch_logs)
137
+
138
+ # Check the early stopping flag
139
+ if self.stop_training:
140
+ break
141
+
142
+ self.callbacks_hook('on_train_end')
143
+ return self.history
144
+
145
+ def _train_step(self):
146
+ self.model.train()
147
+ running_loss = 0.0
148
+ # Enumerate to get batch index
149
+ for batch_idx, (features, target) in enumerate(self.train_loader): # type: ignore
150
+ # Create a log dictionary for the batch
151
+ batch_logs = {
152
+ LogKeys.BATCH_INDEX: batch_idx,
153
+ LogKeys.BATCH_SIZE: features.size(0)
154
+ }
155
+ self.callbacks_hook('on_batch_begin', batch_idx, logs=batch_logs)
156
+
157
+ features, target = features.to(self.device), target.to(self.device)
158
+ self.optimizer.zero_grad()
159
+ output = self.model(features)
160
+ if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
161
+ output = output.view_as(target)
162
+ loss = self.criterion(output, target)
163
+ loss.backward()
164
+ self.optimizer.step()
165
+
166
+ # Calculate batch loss and update running loss for the epoch
167
+ batch_loss = loss.item()
168
+ running_loss += batch_loss * features.size(0)
169
+
170
+ # Add the batch loss to the logs and call the end-of-batch hook
171
+ batch_logs[LogKeys.BATCH_LOSS] = batch_loss
172
+ self.callbacks_hook('on_batch_end', batch_idx, logs=batch_logs)
173
+
174
+ # Return the average loss for the entire epoch
175
+ return {LogKeys.TRAIN_LOSS: running_loss / len(self.train_loader.dataset)} # type: ignore
176
+
177
+ def _validation_step(self):
178
+ self.model.eval()
179
+ running_loss = 0.0
180
+ with torch.no_grad():
181
+ for features, target in self.test_loader: # type: ignore
182
+ features, target = features.to(self.device), target.to(self.device)
183
+ output = self.model(features)
184
+ if isinstance(self.criterion, (nn.MSELoss, nn.L1Loss)):
185
+ output = output.view_as(target)
186
+ loss = self.criterion(output, target)
187
+ running_loss += loss.item() * features.size(0)
188
+ logs = {LogKeys.VAL_LOSS: running_loss / len(self.test_loader.dataset)} # type: ignore
189
+ return logs
190
+
191
+ def predict(self, dataloader: DataLoader):
192
+ """
193
+ Yields model predictions batch by batch, avoids loading all predictions into memory at once.
194
+
195
+ Args:
196
+ dataloader (DataLoader): The dataloader to predict on.
197
+
198
+ Yields:
199
+ tuple: A tuple containing (y_pred_batch, y_prob_batch, y_true_batch).
200
+ y_prob_batch is None for regression tasks.
201
+ """
202
+ self.model.eval()
203
+ self.model.to(self.device)
204
+ with torch.no_grad():
205
+ for features, target in dataloader:
206
+ features = features.to(self.device)
207
+ output = self.model(features).cpu()
208
+ y_true_batch = target.numpy()
209
+
210
+ if self.kind == "classification":
211
+ probs = nn.functional.softmax(output, dim=1)
212
+ preds = torch.argmax(probs, dim=1)
213
+ y_pred_batch = preds.numpy()
214
+ y_prob_batch = probs.numpy()
215
+ else:
216
+ y_pred_batch = output.numpy()
217
+ y_prob_batch = None
218
+
219
+ yield y_pred_batch, y_prob_batch, y_true_batch
220
+
221
+ def evaluate(self, data: Optional[Union[DataLoader, Dataset]] = None, save_dir: Optional[Union[str,Path]] = None):
222
+ """
223
+ Evaluates the model on the given data.
224
+
225
+ Args:
226
+ data (DataLoader | Dataset | None ): The data to evaluate on.
227
+ Can be a DataLoader or a Dataset. If None, defaults to the trainer's internal test_dataset.
228
+ save_dir (str | Path | None): Directory to save all reports and plots. If None, metrics are shown but not saved.
229
+ """
230
+ eval_loader = None
231
+ if isinstance(data, DataLoader):
232
+ eval_loader = data
233
+ else:
234
+ # Determine which dataset to use (the one passed in, or the default test_dataset)
235
+ dataset_to_use = data if data is not None else self.test_dataset
236
+ if not isinstance(dataset_to_use, Dataset):
237
+ raise ValueError("Cannot evaluate. No valid DataLoader or Dataset was provided, "
238
+ "and no test_dataset is available in the trainer.")
239
+
240
+ # Create a new DataLoader from the dataset
241
+ eval_loader = DataLoader(
242
+ dataset=dataset_to_use,
243
+ batch_size=32, # A sensible default for evaluation
244
+ shuffle=False,
245
+ num_workers=0 if self.device.type == 'mps' else self.dataloader_workers,
246
+ pin_memory=(self.device.type == "cuda")
247
+ )
248
+
249
+ print("\n--- Model Evaluation ---")
250
+
251
+ # Collect results from the predict generator
252
+ all_preds, all_probs, all_true = [], [], []
253
+ for y_pred_b, y_prob_b, y_true_b in self.predict(eval_loader):
254
+ all_preds.append(y_pred_b)
255
+ if y_prob_b is not None:
256
+ all_probs.append(y_prob_b)
257
+ all_true.append(y_true_b)
258
+
259
+ y_pred = np.concatenate(all_preds)
260
+ y_true = np.concatenate(all_true)
261
+ y_prob = np.concatenate(all_probs) if self.kind == "classification" else None
262
+
263
+ if self.kind == "classification":
264
+ classification_metrics(y_true, y_pred, y_prob, save_dir=save_dir)
265
+ else:
266
+ regression_metrics(y_true.flatten(), y_pred.flatten(), save_dir=save_dir)
267
+
268
+ print("\n--- Training History ---")
269
+ plot_losses(self.history, save_dir=save_dir)
270
+
271
+ def explain(self, explain_dataset: Optional[Dataset] = None, n_samples: int = 100,
272
+ feature_names: Optional[List[str]] = None, save_dir: Optional[str] = None):
273
+ """
274
+ Explains model predictions using SHAP and saves all artifacts.
275
+
276
+ The background data is automatically sampled from the trainer's training dataset.
277
+
278
+ Args:
279
+ explain_dataset (Dataset, optional): A specific dataset to explain.
280
+ If None, the trainer's test dataset is used.
281
+ n_samples (int): The number of samples to use for both background and explanation.
282
+ feature_names (List[str], optional): Names for the features.
283
+ save_dir (str, optional): Directory to save all SHAP artifacts.
284
+ """
285
+ # Internal helper to create a dataloader and get a random sample
286
+ def _get_random_sample(dataset: Dataset, num_samples: int):
287
+ if dataset is None:
288
+ return None
289
+
290
+ # For MPS devices, num_workers must be 0 to ensure stability
291
+ loader_workers = 0 if self.device.type == 'mps' else self.dataloader_workers
292
+
293
+ loader = DataLoader(
294
+ dataset,
295
+ batch_size=64,
296
+ shuffle=False,
297
+ num_workers=loader_workers
298
+ )
299
+
300
+ all_features = [features for features, _ in loader]
301
+ if not all_features:
302
+ return None
303
+
304
+ full_data = torch.cat(all_features, dim=0)
305
+
306
+ if num_samples >= full_data.size(0):
307
+ return full_data
308
+
309
+ rand_indices = torch.randperm(full_data.size(0))[:num_samples]
310
+ return full_data[rand_indices]
311
+
312
+ print(f"\n--- Preparing SHAP Data (sampling up to {n_samples} instances) ---")
313
+
314
+ # 1. Get background data from the trainer's train_dataset
315
+ background_data = _get_random_sample(self.train_dataset, n_samples)
316
+ if background_data is None:
317
+ print("Warning: Trainer's train_dataset is empty or invalid. Skipping SHAP analysis.")
318
+ return
319
+
320
+ # 2. Determine target dataset and get explanation instances
321
+ target_dataset = explain_dataset if explain_dataset is not None else self.test_dataset
322
+ instances_to_explain = _get_random_sample(target_dataset, n_samples)
323
+ if instances_to_explain is None:
324
+ print("Warning: Explanation dataset is empty or invalid. Skipping SHAP analysis.")
325
+ return
326
+
327
+ # 3. Call the plotting function
328
+ shap_summary_plot(
329
+ model=self.model,
330
+ background_data=background_data,
331
+ instances_to_explain=instances_to_explain,
332
+ feature_names=feature_names,
333
+ save_dir=save_dir
334
+ )
335
+
336
+
337
+ def callbacks_hook(self, method_name: str, *args, **kwargs):
338
+ """Calls the specified method on all callbacks."""
339
+ for callback in self.callbacks:
340
+ method = getattr(callback, method_name)
341
+ method(*args, **kwargs)
342
+
343
+ def info():
344
+ _script_info(__all__)