dragon-ml-toolbox 13.3.0__py3-none-any.whl → 14.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ml_tools/ML_evaluation.py CHANGED
@@ -21,10 +21,10 @@ from pathlib import Path
21
21
  from typing import Union, Optional, List, Literal
22
22
  import warnings
23
23
 
24
- from .path_manager import make_fullpath
24
+ from .path_manager import make_fullpath, sanitize_filename
25
25
  from ._logger import _LOGGER
26
26
  from ._script_info import _script_info
27
- from .keys import SHAPKeys
27
+ from .keys import SHAPKeys, PyTorchLogKeys
28
28
 
29
29
 
30
30
  __all__ = [
@@ -35,6 +35,8 @@ __all__ = [
35
35
  "plot_attention_importance"
36
36
  ]
37
37
 
38
+ DPI_value = 250
39
+
38
40
 
39
41
  def plot_losses(history: dict, save_dir: Union[str, Path]):
40
42
  """
@@ -44,14 +46,14 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
44
46
  history (dict): A dictionary containing 'train_loss' and 'val_loss'.
45
47
  save_dir (str | Path): Directory to save the plot image.
46
48
  """
47
- train_loss = history.get('train_loss', [])
48
- val_loss = history.get('val_loss', [])
49
+ train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
50
+ val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
49
51
 
50
52
  if not train_loss and not val_loss:
51
- print("Warning: Loss history is empty or incomplete. Cannot plot.")
53
+ _LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
52
54
  return
53
55
 
54
- fig, ax = plt.subplots(figsize=(10, 5), dpi=100)
56
+ fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
55
57
 
56
58
  # Plot training loss only if data for it exists
57
59
  if train_loss:
@@ -78,8 +80,15 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
78
80
  plt.close(fig)
79
81
 
80
82
 
81
- def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pred: np.ndarray, y_prob: Optional[np.ndarray] = None,
82
- cmap: str = "Blues"):
83
+ def classification_metrics(save_dir: Union[str, Path],
84
+ y_true: np.ndarray,
85
+ y_pred: np.ndarray,
86
+ y_prob: Optional[np.ndarray] = None,
87
+ cmap: str = "Blues",
88
+ class_map: Optional[dict[str,int]]=None,
89
+ ROC_PR_line: str='darkorange',
90
+ calibration_bins: int=15,
91
+ font_size: int=16):
83
92
  """
84
93
  Saves classification metrics and plots.
85
94
 
@@ -89,12 +98,31 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
89
98
  y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
90
99
  cmap (str): Colormap for the confusion matrix.
91
100
  save_dir (str | Path): Directory to save plots.
101
+ class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
92
102
  """
93
- print("--- Classification Report ---")
103
+ original_rc_params = plt.rcParams.copy()
104
+ plt.rcParams.update({'font.size': font_size})
105
+
106
+ # print("--- Classification Report ---")
107
+
108
+ # --- Parse class_map ---
109
+ map_labels = None
110
+ map_display_labels = None
111
+ if class_map:
112
+ # Sort the map by its values (the indices) to ensure correct order
113
+ try:
114
+ sorted_items = sorted(class_map.items(), key=lambda item: item[1])
115
+ map_labels = [item[1] for item in sorted_items]
116
+ map_display_labels = [item[0] for item in sorted_items]
117
+ except Exception as e:
118
+ _LOGGER.warning(f"Could not parse 'class_map': {e}")
119
+ map_labels = None
120
+ map_display_labels = None
121
+
94
122
  # Generate report as both text and dictionary
95
- report_text: str = classification_report(y_true, y_pred) # type: ignore
96
- report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
97
- print(report_text)
123
+ report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
124
+ report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
125
+ # print(report_text)
98
126
 
99
127
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
100
128
  # Save text report
@@ -104,8 +132,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
104
132
 
105
133
  # --- Save Classification Report Heatmap ---
106
134
  try:
107
- plt.figure(figsize=(8, 6), dpi=100)
108
- sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T, annot=True, cmap='viridis', fmt='.2f')
135
+ plt.figure(figsize=(8, 6), dpi=DPI_value)
136
+ sns.set_theme(font_scale=1.2) # Scale seaborn font
137
+ sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
138
+ annot=True,
139
+ cmap=cmap,
140
+ fmt='.2f',
141
+ vmin=0.0,
142
+ vmax=1.0)
143
+ sns.set_theme(font_scale=1.0) # Reset seaborn scale
109
144
  plt.title("Classification Report")
110
145
  plt.tight_layout()
111
146
  heatmap_path = save_dir_path / "classification_report_heatmap.svg"
@@ -114,69 +149,179 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
114
149
  plt.close()
115
150
  except Exception as e:
116
151
  _LOGGER.error(f"Could not generate classification report heatmap: {e}")
117
-
152
+
153
+ # --- labels for Confusion Matrix ---
154
+ plot_labels = map_labels
155
+ plot_display_labels = map_display_labels
156
+
118
157
  # Save Confusion Matrix
119
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
120
- ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, ax=ax_cm)
158
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
159
+ disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
160
+ y_pred,
161
+ cmap=cmap,
162
+ ax=ax_cm,
163
+ normalize='true',
164
+ labels=plot_labels,
165
+ display_labels=plot_display_labels)
166
+
167
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
168
+
169
+ # Turn off gridlines
170
+ ax_cm.grid(False)
171
+
172
+ # Manually update font size of cell texts
173
+ for text in ax_cm.texts:
174
+ text.set_fontsize(font_size)
175
+
176
+ fig_cm.tight_layout()
177
+
121
178
  ax_cm.set_title("Confusion Matrix")
122
179
  cm_path = save_dir_path / "confusion_matrix.svg"
123
180
  plt.savefig(cm_path)
124
181
  _LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
125
182
  plt.close(fig_cm)
126
183
 
127
- # Plotting logic for ROC and PR Curves
128
- if y_prob is not None and y_prob.ndim > 1 and y_prob.shape[1] >= 2:
129
- # Use probabilities of the positive class
130
- y_score = y_prob[:, 1]
184
+
185
+ # Plotting logic for ROC, PR, and Calibration Curves
186
+ if y_prob is not None and y_prob.ndim == 2:
187
+ num_classes = y_prob.shape[1]
131
188
 
132
- # --- Save ROC Curve ---
133
- fpr, tpr, _ = roc_curve(y_true, y_score)
134
- auc = roc_auc_score(y_true, y_score)
135
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
136
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
137
- ax_roc.plot([0, 1], [0, 1], 'k--')
138
- ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
139
- ax_roc.set_xlabel('False Positive Rate')
140
- ax_roc.set_ylabel('True Positive Rate')
141
- ax_roc.legend(loc='lower right')
142
- ax_roc.grid(True)
143
- roc_path = save_dir_path / "roc_curve.svg"
144
- plt.savefig(roc_path)
145
- _LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
146
- plt.close(fig_roc)
147
-
148
- # --- Save Precision-Recall Curve ---
149
- precision, recall, _ = precision_recall_curve(y_true, y_score)
150
- ap_score = average_precision_score(y_true, y_score)
151
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
152
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
153
- ax_pr.set_title('Precision-Recall Curve')
154
- ax_pr.set_xlabel('Recall')
155
- ax_pr.set_ylabel('Precision')
156
- ax_pr.legend(loc='lower left')
157
- ax_pr.grid(True)
158
- pr_path = save_dir_path / "pr_curve.svg"
159
- plt.savefig(pr_path)
160
- _LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
161
- plt.close(fig_pr)
189
+ # --- Determine which classes to loop over ---
190
+ class_indices_to_plot = []
191
+ plot_titles = []
192
+ save_suffixes = []
193
+
194
+ if num_classes == 2:
195
+ # Binary case: Only plot for the positive class (index 1)
196
+ class_indices_to_plot = [1]
197
+ plot_titles = [""] # No extra title
198
+ save_suffixes = [""] # No extra suffix
199
+ _LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
162
200
 
163
- # --- Save Calibration Plot ---
164
- if y_prob.ndim > 1 and y_prob.shape[1] >= 2:
165
- y_score = y_prob[:, 1] # Use probabilities of the positive class
201
+ elif num_classes > 2:
202
+ _LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
203
+ # Multiclass case: Plot for every class (One-vs-Rest)
204
+ class_indices_to_plot = list(range(num_classes))
205
+
206
+ # --- Use class_map names if available ---
207
+ use_generic_names = True
208
+ if map_display_labels and len(map_display_labels) == num_classes:
209
+ try:
210
+ # Ensure labels are safe for filenames
211
+ safe_names = [sanitize_filename(name) for name in map_display_labels]
212
+ plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
213
+ save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
214
+ use_generic_names = False
215
+ except Exception as e:
216
+ _LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
217
+ use_generic_names = True
218
+
219
+ if use_generic_names:
220
+ plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
221
+ save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
222
+
223
+ else:
224
+ # Should not happen, but good to check
225
+ _LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
226
+
227
+ # --- Loop and generate plots ---
228
+ for i, class_index in enumerate(class_indices_to_plot):
229
+ plot_title = plot_titles[i]
230
+ save_suffix = save_suffixes[i]
231
+
232
+ # Get scores for the current class
233
+ y_score = y_prob[:, class_index]
234
+
235
+ # Binarize y_true for the current class
236
+ y_true_binary = (y_true == class_index).astype(int)
237
+
238
+ # --- Save ROC Curve ---
239
+ fpr, tpr, _ = roc_curve(y_true_binary, y_score)
240
+
241
+ # Calculate AUC.
242
+ # Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
243
+ # Here we calculate the specific AUC for the binarized problem.
244
+ auc = roc_auc_score(y_true_binary, y_score)
166
245
 
167
- fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=100)
168
- CalibrationDisplay.from_predictions(y_true, y_score, n_bins=15, ax=ax_cal)
246
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
247
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
248
+ ax_roc.plot([0, 1], [0, 1], 'k--')
249
+ ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
250
+ ax_roc.set_xlabel('False Positive Rate')
251
+ ax_roc.set_ylabel('True Positive Rate')
252
+ ax_roc.legend(loc='lower right')
253
+ ax_roc.grid(True)
254
+ roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
255
+ plt.savefig(roc_path)
256
+ plt.close(fig_roc)
257
+
258
+ # --- Save Precision-Recall Curve ---
259
+ precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
260
+ ap_score = average_precision_score(y_true_binary, y_score)
261
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
262
+ ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
263
+ ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
264
+ ax_pr.set_xlabel('Recall')
265
+ ax_pr.set_ylabel('Precision')
266
+ ax_pr.legend(loc='lower left')
267
+ ax_pr.grid(True)
268
+ pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
269
+ plt.savefig(pr_path)
270
+ plt.close(fig_pr)
169
271
 
170
- ax_cal.set_title('Reliability Curve')
272
+ # --- Save Calibration Plot ---
273
+ fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
274
+
275
+ # --- Step 1: Get binned data *without* plotting ---
276
+ with plt.ioff(): # Suppress showing the temporary plot
277
+ fig_temp, ax_temp = plt.subplots()
278
+ cal_display_temp = CalibrationDisplay.from_predictions(
279
+ y_true_binary, # Use binarized labels
280
+ y_score,
281
+ n_bins=calibration_bins,
282
+ ax=ax_temp,
283
+ name="temp" # Add a name to suppress potential warnings
284
+ )
285
+ # Get the x, y coordinates of the binned data
286
+ line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
287
+ plt.close(fig_temp) # Close the temporary plot
288
+
289
+ # --- Step 2: Build the plot from scratch ---
290
+ ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
291
+
292
+ sns.regplot(
293
+ x=line_x,
294
+ y=line_y,
295
+ ax=ax_cal,
296
+ scatter=False,
297
+ label=f"Calibration Curve ({calibration_bins} bins)",
298
+ line_kws={
299
+ 'color': ROC_PR_line,
300
+ 'linestyle': '--',
301
+ 'linewidth': 2,
302
+ }
303
+ )
304
+
305
+ ax_cal.set_title(f'Reliability Curve{plot_title}')
171
306
  ax_cal.set_xlabel('Mean Predicted Probability')
172
307
  ax_cal.set_ylabel('Fraction of Positives')
308
+
309
+ # --- Step 3: Set final limits *after* plotting ---
310
+ ax_cal.set_ylim(0.0, 1.0)
311
+ ax_cal.set_xlim(0.0, 1.0)
312
+
313
+ ax_cal.legend(loc='lower right')
173
314
  ax_cal.grid(True)
174
315
  plt.tight_layout()
175
316
 
176
- cal_path = save_dir_path / "calibration_plot.svg"
317
+ cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
177
318
  plt.savefig(cal_path)
178
- _LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
179
319
  plt.close(fig_cal)
320
+
321
+ _LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
322
+
323
+ # restore RC params
324
+ plt.rcParams.update(original_rc_params)
180
325
 
181
326
 
182
327
  def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
@@ -211,7 +356,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
211
356
 
212
357
  # Save residual plot
213
358
  residuals = y_true - y_pred
214
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
359
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
215
360
  ax_res.scatter(y_pred, residuals, alpha=0.6)
216
361
  ax_res.axhline(0, color='red', linestyle='--')
217
362
  ax_res.set_xlabel("Predicted Values")
@@ -225,7 +370,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
225
370
  plt.close(fig_res)
226
371
 
227
372
  # Save true vs predicted plot
228
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
373
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
229
374
  ax_tvp.scatter(y_true, y_pred, alpha=0.6)
230
375
  ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
231
376
  ax_tvp.set_xlabel('True Values')
@@ -239,7 +384,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
239
384
  plt.close(fig_tvp)
240
385
 
241
386
  # Save Histogram of Residuals
242
- fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=100)
387
+ fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
243
388
  sns.histplot(residuals, kde=True, ax=ax_hist)
244
389
  ax_hist.set_xlabel("Residual Value")
245
390
  ax_hist.set_ylabel("Frequency")
@@ -258,7 +403,7 @@ def shap_summary_plot(model,
258
403
  feature_names: Optional[list[str]],
259
404
  save_dir: Union[str, Path],
260
405
  device: torch.device = torch.device('cpu'),
261
- explainer_type: Literal['deep', 'kernel'] = 'deep'):
406
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'):
262
407
  """
263
408
  Calculates SHAP values and saves summary plots and data.
264
409
 
@@ -270,13 +415,13 @@ def shap_summary_plot(model,
270
415
  save_dir (str | Path): Directory to save SHAP artifacts.
271
416
  device (torch.device): The torch device for SHAP calculations.
272
417
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
273
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient for
418
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient for
274
419
  PyTorch models.
275
420
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
276
421
  slow and memory-intensive.
277
422
  """
278
423
 
279
- print(f"\n--- SHAP Value Explanation Using {explainer_type.upper()} Explainer ---")
424
+ _LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
280
425
 
281
426
  model.eval()
282
427
  # model.cpu() # Run explanations on CPU
@@ -285,7 +430,7 @@ def shap_summary_plot(model,
285
430
  instances_to_explain_np = None
286
431
 
287
432
  if explainer_type == 'deep':
288
- # --- 1. Use DeepExplainer (Preferred) ---
433
+ # --- 1. Use DeepExplainer ---
289
434
 
290
435
  # Ensure data is torch.Tensor
291
436
  if isinstance(background_data, np.ndarray):
@@ -309,10 +454,9 @@ def shap_summary_plot(model,
309
454
  instances_to_explain_np = instances_to_explain.cpu().numpy()
310
455
 
311
456
  elif explainer_type == 'kernel':
312
- # --- 2. Use KernelExplainer (Slow Fallback) ---
457
+ # --- 2. Use KernelExplainer ---
313
458
  _LOGGER.warning(
314
- "Using KernelExplainer. This is memory-intensive and slow. "
315
- "Consider reducing 'n_samples' if the process terminates unexpectedly."
459
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
316
460
  )
317
461
 
318
462
  # Ensure data is np.ndarray
@@ -348,14 +492,26 @@ def shap_summary_plot(model,
348
492
  else:
349
493
  _LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
350
494
  raise ValueError()
495
+
496
+ if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
497
+ # _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
498
+ shap_values = shap_values.squeeze(-1) # type: ignore
351
499
 
352
500
  # --- 3. Plotting and Saving ---
353
501
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
354
502
  plt.ioff()
355
503
 
504
+ # Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
505
+ if feature_names is None:
506
+ # Create generic names if none were provided
507
+ num_features = instances_to_explain_np.shape[1]
508
+ feature_names = [f'feature_{i}' for i in range(num_features)]
509
+
510
+ instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
511
+
356
512
  # Save Bar Plot
357
513
  bar_path = save_dir_path / "shap_bar_plot.svg"
358
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
514
+ shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
359
515
  ax = plt.gca()
360
516
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
361
517
  plt.title("SHAP Feature Importance")
@@ -366,7 +522,7 @@ def shap_summary_plot(model,
366
522
 
367
523
  # Save Dot Plot
368
524
  dot_path = save_dir_path / "shap_dot_plot.svg"
369
- shap.summary_plot(shap_values, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
525
+ shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
370
526
  ax = plt.gca()
371
527
  ax.set_xlabel("SHAP Value Impact", labelpad=10)
372
528
  if plt.gcf().axes and len(plt.gcf().axes) > 1:
@@ -389,9 +545,6 @@ def shap_summary_plot(model,
389
545
  mean_abs_shap = np.abs(shap_values).mean(axis=0)
390
546
 
391
547
  mean_abs_shap = mean_abs_shap.flatten()
392
-
393
- if feature_names is None:
394
- feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
395
548
 
396
549
  summary_df = pd.DataFrame({
397
550
  SHAPKeys.FEATURE_COLUMN: feature_names,
@@ -401,7 +554,7 @@ def shap_summary_plot(model,
401
554
  summary_df.to_csv(summary_path, index=False)
402
555
 
403
556
  _LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
404
- plt.ion()
557
+ plt.ion()
405
558
 
406
559
 
407
560
  def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
@@ -447,7 +600,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
447
600
  # --- Step 3: Create and save the plot for top N features ---
448
601
  plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
449
602
 
450
- plt.figure(figsize=(10, 8), dpi=100)
603
+ plt.figure(figsize=(10, 8), dpi=DPI_value)
451
604
 
452
605
  # Create horizontal bar plot with error bars
453
606
  plt.barh(
@@ -34,6 +34,8 @@ __all__ = [
34
34
  "multi_target_shap_summary_plot",
35
35
  ]
36
36
 
37
+ DPI_value = 250
38
+
37
39
 
38
40
  def multi_target_regression_metrics(
39
41
  y_true: np.ndarray,
@@ -90,7 +92,7 @@ def multi_target_regression_metrics(
90
92
 
91
93
  # --- Save Residual Plot ---
92
94
  residuals = true_i - pred_i
93
- fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
95
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
94
96
  ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
95
97
  ax_res.axhline(0, color='red', linestyle='--')
96
98
  ax_res.set_xlabel("Predicted Values")
@@ -103,7 +105,7 @@ def multi_target_regression_metrics(
103
105
  plt.close(fig_res)
104
106
 
105
107
  # --- Save True vs. Predicted Plot ---
106
- fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
108
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
107
109
  ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
108
110
  ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
109
111
  ax_tvp.set_xlabel('True Values')
@@ -127,7 +129,10 @@ def multi_label_classification_metrics(
127
129
  y_prob: np.ndarray,
128
130
  target_names: List[str],
129
131
  save_dir: Union[str, Path],
130
- threshold: float = 0.5
132
+ threshold: float = 0.5,
133
+ ROC_PR_line: str='darkorange',
134
+ cmap: str = "Blues",
135
+ font_size: int = 16
131
136
  ):
132
137
  """
133
138
  Calculates and saves classification metrics for each label individually.
@@ -158,6 +163,10 @@ def multi_label_classification_metrics(
158
163
 
159
164
  # Generate binary predictions from probabilities
160
165
  y_pred = (y_prob >= threshold).astype(int)
166
+
167
+ # --- Save current RC params and update font size ---
168
+ original_rc_params = plt.rcParams.copy()
169
+ plt.rcParams.update({'font.size': font_size})
161
170
 
162
171
  _LOGGER.info("--- Multi-Label Classification Evaluation ---")
163
172
 
@@ -174,7 +183,7 @@ def multi_label_classification_metrics(
174
183
  f"Jaccard Score (macro): {j_score_macro:.4f}\n"
175
184
  f"--------------------------------------------------\n"
176
185
  )
177
- print(overall_report)
186
+ # print(overall_report)
178
187
  overall_report_path = save_dir_path / "classification_report_overall.txt"
179
188
  overall_report_path.write_text(overall_report)
180
189
 
@@ -192,8 +201,26 @@ def multi_label_classification_metrics(
192
201
  report_path.write_text(report_text) # type: ignore
193
202
 
194
203
  # --- Save Confusion Matrix ---
195
- fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
196
- ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
204
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
205
+ disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
206
+ pred_i,
207
+ cmap=cmap,
208
+ ax=ax_cm,
209
+ normalize='true',
210
+ labels=[0, 1],
211
+ display_labels=["Negative", "Positive"])
212
+
213
+ disp_.im_.set_clim(vmin=0.0, vmax=1.0)
214
+
215
+ # Turn off gridlines
216
+ ax_cm.grid(False)
217
+
218
+ # Manually update font size of cell texts
219
+ for text in ax_cm.texts:
220
+ text.set_fontsize(font_size)
221
+
222
+ fig_cm.tight_layout()
223
+
197
224
  ax_cm.set_title(f"Confusion Matrix for '{name}'")
198
225
  cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
199
226
  plt.savefig(cm_path)
@@ -202,8 +229,8 @@ def multi_label_classification_metrics(
202
229
  # --- Save ROC Curve ---
203
230
  fpr, tpr, _ = roc_curve(true_i, prob_i)
204
231
  auc = roc_auc_score(true_i, prob_i)
205
- fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
206
- ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
232
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
233
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
207
234
  ax_roc.plot([0, 1], [0, 1], 'k--')
208
235
  ax_roc.set_title(f'ROC Curve for "{name}"')
209
236
  ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
@@ -215,14 +242,17 @@ def multi_label_classification_metrics(
215
242
  # --- Save Precision-Recall Curve ---
216
243
  precision, recall, _ = precision_recall_curve(true_i, prob_i)
217
244
  ap_score = average_precision_score(true_i, prob_i)
218
- fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
219
- ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
245
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
246
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
220
247
  ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
221
248
  ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
222
249
  ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
223
250
  pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
224
251
  plt.savefig(pr_path)
225
252
  plt.close(fig_pr)
253
+
254
+ # restore RC params
255
+ plt.rcParams.update(original_rc_params)
226
256
 
227
257
  _LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
228
258
 
@@ -235,7 +265,7 @@ def multi_target_shap_summary_plot(
235
265
  target_names: List[str],
236
266
  save_dir: Union[str, Path],
237
267
  device: torch.device = torch.device('cpu'),
238
- explainer_type: Literal['deep', 'kernel'] = 'deep'
268
+ explainer_type: Literal['deep', 'kernel'] = 'kernel'
239
269
  ):
240
270
  """
241
271
  Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
@@ -249,7 +279,7 @@ def multi_target_shap_summary_plot(
249
279
  save_dir (str | Path): Directory to save SHAP artifacts.
250
280
  device (torch.device): The torch device for SHAP calculations.
251
281
  explainer_type (Literal['deep', 'kernel']): The explainer to use.
252
- - 'deep': (Default) Uses shap.DeepExplainer. Fast and efficient.
282
+ - 'deep': Uses shap.DeepExplainer. Fast and efficient.
253
283
  - 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
254
284
  """
255
285
  _LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
@@ -260,7 +290,7 @@ def multi_target_shap_summary_plot(
260
290
  instances_to_explain_np = None
261
291
 
262
292
  if explainer_type == 'deep':
263
- # --- 1. Use DeepExplainer (Preferred) ---
293
+ # --- 1. Use DeepExplainer ---
264
294
 
265
295
  # Ensure data is torch.Tensor
266
296
  if isinstance(background_data, np.ndarray):
@@ -285,10 +315,9 @@ def multi_target_shap_summary_plot(
285
315
  instances_to_explain_np = instances_to_explain.cpu().numpy()
286
316
 
287
317
  elif explainer_type == 'kernel':
288
- # --- 2. Use KernelExplainer (Slow Fallback) ---
318
+ # --- 2. Use KernelExplainer ---
289
319
  _LOGGER.warning(
290
- "Using KernelExplainer. This is memory-intensive and slow. "
291
- "Consider reducing 'n_samples' if the process terminates."
320
+ "KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
292
321
  )
293
322
 
294
323
  # Convert all data to numpy
ml_tools/ML_inference.py CHANGED
@@ -82,7 +82,6 @@ class _BaseInferenceHandler(ABC):
82
82
  _LOGGER.warning("CUDA not available, switching to CPU.")
83
83
  device_lower = "cpu"
84
84
  elif device_lower == "mps" and not torch.backends.mps.is_available():
85
- # Your M-series Mac will appreciate this check!
86
85
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
87
86
  device_lower = "cpu"
88
87
  return torch.device(device_lower)