dragon-ml-toolbox 2.3.0__py3-none-any.whl → 3.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,6 +6,7 @@ import numpy as np
6
6
  from .utilities import load_dataframe, list_csv_paths, sanitize_filename, _script_info, merge_dataframes, save_dataframe, threshold_binary_values, make_fullpath
7
7
  from plotnine import ggplot, labs, theme, element_blank # type: ignore
8
8
  from typing import Optional, Union
9
+ from .logger import _LOGGER
9
10
 
10
11
 
11
12
  __all__ = [
@@ -40,7 +41,9 @@ def apply_mice(df: pd.DataFrame, df_name: str, binary_columns: Optional[list[str
40
41
  if binary_columns is not None:
41
42
  invalid_binary_columns = set(binary_columns) - set(df.columns)
42
43
  if invalid_binary_columns:
43
- print(f"⚠️ These 'binary columns' are not in the dataset: {invalid_binary_columns}")
44
+ _LOGGER.warning(f"⚠️ These 'binary columns' are not in the dataset:")
45
+ for invalid_binary_col in invalid_binary_columns:
46
+ print(f" - {invalid_binary_col}")
44
47
  valid_binary_columns = [col for col in binary_columns if col not in invalid_binary_columns]
45
48
  for imputed_df in imputed_datasets:
46
49
  for binary_column_name in valid_binary_columns:
@@ -125,7 +128,7 @@ def get_convergence_diagnostic(kernel: mf.ImputationKernel, imputed_dataset_name
125
128
  plt.savefig(save_path, bbox_inches='tight', format="svg")
126
129
  plt.close()
127
130
 
128
- print(f"{dataset_file_dir} completed.")
131
+ _LOGGER.info(f"{dataset_file_dir} completed.")
129
132
 
130
133
 
131
134
  # Imputed distributions
@@ -210,7 +213,7 @@ def get_imputed_distributions(kernel: mf.ImputationKernel, df_name: str, root_di
210
213
  fig = kernel.plot_imputed_distributions(variables=[feature])
211
214
  _process_figure(fig, feature)
212
215
 
213
- print(f"{local_dir_name} completed.")
216
+ _LOGGER.info(f"{local_dir_name} completed.")
214
217
 
215
218
 
216
219
  def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str],
@@ -240,7 +243,8 @@ def run_mice_pipeline(df_path_or_dir: Union[str,Path], target_columns: list[str]
240
243
  all_file_paths = list(list_csv_paths(input_path).values())
241
244
 
242
245
  for df_path in all_file_paths:
243
- df, df_name = load_dataframe(df_path=df_path)
246
+ df: pd.DataFrame
247
+ df, df_name = load_dataframe(df_path=df_path, kind="pandas") # type: ignore
244
248
 
245
249
  df, df_targets = _skip_targets(df, target_columns)
246
250
 
@@ -0,0 +1,341 @@
1
+ import numpy as np
2
+ import torch
3
+ from tqdm.auto import tqdm
4
+ from .utilities import make_fullpath, LogKeys
5
+ from .logger import _LOGGER
6
+ from typing import Optional
7
+
8
+
9
+ __all__ = [
10
+ "Callback",
11
+ "History",
12
+ "TqdmProgressBar",
13
+ "EarlyStopping",
14
+ "ModelCheckpoint",
15
+ "LRScheduler"
16
+ ]
17
+
18
+
19
+ class Callback:
20
+ """
21
+ Abstract base class used to build new callbacks.
22
+
23
+ The methods of this class are automatically called by the Trainer at different
24
+ points during training. Subclasses can override these methods to implement
25
+ custom logic.
26
+ """
27
+ def __init__(self):
28
+ self.trainer = None
29
+
30
+ def set_trainer(self, trainer):
31
+ """This is called by the Trainer to associate itself with the callback."""
32
+ self.trainer = trainer
33
+
34
+ def on_train_begin(self, logs=None):
35
+ """Called at the beginning of training."""
36
+ pass
37
+
38
+ def on_train_end(self, logs=None):
39
+ """Called at the end of training."""
40
+ pass
41
+
42
+ def on_epoch_begin(self, epoch, logs=None):
43
+ """Called at the beginning of an epoch."""
44
+ pass
45
+
46
+ def on_epoch_end(self, epoch, logs=None):
47
+ """Called at the end of an epoch."""
48
+ pass
49
+
50
+ def on_batch_begin(self, batch, logs=None):
51
+ """Called at the beginning of a training batch."""
52
+ pass
53
+
54
+ def on_batch_end(self, batch, logs=None):
55
+ """Called at the end of a training batch."""
56
+ pass
57
+
58
+
59
+ class History(Callback):
60
+ """
61
+ Callback that records events into a `history` dictionary.
62
+
63
+ This callback is automatically applied to every MyTrainer model.
64
+ The `history` attribute is a dictionary mapping metric names (e.g., 'val_loss')
65
+ to a list of metric values.
66
+ """
67
+ def on_train_begin(self, logs=None):
68
+ # Clear history at the beginning of training
69
+ self.trainer.history = {} # type: ignore
70
+
71
+ def on_epoch_end(self, epoch, logs=None):
72
+ logs = logs or {}
73
+ for k, v in logs.items():
74
+ # Append new log values to the history dictionary
75
+ self.trainer.history.setdefault(k, []).append(v) # type: ignore
76
+
77
+
78
+ class TqdmProgressBar(Callback):
79
+ """Callback that provides a tqdm progress bar for training."""
80
+ def __init__(self):
81
+ self.epoch_bar = None
82
+ self.batch_bar = None
83
+
84
+ def on_train_begin(self, logs=None):
85
+ self.epochs = self.trainer.epochs # type: ignore
86
+ self.epoch_bar = tqdm(total=self.epochs, desc="Training Progress")
87
+
88
+ def on_epoch_begin(self, epoch, logs=None):
89
+ total_batches = len(self.trainer.train_loader) # type: ignore
90
+ self.batch_bar = tqdm(total=total_batches, desc=f"Epoch {epoch}/{self.epochs}", leave=False)
91
+
92
+ def on_batch_end(self, batch, logs=None):
93
+ self.batch_bar.update(1) # type: ignore
94
+ if logs:
95
+ self.batch_bar.set_postfix(loss=f"{logs.get(LogKeys.BATCH_LOSS, 0):.4f}") # type: ignore
96
+
97
+ def on_epoch_end(self, epoch, logs=None):
98
+ self.batch_bar.close() # type: ignore
99
+ self.epoch_bar.update(1) # type: ignore
100
+ if logs:
101
+ train_loss_str = f"{logs.get(LogKeys.TRAIN_LOSS, 0):.4f}"
102
+ val_loss_str = f"{logs.get(LogKeys.VAL_LOSS, 0):.4f}"
103
+ self.epoch_bar.set_postfix_str(f"Train Loss: {train_loss_str}, Val Loss: {val_loss_str}") # type: ignore
104
+
105
+ def on_train_end(self, logs=None):
106
+ self.epoch_bar.close() # type: ignore
107
+
108
+
109
+ class EarlyStopping(Callback):
110
+ """
111
+ Stop training when a monitored metric has stopped improving.
112
+
113
+ Args:
114
+ monitor (str): Quantity to be monitored. Defaults to 'val_loss'.
115
+ min_delta (float): Minimum change in the monitored quantity to qualify as an improvement.
116
+ patience (int): Number of epochs with no improvement after which training will be stopped.
117
+ mode (str): One of {'auto', 'min', 'max'}. In 'min' mode, training will stop when the quantity
118
+ monitored has stopped decreasing; in 'max' mode it will stop when the quantity
119
+ monitored has stopped increasing; in 'auto' mode, the direction is automatically
120
+ inferred from the name of the monitored quantity.
121
+ verbose (int): Verbosity mode.
122
+ """
123
+ def __init__(self, monitor: str=LogKeys.VAL_LOSS, min_delta=0.0, patience=3, mode='auto', verbose=1):
124
+ super().__init__()
125
+ self.monitor = monitor
126
+ self.patience = patience
127
+ self.min_delta = min_delta
128
+ self.wait = 0
129
+ self.stopped_epoch = 0
130
+ self.verbose = verbose
131
+
132
+ if mode not in ['auto', 'min', 'max']:
133
+ raise ValueError(f"EarlyStopping mode {mode} is unknown, choose one of ('auto', 'min', 'max')")
134
+ self.mode = mode
135
+
136
+ # Determine the comparison operator based on the mode
137
+ if self.mode == 'min':
138
+ self.monitor_op = np.less
139
+ elif self.mode == 'max':
140
+ self.monitor_op = np.greater
141
+ else: # auto mode
142
+ if 'acc' in self.monitor.lower():
143
+ self.monitor_op = np.greater
144
+ else: # Default to min mode for loss or other metrics
145
+ self.monitor_op = np.less
146
+
147
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
148
+
149
+ def on_train_begin(self, logs=None):
150
+ # Reset state at the beginning of training
151
+ self.wait = 0
152
+ self.stopped_epoch = 0
153
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
154
+
155
+ def on_epoch_end(self, epoch, logs=None):
156
+ current = logs.get(self.monitor) # type: ignore
157
+ if current is None:
158
+ return
159
+
160
+ # Determine the comparison threshold based on the mode
161
+ if self.monitor_op == np.less:
162
+ # For 'min' mode, we need to be smaller than 'best' by at least 'min_delta'
163
+ # Correct check: current < self.best - self.min_delta
164
+ is_improvement = self.monitor_op(current, self.best - self.min_delta)
165
+ else:
166
+ # For 'max' mode, we need to be greater than 'best' by at least 'min_delta'
167
+ # Correct check: current > self.best + self.min_delta
168
+ is_improvement = self.monitor_op(current, self.best + self.min_delta)
169
+
170
+ if is_improvement:
171
+ if self.verbose > 1:
172
+ _LOGGER.info(f"EarlyStopping: {self.monitor} improved from {self.best:.4f} to {current:.4f}")
173
+ self.best = current
174
+ self.wait = 0
175
+ else:
176
+ self.wait += 1
177
+ if self.wait >= self.patience:
178
+ self.stopped_epoch = epoch
179
+ self.trainer.stop_training = True # type: ignore
180
+ if self.verbose > 0:
181
+ print("")
182
+ _LOGGER.info(f"Epoch {epoch+1}: early stopping after {self.wait} epochs with no improvement.")
183
+
184
+
185
+ class ModelCheckpoint(Callback):
186
+ """
187
+ Saves the model to a directory with automated filename generation and rotation. The filename includes the epoch and score.
188
+
189
+ - If `save_best_only` is True, it saves the single best model, deleting the
190
+ previous best.
191
+ - If `save_best_only` is False, it keeps the 3 most recent checkpoints,
192
+ deleting the oldest ones automatically.
193
+
194
+ Args:
195
+ save_dir (str): Directory where checkpoint files will be saved.
196
+ monitor (str): Metric to monitor for `save_best_only=True`.
197
+ save_best_only (bool): If true, save only the best model.
198
+ mode (str): One of {'auto', 'min', 'max'}.
199
+ verbose (int): Verbosity mode.
200
+ """
201
+ def __init__(self, save_dir: str, monitor: str = LogKeys.VAL_LOSS,
202
+ save_best_only: bool = False, mode: str = 'auto', verbose: int = 1):
203
+ super().__init__()
204
+ self.save_dir = make_fullpath(save_dir, make=True)
205
+ if not self.save_dir.is_dir():
206
+ _LOGGER.error(f"{save_dir} is not a valid directory.")
207
+ raise IOError()
208
+
209
+ self.monitor = monitor
210
+ self.save_best_only = save_best_only
211
+ self.verbose = verbose
212
+
213
+ # State variables to be managed during training
214
+ self.saved_checkpoints = []
215
+ self.last_best_filepath = None
216
+
217
+ if mode not in ['auto', 'min', 'max']:
218
+ raise ValueError(f"ModelCheckpoint mode {mode} is unknown.")
219
+ self.mode = mode
220
+
221
+ if self.mode == 'min':
222
+ self.monitor_op = np.less
223
+ elif self.mode == 'max':
224
+ self.monitor_op = np.greater
225
+ else:
226
+ self.monitor_op = np.less if 'loss' in self.monitor else np.greater
227
+
228
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
229
+
230
+ def on_train_begin(self, logs=None):
231
+ """Reset state when training starts."""
232
+ self.best = np.Inf if self.monitor_op == np.less else -np.Inf
233
+ self.saved_checkpoints = []
234
+ self.last_best_filepath = None
235
+
236
+ def on_epoch_end(self, epoch, logs=None):
237
+ logs = logs or {}
238
+ self.save_dir.mkdir(parents=True, exist_ok=True)
239
+
240
+ if self.save_best_only:
241
+ self._save_best_model(epoch, logs)
242
+ else:
243
+ self._save_rolling_checkpoints(epoch, logs)
244
+
245
+ def _save_best_model(self, epoch, logs):
246
+ """Saves a single best model and deletes the previous one."""
247
+ current = logs.get(self.monitor)
248
+ if current is None:
249
+ return
250
+
251
+ if self.monitor_op(current, self.best):
252
+ old_best_str = f"{self.best:.4f}" if self.best not in [np.Inf, -np.Inf] else "inf"
253
+
254
+ # Create a descriptive filename
255
+ filename = f"epoch_{epoch}-{self.monitor}_{current:.4f}.pth"
256
+ new_filepath = self.save_dir / filename
257
+
258
+ if self.verbose > 0:
259
+ print("")
260
+ _LOGGER.info(f"Epoch {epoch}: {self.monitor} improved from {old_best_str} to {current:.4f}, saving model to {new_filepath}")
261
+
262
+ # Save the new best model
263
+ torch.save(self.trainer.model.state_dict(), new_filepath) # type: ignore
264
+
265
+ # Delete the old best model file
266
+ if self.last_best_filepath and self.last_best_filepath.exists():
267
+ self.last_best_filepath.unlink()
268
+
269
+ # Update state
270
+ self.best = current
271
+ self.last_best_filepath = new_filepath
272
+
273
+ def _save_rolling_checkpoints(self, epoch, logs):
274
+ """Saves the latest model and keeps only the last 5."""
275
+ filename = f"epoch_{epoch}.pth"
276
+ filepath = self.save_dir / filename
277
+
278
+ if self.verbose > 0:
279
+ print("")
280
+ _LOGGER.info(f'Epoch {epoch}: saving model to {filepath}')
281
+ torch.save(self.trainer.model.state_dict(), filepath) # type: ignore
282
+
283
+ self.saved_checkpoints.append(filepath)
284
+
285
+ # If we have more than n checkpoints, remove the oldest one
286
+ if len(self.saved_checkpoints) > 3:
287
+ file_to_delete = self.saved_checkpoints.pop(0)
288
+ if file_to_delete.exists():
289
+ if self.verbose > 0:
290
+ _LOGGER.info(f" -> Deleting old checkpoint: {file_to_delete.name}")
291
+ file_to_delete.unlink()
292
+
293
+
294
+ class LRScheduler(Callback):
295
+ """
296
+ Callback to manage a PyTorch learning rate scheduler.
297
+
298
+ This callback automatically calls the scheduler's `step()` method at the
299
+ end of each epoch. It also logs a message when the learning rate changes.
300
+
301
+ Args:
302
+ scheduler: An initialized PyTorch learning rate scheduler.
303
+ monitor (str, optional): The metric to monitor for schedulers that
304
+ require it, like `ReduceLROnPlateau`.
305
+ Should match a key in the logs (e.g., 'val_loss').
306
+ """
307
+ def __init__(self, scheduler, monitor: Optional[str] = None):
308
+ super().__init__()
309
+ self.scheduler = scheduler
310
+ self.monitor = monitor
311
+ self.previous_lr = None
312
+
313
+ def on_train_begin(self, logs=None):
314
+ """Store the initial learning rate."""
315
+ self.previous_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
316
+
317
+ def on_epoch_end(self, epoch, logs=None):
318
+ """Step the scheduler and log any change in learning rate."""
319
+ # For schedulers that need a metric (e.g., val_loss)
320
+ if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
321
+ if self.monitor is None:
322
+ raise ValueError("LRScheduler needs a `monitor` metric for ReduceLROnPlateau.")
323
+
324
+ metric_val = logs.get(self.monitor) # type: ignore
325
+ if metric_val is not None:
326
+ self.scheduler.step(metric_val)
327
+ else:
328
+ print("")
329
+ _LOGGER.warning(f"LRScheduler could not find metric '{self.monitor}' in logs.")
330
+
331
+ # For all other schedulers
332
+ else:
333
+ self.scheduler.step()
334
+
335
+ # Log the change if the LR was updated
336
+ current_lr = self.trainer.optimizer.param_groups[0]['lr'] # type: ignore
337
+ if current_lr != self.previous_lr:
338
+ print("")
339
+ _LOGGER.info(f"Epoch {epoch}: Learning rate changed to {current_lr:.6f}")
340
+ self.previous_lr = current_lr
341
+
@@ -0,0 +1,255 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.metrics import (
5
+ classification_report,
6
+ ConfusionMatrixDisplay,
7
+ roc_curve,
8
+ roc_auc_score,
9
+ mean_squared_error,
10
+ mean_absolute_error,
11
+ r2_score,
12
+ median_absolute_error
13
+ )
14
+ import torch
15
+ import shap
16
+ from pathlib import Path
17
+ from .utilities import make_fullpath
18
+ from .logger import _LOGGER
19
+ from typing import Union, Optional
20
+
21
+
22
+ __all__ = [
23
+ "plot_losses",
24
+ "classification_metrics",
25
+ "regression_metrics",
26
+ "shap_summary_plot"
27
+ ]
28
+
29
+
30
+ def plot_losses(history: dict, save_dir: Optional[Union[str, Path]] = None):
31
+ """
32
+ Plots training & validation loss curves from a history object.
33
+
34
+ Args:
35
+ history (dict): A dictionary containing 'train_loss' and 'val_loss'.
36
+ save_dir (str | Path | None): Directory to save the plot image.
37
+ """
38
+ train_loss = history.get('train_loss', [])
39
+ val_loss = history.get('val_loss', [])
40
+
41
+ if not train_loss and not val_loss:
42
+ print("Warning: Loss history is empty or incomplete. Cannot plot.")
43
+ return
44
+
45
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
46
+
47
+ # Plot training loss only if data for it exists
48
+ if train_loss:
49
+ epochs = range(1, len(train_loss) + 1)
50
+ ax.plot(epochs, train_loss, 'o-', label='Training Loss')
51
+
52
+ # Plot validation loss only if data for it exists
53
+ if val_loss:
54
+ epochs = range(1, len(val_loss) + 1)
55
+ ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
56
+
57
+ ax.set_title('Training and Validation Loss')
58
+ ax.set_xlabel('Epochs')
59
+ ax.set_ylabel('Loss')
60
+ ax.legend()
61
+ ax.grid(True)
62
+ plt.tight_layout()
63
+
64
+ if save_dir:
65
+ save_dir_path = make_fullpath(save_dir, make=True)
66
+ save_path = save_dir_path / "loss_plot.svg"
67
+ plt.savefig(save_path)
68
+ _LOGGER.info(f"Loss plot saved as '{save_path.name}'")
69
+ else:
70
+ plt.show()
71
+ plt.close(fig)
72
+
73
+
74
+ def classification_metrics(y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
75
+ cmap: str = "Blues", save_dir: Optional[Union[str, Path]] = None):
76
+ """
77
+ Displays and optionally saves classification metrics and plots.
78
+
79
+ Args:
80
+ y_true (np.ndarray): Ground truth labels.
81
+ y_pred (np.ndarray): Predicted labels.
82
+ y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
83
+ cmap (str): Colormap for the confusion matrix.
84
+ save_dir (str | Path | None): Directory to save plots. If None, plots are shown not saved.
85
+ """
86
+ print("--- Classification Report ---")
87
+ report: str = classification_report(y_true, y_pred) # type: ignore
88
+ print(report)
89
+
90
+ if save_dir:
91
+ save_dir_path = make_fullpath(save_dir, make=True)
92
+ # Save text report
93
+ report_path = save_dir_path / "classification_report.txt"
94
+ report_path.write_text(report, encoding="utf-8")
95
+ _LOGGER.info(f"Classification report saved as '{report_path.name}'")
96
+
97
+ # Save Confusion Matrix
98
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
99
+ ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
100
+ ax_cm.set_title("Confusion Matrix")
101
+ cm_path = save_dir_path / "confusion_matrix.svg"
102
+ plt.savefig(cm_path)
103
+ _LOGGER.info(f"Confusion matrix saved as '{cm_path.name}'")
104
+ plt.close(fig_cm)
105
+
106
+ # Save ROC Curve
107
+ if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
108
+ fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
109
+ auc = roc_auc_score(y_true, y_prob[:, 1])
110
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
111
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
112
+ ax_roc.plot([0, 1], [0, 1], 'k--')
113
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
114
+ ax_roc.set_xlabel('False Positive Rate')
115
+ ax_roc.set_ylabel('True Positive Rate')
116
+ ax_roc.legend(loc='lower right')
117
+ ax_roc.grid(True)
118
+ roc_path = save_dir_path / "roc_curve.svg"
119
+ plt.savefig(roc_path)
120
+ _LOGGER.info(f"ROC curve saved as '{roc_path.name}'")
121
+ plt.close(fig_roc)
122
+ else:
123
+ # Show plots if not saving
124
+ ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap)
125
+ plt.show()
126
+ if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
127
+ fpr, tpr, _ = roc_curve(y_true, y_prob[:, 1])
128
+ auc = roc_auc_score(y_true, y_prob[:, 1])
129
+ plt.figure()
130
+ plt.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
131
+ plt.plot([0, 1], [0, 1], 'k--')
132
+ plt.title('ROC Curve')
133
+ plt.show()
134
+
135
+
136
+ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Optional[Union[str, Path]] = None):
137
+ """
138
+ Displays regression metrics and optionally saves plots and report.
139
+
140
+ Args:
141
+ y_true (np.ndarray): Ground truth values.
142
+ y_pred (np.ndarray): Predicted values.
143
+ save_dir (str | None): Directory to save plots and report.
144
+ """
145
+ rmse = np.sqrt(mean_squared_error(y_true, y_pred))
146
+ mae = mean_absolute_error(y_true, y_pred)
147
+ r2 = r2_score(y_true, y_pred)
148
+ medae = median_absolute_error(y_true, y_pred)
149
+
150
+ report_lines = [
151
+ "--- Regression Report ---",
152
+ f" Root Mean Squared Error (RMSE): {rmse:.4f}",
153
+ f" Mean Absolute Error (MAE): {mae:.4f}",
154
+ f" Median Absolute Error (MedAE): {medae:.4f}",
155
+ f" Coefficient of Determination (R²): {r2:.4f}"
156
+ ]
157
+ report_string = "\n".join(report_lines)
158
+ print(report_string)
159
+
160
+ if save_dir:
161
+ save_dir_path = make_fullpath(save_dir, make=True)
162
+ # Save text report
163
+ report_path = save_dir_path / "regression_report.txt"
164
+ report_path.write_text(report_string)
165
+ _LOGGER.info(f"Regression report saved as '{report_path.name}'")
166
+
167
+ # Save residual plot
168
+ residuals = y_true - y_pred
169
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
170
+ ax_res.scatter(y_pred, residuals, alpha=0.6)
171
+ ax_res.axhline(0, color='red', linestyle='--')
172
+ ax_res.set_xlabel("Predicted Values")
173
+ ax_res.set_ylabel("Residuals")
174
+ ax_res.set_title("Residual Plot")
175
+ ax_res.grid(True)
176
+ plt.tight_layout()
177
+ res_path = save_dir_path / "residual_plot.svg"
178
+ plt.savefig(res_path)
179
+ _LOGGER.info(f"Residual plot saved as '{res_path.name}'")
180
+ plt.close(fig_res)
181
+
182
+ # Save true vs predicted plot
183
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
184
+ ax_tvp.scatter(y_true, y_pred, alpha=0.6)
185
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
186
+ ax_tvp.set_xlabel('True Values')
187
+ ax_tvp.set_ylabel('Predictions')
188
+ ax_tvp.set_title('True vs. Predicted Values')
189
+ ax_tvp.grid(True)
190
+ plt.tight_layout()
191
+ tvp_path = save_dir_path / "true_vs_predicted_plot.svg"
192
+ plt.savefig(tvp_path)
193
+ _LOGGER.info(f"True vs. Predicted plot saved as '{tvp_path.name}'")
194
+ plt.close(fig_tvp)
195
+
196
+
197
+ def shap_summary_plot(model, background_data: torch.Tensor, instances_to_explain: torch.Tensor,
198
+ feature_names: Optional[list[str]]=None, save_dir: Optional[Union[str, Path]] = None):
199
+ """
200
+ Calculates SHAP values and saves summary plots and data.
201
+
202
+ Args:
203
+ model (nn.Module): The trained PyTorch model.
204
+ background_data (torch.Tensor): A sample of data for the explainer background.
205
+ instances_to_explain (torch.Tensor): The specific data instances to explain.
206
+ feature_names (list of str | None): Names of the features for plot labeling.
207
+ save_dir (str | Path | None): Directory to save SHAP artifacts. If None, dot plot is shown.
208
+ """
209
+ print("\n--- SHAP Value Explanation ---")
210
+ print("Calculating SHAP values... ")
211
+
212
+ model.eval()
213
+ model.cpu()
214
+
215
+ explainer = shap.DeepExplainer(model, background_data)
216
+ shap_values = explainer.shap_values(instances_to_explain)
217
+
218
+ shap_values_for_plot = shap_values[1] if isinstance(shap_values, list) else shap_values
219
+ if isinstance(shap_values, list):
220
+ _LOGGER.info("Using SHAP values for the positive class (class 1) for plots.")
221
+
222
+ if save_dir:
223
+ save_dir_path = make_fullpath(save_dir, make=True)
224
+ # Save Bar Plot
225
+ bar_path = save_dir_path / "shap_bar_plot.svg"
226
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="bar", show=False)
227
+ plt.title("SHAP Feature Importance")
228
+ plt.tight_layout()
229
+ plt.savefig(bar_path)
230
+ _LOGGER.info(f"SHAP bar plot saved as '{bar_path.name}'")
231
+ plt.close()
232
+
233
+ # Save Dot Plot
234
+ dot_path = save_dir_path / "shap_dot_plot.svg"
235
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot", show=False)
236
+ plt.title("SHAP Feature Importance")
237
+ plt.tight_layout()
238
+ plt.savefig(dot_path)
239
+ _LOGGER.info(f"SHAP dot plot saved as '{dot_path.name}'")
240
+ plt.close()
241
+
242
+ # Save Summary Data to CSV
243
+ summary_path = save_dir_path / "shap_summary.csv"
244
+ mean_abs_shap = np.abs(shap_values_for_plot).mean(axis=0)
245
+ if feature_names is None:
246
+ feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
247
+ summary_df = pd.DataFrame({
248
+ 'feature': feature_names,
249
+ 'mean_abs_shap_value': mean_abs_shap
250
+ }).sort_values('mean_abs_shap_value', ascending=False)
251
+ summary_df.to_csv(summary_path, index=False)
252
+ _LOGGER.info(f"SHAP summary data saved as '{summary_path.name}'")
253
+ else:
254
+ _LOGGER.info("No save directory provided. Displaying SHAP dot plot.")
255
+ shap.summary_plot(shap_values_for_plot, instances_to_explain, feature_names=feature_names, plot_type="dot")