dragon-ml-toolbox 7.0.0__py3-none-any.whl → 8.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/RECORD +14 -12
- ml_tools/ML_datasetmaster.py +165 -116
- ml_tools/ML_evaluation.py +5 -2
- ml_tools/ML_evaluation_multi.py +296 -0
- ml_tools/ML_inference.py +232 -34
- ml_tools/ML_models.py +0 -4
- ml_tools/ML_trainer.py +168 -71
- ml_tools/_ML_optimization_multi.py +231 -0
- ml_tools/data_exploration.py +80 -2
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-7.0.0.dist-info → dragon_ml_toolbox-8.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,296 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import torch
|
|
6
|
+
import shap
|
|
7
|
+
from sklearn.metrics import (
|
|
8
|
+
classification_report,
|
|
9
|
+
ConfusionMatrixDisplay,
|
|
10
|
+
roc_curve,
|
|
11
|
+
roc_auc_score,
|
|
12
|
+
precision_recall_curve,
|
|
13
|
+
average_precision_score,
|
|
14
|
+
mean_squared_error,
|
|
15
|
+
mean_absolute_error,
|
|
16
|
+
r2_score,
|
|
17
|
+
median_absolute_error,
|
|
18
|
+
hamming_loss,
|
|
19
|
+
jaccard_score
|
|
20
|
+
)
|
|
21
|
+
from pathlib import Path
|
|
22
|
+
from typing import Union, List, Optional
|
|
23
|
+
|
|
24
|
+
from .path_manager import make_fullpath, sanitize_filename
|
|
25
|
+
from ._logger import _LOGGER
|
|
26
|
+
from ._script_info import _script_info
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"multi_target_regression_metrics",
|
|
30
|
+
"multi_label_classification_metrics",
|
|
31
|
+
"multi_target_shap_summary_plot",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def multi_target_regression_metrics(
|
|
36
|
+
y_true: np.ndarray,
|
|
37
|
+
y_pred: np.ndarray,
|
|
38
|
+
target_names: List[str],
|
|
39
|
+
save_dir: Union[str, Path]
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Calculates and saves regression metrics for each target individually.
|
|
43
|
+
|
|
44
|
+
For each target, this function saves a residual plot and a true vs. predicted plot.
|
|
45
|
+
It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
|
|
46
|
+
for all targets.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
|
|
50
|
+
y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
|
|
51
|
+
target_names (List[str]): A list of names for the target variables.
|
|
52
|
+
save_dir (str | Path): Directory to save plots and the report.
|
|
53
|
+
"""
|
|
54
|
+
if y_true.ndim != 2 or y_pred.ndim != 2:
|
|
55
|
+
raise ValueError("y_true and y_pred must be 2D arrays for multi-target regression.")
|
|
56
|
+
if y_true.shape != y_pred.shape:
|
|
57
|
+
raise ValueError("Shapes of y_true and y_pred must match.")
|
|
58
|
+
if y_true.shape[1] != len(target_names):
|
|
59
|
+
raise ValueError("Number of target names must match the number of columns in y_true.")
|
|
60
|
+
|
|
61
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
62
|
+
metrics_summary = []
|
|
63
|
+
|
|
64
|
+
_LOGGER.info("--- Multi-Target Regression Evaluation ---")
|
|
65
|
+
|
|
66
|
+
for i, name in enumerate(target_names):
|
|
67
|
+
_LOGGER.info(f" -> Evaluating target: '{name}'")
|
|
68
|
+
true_i = y_true[:, i]
|
|
69
|
+
pred_i = y_pred[:, i]
|
|
70
|
+
sanitized_name = sanitize_filename(name)
|
|
71
|
+
|
|
72
|
+
# --- Calculate Metrics ---
|
|
73
|
+
rmse = np.sqrt(mean_squared_error(true_i, pred_i))
|
|
74
|
+
mae = mean_absolute_error(true_i, pred_i)
|
|
75
|
+
r2 = r2_score(true_i, pred_i)
|
|
76
|
+
medae = median_absolute_error(true_i, pred_i)
|
|
77
|
+
metrics_summary.append({
|
|
78
|
+
'target': name,
|
|
79
|
+
'rmse': rmse,
|
|
80
|
+
'mae': mae,
|
|
81
|
+
'r2_score': r2,
|
|
82
|
+
'median_abs_error': medae
|
|
83
|
+
})
|
|
84
|
+
|
|
85
|
+
# --- Save Residual Plot ---
|
|
86
|
+
residuals = true_i - pred_i
|
|
87
|
+
fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
|
|
88
|
+
ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
|
|
89
|
+
ax_res.axhline(0, color='red', linestyle='--')
|
|
90
|
+
ax_res.set_xlabel("Predicted Values")
|
|
91
|
+
ax_res.set_ylabel("Residuals (True - Predicted)")
|
|
92
|
+
ax_res.set_title(f"Residual Plot for '{name}'")
|
|
93
|
+
ax_res.grid(True, linestyle='--', alpha=0.6)
|
|
94
|
+
plt.tight_layout()
|
|
95
|
+
res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
|
|
96
|
+
plt.savefig(res_path)
|
|
97
|
+
plt.close(fig_res)
|
|
98
|
+
|
|
99
|
+
# --- Save True vs. Predicted Plot ---
|
|
100
|
+
fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
|
|
101
|
+
ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
|
|
102
|
+
ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
|
|
103
|
+
ax_tvp.set_xlabel('True Values')
|
|
104
|
+
ax_tvp.set_ylabel('Predicted Values')
|
|
105
|
+
ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
|
|
106
|
+
ax_tvp.grid(True, linestyle='--', alpha=0.6)
|
|
107
|
+
plt.tight_layout()
|
|
108
|
+
tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
|
|
109
|
+
plt.savefig(tvp_path)
|
|
110
|
+
plt.close(fig_tvp)
|
|
111
|
+
|
|
112
|
+
# --- Save Summary Report ---
|
|
113
|
+
summary_df = pd.DataFrame(metrics_summary)
|
|
114
|
+
report_path = save_dir_path / "regression_report_multi.csv"
|
|
115
|
+
summary_df.to_csv(report_path, index=False)
|
|
116
|
+
_LOGGER.info(f"✅ Full regression report saved to '{report_path.name}'")
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def multi_label_classification_metrics(
|
|
120
|
+
y_true: np.ndarray,
|
|
121
|
+
y_prob: np.ndarray,
|
|
122
|
+
target_names: List[str],
|
|
123
|
+
save_dir: Union[str, Path],
|
|
124
|
+
threshold: float = 0.5
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Calculates and saves classification metrics for each label individually.
|
|
128
|
+
|
|
129
|
+
This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
|
|
130
|
+
and then iterates through each label to generate and save individual reports,
|
|
131
|
+
confusion matrices, ROC curves, and Precision-Recall curves.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
|
|
135
|
+
y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
|
|
136
|
+
target_names (List[str]): A list of names for the labels.
|
|
137
|
+
save_dir (str | Path): Directory to save plots and reports.
|
|
138
|
+
threshold (float): The probability threshold to convert probabilities into
|
|
139
|
+
binary predictions for metrics like the confusion matrix.
|
|
140
|
+
"""
|
|
141
|
+
if y_true.ndim != 2 or y_prob.ndim != 2:
|
|
142
|
+
raise ValueError("y_true and y_prob must be 2D arrays for multi-label classification.")
|
|
143
|
+
if y_true.shape != y_prob.shape:
|
|
144
|
+
raise ValueError("Shapes of y_true and y_prob must match.")
|
|
145
|
+
if y_true.shape[1] != len(target_names):
|
|
146
|
+
raise ValueError("Number of target names must match the number of columns in y_true.")
|
|
147
|
+
|
|
148
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
149
|
+
|
|
150
|
+
# Generate binary predictions from probabilities
|
|
151
|
+
y_pred = (y_prob >= threshold).astype(int)
|
|
152
|
+
|
|
153
|
+
_LOGGER.info("--- Multi-Label Classification Evaluation ---")
|
|
154
|
+
|
|
155
|
+
# --- Calculate and Save Overall Metrics ---
|
|
156
|
+
h_loss = hamming_loss(y_true, y_pred)
|
|
157
|
+
j_score_micro = jaccard_score(y_true, y_pred, average='micro')
|
|
158
|
+
j_score_macro = jaccard_score(y_true, y_pred, average='macro')
|
|
159
|
+
|
|
160
|
+
overall_report = (
|
|
161
|
+
f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
|
|
162
|
+
f"--------------------------------------------------\n"
|
|
163
|
+
f"Hamming Loss: {h_loss:.4f}\n"
|
|
164
|
+
f"Jaccard Score (micro): {j_score_micro:.4f}\n"
|
|
165
|
+
f"Jaccard Score (macro): {j_score_macro:.4f}\n"
|
|
166
|
+
f"--------------------------------------------------\n"
|
|
167
|
+
)
|
|
168
|
+
_LOGGER.info(overall_report)
|
|
169
|
+
overall_report_path = save_dir_path / "classification_report_overall.txt"
|
|
170
|
+
overall_report_path.write_text(overall_report)
|
|
171
|
+
|
|
172
|
+
# --- Per-Label Metrics and Plots ---
|
|
173
|
+
for i, name in enumerate(target_names):
|
|
174
|
+
_LOGGER.info(f" -> Evaluating label: '{name}'")
|
|
175
|
+
true_i = y_true[:, i]
|
|
176
|
+
pred_i = y_pred[:, i]
|
|
177
|
+
prob_i = y_prob[:, i]
|
|
178
|
+
sanitized_name = sanitize_filename(name)
|
|
179
|
+
|
|
180
|
+
# --- Save Classification Report for the label ---
|
|
181
|
+
report_text = classification_report(true_i, pred_i)
|
|
182
|
+
report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
|
|
183
|
+
report_path.write_text(report_text) # type: ignore
|
|
184
|
+
|
|
185
|
+
# --- Save Confusion Matrix ---
|
|
186
|
+
fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
|
|
187
|
+
ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
|
|
188
|
+
ax_cm.set_title(f"Confusion Matrix for '{name}'")
|
|
189
|
+
cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
|
|
190
|
+
plt.savefig(cm_path)
|
|
191
|
+
plt.close(fig_cm)
|
|
192
|
+
|
|
193
|
+
# --- Save ROC Curve ---
|
|
194
|
+
fpr, tpr, _ = roc_curve(true_i, prob_i)
|
|
195
|
+
auc = roc_auc_score(true_i, prob_i)
|
|
196
|
+
fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
|
|
197
|
+
ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
|
|
198
|
+
ax_roc.plot([0, 1], [0, 1], 'k--')
|
|
199
|
+
ax_roc.set_title(f'ROC Curve for "{name}"')
|
|
200
|
+
ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
|
|
201
|
+
ax_roc.legend(loc='lower right'); ax_roc.grid(True, linestyle='--', alpha=0.6)
|
|
202
|
+
roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
|
|
203
|
+
plt.savefig(roc_path)
|
|
204
|
+
plt.close(fig_roc)
|
|
205
|
+
|
|
206
|
+
# --- Save Precision-Recall Curve ---
|
|
207
|
+
precision, recall, _ = precision_recall_curve(true_i, prob_i)
|
|
208
|
+
ap_score = average_precision_score(true_i, prob_i)
|
|
209
|
+
fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
|
|
210
|
+
ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
|
|
211
|
+
ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
|
|
212
|
+
ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
|
|
213
|
+
ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
|
|
214
|
+
pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
|
|
215
|
+
plt.savefig(pr_path)
|
|
216
|
+
plt.close(fig_pr)
|
|
217
|
+
|
|
218
|
+
_LOGGER.info(f"✅ All individual label reports and plots saved to '{save_dir_path.name}'")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def multi_target_shap_summary_plot(
|
|
222
|
+
model: torch.nn.Module,
|
|
223
|
+
background_data: Union[torch.Tensor, np.ndarray],
|
|
224
|
+
instances_to_explain: Union[torch.Tensor, np.ndarray],
|
|
225
|
+
feature_names: List[str],
|
|
226
|
+
target_names: List[str],
|
|
227
|
+
save_dir: Union[str, Path]
|
|
228
|
+
):
|
|
229
|
+
"""
|
|
230
|
+
Calculates SHAP values for a multi-target model and saves summary plots for each target.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
model (torch.nn.Module): The trained PyTorch model.
|
|
234
|
+
background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
|
|
235
|
+
instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
|
|
236
|
+
feature_names (List[str]): Names of the features for plot labeling.
|
|
237
|
+
target_names (List[str]): Names of the output targets.
|
|
238
|
+
save_dir (str | Path): Directory to save SHAP artifacts.
|
|
239
|
+
"""
|
|
240
|
+
# Convert all data to numpy
|
|
241
|
+
background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
|
|
242
|
+
instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
|
|
243
|
+
|
|
244
|
+
if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
|
|
245
|
+
_LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
|
|
246
|
+
return
|
|
247
|
+
|
|
248
|
+
_LOGGER.info("\n--- Multi-Target SHAP Value Explanation ---")
|
|
249
|
+
model.eval()
|
|
250
|
+
model.cpu()
|
|
251
|
+
|
|
252
|
+
# 1. Summarize the background data.
|
|
253
|
+
background_summary = shap.kmeans(background_data_np, 30)
|
|
254
|
+
|
|
255
|
+
# 2. Define a prediction function wrapper for the multi-target model.
|
|
256
|
+
def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
|
|
257
|
+
x_torch = torch.from_numpy(x_np).float()
|
|
258
|
+
with torch.no_grad():
|
|
259
|
+
output = model(x_torch)
|
|
260
|
+
return output.cpu().numpy()
|
|
261
|
+
|
|
262
|
+
# 3. Create the KernelExplainer.
|
|
263
|
+
explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
|
|
264
|
+
|
|
265
|
+
_LOGGER.info("Calculating SHAP values with KernelExplainer...")
|
|
266
|
+
# For multi-output models, shap_values is a list of arrays.
|
|
267
|
+
shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
|
|
268
|
+
|
|
269
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
270
|
+
plt.ioff()
|
|
271
|
+
|
|
272
|
+
# 4. Iterate through each target's SHAP values and generate plots.
|
|
273
|
+
for i, target_name in enumerate(target_names):
|
|
274
|
+
_LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
|
|
275
|
+
shap_values_for_target = shap_values_list[i]
|
|
276
|
+
sanitized_target_name = sanitize_filename(target_name)
|
|
277
|
+
|
|
278
|
+
# Save Bar Plot for the target
|
|
279
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
|
|
280
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
281
|
+
plt.tight_layout()
|
|
282
|
+
bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
|
|
283
|
+
plt.savefig(bar_path)
|
|
284
|
+
plt.close()
|
|
285
|
+
|
|
286
|
+
# Save Dot Plot for the target
|
|
287
|
+
shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
|
|
288
|
+
plt.title(f"SHAP Feature Importance for '{target_name}'")
|
|
289
|
+
plt.tight_layout()
|
|
290
|
+
dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
|
|
291
|
+
plt.savefig(dot_path)
|
|
292
|
+
plt.close()
|
|
293
|
+
|
|
294
|
+
plt.ion()
|
|
295
|
+
_LOGGER.info(f"✅ All SHAP plots saved to '{save_dir_path.name}'")
|
|
296
|
+
|
ml_tools/ML_inference.py
CHANGED
|
@@ -3,6 +3,7 @@ from torch import nn
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from typing import Union, Literal, Dict, Any, Optional
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
6
7
|
|
|
7
8
|
from .ML_scaler import PytorchScaler
|
|
8
9
|
from ._script_info import _script_info
|
|
@@ -12,38 +13,36 @@ from .keys import PyTorchInferenceKeys
|
|
|
12
13
|
|
|
13
14
|
__all__ = [
|
|
14
15
|
"PyTorchInferenceHandler",
|
|
16
|
+
"PyTorchInferenceHandlerMulti",
|
|
15
17
|
"multi_inference_regression",
|
|
16
18
|
"multi_inference_classification"
|
|
17
19
|
]
|
|
18
20
|
|
|
19
|
-
|
|
21
|
+
|
|
22
|
+
class _BaseInferenceHandler(ABC):
|
|
20
23
|
"""
|
|
21
|
-
|
|
22
|
-
|
|
24
|
+
Abstract base class for PyTorch inference handlers.
|
|
25
|
+
|
|
26
|
+
Manages common tasks like loading a model's state dictionary, validating
|
|
27
|
+
the target device, and preprocessing input features.
|
|
23
28
|
"""
|
|
24
29
|
def __init__(self,
|
|
25
30
|
model: nn.Module,
|
|
26
31
|
state_dict: Union[str, Path],
|
|
27
|
-
task: Literal["classification", "regression"],
|
|
28
32
|
device: str = 'cpu',
|
|
29
|
-
target_id: Optional[str]=None,
|
|
30
33
|
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
31
34
|
"""
|
|
32
|
-
Initializes the handler
|
|
35
|
+
Initializes the handler.
|
|
33
36
|
|
|
34
37
|
Args:
|
|
35
|
-
model (nn.Module): An instantiated PyTorch model
|
|
36
|
-
state_dict (str | Path):
|
|
37
|
-
task (str): The type of task, 'regression' or 'classification'.
|
|
38
|
+
model (nn.Module): An instantiated PyTorch model.
|
|
39
|
+
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
38
40
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
39
|
-
|
|
40
|
-
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
41
|
+
scaler (PytorchScaler | str | Path | None): An optional scaler or path to a saved scaler state.
|
|
41
42
|
"""
|
|
42
43
|
self.model = model
|
|
43
|
-
self.task = task
|
|
44
44
|
self.device = self._validate_device(device)
|
|
45
|
-
|
|
46
|
-
|
|
45
|
+
|
|
47
46
|
# Load the scaler if a path is provided
|
|
48
47
|
if scaler is not None:
|
|
49
48
|
if isinstance(scaler, (str, Path)):
|
|
@@ -52,7 +51,7 @@ class PyTorchInferenceHandler:
|
|
|
52
51
|
self.scaler = scaler
|
|
53
52
|
else:
|
|
54
53
|
self.scaler = None
|
|
55
|
-
|
|
54
|
+
|
|
56
55
|
model_p = make_fullpath(state_dict, enforce="file")
|
|
57
56
|
|
|
58
57
|
try:
|
|
@@ -72,6 +71,7 @@ class PyTorchInferenceHandler:
|
|
|
72
71
|
_LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
|
|
73
72
|
device_lower = "cpu"
|
|
74
73
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
74
|
+
# Your M-series Mac will appreciate this check!
|
|
75
75
|
_LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
76
76
|
device_lower = "cpu"
|
|
77
77
|
return torch.device(device_lower)
|
|
@@ -84,51 +84,103 @@ class PyTorchInferenceHandler:
|
|
|
84
84
|
if isinstance(features, np.ndarray):
|
|
85
85
|
features_tensor = torch.from_numpy(features).float()
|
|
86
86
|
else:
|
|
87
|
-
# Ensure it's a float tensor for the model
|
|
88
87
|
features_tensor = features.float()
|
|
89
|
-
|
|
90
|
-
# Apply the scaler transformation if the scaler is available
|
|
88
|
+
|
|
91
89
|
if self.scaler:
|
|
92
90
|
features_tensor = self.scaler.transform(features_tensor)
|
|
93
|
-
|
|
94
|
-
# Ensure tensor is on the correct device
|
|
91
|
+
|
|
95
92
|
return features_tensor.to(self.device)
|
|
96
|
-
|
|
93
|
+
|
|
94
|
+
@abstractmethod
|
|
95
|
+
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
96
|
+
"""Core batch prediction method. Must be implemented by subclasses."""
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
@abstractmethod
|
|
100
|
+
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
101
|
+
"""Core single-sample prediction method. Must be implemented by subclasses."""
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
106
|
+
"""
|
|
107
|
+
Handles loading a PyTorch model's state dictionary and performing inference
|
|
108
|
+
for single-target regression or classification tasks.
|
|
109
|
+
"""
|
|
110
|
+
def __init__(self,
|
|
111
|
+
model: nn.Module,
|
|
112
|
+
state_dict: Union[str, Path],
|
|
113
|
+
task: Literal["classification", "regression"],
|
|
114
|
+
device: str = 'cpu',
|
|
115
|
+
target_id: Optional[str] = None,
|
|
116
|
+
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
117
|
+
"""
|
|
118
|
+
Initializes the handler for single-target tasks.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
model (nn.Module): An instantiated PyTorch model architecture.
|
|
122
|
+
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
123
|
+
task (str): The type of task, 'regression' or 'classification'.
|
|
124
|
+
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
125
|
+
target_id (str | None): An optional identifier for the target.
|
|
126
|
+
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
127
|
+
"""
|
|
128
|
+
# Call the parent constructor to handle model loading, device, and scaler
|
|
129
|
+
super().__init__(model, state_dict, device, scaler)
|
|
130
|
+
|
|
131
|
+
if task not in ["classification", "regression"]:
|
|
132
|
+
raise ValueError("`task` must be 'classification' or 'regression'.")
|
|
133
|
+
self.task = task
|
|
134
|
+
self.target_id = target_id
|
|
135
|
+
|
|
97
136
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
98
137
|
"""
|
|
99
|
-
Core batch prediction method
|
|
138
|
+
Core batch prediction method for single-target models.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
A dictionary containing the raw output tensors from the model.
|
|
100
145
|
"""
|
|
101
146
|
if features.ndim != 2:
|
|
102
147
|
raise ValueError("Input for batch prediction must be a 2D array or tensor.")
|
|
103
148
|
|
|
104
149
|
input_tensor = self._preprocess_input(features)
|
|
105
|
-
|
|
150
|
+
|
|
106
151
|
with torch.no_grad():
|
|
107
|
-
# Output tensor remains on the model's device (e.g., 'mps' or 'cuda')
|
|
108
152
|
output = self.model(input_tensor)
|
|
109
153
|
|
|
110
154
|
if self.task == "classification":
|
|
111
|
-
probs =
|
|
155
|
+
probs = torch.softmax(output, dim=1)
|
|
112
156
|
labels = torch.argmax(probs, dim=1)
|
|
113
157
|
return {
|
|
114
158
|
PyTorchInferenceKeys.LABELS: labels,
|
|
115
159
|
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
116
160
|
}
|
|
117
161
|
else: # regression
|
|
118
|
-
|
|
162
|
+
# For single-target regression, ensure output is flattened
|
|
163
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
|
|
119
164
|
|
|
120
165
|
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
121
166
|
"""
|
|
122
|
-
Core single-sample prediction
|
|
167
|
+
Core single-sample prediction method for single-target models.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
171
|
+
|
|
172
|
+
Returns:
|
|
173
|
+
A dictionary containing the raw output tensors for a single sample.
|
|
123
174
|
"""
|
|
124
175
|
if features.ndim == 1:
|
|
125
|
-
features = features.reshape(1, -1)
|
|
126
|
-
|
|
176
|
+
features = features.reshape(1, -1) # Reshape to a batch of one
|
|
177
|
+
|
|
127
178
|
if features.shape[0] != 1:
|
|
128
179
|
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
129
180
|
|
|
130
181
|
batch_results = self.predict_batch(features)
|
|
131
|
-
|
|
182
|
+
|
|
183
|
+
# Extract the first (and only) result from the batch output
|
|
132
184
|
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
133
185
|
return single_results
|
|
134
186
|
|
|
@@ -139,7 +191,6 @@ class PyTorchInferenceHandler:
|
|
|
139
191
|
Convenience wrapper for predict_batch that returns NumPy arrays.
|
|
140
192
|
"""
|
|
141
193
|
tensor_results = self.predict_batch(features)
|
|
142
|
-
# Move tensor to CPU before converting to NumPy
|
|
143
194
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
144
195
|
return numpy_results
|
|
145
196
|
|
|
@@ -148,16 +199,163 @@ class PyTorchInferenceHandler:
|
|
|
148
199
|
Convenience wrapper for predict that returns NumPy arrays or scalars.
|
|
149
200
|
"""
|
|
150
201
|
tensor_results = self.predict(features)
|
|
151
|
-
|
|
202
|
+
|
|
152
203
|
if self.task == "regression":
|
|
153
|
-
# .item() implicitly moves to CPU
|
|
204
|
+
# .item() implicitly moves to CPU and returns a Python scalar
|
|
154
205
|
return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
155
206
|
else: # classification
|
|
156
207
|
return {
|
|
157
208
|
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
|
|
158
|
-
# Move tensor to CPU before converting to NumPy
|
|
159
209
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
160
210
|
}
|
|
211
|
+
|
|
212
|
+
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
213
|
+
"""
|
|
214
|
+
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
215
|
+
|
|
216
|
+
`target_id` must be implemented.
|
|
217
|
+
"""
|
|
218
|
+
if self.target_id is None:
|
|
219
|
+
raise AttributeError(f"'target_id' has not been implemented.")
|
|
220
|
+
|
|
221
|
+
if self.task == "regression":
|
|
222
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
|
|
223
|
+
else:
|
|
224
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
|
|
225
|
+
|
|
226
|
+
return {self.target_id: result}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class PyTorchInferenceHandlerMulti(_BaseInferenceHandler):
|
|
230
|
+
"""
|
|
231
|
+
Handles loading a PyTorch model's state dictionary and performing inference
|
|
232
|
+
for multi-target regression or multi-label classification tasks.
|
|
233
|
+
"""
|
|
234
|
+
def __init__(self,
|
|
235
|
+
model: nn.Module,
|
|
236
|
+
state_dict: Union[str, Path],
|
|
237
|
+
task: Literal["multi_target_regression", "multi_label_classification"],
|
|
238
|
+
device: str = 'cpu',
|
|
239
|
+
target_ids: Optional[list[str]] = None,
|
|
240
|
+
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
241
|
+
"""
|
|
242
|
+
Initializes the handler for multi-target tasks.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
model (nn.Module): An instantiated PyTorch model.
|
|
246
|
+
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
247
|
+
task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
|
|
248
|
+
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
249
|
+
target_ids (list[str] | None): An optional identifier for the targets.
|
|
250
|
+
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
251
|
+
"""
|
|
252
|
+
super().__init__(model, state_dict, device, scaler)
|
|
253
|
+
|
|
254
|
+
if task not in ["multi_target_regression", "multi_label_classification"]:
|
|
255
|
+
raise ValueError("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
|
|
256
|
+
self.task = task
|
|
257
|
+
self.target_ids = target_ids
|
|
258
|
+
|
|
259
|
+
def predict_batch(self,
|
|
260
|
+
features: Union[np.ndarray, torch.Tensor],
|
|
261
|
+
classification_threshold: float = 0.5
|
|
262
|
+
) -> Dict[str, torch.Tensor]:
|
|
263
|
+
"""
|
|
264
|
+
Core batch prediction method for multi-target models.
|
|
265
|
+
|
|
266
|
+
Args:
|
|
267
|
+
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
268
|
+
classification_threshold (float): The threshold to convert probabilities
|
|
269
|
+
into binary predictions for multi-label classification.
|
|
270
|
+
|
|
271
|
+
Returns:
|
|
272
|
+
A dictionary containing the raw output tensors from the model.
|
|
273
|
+
"""
|
|
274
|
+
if features.ndim != 2:
|
|
275
|
+
raise ValueError("Input for batch prediction must be a 2D array or tensor.")
|
|
276
|
+
|
|
277
|
+
input_tensor = self._preprocess_input(features)
|
|
278
|
+
|
|
279
|
+
with torch.no_grad():
|
|
280
|
+
output = self.model(input_tensor)
|
|
281
|
+
|
|
282
|
+
if self.task == "multi_label_classification":
|
|
283
|
+
probs = torch.sigmoid(output)
|
|
284
|
+
# Get binary predictions based on the threshold
|
|
285
|
+
labels = (probs >= classification_threshold).int()
|
|
286
|
+
return {
|
|
287
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
288
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
289
|
+
}
|
|
290
|
+
else: # multi_target_regression
|
|
291
|
+
# The output is already in the correct [batch_size, n_targets] shape
|
|
292
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
293
|
+
|
|
294
|
+
def predict(self,
|
|
295
|
+
features: Union[np.ndarray, torch.Tensor],
|
|
296
|
+
classification_threshold: float = 0.5
|
|
297
|
+
) -> Dict[str, torch.Tensor]:
|
|
298
|
+
"""
|
|
299
|
+
Core single-sample prediction method for multi-target models.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
303
|
+
classification_threshold (float): The threshold for multi-label tasks.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
A dictionary containing the raw output tensors for a single sample.
|
|
307
|
+
"""
|
|
308
|
+
if features.ndim == 1:
|
|
309
|
+
features = features.reshape(1, -1)
|
|
310
|
+
|
|
311
|
+
if features.shape[0] != 1:
|
|
312
|
+
raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
|
|
313
|
+
|
|
314
|
+
batch_results = self.predict_batch(features, classification_threshold)
|
|
315
|
+
|
|
316
|
+
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
317
|
+
return single_results
|
|
318
|
+
|
|
319
|
+
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
320
|
+
|
|
321
|
+
def predict_batch_numpy(self,
|
|
322
|
+
features: Union[np.ndarray, torch.Tensor],
|
|
323
|
+
classification_threshold: float = 0.5
|
|
324
|
+
) -> Dict[str, np.ndarray]:
|
|
325
|
+
"""
|
|
326
|
+
Convenience wrapper for predict_batch that returns NumPy arrays.
|
|
327
|
+
"""
|
|
328
|
+
tensor_results = self.predict_batch(features, classification_threshold)
|
|
329
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
330
|
+
return numpy_results
|
|
331
|
+
|
|
332
|
+
def predict_numpy(self,
|
|
333
|
+
features: Union[np.ndarray, torch.Tensor],
|
|
334
|
+
classification_threshold: float = 0.5
|
|
335
|
+
) -> Dict[str, np.ndarray]:
|
|
336
|
+
"""
|
|
337
|
+
Convenience wrapper for predict that returns NumPy arrays for a single sample.
|
|
338
|
+
Note: For multi-target models, the output is always an array.
|
|
339
|
+
"""
|
|
340
|
+
tensor_results = self.predict(features, classification_threshold)
|
|
341
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
342
|
+
return numpy_results
|
|
343
|
+
|
|
344
|
+
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
345
|
+
"""
|
|
346
|
+
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
347
|
+
|
|
348
|
+
`target_ids` must be implemented.
|
|
349
|
+
"""
|
|
350
|
+
if self.target_ids is None:
|
|
351
|
+
raise AttributeError(f"'target_id' has not been implemented.")
|
|
352
|
+
|
|
353
|
+
if self.task == "multi_target_regression":
|
|
354
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
|
|
355
|
+
else:
|
|
356
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
357
|
+
|
|
358
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
161
359
|
|
|
162
360
|
|
|
163
361
|
def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
ml_tools/ML_models.py
CHANGED
|
@@ -89,10 +89,6 @@ class _BaseMLP(nn.Module):
|
|
|
89
89
|
class MultilayerPerceptron(_BaseMLP):
|
|
90
90
|
"""
|
|
91
91
|
Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
|
|
92
|
-
|
|
93
|
-
This model generates raw output values (logits) suitable for use with loss
|
|
94
|
-
functions like `nn.CrossEntropyLoss` (for classification) or `nn.MSELoss`
|
|
95
|
-
(for regression).
|
|
96
92
|
"""
|
|
97
93
|
def __init__(self, in_features: int, out_targets: int,
|
|
98
94
|
hidden_layers: List[int] = [256, 128], drop_out: float = 0.2) -> None:
|