dragon-ml-toolbox 13.3.0__py3-none-any.whl → 14.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_configuration.py +108 -0
- ml_tools/ML_datasetmaster.py +106 -206
- ml_tools/ML_evaluation.py +229 -76
- ml_tools/ML_evaluation_multi.py +45 -16
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +22 -6
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_trainer.py +498 -29
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1492 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +203 -0
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +39 -0
- ml_tools/math_utilities.py +1 -1
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -21,10 +21,10 @@ from pathlib import Path
|
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
22
|
import warnings
|
|
23
23
|
|
|
24
|
-
from .path_manager import make_fullpath
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .keys import SHAPKeys
|
|
27
|
+
from .keys import SHAPKeys, PyTorchLogKeys
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
__all__ = [
|
|
@@ -35,6 +35,8 @@ __all__ = [
|
|
|
35
35
|
"plot_attention_importance"
|
|
36
36
|
]
|
|
37
37
|
|
|
38
|
+
DPI_value = 250
|
|
39
|
+
|
|
38
40
|
|
|
39
41
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
40
42
|
"""
|
|
@@ -44,14 +46,14 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
44
46
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
45
47
|
save_dir (str | Path): Directory to save the plot image.
|
|
46
48
|
"""
|
|
47
|
-
train_loss = history.get(
|
|
48
|
-
val_loss = history.get(
|
|
49
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
50
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
49
51
|
|
|
50
52
|
if not train_loss and not val_loss:
|
|
51
|
-
|
|
53
|
+
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
52
54
|
return
|
|
53
55
|
|
|
54
|
-
fig, ax = plt.subplots(figsize=(10, 5), dpi=
|
|
56
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
55
57
|
|
|
56
58
|
# Plot training loss only if data for it exists
|
|
57
59
|
if train_loss:
|
|
@@ -78,8 +80,15 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
78
80
|
plt.close(fig)
|
|
79
81
|
|
|
80
82
|
|
|
81
|
-
def classification_metrics(save_dir: Union[str, Path],
|
|
82
|
-
|
|
83
|
+
def classification_metrics(save_dir: Union[str, Path],
|
|
84
|
+
y_true: np.ndarray,
|
|
85
|
+
y_pred: np.ndarray,
|
|
86
|
+
y_prob: Optional[np.ndarray] = None,
|
|
87
|
+
cmap: str = "Blues",
|
|
88
|
+
class_map: Optional[dict[str,int]]=None,
|
|
89
|
+
ROC_PR_line: str='darkorange',
|
|
90
|
+
calibration_bins: int=15,
|
|
91
|
+
font_size: int=16):
|
|
83
92
|
"""
|
|
84
93
|
Saves classification metrics and plots.
|
|
85
94
|
|
|
@@ -89,12 +98,31 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
89
98
|
y_prob (np.ndarray, optional): Predicted probabilities for ROC curve.
|
|
90
99
|
cmap (str): Colormap for the confusion matrix.
|
|
91
100
|
save_dir (str | Path): Directory to save plots.
|
|
101
|
+
class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
|
|
92
102
|
"""
|
|
93
|
-
|
|
103
|
+
original_rc_params = plt.rcParams.copy()
|
|
104
|
+
plt.rcParams.update({'font.size': font_size})
|
|
105
|
+
|
|
106
|
+
# print("--- Classification Report ---")
|
|
107
|
+
|
|
108
|
+
# --- Parse class_map ---
|
|
109
|
+
map_labels = None
|
|
110
|
+
map_display_labels = None
|
|
111
|
+
if class_map:
|
|
112
|
+
# Sort the map by its values (the indices) to ensure correct order
|
|
113
|
+
try:
|
|
114
|
+
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
115
|
+
map_labels = [item[1] for item in sorted_items]
|
|
116
|
+
map_display_labels = [item[0] for item in sorted_items]
|
|
117
|
+
except Exception as e:
|
|
118
|
+
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
119
|
+
map_labels = None
|
|
120
|
+
map_display_labels = None
|
|
121
|
+
|
|
94
122
|
# Generate report as both text and dictionary
|
|
95
|
-
report_text: str = classification_report(y_true, y_pred) # type: ignore
|
|
96
|
-
report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
|
|
97
|
-
print(report_text)
|
|
123
|
+
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
124
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
125
|
+
# print(report_text)
|
|
98
126
|
|
|
99
127
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
100
128
|
# Save text report
|
|
@@ -104,8 +132,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
104
132
|
|
|
105
133
|
# --- Save Classification Report Heatmap ---
|
|
106
134
|
try:
|
|
107
|
-
plt.figure(figsize=(8, 6), dpi=
|
|
108
|
-
sns.
|
|
135
|
+
plt.figure(figsize=(8, 6), dpi=DPI_value)
|
|
136
|
+
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
137
|
+
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
138
|
+
annot=True,
|
|
139
|
+
cmap=cmap,
|
|
140
|
+
fmt='.2f',
|
|
141
|
+
vmin=0.0,
|
|
142
|
+
vmax=1.0)
|
|
143
|
+
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
109
144
|
plt.title("Classification Report")
|
|
110
145
|
plt.tight_layout()
|
|
111
146
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
@@ -114,69 +149,179 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
114
149
|
plt.close()
|
|
115
150
|
except Exception as e:
|
|
116
151
|
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
117
|
-
|
|
152
|
+
|
|
153
|
+
# --- labels for Confusion Matrix ---
|
|
154
|
+
plot_labels = map_labels
|
|
155
|
+
plot_display_labels = map_display_labels
|
|
156
|
+
|
|
118
157
|
# Save Confusion Matrix
|
|
119
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
120
|
-
ConfusionMatrixDisplay.from_predictions(y_true,
|
|
158
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
159
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
160
|
+
y_pred,
|
|
161
|
+
cmap=cmap,
|
|
162
|
+
ax=ax_cm,
|
|
163
|
+
normalize='true',
|
|
164
|
+
labels=plot_labels,
|
|
165
|
+
display_labels=plot_display_labels)
|
|
166
|
+
|
|
167
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
168
|
+
|
|
169
|
+
# Turn off gridlines
|
|
170
|
+
ax_cm.grid(False)
|
|
171
|
+
|
|
172
|
+
# Manually update font size of cell texts
|
|
173
|
+
for text in ax_cm.texts:
|
|
174
|
+
text.set_fontsize(font_size)
|
|
175
|
+
|
|
176
|
+
fig_cm.tight_layout()
|
|
177
|
+
|
|
121
178
|
ax_cm.set_title("Confusion Matrix")
|
|
122
179
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
123
180
|
plt.savefig(cm_path)
|
|
124
181
|
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
125
182
|
plt.close(fig_cm)
|
|
126
183
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
184
|
+
|
|
185
|
+
# Plotting logic for ROC, PR, and Calibration Curves
|
|
186
|
+
if y_prob is not None and y_prob.ndim == 2:
|
|
187
|
+
num_classes = y_prob.shape[1]
|
|
131
188
|
|
|
132
|
-
# ---
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
roc_path = save_dir_path / "roc_curve.svg"
|
|
144
|
-
plt.savefig(roc_path)
|
|
145
|
-
_LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
|
|
146
|
-
plt.close(fig_roc)
|
|
147
|
-
|
|
148
|
-
# --- Save Precision-Recall Curve ---
|
|
149
|
-
precision, recall, _ = precision_recall_curve(y_true, y_score)
|
|
150
|
-
ap_score = average_precision_score(y_true, y_score)
|
|
151
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
|
|
152
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
153
|
-
ax_pr.set_title('Precision-Recall Curve')
|
|
154
|
-
ax_pr.set_xlabel('Recall')
|
|
155
|
-
ax_pr.set_ylabel('Precision')
|
|
156
|
-
ax_pr.legend(loc='lower left')
|
|
157
|
-
ax_pr.grid(True)
|
|
158
|
-
pr_path = save_dir_path / "pr_curve.svg"
|
|
159
|
-
plt.savefig(pr_path)
|
|
160
|
-
_LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
|
|
161
|
-
plt.close(fig_pr)
|
|
189
|
+
# --- Determine which classes to loop over ---
|
|
190
|
+
class_indices_to_plot = []
|
|
191
|
+
plot_titles = []
|
|
192
|
+
save_suffixes = []
|
|
193
|
+
|
|
194
|
+
if num_classes == 2:
|
|
195
|
+
# Binary case: Only plot for the positive class (index 1)
|
|
196
|
+
class_indices_to_plot = [1]
|
|
197
|
+
plot_titles = [""] # No extra title
|
|
198
|
+
save_suffixes = [""] # No extra suffix
|
|
199
|
+
_LOGGER.info("Generating binary classification plots (ROC, PR, Calibration).")
|
|
162
200
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
201
|
+
elif num_classes > 2:
|
|
202
|
+
_LOGGER.info(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
203
|
+
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
204
|
+
class_indices_to_plot = list(range(num_classes))
|
|
205
|
+
|
|
206
|
+
# --- Use class_map names if available ---
|
|
207
|
+
use_generic_names = True
|
|
208
|
+
if map_display_labels and len(map_display_labels) == num_classes:
|
|
209
|
+
try:
|
|
210
|
+
# Ensure labels are safe for filenames
|
|
211
|
+
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
212
|
+
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
213
|
+
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
214
|
+
use_generic_names = False
|
|
215
|
+
except Exception as e:
|
|
216
|
+
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
217
|
+
use_generic_names = True
|
|
218
|
+
|
|
219
|
+
if use_generic_names:
|
|
220
|
+
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
221
|
+
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
222
|
+
|
|
223
|
+
else:
|
|
224
|
+
# Should not happen, but good to check
|
|
225
|
+
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
226
|
+
|
|
227
|
+
# --- Loop and generate plots ---
|
|
228
|
+
for i, class_index in enumerate(class_indices_to_plot):
|
|
229
|
+
plot_title = plot_titles[i]
|
|
230
|
+
save_suffix = save_suffixes[i]
|
|
231
|
+
|
|
232
|
+
# Get scores for the current class
|
|
233
|
+
y_score = y_prob[:, class_index]
|
|
234
|
+
|
|
235
|
+
# Binarize y_true for the current class
|
|
236
|
+
y_true_binary = (y_true == class_index).astype(int)
|
|
237
|
+
|
|
238
|
+
# --- Save ROC Curve ---
|
|
239
|
+
fpr, tpr, _ = roc_curve(y_true_binary, y_score)
|
|
240
|
+
|
|
241
|
+
# Calculate AUC.
|
|
242
|
+
# Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
|
|
243
|
+
# Here we calculate the specific AUC for the binarized problem.
|
|
244
|
+
auc = roc_auc_score(y_true_binary, y_score)
|
|
166
245
|
|
|
167
|
-
|
|
168
|
-
|
|
246
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
247
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
248
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
249
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
250
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
251
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
252
|
+
ax_roc.legend(loc='lower right')
|
|
253
|
+
ax_roc.grid(True)
|
|
254
|
+
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
255
|
+
plt.savefig(roc_path)
|
|
256
|
+
plt.close(fig_roc)
|
|
257
|
+
|
|
258
|
+
# --- Save Precision-Recall Curve ---
|
|
259
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
260
|
+
ap_score = average_precision_score(y_true_binary, y_score)
|
|
261
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
262
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
|
|
263
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
264
|
+
ax_pr.set_xlabel('Recall')
|
|
265
|
+
ax_pr.set_ylabel('Precision')
|
|
266
|
+
ax_pr.legend(loc='lower left')
|
|
267
|
+
ax_pr.grid(True)
|
|
268
|
+
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
269
|
+
plt.savefig(pr_path)
|
|
270
|
+
plt.close(fig_pr)
|
|
169
271
|
|
|
170
|
-
|
|
272
|
+
# --- Save Calibration Plot ---
|
|
273
|
+
fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
|
|
274
|
+
|
|
275
|
+
# --- Step 1: Get binned data *without* plotting ---
|
|
276
|
+
with plt.ioff(): # Suppress showing the temporary plot
|
|
277
|
+
fig_temp, ax_temp = plt.subplots()
|
|
278
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
279
|
+
y_true_binary, # Use binarized labels
|
|
280
|
+
y_score,
|
|
281
|
+
n_bins=calibration_bins,
|
|
282
|
+
ax=ax_temp,
|
|
283
|
+
name="temp" # Add a name to suppress potential warnings
|
|
284
|
+
)
|
|
285
|
+
# Get the x, y coordinates of the binned data
|
|
286
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
287
|
+
plt.close(fig_temp) # Close the temporary plot
|
|
288
|
+
|
|
289
|
+
# --- Step 2: Build the plot from scratch ---
|
|
290
|
+
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
291
|
+
|
|
292
|
+
sns.regplot(
|
|
293
|
+
x=line_x,
|
|
294
|
+
y=line_y,
|
|
295
|
+
ax=ax_cal,
|
|
296
|
+
scatter=False,
|
|
297
|
+
label=f"Calibration Curve ({calibration_bins} bins)",
|
|
298
|
+
line_kws={
|
|
299
|
+
'color': ROC_PR_line,
|
|
300
|
+
'linestyle': '--',
|
|
301
|
+
'linewidth': 2,
|
|
302
|
+
}
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}')
|
|
171
306
|
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
172
307
|
ax_cal.set_ylabel('Fraction of Positives')
|
|
308
|
+
|
|
309
|
+
# --- Step 3: Set final limits *after* plotting ---
|
|
310
|
+
ax_cal.set_ylim(0.0, 1.0)
|
|
311
|
+
ax_cal.set_xlim(0.0, 1.0)
|
|
312
|
+
|
|
313
|
+
ax_cal.legend(loc='lower right')
|
|
173
314
|
ax_cal.grid(True)
|
|
174
315
|
plt.tight_layout()
|
|
175
316
|
|
|
176
|
-
cal_path = save_dir_path / "calibration_plot.svg"
|
|
317
|
+
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
177
318
|
plt.savefig(cal_path)
|
|
178
|
-
_LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
|
|
179
319
|
plt.close(fig_cal)
|
|
320
|
+
|
|
321
|
+
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
322
|
+
|
|
323
|
+
# restore RC params
|
|
324
|
+
plt.rcParams.update(original_rc_params)
|
|
180
325
|
|
|
181
326
|
|
|
182
327
|
def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[str, Path]):
|
|
@@ -211,7 +356,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
211
356
|
|
|
212
357
|
# Save residual plot
|
|
213
358
|
residuals = y_true - y_pred
|
|
214
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
359
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
215
360
|
ax_res.scatter(y_pred, residuals, alpha=0.6)
|
|
216
361
|
ax_res.axhline(0, color='red', linestyle='--')
|
|
217
362
|
ax_res.set_xlabel("Predicted Values")
|
|
@@ -225,7 +370,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
225
370
|
plt.close(fig_res)
|
|
226
371
|
|
|
227
372
|
# Save true vs predicted plot
|
|
228
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
373
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
229
374
|
ax_tvp.scatter(y_true, y_pred, alpha=0.6)
|
|
230
375
|
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()], 'k--', lw=2)
|
|
231
376
|
ax_tvp.set_xlabel('True Values')
|
|
@@ -239,7 +384,7 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
239
384
|
plt.close(fig_tvp)
|
|
240
385
|
|
|
241
386
|
# Save Histogram of Residuals
|
|
242
|
-
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=
|
|
387
|
+
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
243
388
|
sns.histplot(residuals, kde=True, ax=ax_hist)
|
|
244
389
|
ax_hist.set_xlabel("Residual Value")
|
|
245
390
|
ax_hist.set_ylabel("Frequency")
|
|
@@ -258,7 +403,7 @@ def shap_summary_plot(model,
|
|
|
258
403
|
feature_names: Optional[list[str]],
|
|
259
404
|
save_dir: Union[str, Path],
|
|
260
405
|
device: torch.device = torch.device('cpu'),
|
|
261
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
406
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
262
407
|
"""
|
|
263
408
|
Calculates SHAP values and saves summary plots and data.
|
|
264
409
|
|
|
@@ -270,13 +415,13 @@ def shap_summary_plot(model,
|
|
|
270
415
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
271
416
|
device (torch.device): The torch device for SHAP calculations.
|
|
272
417
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
273
|
-
- 'deep':
|
|
418
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
274
419
|
PyTorch models.
|
|
275
420
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
276
421
|
slow and memory-intensive.
|
|
277
422
|
"""
|
|
278
423
|
|
|
279
|
-
|
|
424
|
+
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
280
425
|
|
|
281
426
|
model.eval()
|
|
282
427
|
# model.cpu() # Run explanations on CPU
|
|
@@ -285,7 +430,7 @@ def shap_summary_plot(model,
|
|
|
285
430
|
instances_to_explain_np = None
|
|
286
431
|
|
|
287
432
|
if explainer_type == 'deep':
|
|
288
|
-
# --- 1. Use DeepExplainer
|
|
433
|
+
# --- 1. Use DeepExplainer ---
|
|
289
434
|
|
|
290
435
|
# Ensure data is torch.Tensor
|
|
291
436
|
if isinstance(background_data, np.ndarray):
|
|
@@ -309,10 +454,9 @@ def shap_summary_plot(model,
|
|
|
309
454
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
310
455
|
|
|
311
456
|
elif explainer_type == 'kernel':
|
|
312
|
-
# --- 2. Use KernelExplainer
|
|
457
|
+
# --- 2. Use KernelExplainer ---
|
|
313
458
|
_LOGGER.warning(
|
|
314
|
-
"
|
|
315
|
-
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
459
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
316
460
|
)
|
|
317
461
|
|
|
318
462
|
# Ensure data is np.ndarray
|
|
@@ -348,14 +492,26 @@ def shap_summary_plot(model,
|
|
|
348
492
|
else:
|
|
349
493
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
350
494
|
raise ValueError()
|
|
495
|
+
|
|
496
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
497
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
498
|
+
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
351
499
|
|
|
352
500
|
# --- 3. Plotting and Saving ---
|
|
353
501
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
354
502
|
plt.ioff()
|
|
355
503
|
|
|
504
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
505
|
+
if feature_names is None:
|
|
506
|
+
# Create generic names if none were provided
|
|
507
|
+
num_features = instances_to_explain_np.shape[1]
|
|
508
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
509
|
+
|
|
510
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
511
|
+
|
|
356
512
|
# Save Bar Plot
|
|
357
513
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
358
|
-
shap.summary_plot(shap_values,
|
|
514
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
359
515
|
ax = plt.gca()
|
|
360
516
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
361
517
|
plt.title("SHAP Feature Importance")
|
|
@@ -366,7 +522,7 @@ def shap_summary_plot(model,
|
|
|
366
522
|
|
|
367
523
|
# Save Dot Plot
|
|
368
524
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
369
|
-
shap.summary_plot(shap_values,
|
|
525
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
370
526
|
ax = plt.gca()
|
|
371
527
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
372
528
|
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
@@ -389,9 +545,6 @@ def shap_summary_plot(model,
|
|
|
389
545
|
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
390
546
|
|
|
391
547
|
mean_abs_shap = mean_abs_shap.flatten()
|
|
392
|
-
|
|
393
|
-
if feature_names is None:
|
|
394
|
-
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
395
548
|
|
|
396
549
|
summary_df = pd.DataFrame({
|
|
397
550
|
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
@@ -401,7 +554,7 @@ def shap_summary_plot(model,
|
|
|
401
554
|
summary_df.to_csv(summary_path, index=False)
|
|
402
555
|
|
|
403
556
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
404
|
-
plt.ion()
|
|
557
|
+
plt.ion()
|
|
405
558
|
|
|
406
559
|
|
|
407
560
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
@@ -447,7 +600,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
447
600
|
# --- Step 3: Create and save the plot for top N features ---
|
|
448
601
|
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
449
602
|
|
|
450
|
-
plt.figure(figsize=(10, 8), dpi=
|
|
603
|
+
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
451
604
|
|
|
452
605
|
# Create horizontal bar plot with error bars
|
|
453
606
|
plt.barh(
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -34,6 +34,8 @@ __all__ = [
|
|
|
34
34
|
"multi_target_shap_summary_plot",
|
|
35
35
|
]
|
|
36
36
|
|
|
37
|
+
DPI_value = 250
|
|
38
|
+
|
|
37
39
|
|
|
38
40
|
def multi_target_regression_metrics(
|
|
39
41
|
y_true: np.ndarray,
|
|
@@ -90,7 +92,7 @@ def multi_target_regression_metrics(
|
|
|
90
92
|
|
|
91
93
|
# --- Save Residual Plot ---
|
|
92
94
|
residuals = true_i - pred_i
|
|
93
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
95
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
94
96
|
ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
|
|
95
97
|
ax_res.axhline(0, color='red', linestyle='--')
|
|
96
98
|
ax_res.set_xlabel("Predicted Values")
|
|
@@ -103,7 +105,7 @@ def multi_target_regression_metrics(
|
|
|
103
105
|
plt.close(fig_res)
|
|
104
106
|
|
|
105
107
|
# --- Save True vs. Predicted Plot ---
|
|
106
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
108
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
107
109
|
ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
|
|
108
110
|
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
|
|
109
111
|
ax_tvp.set_xlabel('True Values')
|
|
@@ -127,7 +129,10 @@ def multi_label_classification_metrics(
|
|
|
127
129
|
y_prob: np.ndarray,
|
|
128
130
|
target_names: List[str],
|
|
129
131
|
save_dir: Union[str, Path],
|
|
130
|
-
threshold: float = 0.5
|
|
132
|
+
threshold: float = 0.5,
|
|
133
|
+
ROC_PR_line: str='darkorange',
|
|
134
|
+
cmap: str = "Blues",
|
|
135
|
+
font_size: int = 16
|
|
131
136
|
):
|
|
132
137
|
"""
|
|
133
138
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -158,6 +163,10 @@ def multi_label_classification_metrics(
|
|
|
158
163
|
|
|
159
164
|
# Generate binary predictions from probabilities
|
|
160
165
|
y_pred = (y_prob >= threshold).astype(int)
|
|
166
|
+
|
|
167
|
+
# --- Save current RC params and update font size ---
|
|
168
|
+
original_rc_params = plt.rcParams.copy()
|
|
169
|
+
plt.rcParams.update({'font.size': font_size})
|
|
161
170
|
|
|
162
171
|
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
163
172
|
|
|
@@ -174,7 +183,7 @@ def multi_label_classification_metrics(
|
|
|
174
183
|
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
175
184
|
f"--------------------------------------------------\n"
|
|
176
185
|
)
|
|
177
|
-
print(overall_report)
|
|
186
|
+
# print(overall_report)
|
|
178
187
|
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
179
188
|
overall_report_path.write_text(overall_report)
|
|
180
189
|
|
|
@@ -192,8 +201,26 @@ def multi_label_classification_metrics(
|
|
|
192
201
|
report_path.write_text(report_text) # type: ignore
|
|
193
202
|
|
|
194
203
|
# --- Save Confusion Matrix ---
|
|
195
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
196
|
-
ConfusionMatrixDisplay.from_predictions(true_i,
|
|
204
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
205
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
206
|
+
pred_i,
|
|
207
|
+
cmap=cmap,
|
|
208
|
+
ax=ax_cm,
|
|
209
|
+
normalize='true',
|
|
210
|
+
labels=[0, 1],
|
|
211
|
+
display_labels=["Negative", "Positive"])
|
|
212
|
+
|
|
213
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
214
|
+
|
|
215
|
+
# Turn off gridlines
|
|
216
|
+
ax_cm.grid(False)
|
|
217
|
+
|
|
218
|
+
# Manually update font size of cell texts
|
|
219
|
+
for text in ax_cm.texts:
|
|
220
|
+
text.set_fontsize(font_size)
|
|
221
|
+
|
|
222
|
+
fig_cm.tight_layout()
|
|
223
|
+
|
|
197
224
|
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
198
225
|
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
199
226
|
plt.savefig(cm_path)
|
|
@@ -202,8 +229,8 @@ def multi_label_classification_metrics(
|
|
|
202
229
|
# --- Save ROC Curve ---
|
|
203
230
|
fpr, tpr, _ = roc_curve(true_i, prob_i)
|
|
204
231
|
auc = roc_auc_score(true_i, prob_i)
|
|
205
|
-
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=
|
|
206
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
232
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
207
234
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
208
235
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
209
236
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -215,14 +242,17 @@ def multi_label_classification_metrics(
|
|
|
215
242
|
# --- Save Precision-Recall Curve ---
|
|
216
243
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
217
244
|
ap_score = average_precision_score(true_i, prob_i)
|
|
218
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=
|
|
219
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
245
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
246
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
|
|
220
247
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
221
248
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
222
249
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
223
250
|
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
224
251
|
plt.savefig(pr_path)
|
|
225
252
|
plt.close(fig_pr)
|
|
253
|
+
|
|
254
|
+
# restore RC params
|
|
255
|
+
plt.rcParams.update(original_rc_params)
|
|
226
256
|
|
|
227
257
|
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
228
258
|
|
|
@@ -235,7 +265,7 @@ def multi_target_shap_summary_plot(
|
|
|
235
265
|
target_names: List[str],
|
|
236
266
|
save_dir: Union[str, Path],
|
|
237
267
|
device: torch.device = torch.device('cpu'),
|
|
238
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
268
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'
|
|
239
269
|
):
|
|
240
270
|
"""
|
|
241
271
|
Calculates SHAP values for a multi-target model and saves summary plots and data for each target.
|
|
@@ -249,7 +279,7 @@ def multi_target_shap_summary_plot(
|
|
|
249
279
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
250
280
|
device (torch.device): The torch device for SHAP calculations.
|
|
251
281
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
252
|
-
- 'deep':
|
|
282
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient.
|
|
253
283
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but slow and memory-intensive.
|
|
254
284
|
"""
|
|
255
285
|
_LOGGER.info(f"--- Multi-Target SHAP Value Explanation (Using: {explainer_type.upper()}Explainer) ---")
|
|
@@ -260,7 +290,7 @@ def multi_target_shap_summary_plot(
|
|
|
260
290
|
instances_to_explain_np = None
|
|
261
291
|
|
|
262
292
|
if explainer_type == 'deep':
|
|
263
|
-
# --- 1. Use DeepExplainer
|
|
293
|
+
# --- 1. Use DeepExplainer ---
|
|
264
294
|
|
|
265
295
|
# Ensure data is torch.Tensor
|
|
266
296
|
if isinstance(background_data, np.ndarray):
|
|
@@ -285,10 +315,9 @@ def multi_target_shap_summary_plot(
|
|
|
285
315
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
286
316
|
|
|
287
317
|
elif explainer_type == 'kernel':
|
|
288
|
-
# --- 2. Use KernelExplainer
|
|
318
|
+
# --- 2. Use KernelExplainer ---
|
|
289
319
|
_LOGGER.warning(
|
|
290
|
-
"
|
|
291
|
-
"Consider reducing 'n_samples' if the process terminates."
|
|
320
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
292
321
|
)
|
|
293
322
|
|
|
294
323
|
# Convert all data to numpy
|
ml_tools/ML_inference.py
CHANGED
|
@@ -82,7 +82,6 @@ class _BaseInferenceHandler(ABC):
|
|
|
82
82
|
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
83
83
|
device_lower = "cpu"
|
|
84
84
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
85
|
-
# Your M-series Mac will appreciate this check!
|
|
86
85
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
87
86
|
device_lower = "cpu"
|
|
88
87
|
return torch.device(device_lower)
|