dragon-ml-toolbox 14.8.0__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +204 -11
- ml_tools/ML_datasetmaster.py +198 -280
- ml_tools/ML_evaluation.py +132 -41
- ml_tools/ML_evaluation_multi.py +96 -35
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1237 -354
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +73 -67
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +43 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.8.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation.py
CHANGED
|
@@ -24,7 +24,8 @@ import warnings
|
|
|
24
24
|
from .path_manager import make_fullpath, sanitize_filename
|
|
25
25
|
from ._logger import _LOGGER
|
|
26
26
|
from ._script_info import _script_info
|
|
27
|
-
from .
|
|
27
|
+
from ._keys import SHAPKeys, PyTorchLogKeys
|
|
28
|
+
from .ML_configuration import RegressionMetricsFormat, ClassificationMetricsFormat
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
__all__ = [
|
|
@@ -41,6 +42,7 @@ DPI_value = 250
|
|
|
41
42
|
def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
42
43
|
"""
|
|
43
44
|
Plots training & validation loss curves from a history object.
|
|
45
|
+
Also plots the learning rate if available in the history.
|
|
44
46
|
|
|
45
47
|
Args:
|
|
46
48
|
history (dict): A dictionary containing 'train_loss' and 'val_loss'.
|
|
@@ -48,6 +50,7 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
48
50
|
"""
|
|
49
51
|
train_loss = history.get(PyTorchLogKeys.TRAIN_LOSS, [])
|
|
50
52
|
val_loss = history.get(PyTorchLogKeys.VAL_LOSS, [])
|
|
53
|
+
lr_history = history.get(PyTorchLogKeys.LEARNING_RATE, [])
|
|
51
54
|
|
|
52
55
|
if not train_loss and not val_loss:
|
|
53
56
|
_LOGGER.warning("Loss history is empty or incomplete. Cannot plot.")
|
|
@@ -55,22 +58,44 @@ def plot_losses(history: dict, save_dir: Union[str, Path]):
|
|
|
55
58
|
|
|
56
59
|
fig, ax = plt.subplots(figsize=(10, 5), dpi=DPI_value)
|
|
57
60
|
|
|
61
|
+
# --- Plot Losses (Left Y-axis) ---
|
|
62
|
+
line_handles = [] # To store line objects for the legend
|
|
63
|
+
|
|
58
64
|
# Plot training loss only if data for it exists
|
|
59
65
|
if train_loss:
|
|
60
66
|
epochs = range(1, len(train_loss) + 1)
|
|
61
|
-
ax.plot(epochs, train_loss, 'o-', label='Training Loss')
|
|
67
|
+
line1, = ax.plot(epochs, train_loss, 'o-', label='Training Loss', color='tab:blue')
|
|
68
|
+
line_handles.append(line1)
|
|
62
69
|
|
|
63
70
|
# Plot validation loss only if data for it exists
|
|
64
71
|
if val_loss:
|
|
65
72
|
epochs = range(1, len(val_loss) + 1)
|
|
66
|
-
ax.plot(epochs, val_loss, 'o-', label='Validation Loss')
|
|
73
|
+
line2, = ax.plot(epochs, val_loss, 'o-', label='Validation Loss', color='tab:orange')
|
|
74
|
+
line_handles.append(line2)
|
|
67
75
|
|
|
68
76
|
ax.set_title('Training and Validation Loss')
|
|
69
77
|
ax.set_xlabel('Epochs')
|
|
70
|
-
ax.set_ylabel('Loss')
|
|
71
|
-
ax.
|
|
72
|
-
ax.grid(True)
|
|
73
|
-
|
|
78
|
+
ax.set_ylabel('Loss', color='tab:blue')
|
|
79
|
+
ax.tick_params(axis='y', labelcolor='tab:blue')
|
|
80
|
+
ax.grid(True, linestyle='--')
|
|
81
|
+
|
|
82
|
+
# --- Plot Learning Rate (Right Y-axis) ---
|
|
83
|
+
if lr_history:
|
|
84
|
+
ax2 = ax.twinx() # Create a second y-axis
|
|
85
|
+
epochs = range(1, len(lr_history) + 1)
|
|
86
|
+
line3, = ax2.plot(epochs, lr_history, 'g--', label='Learning Rate')
|
|
87
|
+
line_handles.append(line3)
|
|
88
|
+
|
|
89
|
+
ax2.set_ylabel('Learning Rate', color='g')
|
|
90
|
+
ax2.tick_params(axis='y', labelcolor='g')
|
|
91
|
+
# Use scientific notation if the LR is very small
|
|
92
|
+
ax2.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
93
|
+
|
|
94
|
+
# Combine legends from both axes
|
|
95
|
+
ax.legend(handles=line_handles, loc='best')
|
|
96
|
+
|
|
97
|
+
# ax.grid(True)
|
|
98
|
+
plt.tight_layout()
|
|
74
99
|
|
|
75
100
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
76
101
|
save_path = save_dir_path / "loss_plot.svg"
|
|
@@ -84,34 +109,34 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
84
109
|
y_true: np.ndarray,
|
|
85
110
|
y_pred: np.ndarray,
|
|
86
111
|
y_prob: Optional[np.ndarray] = None,
|
|
87
|
-
|
|
88
|
-
class_map: Optional[dict[str,int]]=None,
|
|
89
|
-
ROC_PR_line: str='darkorange',
|
|
90
|
-
calibration_bins: int=15,
|
|
91
|
-
font_size: int=16):
|
|
112
|
+
config: Optional[ClassificationMetricsFormat] = None):
|
|
92
113
|
"""
|
|
93
114
|
Saves classification metrics and plots.
|
|
94
115
|
|
|
95
116
|
Args:
|
|
96
117
|
y_true (np.ndarray): Ground truth labels.
|
|
97
118
|
y_pred (np.ndarray): Predicted labels.
|
|
98
|
-
y_prob (np.ndarray
|
|
99
|
-
|
|
119
|
+
y_prob (np.ndarray): Predicted probabilities for ROC curve.
|
|
120
|
+
config (ClassificationMetricsFormat): Formatting configuration object.
|
|
100
121
|
save_dir (str | Path): Directory to save plots.
|
|
101
|
-
class_map (dict[str, int], None): A map of {class_name: index} used to order and label the confusion matrix.
|
|
102
122
|
"""
|
|
123
|
+
# --- Parse Config or use defaults ---
|
|
124
|
+
if config is None:
|
|
125
|
+
# Create a default config if one wasn't provided
|
|
126
|
+
config = ClassificationMetricsFormat()
|
|
127
|
+
|
|
103
128
|
original_rc_params = plt.rcParams.copy()
|
|
104
|
-
plt.rcParams.update({'font.size': font_size})
|
|
129
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
105
130
|
|
|
106
131
|
# print("--- Classification Report ---")
|
|
107
132
|
|
|
108
133
|
# --- Parse class_map ---
|
|
109
134
|
map_labels = None
|
|
110
135
|
map_display_labels = None
|
|
111
|
-
if class_map:
|
|
136
|
+
if config.class_map:
|
|
112
137
|
# Sort the map by its values (the indices) to ensure correct order
|
|
113
138
|
try:
|
|
114
|
-
sorted_items = sorted(class_map.items(), key=lambda item: item[1])
|
|
139
|
+
sorted_items = sorted(config.class_map.items(), key=lambda item: item[1])
|
|
115
140
|
map_labels = [item[1] for item in sorted_items]
|
|
116
141
|
map_display_labels = [item[0] for item in sorted_items]
|
|
117
142
|
except Exception as e:
|
|
@@ -136,7 +161,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
136
161
|
sns.set_theme(font_scale=1.2) # Scale seaborn font
|
|
137
162
|
sns.heatmap(pd.DataFrame(report_dict).iloc[:-1, :].T,
|
|
138
163
|
annot=True,
|
|
139
|
-
cmap=cmap,
|
|
164
|
+
cmap=config.cmap,
|
|
140
165
|
fmt='.2f',
|
|
141
166
|
vmin=0.0,
|
|
142
167
|
vmax=1.0)
|
|
@@ -158,7 +183,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
158
183
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
159
184
|
disp_ = ConfusionMatrixDisplay.from_predictions(y_true,
|
|
160
185
|
y_pred,
|
|
161
|
-
cmap=cmap,
|
|
186
|
+
cmap=config.cmap,
|
|
162
187
|
ax=ax_cm,
|
|
163
188
|
normalize='true',
|
|
164
189
|
labels=plot_labels,
|
|
@@ -171,7 +196,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
171
196
|
|
|
172
197
|
# Manually update font size of cell texts
|
|
173
198
|
for text in ax_cm.texts:
|
|
174
|
-
text.set_fontsize(font_size)
|
|
199
|
+
text.set_fontsize(config.font_size)
|
|
175
200
|
|
|
176
201
|
fig_cm.tight_layout()
|
|
177
202
|
|
|
@@ -196,10 +221,10 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
196
221
|
class_indices_to_plot = [1]
|
|
197
222
|
plot_titles = [""] # No extra title
|
|
198
223
|
save_suffixes = [""] # No extra suffix
|
|
199
|
-
_LOGGER.
|
|
224
|
+
_LOGGER.debug("Generating binary classification plots (ROC, PR, Calibration).")
|
|
200
225
|
|
|
201
226
|
elif num_classes > 2:
|
|
202
|
-
_LOGGER.
|
|
227
|
+
_LOGGER.debug(f"Generating One-vs-Rest plots for {num_classes} classes.")
|
|
203
228
|
# Multiclass case: Plot for every class (One-vs-Rest)
|
|
204
229
|
class_indices_to_plot = list(range(num_classes))
|
|
205
230
|
|
|
@@ -234,17 +259,54 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
234
259
|
|
|
235
260
|
# Binarize y_true for the current class
|
|
236
261
|
y_true_binary = (y_true == class_index).astype(int)
|
|
237
|
-
|
|
262
|
+
|
|
238
263
|
# --- Save ROC Curve ---
|
|
239
|
-
fpr, tpr,
|
|
264
|
+
fpr, tpr, thresholds = roc_curve(y_true_binary, y_score)
|
|
265
|
+
|
|
266
|
+
try:
|
|
267
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
268
|
+
J = tpr - fpr
|
|
269
|
+
# Find the index of the best threshold
|
|
270
|
+
best_index = np.argmax(J)
|
|
271
|
+
optimal_threshold = thresholds[best_index]
|
|
272
|
+
|
|
273
|
+
# Define the filename
|
|
274
|
+
threshold_filename = f"best_threshold{save_suffix}.txt"
|
|
275
|
+
threshold_path = save_dir_path / threshold_filename
|
|
276
|
+
|
|
277
|
+
# Get the class name for the report
|
|
278
|
+
class_name = ""
|
|
279
|
+
# Check if we have display labels and the current index is valid
|
|
280
|
+
if map_display_labels and class_index < len(map_display_labels):
|
|
281
|
+
class_name = map_display_labels[class_index]
|
|
282
|
+
if num_classes > 2:
|
|
283
|
+
# Add 'vs. Rest' for multiclass one-vs-rest plots
|
|
284
|
+
class_name += " (vs. Rest)"
|
|
285
|
+
else:
|
|
286
|
+
# Fallback to the generic title or default binary name
|
|
287
|
+
class_name = plot_title.strip() or "Binary Positive Class"
|
|
288
|
+
|
|
289
|
+
# Create content for the file
|
|
290
|
+
file_content = (
|
|
291
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
292
|
+
f"Class: {class_name}\n"
|
|
293
|
+
f"--------------------------------------------------\n"
|
|
294
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
295
|
+
f"True Positive Rate (TPR): {tpr[best_index]:.6f}\n"
|
|
296
|
+
f"False Positive Rate (FPR): {fpr[best_index]:.6f}\n"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
300
|
+
_LOGGER.info(f"💾 Optimal threshold saved as '{threshold_path.name}'")
|
|
301
|
+
|
|
302
|
+
except Exception as e:
|
|
303
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold: {e}")
|
|
240
304
|
|
|
241
305
|
# Calculate AUC.
|
|
242
|
-
# Note: For multiclass, roc_auc_score(y_true, y_prob, multi_class='ovr') could average, but plotting individual curves is more informative.
|
|
243
|
-
# Here we calculate the specific AUC for the binarized problem.
|
|
244
306
|
auc = roc_auc_score(y_true_binary, y_score)
|
|
245
307
|
|
|
246
308
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
247
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
309
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line)
|
|
248
310
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
249
311
|
ax_roc.set_title(f'Receiver Operating Characteristic{plot_title}')
|
|
250
312
|
ax_roc.set_xlabel('False Positive Rate')
|
|
@@ -259,7 +321,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
259
321
|
precision, recall, _ = precision_recall_curve(y_true_binary, y_score)
|
|
260
322
|
ap_score = average_precision_score(y_true_binary, y_score)
|
|
261
323
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
262
|
-
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=ROC_PR_line)
|
|
324
|
+
ax_pr.plot(recall, precision, label=f'Avg Precision = {ap_score:.2f}', color=config.ROC_PR_line)
|
|
263
325
|
ax_pr.set_title(f'Precision-Recall Curve{plot_title}')
|
|
264
326
|
ax_pr.set_xlabel('Recall')
|
|
265
327
|
ax_pr.set_ylabel('Precision')
|
|
@@ -278,7 +340,7 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
278
340
|
cal_display_temp = CalibrationDisplay.from_predictions(
|
|
279
341
|
y_true_binary, # Use binarized labels
|
|
280
342
|
y_score,
|
|
281
|
-
n_bins=calibration_bins,
|
|
343
|
+
n_bins=config.calibration_bins,
|
|
282
344
|
ax=ax_temp,
|
|
283
345
|
name="temp" # Add a name to suppress potential warnings
|
|
284
346
|
)
|
|
@@ -294,9 +356,9 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
294
356
|
y=line_y,
|
|
295
357
|
ax=ax_cal,
|
|
296
358
|
scatter=False,
|
|
297
|
-
label=f"Calibration Curve ({calibration_bins} bins)",
|
|
359
|
+
label=f"Calibration Curve ({config.calibration_bins} bins)",
|
|
298
360
|
line_kws={
|
|
299
|
-
'color': ROC_PR_line,
|
|
361
|
+
'color': config.ROC_PR_line,
|
|
300
362
|
'linestyle': '--',
|
|
301
363
|
'linewidth': 2,
|
|
302
364
|
}
|
|
@@ -324,7 +386,12 @@ def classification_metrics(save_dir: Union[str, Path],
|
|
|
324
386
|
plt.rcParams.update(original_rc_params)
|
|
325
387
|
|
|
326
388
|
|
|
327
|
-
def regression_metrics(
|
|
389
|
+
def regression_metrics(
|
|
390
|
+
y_true: np.ndarray,
|
|
391
|
+
y_pred: np.ndarray,
|
|
392
|
+
save_dir: Union[str, Path],
|
|
393
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
394
|
+
):
|
|
328
395
|
"""
|
|
329
396
|
Saves regression metrics and plots.
|
|
330
397
|
|
|
@@ -332,7 +399,19 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
332
399
|
y_true (np.ndarray): Ground truth values.
|
|
333
400
|
y_pred (np.ndarray): Predicted values.
|
|
334
401
|
save_dir (str | Path): Directory to save plots and report.
|
|
402
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
335
403
|
"""
|
|
404
|
+
|
|
405
|
+
# --- Parse Config or use defaults ---
|
|
406
|
+
if config is None:
|
|
407
|
+
# Create a default config if one wasn't provided
|
|
408
|
+
config = RegressionMetricsFormat()
|
|
409
|
+
|
|
410
|
+
# --- Set Matplotlib font size ---
|
|
411
|
+
original_rc_params = plt.rcParams.copy()
|
|
412
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
413
|
+
|
|
414
|
+
# --- Calculate Metrics ---
|
|
336
415
|
rmse = np.sqrt(mean_squared_error(y_true, y_pred))
|
|
337
416
|
mae = mean_absolute_error(y_true, y_pred)
|
|
338
417
|
r2 = r2_score(y_true, y_pred)
|
|
@@ -354,11 +433,13 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
354
433
|
report_path.write_text(report_string)
|
|
355
434
|
_LOGGER.info(f"📝 Regression report saved as '{report_path.name}'")
|
|
356
435
|
|
|
357
|
-
# Save residual plot
|
|
436
|
+
# --- Save residual plot ---
|
|
358
437
|
residuals = y_true - y_pred
|
|
359
438
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
360
|
-
ax_res.scatter(y_pred, residuals,
|
|
361
|
-
|
|
439
|
+
ax_res.scatter(y_pred, residuals,
|
|
440
|
+
alpha=config.scatter_alpha,
|
|
441
|
+
color=config.scatter_color)
|
|
442
|
+
ax_res.axhline(0, color=config.residual_line_color, linestyle='--')
|
|
362
443
|
ax_res.set_xlabel("Predicted Values")
|
|
363
444
|
ax_res.set_ylabel("Residuals")
|
|
364
445
|
ax_res.set_title("Residual Plot")
|
|
@@ -369,10 +450,15 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
369
450
|
_LOGGER.info(f"📈 Residual plot saved as '{res_path.name}'")
|
|
370
451
|
plt.close(fig_res)
|
|
371
452
|
|
|
372
|
-
# Save true vs predicted plot
|
|
453
|
+
# --- Save true vs predicted plot ---
|
|
373
454
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
374
|
-
ax_tvp.scatter(y_true, y_pred,
|
|
375
|
-
|
|
455
|
+
ax_tvp.scatter(y_true, y_pred,
|
|
456
|
+
alpha=config.scatter_alpha,
|
|
457
|
+
color=config.scatter_color)
|
|
458
|
+
ax_tvp.plot([y_true.min(), y_true.max()], [y_true.min(), y_true.max()],
|
|
459
|
+
linestyle='--',
|
|
460
|
+
lw=2,
|
|
461
|
+
color=config.ideal_line_color)
|
|
376
462
|
ax_tvp.set_xlabel('True Values')
|
|
377
463
|
ax_tvp.set_ylabel('Predictions')
|
|
378
464
|
ax_tvp.set_title('True vs. Predicted Values')
|
|
@@ -383,9 +469,11 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
383
469
|
_LOGGER.info(f"📉 True vs. Predicted plot saved as '{tvp_path.name}'")
|
|
384
470
|
plt.close(fig_tvp)
|
|
385
471
|
|
|
386
|
-
# Save Histogram of Residuals
|
|
472
|
+
# --- Save Histogram of Residuals ---
|
|
387
473
|
fig_hist, ax_hist = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
388
|
-
sns.histplot(residuals, kde=True, ax=ax_hist
|
|
474
|
+
sns.histplot(residuals, kde=True, ax=ax_hist,
|
|
475
|
+
bins=config.hist_bins,
|
|
476
|
+
color=config.scatter_color)
|
|
389
477
|
ax_hist.set_xlabel("Residual Value")
|
|
390
478
|
ax_hist.set_ylabel("Frequency")
|
|
391
479
|
ax_hist.set_title("Distribution of Residuals")
|
|
@@ -396,6 +484,9 @@ def regression_metrics(y_true: np.ndarray, y_pred: np.ndarray, save_dir: Union[s
|
|
|
396
484
|
_LOGGER.info(f"📊 Residuals histogram saved as '{hist_path.name}'")
|
|
397
485
|
plt.close(fig_hist)
|
|
398
486
|
|
|
487
|
+
# --- Restore RC params ---
|
|
488
|
+
plt.rcParams.update(original_rc_params)
|
|
489
|
+
|
|
399
490
|
|
|
400
491
|
def shap_summary_plot(model,
|
|
401
492
|
background_data: Union[torch.Tensor,np.ndarray],
|
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,13 +19,14 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal
|
|
22
|
+
from typing import Union, List, Literal, Optional
|
|
23
23
|
import warnings
|
|
24
24
|
|
|
25
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
26
26
|
from ._logger import _LOGGER
|
|
27
27
|
from ._script_info import _script_info
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
|
+
from .ML_configuration import RegressionMetricsFormat, MultiClassificationMetricsFormat
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
__all__ = [
|
|
@@ -41,7 +42,8 @@ def multi_target_regression_metrics(
|
|
|
41
42
|
y_true: np.ndarray,
|
|
42
43
|
y_pred: np.ndarray,
|
|
43
44
|
target_names: List[str],
|
|
44
|
-
save_dir: Union[str, Path]
|
|
45
|
+
save_dir: Union[str, Path],
|
|
46
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
45
47
|
):
|
|
46
48
|
"""
|
|
47
49
|
Calculates and saves regression metrics for each target individually.
|
|
@@ -55,6 +57,7 @@ def multi_target_regression_metrics(
|
|
|
55
57
|
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
56
58
|
target_names (List[str]): A list of names for the target variables.
|
|
57
59
|
save_dir (str | Path): Directory to save plots and the report.
|
|
60
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
58
61
|
"""
|
|
59
62
|
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
60
63
|
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
@@ -68,8 +71,17 @@ def multi_target_regression_metrics(
|
|
|
68
71
|
|
|
69
72
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
70
73
|
metrics_summary = []
|
|
74
|
+
|
|
75
|
+
# --- Parse Config or use defaults ---
|
|
76
|
+
if config is None:
|
|
77
|
+
# Create a default config if one wasn't provided
|
|
78
|
+
config = RegressionMetricsFormat()
|
|
79
|
+
|
|
80
|
+
# --- Set Matplotlib font size ---
|
|
81
|
+
original_rc_params = plt.rcParams.copy()
|
|
82
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
71
83
|
|
|
72
|
-
_LOGGER.
|
|
84
|
+
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
73
85
|
|
|
74
86
|
for i, name in enumerate(target_names):
|
|
75
87
|
print(f" -> Evaluating target: '{name}'")
|
|
@@ -93,8 +105,12 @@ def multi_target_regression_metrics(
|
|
|
93
105
|
# --- Save Residual Plot ---
|
|
94
106
|
residuals = true_i - pred_i
|
|
95
107
|
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
96
|
-
ax_res.scatter(pred_i, residuals,
|
|
97
|
-
|
|
108
|
+
ax_res.scatter(pred_i, residuals,
|
|
109
|
+
alpha=config.scatter_alpha,
|
|
110
|
+
edgecolors='k',
|
|
111
|
+
s=50,
|
|
112
|
+
color=config.scatter_color) # Use config color
|
|
113
|
+
ax_res.axhline(0, color=config.residual_line_color, linestyle='--') # Use config color
|
|
98
114
|
ax_res.set_xlabel("Predicted Values")
|
|
99
115
|
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
100
116
|
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
@@ -106,8 +122,15 @@ def multi_target_regression_metrics(
|
|
|
106
122
|
|
|
107
123
|
# --- Save True vs. Predicted Plot ---
|
|
108
124
|
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
109
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
110
|
-
|
|
125
|
+
ax_tvp.scatter(true_i, pred_i,
|
|
126
|
+
alpha=config.scatter_alpha,
|
|
127
|
+
edgecolors='k',
|
|
128
|
+
s=50,
|
|
129
|
+
color=config.scatter_color) # Use config color
|
|
130
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
131
|
+
linestyle='--',
|
|
132
|
+
lw=2,
|
|
133
|
+
color=config.ideal_line_color) # Use config color
|
|
111
134
|
ax_tvp.set_xlabel('True Values')
|
|
112
135
|
ax_tvp.set_ylabel('Predicted Values')
|
|
113
136
|
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
@@ -122,17 +145,18 @@ def multi_target_regression_metrics(
|
|
|
122
145
|
report_path = save_dir_path / "regression_report_multi.csv"
|
|
123
146
|
summary_df.to_csv(report_path, index=False)
|
|
124
147
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
148
|
+
|
|
149
|
+
# --- Restore RC params ---
|
|
150
|
+
plt.rcParams.update(original_rc_params)
|
|
125
151
|
|
|
126
152
|
|
|
127
153
|
def multi_label_classification_metrics(
|
|
128
154
|
y_true: np.ndarray,
|
|
155
|
+
y_pred: np.ndarray,
|
|
129
156
|
y_prob: np.ndarray,
|
|
130
157
|
target_names: List[str],
|
|
131
158
|
save_dir: Union[str, Path],
|
|
132
|
-
|
|
133
|
-
ROC_PR_line: str='darkorange',
|
|
134
|
-
cmap: str = "Blues",
|
|
135
|
-
font_size: int = 16
|
|
159
|
+
config: Optional[MultiClassificationMetricsFormat] = None # Add config object
|
|
136
160
|
):
|
|
137
161
|
"""
|
|
138
162
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -143,17 +167,17 @@ def multi_label_classification_metrics(
|
|
|
143
167
|
|
|
144
168
|
Args:
|
|
145
169
|
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
170
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
146
171
|
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
147
172
|
target_names (List[str]): A list of names for the labels.
|
|
148
173
|
save_dir (str | Path): Directory to save plots and reports.
|
|
149
|
-
|
|
150
|
-
binary predictions for metrics like the confusion matrix.
|
|
174
|
+
config (MultiClassificationMetricsFormat, optional): Formatting configuration object.
|
|
151
175
|
"""
|
|
152
|
-
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
153
|
-
_LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
176
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
177
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
154
178
|
raise ValueError()
|
|
155
|
-
if y_true.shape != y_prob.shape:
|
|
156
|
-
_LOGGER.error("Shapes of y_true and y_prob must match.")
|
|
179
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
180
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
157
181
|
raise ValueError()
|
|
158
182
|
if y_true.shape[1] != len(target_names):
|
|
159
183
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
@@ -161,22 +185,26 @@ def multi_label_classification_metrics(
|
|
|
161
185
|
|
|
162
186
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
163
187
|
|
|
164
|
-
#
|
|
165
|
-
|
|
188
|
+
# --- Parse Config or use defaults ---
|
|
189
|
+
if config is None:
|
|
190
|
+
# Create a default config if one wasn't provided
|
|
191
|
+
config = MultiClassificationMetricsFormat()
|
|
192
|
+
|
|
193
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
166
194
|
|
|
167
195
|
# --- Save current RC params and update font size ---
|
|
168
196
|
original_rc_params = plt.rcParams.copy()
|
|
169
|
-
plt.rcParams.update({'font.size': font_size})
|
|
197
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
170
198
|
|
|
171
|
-
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
199
|
+
# _LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
172
200
|
|
|
173
|
-
# --- Calculate and Save Overall Metrics ---
|
|
201
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
174
202
|
h_loss = hamming_loss(y_true, y_pred)
|
|
175
203
|
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
176
204
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
177
205
|
|
|
178
206
|
overall_report = (
|
|
179
|
-
f"Overall Multi-Label Metrics
|
|
207
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
180
208
|
f"--------------------------------------------------\n"
|
|
181
209
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
182
210
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
@@ -191,20 +219,20 @@ def multi_label_classification_metrics(
|
|
|
191
219
|
for i, name in enumerate(target_names):
|
|
192
220
|
print(f" -> Evaluating label: '{name}'")
|
|
193
221
|
true_i = y_true[:, i]
|
|
194
|
-
pred_i = y_pred[:, i]
|
|
195
|
-
prob_i = y_prob[:, i]
|
|
222
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
223
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
196
224
|
sanitized_name = sanitize_filename(name)
|
|
197
225
|
|
|
198
|
-
# --- Save Classification Report for the label ---
|
|
226
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
199
227
|
report_text = classification_report(true_i, pred_i)
|
|
200
228
|
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
201
229
|
report_path.write_text(report_text) # type: ignore
|
|
202
230
|
|
|
203
|
-
# --- Save Confusion Matrix ---
|
|
231
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
204
232
|
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
205
233
|
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
206
234
|
pred_i,
|
|
207
|
-
cmap=cmap,
|
|
235
|
+
cmap=config.cmap, # Use config cmap
|
|
208
236
|
ax=ax_cm,
|
|
209
237
|
normalize='true',
|
|
210
238
|
labels=[0, 1],
|
|
@@ -217,7 +245,7 @@ def multi_label_classification_metrics(
|
|
|
217
245
|
|
|
218
246
|
# Manually update font size of cell texts
|
|
219
247
|
for text in ax_cm.texts:
|
|
220
|
-
text.set_fontsize(font_size)
|
|
248
|
+
text.set_fontsize(config.font_size) # Use config font_size
|
|
221
249
|
|
|
222
250
|
fig_cm.tight_layout()
|
|
223
251
|
|
|
@@ -226,11 +254,44 @@ def multi_label_classification_metrics(
|
|
|
226
254
|
plt.savefig(cm_path)
|
|
227
255
|
plt.close(fig_cm)
|
|
228
256
|
|
|
229
|
-
# --- Save ROC Curve ---
|
|
230
|
-
fpr, tpr,
|
|
257
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
258
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
262
|
+
J = tpr - fpr
|
|
263
|
+
# Find the index of the best threshold
|
|
264
|
+
best_index = np.argmax(J)
|
|
265
|
+
optimal_threshold = thresholds[best_index]
|
|
266
|
+
best_tpr = tpr[best_index]
|
|
267
|
+
best_fpr = fpr[best_index]
|
|
268
|
+
|
|
269
|
+
# Define the filename
|
|
270
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
271
|
+
threshold_path = save_dir_path / threshold_filename
|
|
272
|
+
|
|
273
|
+
# The class name is the target_name for this label
|
|
274
|
+
class_name = name
|
|
275
|
+
|
|
276
|
+
# Create content for the file
|
|
277
|
+
file_content = (
|
|
278
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
279
|
+
f"Class/Label: {class_name}\n"
|
|
280
|
+
f"--------------------------------------------------\n"
|
|
281
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
282
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
283
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
287
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
288
|
+
|
|
289
|
+
except Exception as e:
|
|
290
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
291
|
+
|
|
231
292
|
auc = roc_auc_score(true_i, prob_i)
|
|
232
293
|
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=ROC_PR_line)
|
|
294
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line) # Use config color
|
|
234
295
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
235
296
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
236
297
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -239,11 +300,11 @@ def multi_label_classification_metrics(
|
|
|
239
300
|
plt.savefig(roc_path)
|
|
240
301
|
plt.close(fig_roc)
|
|
241
302
|
|
|
242
|
-
# --- Save Precision-Recall Curve ---
|
|
303
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
243
304
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
244
305
|
ap_score = average_precision_score(true_i, prob_i)
|
|
245
306
|
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
246
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=ROC_PR_line)
|
|
307
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=config.ROC_PR_line) # Use config color
|
|
247
308
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
248
309
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
249
310
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|