dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -21,10 +21,11 @@ from pathlib import Path
|
|
|
21
21
|
from typing import Union, Optional, List, Literal
|
|
22
22
|
import warnings
|
|
23
23
|
|
|
24
|
-
from .path_manager import make_fullpath
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .
|
|
27
|
+
from ._keys import SHAPKeys, PyTorchLogKeys
|
|
28
|
+
from .ML_configuration import RegressionMetricsFormat, ClassificationMetricsFormat
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
@@ -35,10 +36,13 @@ __all__ = [
|
|
|
35
36
|
"plot_attention_importance"
|
|
36
37
|
]
|
|
37
38
|
|
|
39
|
+
DPI_value = 250
|
|
40
|
+
|
|
38
41
|
|
|
39
42
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
40
43
|
"""
|
|
41
44
|
Plots training & validation loss curves from a history object.
|
|
45
|
+
Also plots the learning rate if available in the history.
|
|
42
46
|
|
|
43
47
|
Args:
|
|
44
48
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
@@ -46,29 +50,52 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
46
50
|
"""
|
|
47
51
|
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
48
52
|
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
53
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
49
54
|
|
|
50
55
|
if not train_loss and not val_loss:
|
|
51
|
-
|
|
56
|
+
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
52
57
|
return
|
|
53
58
|
|
|
54
|
-
fig, ax = plt.subplots(figsize=(10, 5), dpi=
|
|
59
|
+
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
60
|
+
|
|
61
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
62
|
+
line_handles = [] # To store line objects for the legend
|
|
55
63
|
|
|
56
64
|
# Plot training loss only if data for it exists
|
|
57
65
|
if train_loss:
|
|
58
66
|
epochs = range(1, len(train_loss) + 1)
|
|
59
|
-
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
67
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
68
|
+
line_handles.append(line1)
|
|
60
69
|
|
|
61
70
|
# Plot validation loss only if data for it exists
|
|
62
71
|
if val_loss:
|
|
63
72
|
epochs = range(1, len(val_loss) + 1)
|
|
64
|
-
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
73
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
74
|
+
line_handles.append(line2)
|
|
65
75
|
|
|
66
76
|
ax.set_title('Training and Validation Loss')
|
|
67
77
|
ax.set_xlabel('Epochs')
|
|
68
|
-
ax.set_ylabel('Loss')
|
|
69
|
-
ax.
|
|
70
|
-
ax.grid(True)
|
|
71
|
-
|
|
78
|
+
ax.set_ylabel('Loss', color='tab:blue')
|
|
79
|
+
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
80
|
+
ax.grid(True, linestyle='--')
|
|
81
|
+
|
|
82
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
83
|
+
if lr_history:
|
|
84
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
85
|
+
epochs = range(1, len(lr_history) + 1)
|
|
86
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
87
|
+
line_handles.append(line3)
|
|
88
|
+
|
|
89
|
+
ax2.set_ylabel('Learning Rate', color='g')
|
|
90
|
+
ax2.tick_params(axis='y', labelcolor='g')
|
|
91
|
+
# Use scientific notation if the LR is very small
|
|
92
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
93
|
+
|
|
94
|
+
# Combine legends from both axes
|
|
95
|
+
ax.legend(handles=line_handles, loc='best')
|
|
96
|
+
|
|
97
|
+
# ax.grid(True)
|
|
98
|
+
plt.tight_layout()
|
|
72
99
|
|
|
73
100
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
74
101
|
save_path = save_dir_path / "loss_plot.svg"
|
|
@@ -78,23 +105,49 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
78
105
|
plt.close(fig)
|
|
79
106
|
|
|
80
107
|
|
|
81
|
-
def classification_metrics(save_dir: Union[str, Path],
|
|
82
|
-
|
|
108
|
+
def classification_metrics(save_dir: Union[str, Path],
|
|
109
|
+
y_true: np.ndarray,
|
|
110
|
+
y_pred: np.ndarray,
|
|
111
|
+
y_prob: Optional[np.ndarray] = None,
|
|
112
|
+
config: Optional[ClassificationMetricsFormat] = None):
|
|
83
113
|
"""
|
|
84
114
|
Saves classification metrics and plots.
|
|
85
115
|
|
|
86
116
|
Args:
|
|
87
117
|
y_true (np.ndarray): Ground truth labels.
|
|
88
118
|
y_pred (np.ndarray): Predicted labels.
|
|
89
|
-
y_prob (np.ndarray
|
|
90
|
-
|
|
119
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
120
|
+
config (ClassificationMetricsFormat): Formatting configuration object.
|
|
91
121
|
save_dir (str | Path): Directory to save plots.
|
|
92
122
|
"""
|
|
93
|
-
|
|
123
|
+
# --- Parse Config or use defaults ---
|
|
124
|
+
if config is None:
|
|
125
|
+
# Create a default config if one wasn't provided
|
|
126
|
+
config = ClassificationMetricsFormat()
|
|
127
|
+
|
|
128
|
+
original_rc_params = plt.rcParams.copy()
|
|
129
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
130
|
+
|
|
131
|
+
# print("--- Classification Report ---")
|
|
132
|
+
|
|
133
|
+
# --- Parse class_map ---
|
|
134
|
+
map_labels = None
|
|
135
|
+
map_display_labels = None
|
|
136
|
+
if config.class_map:
|
|
137
|
+
# Sort the map by its values (the indices) to ensure correct order
|
|
138
|
+
try:
|
|
139
|
+
sorted_items = sorted(config.class_map.items(), key=lambda item: item[1])
|
|
140
|
+
map_labels = [item[1] for item in sorted_items]
|
|
141
|
+
map_display_labels = [item[0] for item in sorted_items]
|
|
142
|
+
except Exception as e:
|
|
143
|
+
_LOGGER.warning(f"Could not parse 'class_map': {e}")
|
|
144
|
+
map_labels = None
|
|
145
|
+
map_display_labels = None
|
|
146
|
+
|
|
94
147
|
# Generate report as both text and dictionary
|
|
95
|
-
report_text: str = classification_report(y_true, y_pred) # type: ignore
|
|
96
|
-
report_dict: dict = classification_report(y_true, y_pred, output_dict=True) # type: ignore
|
|
97
|
-
print(report_text)
|
|
148
|
+
report_text: str = classification_report(y_true, y_pred, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
149
|
+
report_dict: dict = classification_report(y_true, y_pred, output_dict=True, labels=map_labels, target_names=map_display_labels) # type: ignore
|
|
150
|
+
# print(report_text)
|
|
98
151
|
|
|
99
152
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
100
153
|
# Save text report
|
|
@@ -104,8 +157,15 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
104
157
|
|
|
105
158
|
# --- Save Classification Report Heatmap ---
|
|
106
159
|
try:
|
|
107
|
-
plt.figure(figsize=(8, 6), dpi=
|
|
108
|
-
sns.
|
|
160
|
+
plt.figure(figsize=(8, 6), dpi=DPI_value)
|
|
161
|
+
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
162
|
+
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
163
|
+
annot=True,
|
|
164
|
+
cmap=config.cmap,
|
|
165
|
+
fmt='.2f',
|
|
166
|
+
vmin=0.0,
|
|
167
|
+
vmax=1.0)
|
|
168
|
+
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
109
169
|
plt.title("Classification Report")
|
|
110
170
|
plt.tight_layout()
|
|
111
171
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
@@ -114,72 +174,224 @@ def classification_metrics(save_dir: Union[str, Path], y_true: np.ndarray, y_pre
|
|
|
114
174
|
plt.close()
|
|
115
175
|
except Exception as e:
|
|
116
176
|
_LOGGER.error(f"Could not generate classification report heatmap: {e}")
|
|
117
|
-
|
|
177
|
+
|
|
178
|
+
# --- labels for Confusion Matrix ---
|
|
179
|
+
plot_labels = map_labels
|
|
180
|
+
plot_display_labels = map_display_labels
|
|
181
|
+
|
|
118
182
|
# Save Confusion Matrix
|
|
119
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
120
|
-
ConfusionMatrixDisplay.from_predictions(y_true,
|
|
183
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
184
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
185
|
+
y_pred,
|
|
186
|
+
cmap=config.cmap,
|
|
187
|
+
ax=ax_cm,
|
|
188
|
+
normalize='true',
|
|
189
|
+
labels=plot_labels,
|
|
190
|
+
display_labels=plot_display_labels)
|
|
191
|
+
|
|
192
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
193
|
+
|
|
194
|
+
# Turn off gridlines
|
|
195
|
+
ax_cm.grid(False)
|
|
196
|
+
|
|
197
|
+
# Manually update font size of cell texts
|
|
198
|
+
for text in ax_cm.texts:
|
|
199
|
+
text.set_fontsize(config.font_size)
|
|
200
|
+
|
|
201
|
+
fig_cm.tight_layout()
|
|
202
|
+
|
|
121
203
|
ax_cm.set_title("Confusion Matrix")
|
|
122
204
|
cm_path = save_dir_path / "confusion_matrix.svg"
|
|
123
205
|
plt.savefig(cm_path)
|
|
124
206
|
_LOGGER.info(f"❇️ Confusion matrix saved as '{cm_path.name}'")
|
|
125
207
|
plt.close(fig_cm)
|
|
126
208
|
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
209
|
+
|
|
210
|
+
# Plotting logic for ROC, PR, and Calibration Curves
|
|
211
|
+
if y_prob is not None and y_prob.ndim == 2:
|
|
212
|
+
num_classes = y_prob.shape[1]
|
|
131
213
|
|
|
132
|
-
# ---
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
roc_path = save_dir_path / "roc_curve.svg"
|
|
144
|
-
plt.savefig(roc_path)
|
|
145
|
-
_LOGGER.info(f"📈 ROC curve saved as '{roc_path.name}'")
|
|
146
|
-
plt.close(fig_roc)
|
|
147
|
-
|
|
148
|
-
# --- Save Precision-Recall Curve ---
|
|
149
|
-
precision, recall, _ = precision_recall_curve(y_true, y_score)
|
|
150
|
-
ap_score = average_precision_score(y_true, y_score)
|
|
151
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
|
|
152
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
153
|
-
ax_pr.set_title('Precision-Recall Curve')
|
|
154
|
-
ax_pr.set_xlabel('Recall')
|
|
155
|
-
ax_pr.set_ylabel('Precision')
|
|
156
|
-
ax_pr.legend(loc='lower left')
|
|
157
|
-
ax_pr.grid(True)
|
|
158
|
-
pr_path = save_dir_path / "pr_curve.svg"
|
|
159
|
-
plt.savefig(pr_path)
|
|
160
|
-
_LOGGER.info(f"📈 PR curve saved as '{pr_path.name}'")
|
|
161
|
-
plt.close(fig_pr)
|
|
214
|
+
# --- Determine which classes to loop over ---
|
|
215
|
+
class_indices_to_plot = []
|
|
216
|
+
plot_titles = []
|
|
217
|
+
save_suffixes = []
|
|
218
|
+
|
|
219
|
+
if num_classes == 2:
|
|
220
|
+
# Binary case: Only plot for the positive class (index 1)
|
|
221
|
+
class_indices_to_plot = [1]
|
|
222
|
+
plot_titles = [""] # No extra title
|
|
223
|
+
save_suffixes = [""] # No extra suffix
|
|
224
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
162
225
|
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
226
|
+
elif num_classes > 2:
|
|
227
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
228
|
+
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
229
|
+
class_indices_to_plot = list(range(num_classes))
|
|
166
230
|
|
|
167
|
-
|
|
168
|
-
|
|
231
|
+
# --- Use class_map names if available ---
|
|
232
|
+
use_generic_names = True
|
|
233
|
+
if map_display_labels and len(map_display_labels) == num_classes:
|
|
234
|
+
try:
|
|
235
|
+
# Ensure labels are safe for filenames
|
|
236
|
+
safe_names = [sanitize_filename(name) for name in map_display_labels]
|
|
237
|
+
plot_titles = [f" ({name} vs. Rest)" for name in map_display_labels]
|
|
238
|
+
save_suffixes = [f"_{safe_names[i]}" for i in class_indices_to_plot]
|
|
239
|
+
use_generic_names = False
|
|
240
|
+
except Exception as e:
|
|
241
|
+
_LOGGER.warning(f"Failed to use 'class_map' for plot titles: {e}. Reverting to generic names.")
|
|
242
|
+
use_generic_names = True
|
|
169
243
|
|
|
170
|
-
|
|
244
|
+
if use_generic_names:
|
|
245
|
+
plot_titles = [f" (Class {i} vs. Rest)" for i in class_indices_to_plot]
|
|
246
|
+
save_suffixes = [f"_class_{i}" for i in class_indices_to_plot]
|
|
247
|
+
|
|
248
|
+
else:
|
|
249
|
+
# Should not happen, but good to check
|
|
250
|
+
_LOGGER.warning(f"Probability array has invalid shape {y_prob.shape}. Skipping ROC/PR/Calibration plots.")
|
|
251
|
+
|
|
252
|
+
# --- Loop and generate plots ---
|
|
253
|
+
for i, class_index in enumerate(class_indices_to_plot):
|
|
254
|
+
plot_title = plot_titles[i]
|
|
255
|
+
save_suffix = save_suffixes[i]
|
|
256
|
+
|
|
257
|
+
# Get scores for the current class
|
|
258
|
+
y_score = y_prob[:, class_index]
|
|
259
|
+
|
|
260
|
+
# Binarize y_true for the current class
|
|
261
|
+
y_true_binary = (y_true == class_index).astype(int)
|
|
262
|
+
|
|
263
|
+
# --- Save ROC Curve ---
|
|
264
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
268
|
+
J = tpr - fpr
|
|
269
|
+
# Find the index of the best threshold
|
|
270
|
+
best_index = np.argmax(J)
|
|
271
|
+
optimal_threshold = thresholds[best_index]
|
|
272
|
+
|
|
273
|
+
# Define the filename
|
|
274
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
275
|
+
threshold_path = save_dir_path / threshold_filename
|
|
276
|
+
|
|
277
|
+
# Get the class name for the report
|
|
278
|
+
class_name = ""
|
|
279
|
+
# Check if we have display labels and the current index is valid
|
|
280
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
281
|
+
class_name = map_display_labels[class_index]
|
|
282
|
+
if num_classes > 2:
|
|
283
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
284
|
+
class_name += " (vs. Rest)"
|
|
285
|
+
else:
|
|
286
|
+
# Fallback to the generic title or default binary name
|
|
287
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
288
|
+
|
|
289
|
+
# Create content for the file
|
|
290
|
+
file_content = (
|
|
291
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
292
|
+
f"Class: {class_name}\n"
|
|
293
|
+
f"--------------------------------------------------\n"
|
|
294
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
295
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
296
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
300
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
304
|
+
|
|
305
|
+
# Calculate AUC.
|
|
306
|
+
auc = roc_auc_score(y_true_binary, y_score)
|
|
307
|
+
|
|
308
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
309
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line)
|
|
310
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
311
|
+
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
312
|
+
ax_roc.set_xlabel('False Positive Rate')
|
|
313
|
+
ax_roc.set_ylabel('True Positive Rate')
|
|
314
|
+
ax_roc.legend(loc='lower right')
|
|
315
|
+
ax_roc.grid(True)
|
|
316
|
+
roc_path = save_dir_path / f"roc_curve{save_suffix}.svg"
|
|
317
|
+
plt.savefig(roc_path)
|
|
318
|
+
plt.close(fig_roc)
|
|
319
|
+
|
|
320
|
+
# --- Save Precision-Recall Curve ---
|
|
321
|
+
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
322
|
+
ap_score = average_precision_score(y_true_binary, y_score)
|
|
323
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
324
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=config.ROC_PR_line)
|
|
325
|
+
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
326
|
+
ax_pr.set_xlabel('Recall')
|
|
327
|
+
ax_pr.set_ylabel('Precision')
|
|
328
|
+
ax_pr.legend(loc='lower left')
|
|
329
|
+
ax_pr.grid(True)
|
|
330
|
+
pr_path = save_dir_path / f"pr_curve{save_suffix}.svg"
|
|
331
|
+
plt.savefig(pr_path)
|
|
332
|
+
plt.close(fig_pr)
|
|
333
|
+
|
|
334
|
+
# --- Save Calibration Plot ---
|
|
335
|
+
fig_cal, ax_cal = plt.subplots(figsize=(8, 8), dpi=DPI_value)
|
|
336
|
+
|
|
337
|
+
# --- Step 1: Get binned data *without* plotting ---
|
|
338
|
+
with plt.ioff(): # Suppress showing the temporary plot
|
|
339
|
+
fig_temp, ax_temp = plt.subplots()
|
|
340
|
+
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
341
|
+
y_true_binary, # Use binarized labels
|
|
342
|
+
y_score,
|
|
343
|
+
n_bins=config.calibration_bins,
|
|
344
|
+
ax=ax_temp,
|
|
345
|
+
name="temp" # Add a name to suppress potential warnings
|
|
346
|
+
)
|
|
347
|
+
# Get the x, y coordinates of the binned data
|
|
348
|
+
line_x, line_y = cal_display_temp.line_.get_data() # type: ignore
|
|
349
|
+
plt.close(fig_temp) # Close the temporary plot
|
|
350
|
+
|
|
351
|
+
# --- Step 2: Build the plot from scratch ---
|
|
352
|
+
ax_cal.plot([0, 1], [0, 1], 'k--', label='Perfectly calibrated')
|
|
353
|
+
|
|
354
|
+
sns.regplot(
|
|
355
|
+
x=line_x,
|
|
356
|
+
y=line_y,
|
|
357
|
+
ax=ax_cal,
|
|
358
|
+
scatter=False,
|
|
359
|
+
label=f"Calibration Curve ({config.calibration_bins} bins)",
|
|
360
|
+
line_kws={
|
|
361
|
+
'color': config.ROC_PR_line,
|
|
362
|
+
'linestyle': '--',
|
|
363
|
+
'linewidth': 2,
|
|
364
|
+
}
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
ax_cal.set_title(f'Reliability Curve{plot_title}')
|
|
171
368
|
ax_cal.set_xlabel('Mean Predicted Probability')
|
|
172
369
|
ax_cal.set_ylabel('Fraction of Positives')
|
|
370
|
+
|
|
371
|
+
# --- Step 3: Set final limits *after* plotting ---
|
|
372
|
+
ax_cal.set_ylim(0.0, 1.0)
|
|
373
|
+
ax_cal.set_xlim(0.0, 1.0)
|
|
374
|
+
|
|
375
|
+
ax_cal.legend(loc='lower right')
|
|
173
376
|
ax_cal.grid(True)
|
|
174
377
|
plt.tight_layout()
|
|
175
378
|
|
|
176
|
-
cal_path = save_dir_path / "calibration_plot.svg"
|
|
379
|
+
cal_path = save_dir_path / f"calibration_plot{save_suffix}.svg"
|
|
177
380
|
plt.savefig(cal_path)
|
|
178
|
-
_LOGGER.info(f"📈 Calibration plot saved as '{cal_path.name}'")
|
|
179
381
|
plt.close(fig_cal)
|
|
382
|
+
|
|
383
|
+
_LOGGER.info(f"📈 Saved {len(class_indices_to_plot)} sets of ROC, Precision-Recall, and Calibration plots.")
|
|
384
|
+
|
|
385
|
+
# restore RC params
|
|
386
|
+
plt.rcParams.update(original_rc_params)
|
|
180
387
|
|
|
181
388
|
|
|
182
|
-
def regression_metrics(
|
|
389
|
+
def regression_metrics(
|
|
390
|
+
y_true: np.ndarray,
|
|
391
|
+
y_pred: np.ndarray,
|
|
392
|
+
save_dir: Union[str, Path],
|
|
393
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
394
|
+
):
|
|
183
395
|
"""
|
|
184
396
|
Saves regression metrics and plots.
|
|
185
397
|
|
|
@@ -187,7 +399,19 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
187
399
|
y_true (np.ndarray): Ground truth values.
|
|
188
400
|
y_pred (np.ndarray): Predicted values.
|
|
189
401
|
save_dir (str | Path): Directory to save plots and report.
|
|
402
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
190
403
|
"""
|
|
404
|
+
|
|
405
|
+
# --- Parse Config or use defaults ---
|
|
406
|
+
if config is None:
|
|
407
|
+
# Create a default config if one wasn't provided
|
|
408
|
+
config = RegressionMetricsFormat()
|
|
409
|
+
|
|
410
|
+
# --- Set Matplotlib font size ---
|
|
411
|
+
original_rc_params = plt.rcParams.copy()
|
|
412
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
413
|
+
|
|
414
|
+
# --- Calculate Metrics ---
|
|
191
415
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
192
416
|
mae = mean_absolute_error(y_true, y_pred)
|
|
193
417
|
r2 = r2_score(y_true, y_pred)
|
|
@@ -209,11 +433,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
209
433
|
report_path.write_text(report_string)
|
|
210
434
|
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
211
435
|
|
|
212
|
-
# Save residual plot
|
|
436
|
+
# --- Save residual plot ---
|
|
213
437
|
residuals = y_true - y_pred
|
|
214
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
215
|
-
ax_res.scatter(y_pred, residuals,
|
|
216
|
-
|
|
438
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
439
|
+
ax_res.scatter(y_pred, residuals,
|
|
440
|
+
alpha=config.scatter_alpha,
|
|
441
|
+
color=config.scatter_color)
|
|
442
|
+
ax_res.axhline(0, color=config.residual_line_color, linestyle='--')
|
|
217
443
|
ax_res.set_xlabel("Predicted Values")
|
|
218
444
|
ax_res.set_ylabel("Residuals")
|
|
219
445
|
ax_res.set_title("Residual Plot")
|
|
@@ -224,10 +450,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
224
450
|
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
225
451
|
plt.close(fig_res)
|
|
226
452
|
|
|
227
|
-
# Save true vs predicted plot
|
|
228
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
229
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
230
|
-
|
|
453
|
+
# --- Save true vs predicted plot ---
|
|
454
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
455
|
+
ax_tvp.scatter(y_true, y_pred,
|
|
456
|
+
alpha=config.scatter_alpha,
|
|
457
|
+
color=config.scatter_color)
|
|
458
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
459
|
+
linestyle='--',
|
|
460
|
+
lw=2,
|
|
461
|
+
color=config.ideal_line_color)
|
|
231
462
|
ax_tvp.set_xlabel('True Values')
|
|
232
463
|
ax_tvp.set_ylabel('Predictions')
|
|
233
464
|
ax_tvp.set_title('True vs. Predicted Values')
|
|
@@ -238,9 +469,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
238
469
|
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
239
470
|
plt.close(fig_tvp)
|
|
240
471
|
|
|
241
|
-
# Save Histogram of Residuals
|
|
242
|
-
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=
|
|
243
|
-
sns.histplot(residuals, kde=True, ax=ax_hist
|
|
472
|
+
# --- Save Histogram of Residuals ---
|
|
473
|
+
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
474
|
+
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
475
|
+
bins=config.hist_bins,
|
|
476
|
+
color=config.scatter_color)
|
|
244
477
|
ax_hist.set_xlabel("Residual Value")
|
|
245
478
|
ax_hist.set_ylabel("Frequency")
|
|
246
479
|
ax_hist.set_title("Distribution of Residuals")
|
|
@@ -251,6 +484,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
251
484
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
252
485
|
plt.close(fig_hist)
|
|
253
486
|
|
|
487
|
+
# --- Restore RC params ---
|
|
488
|
+
plt.rcParams.update(original_rc_params)
|
|
489
|
+
|
|
254
490
|
|
|
255
491
|
def shap_summary_plot(model,
|
|
256
492
|
background_data: Union[torch.Tensor,np.ndarray],
|
|
@@ -276,7 +512,7 @@ def shap_summary_plot(model,
|
|
|
276
512
|
slow and memory-intensive.
|
|
277
513
|
"""
|
|
278
514
|
|
|
279
|
-
|
|
515
|
+
_LOGGER.info(f"📊 Running SHAP Value Explanation Using {explainer_type.upper()} Explainer")
|
|
280
516
|
|
|
281
517
|
model.eval()
|
|
282
518
|
# model.cpu() # Run explanations on CPU
|
|
@@ -348,9 +584,9 @@ def shap_summary_plot(model,
|
|
|
348
584
|
_LOGGER.error(f"Invalid explainer_type: '{explainer_type}'. Must be 'deep' or 'kernel'.")
|
|
349
585
|
raise ValueError()
|
|
350
586
|
|
|
351
|
-
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1:
|
|
587
|
+
if not isinstance(shap_values, list) and shap_values.ndim == 3 and shap_values.shape[2] == 1: # type: ignore
|
|
352
588
|
# _LOGGER.info("Squeezing SHAP values from (N, F, 1) to (N, F) for regression plot.")
|
|
353
|
-
shap_values = shap_values.squeeze(-1)
|
|
589
|
+
shap_values = shap_values.squeeze(-1) # type: ignore
|
|
354
590
|
|
|
355
591
|
# --- 3. Plotting and Saving ---
|
|
356
592
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
@@ -455,7 +691,7 @@ def plot_attention_importance(weights: List[torch.Tensor], feature_names: Option
|
|
|
455
691
|
# --- Step 3: Create and save the plot for top N features ---
|
|
456
692
|
plot_df = summary_df.head(top_n).sort_values('mean_attention', ascending=True)
|
|
457
693
|
|
|
458
|
-
plt.figure(figsize=(10, 8), dpi=
|
|
694
|
+
plt.figure(figsize=(10, 8), dpi=DPI_value)
|
|
459
695
|
|
|
460
696
|
# Create horizontal bar plot with error bars
|
|
461
697
|
plt.barh(
|