dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -21,10 +21,17 @@ from pathlib import Path
|
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
22
|
import warnings
|
|
23
23
|
|
|
24
|
-
from .path_manager import make_fullpath
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .
|
|
27
|
+
from ._keys import SHAPKeys, PyTorchLogKeys
|
|
28
|
+
from .ML_configuration import (RegressionMetricsFormat,
|
|
29
|
+
BinaryClassificationMetricsFormat,
|
|
30
|
+
MultiClassClassificationMetricsFormat,
|
|
31
|
+
BinaryImageClassificationMetricsFormat,
|
|
32
|
+
MultiClassImageClassificationMetricsFormat,
|
|
33
|
+
_BaseClassificationFormat,
|
|
34
|
+
_BaseRegressionFormat)
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
__all__ = [
|
|
@@ -35,40 +42,66 @@ __all__ = [
|
|
|
35
42
|
"plot_attention_importance"
|
|
36
43
|
]
|
|
37
44
|
|
|
45
|
+
DPI_value = 250
|
|
46
|
+
|
|
38
47
|
|
|
39
48
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
40
49
|
"""
|
|
41
50
|
Plots training & validation loss curves from a history object.
|
|
51
|
+
Also plots the learning rate if available in the history.
|
|
42
52
|
|
|
43
53
|
Args:
|
|
44
54
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
45
55
|
save_dir (str | Path): Directory to save the plot image.
|
|
46
56
|
"""
|
|
47
|
-
train_loss = history.get(
|
|
48
|
-
val_loss = history.get(
|
|
57
|
+
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
58
|
+
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
59
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
49
60
|
|
|
50
61
|
if not train_loss and not val_loss:
|
|
51
|
-
|
|
62
|
+
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
52
63
|
return
|
|
53
64
|
|
|
54
|
-
fig, ax = plt.subplots(figsize=(10, 5), dpi=
|
|
65
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
66
|
+
|
|
67
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
68
|
+
line_handles = [] # To store line objects for the legend
|
|
55
69
|
|
|
56
70
|
# Plot training loss only if data for it exists
|
|
57
71
|
if train_loss:
|
|
58
72
|
epochs = range(1, len(train_loss) + 1)
|
|
59
|
-
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
73
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
74
|
+
line_handles.append(line1)
|
|
60
75
|
|
|
61
76
|
# Plot validation loss only if data for it exists
|
|
62
77
|
if val_loss:
|
|
63
78
|
epochs = range(1, len(val_loss) + 1)
|
|
64
|
-
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
79
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
80
|
+
line_handles.append(line2)
|
|
65
81
|
|
|
66
82
|
ax.set_title('Training and Validation Loss')
|
|
67
83
|
ax.set_xlabel('Epochs')
|
|
68
|
-
ax.set_ylabel('Loss')
|
|
69
|
-
ax.
|
|
70
|
-
ax.grid(True)
|
|
71
|
-
|
|
84
|
+
ax.set_ylabel('Loss', color='tab:blue')
|
|
85
|
+
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
86
|
+
ax.grid(True, linestyle='--')
|
|
87
|
+
|
|
88
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
89
|
+
if lr_history:
|
|
90
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
91
|
+
epochs = range(1, len(lr_history) + 1)
|
|
92
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
93
|
+
line_handles.append(line3)
|
|
94
|
+
|
|
95
|
+
ax2.set_ylabel('Learning Rate', color='g')
|
|
96
|
+
ax2.tick_params(axis='y', labelcolor='g')
|
|
97
|
+
# Use scientific notation if the LR is very small
|
|
98
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
99
|
+
|
|
100
|
+
# Combine legends from both axes
|
|
101
|
+
ax.legend(handles=line_handles, loc='best')
|
|
102
|
+
|
|
103
|
+
# ax.grid(True)
|
|
104
|
+
plt.tight_layout()
|
|
72
105
|
|
|
73
106
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
74
107
|
save_path = save_dir_path / "loss_plot.svg"
|
|
@@ -78,23 +111,55 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
78
111
|
plt.close(fig)
|
|
79
112
|
|
|
80
113
|
|
|
81
|
-
def classification_metrics(save_dir: Union[str, Path],
|
|
82
|
-
|
|
114
|
+
def classification_metrics(save_dir: Union[str, Path],
|
|
115
|
+
y_true: np.ndarray,
|
|
116
|
+
y_pred: np.ndarray,
|
|
117
|
+
y_prob: Optional[np.ndarray] = None,
|
|
118
|
+
class_map: Optional[dict[str,int]] = None,
|
|
119
|
+
config: Optional[Union[BinaryClassificationMetricsFormat,
|
|
120
|
+
MultiClassClassificationMetricsFormat,
|
|
121
|
+
BinaryImageClassificationMetricsFormat,
|
|
122
|
+
MultiClassImageClassificationMetricsFormat]] = None):
|
|
83
123
|
"""
|
|
84
124
|
Saves classification metrics and plots.
|
|
85
125
|
|
|
86
126
|
Args:
|
|
87
127
|
y_true (np.ndarray): Ground truth labels.
|
|
88
128
|
y_pred (np.ndarray): Predicted labels.
|
|
89
|
-
y_prob (np.ndarray
|
|
90
|
-
|
|
129
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
130
|
+
config (object): Formatting configuration object.
|
|
91
131
|
save_dir (str | Path): Directory to save plots.
|
|
92
132
|
"""
|
|
93
|
-
|
|
133
|
+
# --- Parse Config or use defaults ---
|
|
134
|
+
if config is None:
|
|
135
|
+
# Create a default config if one wasn't provided
|
|
136
|
+
format_config = _BaseClassificationFormat()
|
|
137
|
+
else:
|
|
138
|
+
format_config = config
|
|
139
|
+
|
|
140
|
+
original_rc_params = plt.rcParams.copy()
|
|
141
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
142
|
+
|
|
143
|
+
# print("--- Classification Report ---")
|
|
144
|
+
|
|
145
|
+
# --- Parse class_map ---
|
|
146
|
+
map_labels = None
|
|
147
|
+
map_display_labels = None
|
|
148
|
+
if class_map:
|
|
149
|
+
# Sort the map by its values (the indices) to ensure correct order
|
|
150
|
+
try:
|
|
151
|
+
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
152
|
+
map_labels = [item[1] for item in sorted_items]
|
|
153
|
+
map_display_labels = [item[0] for item in sorted_items]
|
|
154
|
+
except Exception as e:
|
|
155
|
+
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
156
|
+
map_labels = None
|
|
157
|
+
map_display_labels = None
|
|
158
|
+
|
|
94
159
|
# Generate report as both text and dictionary
|
|
95
|
-
report_text: str = classification_report(y_true, y_pred) # type: ignore
|
|
96
|
-
report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
|
|
97
|
-
print(report_text)
|
|
160
|
+
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
161
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
162
|
+
# print(report_text)
|
|
98
163
|
|
|
99
164
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
100
165
|
# Save text report
|
|
@@ -104,8 +169,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
104
169
|
|
|
105
170
|
# --- Save Classification Report Heatmap ---
|
|
106
171
|
try:
|
|
107
|
-
plt.figure(figsize=(8, 6), dpi=
|
|
108
|
-
sns.
|
|
172
|
+
plt.figure(figsize=(8, 6), dpi=DPI_value)
|
|
173
|
+
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
174
|
+
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
175
|
+
annot=True,
|
|
176
|
+
cmap=format_config.cmap,
|
|
177
|
+
fmt='.2f',
|
|
178
|
+
vmin=0.0,
|
|
179
|
+
vmax=1.0)
|
|
180
|
+
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
109
181
|
plt.title("Classification Report")
|
|
110
182
|
plt.tight_layout()
|
|
111
183
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
@@ -114,72 +186,224 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
114
186
|
plt.close()
|
|
115
187
|
except Exception as e:
|
|
116
188
|
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
117
|
-
|
|
189
|
+
|
|
190
|
+
# --- labels for Confusion Matrix ---
|
|
191
|
+
plot_labels = map_labels
|
|
192
|
+
plot_display_labels = map_display_labels
|
|
193
|
+
|
|
118
194
|
# Save Confusion Matrix
|
|
119
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
120
|
-
ConfusionMatrixDisplay.from_predictions(y_true,
|
|
195
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
196
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
197
|
+
y_pred,
|
|
198
|
+
cmap=format_config.cmap,
|
|
199
|
+
ax=ax_cm,
|
|
200
|
+
normalize='true',
|
|
201
|
+
labels=plot_labels,
|
|
202
|
+
display_labels=plot_display_labels)
|
|
203
|
+
|
|
204
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
205
|
+
|
|
206
|
+
# Turn off gridlines
|
|
207
|
+
ax_cm.grid(False)
|
|
208
|
+
|
|
209
|
+
# Manually update font size of cell texts
|
|
210
|
+
for text in ax_cm.texts:
|
|
211
|
+
text.set_fontsize(format_config.font_size)
|
|
212
|
+
|
|
213
|
+
fig_cm.tight_layout()
|
|
214
|
+
|
|
121
215
|
ax_cm.set_title("Confusion Matrix")
|
|
122
216
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
123
217
|
plt.savefig(cm_path)
|
|
124
218
|
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
125
219
|
plt.close(fig_cm)
|
|
126
220
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
221
|
+
|
|
222
|
+
# Plotting logic for ROC, PR, and Calibration Curves
|
|
223
|
+
if y_prob is not None and y_prob.ndim == 2:
|
|
224
|
+
num_classes = y_prob.shape[1]
|
|
131
225
|
|
|
132
|
-
# ---
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
226
|
+
# --- Determine which classes to loop over ---
|
|
227
|
+
class_indices_to_plot = []
|
|
228
|
+
plot_titles = []
|
|
229
|
+
save_suffixes = []
|
|
230
|
+
|
|
231
|
+
if num_classes == 2:
|
|
232
|
+
# Binary case: Only plot for the positive class (index 1)
|
|
233
|
+
class_indices_to_plot = [1]
|
|
234
|
+
plot_titles = [""] # No extra title
|
|
235
|
+
save_suffixes = [""] # No extra suffix
|
|
236
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
237
|
+
|
|
238
|
+
elif num_classes > 2:
|
|
239
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
240
|
+
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
241
|
+
class_indices_to_plot = list(range(num_classes))
|
|
242
|
+
|
|
243
|
+
# --- Use class_map names if available ---
|
|
244
|
+
use_generic_names = True
|
|
245
|
+
if map_display_labels and len(map_display_labels) == num_classes:
|
|
246
|
+
try:
|
|
247
|
+
# Ensure labels are safe for filenames
|
|
248
|
+
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
249
|
+
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
250
|
+
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
251
|
+
use_generic_names = False
|
|
252
|
+
except Exception as e:
|
|
253
|
+
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
254
|
+
use_generic_names = True
|
|
255
|
+
|
|
256
|
+
if use_generic_names:
|
|
257
|
+
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
258
|
+
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
162
259
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
260
|
+
else:
|
|
261
|
+
# Should not happen, but good to check
|
|
262
|
+
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
263
|
+
|
|
264
|
+
# --- Loop and generate plots ---
|
|
265
|
+
for i, class_index in enumerate(class_indices_to_plot):
|
|
266
|
+
plot_title = plot_titles[i]
|
|
267
|
+
save_suffix = save_suffixes[i]
|
|
268
|
+
|
|
269
|
+
# Get scores for the current class
|
|
270
|
+
y_score = y_prob[:, class_index]
|
|
271
|
+
|
|
272
|
+
# Binarize y_true for the current class
|
|
273
|
+
y_true_binary = (y_true == class_index).astype(int)
|
|
274
|
+
|
|
275
|
+
# --- Save ROC Curve ---
|
|
276
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
280
|
+
J = tpr - fpr
|
|
281
|
+
# Find the index of the best threshold
|
|
282
|
+
best_index = np.argmax(J)
|
|
283
|
+
optimal_threshold = thresholds[best_index]
|
|
284
|
+
|
|
285
|
+
# Define the filename
|
|
286
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
287
|
+
threshold_path = save_dir_path / threshold_filename
|
|
288
|
+
|
|
289
|
+
# Get the class name for the report
|
|
290
|
+
class_name = ""
|
|
291
|
+
# Check if we have display labels and the current index is valid
|
|
292
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
293
|
+
class_name = map_display_labels[class_index]
|
|
294
|
+
if num_classes > 2:
|
|
295
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
296
|
+
class_name += " (vs. Rest)"
|
|
297
|
+
else:
|
|
298
|
+
# Fallback to the generic title or default binary name
|
|
299
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
300
|
+
|
|
301
|
+
# Create content for the file
|
|
302
|
+
file_content = (
|
|
303
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
304
|
+
f"Class: {class_name}\n"
|
|
305
|
+
f"--------------------------------------------------\n"
|
|
306
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
307
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
308
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
312
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
166
316
|
|
|
167
|
-
|
|
168
|
-
|
|
317
|
+
# Calculate AUC.
|
|
318
|
+
auc = roc_auc_score(y_true_binary, y_score)
|
|
169
319
|
|
|
170
|
-
|
|
320
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
321
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
322
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
323
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
324
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
325
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
326
|
+
ax_roc.legend(loc='lower right')
|
|
327
|
+
ax_roc.grid(True)
|
|
328
|
+
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
329
|
+
plt.savefig(roc_path)
|
|
330
|
+
plt.close(fig_roc)
|
|
331
|
+
|
|
332
|
+
# --- Save Precision-Recall Curve ---
|
|
333
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
334
|
+
ap_score = average_precision_score(y_true_binary, y_score)
|
|
335
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
336
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
337
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
338
|
+
ax_pr.set_xlabel('Recall')
|
|
339
|
+
ax_pr.set_ylabel('Precision')
|
|
340
|
+
ax_pr.legend(loc='lower left')
|
|
341
|
+
ax_pr.grid(True)
|
|
342
|
+
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
343
|
+
plt.savefig(pr_path)
|
|
344
|
+
plt.close(fig_pr)
|
|
345
|
+
|
|
346
|
+
# --- Save Calibration Plot ---
|
|
347
|
+
fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
|
|
348
|
+
|
|
349
|
+
# --- Step 1: Get binned data *without* plotting ---
|
|
350
|
+
with plt.ioff(): # Suppress showing the temporary plot
|
|
351
|
+
fig_temp, ax_temp = plt.subplots()
|
|
352
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
353
|
+
y_true_binary, # Use binarized labels
|
|
354
|
+
y_score,
|
|
355
|
+
n_bins=format_config.calibration_bins,
|
|
356
|
+
ax=ax_temp,
|
|
357
|
+
name="temp" # Add a name to suppress potential warnings
|
|
358
|
+
)
|
|
359
|
+
# Get the x, y coordinates of the binned data
|
|
360
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
361
|
+
plt.close(fig_temp) # Close the temporary plot
|
|
362
|
+
|
|
363
|
+
# --- Step 2: Build the plot from scratch ---
|
|
364
|
+
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
365
|
+
|
|
366
|
+
sns.regplot(
|
|
367
|
+
x=line_x,
|
|
368
|
+
y=line_y,
|
|
369
|
+
ax=ax_cal,
|
|
370
|
+
scatter=False,
|
|
371
|
+
label=f"Calibration Curve ({format_config.calibration_bins} bins)",
|
|
372
|
+
line_kws={
|
|
373
|
+
'color': format_config.ROC_PR_line,
|
|
374
|
+
'linestyle': '--',
|
|
375
|
+
'linewidth': 2,
|
|
376
|
+
}
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}')
|
|
171
380
|
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
172
381
|
ax_cal.set_ylabel('Fraction of Positives')
|
|
382
|
+
|
|
383
|
+
# --- Step 3: Set final limits *after* plotting ---
|
|
384
|
+
ax_cal.set_ylim(0.0, 1.0)
|
|
385
|
+
ax_cal.set_xlim(0.0, 1.0)
|
|
386
|
+
|
|
387
|
+
ax_cal.legend(loc='lower right')
|
|
173
388
|
ax_cal.grid(True)
|
|
174
389
|
plt.tight_layout()
|
|
175
390
|
|
|
176
|
-
cal_path = save_dir_path / "calibration_plot.svg"
|
|
391
|
+
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
177
392
|
plt.savefig(cal_path)
|
|
178
|
-
_LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
|
|
179
393
|
plt.close(fig_cal)
|
|
394
|
+
|
|
395
|
+
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
396
|
+
|
|
397
|
+
# restore RC params
|
|
398
|
+
plt.rcParams.update(original_rc_params)
|
|
180
399
|
|
|
181
400
|
|
|
182
|
-
def regression_metrics(
|
|
401
|
+
def regression_metrics(
|
|
402
|
+
y_true: np.ndarray,
|
|
403
|
+
y_pred: np.ndarray,
|
|
404
|
+
save_dir: Union[str, Path],
|
|
405
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
406
|
+
):
|
|
183
407
|
"""
|
|
184
408
|
Saves regression metrics and plots.
|
|
185
409
|
|
|
@@ -187,7 +411,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
187
411
|
y_true (np.ndarray): Ground truth values.
|
|
188
412
|
y_pred (np.ndarray): Predicted values.
|
|
189
413
|
save_dir (str | Path): Directory to save plots and report.
|
|
414
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
190
415
|
"""
|
|
416
|
+
|
|
417
|
+
# --- Parse Config or use defaults ---
|
|
418
|
+
if config is None:
|
|
419
|
+
# Create a default config if one wasn't provided
|
|
420
|
+
format_config = _BaseRegressionFormat()
|
|
421
|
+
else:
|
|
422
|
+
format_config = config
|
|
423
|
+
|
|
424
|
+
# --- Set Matplotlib font size ---
|
|
425
|
+
original_rc_params = plt.rcParams.copy()
|
|
426
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
427
|
+
|
|
428
|
+
# --- Calculate Metrics ---
|
|
191
429
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
192
430
|
mae = mean_absolute_error(y_true, y_pred)
|
|
193
431
|
r2 = r2_score(y_true, y_pred)
|
|
@@ -209,11 +447,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
209
447
|
report_path.write_text(report_string)
|
|
210
448
|
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
211
449
|
|
|
212
|
-
# Save residual plot
|
|
450
|
+
# --- Save residual plot ---
|
|
213
451
|
residuals = y_true - y_pred
|
|
214
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
215
|
-
ax_res.scatter(y_pred, residuals,
|
|
216
|
-
|
|
452
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
453
|
+
ax_res.scatter(y_pred, residuals,
|
|
454
|
+
alpha=format_config.scatter_alpha,
|
|
455
|
+
color=format_config.scatter_color)
|
|
456
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
|
|
217
457
|
ax_res.set_xlabel("Predicted Values")
|
|
218
458
|
ax_res.set_ylabel("Residuals")
|
|
219
459
|
ax_res.set_title("Residual Plot")
|
|
@@ -224,10 +464,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
224
464
|
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
225
465
|
plt.close(fig_res)
|
|
226
466
|
|
|
227
|
-
# Save true vs predicted plot
|
|
228
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
229
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
230
|
-
|
|
467
|
+
# --- Save true vs predicted plot ---
|
|
468
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
469
|
+
ax_tvp.scatter(y_true, y_pred,
|
|
470
|
+
alpha=format_config.scatter_alpha,
|
|
471
|
+
color=format_config.scatter_color)
|
|
472
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
473
|
+
linestyle='--',
|
|
474
|
+
lw=2,
|
|
475
|
+
color=format_config.ideal_line_color)
|
|
231
476
|
ax_tvp.set_xlabel('True Values')
|
|
232
477
|
ax_tvp.set_ylabel('Predictions')
|
|
233
478
|
ax_tvp.set_title('True vs. Predicted Values')
|
|
@@ -238,9 +483,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
238
483
|
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
239
484
|
plt.close(fig_tvp)
|
|
240
485
|
|
|
241
|
-
# Save Histogram of Residuals
|
|
242
|
-
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=
|
|
243
|
-
sns.histplot(residuals, kde=True, ax=ax_hist
|
|
486
|
+
# --- Save Histogram of Residuals ---
|
|
487
|
+
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
488
|
+
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
489
|
+
bins=format_config.hist_bins,
|
|
490
|
+
color=format_config.scatter_color)
|
|
244
491
|
ax_hist.set_xlabel("Residual Value")
|
|
245
492
|
ax_hist.set_ylabel("Frequency")
|
|
246
493
|
ax_hist.set_title("Distribution of Residuals")
|
|
@@ -251,6 +498,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
251
498
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
252
499
|
plt.close(fig_hist)
|
|
253
500
|
|
|
501
|
+
# --- Restore RC params ---
|
|
502
|
+
plt.rcParams.update(original_rc_params)
|
|
503
|
+
|
|
254
504
|
|
|
255
505
|
def shap_summary_plot(model,
|
|
256
506
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
@@ -258,7 +508,7 @@ def shap_summary_plot(model,
|
|
|
258
508
|
feature_names: Optional[list[str]],
|
|
259
509
|
save_dir: Union[str, Path],
|
|
260
510
|
device: torch.device = torch.device('cpu'),
|
|
261
|
-
explainer_type: Literal['deep', 'kernel'] = '
|
|
511
|
+
explainer_type: Literal['deep', 'kernel'] = 'kernel'):
|
|
262
512
|
"""
|
|
263
513
|
Calculates SHAP values and saves summary plots and data.
|
|
264
514
|
|
|
@@ -270,13 +520,13 @@ def shap_summary_plot(model,
|
|
|
270
520
|
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
271
521
|
device (torch.device): The torch device for SHAP calculations.
|
|
272
522
|
explainer_type (Literal['deep', 'kernel']): The explainer to use.
|
|
273
|
-
- 'deep':
|
|
523
|
+
- 'deep': Uses shap.DeepExplainer. Fast and efficient for
|
|
274
524
|
PyTorch models.
|
|
275
525
|
- 'kernel': Uses shap.KernelExplainer. Model-agnostic but EXTREMELY
|
|
276
526
|
slow and memory-intensive.
|
|
277
527
|
"""
|
|
278
528
|
|
|
279
|
-
|
|
529
|
+
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
280
530
|
|
|
281
531
|
model.eval()
|
|
282
532
|
# model.cpu() # Run explanations on CPU
|
|
@@ -285,7 +535,7 @@ def shap_summary_plot(model,
|
|
|
285
535
|
instances_to_explain_np = None
|
|
286
536
|
|
|
287
537
|
if explainer_type == 'deep':
|
|
288
|
-
# --- 1. Use DeepExplainer
|
|
538
|
+
# --- 1. Use DeepExplainer ---
|
|
289
539
|
|
|
290
540
|
# Ensure data is torch.Tensor
|
|
291
541
|
if isinstance(background_data, np.ndarray):
|
|
@@ -309,10 +559,9 @@ def shap_summary_plot(model,
|
|
|
309
559
|
instances_to_explain_np = instances_to_explain.cpu().numpy()
|
|
310
560
|
|
|
311
561
|
elif explainer_type == 'kernel':
|
|
312
|
-
# --- 2. Use KernelExplainer
|
|
562
|
+
# --- 2. Use KernelExplainer ---
|
|
313
563
|
_LOGGER.warning(
|
|
314
|
-
"
|
|
315
|
-
"Consider reducing 'n_samples' if the process terminates unexpectedly."
|
|
564
|
+
"KernelExplainer is memory-intensive and slow. Consider reducing the number of instances to explain if the process terminates unexpectedly."
|
|
316
565
|
)
|
|
317
566
|
|
|
318
567
|
# Ensure data is np.ndarray
|
|
@@ -348,14 +597,26 @@ def shap_summary_plot(model,
|
|
|
348
597
|
else:
|
|
349
598
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
350
599
|
raise ValueError()
|
|
600
|
+
|
|
601
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
602
|
+
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
603
|
+
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
351
604
|
|
|
352
605
|
# --- 3. Plotting and Saving ---
|
|
353
606
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
354
607
|
plt.ioff()
|
|
355
608
|
|
|
609
|
+
# Convert instances to a DataFrame. robust way to ensure SHAP correctly maps values to feature names.
|
|
610
|
+
if feature_names is None:
|
|
611
|
+
# Create generic names if none were provided
|
|
612
|
+
num_features = instances_to_explain_np.shape[1]
|
|
613
|
+
feature_names = [f'feature_{i}' for i in range(num_features)]
|
|
614
|
+
|
|
615
|
+
instances_df = pd.DataFrame(instances_to_explain_np, columns=feature_names)
|
|
616
|
+
|
|
356
617
|
# Save Bar Plot
|
|
357
618
|
bar_path = save_dir_path / "shap_bar_plot.svg"
|
|
358
|
-
shap.summary_plot(shap_values,
|
|
619
|
+
shap.summary_plot(shap_values, instances_df, plot_type="bar", show=False)
|
|
359
620
|
ax = plt.gca()
|
|
360
621
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
361
622
|
plt.title("SHAP Feature Importance")
|
|
@@ -366,7 +627,7 @@ def shap_summary_plot(model,
|
|
|
366
627
|
|
|
367
628
|
# Save Dot Plot
|
|
368
629
|
dot_path = save_dir_path / "shap_dot_plot.svg"
|
|
369
|
-
shap.summary_plot(shap_values,
|
|
630
|
+
shap.summary_plot(shap_values, instances_df, plot_type="dot", show=False)
|
|
370
631
|
ax = plt.gca()
|
|
371
632
|
ax.set_xlabel("SHAP Value Impact", labelpad=10)
|
|
372
633
|
if plt.gcf().axes and len(plt.gcf().axes) > 1:
|
|
@@ -389,9 +650,6 @@ def shap_summary_plot(model,
|
|
|
389
650
|
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
|
390
651
|
|
|
391
652
|
mean_abs_shap = mean_abs_shap.flatten()
|
|
392
|
-
|
|
393
|
-
if feature_names is None:
|
|
394
|
-
feature_names = [f'feature_{i}' for i in range(len(mean_abs_shap))]
|
|
395
653
|
|
|
396
654
|
summary_df = pd.DataFrame({
|
|
397
655
|
SHAPKeys.FEATURE_COLUMN: feature_names,
|
|
@@ -401,7 +659,7 @@ def shap_summary_plot(model,
|
|
|
401
659
|
summary_df.to_csv(summary_path, index=False)
|
|
402
660
|
|
|
403
661
|
_LOGGER.info(f"📝 SHAP summary data saved as '{summary_path.name}'")
|
|
404
|
-
plt.ion()
|
|
662
|
+
plt.ion()
|
|
405
663
|
|
|
406
664
|
|
|
407
665
|
def plot_attention_importance(weights: List[torch.Tensor], feature_names: Optional[List[str]], save_dir: Union[str, Path], top_n: int = 10):
|
|
@@ -447,7 +705,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
447
705
|
# --- Step 3: Create and save the plot for top N features ---
|
|
448
706
|
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
449
707
|
|
|
450
|
-
plt.figure(figsize=(10, 8), dpi=
|
|
708
|
+
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
451
709
|
|
|
452
710
|
# Create horizontal bar plot with error bars
|
|
453
711
|
plt.barh(
|