dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -21,10 +21,17 @@ from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
22
  import warnings
23
23
 
24
- from .path_manager import make_fullpath
24
+ from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys
27
+ from ._keys import SHAPKeys, PyTorchLogKeys
28
+ from .ML_configuration import (RegressionMetricsFormat,
29
+ BinaryClassificationMetricsFormat,
30
+ MultiClassClassificationMetricsFormat,
31
+ BinaryImageClassificationMetricsFormat,
32
+ MultiClassImageClassificationMetricsFormat,
33
+ _BaseClassificationFormat,
34
+ _BaseRegressionFormat)
28
35
 
29
36
 
30
37
  __all__ = [
@@ -35,40 +42,66 @@ __all__ = [
35
42
  "plot_attention_importance"
36
43
  ]
37
44
 
45
+ DPI_value = 250
46
+
38
47
 
39
48
  def plot_losses(history: dict, save_dir: Union[str, Path]):
40
49
  """
41
50
  Plots training & validation loss curves from a history object.
51
+ Also plots the learning rate if available in the history.
42
52
 
43
53
  Args:
44
54
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
45
55
  save_dir (str | Path): Directory to save the plot image.
46
56
  """
47
- train_loss = history.get('train_loss', [])
48
- val_loss = history.get('val_loss', [])
57
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
58
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
59
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
49
60
 
50
61
  if not train_loss and not val_loss:
51
- print("Warning: Loss history is empty or incomplete. Cannot plot.")
62
+ _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
52
63
  return
53
64
 
54
- fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
65
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
66
+
67
+ # --- Plot Losses (Left Y-axis) ---
68
+ line_handles = [] # To store line objects for the legend
55
69
 
56
70
  # Plot training loss only if data for it exists
57
71
  if train_loss:
58
72
  epochs = range(1, len(train_loss) + 1)
59
- ax.plot(epochs, train_loss, 'o-', label='Training Loss')
73
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
74
+ line_handles.append(line1)
60
75
 
61
76
  # Plot validation loss only if data for it exists
62
77
  if val_loss:
63
78
  epochs = range(1, len(val_loss) + 1)
64
- ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
79
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
80
+ line_handles.append(line2)
65
81
 
66
82
  ax.set_title('Training and Validation Loss')
67
83
  ax.set_xlabel('Epochs')
68
- ax.set_ylabel('Loss')
69
- ax.legend()
70
- ax.grid(True)
71
- plt.tight_layout()
84
+ ax.set_ylabel('Loss', color='tab:blue')
85
+ ax.tick_params(axis='y', labelcolor='tab:blue')
86
+ ax.grid(True, linestyle='--')
87
+
88
+ # --- Plot Learning Rate (Right Y-axis) ---
89
+ if lr_history:
90
+ ax2 = ax.twinx() # Create a second y-axis
91
+ epochs = range(1, len(lr_history) + 1)
92
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
93
+ line_handles.append(line3)
94
+
95
+ ax2.set_ylabel('Learning Rate', color='g')
96
+ ax2.tick_params(axis='y', labelcolor='g')
97
+ # Use scientific notation if the LR is very small
98
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
99
+
100
+ # Combine legends from both axes
101
+ ax.legend(handles=line_handles, loc='best')
102
+
103
+ # ax.grid(True)
104
+ plt.tight_layout()
72
105
 
73
106
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
74
107
  save_path = save_dir_path / "loss_plot.svg"
@@ -78,23 +111,55 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
78
111
  plt.close(fig)
79
112
 
80
113
 
81
- def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
82
- cmap: str = "Blues"):
114
+ def classification_metrics(save_dir: Union[str, Path],
115
+ y_true: np.ndarray,
116
+ y_pred: np.ndarray,
117
+ y_prob: Optional[np.ndarray] = None,
118
+ class_map: Optional[dict[str,int]] = None,
119
+ config: Optional[Union[BinaryClassificationMetricsFormat,
120
+ MultiClassClassificationMetricsFormat,
121
+ BinaryImageClassificationMetricsFormat,
122
+ MultiClassImageClassificationMetricsFormat]] = None):
83
123
  """
84
124
  Saves classification metrics and plots.
85
125
 
86
126
  Args:
87
127
  y_true (np.ndarray): Ground truth labels.
88
128
  y_pred (np.ndarray): Predicted labels.
89
- y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
90
- cmap (str): Colormap for the confusion matrix.
129
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
130
+ config (object): Formatting configuration object.
91
131
  save_dir (str | Path): Directory to save plots.
92
132
  """
93
- print("--- Classification Report ---")
133
+ # --- Parse Config or use defaults ---
134
+ if config is None:
135
+ # Create a default config if one wasn't provided
136
+ format_config = _BaseClassificationFormat()
137
+ else:
138
+ format_config = config
139
+
140
+ original_rc_params = plt.rcParams.copy()
141
+ plt.rcParams.update({'font.size': format_config.font_size})
142
+
143
+ # print("--- Classification Report ---")
144
+
145
+ # --- Parse class_map ---
146
+ map_labels = None
147
+ map_display_labels = None
148
+ if class_map:
149
+ # Sort the map by its values (the indices) to ensure correct order
150
+ try:
151
+ sorted_items = sorted(class_map.items(), key=lambda item: item[1])
152
+ map_labels = [item[1] for item in sorted_items]
153
+ map_display_labels = [item[0] for item in sorted_items]
154
+ except Exception as e:
155
+ _LOGGER.warning(f"Could not parse 'class_map': {e}")
156
+ map_labels = None
157
+ map_display_labels = None
158
+
94
159
  # Generate report as both text and dictionary
95
- report_text: str = classification_report(y_true, y_pred) # type: ignore
96
- report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
97
- print(report_text)
160
+ report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
161
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
162
+ # print(report_text)
98
163
 
99
164
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
100
165
  # Save text report
@@ -104,8 +169,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
104
169
 
105
170
  # --- Save Classification Report Heatmap ---
106
171
  try:
107
- plt.figure(figsize=(8, 6), dpi=100)
108
- sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
172
+ plt.figure(figsize=(8, 6), dpi=DPI_value)
173
+ sns.set_theme(font_scale=1.2) # Scale seaborn font
174
+ sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
175
+ annot=True,
176
+ cmap=format_config.cmap,
177
+ fmt='.2f',
178
+ vmin=0.0,
179
+ vmax=1.0)
180
+ sns.set_theme(font_scale=1.0) # Reset seaborn scale
109
181
  plt.title("Classification Report")
110
182
  plt.tight_layout()
111
183
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
@@ -114,72 +186,224 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
114
186
  plt.close()
115
187
  except Exception as e:
116
188
  _LOGGER.error(f"Could not generate classification report heatmap: {e}")
117
-
189
+
190
+ # --- labels for Confusion Matrix ---
191
+ plot_labels = map_labels
192
+ plot_display_labels = map_display_labels
193
+
118
194
  # Save Confusion Matrix
119
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
120
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
195
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
196
+ disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
197
+ y_pred,
198
+ cmap=format_config.cmap,
199
+ ax=ax_cm,
200
+ normalize='true',
201
+ labels=plot_labels,
202
+ display_labels=plot_display_labels)
203
+
204
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
205
+
206
+ # Turn off gridlines
207
+ ax_cm.grid(False)
208
+
209
+ # Manually update font size of cell texts
210
+ for text in ax_cm.texts:
211
+ text.set_fontsize(format_config.font_size)
212
+
213
+ fig_cm.tight_layout()
214
+
121
215
  ax_cm.set_title("Confusion Matrix")
122
216
  cm_path = save_dir_path / "confusion_matrix.svg"
123
217
  plt.savefig(cm_path)
124
218
  _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
125
219
  plt.close(fig_cm)
126
220
 
127
- # Plotting logic for ROC and PR Curves
128
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
129
- # Use probabilities of the positive class
130
- y_score = y_prob[:, 1]
221
+
222
+ # Plotting logic for ROC, PR, and Calibration Curves
223
+ if y_prob is not None and y_prob.ndim == 2:
224
+ num_classes = y_prob.shape[1]
131
225
 
132
- # --- Save ROC Curve ---
133
- fpr, tpr, _ = roc_curve(y_true, y_score)
134
- auc = roc_auc_score(y_true, y_score)
135
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
136
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
137
- ax_roc.plot([0, 1], [0, 1], 'k--')
138
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
139
- ax_roc.set_xlabel('False Positive Rate')
140
- ax_roc.set_ylabel('True Positive Rate')
141
- ax_roc.legend(loc='lower right')
142
- ax_roc.grid(True)
143
- roc_path = save_dir_path / "roc_curve.svg"
144
- plt.savefig(roc_path)
145
- _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
146
- plt.close(fig_roc)
147
-
148
- # --- Save Precision-Recall Curve ---
149
- precision, recall, _ = precision_recall_curve(y_true, y_score)
150
- ap_score = average_precision_score(y_true, y_score)
151
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
152
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
153
- ax_pr.set_title('Precision-Recall Curve')
154
- ax_pr.set_xlabel('Recall')
155
- ax_pr.set_ylabel('Precision')
156
- ax_pr.legend(loc='lower left')
157
- ax_pr.grid(True)
158
- pr_path = save_dir_path / "pr_curve.svg"
159
- plt.savefig(pr_path)
160
- _LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
161
- plt.close(fig_pr)
226
+ # --- Determine which classes to loop over ---
227
+ class_indices_to_plot = []
228
+ plot_titles = []
229
+ save_suffixes = []
230
+
231
+ if num_classes == 2:
232
+ # Binary case: Only plot for the positive class (index 1)
233
+ class_indices_to_plot = [1]
234
+ plot_titles = [""] # No extra title
235
+ save_suffixes = [""] # No extra suffix
236
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
237
+
238
+ elif num_classes > 2:
239
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
240
+ # Multiclass case: Plot for every class (One-vs-Rest)
241
+ class_indices_to_plot = list(range(num_classes))
242
+
243
+ # --- Use class_map names if available ---
244
+ use_generic_names = True
245
+ if map_display_labels and len(map_display_labels) == num_classes:
246
+ try:
247
+ # Ensure labels are safe for filenames
248
+ safe_names = [sanitize_filename(name) for name in map_display_labels]
249
+ plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
250
+ save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
251
+ use_generic_names = False
252
+ except Exception as e:
253
+ _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
254
+ use_generic_names = True
255
+
256
+ if use_generic_names:
257
+ plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
258
+ save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
162
259
 
163
- # --- Save Calibration Plot ---
164
- if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
165
- y_score = y_prob[:, 1] # Use probabilities of the positive class
260
+ else:
261
+ # Should not happen, but good to check
262
+ _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
263
+
264
+ # --- Loop and generate plots ---
265
+ for i, class_index in enumerate(class_indices_to_plot):
266
+ plot_title = plot_titles[i]
267
+ save_suffix = save_suffixes[i]
268
+
269
+ # Get scores for the current class
270
+ y_score = y_prob[:, class_index]
271
+
272
+ # Binarize y_true for the current class
273
+ y_true_binary = (y_true == class_index).astype(int)
274
+
275
+ # --- Save ROC Curve ---
276
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
277
+
278
+ try:
279
+ # Calculate Youden's J statistic (tpr - fpr)
280
+ J = tpr - fpr
281
+ # Find the index of the best threshold
282
+ best_index = np.argmax(J)
283
+ optimal_threshold = thresholds[best_index]
284
+
285
+ # Define the filename
286
+ threshold_filename = f"best_threshold{save_suffix}.txt"
287
+ threshold_path = save_dir_path / threshold_filename
288
+
289
+ # Get the class name for the report
290
+ class_name = ""
291
+ # Check if we have display labels and the current index is valid
292
+ if map_display_labels and class_index < len(map_display_labels):
293
+ class_name = map_display_labels[class_index]
294
+ if num_classes > 2:
295
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
296
+ class_name += " (vs. Rest)"
297
+ else:
298
+ # Fallback to the generic title or default binary name
299
+ class_name = plot_title.strip() or "Binary Positive Class"
300
+
301
+ # Create content for the file
302
+ file_content = (
303
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
304
+ f"Class: {class_name}\n"
305
+ f"--------------------------------------------------\n"
306
+ f"Threshold: {optimal_threshold:.6f}\n"
307
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
308
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
309
+ )
310
+
311
+ threshold_path.write_text(file_content, encoding="utf-8")
312
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
313
+
314
+ except Exception as e:
315
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
166
316
 
167
- fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
168
- CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
317
+ # Calculate AUC.
318
+ auc = roc_auc_score(y_true_binary, y_score)
169
319
 
170
- ax_cal.set_title('Reliability Curve')
320
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
321
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
322
+ ax_roc.plot([0, 1], [0, 1], 'k--')
323
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
324
+ ax_roc.set_xlabel('False Positive Rate')
325
+ ax_roc.set_ylabel('True Positive Rate')
326
+ ax_roc.legend(loc='lower right')
327
+ ax_roc.grid(True)
328
+ roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
329
+ plt.savefig(roc_path)
330
+ plt.close(fig_roc)
331
+
332
+ # --- Save Precision-Recall Curve ---
333
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
334
+ ap_score = average_precision_score(y_true_binary, y_score)
335
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
336
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
337
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
338
+ ax_pr.set_xlabel('Recall')
339
+ ax_pr.set_ylabel('Precision')
340
+ ax_pr.legend(loc='lower left')
341
+ ax_pr.grid(True)
342
+ pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
343
+ plt.savefig(pr_path)
344
+ plt.close(fig_pr)
345
+
346
+ # --- Save Calibration Plot ---
347
+ fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
348
+
349
+ # --- Step 1: Get binned data *without* plotting ---
350
+ with plt.ioff(): # Suppress showing the temporary plot
351
+ fig_temp, ax_temp = plt.subplots()
352
+ cal_display_temp = CalibrationDisplay.from_predictions(
353
+ y_true_binary, # Use binarized labels
354
+ y_score,
355
+ n_bins=format_config.calibration_bins,
356
+ ax=ax_temp,
357
+ name="temp" # Add a name to suppress potential warnings
358
+ )
359
+ # Get the x, y coordinates of the binned data
360
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
361
+ plt.close(fig_temp) # Close the temporary plot
362
+
363
+ # --- Step 2: Build the plot from scratch ---
364
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
365
+
366
+ sns.regplot(
367
+ x=line_x,
368
+ y=line_y,
369
+ ax=ax_cal,
370
+ scatter=False,
371
+ label=f"Calibration Curve ({format_config.calibration_bins} bins)",
372
+ line_kws={
373
+ 'color': format_config.ROC_PR_line,
374
+ 'linestyle': '--',
375
+ 'linewidth': 2,
376
+ }
377
+ )
378
+
379
+ ax_cal.set_title(f'Reliability Curve{plot_title}')
171
380
  ax_cal.set_xlabel('Mean Predicted Probability')
172
381
  ax_cal.set_ylabel('Fraction of Positives')
382
+
383
+ # --- Step 3: Set final limits *after* plotting ---
384
+ ax_cal.set_ylim(0.0, 1.0)
385
+ ax_cal.set_xlim(0.0, 1.0)
386
+
387
+ ax_cal.legend(loc='lower right')
173
388
  ax_cal.grid(True)
174
389
  plt.tight_layout()
175
390
 
176
- cal_path = save_dir_path / "calibration_plot.svg"
391
+ cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
177
392
  plt.savefig(cal_path)
178
- _LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
179
393
  plt.close(fig_cal)
394
+
395
+ _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
396
+
397
+ # restore RC params
398
+ plt.rcParams.update(original_rc_params)
180
399
 
181
400
 
182
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
401
+ def regression_metrics(
402
+ y_true: np.ndarray,
403
+ y_pred: np.ndarray,
404
+ save_dir: Union[str, Path],
405
+ config: Optional[RegressionMetricsFormat] = None
406
+ ):
183
407
  """
184
408
  Saves regression metrics and plots.
185
409
 
@@ -187,7 +411,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
187
411
  y_true (np.ndarray): Ground truth values.
188
412
  y_pred (np.ndarray): Predicted values.
189
413
  save_dir (str | Path): Directory to save plots and report.
414
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
190
415
  """
416
+
417
+ # --- Parse Config or use defaults ---
418
+ if config is None:
419
+ # Create a default config if one wasn't provided
420
+ format_config = _BaseRegressionFormat()
421
+ else:
422
+ format_config = config
423
+
424
+ # --- Set Matplotlib font size ---
425
+ original_rc_params = plt.rcParams.copy()
426
+ plt.rcParams.update({'font.size': format_config.font_size})
427
+
428
+ # --- Calculate Metrics ---
191
429
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
192
430
  mae = mean_absolute_error(y_true, y_pred)
193
431
  r2 = r2_score(y_true, y_pred)
@@ -209,11 +447,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
209
447
  report_path.write_text(report_string)
210
448
  _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
211
449
 
212
- # Save residual plot
450
+ # --- Save residual plot ---
213
451
  residuals = y_true - y_pred
214
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
215
- ax_res.scatter(y_pred, residuals, alpha=0.6)
216
- ax_res.axhline(0, color='red', linestyle='--')
452
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
453
+ ax_res.scatter(y_pred, residuals,
454
+ alpha=format_config.scatter_alpha,
455
+ color=format_config.scatter_color)
456
+ ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
217
457
  ax_res.set_xlabel("Predicted Values")
218
458
  ax_res.set_ylabel("Residuals")
219
459
  ax_res.set_title("Residual Plot")
@@ -224,10 +464,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
224
464
  _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
225
465
  plt.close(fig_res)
226
466
 
227
- # Save true vs predicted plot
228
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
229
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
230
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
467
+ # --- Save true vs predicted plot ---
468
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
469
+ ax_tvp.scatter(y_true, y_pred,
470
+ alpha=format_config.scatter_alpha,
471
+ color=format_config.scatter_color)
472
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
473
+ linestyle='--',
474
+ lw=2,
475
+ color=format_config.ideal_line_color)
231
476
  ax_tvp.set_xlabel('True Values')
232
477
  ax_tvp.set_ylabel('Predictions')
233
478
  ax_tvp.set_title('True vs. Predicted Values')
@@ -238,9 +483,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
238
483
  _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
239
484
  plt.close(fig_tvp)
240
485
 
241
- # Save Histogram of Residuals
242
- fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
243
- sns.histplot(residuals, kde=True, ax=ax_hist)
486
+ # --- Save Histogram of Residuals ---
487
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
488
+ sns.histplot(residuals, kde=True, ax=ax_hist,
489
+ bins=format_config.hist_bins,
490
+ color=format_config.scatter_color)
244
491
  ax_hist.set_xlabel("Residual Value")
245
492
  ax_hist.set_ylabel("Frequency")
246
493
  ax_hist.set_title("Distribution of Residuals")
@@ -251,6 +498,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
251
498
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
252
499
  plt.close(fig_hist)
253
500
 
501
+ # --- Restore RC params ---
502
+ plt.rcParams.update(original_rc_params)
503
+
254
504
 
255
505
  def shap_summary_plot(model,
256
506
  background_data: Union[torch.Tensor,np.ndarray],
@@ -258,7 +508,7 @@ def shap_summary_plot(model,
258
508
  feature_names: Optional[list[str]],
259
509
  save_dir: Union[str, Path],
260
510
  device: torch.device = torch.device('cpu'),
261
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
511
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
262
512
  """
263
513
  Calculates SHAP values and saves summary plots and data.
264
514
 
@@ -270,13 +520,13 @@ def shap_summary_plot(model,
270
520
  save_dir (str | Path): Directory to save SHAP artifacts.
271
521
  device (torch.device): The torch device for SHAP calculations.
272
522
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
523
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
524
  PyTorch models.
275
525
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
526
  slow and memory-intensive.
277
527
  """
278
528
 
279
- print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
529
+ _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
280
530
 
281
531
  model.eval()
282
532
  # model.cpu() # Run explanations on CPU
@@ -285,7 +535,7 @@ def shap_summary_plot(model,
285
535
  instances_to_explain_np = None
286
536
 
287
537
  if explainer_type == 'deep':
288
- # --- 1. Use DeepExplainer (Preferred) ---
538
+ # --- 1. Use DeepExplainer ---
289
539
 
290
540
  # Ensure data is torch.Tensor
291
541
  if isinstance(background_data, np.ndarray):
@@ -309,10 +559,9 @@ def shap_summary_plot(model,
309
559
  instances_to_explain_np = instances_to_explain.cpu().numpy()
310
560
 
311
561
  elif explainer_type == 'kernel':
312
- # --- 2. Use KernelExplainer (Slow Fallback) ---
562
+ # --- 2. Use KernelExplainer ---
313
563
  _LOGGER.warning(
314
- "Using KernelExplainer. This is memory-intensive and slow. "
315
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
564
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
316
565
  )
317
566
 
318
567
  # Ensure data is np.ndarray
@@ -348,14 +597,26 @@ def shap_summary_plot(model,
348
597
  else:
349
598
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
599
  raise ValueError()
600
+
601
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
602
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
603
+ shap_values = shap_values.squeeze(-1) # type: ignore
351
604
 
352
605
  # --- 3. Plotting and Saving ---
353
606
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
354
607
  plt.ioff()
355
608
 
609
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
610
+ if feature_names is None:
611
+ # Create generic names if none were provided
612
+ num_features = instances_to_explain_np.shape[1]
613
+ feature_names = [f'feature_{i}' for i in range(num_features)]
614
+
615
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
616
+
356
617
  # Save Bar Plot
357
618
  bar_path = save_dir_path / "shap_bar_plot.svg"
358
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
619
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
359
620
  ax = plt.gca()
360
621
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
361
622
  plt.title("SHAP Feature Importance")
@@ -366,7 +627,7 @@ def shap_summary_plot(model,
366
627
 
367
628
  # Save Dot Plot
368
629
  dot_path = save_dir_path / "shap_dot_plot.svg"
369
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
630
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
370
631
  ax = plt.gca()
371
632
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
372
633
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -389,9 +650,6 @@ def shap_summary_plot(model,
389
650
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
651
 
391
652
  mean_abs_shap = mean_abs_shap.flatten()
392
-
393
- if feature_names is None:
394
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
395
653
 
396
654
  summary_df = pd.DataFrame({
397
655
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -401,7 +659,7 @@ def shap_summary_plot(model,
401
659
  summary_df.to_csv(summary_path, index=False)
402
660
 
403
661
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
404
- plt.ion()
662
+ plt.ion()
405
663
 
406
664
 
407
665
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -447,7 +705,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
447
705
  # --- Step 3: Create and save the plot for top N features ---
448
706
  plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
449
707
 
450
- plt.figure(figsize=(10, 8), dpi=100)
708
+ plt.figure(figsize=(10, 8), dpi=DPI_value)
451
709
 
452
710
  # Create horizontal bar plot with error bars
453
711
  plt.barh(