dragon-ml-toolbox 5.3.1__py3-none-any.whl → 6.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/METADATA +9 -6
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/RECORD +15 -14
- ml_tools/ML_callbacks.py +6 -6
- ml_tools/ML_evaluation.py +154 -95
- ml_tools/ML_trainer.py +13 -13
- ml_tools/PSO_optimization.py +5 -5
- ml_tools/ensemble_evaluation.py +639 -0
- ml_tools/ensemble_inference.py +10 -10
- ml_tools/ensemble_learning.py +47 -413
- ml_tools/keys.py +2 -2
- ml_tools/utilities.py +27 -3
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-5.3.1.dist-info → dragon_ml_toolbox-6.0.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,639 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
import numpy as np
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
from matplotlib.colors import Colormap
|
|
6
|
+
from matplotlib import rcdefaults
|
|
7
|
+
import shap
|
|
8
|
+
import xgboost as xgb
|
|
9
|
+
import lightgbm as lgb
|
|
10
|
+
from sklearn.model_selection import learning_curve
|
|
11
|
+
from sklearn.calibration import CalibrationDisplay
|
|
12
|
+
from sklearn.metrics import (accuracy_score,
|
|
13
|
+
classification_report,
|
|
14
|
+
ConfusionMatrixDisplay,
|
|
15
|
+
mean_absolute_error,
|
|
16
|
+
mean_squared_error,
|
|
17
|
+
r2_score,
|
|
18
|
+
roc_curve,
|
|
19
|
+
roc_auc_score,
|
|
20
|
+
precision_recall_curve,
|
|
21
|
+
average_precision_score)
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Union, Optional, Literal
|
|
24
|
+
|
|
25
|
+
from .path_manager import sanitize_filename, make_fullpath
|
|
26
|
+
from ._script_info import _script_info
|
|
27
|
+
from ._logger import _LOGGER
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"evaluate_model_classification",
|
|
32
|
+
"plot_roc_curve",
|
|
33
|
+
"plot_precision_recall_curve",
|
|
34
|
+
"plot_calibration_curve",
|
|
35
|
+
"evaluate_model_regression",
|
|
36
|
+
"get_shap_values",
|
|
37
|
+
"plot_learning_curves",
|
|
38
|
+
]
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
# function to evaluate the model and save metrics (Classification)
|
|
42
|
+
def evaluate_model_classification(
|
|
43
|
+
model,
|
|
44
|
+
model_name: str,
|
|
45
|
+
save_dir: Union[str,Path],
|
|
46
|
+
x_test_scaled: np.ndarray,
|
|
47
|
+
single_y_test: np.ndarray,
|
|
48
|
+
target_name: str,
|
|
49
|
+
figsize: tuple = (10, 8),
|
|
50
|
+
base_fontsize: int = 24,
|
|
51
|
+
cmap: Colormap = plt.cm.Blues, # type: ignore
|
|
52
|
+
heatmap_cmap: str = "viridis"
|
|
53
|
+
) -> np.ndarray:
|
|
54
|
+
"""
|
|
55
|
+
Evaluates a classification model, saves the classification report (text and heatmap) and the confusion matrix plot.
|
|
56
|
+
|
|
57
|
+
Parameters:
|
|
58
|
+
model: Trained classifier with .predict() method
|
|
59
|
+
model_name: Identifier for the model
|
|
60
|
+
save_dir: Directory where results are saved
|
|
61
|
+
x_test_scaled: Feature matrix for test set
|
|
62
|
+
single_y_test: True targets
|
|
63
|
+
target_name: Target name
|
|
64
|
+
figsize: Size of the confusion matrix figure (width, height)
|
|
65
|
+
fontsize: Font size used for title, axis labels and ticks
|
|
66
|
+
heatmap_cmap: Colormap for the classification report heatmap.
|
|
67
|
+
cmap: Color map for the confusion matrix. Examples include:
|
|
68
|
+
- plt.cm.Blues (default)
|
|
69
|
+
- plt.cm.Greens
|
|
70
|
+
- plt.cm.Oranges
|
|
71
|
+
- plt.cm.Purples
|
|
72
|
+
- plt.cm.Reds
|
|
73
|
+
- plt.cm.cividis
|
|
74
|
+
- plt.cm.inferno
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
y_pred: Predicted class labels
|
|
78
|
+
"""
|
|
79
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
80
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
81
|
+
|
|
82
|
+
y_pred = model.predict(x_test_scaled)
|
|
83
|
+
accuracy = accuracy_score(single_y_test, y_pred)
|
|
84
|
+
|
|
85
|
+
# Generate report as dictionary for the heatmap
|
|
86
|
+
report_dict = classification_report(
|
|
87
|
+
single_y_test,
|
|
88
|
+
y_pred,
|
|
89
|
+
target_names=["Negative", "Positive"],
|
|
90
|
+
output_dict=True
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# text report to save
|
|
94
|
+
report_text = classification_report(
|
|
95
|
+
single_y_test,
|
|
96
|
+
y_pred,
|
|
97
|
+
target_names=["Negative", "Positive"],
|
|
98
|
+
output_dict=False
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
# Save text report
|
|
102
|
+
|
|
103
|
+
report_path = save_path / f"Classification_Report_{sanitized_target_name}.txt"
|
|
104
|
+
with open(report_path, "w") as f:
|
|
105
|
+
f.write(f"{model_name} - {target_name}\t\tAccuracy: {accuracy:.2f}\n")
|
|
106
|
+
f.write("Classification Report:\n")
|
|
107
|
+
f.write(report_text) # type: ignore
|
|
108
|
+
|
|
109
|
+
# 3. Create and save the classification report heatmap
|
|
110
|
+
try:
|
|
111
|
+
report_df = pd.DataFrame(report_dict).iloc[:-1, :].T
|
|
112
|
+
plt.figure(figsize=figsize)
|
|
113
|
+
sns.heatmap(report_df, annot=True, cmap=heatmap_cmap, fmt='.2f',
|
|
114
|
+
annot_kws={"size": base_fontsize - 4})
|
|
115
|
+
plt.title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
116
|
+
plt.xticks(fontsize=base_fontsize - 2)
|
|
117
|
+
plt.yticks(fontsize=base_fontsize - 2)
|
|
118
|
+
|
|
119
|
+
heatmap_path = save_path / f"Classification_Report_{sanitized_target_name}.svg"
|
|
120
|
+
plt.savefig(heatmap_path, format="svg", bbox_inches="tight")
|
|
121
|
+
plt.close()
|
|
122
|
+
except Exception as e:
|
|
123
|
+
_LOGGER.error(f"❌ Could not generate classification report heatmap for {target_name}: {e}")
|
|
124
|
+
|
|
125
|
+
# Create confusion matrix
|
|
126
|
+
fig, ax = plt.subplots(figsize=figsize)
|
|
127
|
+
disp = ConfusionMatrixDisplay.from_predictions(
|
|
128
|
+
y_true=single_y_test,
|
|
129
|
+
y_pred=y_pred,
|
|
130
|
+
display_labels=["Negative", "Positive"],
|
|
131
|
+
cmap=cmap,
|
|
132
|
+
normalize="true",
|
|
133
|
+
ax=ax
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
137
|
+
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
138
|
+
ax.set_xlabel("Predicted label", fontsize=base_fontsize)
|
|
139
|
+
ax.set_ylabel("True label", fontsize=base_fontsize)
|
|
140
|
+
|
|
141
|
+
# Turn off gridlines
|
|
142
|
+
ax.grid(False)
|
|
143
|
+
|
|
144
|
+
# Manually update font size of cell texts
|
|
145
|
+
for text in ax.texts:
|
|
146
|
+
text.set_fontsize(base_fontsize+4)
|
|
147
|
+
|
|
148
|
+
fig.tight_layout()
|
|
149
|
+
fig_path = save_path / f"Confusion_Matrix_{sanitized_target_name}.svg"
|
|
150
|
+
fig.savefig(fig_path, format="svg", bbox_inches="tight") # type: ignore
|
|
151
|
+
plt.close(fig)
|
|
152
|
+
|
|
153
|
+
return y_pred
|
|
154
|
+
|
|
155
|
+
#Function to save ROC and ROC AUC (Classification)
|
|
156
|
+
def plot_roc_curve(
|
|
157
|
+
true_labels: np.ndarray,
|
|
158
|
+
probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
|
|
159
|
+
model_name: str,
|
|
160
|
+
target_name: str,
|
|
161
|
+
save_directory: Union[str,Path],
|
|
162
|
+
color: str = "darkorange",
|
|
163
|
+
figure_size: tuple = (10, 10),
|
|
164
|
+
linewidth: int = 2,
|
|
165
|
+
base_fontsize: int = 24,
|
|
166
|
+
input_features: Optional[np.ndarray] = None,
|
|
167
|
+
) -> plt.Figure: # type: ignore
|
|
168
|
+
"""
|
|
169
|
+
Plots the ROC curve and computes AUC for binary classification. Positive class is assumed to be in the second column of the probabilities array.
|
|
170
|
+
|
|
171
|
+
Parameters:
|
|
172
|
+
true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
|
|
173
|
+
probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
|
|
174
|
+
target_name: str, Target name.
|
|
175
|
+
save_directory: str or Path, path to directory where figure is saved.
|
|
176
|
+
color: color of the ROC curve. Accepts any valid Matplotlib color specification. Examples:
|
|
177
|
+
- Named colors: "darkorange", "blue", "red", "green", "black"
|
|
178
|
+
- Hex codes: "#1f77b4", "#ff7f0e"
|
|
179
|
+
- RGB tuples: (0.2, 0.4, 0.6)
|
|
180
|
+
- Colormap value: plt.cm.viridis(0.6)
|
|
181
|
+
figure_size: Tuple for figure size (width, height).
|
|
182
|
+
linewidth: int, width of the plotted ROC line.
|
|
183
|
+
title_fontsize: int, font size of the title.
|
|
184
|
+
label_fontsize: int, font size for axes labels.
|
|
185
|
+
input_features: np.ndarray of shape (n_samples, n_features), required if a model is passed.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
fig: matplotlib Figure object
|
|
189
|
+
"""
|
|
190
|
+
|
|
191
|
+
# Determine predicted probabilities
|
|
192
|
+
if isinstance(probabilities_or_model, np.ndarray):
|
|
193
|
+
# Input is already probabilities
|
|
194
|
+
if probabilities_or_model.ndim == 2: # type: ignore
|
|
195
|
+
y_score = probabilities_or_model[:, 1] # type: ignore
|
|
196
|
+
else:
|
|
197
|
+
y_score = probabilities_or_model
|
|
198
|
+
|
|
199
|
+
elif hasattr(probabilities_or_model, "predict_proba"):
|
|
200
|
+
if input_features is None:
|
|
201
|
+
raise ValueError("input_features must be provided when using a classifier.")
|
|
202
|
+
|
|
203
|
+
try:
|
|
204
|
+
classes = probabilities_or_model.classes_ # type: ignore
|
|
205
|
+
positive_class_index = list(classes).index(1)
|
|
206
|
+
except (AttributeError, ValueError):
|
|
207
|
+
positive_class_index = 1
|
|
208
|
+
|
|
209
|
+
y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
|
|
210
|
+
|
|
211
|
+
else:
|
|
212
|
+
raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
|
|
213
|
+
|
|
214
|
+
# ROC and AUC
|
|
215
|
+
fpr, tpr, _ = roc_curve(true_labels, y_score)
|
|
216
|
+
auc_score = roc_auc_score(true_labels, y_score)
|
|
217
|
+
|
|
218
|
+
# Plot
|
|
219
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
220
|
+
ax.plot(fpr, tpr, color=color, lw=linewidth, label=f"AUC = {auc_score:.2f}")
|
|
221
|
+
ax.plot([0, 1], [0, 1], color="gray", linestyle="--", lw=1)
|
|
222
|
+
|
|
223
|
+
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
224
|
+
ax.set_xlabel("False Positive Rate", fontsize=base_fontsize)
|
|
225
|
+
ax.set_ylabel("True Positive Rate", fontsize=base_fontsize)
|
|
226
|
+
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
227
|
+
ax.legend(loc="lower right", fontsize=base_fontsize)
|
|
228
|
+
ax.grid(True)
|
|
229
|
+
|
|
230
|
+
# Save figure
|
|
231
|
+
save_path = make_fullpath(save_directory, make=True)
|
|
232
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
233
|
+
full_save_path = save_path / f"ROC_{sanitized_target_name}.svg"
|
|
234
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
|
|
235
|
+
|
|
236
|
+
return fig
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
# Precision-Recall curve (Classification)
|
|
240
|
+
def plot_precision_recall_curve(
|
|
241
|
+
true_labels: np.ndarray,
|
|
242
|
+
probabilities_or_model: Union[np.ndarray, xgb.XGBClassifier, lgb.LGBMClassifier, object],
|
|
243
|
+
model_name: str,
|
|
244
|
+
target_name: str,
|
|
245
|
+
save_directory: Union[str, Path],
|
|
246
|
+
color: str = "teal",
|
|
247
|
+
figure_size: tuple = (10, 10),
|
|
248
|
+
linewidth: int = 2,
|
|
249
|
+
base_fontsize: int = 24,
|
|
250
|
+
input_features: Optional[np.ndarray] = None,
|
|
251
|
+
) -> plt.Figure: # type: ignore
|
|
252
|
+
"""
|
|
253
|
+
Plots the Precision-Recall curve and computes Average Precision (AP) for binary classification.
|
|
254
|
+
|
|
255
|
+
Parameters:
|
|
256
|
+
true_labels: np.ndarray of shape (n_samples,), ground truth binary labels (0 or 1).
|
|
257
|
+
probabilities_or_model: either predicted probabilities (ndarray), or a trained model with attribute `.predict_proba()`.
|
|
258
|
+
model_name: Identifier for the model.
|
|
259
|
+
target_name: Name of the target variable.
|
|
260
|
+
save_directory: Path to the directory where the figure will be saved.
|
|
261
|
+
color: str, color of the PR curve.
|
|
262
|
+
figure_size: Tuple for figure size (width, height).
|
|
263
|
+
linewidth: int, width of the plotted PR line.
|
|
264
|
+
base_fontsize: int, base font size for titles and labels.
|
|
265
|
+
input_features: np.ndarray, required if a model object is passed instead of probabilities.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
fig: matplotlib Figure object
|
|
269
|
+
"""
|
|
270
|
+
# Determine predicted probabilities for the positive class
|
|
271
|
+
if isinstance(probabilities_or_model, np.ndarray):
|
|
272
|
+
if probabilities_or_model.ndim == 2:
|
|
273
|
+
y_score = probabilities_or_model[:, 1]
|
|
274
|
+
else:
|
|
275
|
+
y_score = probabilities_or_model
|
|
276
|
+
|
|
277
|
+
elif hasattr(probabilities_or_model, "predict_proba"):
|
|
278
|
+
if input_features is None:
|
|
279
|
+
raise ValueError("input_features must be provided when using a classifier.")
|
|
280
|
+
try:
|
|
281
|
+
classes = probabilities_or_model.classes_ # type: ignore
|
|
282
|
+
positive_class_index = list(classes).index(1)
|
|
283
|
+
except (AttributeError, ValueError):
|
|
284
|
+
positive_class_index = 1
|
|
285
|
+
y_score = probabilities_or_model.predict_proba(input_features)[:, positive_class_index] # type: ignore
|
|
286
|
+
else:
|
|
287
|
+
raise TypeError("Unsupported type for 'probabilities_or_model'. Must be a NumPy array or a model with support for '.predict_proba()'.")
|
|
288
|
+
|
|
289
|
+
# Calculate PR curve and AP score
|
|
290
|
+
precision, recall, _ = precision_recall_curve(true_labels, y_score)
|
|
291
|
+
ap_score = average_precision_score(true_labels, y_score)
|
|
292
|
+
|
|
293
|
+
# Plot
|
|
294
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
295
|
+
ax.plot(recall, precision, color=color, lw=linewidth, label=f"AP = {ap_score:.2f}")
|
|
296
|
+
|
|
297
|
+
ax.set_title(f"{model_name} - {target_name}", fontsize=base_fontsize)
|
|
298
|
+
ax.set_xlabel("Recall", fontsize=base_fontsize)
|
|
299
|
+
ax.set_ylabel("Precision", fontsize=base_fontsize)
|
|
300
|
+
ax.tick_params(axis='both', labelsize=base_fontsize)
|
|
301
|
+
ax.legend(loc="lower left", fontsize=base_fontsize)
|
|
302
|
+
ax.grid(True)
|
|
303
|
+
fig.tight_layout()
|
|
304
|
+
|
|
305
|
+
# Save figure
|
|
306
|
+
save_path = make_fullpath(save_directory, make=True)
|
|
307
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
308
|
+
full_save_path = save_path / f"PR_Curve_{sanitized_target_name}.svg"
|
|
309
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
|
|
310
|
+
plt.close(fig)
|
|
311
|
+
|
|
312
|
+
return fig
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
# Calibration curve (classification)
|
|
316
|
+
def plot_calibration_curve(
|
|
317
|
+
model,
|
|
318
|
+
model_name: str,
|
|
319
|
+
save_dir: Union[str, Path],
|
|
320
|
+
x_test: np.ndarray,
|
|
321
|
+
y_test: np.ndarray,
|
|
322
|
+
target_name: str,
|
|
323
|
+
figure_size: tuple = (10, 10),
|
|
324
|
+
base_fontsize: int = 24,
|
|
325
|
+
n_bins: int = 15
|
|
326
|
+
) -> plt.Figure: # type: ignore
|
|
327
|
+
"""
|
|
328
|
+
Plots the calibration curve (reliability diagram) for a classifier.
|
|
329
|
+
|
|
330
|
+
Parameters:
|
|
331
|
+
model: Trained classifier with .predict_proba() method.
|
|
332
|
+
model_name: Identifier for the model.
|
|
333
|
+
save_dir: Directory where the plot will be saved.
|
|
334
|
+
x_test: Feature matrix for the test set.
|
|
335
|
+
y_test: True labels for the test set.
|
|
336
|
+
target_name: Name of the target variable.
|
|
337
|
+
figure_size: Tuple for figure size (width, height).
|
|
338
|
+
base_fontsize: Base font size for titles and labels.
|
|
339
|
+
n_bins: Number of bins to discretize predictions into.
|
|
340
|
+
|
|
341
|
+
Returns:
|
|
342
|
+
fig: matplotlib Figure object
|
|
343
|
+
"""
|
|
344
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
345
|
+
|
|
346
|
+
disp = CalibrationDisplay.from_estimator(
|
|
347
|
+
model,
|
|
348
|
+
x_test,
|
|
349
|
+
y_test,
|
|
350
|
+
n_bins=n_bins,
|
|
351
|
+
ax=ax
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
ax.set_title(f"{model_name} - Reliability Curve for {target_name}", fontsize=base_fontsize)
|
|
355
|
+
ax.tick_params(axis='both', labelsize=base_fontsize - 2)
|
|
356
|
+
ax.set_xlabel("Mean Predicted Probability", fontsize=base_fontsize)
|
|
357
|
+
ax.set_ylabel("Fraction of Positives", fontsize=base_fontsize)
|
|
358
|
+
ax.legend(fontsize=base_fontsize - 4)
|
|
359
|
+
fig.tight_layout()
|
|
360
|
+
|
|
361
|
+
# Save figure
|
|
362
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
363
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
364
|
+
full_save_path = save_path / f"Calibration_Plot_{sanitized_target_name}.svg"
|
|
365
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg") # type: ignore
|
|
366
|
+
plt.close(fig)
|
|
367
|
+
|
|
368
|
+
return fig
|
|
369
|
+
|
|
370
|
+
|
|
371
|
+
# function to evaluate the model and save metrics (Regression)
|
|
372
|
+
def evaluate_model_regression(model, model_name: str,
|
|
373
|
+
save_dir: Union[str,Path],
|
|
374
|
+
x_test_scaled: np.ndarray, single_y_test: np.ndarray,
|
|
375
|
+
target_name: str,
|
|
376
|
+
figure_size: tuple = (12, 8),
|
|
377
|
+
alpha_transparency: float = 0.5,
|
|
378
|
+
base_fontsize: int = 24,
|
|
379
|
+
hist_bins: int = 30):
|
|
380
|
+
# Generate predictions
|
|
381
|
+
y_pred = model.predict(x_test_scaled)
|
|
382
|
+
|
|
383
|
+
# Calculate regression metrics
|
|
384
|
+
mae = mean_absolute_error(single_y_test, y_pred)
|
|
385
|
+
mse = mean_squared_error(single_y_test, y_pred)
|
|
386
|
+
rmse = np.sqrt(mse)
|
|
387
|
+
r2 = r2_score(single_y_test, y_pred)
|
|
388
|
+
|
|
389
|
+
# Create formatted report
|
|
390
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
391
|
+
save_path = make_fullpath(save_dir, make=True)
|
|
392
|
+
report_path = save_path / f"Regression_Report_{sanitized_target_name}.txt"
|
|
393
|
+
with open(report_path, "w") as f:
|
|
394
|
+
f.write(f"{model_name} - Regression Performance for '{target_name}'\n\n")
|
|
395
|
+
f.write(f"Mean Absolute Error (MAE): {mae:.4f}\n")
|
|
396
|
+
f.write(f"Mean Squared Error (MSE): {mse:.4f}\n")
|
|
397
|
+
f.write(f"Root Mean Squared Error (RMSE): {rmse:.4f}\n")
|
|
398
|
+
f.write(f"R² Score: {r2:.4f}\n")
|
|
399
|
+
|
|
400
|
+
# Generate and save residual plot
|
|
401
|
+
residuals = single_y_test - y_pred
|
|
402
|
+
|
|
403
|
+
plt.figure(figsize=figure_size)
|
|
404
|
+
plt.scatter(y_pred, residuals, alpha=alpha_transparency)
|
|
405
|
+
plt.axhline(0, color='red', linestyle='--')
|
|
406
|
+
plt.xlabel("Predicted Values", fontsize=base_fontsize)
|
|
407
|
+
plt.ylabel("Residuals", fontsize=base_fontsize)
|
|
408
|
+
plt.title(f"{model_name} - Residual Plot for {target_name}", fontsize=base_fontsize)
|
|
409
|
+
plt.grid(True)
|
|
410
|
+
plt.tight_layout()
|
|
411
|
+
residual_path = save_path / f"Residuals_Plot_{sanitized_target_name}.svg"
|
|
412
|
+
plt.savefig(residual_path, bbox_inches='tight', format="svg")
|
|
413
|
+
plt.close()
|
|
414
|
+
|
|
415
|
+
# Create true vs predicted values plot
|
|
416
|
+
plt.figure(figsize=figure_size)
|
|
417
|
+
plt.scatter(single_y_test, y_pred, alpha=alpha_transparency)
|
|
418
|
+
plt.plot([single_y_test.min(), single_y_test.max()],
|
|
419
|
+
[single_y_test.min(), single_y_test.max()],
|
|
420
|
+
'k--', lw=2)
|
|
421
|
+
plt.xlabel('True Values', fontsize=base_fontsize)
|
|
422
|
+
plt.ylabel('Predictions', fontsize=base_fontsize)
|
|
423
|
+
plt.title(f"{model_name} - True vs Predicted for {target_name}", fontsize=base_fontsize)
|
|
424
|
+
plt.grid(True)
|
|
425
|
+
plot_path = save_path / f"True_Vs_Predict_Plot_{sanitized_target_name}.svg"
|
|
426
|
+
plt.savefig(plot_path, bbox_inches='tight', format="svg")
|
|
427
|
+
plt.close()
|
|
428
|
+
|
|
429
|
+
# Generate and save histogram of residuals
|
|
430
|
+
plt.figure(figsize=figure_size)
|
|
431
|
+
sns.histplot(residuals, bins=hist_bins, kde=True)
|
|
432
|
+
plt.xlabel("Residual Value", fontsize=base_fontsize)
|
|
433
|
+
plt.ylabel("Frequency", fontsize=base_fontsize)
|
|
434
|
+
plt.title(f"{model_name} - Distribution of Residuals for {target_name}", fontsize=base_fontsize)
|
|
435
|
+
plt.grid(True)
|
|
436
|
+
plt.tight_layout()
|
|
437
|
+
hist_path = save_path / f"Residuals_Distribution_{sanitized_target_name}.svg"
|
|
438
|
+
plt.savefig(hist_path, bbox_inches='tight', format="svg")
|
|
439
|
+
plt.close()
|
|
440
|
+
|
|
441
|
+
return y_pred
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
# Get SHAP values
|
|
445
|
+
def get_shap_values(
|
|
446
|
+
model,
|
|
447
|
+
model_name: str,
|
|
448
|
+
save_dir: Union[str, Path],
|
|
449
|
+
features_to_explain: np.ndarray,
|
|
450
|
+
feature_names: list[str],
|
|
451
|
+
target_name: str,
|
|
452
|
+
task: Literal["classification", "regression"],
|
|
453
|
+
max_display_features: int = 10,
|
|
454
|
+
figsize: tuple = (16, 20),
|
|
455
|
+
base_fontsize: int = 38,
|
|
456
|
+
):
|
|
457
|
+
"""
|
|
458
|
+
Universal SHAP explainer for regression and classification.
|
|
459
|
+
* Use `X_train` (or a subsample of it) to see how the model explains the data it was trained on.
|
|
460
|
+
|
|
461
|
+
* Use `X_test` (or a hold-out set) to see how the model explains unseen data.
|
|
462
|
+
|
|
463
|
+
* Use the entire dataset to get the global view.
|
|
464
|
+
|
|
465
|
+
Parameters:
|
|
466
|
+
task: 'regression' or 'classification'.
|
|
467
|
+
features_to_explain: Should match the model's training data format, including scaling.
|
|
468
|
+
save_dir: Directory to save visualizations.
|
|
469
|
+
"""
|
|
470
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
471
|
+
global_save_path = make_fullpath(save_dir, make=True)
|
|
472
|
+
|
|
473
|
+
def _apply_plot_style():
|
|
474
|
+
styles = ['seaborn', 'seaborn-v0_8-darkgrid', 'seaborn-v0_8', 'default']
|
|
475
|
+
for style in styles:
|
|
476
|
+
if style in plt.style.available or style == 'default':
|
|
477
|
+
plt.style.use(style)
|
|
478
|
+
break
|
|
479
|
+
|
|
480
|
+
def _configure_rcparams():
|
|
481
|
+
plt.rc('font', size=base_fontsize)
|
|
482
|
+
plt.rc('axes', titlesize=base_fontsize)
|
|
483
|
+
plt.rc('axes', labelsize=base_fontsize)
|
|
484
|
+
plt.rc('xtick', labelsize=base_fontsize)
|
|
485
|
+
plt.rc('ytick', labelsize=base_fontsize + 2)
|
|
486
|
+
plt.rc('legend', fontsize=base_fontsize)
|
|
487
|
+
plt.rc('figure', titlesize=base_fontsize)
|
|
488
|
+
|
|
489
|
+
def _create_shap_plot(shap_values, features, save_path: Path, plot_type: str, title: str):
|
|
490
|
+
_apply_plot_style()
|
|
491
|
+
_configure_rcparams()
|
|
492
|
+
plt.figure(figsize=figsize)
|
|
493
|
+
|
|
494
|
+
shap.summary_plot(
|
|
495
|
+
shap_values=shap_values,
|
|
496
|
+
features=features,
|
|
497
|
+
feature_names=feature_names,
|
|
498
|
+
plot_type=plot_type,
|
|
499
|
+
show=False,
|
|
500
|
+
plot_size=figsize,
|
|
501
|
+
max_display=max_display_features,
|
|
502
|
+
alpha=0.7,
|
|
503
|
+
# color='viridis'
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
ax = plt.gca()
|
|
507
|
+
ax.set_xlabel("SHAP Value Impact", fontsize=base_fontsize + 2, weight='bold', labelpad=20)
|
|
508
|
+
plt.title(title, fontsize=base_fontsize + 2, pad=20, weight='bold')
|
|
509
|
+
|
|
510
|
+
for tick in ax.get_xticklabels():
|
|
511
|
+
tick.set_fontsize(base_fontsize)
|
|
512
|
+
tick.set_rotation(30)
|
|
513
|
+
for tick in ax.get_yticklabels():
|
|
514
|
+
tick.set_fontsize(base_fontsize + 2)
|
|
515
|
+
|
|
516
|
+
if plot_type == "dot":
|
|
517
|
+
cb = plt.gcf().axes[-1]
|
|
518
|
+
cb.set_ylabel("", size=1)
|
|
519
|
+
cb.tick_params(labelsize=base_fontsize - 2)
|
|
520
|
+
|
|
521
|
+
plt.savefig(save_path, bbox_inches='tight', facecolor='white', format="svg")
|
|
522
|
+
plt.close()
|
|
523
|
+
rcdefaults()
|
|
524
|
+
|
|
525
|
+
def _plot_for_classification(shap_values, class_names):
|
|
526
|
+
is_multiclass = isinstance(shap_values, list) and len(shap_values) > 1
|
|
527
|
+
|
|
528
|
+
if is_multiclass:
|
|
529
|
+
for class_shap, class_name in zip(shap_values, class_names):
|
|
530
|
+
for plot_type in ["bar", "dot"]:
|
|
531
|
+
_create_shap_plot(
|
|
532
|
+
shap_values=class_shap,
|
|
533
|
+
features=features_to_explain,
|
|
534
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_Class{class_name}_{plot_type}.svg",
|
|
535
|
+
plot_type=plot_type,
|
|
536
|
+
title=f"{model_name} - {target_name} (Class {class_name})"
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
values = shap_values[1] if isinstance(shap_values, list) else shap_values
|
|
540
|
+
for plot_type in ["bar", "dot"]:
|
|
541
|
+
_create_shap_plot(
|
|
542
|
+
shap_values=values,
|
|
543
|
+
features=features_to_explain,
|
|
544
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
|
|
545
|
+
plot_type=plot_type,
|
|
546
|
+
title=f"{model_name} - {target_name}"
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
def _plot_for_regression(shap_values):
|
|
550
|
+
for plot_type in ["bar", "dot"]:
|
|
551
|
+
_create_shap_plot(
|
|
552
|
+
shap_values=shap_values,
|
|
553
|
+
features=features_to_explain,
|
|
554
|
+
save_path=global_save_path / f"SHAP_{sanitized_target_name}_{plot_type}.svg",
|
|
555
|
+
plot_type=plot_type,
|
|
556
|
+
title=f"{model_name} - {target_name}"
|
|
557
|
+
)
|
|
558
|
+
#START_O
|
|
559
|
+
|
|
560
|
+
explainer = shap.TreeExplainer(model)
|
|
561
|
+
shap_values = explainer.shap_values(features_to_explain)
|
|
562
|
+
|
|
563
|
+
if task == 'classification':
|
|
564
|
+
try:
|
|
565
|
+
class_names = model.classes_ if hasattr(model, 'classes_') else list(range(len(shap_values)))
|
|
566
|
+
except Exception:
|
|
567
|
+
class_names = list(range(len(shap_values)))
|
|
568
|
+
_plot_for_classification(shap_values, class_names)
|
|
569
|
+
else:
|
|
570
|
+
_plot_for_regression(shap_values)
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
# Learning curves for regression and classification
|
|
574
|
+
def plot_learning_curves(
|
|
575
|
+
estimator,
|
|
576
|
+
X: np.ndarray,
|
|
577
|
+
y: np.ndarray,
|
|
578
|
+
task: Literal["classification", "regression"],
|
|
579
|
+
model_name: str,
|
|
580
|
+
target_name: str,
|
|
581
|
+
save_directory: Union[str, Path],
|
|
582
|
+
cv: int = 5,
|
|
583
|
+
n_jobs: int = -1,
|
|
584
|
+
figure_size: tuple = (12, 8),
|
|
585
|
+
base_fontsize: int = 24
|
|
586
|
+
):
|
|
587
|
+
"""
|
|
588
|
+
Generates and saves a plot of the learning curves for a given estimator
|
|
589
|
+
to diagnose bias vs. variance.
|
|
590
|
+
|
|
591
|
+
Computationally expensive, requires a fresh, unfitted instance of the model.
|
|
592
|
+
"""
|
|
593
|
+
save_path = make_fullpath(save_directory, make=True)
|
|
594
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
595
|
+
|
|
596
|
+
# Select scoring metric based on task
|
|
597
|
+
scoring = "accuracy" if task == "classification" else "r2"
|
|
598
|
+
|
|
599
|
+
train_sizes_abs, train_scores, val_scores, *_ = learning_curve(
|
|
600
|
+
estimator, X, y,
|
|
601
|
+
cv=cv,
|
|
602
|
+
n_jobs=n_jobs,
|
|
603
|
+
train_sizes=np.linspace(0.1, 1.0, 10),
|
|
604
|
+
scoring=scoring
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
train_scores_mean = np.mean(train_scores, axis=1)
|
|
608
|
+
train_scores_std = np.std(train_scores, axis=1)
|
|
609
|
+
val_scores_mean = np.mean(val_scores, axis=1)
|
|
610
|
+
val_scores_std = np.std(val_scores, axis=1)
|
|
611
|
+
|
|
612
|
+
fig, ax = plt.subplots(figsize=figure_size)
|
|
613
|
+
ax.grid(True)
|
|
614
|
+
|
|
615
|
+
# Plot the mean scores
|
|
616
|
+
ax.plot(train_sizes_abs, train_scores_mean, 'o-', color="r", label="Training score")
|
|
617
|
+
ax.plot(train_sizes_abs, val_scores_mean, 'o-', color="g", label="Cross-validation score")
|
|
618
|
+
|
|
619
|
+
# Plot the standard deviation bands
|
|
620
|
+
ax.fill_between(train_sizes_abs, train_scores_mean - train_scores_std,
|
|
621
|
+
train_scores_mean + train_scores_std, alpha=0.1, color="r")
|
|
622
|
+
ax.fill_between(train_sizes_abs, val_scores_mean - val_scores_std,
|
|
623
|
+
val_scores_mean + val_scores_std, alpha=0.1, color="g")
|
|
624
|
+
|
|
625
|
+
ax.set_title(f"{model_name} - Learning Curve for {target_name}", fontsize=base_fontsize)
|
|
626
|
+
ax.set_xlabel("Training examples", fontsize=base_fontsize)
|
|
627
|
+
ax.set_ylabel(f"Score ({scoring})", fontsize=base_fontsize)
|
|
628
|
+
ax.legend(loc="best", fontsize=base_fontsize - 4)
|
|
629
|
+
ax.tick_params(axis='both', labelsize=base_fontsize - 4)
|
|
630
|
+
fig.tight_layout()
|
|
631
|
+
|
|
632
|
+
# Save figure
|
|
633
|
+
full_save_path = save_path / f"Learning_Curve_{sanitized_target_name}.svg"
|
|
634
|
+
fig.savefig(full_save_path, bbox_inches="tight", format="svg")
|
|
635
|
+
plt.close(fig)
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
def info():
|
|
639
|
+
_script_info(__all__)
|
ml_tools/ensemble_inference.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from ._script_info import _script_info
|
|
2
2
|
from ._logger import _LOGGER
|
|
3
3
|
from .path_manager import make_fullpath, list_files_by_extension
|
|
4
|
-
from .keys import
|
|
4
|
+
from .keys import EnsembleKeys
|
|
5
5
|
|
|
6
6
|
from typing import Union, Literal, Dict, Any, Optional, List
|
|
7
7
|
from pathlib import Path
|
|
@@ -49,9 +49,9 @@ class InferenceHandler:
|
|
|
49
49
|
verbose=self.verbose,
|
|
50
50
|
raise_on_error=True) # type: ignore
|
|
51
51
|
|
|
52
|
-
model: Any = full_object[
|
|
53
|
-
target_name: str = full_object[
|
|
54
|
-
feature_names_list: List[str] = full_object[
|
|
52
|
+
model: Any = full_object[EnsembleKeys.MODEL]
|
|
53
|
+
target_name: str = full_object[EnsembleKeys.TARGET]
|
|
54
|
+
feature_names_list: List[str] = full_object[EnsembleKeys.FEATURES]
|
|
55
55
|
|
|
56
56
|
# Check that feature names match
|
|
57
57
|
if self._feature_names is None:
|
|
@@ -102,8 +102,8 @@ class InferenceHandler:
|
|
|
102
102
|
else: # Classification
|
|
103
103
|
label = model.predict(features)[0]
|
|
104
104
|
probabilities = model.predict_proba(features)[0]
|
|
105
|
-
results[target_name] = {
|
|
106
|
-
|
|
105
|
+
results[target_name] = {EnsembleKeys.CLASSIFICATION_LABEL: label,
|
|
106
|
+
EnsembleKeys.CLASSIFICATION_PROBABILITIES: probabilities}
|
|
107
107
|
|
|
108
108
|
if self.verbose:
|
|
109
109
|
_LOGGER.info("✅ Inference process complete.")
|
|
@@ -170,15 +170,15 @@ def model_report(
|
|
|
170
170
|
# --- 2. Deserialize and Extract Info ---
|
|
171
171
|
try:
|
|
172
172
|
full_object: dict = _deserialize_object(model_p) # type: ignore
|
|
173
|
-
model = full_object[
|
|
174
|
-
target = full_object[
|
|
175
|
-
features = full_object[
|
|
173
|
+
model = full_object[EnsembleKeys.MODEL]
|
|
174
|
+
target = full_object[EnsembleKeys.TARGET]
|
|
175
|
+
features = full_object[EnsembleKeys.FEATURES]
|
|
176
176
|
except FileNotFoundError:
|
|
177
177
|
_LOGGER.error(f"❌ Model file not found at '{model_p}'")
|
|
178
178
|
raise
|
|
179
179
|
except (KeyError, TypeError) as e:
|
|
180
180
|
_LOGGER.error(
|
|
181
|
-
f"❌ The serialized object is missing required keys '{
|
|
181
|
+
f"❌ The serialized object is missing required keys '{EnsembleKeys.MODEL}', '{EnsembleKeys.TARGET}', '{EnsembleKeys.FEATURES}'"
|
|
182
182
|
)
|
|
183
183
|
raise e
|
|
184
184
|
|