dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +726 -32
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +160 -42
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -24,7 +24,14 @@ import warnings
|
|
|
24
24
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .
|
|
27
|
+
from ._keys import SHAPKeys, PyTorchLogKeys
|
|
28
|
+
from .ML_configuration import (RegressionMetricsFormat,
|
|
29
|
+
BinaryClassificationMetricsFormat,
|
|
30
|
+
MultiClassClassificationMetricsFormat,
|
|
31
|
+
BinaryImageClassificationMetricsFormat,
|
|
32
|
+
MultiClassImageClassificationMetricsFormat,
|
|
33
|
+
_BaseClassificationFormat,
|
|
34
|
+
_BaseRegressionFormat)
|
|
28
35
|
|
|
29
36
|
|
|
30
37
|
__all__ = [
|
|
@@ -41,6 +48,7 @@ DPI_value = 250
|
|
|
41
48
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
42
49
|
"""
|
|
43
50
|
Plots training & validation loss curves from a history object.
|
|
51
|
+
Also plots the learning rate if available in the history.
|
|
44
52
|
|
|
45
53
|
Args:
|
|
46
54
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
@@ -48,6 +56,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
48
56
|
"""
|
|
49
57
|
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
50
58
|
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
59
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
51
60
|
|
|
52
61
|
if not train_loss and not val_loss:
|
|
53
62
|
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -55,22 +64,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
55
64
|
|
|
56
65
|
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
57
66
|
|
|
67
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
68
|
+
line_handles = [] # To store line objects for the legend
|
|
69
|
+
|
|
58
70
|
# Plot training loss only if data for it exists
|
|
59
71
|
if train_loss:
|
|
60
72
|
epochs = range(1, len(train_loss) + 1)
|
|
61
|
-
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
73
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
74
|
+
line_handles.append(line1)
|
|
62
75
|
|
|
63
76
|
# Plot validation loss only if data for it exists
|
|
64
77
|
if val_loss:
|
|
65
78
|
epochs = range(1, len(val_loss) + 1)
|
|
66
|
-
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
79
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
80
|
+
line_handles.append(line2)
|
|
67
81
|
|
|
68
82
|
ax.set_title('Training and Validation Loss')
|
|
69
83
|
ax.set_xlabel('Epochs')
|
|
70
|
-
ax.set_ylabel('Loss')
|
|
71
|
-
ax.
|
|
72
|
-
ax.grid(True)
|
|
73
|
-
|
|
84
|
+
ax.set_ylabel('Loss', color='tab:blue')
|
|
85
|
+
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
86
|
+
ax.grid(True, linestyle='--')
|
|
87
|
+
|
|
88
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
89
|
+
if lr_history:
|
|
90
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
91
|
+
epochs = range(1, len(lr_history) + 1)
|
|
92
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
93
|
+
line_handles.append(line3)
|
|
94
|
+
|
|
95
|
+
ax2.set_ylabel('Learning Rate', color='g')
|
|
96
|
+
ax2.tick_params(axis='y', labelcolor='g')
|
|
97
|
+
# Use scientific notation if the LR is very small
|
|
98
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
99
|
+
|
|
100
|
+
# Combine legends from both axes
|
|
101
|
+
ax.legend(handles=line_handles, loc='best')
|
|
102
|
+
|
|
103
|
+
# ax.grid(True)
|
|
104
|
+
plt.tight_layout()
|
|
74
105
|
|
|
75
106
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
76
107
|
save_path = save_dir_path / "loss_plot.svg"
|
|
@@ -84,24 +115,30 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
84
115
|
y_true: np.ndarray,
|
|
85
116
|
y_pred: np.ndarray,
|
|
86
117
|
y_prob: Optional[np.ndarray] = None,
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
118
|
+
class_map: Optional[dict[str,int]] = None,
|
|
119
|
+
config: Optional[Union[BinaryClassificationMetricsFormat,
|
|
120
|
+
MultiClassClassificationMetricsFormat,
|
|
121
|
+
BinaryImageClassificationMetricsFormat,
|
|
122
|
+
MultiClassImageClassificationMetricsFormat]] = None):
|
|
92
123
|
"""
|
|
93
124
|
Saves classification metrics and plots.
|
|
94
125
|
|
|
95
126
|
Args:
|
|
96
127
|
y_true (np.ndarray): Ground truth labels.
|
|
97
128
|
y_pred (np.ndarray): Predicted labels.
|
|
98
|
-
y_prob (np.ndarray
|
|
99
|
-
|
|
129
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
130
|
+
config (object): Formatting configuration object.
|
|
100
131
|
save_dir (str | Path): Directory to save plots.
|
|
101
|
-
class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
|
|
102
132
|
"""
|
|
133
|
+
# --- Parse Config or use defaults ---
|
|
134
|
+
if config is None:
|
|
135
|
+
# Create a default config if one wasn't provided
|
|
136
|
+
format_config = _BaseClassificationFormat()
|
|
137
|
+
else:
|
|
138
|
+
format_config = config
|
|
139
|
+
|
|
103
140
|
original_rc_params = plt.rcParams.copy()
|
|
104
|
-
plt.rcParams.update({'font.size': font_size})
|
|
141
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
105
142
|
|
|
106
143
|
# print("--- Classification Report ---")
|
|
107
144
|
|
|
@@ -132,16 +169,29 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
132
169
|
|
|
133
170
|
# --- Save Classification Report Heatmap ---
|
|
134
171
|
try:
|
|
135
|
-
|
|
172
|
+
# Create DataFrame from report
|
|
173
|
+
report_df = pd.DataFrame(report_dict)
|
|
174
|
+
|
|
175
|
+
# 1. Drop the 'accuracy' column (single float)
|
|
176
|
+
if 'accuracy' in report_df.columns:
|
|
177
|
+
report_df = report_df.drop(columns=['accuracy'])
|
|
178
|
+
|
|
179
|
+
# 2. Select all metric rows *except* the last one ('support')
|
|
180
|
+
# 3. Transpose the DataFrame
|
|
181
|
+
plot_df = report_df.iloc[:-1, :].T
|
|
182
|
+
|
|
183
|
+
fig_height = max(5.0, len(plot_df.index) * 0.5 + 2.0)
|
|
184
|
+
plt.figure(figsize=(7, fig_height), dpi=DPI_value)
|
|
185
|
+
|
|
136
186
|
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
137
|
-
sns.heatmap(
|
|
187
|
+
sns.heatmap(plot_df,
|
|
138
188
|
annot=True,
|
|
139
|
-
cmap=cmap,
|
|
189
|
+
cmap=format_config.cmap,
|
|
140
190
|
fmt='.2f',
|
|
141
191
|
vmin=0.0,
|
|
142
192
|
vmax=1.0)
|
|
143
193
|
sns.set_theme(font_scale=1.0) # Reset seaborn scale
|
|
144
|
-
plt.title("Classification Report")
|
|
194
|
+
plt.title("Classification Report Heatmap")
|
|
145
195
|
plt.tight_layout()
|
|
146
196
|
heatmap_path = save_dir_path / "classification_report_heatmap.svg"
|
|
147
197
|
plt.savefig(heatmap_path)
|
|
@@ -158,7 +208,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
158
208
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
159
209
|
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
160
210
|
y_pred,
|
|
161
|
-
cmap=cmap,
|
|
211
|
+
cmap=format_config.cmap,
|
|
162
212
|
ax=ax_cm,
|
|
163
213
|
normalize='true',
|
|
164
214
|
labels=plot_labels,
|
|
@@ -171,7 +221,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
171
221
|
|
|
172
222
|
# Manually update font size of cell texts
|
|
173
223
|
for text in ax_cm.texts:
|
|
174
|
-
text.set_fontsize(font_size)
|
|
224
|
+
text.set_fontsize(format_config.font_size)
|
|
175
225
|
|
|
176
226
|
fig_cm.tight_layout()
|
|
177
227
|
|
|
@@ -196,10 +246,10 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
196
246
|
class_indices_to_plot = [1]
|
|
197
247
|
plot_titles = [""] # No extra title
|
|
198
248
|
save_suffixes = [""] # No extra suffix
|
|
199
|
-
_LOGGER.
|
|
249
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
200
250
|
|
|
201
251
|
elif num_classes > 2:
|
|
202
|
-
_LOGGER.
|
|
252
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
203
253
|
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
204
254
|
class_indices_to_plot = list(range(num_classes))
|
|
205
255
|
|
|
@@ -234,17 +284,54 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
234
284
|
|
|
235
285
|
# Binarize y_true for the current class
|
|
236
286
|
y_true_binary = (y_true == class_index).astype(int)
|
|
237
|
-
|
|
287
|
+
|
|
238
288
|
# --- Save ROC Curve ---
|
|
239
|
-
fpr, tpr,
|
|
289
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
290
|
+
|
|
291
|
+
try:
|
|
292
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
293
|
+
J = tpr - fpr
|
|
294
|
+
# Find the index of the best threshold
|
|
295
|
+
best_index = np.argmax(J)
|
|
296
|
+
optimal_threshold = thresholds[best_index]
|
|
297
|
+
|
|
298
|
+
# Define the filename
|
|
299
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
300
|
+
threshold_path = save_dir_path / threshold_filename
|
|
301
|
+
|
|
302
|
+
# Get the class name for the report
|
|
303
|
+
class_name = ""
|
|
304
|
+
# Check if we have display labels and the current index is valid
|
|
305
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
306
|
+
class_name = map_display_labels[class_index]
|
|
307
|
+
if num_classes > 2:
|
|
308
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
309
|
+
class_name += " (vs. Rest)"
|
|
310
|
+
else:
|
|
311
|
+
# Fallback to the generic title or default binary name
|
|
312
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
313
|
+
|
|
314
|
+
# Create content for the file
|
|
315
|
+
file_content = (
|
|
316
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
317
|
+
f"Class: {class_name}\n"
|
|
318
|
+
f"--------------------------------------------------\n"
|
|
319
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
320
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
321
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
322
|
+
)
|
|
323
|
+
|
|
324
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
325
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
326
|
+
|
|
327
|
+
except Exception as e:
|
|
328
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
240
329
|
|
|
241
330
|
# Calculate AUC.
|
|
242
|
-
# Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
|
|
243
|
-
# Here we calculate the specific AUC for the binarized problem.
|
|
244
331
|
auc = roc_auc_score(y_true_binary, y_score)
|
|
245
332
|
|
|
246
333
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
247
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
334
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line)
|
|
248
335
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
249
336
|
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
250
337
|
ax_roc.set_xlabel('False Positive Rate')
|
|
@@ -259,7 +346,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
259
346
|
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
260
347
|
ap_score = average_precision_score(y_true_binary, y_score)
|
|
261
348
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
262
|
-
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
|
|
349
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=format_config.ROC_PR_line)
|
|
263
350
|
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
264
351
|
ax_pr.set_xlabel('Recall')
|
|
265
352
|
ax_pr.set_ylabel('Precision')
|
|
@@ -278,7 +365,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
278
365
|
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
279
366
|
y_true_binary, # Use binarized labels
|
|
280
367
|
y_score,
|
|
281
|
-
n_bins=calibration_bins,
|
|
368
|
+
n_bins=format_config.calibration_bins,
|
|
282
369
|
ax=ax_temp,
|
|
283
370
|
name="temp" # Add a name to suppress potential warnings
|
|
284
371
|
)
|
|
@@ -294,9 +381,9 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
294
381
|
y=line_y,
|
|
295
382
|
ax=ax_cal,
|
|
296
383
|
scatter=False,
|
|
297
|
-
label=f"Calibration Curve ({calibration_bins} bins)",
|
|
384
|
+
label=f"Calibration Curve ({format_config.calibration_bins} bins)",
|
|
298
385
|
line_kws={
|
|
299
|
-
'color': ROC_PR_line,
|
|
386
|
+
'color': format_config.ROC_PR_line,
|
|
300
387
|
'linestyle': '--',
|
|
301
388
|
'linewidth': 2,
|
|
302
389
|
}
|
|
@@ -324,7 +411,12 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
324
411
|
plt.rcParams.update(original_rc_params)
|
|
325
412
|
|
|
326
413
|
|
|
327
|
-
def regression_metrics(
|
|
414
|
+
def regression_metrics(
|
|
415
|
+
y_true: np.ndarray,
|
|
416
|
+
y_pred: np.ndarray,
|
|
417
|
+
save_dir: Union[str, Path],
|
|
418
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
419
|
+
):
|
|
328
420
|
"""
|
|
329
421
|
Saves regression metrics and plots.
|
|
330
422
|
|
|
@@ -332,7 +424,21 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
332
424
|
y_true (np.ndarray): Ground truth values.
|
|
333
425
|
y_pred (np.ndarray): Predicted values.
|
|
334
426
|
save_dir (str | Path): Directory to save plots and report.
|
|
427
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
335
428
|
"""
|
|
429
|
+
|
|
430
|
+
# --- Parse Config or use defaults ---
|
|
431
|
+
if config is None:
|
|
432
|
+
# Create a default config if one wasn't provided
|
|
433
|
+
format_config = _BaseRegressionFormat()
|
|
434
|
+
else:
|
|
435
|
+
format_config = config
|
|
436
|
+
|
|
437
|
+
# --- Set Matplotlib font size ---
|
|
438
|
+
original_rc_params = plt.rcParams.copy()
|
|
439
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
440
|
+
|
|
441
|
+
# --- Calculate Metrics ---
|
|
336
442
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
337
443
|
mae = mean_absolute_error(y_true, y_pred)
|
|
338
444
|
r2 = r2_score(y_true, y_pred)
|
|
@@ -354,11 +460,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
354
460
|
report_path.write_text(report_string)
|
|
355
461
|
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
356
462
|
|
|
357
|
-
# Save residual plot
|
|
463
|
+
# --- Save residual plot ---
|
|
358
464
|
residuals = y_true - y_pred
|
|
359
465
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
360
|
-
ax_res.scatter(y_pred, residuals,
|
|
361
|
-
|
|
466
|
+
ax_res.scatter(y_pred, residuals,
|
|
467
|
+
alpha=format_config.scatter_alpha,
|
|
468
|
+
color=format_config.scatter_color)
|
|
469
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--')
|
|
362
470
|
ax_res.set_xlabel("Predicted Values")
|
|
363
471
|
ax_res.set_ylabel("Residuals")
|
|
364
472
|
ax_res.set_title("Residual Plot")
|
|
@@ -369,10 +477,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
369
477
|
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
370
478
|
plt.close(fig_res)
|
|
371
479
|
|
|
372
|
-
# Save true vs predicted plot
|
|
480
|
+
# --- Save true vs predicted plot ---
|
|
373
481
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
374
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
375
|
-
|
|
482
|
+
ax_tvp.scatter(y_true, y_pred,
|
|
483
|
+
alpha=format_config.scatter_alpha,
|
|
484
|
+
color=format_config.scatter_color)
|
|
485
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
486
|
+
linestyle='--',
|
|
487
|
+
lw=2,
|
|
488
|
+
color=format_config.ideal_line_color)
|
|
376
489
|
ax_tvp.set_xlabel('True Values')
|
|
377
490
|
ax_tvp.set_ylabel('Predictions')
|
|
378
491
|
ax_tvp.set_title('True vs. Predicted Values')
|
|
@@ -383,9 +496,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
383
496
|
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
384
497
|
plt.close(fig_tvp)
|
|
385
498
|
|
|
386
|
-
# Save Histogram of Residuals
|
|
499
|
+
# --- Save Histogram of Residuals ---
|
|
387
500
|
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
388
|
-
sns.histplot(residuals, kde=True, ax=ax_hist
|
|
501
|
+
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
502
|
+
bins=format_config.hist_bins,
|
|
503
|
+
color=format_config.scatter_color)
|
|
389
504
|
ax_hist.set_xlabel("Residual Value")
|
|
390
505
|
ax_hist.set_ylabel("Frequency")
|
|
391
506
|
ax_hist.set_title("Distribution of Residuals")
|
|
@@ -396,6 +511,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
396
511
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
397
512
|
plt.close(fig_hist)
|
|
398
513
|
|
|
514
|
+
# --- Restore RC params ---
|
|
515
|
+
plt.rcParams.update(original_rc_params)
|
|
516
|
+
|
|
399
517
|
|
|
400
518
|
def shap_summary_plot(model,
|
|
401
519
|
background_data: Union[torch.Tensor,np.ndarray],
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,13 +19,17 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal
|
|
22
|
+
from typing import Union, List, Literal, Optional
|
|
23
23
|
import warnings
|
|
24
24
|
|
|
25
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
26
26
|
from ._logger import _LOGGER
|
|
27
27
|
from ._script_info import _script_info
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
|
+
from .ML_configuration import (MultiTargetRegressionMetricsFormat,
|
|
30
|
+
_BaseRegressionFormat,
|
|
31
|
+
MultiLabelBinaryClassificationMetricsFormat,
|
|
32
|
+
_BaseMultiLabelFormat)
|
|
29
33
|
|
|
30
34
|
|
|
31
35
|
__all__ = [
|
|
@@ -41,7 +45,8 @@ def multi_target_regression_metrics(
|
|
|
41
45
|
y_true: np.ndarray,
|
|
42
46
|
y_pred: np.ndarray,
|
|
43
47
|
target_names: List[str],
|
|
44
|
-
save_dir: Union[str, Path]
|
|
48
|
+
save_dir: Union[str, Path],
|
|
49
|
+
config: Optional[MultiTargetRegressionMetricsFormat] = None
|
|
45
50
|
):
|
|
46
51
|
"""
|
|
47
52
|
Calculates and saves regression metrics for each target individually.
|
|
@@ -55,6 +60,7 @@ def multi_target_regression_metrics(
|
|
|
55
60
|
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
56
61
|
target_names (List[str]): A list of names for the target variables.
|
|
57
62
|
save_dir (str | Path): Directory to save plots and the report.
|
|
63
|
+
config (object): Formatting configuration object.
|
|
58
64
|
"""
|
|
59
65
|
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
60
66
|
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
@@ -68,8 +74,19 @@ def multi_target_regression_metrics(
|
|
|
68
74
|
|
|
69
75
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
70
76
|
metrics_summary = []
|
|
77
|
+
|
|
78
|
+
# --- Parse Config or use defaults ---
|
|
79
|
+
if config is None:
|
|
80
|
+
# Create a default config if one wasn't provided
|
|
81
|
+
format_config = _BaseRegressionFormat()
|
|
82
|
+
else:
|
|
83
|
+
format_config = config
|
|
84
|
+
|
|
85
|
+
# --- Set Matplotlib font size ---
|
|
86
|
+
original_rc_params = plt.rcParams.copy()
|
|
87
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
71
88
|
|
|
72
|
-
_LOGGER.
|
|
89
|
+
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
73
90
|
|
|
74
91
|
for i, name in enumerate(target_names):
|
|
75
92
|
print(f" -> Evaluating target: '{name}'")
|
|
@@ -93,8 +110,12 @@ def multi_target_regression_metrics(
|
|
|
93
110
|
# --- Save Residual Plot ---
|
|
94
111
|
residuals = true_i - pred_i
|
|
95
112
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
96
|
-
ax_res.scatter(pred_i, residuals,
|
|
97
|
-
|
|
113
|
+
ax_res.scatter(pred_i, residuals,
|
|
114
|
+
alpha=format_config.scatter_alpha,
|
|
115
|
+
edgecolors='k',
|
|
116
|
+
s=50,
|
|
117
|
+
color=format_config.scatter_color) # Use config color
|
|
118
|
+
ax_res.axhline(0, color=format_config.residual_line_color, linestyle='--') # Use config color
|
|
98
119
|
ax_res.set_xlabel("Predicted Values")
|
|
99
120
|
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
100
121
|
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
@@ -106,8 +127,15 @@ def multi_target_regression_metrics(
|
|
|
106
127
|
|
|
107
128
|
# --- Save True vs. Predicted Plot ---
|
|
108
129
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
109
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
110
|
-
|
|
130
|
+
ax_tvp.scatter(true_i, pred_i,
|
|
131
|
+
alpha=format_config.scatter_alpha,
|
|
132
|
+
edgecolors='k',
|
|
133
|
+
s=50,
|
|
134
|
+
color=format_config.scatter_color) # Use config color
|
|
135
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
136
|
+
linestyle='--',
|
|
137
|
+
lw=2,
|
|
138
|
+
color=format_config.ideal_line_color) # Use config color
|
|
111
139
|
ax_tvp.set_xlabel('True Values')
|
|
112
140
|
ax_tvp.set_ylabel('Predicted Values')
|
|
113
141
|
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
@@ -122,17 +150,18 @@ def multi_target_regression_metrics(
|
|
|
122
150
|
report_path = save_dir_path / "regression_report_multi.csv"
|
|
123
151
|
summary_df.to_csv(report_path, index=False)
|
|
124
152
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
153
|
+
|
|
154
|
+
# --- Restore RC params ---
|
|
155
|
+
plt.rcParams.update(original_rc_params)
|
|
125
156
|
|
|
126
157
|
|
|
127
158
|
def multi_label_classification_metrics(
|
|
128
159
|
y_true: np.ndarray,
|
|
160
|
+
y_pred: np.ndarray,
|
|
129
161
|
y_prob: np.ndarray,
|
|
130
162
|
target_names: List[str],
|
|
131
163
|
save_dir: Union[str, Path],
|
|
132
|
-
|
|
133
|
-
ROC_PR_line: str='darkorange',
|
|
134
|
-
cmap: str = "Blues",
|
|
135
|
-
font_size: int = 16
|
|
164
|
+
config: Optional[MultiLabelBinaryClassificationMetricsFormat] = None
|
|
136
165
|
):
|
|
137
166
|
"""
|
|
138
167
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -143,17 +172,17 @@ def multi_label_classification_metrics(
|
|
|
143
172
|
|
|
144
173
|
Args:
|
|
145
174
|
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
175
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
146
176
|
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
147
177
|
target_names (List[str]): A list of names for the labels.
|
|
148
178
|
save_dir (str | Path): Directory to save plots and reports.
|
|
149
|
-
|
|
150
|
-
binary predictions for metrics like the confusion matrix.
|
|
179
|
+
config (object): Formatting configuration object.
|
|
151
180
|
"""
|
|
152
|
-
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
153
|
-
_LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
181
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
182
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
154
183
|
raise ValueError()
|
|
155
|
-
if y_true.shape != y_prob.shape:
|
|
156
|
-
_LOGGER.error("Shapes of y_true and y_prob must match.")
|
|
184
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
185
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
157
186
|
raise ValueError()
|
|
158
187
|
if y_true.shape[1] != len(target_names):
|
|
159
188
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
@@ -161,22 +190,28 @@ def multi_label_classification_metrics(
|
|
|
161
190
|
|
|
162
191
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
163
192
|
|
|
164
|
-
#
|
|
165
|
-
|
|
193
|
+
# --- Parse Config or use defaults ---
|
|
194
|
+
if config is None:
|
|
195
|
+
# Create a default config if one wasn't provided
|
|
196
|
+
format_config = _BaseMultiLabelFormat()
|
|
197
|
+
else:
|
|
198
|
+
format_config = config
|
|
199
|
+
|
|
200
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
166
201
|
|
|
167
202
|
# --- Save current RC params and update font size ---
|
|
168
203
|
original_rc_params = plt.rcParams.copy()
|
|
169
|
-
plt.rcParams.update({'font.size': font_size})
|
|
204
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
170
205
|
|
|
171
|
-
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
206
|
+
# _LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
172
207
|
|
|
173
|
-
# --- Calculate and Save Overall Metrics ---
|
|
208
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
174
209
|
h_loss = hamming_loss(y_true, y_pred)
|
|
175
210
|
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
176
211
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
177
212
|
|
|
178
213
|
overall_report = (
|
|
179
|
-
f"Overall Multi-Label Metrics
|
|
214
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
180
215
|
f"--------------------------------------------------\n"
|
|
181
216
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
182
217
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
@@ -191,20 +226,20 @@ def multi_label_classification_metrics(
|
|
|
191
226
|
for i, name in enumerate(target_names):
|
|
192
227
|
print(f" -> Evaluating label: '{name}'")
|
|
193
228
|
true_i = y_true[:, i]
|
|
194
|
-
pred_i = y_pred[:, i]
|
|
195
|
-
prob_i = y_prob[:, i]
|
|
229
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
230
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
196
231
|
sanitized_name = sanitize_filename(name)
|
|
197
232
|
|
|
198
|
-
# --- Save Classification Report for the label ---
|
|
233
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
199
234
|
report_text = classification_report(true_i, pred_i)
|
|
200
235
|
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
201
236
|
report_path.write_text(report_text) # type: ignore
|
|
202
237
|
|
|
203
|
-
# --- Save Confusion Matrix ---
|
|
238
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
204
239
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
205
240
|
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
206
241
|
pred_i,
|
|
207
|
-
cmap=cmap,
|
|
242
|
+
cmap=format_config.cmap, # Use config cmap
|
|
208
243
|
ax=ax_cm,
|
|
209
244
|
normalize='true',
|
|
210
245
|
labels=[0, 1],
|
|
@@ -217,7 +252,7 @@ def multi_label_classification_metrics(
|
|
|
217
252
|
|
|
218
253
|
# Manually update font size of cell texts
|
|
219
254
|
for text in ax_cm.texts:
|
|
220
|
-
text.set_fontsize(font_size)
|
|
255
|
+
text.set_fontsize(format_config.font_size) # Use config font_size
|
|
221
256
|
|
|
222
257
|
fig_cm.tight_layout()
|
|
223
258
|
|
|
@@ -226,11 +261,44 @@ def multi_label_classification_metrics(
|
|
|
226
261
|
plt.savefig(cm_path)
|
|
227
262
|
plt.close(fig_cm)
|
|
228
263
|
|
|
229
|
-
# --- Save ROC Curve ---
|
|
230
|
-
fpr, tpr,
|
|
264
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
265
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
266
|
+
|
|
267
|
+
try:
|
|
268
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
269
|
+
J = tpr - fpr
|
|
270
|
+
# Find the index of the best threshold
|
|
271
|
+
best_index = np.argmax(J)
|
|
272
|
+
optimal_threshold = thresholds[best_index]
|
|
273
|
+
best_tpr = tpr[best_index]
|
|
274
|
+
best_fpr = fpr[best_index]
|
|
275
|
+
|
|
276
|
+
# Define the filename
|
|
277
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
278
|
+
threshold_path = save_dir_path / threshold_filename
|
|
279
|
+
|
|
280
|
+
# The class name is the target_name for this label
|
|
281
|
+
class_name = name
|
|
282
|
+
|
|
283
|
+
# Create content for the file
|
|
284
|
+
file_content = (
|
|
285
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
286
|
+
f"Class/Label: {class_name}\n"
|
|
287
|
+
f"--------------------------------------------------\n"
|
|
288
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
289
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
290
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
294
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
295
|
+
|
|
296
|
+
except Exception as e:
|
|
297
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
298
|
+
|
|
231
299
|
auc = roc_auc_score(true_i, prob_i)
|
|
232
300
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
301
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
234
302
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
235
303
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
236
304
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -239,11 +307,11 @@ def multi_label_classification_metrics(
|
|
|
239
307
|
plt.savefig(roc_path)
|
|
240
308
|
plt.close(fig_roc)
|
|
241
309
|
|
|
242
|
-
# --- Save Precision-Recall Curve ---
|
|
310
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
243
311
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
244
312
|
ap_score = average_precision_score(true_i, prob_i)
|
|
245
313
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
246
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
|
|
314
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=format_config.ROC_PR_line) # Use config color
|
|
247
315
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
248
316
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
249
317
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|