dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +704 -24
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +144 -39
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -24,7 +24,14 @@ import warnings
|
|
|
24
24
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .
|
|
27
|
+
from ._keys import SHAPKeys, PyTorchLogKeys
|
|
28
|
+
from .ML_configuration import (RegressionMetricsFormat,
|
|
29
|
+
BinaryClassificationMetricsFormat,
|
|
30
|
+
MultiClassClassificationMetricsFormat,
|
|
31
|
+
BinaryImageClassificationMetricsFormat,
|
|
32
|
+
MultiClassImageClassificationMetricsFormat,
|
|
33
|
+
_BaseClassificationFormat,
|
|
34
|
+
_BaseRegressionFormat)
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
__all__ = [
|
|
@@ -41,6 +48,7 @@ DPI_value = 250
|
|
|
41
48
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
42
49
|
"""
|
|
43
50
|
Plots training & validation loss curves from a history object.
|
|
51
|
+
Also plots the learning rate if available in the history.
|
|
44
52
|
|
|
45
53
|
Args:
|
|
46
54
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
@@ -48,6 +56,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
48
56
|
"""
|
|
49
57
|
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
50
58
|
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
59
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
51
60
|
|
|
52
61
|
if not train_loss and not val_loss:
|
|
53
62
|
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -55,22 +64,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
55
64
|
|
|
56
65
|
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
57
66
|
|
|
67
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
68
|
+
line_handles = [] # To store line objects for the legend
|
|
69
|
+
|
|
58
70
|
# Plot training loss only if data for it exists
|
|
59
71
|
if train_loss:
|
|
60
72
|
epochs = range(1, len(train_loss) + 1)
|
|
61
|
-
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
73
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
74
|
+
line_handles.append(line1)
|
|
62
75
|
|
|
63
76
|
# Plot validation loss only if data for it exists
|
|
64
77
|
if val_loss:
|
|
65
78
|
epochs = range(1, len(val_loss) + 1)
|
|
66
|
-
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
79
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
80
|
+
line_handles.append(line2)
|
|
67
81
|
|
|
68
82
|
ax.set_title('Training and Validation Loss')
|
|
69
83
|
ax.set_xlabel('Epochs')
|
|
70
|
-
ax.set_ylabel('Loss')
|
|
71
|
-
ax.
|
|
72
|
-
ax.grid(True)
|
|
73
|
-
|
|
84
|
+
ax.set_ylabel('Loss', color='tab:blue')
|
|
85
|
+
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
86
|
+
ax.grid(True, linestyle='--')
|
|
87
|
+
|
|
88
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
89
|
+
if lr_history:
|
|
90
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
91
|
+
epochs = range(1, len(lr_history) + 1)
|
|
92
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
93
|
+
line_handles.append(line3)
|
|
94
|
+
|
|
95
|
+
ax2.set_ylabel('Learning Rate', color='g')
|
|
96
|
+
ax2.tick_params(axis='y', labelcolor='g')
|
|
97
|
+
# Use scientific notation if the LR is very small
|
|
98
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
99
|
+
|
|
100
|
+
# Combine legends from both axes
|
|
101
|
+
ax.legend(handles=line_handles, loc='best')
|
|
102
|
+
|
|
103
|
+
# ax.grid(True)
|
|
104
|
+
plt.tight_layout()
|
|
74
105
|
|
|
75
106
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
76
107
|
save_path = save_dir_path / "loss_plot.svg"
|
|
@@ -84,24 +115,30 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
84
115
|
y_true: np.ndarray,
|
|
85
116
|
y_pred: np.ndarray,
|
|
86
117
|
y_prob: Optional[np.ndarray] = None,
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
118
|
+
class_map: Optional[dict[str,int]] = None,
|
|
119
|
+
config: Optional[Union[BinaryClassificationMetricsFormat,
|
|
120
|
+
MultiClassClassificationMetricsFormat,
|
|
121
|
+
BinaryImageClassificationMetricsFormat,
|
|
122
|
+
MultiClassImageClassificationMetricsFormat]] = None):
|
|
92
123
|
"""
|
|
93
124
|
Saves classification metrics and plots.
|
|
94
125
|
|
|
95
126
|
Args:
|
|
96
127
|
y_true (np.ndarray): Ground truth labels.
|
|
97
128
|
y_pred (np.ndarray): Predicted labels.
|
|
98
|
-
y_prob (np.ndarray
|
|
99
|
-
|
|
129
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
130
|
+
config (object): Formatting configuration object.
|
|
100
131
|
save_dir (str | Path): Directory to save plots.
|
|
101
|
-
class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
|
|
102
132
|
"""
|
|
133
|
+
# --- Parse Config or use defaults ---
|
|
134
|
+
if config is None:
|
|
135
|
+
# Create a default config if one wasn't provided
|
|
136
|
+
format_config = _BaseClassificationFormat()
|
|
137
|
+
else:
|
|
138
|
+
format_config = config
|
|
139
|
+
|
|
103
140
|
original_rc_params = plt.rcParams.copy()
|
|
104
|
-
plt.rcParams.update({'font.size': font_size})
|
|
141
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
105
142
|
|
|
106
143
|
# print("--- Classification Report ---")
|
|
107
144
|
|
|
@@ -136,7 +173,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
136
173
|
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
137
174
|
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
138
175
|
annot=True,
|
|
139
|
-
cmap=cmap,
|
|
176
|
+
cmap=format_config.cmap,
|
|
140
177
|
fmt='.2f',
|
|
141
178
|
vmin=0.0,
|
|
142
179
|
vmax=1.0)
|
|
@@ -158,7 +195,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
158
195
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
159
196
|
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
160
197
|
y_pred,
|
|
161
|
-
cmap=cmap,
|
|
198
|
+
cmap=format_config.cmap,
|
|
162
199
|
ax=ax_cm,
|
|
163
200
|
normalize='true',
|
|
164
201
|
labels=plot_labels,
|
|
@@ -171,7 +208,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
171
208
|
|
|
172
209
|
# Manually update font size of cell texts
|
|
173
210
|
for text in ax_cm.texts:
|
|
174
|
-
text.set_fontsize(font_size)
|
|
211
|
+
text.set_fontsize(format_config.font_size)
|
|
175
212
|
|
|
176
213
|
fig_cm.tight_layout()
|
|
177
214
|
|
|
@@ -196,10 +233,10 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
196
233
|
class_indices_to_plot = [1]
|
|
197
234
|
plot_titles = [""] # No extra title
|
|
198
235
|
save_suffixes = [""] # No extra suffix
|
|
199
|
-
_LOGGER.
|
|
236
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
200
237
|
|
|
201
238
|
elif num_classes > 2:
|
|
202
|
-
_LOGGER.
|
|
239
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
203
240
|
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
204
241
|
class_indices_to_plot = list(range(num_classes))
|
|
205
242
|
|
|
@@ -234,17 +271,54 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
234
271
|
|
|
235
272
|
# Binarize y_true for the current class
|
|
236
273
|
y_true_binary = (y_true == class_index).astype(int)
|
|
237
|
-
|
|
274
|
+
|
|
238
275
|
# --- Save ROC Curve ---
|
|
239
|
-
fpr, tpr,
|
|
276
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
277
|
+
|
|
278
|
+
try:
|
|
279
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
280
|
+
J = tpr - fpr
|
|
281
|
+
# Find the index of the best threshold
|
|
282
|
+
best_index = np.argmax(J)
|
|
283
|
+
optimal_threshold = thresholds[best_index]
|
|
284
|
+
|
|
285
|
+
# Define the filename
|
|
286
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
287
|
+
threshold_path = save_dir_path / threshold_filename
|
|
288
|
+
|
|
289
|
+
# Get the class name for the report
|
|
290
|
+
class_name = ""
|
|
291
|
+
# Check if we have display labels and the current index is valid
|
|
292
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
293
|
+
class_name = map_display_labels[class_index]
|
|
294
|
+
if num_classes > 2:
|
|
295
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
296
|
+
class_name += " (vs. Rest)"
|
|
297
|
+
else:
|
|
298
|
+
# Fallback to the generic title or default binary name
|
|
299
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
300
|
+
|
|
301
|
+
# Create content for the file
|
|
302
|
+
file_content = (
|
|
303
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
304
|
+
f"Class: {class_name}\n"
|
|
305
|
+
f"--------------------------------------------------\n"
|
|
306
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
307
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
308
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
312
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
313
|
+
|
|
314
|
+
except Exception as e:
|
|
315
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
240
316
|
|
|
241
317
|
# Calculate AUC.
|
|
242
|
-
# Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
|
|
243
|
-
# Here we calculate the specific AUC for the binarized problem.
|
|
244
318
|
auc = roc_auc_score(y_true_binary, y_score)
|
|
245
319
|
|
|
246
320
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
247
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
321
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
248
322
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
249
323
|
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
250
324
|
ax_roc.set_xlabel('False Positive Rate')
|
|
@@ -259,7 +333,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
259
333
|
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
260
334
|
ap_score = average_precision_score(y_true_binary, y_score)
|
|
261
335
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
262
|
-
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
|
|
336
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
263
337
|
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
264
338
|
ax_pr.set_xlabel('Recall')
|
|
265
339
|
ax_pr.set_ylabel('Precision')
|
|
@@ -278,7 +352,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
278
352
|
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
279
353
|
y_true_binary, # Use binarized labels
|
|
280
354
|
y_score,
|
|
281
|
-
n_bins=calibration_bins,
|
|
355
|
+
n_bins=format_config.calibration_bins,
|
|
282
356
|
ax=ax_temp,
|
|
283
357
|
name="temp" # Add a name to suppress potential warnings
|
|
284
358
|
)
|
|
@@ -294,9 +368,9 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
294
368
|
y=line_y,
|
|
295
369
|
ax=ax_cal,
|
|
296
370
|
scatter=False,
|
|
297
|
-
label=f"Calibration Curve ({calibration_bins} bins)",
|
|
371
|
+
label=f"Calibration Curve ({format_config.calibration_bins} bins)",
|
|
298
372
|
line_kws={
|
|
299
|
-
'color': ROC_PR_line,
|
|
373
|
+
'color': format_config.ROC_PR_line,
|
|
300
374
|
'linestyle': '--',
|
|
301
375
|
'linewidth': 2,
|
|
302
376
|
}
|
|
@@ -324,7 +398,12 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
324
398
|
plt.rcParams.update(original_rc_params)
|
|
325
399
|
|
|
326
400
|
|
|
327
|
-
def regression_metrics(
|
|
401
|
+
def regression_metrics(
|
|
402
|
+
y_true: np.ndarray,
|
|
403
|
+
y_pred: np.ndarray,
|
|
404
|
+
save_dir: Union[str, Path],
|
|
405
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
406
|
+
):
|
|
328
407
|
"""
|
|
329
408
|
Saves regression metrics and plots.
|
|
330
409
|
|
|
@@ -332,7 +411,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
332
411
|
y_true (np.ndarray): Ground truth values.
|
|
333
412
|
y_pred (np.ndarray): Predicted values.
|
|
334
413
|
save_dir (str | Path): Directory to save plots and report.
|
|
414
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
335
415
|
"""
|
|
416
|
+
|
|
417
|
+
# --- Parse Config or use defaults ---
|
|
418
|
+
if config is None:
|
|
419
|
+
# Create a default config if one wasn't provided
|
|
420
|
+
format_config = _BaseRegressionFormat()
|
|
421
|
+
else:
|
|
422
|
+
format_config = config
|
|
423
|
+
|
|
424
|
+
# --- Set Matplotlib font size ---
|
|
425
|
+
original_rc_params = plt.rcParams.copy()
|
|
426
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
427
|
+
|
|
428
|
+
# --- Calculate Metrics ---
|
|
336
429
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
337
430
|
mae = mean_absolute_error(y_true, y_pred)
|
|
338
431
|
r2 = r2_score(y_true, y_pred)
|
|
@@ -354,11 +447,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
354
447
|
report_path.write_text(report_string)
|
|
355
448
|
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
356
449
|
|
|
357
|
-
# Save residual plot
|
|
450
|
+
# --- Save residual plot ---
|
|
358
451
|
residuals = y_true - y_pred
|
|
359
452
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
360
|
-
ax_res.scatter(y_pred, residuals,
|
|
361
|
-
|
|
453
|
+
ax_res.scatter(y_pred, residuals,
|
|
454
|
+
alpha=format_config.scatter_alpha,
|
|
455
|
+
color=format_config.scatter_color)
|
|
456
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
|
|
362
457
|
ax_res.set_xlabel("Predicted Values")
|
|
363
458
|
ax_res.set_ylabel("Residuals")
|
|
364
459
|
ax_res.set_title("Residual Plot")
|
|
@@ -369,10 +464,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
369
464
|
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
370
465
|
plt.close(fig_res)
|
|
371
466
|
|
|
372
|
-
# Save true vs predicted plot
|
|
467
|
+
# --- Save true vs predicted plot ---
|
|
373
468
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
374
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
375
|
-
|
|
469
|
+
ax_tvp.scatter(y_true, y_pred,
|
|
470
|
+
alpha=format_config.scatter_alpha,
|
|
471
|
+
color=format_config.scatter_color)
|
|
472
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
473
|
+
linestyle='--',
|
|
474
|
+
lw=2,
|
|
475
|
+
color=format_config.ideal_line_color)
|
|
376
476
|
ax_tvp.set_xlabel('True Values')
|
|
377
477
|
ax_tvp.set_ylabel('Predictions')
|
|
378
478
|
ax_tvp.set_title('True vs. Predicted Values')
|
|
@@ -383,9 +483,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
383
483
|
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
384
484
|
plt.close(fig_tvp)
|
|
385
485
|
|
|
386
|
-
# Save Histogram of Residuals
|
|
486
|
+
# --- Save Histogram of Residuals ---
|
|
387
487
|
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
388
|
-
sns.histplot(residuals, kde=True, ax=ax_hist
|
|
488
|
+
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
489
|
+
bins=format_config.hist_bins,
|
|
490
|
+
color=format_config.scatter_color)
|
|
389
491
|
ax_hist.set_xlabel("Residual Value")
|
|
390
492
|
ax_hist.set_ylabel("Frequency")
|
|
391
493
|
ax_hist.set_title("Distribution of Residuals")
|
|
@@ -396,6 +498,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
396
498
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
397
499
|
plt.close(fig_hist)
|
|
398
500
|
|
|
501
|
+
# --- Restore RC params ---
|
|
502
|
+
plt.rcParams.update(original_rc_params)
|
|
503
|
+
|
|
399
504
|
|
|
400
505
|
def shap_summary_plot(model,
|
|
401
506
|
background_data: Union[torch.Tensor,np.ndarray],
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,13 +19,17 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal
|
|
22
|
+
from typing import Union, List, Literal, Optional
|
|
23
23
|
import warnings
|
|
24
24
|
|
|
25
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
26
26
|
from ._logger import _LOGGER
|
|
27
27
|
from ._script_info import _script_info
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
|
+
from .ML_configuration import (MultiTargetRegressionMetricsFormat,
|
|
30
|
+
_BaseRegressionFormat,
|
|
31
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
32
|
+
_BaseMultiLabelFormat)
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
__all__ = [
|
|
@@ -41,7 +45,8 @@ def multi_target_regression_metrics(
|
|
|
41
45
|
y_true: np.ndarray,
|
|
42
46
|
y_pred: np.ndarray,
|
|
43
47
|
target_names: List[str],
|
|
44
|
-
save_dir: Union[str, Path]
|
|
48
|
+
save_dir: Union[str, Path],
|
|
49
|
+
config: Optional[MultiTargetRegressionMetricsFormat] = None
|
|
45
50
|
):
|
|
46
51
|
"""
|
|
47
52
|
Calculates and saves regression metrics for each target individually.
|
|
@@ -55,6 +60,7 @@ def multi_target_regression_metrics(
|
|
|
55
60
|
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
56
61
|
target_names (List[str]): A list of names for the target variables.
|
|
57
62
|
save_dir (str | Path): Directory to save plots and the report.
|
|
63
|
+
config (object): Formatting configuration object.
|
|
58
64
|
"""
|
|
59
65
|
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
60
66
|
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
@@ -68,8 +74,19 @@ def multi_target_regression_metrics(
|
|
|
68
74
|
|
|
69
75
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
70
76
|
metrics_summary = []
|
|
77
|
+
|
|
78
|
+
# --- Parse Config or use defaults ---
|
|
79
|
+
if config is None:
|
|
80
|
+
# Create a default config if one wasn't provided
|
|
81
|
+
format_config = _BaseRegressionFormat()
|
|
82
|
+
else:
|
|
83
|
+
format_config = config
|
|
84
|
+
|
|
85
|
+
# --- Set Matplotlib font size ---
|
|
86
|
+
original_rc_params = plt.rcParams.copy()
|
|
87
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
71
88
|
|
|
72
|
-
_LOGGER.
|
|
89
|
+
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
73
90
|
|
|
74
91
|
for i, name in enumerate(target_names):
|
|
75
92
|
print(f" -> Evaluating target: '{name}'")
|
|
@@ -93,8 +110,12 @@ def multi_target_regression_metrics(
|
|
|
93
110
|
# --- Save Residual Plot ---
|
|
94
111
|
residuals = true_i - pred_i
|
|
95
112
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
96
|
-
ax_res.scatter(pred_i, residuals,
|
|
97
|
-
|
|
113
|
+
ax_res.scatter(pred_i, residuals,
|
|
114
|
+
alpha=format_config.scatter_alpha,
|
|
115
|
+
edgecolors='k',
|
|
116
|
+
s=50,
|
|
117
|
+
color=format_config.scatter_color) # Use config color
|
|
118
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
|
|
98
119
|
ax_res.set_xlabel("Predicted Values")
|
|
99
120
|
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
100
121
|
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
@@ -106,8 +127,15 @@ def multi_target_regression_metrics(
|
|
|
106
127
|
|
|
107
128
|
# --- Save True vs. Predicted Plot ---
|
|
108
129
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
109
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
110
|
-
|
|
130
|
+
ax_tvp.scatter(true_i, pred_i,
|
|
131
|
+
alpha=format_config.scatter_alpha,
|
|
132
|
+
edgecolors='k',
|
|
133
|
+
s=50,
|
|
134
|
+
color=format_config.scatter_color) # Use config color
|
|
135
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
136
|
+
linestyle='--',
|
|
137
|
+
lw=2,
|
|
138
|
+
color=format_config.ideal_line_color) # Use config color
|
|
111
139
|
ax_tvp.set_xlabel('True Values')
|
|
112
140
|
ax_tvp.set_ylabel('Predicted Values')
|
|
113
141
|
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
@@ -122,17 +150,18 @@ def multi_target_regression_metrics(
|
|
|
122
150
|
report_path = save_dir_path / "regression_report_multi.csv"
|
|
123
151
|
summary_df.to_csv(report_path, index=False)
|
|
124
152
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
153
|
+
|
|
154
|
+
# --- Restore RC params ---
|
|
155
|
+
plt.rcParams.update(original_rc_params)
|
|
125
156
|
|
|
126
157
|
|
|
127
158
|
def multi_label_classification_metrics(
|
|
128
159
|
y_true: np.ndarray,
|
|
160
|
+
y_pred: np.ndarray,
|
|
129
161
|
y_prob: np.ndarray,
|
|
130
162
|
target_names: List[str],
|
|
131
163
|
save_dir: Union[str, Path],
|
|
132
|
-
|
|
133
|
-
ROC_PR_line: str='darkorange',
|
|
134
|
-
cmap: str = "Blues",
|
|
135
|
-
font_size: int = 16
|
|
164
|
+
config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
|
|
136
165
|
):
|
|
137
166
|
"""
|
|
138
167
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -143,17 +172,17 @@ def multi_label_classification_metrics(
|
|
|
143
172
|
|
|
144
173
|
Args:
|
|
145
174
|
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
175
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
146
176
|
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
147
177
|
target_names (List[str]): A list of names for the labels.
|
|
148
178
|
save_dir (str | Path): Directory to save plots and reports.
|
|
149
|
-
|
|
150
|
-
binary predictions for metrics like the confusion matrix.
|
|
179
|
+
config (object): Formatting configuration object.
|
|
151
180
|
"""
|
|
152
|
-
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
153
|
-
_LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
181
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
182
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
154
183
|
raise ValueError()
|
|
155
|
-
if y_true.shape != y_prob.shape:
|
|
156
|
-
_LOGGER.error("Shapes of y_true and y_prob must match.")
|
|
184
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
185
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
157
186
|
raise ValueError()
|
|
158
187
|
if y_true.shape[1] != len(target_names):
|
|
159
188
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
@@ -161,22 +190,28 @@ def multi_label_classification_metrics(
|
|
|
161
190
|
|
|
162
191
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
163
192
|
|
|
164
|
-
#
|
|
165
|
-
|
|
193
|
+
# --- Parse Config or use defaults ---
|
|
194
|
+
if config is None:
|
|
195
|
+
# Create a default config if one wasn't provided
|
|
196
|
+
format_config = _BaseMultiLabelFormat()
|
|
197
|
+
else:
|
|
198
|
+
format_config = config
|
|
199
|
+
|
|
200
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
166
201
|
|
|
167
202
|
# --- Save current RC params and update font size ---
|
|
168
203
|
original_rc_params = plt.rcParams.copy()
|
|
169
|
-
plt.rcParams.update({'font.size': font_size})
|
|
204
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
170
205
|
|
|
171
|
-
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
206
|
+
# _LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
172
207
|
|
|
173
|
-
# --- Calculate and Save Overall Metrics ---
|
|
208
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
174
209
|
h_loss = hamming_loss(y_true, y_pred)
|
|
175
210
|
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
176
211
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
177
212
|
|
|
178
213
|
overall_report = (
|
|
179
|
-
f"Overall Multi-Label Metrics
|
|
214
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
180
215
|
f"--------------------------------------------------\n"
|
|
181
216
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
182
217
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
@@ -191,20 +226,20 @@ def multi_label_classification_metrics(
|
|
|
191
226
|
for i, name in enumerate(target_names):
|
|
192
227
|
print(f" -> Evaluating label: '{name}'")
|
|
193
228
|
true_i = y_true[:, i]
|
|
194
|
-
pred_i = y_pred[:, i]
|
|
195
|
-
prob_i = y_prob[:, i]
|
|
229
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
230
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
196
231
|
sanitized_name = sanitize_filename(name)
|
|
197
232
|
|
|
198
|
-
# --- Save Classification Report for the label ---
|
|
233
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
199
234
|
report_text = classification_report(true_i, pred_i)
|
|
200
235
|
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
201
236
|
report_path.write_text(report_text) # type: ignore
|
|
202
237
|
|
|
203
|
-
# --- Save Confusion Matrix ---
|
|
238
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
204
239
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
205
240
|
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
206
241
|
pred_i,
|
|
207
|
-
cmap=cmap,
|
|
242
|
+
cmap=format_config.cmap, # Use config cmap
|
|
208
243
|
ax=ax_cm,
|
|
209
244
|
normalize='true',
|
|
210
245
|
labels=[0, 1],
|
|
@@ -217,7 +252,7 @@ def multi_label_classification_metrics(
|
|
|
217
252
|
|
|
218
253
|
# Manually update font size of cell texts
|
|
219
254
|
for text in ax_cm.texts:
|
|
220
|
-
text.set_fontsize(font_size)
|
|
255
|
+
text.set_fontsize(format_config.font_size) # Use config font_size
|
|
221
256
|
|
|
222
257
|
fig_cm.tight_layout()
|
|
223
258
|
|
|
@@ -226,11 +261,44 @@ def multi_label_classification_metrics(
|
|
|
226
261
|
plt.savefig(cm_path)
|
|
227
262
|
plt.close(fig_cm)
|
|
228
263
|
|
|
229
|
-
# --- Save ROC Curve ---
|
|
230
|
-
fpr, tpr,
|
|
264
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
265
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
269
|
+
J = tpr - fpr
|
|
270
|
+
# Find the index of the best threshold
|
|
271
|
+
best_index = np.argmax(J)
|
|
272
|
+
optimal_threshold = thresholds[best_index]
|
|
273
|
+
best_tpr = tpr[best_index]
|
|
274
|
+
best_fpr = fpr[best_index]
|
|
275
|
+
|
|
276
|
+
# Define the filename
|
|
277
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
278
|
+
threshold_path = save_dir_path / threshold_filename
|
|
279
|
+
|
|
280
|
+
# The class name is the target_name for this label
|
|
281
|
+
class_name = name
|
|
282
|
+
|
|
283
|
+
# Create content for the file
|
|
284
|
+
file_content = (
|
|
285
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
286
|
+
f"Class/Label: {class_name}\n"
|
|
287
|
+
f"--------------------------------------------------\n"
|
|
288
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
289
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
290
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
294
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
298
|
+
|
|
231
299
|
auc = roc_auc_score(true_i, prob_i)
|
|
232
300
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
301
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
234
302
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
235
303
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
236
304
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -239,11 +307,11 @@ def multi_label_classification_metrics(
|
|
|
239
307
|
plt.savefig(roc_path)
|
|
240
308
|
plt.close(fig_roc)
|
|
241
309
|
|
|
242
|
-
# --- Save Precision-Recall Curve ---
|
|
310
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
243
311
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
244
312
|
ap_score = average_precision_score(true_i, prob_i)
|
|
245
313
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
246
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
|
|
314
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
247
315
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
248
316
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
249
317
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|