dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +309 -0
  9. ml_tools/ML_datasetmaster.py +220 -260
  10. ml_tools/ML_evaluation.py +317 -81
  11. ml_tools/ML_evaluation_multi.py +127 -36
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1247 -338
  22. ml_tools/ML_utilities.py +51 -2
  23. ml_tools/ML_vision_datasetmaster.py +262 -118
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +15 -1
  27. ml_tools/ML_vision_transformers.py +233 -7
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -1
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +54 -11
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py CHANGED
@@ -21,10 +21,11 @@ from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
22
  import warnings
23
23
 
24
- from .path_manager import make_fullpath
24
+ from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys, PyTorchLogKeys
27
+ from ._keys import SHAPKeys, PyTorchLogKeys
28
+ from .ML_configuration import RegressionMetricsFormat, ClassificationMetricsFormat
28
29
 
29
30
 
30
31
  __all__ = [
@@ -35,10 +36,13 @@ __all__ = [
35
36
  "plot_attention_importance"
36
37
  ]
37
38
 
39
+ DPI_value = 250
40
+
38
41
 
39
42
  def plot_losses(history: dict, save_dir: Union[str, Path]):
40
43
  """
41
44
  Plots training & validation loss curves from a history object.
45
+ Also plots the learning rate if available in the history.
42
46
 
43
47
  Args:
44
48
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
@@ -46,29 +50,52 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
46
50
  """
47
51
  train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
48
52
  val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
53
+ lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
49
54
 
50
55
  if not train_loss and not val_loss:
51
- print("Warning: Loss history is empty or incomplete. Cannot plot.")
56
+ _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
52
57
  return
53
58
 
54
- fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
59
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
60
+
61
+ # --- Plot Losses (Left Y-axis) ---
62
+ line_handles = [] # To store line objects for the legend
55
63
 
56
64
  # Plot training loss only if data for it exists
57
65
  if train_loss:
58
66
  epochs = range(1, len(train_loss) + 1)
59
- ax.plot(epochs, train_loss, 'o-', label='Training Loss')
67
+ line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
68
+ line_handles.append(line1)
60
69
 
61
70
  # Plot validation loss only if data for it exists
62
71
  if val_loss:
63
72
  epochs = range(1, len(val_loss) + 1)
64
- ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
73
+ line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
74
+ line_handles.append(line2)
65
75
 
66
76
  ax.set_title('Training and Validation Loss')
67
77
  ax.set_xlabel('Epochs')
68
- ax.set_ylabel('Loss')
69
- ax.legend()
70
- ax.grid(True)
71
- plt.tight_layout()
78
+ ax.set_ylabel('Loss', color='tab:blue')
79
+ ax.tick_params(axis='y', labelcolor='tab:blue')
80
+ ax.grid(True, linestyle='--')
81
+
82
+ # --- Plot Learning Rate (Right Y-axis) ---
83
+ if lr_history:
84
+ ax2 = ax.twinx() # Create a second y-axis
85
+ epochs = range(1, len(lr_history) + 1)
86
+ line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
87
+ line_handles.append(line3)
88
+
89
+ ax2.set_ylabel('Learning Rate', color='g')
90
+ ax2.tick_params(axis='y', labelcolor='g')
91
+ # Use scientific notation if the LR is very small
92
+ ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
93
+
94
+ # Combine legends from both axes
95
+ ax.legend(handles=line_handles, loc='best')
96
+
97
+ # ax.grid(True)
98
+ plt.tight_layout()
72
99
 
73
100
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
74
101
  save_path = save_dir_path / "loss_plot.svg"
@@ -78,23 +105,49 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
78
105
  plt.close(fig)
79
106
 
80
107
 
81
- def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
82
- cmap: str = "Blues"):
108
+ def classification_metrics(save_dir: Union[str, Path],
109
+ y_true: np.ndarray,
110
+ y_pred: np.ndarray,
111
+ y_prob: Optional[np.ndarray] = None,
112
+ config: Optional[ClassificationMetricsFormat] = None):
83
113
  """
84
114
  Saves classification metrics and plots.
85
115
 
86
116
  Args:
87
117
  y_true (np.ndarray): Ground truth labels.
88
118
  y_pred (np.ndarray): Predicted labels.
89
- y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
90
- cmap (str): Colormap for the confusion matrix.
119
+ y_prob (np.ndarray): Predicted probabilities for ROC curve.
120
+ config (ClassificationMetricsFormat): Formatting configuration object.
91
121
  save_dir (str | Path): Directory to save plots.
92
122
  """
93
- print("--- Classification Report ---")
123
+ # --- Parse Config or use defaults ---
124
+ if config is None:
125
+ # Create a default config if one wasn't provided
126
+ config = ClassificationMetricsFormat()
127
+
128
+ original_rc_params = plt.rcParams.copy()
129
+ plt.rcParams.update({'font.size': config.font_size})
130
+
131
+ # print("--- Classification Report ---")
132
+
133
+ # --- Parse class_map ---
134
+ map_labels = None
135
+ map_display_labels = None
136
+ if config.class_map:
137
+ # Sort the map by its values (the indices) to ensure correct order
138
+ try:
139
+ sorted_items = sorted(config.class_map.items(), key=lambda item: item[1])
140
+ map_labels = [item[1] for item in sorted_items]
141
+ map_display_labels = [item[0] for item in sorted_items]
142
+ except Exception as e:
143
+ _LOGGER.warning(f"Could not parse 'class_map': {e}")
144
+ map_labels = None
145
+ map_display_labels = None
146
+
94
147
  # Generate report as both text and dictionary
95
- report_text: str = classification_report(y_true, y_pred) # type: ignore
96
- report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
97
- print(report_text)
148
+ report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
149
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
150
+ # print(report_text)
98
151
 
99
152
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
100
153
  # Save text report
@@ -104,8 +157,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
104
157
 
105
158
  # --- Save Classification Report Heatmap ---
106
159
  try:
107
- plt.figure(figsize=(8, 6), dpi=100)
108
- sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
160
+ plt.figure(figsize=(8, 6), dpi=DPI_value)
161
+ sns.set_theme(font_scale=1.2) # Scale seaborn font
162
+ sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
163
+ annot=True,
164
+ cmap=config.cmap,
165
+ fmt='.2f',
166
+ vmin=0.0,
167
+ vmax=1.0)
168
+ sns.set_theme(font_scale=1.0) # Reset seaborn scale
109
169
  plt.title("Classification Report")
110
170
  plt.tight_layout()
111
171
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
@@ -114,72 +174,224 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
114
174
  plt.close()
115
175
  except Exception as e:
116
176
  _LOGGER.error(f"Could not generate classification report heatmap: {e}")
117
-
177
+
178
+ # --- labels for Confusion Matrix ---
179
+ plot_labels = map_labels
180
+ plot_display_labels = map_display_labels
181
+
118
182
  # Save Confusion Matrix
119
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
120
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
183
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
184
+ disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
185
+ y_pred,
186
+ cmap=config.cmap,
187
+ ax=ax_cm,
188
+ normalize='true',
189
+ labels=plot_labels,
190
+ display_labels=plot_display_labels)
191
+
192
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
193
+
194
+ # Turn off gridlines
195
+ ax_cm.grid(False)
196
+
197
+ # Manually update font size of cell texts
198
+ for text in ax_cm.texts:
199
+ text.set_fontsize(config.font_size)
200
+
201
+ fig_cm.tight_layout()
202
+
121
203
  ax_cm.set_title("Confusion Matrix")
122
204
  cm_path = save_dir_path / "confusion_matrix.svg"
123
205
  plt.savefig(cm_path)
124
206
  _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
125
207
  plt.close(fig_cm)
126
208
 
127
- # Plotting logic for ROC and PR Curves
128
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
129
- # Use probabilities of the positive class
130
- y_score = y_prob[:, 1]
209
+
210
+ # Plotting logic for ROC, PR, and Calibration Curves
211
+ if y_prob is not None and y_prob.ndim == 2:
212
+ num_classes = y_prob.shape[1]
131
213
 
132
- # --- Save ROC Curve ---
133
- fpr, tpr, _ = roc_curve(y_true, y_score)
134
- auc = roc_auc_score(y_true, y_score)
135
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
136
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
137
- ax_roc.plot([0, 1], [0, 1], 'k--')
138
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
139
- ax_roc.set_xlabel('False Positive Rate')
140
- ax_roc.set_ylabel('True Positive Rate')
141
- ax_roc.legend(loc='lower right')
142
- ax_roc.grid(True)
143
- roc_path = save_dir_path / "roc_curve.svg"
144
- plt.savefig(roc_path)
145
- _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
146
- plt.close(fig_roc)
147
-
148
- # --- Save Precision-Recall Curve ---
149
- precision, recall, _ = precision_recall_curve(y_true, y_score)
150
- ap_score = average_precision_score(y_true, y_score)
151
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
152
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
153
- ax_pr.set_title('Precision-Recall Curve')
154
- ax_pr.set_xlabel('Recall')
155
- ax_pr.set_ylabel('Precision')
156
- ax_pr.legend(loc='lower left')
157
- ax_pr.grid(True)
158
- pr_path = save_dir_path / "pr_curve.svg"
159
- plt.savefig(pr_path)
160
- _LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
161
- plt.close(fig_pr)
214
+ # --- Determine which classes to loop over ---
215
+ class_indices_to_plot = []
216
+ plot_titles = []
217
+ save_suffixes = []
218
+
219
+ if num_classes == 2:
220
+ # Binary case: Only plot for the positive class (index 1)
221
+ class_indices_to_plot = [1]
222
+ plot_titles = [""] # No extra title
223
+ save_suffixes = [""] # No extra suffix
224
+ _LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
162
225
 
163
- # --- Save Calibration Plot ---
164
- if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
165
- y_score = y_prob[:, 1] # Use probabilities of the positive class
226
+ elif num_classes > 2:
227
+ _LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
228
+ # Multiclass case: Plot for every class (One-vs-Rest)
229
+ class_indices_to_plot = list(range(num_classes))
166
230
 
167
- fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
168
- CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
231
+ # --- Use class_map names if available ---
232
+ use_generic_names = True
233
+ if map_display_labels and len(map_display_labels) == num_classes:
234
+ try:
235
+ # Ensure labels are safe for filenames
236
+ safe_names = [sanitize_filename(name) for name in map_display_labels]
237
+ plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
238
+ save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
239
+ use_generic_names = False
240
+ except Exception as e:
241
+ _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
242
+ use_generic_names = True
169
243
 
170
- ax_cal.set_title('Reliability Curve')
244
+ if use_generic_names:
245
+ plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
246
+ save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
247
+
248
+ else:
249
+ # Should not happen, but good to check
250
+ _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
251
+
252
+ # --- Loop and generate plots ---
253
+ for i, class_index in enumerate(class_indices_to_plot):
254
+ plot_title = plot_titles[i]
255
+ save_suffix = save_suffixes[i]
256
+
257
+ # Get scores for the current class
258
+ y_score = y_prob[:, class_index]
259
+
260
+ # Binarize y_true for the current class
261
+ y_true_binary = (y_true == class_index).astype(int)
262
+
263
+ # --- Save ROC Curve ---
264
+ fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
265
+
266
+ try:
267
+ # Calculate Youden's J statistic (tpr - fpr)
268
+ J = tpr - fpr
269
+ # Find the index of the best threshold
270
+ best_index = np.argmax(J)
271
+ optimal_threshold = thresholds[best_index]
272
+
273
+ # Define the filename
274
+ threshold_filename = f"best_threshold{save_suffix}.txt"
275
+ threshold_path = save_dir_path / threshold_filename
276
+
277
+ # Get the class name for the report
278
+ class_name = ""
279
+ # Check if we have display labels and the current index is valid
280
+ if map_display_labels and class_index < len(map_display_labels):
281
+ class_name = map_display_labels[class_index]
282
+ if num_classes > 2:
283
+ # Add 'vs. Rest' for multiclass one-vs-rest plots
284
+ class_name += " (vs. Rest)"
285
+ else:
286
+ # Fallback to the generic title or default binary name
287
+ class_name = plot_title.strip() or "Binary Positive Class"
288
+
289
+ # Create content for the file
290
+ file_content = (
291
+ f"Optimal Classification Threshold (Youden's J Statistic)\n"
292
+ f"Class: {class_name}\n"
293
+ f"--------------------------------------------------\n"
294
+ f"Threshold: {optimal_threshold:.6f}\n"
295
+ f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
296
+ f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
297
+ )
298
+
299
+ threshold_path.write_text(file_content, encoding="utf-8")
300
+ _LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
301
+
302
+ except Exception as e:
303
+ _LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
304
+
305
+ # Calculate AUC.
306
+ auc = roc_auc_score(y_true_binary, y_score)
307
+
308
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
309
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line)
310
+ ax_roc.plot([0, 1], [0, 1], 'k--')
311
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
312
+ ax_roc.set_xlabel('False Positive Rate')
313
+ ax_roc.set_ylabel('True Positive Rate')
314
+ ax_roc.legend(loc='lower right')
315
+ ax_roc.grid(True)
316
+ roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
317
+ plt.savefig(roc_path)
318
+ plt.close(fig_roc)
319
+
320
+ # --- Save Precision-Recall Curve ---
321
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
322
+ ap_score = average_precision_score(y_true_binary, y_score)
323
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
324
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=config.ROC_PR_line)
325
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
326
+ ax_pr.set_xlabel('Recall')
327
+ ax_pr.set_ylabel('Precision')
328
+ ax_pr.legend(loc='lower left')
329
+ ax_pr.grid(True)
330
+ pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
331
+ plt.savefig(pr_path)
332
+ plt.close(fig_pr)
333
+
334
+ # --- Save Calibration Plot ---
335
+ fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
336
+
337
+ # --- Step 1: Get binned data *without* plotting ---
338
+ with plt.ioff(): # Suppress showing the temporary plot
339
+ fig_temp, ax_temp = plt.subplots()
340
+ cal_display_temp = CalibrationDisplay.from_predictions(
341
+ y_true_binary, # Use binarized labels
342
+ y_score,
343
+ n_bins=config.calibration_bins,
344
+ ax=ax_temp,
345
+ name="temp" # Add a name to suppress potential warnings
346
+ )
347
+ # Get the x, y coordinates of the binned data
348
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
349
+ plt.close(fig_temp) # Close the temporary plot
350
+
351
+ # --- Step 2: Build the plot from scratch ---
352
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
353
+
354
+ sns.regplot(
355
+ x=line_x,
356
+ y=line_y,
357
+ ax=ax_cal,
358
+ scatter=False,
359
+ label=f"Calibration Curve ({config.calibration_bins} bins)",
360
+ line_kws={
361
+ 'color': config.ROC_PR_line,
362
+ 'linestyle': '--',
363
+ 'linewidth': 2,
364
+ }
365
+ )
366
+
367
+ ax_cal.set_title(f'Reliability Curve{plot_title}')
171
368
  ax_cal.set_xlabel('Mean Predicted Probability')
172
369
  ax_cal.set_ylabel('Fraction of Positives')
370
+
371
+ # --- Step 3: Set final limits *after* plotting ---
372
+ ax_cal.set_ylim(0.0, 1.0)
373
+ ax_cal.set_xlim(0.0, 1.0)
374
+
375
+ ax_cal.legend(loc='lower right')
173
376
  ax_cal.grid(True)
174
377
  plt.tight_layout()
175
378
 
176
- cal_path = save_dir_path / "calibration_plot.svg"
379
+ cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
177
380
  plt.savefig(cal_path)
178
- _LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
179
381
  plt.close(fig_cal)
382
+
383
+ _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
384
+
385
+ # restore RC params
386
+ plt.rcParams.update(original_rc_params)
180
387
 
181
388
 
182
- def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
389
+ def regression_metrics(
390
+ y_true: np.ndarray,
391
+ y_pred: np.ndarray,
392
+ save_dir: Union[str, Path],
393
+ config: Optional[RegressionMetricsFormat] = None
394
+ ):
183
395
  """
184
396
  Saves regression metrics and plots.
185
397
 
@@ -187,7 +399,19 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
187
399
  y_true (np.ndarray): Ground truth values.
188
400
  y_pred (np.ndarray): Predicted values.
189
401
  save_dir (str | Path): Directory to save plots and report.
402
+ config (RegressionMetricsFormat, optional): Formatting configuration object.
190
403
  """
404
+
405
+ # --- Parse Config or use defaults ---
406
+ if config is None:
407
+ # Create a default config if one wasn't provided
408
+ config = RegressionMetricsFormat()
409
+
410
+ # --- Set Matplotlib font size ---
411
+ original_rc_params = plt.rcParams.copy()
412
+ plt.rcParams.update({'font.size': config.font_size})
413
+
414
+ # --- Calculate Metrics ---
191
415
  rmse = np.sqrt(mean_squared_error(y_true, y_pred))
192
416
  mae = mean_absolute_error(y_true, y_pred)
193
417
  r2 = r2_score(y_true, y_pred)
@@ -209,11 +433,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
209
433
  report_path.write_text(report_string)
210
434
  _LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
211
435
 
212
- # Save residual plot
436
+ # --- Save residual plot ---
213
437
  residuals = y_true - y_pred
214
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
215
- ax_res.scatter(y_pred, residuals, alpha=0.6)
216
- ax_res.axhline(0, color='red', linestyle='--')
438
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
439
+ ax_res.scatter(y_pred, residuals,
440
+ alpha=config.scatter_alpha,
441
+ color=config.scatter_color)
442
+ ax_res.axhline(0, color=config.residual_line_color, linestyle='--')
217
443
  ax_res.set_xlabel("Predicted Values")
218
444
  ax_res.set_ylabel("Residuals")
219
445
  ax_res.set_title("Residual Plot")
@@ -224,10 +450,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
224
450
  _LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
225
451
  plt.close(fig_res)
226
452
 
227
- # Save true vs predicted plot
228
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
229
- ax_tvp.scatter(y_true, y_pred, alpha=0.6)
230
- ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
453
+ # --- Save true vs predicted plot ---
454
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
455
+ ax_tvp.scatter(y_true, y_pred,
456
+ alpha=config.scatter_alpha,
457
+ color=config.scatter_color)
458
+ ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
459
+ linestyle='--',
460
+ lw=2,
461
+ color=config.ideal_line_color)
231
462
  ax_tvp.set_xlabel('True Values')
232
463
  ax_tvp.set_ylabel('Predictions')
233
464
  ax_tvp.set_title('True vs. Predicted Values')
@@ -238,9 +469,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
238
469
  _LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
239
470
  plt.close(fig_tvp)
240
471
 
241
- # Save Histogram of Residuals
242
- fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
243
- sns.histplot(residuals, kde=True, ax=ax_hist)
472
+ # --- Save Histogram of Residuals ---
473
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
474
+ sns.histplot(residuals, kde=True, ax=ax_hist,
475
+ bins=config.hist_bins,
476
+ color=config.scatter_color)
244
477
  ax_hist.set_xlabel("Residual Value")
245
478
  ax_hist.set_ylabel("Frequency")
246
479
  ax_hist.set_title("Distribution of Residuals")
@@ -251,6 +484,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
251
484
  _LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
252
485
  plt.close(fig_hist)
253
486
 
487
+ # --- Restore RC params ---
488
+ plt.rcParams.update(original_rc_params)
489
+
254
490
 
255
491
  def shap_summary_plot(model,
256
492
  background_data: Union[torch.Tensor,np.ndarray],
@@ -276,7 +512,7 @@ def shap_summary_plot(model,
276
512
  slow and memory-intensive.
277
513
  """
278
514
 
279
- print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
515
+ _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
280
516
 
281
517
  model.eval()
282
518
  # model.cpu() # Run explanations on CPU
@@ -348,9 +584,9 @@ def shap_summary_plot(model,
348
584
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
349
585
  raise ValueError()
350
586
 
351
- if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
587
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
352
588
  # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
353
- shap_values = shap_values.squeeze(-1)
589
+ shap_values = shap_values.squeeze(-1) # type: ignore
354
590
 
355
591
  # --- 3. Plotting and Saving ---
356
592
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
@@ -455,7 +691,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
455
691
  # --- Step 3: Create and save the plot for top N features ---
456
692
  plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
457
693
 
458
- plt.figure(figsize=(10, 8), dpi=100)
694
+ plt.figure(figsize=(10, 8), dpi=DPI_value)
459
695
 
460
696
  # Create horizontal bar plot with error bars
461
697
  plt.barh(