dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_evaluation_multi.py
CHANGED
|
@@ -19,13 +19,14 @@ from sklearn.metrics import (
|
|
|
19
19
|
jaccard_score
|
|
20
20
|
)
|
|
21
21
|
from pathlib import Path
|
|
22
|
-
from typing import Union, List, Literal
|
|
22
|
+
from typing import Union, List, Literal, Optional
|
|
23
23
|
import warnings
|
|
24
24
|
|
|
25
25
|
from .path_manager import make_fullpath, sanitize_filename
|
|
26
26
|
from ._logger import _LOGGER
|
|
27
27
|
from ._script_info import _script_info
|
|
28
|
-
from .
|
|
28
|
+
from ._keys import SHAPKeys
|
|
29
|
+
from .ML_configuration import RegressionMetricsFormat, MultiClassificationMetricsFormat
|
|
29
30
|
|
|
30
31
|
|
|
31
32
|
__all__ = [
|
|
@@ -34,12 +35,15 @@ __all__ = [
|
|
|
34
35
|
"multi_target_shap_summary_plot",
|
|
35
36
|
]
|
|
36
37
|
|
|
38
|
+
DPI_value = 250
|
|
39
|
+
|
|
37
40
|
|
|
38
41
|
def multi_target_regression_metrics(
|
|
39
42
|
y_true: np.ndarray,
|
|
40
43
|
y_pred: np.ndarray,
|
|
41
44
|
target_names: List[str],
|
|
42
|
-
save_dir: Union[str, Path]
|
|
45
|
+
save_dir: Union[str, Path],
|
|
46
|
+
config: Optional[RegressionMetricsFormat] = None
|
|
43
47
|
):
|
|
44
48
|
"""
|
|
45
49
|
Calculates and saves regression metrics for each target individually.
|
|
@@ -53,6 +57,7 @@ def multi_target_regression_metrics(
|
|
|
53
57
|
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
54
58
|
target_names (List[str]): A list of names for the target variables.
|
|
55
59
|
save_dir (str | Path): Directory to save plots and the report.
|
|
60
|
+
config (RegressionMetricsFormat, optional): Formatting configuration object.
|
|
56
61
|
"""
|
|
57
62
|
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
58
63
|
_LOGGER.error("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
@@ -66,8 +71,17 @@ def multi_target_regression_metrics(
|
|
|
66
71
|
|
|
67
72
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
68
73
|
metrics_summary = []
|
|
74
|
+
|
|
75
|
+
# --- Parse Config or use defaults ---
|
|
76
|
+
if config is None:
|
|
77
|
+
# Create a default config if one wasn't provided
|
|
78
|
+
config = RegressionMetricsFormat()
|
|
79
|
+
|
|
80
|
+
# --- Set Matplotlib font size ---
|
|
81
|
+
original_rc_params = plt.rcParams.copy()
|
|
82
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
69
83
|
|
|
70
|
-
_LOGGER.
|
|
84
|
+
_LOGGER.debug("--- Multi-Target Regression Evaluation ---")
|
|
71
85
|
|
|
72
86
|
for i, name in enumerate(target_names):
|
|
73
87
|
print(f" -> Evaluating target: '{name}'")
|
|
@@ -90,9 +104,13 @@ def multi_target_regression_metrics(
|
|
|
90
104
|
|
|
91
105
|
# --- Save Residual Plot ---
|
|
92
106
|
residuals = true_i - pred_i
|
|
93
|
-
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=
|
|
94
|
-
ax_res.scatter(pred_i, residuals,
|
|
95
|
-
|
|
107
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
108
|
+
ax_res.scatter(pred_i, residuals,
|
|
109
|
+
alpha=config.scatter_alpha,
|
|
110
|
+
edgecolors='k',
|
|
111
|
+
s=50,
|
|
112
|
+
color=config.scatter_color) # Use config color
|
|
113
|
+
ax_res.axhline(0, color=config.residual_line_color, linestyle='--') # Use config color
|
|
96
114
|
ax_res.set_xlabel("Predicted Values")
|
|
97
115
|
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
98
116
|
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
@@ -103,9 +121,16 @@ def multi_target_regression_metrics(
|
|
|
103
121
|
plt.close(fig_res)
|
|
104
122
|
|
|
105
123
|
# --- Save True vs. Predicted Plot ---
|
|
106
|
-
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=
|
|
107
|
-
ax_tvp.scatter(true_i, pred_i,
|
|
108
|
-
|
|
124
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=DPI_value)
|
|
125
|
+
ax_tvp.scatter(true_i, pred_i,
|
|
126
|
+
alpha=config.scatter_alpha,
|
|
127
|
+
edgecolors='k',
|
|
128
|
+
s=50,
|
|
129
|
+
color=config.scatter_color) # Use config color
|
|
130
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()],
|
|
131
|
+
linestyle='--',
|
|
132
|
+
lw=2,
|
|
133
|
+
color=config.ideal_line_color) # Use config color
|
|
109
134
|
ax_tvp.set_xlabel('True Values')
|
|
110
135
|
ax_tvp.set_ylabel('Predicted Values')
|
|
111
136
|
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
@@ -120,14 +145,18 @@ def multi_target_regression_metrics(
|
|
|
120
145
|
report_path = save_dir_path / "regression_report_multi.csv"
|
|
121
146
|
summary_df.to_csv(report_path, index=False)
|
|
122
147
|
_LOGGER.info(f"Full regression report saved to '{report_path.name}'")
|
|
148
|
+
|
|
149
|
+
# --- Restore RC params ---
|
|
150
|
+
plt.rcParams.update(original_rc_params)
|
|
123
151
|
|
|
124
152
|
|
|
125
153
|
def multi_label_classification_metrics(
|
|
126
154
|
y_true: np.ndarray,
|
|
155
|
+
y_pred: np.ndarray,
|
|
127
156
|
y_prob: np.ndarray,
|
|
128
157
|
target_names: List[str],
|
|
129
158
|
save_dir: Union[str, Path],
|
|
130
|
-
|
|
159
|
+
config: Optional[MultiClassificationMetricsFormat] = None # Add config object
|
|
131
160
|
):
|
|
132
161
|
"""
|
|
133
162
|
Calculates and saves classification metrics for each label individually.
|
|
@@ -138,17 +167,17 @@ def multi_label_classification_metrics(
|
|
|
138
167
|
|
|
139
168
|
Args:
|
|
140
169
|
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
170
|
+
y_pred (np.ndarray): Predicted binary labels, shape (n_samples, n_labels).
|
|
141
171
|
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
142
172
|
target_names (List[str]): A list of names for the labels.
|
|
143
173
|
save_dir (str | Path): Directory to save plots and reports.
|
|
144
|
-
|
|
145
|
-
binary predictions for metrics like the confusion matrix.
|
|
174
|
+
config (MultiClassificationMetricsFormat, optional): Formatting configuration object.
|
|
146
175
|
"""
|
|
147
|
-
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
148
|
-
_LOGGER.error("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
176
|
+
if y_true.ndim != 2 or y_prob.ndim != 2 or y_pred.ndim != 2:
|
|
177
|
+
_LOGGER.error("y_true, y_pred, and y_prob must be 2D arrays for multi-label classification.")
|
|
149
178
|
raise ValueError()
|
|
150
|
-
if y_true.shape != y_prob.shape:
|
|
151
|
-
_LOGGER.error("Shapes of y_true and y_prob must match.")
|
|
179
|
+
if y_true.shape != y_prob.shape or y_true.shape != y_pred.shape:
|
|
180
|
+
_LOGGER.error("Shapes of y_true, y_pred, and y_prob must match.")
|
|
152
181
|
raise ValueError()
|
|
153
182
|
if y_true.shape[1] != len(target_names):
|
|
154
183
|
_LOGGER.error("Number of target names must match the number of columns in y_true.")
|
|
@@ -156,25 +185,33 @@ def multi_label_classification_metrics(
|
|
|
156
185
|
|
|
157
186
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
158
187
|
|
|
159
|
-
#
|
|
160
|
-
|
|
188
|
+
# --- Parse Config or use defaults ---
|
|
189
|
+
if config is None:
|
|
190
|
+
# Create a default config if one wasn't provided
|
|
191
|
+
config = MultiClassificationMetricsFormat()
|
|
192
|
+
|
|
193
|
+
# y_pred is now passed in directly, no threshold needed.
|
|
194
|
+
|
|
195
|
+
# --- Save current RC params and update font size ---
|
|
196
|
+
original_rc_params = plt.rcParams.copy()
|
|
197
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
161
198
|
|
|
162
|
-
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
199
|
+
# _LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
163
200
|
|
|
164
|
-
# --- Calculate and Save Overall Metrics ---
|
|
201
|
+
# --- Calculate and Save Overall Metrics (using y_pred) ---
|
|
165
202
|
h_loss = hamming_loss(y_true, y_pred)
|
|
166
203
|
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
167
204
|
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
168
205
|
|
|
169
206
|
overall_report = (
|
|
170
|
-
f"Overall Multi-Label Metrics
|
|
207
|
+
f"Overall Multi-Label Metrics:\n" # No threshold to report here
|
|
171
208
|
f"--------------------------------------------------\n"
|
|
172
209
|
f"Hamming Loss: {h_loss:.4f}\n"
|
|
173
210
|
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
174
211
|
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
175
212
|
f"--------------------------------------------------\n"
|
|
176
213
|
)
|
|
177
|
-
print(overall_report)
|
|
214
|
+
# print(overall_report)
|
|
178
215
|
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
179
216
|
overall_report_path.write_text(overall_report)
|
|
180
217
|
|
|
@@ -182,28 +219,79 @@ def multi_label_classification_metrics(
|
|
|
182
219
|
for i, name in enumerate(target_names):
|
|
183
220
|
print(f" -> Evaluating label: '{name}'")
|
|
184
221
|
true_i = y_true[:, i]
|
|
185
|
-
pred_i = y_pred[:, i]
|
|
186
|
-
prob_i = y_prob[:, i]
|
|
222
|
+
pred_i = y_pred[:, i] # Use passed-in y_pred
|
|
223
|
+
prob_i = y_prob[:, i] # Use passed-in y_prob
|
|
187
224
|
sanitized_name = sanitize_filename(name)
|
|
188
225
|
|
|
189
|
-
# --- Save Classification Report for the label ---
|
|
226
|
+
# --- Save Classification Report for the label (uses y_pred) ---
|
|
190
227
|
report_text = classification_report(true_i, pred_i)
|
|
191
228
|
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
192
229
|
report_path.write_text(report_text) # type: ignore
|
|
193
230
|
|
|
194
|
-
# --- Save Confusion Matrix ---
|
|
195
|
-
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=
|
|
196
|
-
ConfusionMatrixDisplay.from_predictions(true_i,
|
|
231
|
+
# --- Save Confusion Matrix (uses y_pred) ---
|
|
232
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
233
|
+
disp_ = ConfusionMatrixDisplay.from_predictions(true_i,
|
|
234
|
+
pred_i,
|
|
235
|
+
cmap=config.cmap, # Use config cmap
|
|
236
|
+
ax=ax_cm,
|
|
237
|
+
normalize='true',
|
|
238
|
+
labels=[0, 1],
|
|
239
|
+
display_labels=["Negative", "Positive"])
|
|
240
|
+
|
|
241
|
+
disp_.im_.set_clim(vmin=0.0, vmax=1.0)
|
|
242
|
+
|
|
243
|
+
# Turn off gridlines
|
|
244
|
+
ax_cm.grid(False)
|
|
245
|
+
|
|
246
|
+
# Manually update font size of cell texts
|
|
247
|
+
for text in ax_cm.texts:
|
|
248
|
+
text.set_fontsize(config.font_size) # Use config font_size
|
|
249
|
+
|
|
250
|
+
fig_cm.tight_layout()
|
|
251
|
+
|
|
197
252
|
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
198
253
|
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
199
254
|
plt.savefig(cm_path)
|
|
200
255
|
plt.close(fig_cm)
|
|
201
256
|
|
|
202
|
-
# --- Save ROC Curve ---
|
|
203
|
-
fpr, tpr,
|
|
257
|
+
# --- Save ROC Curve (uses y_prob) ---
|
|
258
|
+
fpr, tpr, thresholds = roc_curve(true_i, prob_i)
|
|
259
|
+
|
|
260
|
+
try:
|
|
261
|
+
# Calculate Youden's J statistic (tpr - fpr)
|
|
262
|
+
J = tpr - fpr
|
|
263
|
+
# Find the index of the best threshold
|
|
264
|
+
best_index = np.argmax(J)
|
|
265
|
+
optimal_threshold = thresholds[best_index]
|
|
266
|
+
best_tpr = tpr[best_index]
|
|
267
|
+
best_fpr = fpr[best_index]
|
|
268
|
+
|
|
269
|
+
# Define the filename
|
|
270
|
+
threshold_filename = f"best_threshold_{sanitized_name}.txt"
|
|
271
|
+
threshold_path = save_dir_path / threshold_filename
|
|
272
|
+
|
|
273
|
+
# The class name is the target_name for this label
|
|
274
|
+
class_name = name
|
|
275
|
+
|
|
276
|
+
# Create content for the file
|
|
277
|
+
file_content = (
|
|
278
|
+
f"Optimal Classification Threshold (Youden's J Statistic)\n"
|
|
279
|
+
f"Class/Label: {class_name}\n"
|
|
280
|
+
f"--------------------------------------------------\n"
|
|
281
|
+
f"Threshold: {optimal_threshold:.6f}\n"
|
|
282
|
+
f"True Positive Rate (TPR): {best_tpr:.6f}\n"
|
|
283
|
+
f"False Positive Rate (FPR): {best_fpr:.6f}\n"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
threshold_path.write_text(file_content, encoding="utf-8")
|
|
287
|
+
_LOGGER.info(f"💾 Optimal threshold for '{name}' saved to '{threshold_path.name}'")
|
|
288
|
+
|
|
289
|
+
except Exception as e:
|
|
290
|
+
_LOGGER.warning(f"Could not calculate or save optimal threshold for '{name}': {e}")
|
|
291
|
+
|
|
204
292
|
auc = roc_auc_score(true_i, prob_i)
|
|
205
|
-
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=
|
|
206
|
-
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
293
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
294
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}', color=config.ROC_PR_line) # Use config color
|
|
207
295
|
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
208
296
|
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
209
297
|
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
@@ -212,17 +300,20 @@ def multi_label_classification_metrics(
|
|
|
212
300
|
plt.savefig(roc_path)
|
|
213
301
|
plt.close(fig_roc)
|
|
214
302
|
|
|
215
|
-
# --- Save Precision-Recall Curve ---
|
|
303
|
+
# --- Save Precision-Recall Curve (uses y_prob) ---
|
|
216
304
|
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
217
305
|
ap_score = average_precision_score(true_i, prob_i)
|
|
218
|
-
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=
|
|
219
|
-
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
306
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=DPI_value)
|
|
307
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}', color=config.ROC_PR_line) # Use config color
|
|
220
308
|
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
221
309
|
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
222
310
|
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
223
311
|
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
224
312
|
plt.savefig(pr_path)
|
|
225
313
|
plt.close(fig_pr)
|
|
314
|
+
|
|
315
|
+
# restore RC params
|
|
316
|
+
plt.rcParams.update(original_rc_params)
|
|
226
317
|
|
|
227
318
|
_LOGGER.info(f"All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
228
319
|
|