dragon-ml-toolbox 7.0.0__py3-none-any.whl → 8.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,296 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import torch
6
+ import shap
7
+ from sklearn.metrics import (
8
+ classification_report,
9
+ ConfusionMatrixDisplay,
10
+ roc_curve,
11
+ roc_auc_score,
12
+ precision_recall_curve,
13
+ average_precision_score,
14
+ mean_squared_error,
15
+ mean_absolute_error,
16
+ r2_score,
17
+ median_absolute_error,
18
+ hamming_loss,
19
+ jaccard_score
20
+ )
21
+ from pathlib import Path
22
+ from typing import Union, List, Optional
23
+
24
+ from .path_manager import make_fullpath, sanitize_filename
25
+ from ._logger import _LOGGER
26
+ from ._script_info import _script_info
27
+
28
+ __all__ = [
29
+ "multi_target_regression_metrics",
30
+ "multi_label_classification_metrics",
31
+ "multi_target_shap_summary_plot",
32
+ ]
33
+
34
+
35
+ def multi_target_regression_metrics(
36
+ y_true: np.ndarray,
37
+ y_pred: np.ndarray,
38
+ target_names: List[str],
39
+ save_dir: Union[str, Path]
40
+ ):
41
+ """
42
+ Calculates and saves regression metrics for each target individually.
43
+
44
+ For each target, this function saves a residual plot and a true vs. predicted plot.
45
+ It also saves a single CSV file containing the key metrics (RMSE, MAE, R², MedAE)
46
+ for all targets.
47
+
48
+ Args:
49
+ y_true (np.ndarray): Ground truth values, shape (n_samples, n_targets).
50
+ y_pred (np.ndarray): Predicted values, shape (n_samples, n_targets).
51
+ target_names (List[str]): A list of names for the target variables.
52
+ save_dir (str | Path): Directory to save plots and the report.
53
+ """
54
+ if y_true.ndim != 2 or y_pred.ndim != 2:
55
+ raise ValueError("y_true and y_pred must be 2D arrays for multi-target regression.")
56
+ if y_true.shape != y_pred.shape:
57
+ raise ValueError("Shapes of y_true and y_pred must match.")
58
+ if y_true.shape[1] != len(target_names):
59
+ raise ValueError("Number of target names must match the number of columns in y_true.")
60
+
61
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
62
+ metrics_summary = []
63
+
64
+ _LOGGER.info("--- Multi-Target Regression Evaluation ---")
65
+
66
+ for i, name in enumerate(target_names):
67
+ _LOGGER.info(f" -> Evaluating target: '{name}'")
68
+ true_i = y_true[:, i]
69
+ pred_i = y_pred[:, i]
70
+ sanitized_name = sanitize_filename(name)
71
+
72
+ # --- Calculate Metrics ---
73
+ rmse = np.sqrt(mean_squared_error(true_i, pred_i))
74
+ mae = mean_absolute_error(true_i, pred_i)
75
+ r2 = r2_score(true_i, pred_i)
76
+ medae = median_absolute_error(true_i, pred_i)
77
+ metrics_summary.append({
78
+ 'target': name,
79
+ 'rmse': rmse,
80
+ 'mae': mae,
81
+ 'r2_score': r2,
82
+ 'median_abs_error': medae
83
+ })
84
+
85
+ # --- Save Residual Plot ---
86
+ residuals = true_i - pred_i
87
+ fig_res, ax_res = plt.subplots(figsize=(8, 6), dpi=100)
88
+ ax_res.scatter(pred_i, residuals, alpha=0.6, edgecolors='k', s=50)
89
+ ax_res.axhline(0, color='red', linestyle='--')
90
+ ax_res.set_xlabel("Predicted Values")
91
+ ax_res.set_ylabel("Residuals (True - Predicted)")
92
+ ax_res.set_title(f"Residual Plot for '{name}'")
93
+ ax_res.grid(True, linestyle='--', alpha=0.6)
94
+ plt.tight_layout()
95
+ res_path = save_dir_path / f"residual_plot_{sanitized_name}.svg"
96
+ plt.savefig(res_path)
97
+ plt.close(fig_res)
98
+
99
+ # --- Save True vs. Predicted Plot ---
100
+ fig_tvp, ax_tvp = plt.subplots(figsize=(8, 6), dpi=100)
101
+ ax_tvp.scatter(true_i, pred_i, alpha=0.6, edgecolors='k', s=50)
102
+ ax_tvp.plot([true_i.min(), true_i.max()], [true_i.min(), true_i.max()], 'k--', lw=2)
103
+ ax_tvp.set_xlabel('True Values')
104
+ ax_tvp.set_ylabel('Predicted Values')
105
+ ax_tvp.set_title(f'True vs. Predicted Values for "{name}"')
106
+ ax_tvp.grid(True, linestyle='--', alpha=0.6)
107
+ plt.tight_layout()
108
+ tvp_path = save_dir_path / f"true_vs_predicted_plot_{sanitized_name}.svg"
109
+ plt.savefig(tvp_path)
110
+ plt.close(fig_tvp)
111
+
112
+ # --- Save Summary Report ---
113
+ summary_df = pd.DataFrame(metrics_summary)
114
+ report_path = save_dir_path / "regression_report_multi.csv"
115
+ summary_df.to_csv(report_path, index=False)
116
+ _LOGGER.info(f"✅ Full regression report saved to '{report_path.name}'")
117
+
118
+
119
+ def multi_label_classification_metrics(
120
+ y_true: np.ndarray,
121
+ y_prob: np.ndarray,
122
+ target_names: List[str],
123
+ save_dir: Union[str, Path],
124
+ threshold: float = 0.5
125
+ ):
126
+ """
127
+ Calculates and saves classification metrics for each label individually.
128
+
129
+ This function first computes overall multi-label metrics (Hamming Loss, Jaccard Score)
130
+ and then iterates through each label to generate and save individual reports,
131
+ confusion matrices, ROC curves, and Precision-Recall curves.
132
+
133
+ Args:
134
+ y_true (np.ndarray): Ground truth binary labels, shape (n_samples, n_labels).
135
+ y_prob (np.ndarray): Predicted probabilities, shape (n_samples, n_labels).
136
+ target_names (List[str]): A list of names for the labels.
137
+ save_dir (str | Path): Directory to save plots and reports.
138
+ threshold (float): The probability threshold to convert probabilities into
139
+ binary predictions for metrics like the confusion matrix.
140
+ """
141
+ if y_true.ndim != 2 or y_prob.ndim != 2:
142
+ raise ValueError("y_true and y_prob must be 2D arrays for multi-label classification.")
143
+ if y_true.shape != y_prob.shape:
144
+ raise ValueError("Shapes of y_true and y_prob must match.")
145
+ if y_true.shape[1] != len(target_names):
146
+ raise ValueError("Number of target names must match the number of columns in y_true.")
147
+
148
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
149
+
150
+ # Generate binary predictions from probabilities
151
+ y_pred = (y_prob >= threshold).astype(int)
152
+
153
+ _LOGGER.info("--- Multi-Label Classification Evaluation ---")
154
+
155
+ # --- Calculate and Save Overall Metrics ---
156
+ h_loss = hamming_loss(y_true, y_pred)
157
+ j_score_micro = jaccard_score(y_true, y_pred, average='micro')
158
+ j_score_macro = jaccard_score(y_true, y_pred, average='macro')
159
+
160
+ overall_report = (
161
+ f"Overall Multi-Label Metrics (Threshold = {threshold}):\n"
162
+ f"--------------------------------------------------\n"
163
+ f"Hamming Loss: {h_loss:.4f}\n"
164
+ f"Jaccard Score (micro): {j_score_micro:.4f}\n"
165
+ f"Jaccard Score (macro): {j_score_macro:.4f}\n"
166
+ f"--------------------------------------------------\n"
167
+ )
168
+ _LOGGER.info(overall_report)
169
+ overall_report_path = save_dir_path / "classification_report_overall.txt"
170
+ overall_report_path.write_text(overall_report)
171
+
172
+ # --- Per-Label Metrics and Plots ---
173
+ for i, name in enumerate(target_names):
174
+ _LOGGER.info(f" -> Evaluating label: '{name}'")
175
+ true_i = y_true[:, i]
176
+ pred_i = y_pred[:, i]
177
+ prob_i = y_prob[:, i]
178
+ sanitized_name = sanitize_filename(name)
179
+
180
+ # --- Save Classification Report for the label ---
181
+ report_text = classification_report(true_i, pred_i)
182
+ report_path = save_dir_path / f"classification_report_{sanitized_name}.txt"
183
+ report_path.write_text(report_text) # type: ignore
184
+
185
+ # --- Save Confusion Matrix ---
186
+ fig_cm, ax_cm = plt.subplots(figsize=(6, 6), dpi=100)
187
+ ConfusionMatrixDisplay.from_predictions(true_i, pred_i, cmap="Blues", ax=ax_cm)
188
+ ax_cm.set_title(f"Confusion Matrix for '{name}'")
189
+ cm_path = save_dir_path / f"confusion_matrix_{sanitized_name}.svg"
190
+ plt.savefig(cm_path)
191
+ plt.close(fig_cm)
192
+
193
+ # --- Save ROC Curve ---
194
+ fpr, tpr, _ = roc_curve(true_i, prob_i)
195
+ auc = roc_auc_score(true_i, prob_i)
196
+ fig_roc, ax_roc = plt.subplots(figsize=(6, 6), dpi=100)
197
+ ax_roc.plot(fpr, tpr, label=f'AUC = {auc:.2f}')
198
+ ax_roc.plot([0, 1], [0, 1], 'k--')
199
+ ax_roc.set_title(f'ROC Curve for "{name}"')
200
+ ax_roc.set_xlabel('False Positive Rate'); ax_roc.set_ylabel('True Positive Rate')
201
+ ax_roc.legend(loc='lower right'); ax_roc.grid(True, linestyle='--', alpha=0.6)
202
+ roc_path = save_dir_path / f"roc_curve_{sanitized_name}.svg"
203
+ plt.savefig(roc_path)
204
+ plt.close(fig_roc)
205
+
206
+ # --- Save Precision-Recall Curve ---
207
+ precision, recall, _ = precision_recall_curve(true_i, prob_i)
208
+ ap_score = average_precision_score(true_i, prob_i)
209
+ fig_pr, ax_pr = plt.subplots(figsize=(6, 6), dpi=100)
210
+ ax_pr.plot(recall, precision, label=f'AP = {ap_score:.2f}')
211
+ ax_pr.set_title(f'Precision-Recall Curve for "{name}"')
212
+ ax_pr.set_xlabel('Recall'); ax_pr.set_ylabel('Precision')
213
+ ax_pr.legend(loc='lower left'); ax_pr.grid(True, linestyle='--', alpha=0.6)
214
+ pr_path = save_dir_path / f"pr_curve_{sanitized_name}.svg"
215
+ plt.savefig(pr_path)
216
+ plt.close(fig_pr)
217
+
218
+ _LOGGER.info(f"✅ All individual label reports and plots saved to '{save_dir_path.name}'")
219
+
220
+
221
+ def multi_target_shap_summary_plot(
222
+ model: torch.nn.Module,
223
+ background_data: Union[torch.Tensor, np.ndarray],
224
+ instances_to_explain: Union[torch.Tensor, np.ndarray],
225
+ feature_names: List[str],
226
+ target_names: List[str],
227
+ save_dir: Union[str, Path]
228
+ ):
229
+ """
230
+ Calculates SHAP values for a multi-target model and saves summary plots for each target.
231
+
232
+ Args:
233
+ model (torch.nn.Module): The trained PyTorch model.
234
+ background_data (torch.Tensor | np.ndarray): A sample of data for the explainer background.
235
+ instances_to_explain (torch.Tensor | np.ndarray): The specific data instances to explain.
236
+ feature_names (List[str]): Names of the features for plot labeling.
237
+ target_names (List[str]): Names of the output targets.
238
+ save_dir (str | Path): Directory to save SHAP artifacts.
239
+ """
240
+ # Convert all data to numpy
241
+ background_data_np = background_data.numpy() if isinstance(background_data, torch.Tensor) else background_data
242
+ instances_to_explain_np = instances_to_explain.numpy() if isinstance(instances_to_explain, torch.Tensor) else instances_to_explain
243
+
244
+ if np.isnan(background_data_np).any() or np.isnan(instances_to_explain_np).any():
245
+ _LOGGER.error("❌ Input data for SHAP contains NaN values. Aborting explanation.")
246
+ return
247
+
248
+ _LOGGER.info("\n--- Multi-Target SHAP Value Explanation ---")
249
+ model.eval()
250
+ model.cpu()
251
+
252
+ # 1. Summarize the background data.
253
+ background_summary = shap.kmeans(background_data_np, 30)
254
+
255
+ # 2. Define a prediction function wrapper for the multi-target model.
256
+ def prediction_wrapper(x_np: np.ndarray) -> np.ndarray:
257
+ x_torch = torch.from_numpy(x_np).float()
258
+ with torch.no_grad():
259
+ output = model(x_torch)
260
+ return output.cpu().numpy()
261
+
262
+ # 3. Create the KernelExplainer.
263
+ explainer = shap.KernelExplainer(prediction_wrapper, background_summary)
264
+
265
+ _LOGGER.info("Calculating SHAP values with KernelExplainer...")
266
+ # For multi-output models, shap_values is a list of arrays.
267
+ shap_values_list = explainer.shap_values(instances_to_explain_np, l1_reg="aic")
268
+
269
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
270
+ plt.ioff()
271
+
272
+ # 4. Iterate through each target's SHAP values and generate plots.
273
+ for i, target_name in enumerate(target_names):
274
+ _LOGGER.info(f" -> Generating SHAP plots for target: '{target_name}'")
275
+ shap_values_for_target = shap_values_list[i]
276
+ sanitized_target_name = sanitize_filename(target_name)
277
+
278
+ # Save Bar Plot for the target
279
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="bar", show=False)
280
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
281
+ plt.tight_layout()
282
+ bar_path = save_dir_path / f"shap_bar_plot_{sanitized_target_name}.svg"
283
+ plt.savefig(bar_path)
284
+ plt.close()
285
+
286
+ # Save Dot Plot for the target
287
+ shap.summary_plot(shap_values_for_target, instances_to_explain_np, feature_names=feature_names, plot_type="dot", show=False)
288
+ plt.title(f"SHAP Feature Importance for '{target_name}'")
289
+ plt.tight_layout()
290
+ dot_path = save_dir_path / f"shap_dot_plot_{sanitized_target_name}.svg"
291
+ plt.savefig(dot_path)
292
+ plt.close()
293
+
294
+ plt.ion()
295
+ _LOGGER.info(f"✅ All SHAP plots saved to '{save_dir_path.name}'")
296
+
ml_tools/ML_inference.py CHANGED
@@ -3,6 +3,7 @@ from torch import nn
3
3
  import numpy as np
4
4
  from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
+ from abc import ABC, abstractmethod
6
7
 
7
8
  from .ML_scaler import PytorchScaler
8
9
  from ._script_info import _script_info
@@ -12,38 +13,36 @@ from .keys import PyTorchInferenceKeys
12
13
 
13
14
  __all__ = [
14
15
  "PyTorchInferenceHandler",
16
+ "PyTorchInferenceHandlerMulti",
15
17
  "multi_inference_regression",
16
18
  "multi_inference_classification"
17
19
  ]
18
20
 
19
- class PyTorchInferenceHandler:
21
+
22
+ class _BaseInferenceHandler(ABC):
20
23
  """
21
- Handles loading a PyTorch model's state dictionary and performing inference
22
- for either regression or classification tasks.
24
+ Abstract base class for PyTorch inference handlers.
25
+
26
+ Manages common tasks like loading a model's state dictionary, validating
27
+ the target device, and preprocessing input features.
23
28
  """
24
29
  def __init__(self,
25
30
  model: nn.Module,
26
31
  state_dict: Union[str, Path],
27
- task: Literal["classification", "regression"],
28
32
  device: str = 'cpu',
29
- target_id: Optional[str]=None,
30
33
  scaler: Optional[Union[PytorchScaler, str, Path]] = None):
31
34
  """
32
- Initializes the handler by loading a model's state_dict.
35
+ Initializes the handler.
33
36
 
34
37
  Args:
35
- model (nn.Module): An instantiated PyTorch model with the correct architecture.
36
- state_dict (str | Path): The path to the saved .pth model state_dict file.
37
- task (str): The type of task, 'regression' or 'classification'.
38
+ model (nn.Module): An instantiated PyTorch model.
39
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
38
40
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
39
- target_id (str | None): Target name as used in the training set.
40
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
41
+ scaler (PytorchScaler | str | Path | None): An optional scaler or path to a saved scaler state.
41
42
  """
42
43
  self.model = model
43
- self.task = task
44
44
  self.device = self._validate_device(device)
45
- self.target_id = target_id
46
-
45
+
47
46
  # Load the scaler if a path is provided
48
47
  if scaler is not None:
49
48
  if isinstance(scaler, (str, Path)):
@@ -52,7 +51,7 @@ class PyTorchInferenceHandler:
52
51
  self.scaler = scaler
53
52
  else:
54
53
  self.scaler = None
55
-
54
+
56
55
  model_p = make_fullpath(state_dict, enforce="file")
57
56
 
58
57
  try:
@@ -72,6 +71,7 @@ class PyTorchInferenceHandler:
72
71
  _LOGGER.warning("⚠️ CUDA not available, switching to CPU.")
73
72
  device_lower = "cpu"
74
73
  elif device_lower == "mps" and not torch.backends.mps.is_available():
74
+ # Your M-series Mac will appreciate this check!
75
75
  _LOGGER.warning("⚠️ Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
76
76
  device_lower = "cpu"
77
77
  return torch.device(device_lower)
@@ -84,51 +84,103 @@ class PyTorchInferenceHandler:
84
84
  if isinstance(features, np.ndarray):
85
85
  features_tensor = torch.from_numpy(features).float()
86
86
  else:
87
- # Ensure it's a float tensor for the model
88
87
  features_tensor = features.float()
89
-
90
- # Apply the scaler transformation if the scaler is available
88
+
91
89
  if self.scaler:
92
90
  features_tensor = self.scaler.transform(features_tensor)
93
-
94
- # Ensure tensor is on the correct device
91
+
95
92
  return features_tensor.to(self.device)
96
-
93
+
94
+ @abstractmethod
95
+ def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
96
+ """Core batch prediction method. Must be implemented by subclasses."""
97
+ pass
98
+
99
+ @abstractmethod
100
+ def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
101
+ """Core single-sample prediction method. Must be implemented by subclasses."""
102
+ pass
103
+
104
+
105
+ class PyTorchInferenceHandler(_BaseInferenceHandler):
106
+ """
107
+ Handles loading a PyTorch model's state dictionary and performing inference
108
+ for single-target regression or classification tasks.
109
+ """
110
+ def __init__(self,
111
+ model: nn.Module,
112
+ state_dict: Union[str, Path],
113
+ task: Literal["classification", "regression"],
114
+ device: str = 'cpu',
115
+ target_id: Optional[str] = None,
116
+ scaler: Optional[Union[PytorchScaler, str, Path]] = None):
117
+ """
118
+ Initializes the handler for single-target tasks.
119
+
120
+ Args:
121
+ model (nn.Module): An instantiated PyTorch model architecture.
122
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
123
+ task (str): The type of task, 'regression' or 'classification'.
124
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
125
+ target_id (str | None): An optional identifier for the target.
126
+ scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
127
+ """
128
+ # Call the parent constructor to handle model loading, device, and scaler
129
+ super().__init__(model, state_dict, device, scaler)
130
+
131
+ if task not in ["classification", "regression"]:
132
+ raise ValueError("`task` must be 'classification' or 'regression'.")
133
+ self.task = task
134
+ self.target_id = target_id
135
+
97
136
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
98
137
  """
99
- Core batch prediction method. Returns results as PyTorch tensors on the model's device.
138
+ Core batch prediction method for single-target models.
139
+
140
+ Args:
141
+ features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
142
+
143
+ Returns:
144
+ A dictionary containing the raw output tensors from the model.
100
145
  """
101
146
  if features.ndim != 2:
102
147
  raise ValueError("Input for batch prediction must be a 2D array or tensor.")
103
148
 
104
149
  input_tensor = self._preprocess_input(features)
105
-
150
+
106
151
  with torch.no_grad():
107
- # Output tensor remains on the model's device (e.g., 'mps' or 'cuda')
108
152
  output = self.model(input_tensor)
109
153
 
110
154
  if self.task == "classification":
111
- probs = nn.functional.softmax(output, dim=1)
155
+ probs = torch.softmax(output, dim=1)
112
156
  labels = torch.argmax(probs, dim=1)
113
157
  return {
114
158
  PyTorchInferenceKeys.LABELS: labels,
115
159
  PyTorchInferenceKeys.PROBABILITIES: probs
116
160
  }
117
161
  else: # regression
118
- return {PyTorchInferenceKeys.PREDICTIONS: output}
162
+ # For single-target regression, ensure output is flattened
163
+ return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
119
164
 
120
165
  def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
121
166
  """
122
- Core single-sample prediction. Returns results as PyTorch tensors on the model's device.
167
+ Core single-sample prediction method for single-target models.
168
+
169
+ Args:
170
+ features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
171
+
172
+ Returns:
173
+ A dictionary containing the raw output tensors for a single sample.
123
174
  """
124
175
  if features.ndim == 1:
125
- features = features.reshape(1, -1)
126
-
176
+ features = features.reshape(1, -1) # Reshape to a batch of one
177
+
127
178
  if features.shape[0] != 1:
128
179
  raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
129
180
 
130
181
  batch_results = self.predict_batch(features)
131
-
182
+
183
+ # Extract the first (and only) result from the batch output
132
184
  single_results = {key: value[0] for key, value in batch_results.items()}
133
185
  return single_results
134
186
 
@@ -139,7 +191,6 @@ class PyTorchInferenceHandler:
139
191
  Convenience wrapper for predict_batch that returns NumPy arrays.
140
192
  """
141
193
  tensor_results = self.predict_batch(features)
142
- # Move tensor to CPU before converting to NumPy
143
194
  numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
144
195
  return numpy_results
145
196
 
@@ -148,16 +199,163 @@ class PyTorchInferenceHandler:
148
199
  Convenience wrapper for predict that returns NumPy arrays or scalars.
149
200
  """
150
201
  tensor_results = self.predict(features)
151
-
202
+
152
203
  if self.task == "regression":
153
- # .item() implicitly moves to CPU
204
+ # .item() implicitly moves to CPU and returns a Python scalar
154
205
  return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
155
206
  else: # classification
156
207
  return {
157
208
  PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
158
- # Move tensor to CPU before converting to NumPy
159
209
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
160
210
  }
211
+
212
+ def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
213
+ """
214
+ Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
215
+
216
+ `target_id` must be implemented.
217
+ """
218
+ if self.target_id is None:
219
+ raise AttributeError(f"'target_id' has not been implemented.")
220
+
221
+ if self.task == "regression":
222
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
223
+ else:
224
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
225
+
226
+ return {self.target_id: result}
227
+
228
+
229
+ class PyTorchInferenceHandlerMulti(_BaseInferenceHandler):
230
+ """
231
+ Handles loading a PyTorch model's state dictionary and performing inference
232
+ for multi-target regression or multi-label classification tasks.
233
+ """
234
+ def __init__(self,
235
+ model: nn.Module,
236
+ state_dict: Union[str, Path],
237
+ task: Literal["multi_target_regression", "multi_label_classification"],
238
+ device: str = 'cpu',
239
+ target_ids: Optional[list[str]] = None,
240
+ scaler: Optional[Union[PytorchScaler, str, Path]] = None):
241
+ """
242
+ Initializes the handler for multi-target tasks.
243
+
244
+ Args:
245
+ model (nn.Module): An instantiated PyTorch model.
246
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
247
+ task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
248
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
249
+ target_ids (list[str] | None): An optional identifier for the targets.
250
+ scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
251
+ """
252
+ super().__init__(model, state_dict, device, scaler)
253
+
254
+ if task not in ["multi_target_regression", "multi_label_classification"]:
255
+ raise ValueError("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
256
+ self.task = task
257
+ self.target_ids = target_ids
258
+
259
+ def predict_batch(self,
260
+ features: Union[np.ndarray, torch.Tensor],
261
+ classification_threshold: float = 0.5
262
+ ) -> Dict[str, torch.Tensor]:
263
+ """
264
+ Core batch prediction method for multi-target models.
265
+
266
+ Args:
267
+ features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
268
+ classification_threshold (float): The threshold to convert probabilities
269
+ into binary predictions for multi-label classification.
270
+
271
+ Returns:
272
+ A dictionary containing the raw output tensors from the model.
273
+ """
274
+ if features.ndim != 2:
275
+ raise ValueError("Input for batch prediction must be a 2D array or tensor.")
276
+
277
+ input_tensor = self._preprocess_input(features)
278
+
279
+ with torch.no_grad():
280
+ output = self.model(input_tensor)
281
+
282
+ if self.task == "multi_label_classification":
283
+ probs = torch.sigmoid(output)
284
+ # Get binary predictions based on the threshold
285
+ labels = (probs >= classification_threshold).int()
286
+ return {
287
+ PyTorchInferenceKeys.LABELS: labels,
288
+ PyTorchInferenceKeys.PROBABILITIES: probs
289
+ }
290
+ else: # multi_target_regression
291
+ # The output is already in the correct [batch_size, n_targets] shape
292
+ return {PyTorchInferenceKeys.PREDICTIONS: output}
293
+
294
+ def predict(self,
295
+ features: Union[np.ndarray, torch.Tensor],
296
+ classification_threshold: float = 0.5
297
+ ) -> Dict[str, torch.Tensor]:
298
+ """
299
+ Core single-sample prediction method for multi-target models.
300
+
301
+ Args:
302
+ features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
303
+ classification_threshold (float): The threshold for multi-label tasks.
304
+
305
+ Returns:
306
+ A dictionary containing the raw output tensors for a single sample.
307
+ """
308
+ if features.ndim == 1:
309
+ features = features.reshape(1, -1)
310
+
311
+ if features.shape[0] != 1:
312
+ raise ValueError("The predict() method is for a single sample. Use predict_batch() for multiple samples.")
313
+
314
+ batch_results = self.predict_batch(features, classification_threshold)
315
+
316
+ single_results = {key: value[0] for key, value in batch_results.items()}
317
+ return single_results
318
+
319
+ # --- NumPy Convenience Wrappers (on CPU) ---
320
+
321
+ def predict_batch_numpy(self,
322
+ features: Union[np.ndarray, torch.Tensor],
323
+ classification_threshold: float = 0.5
324
+ ) -> Dict[str, np.ndarray]:
325
+ """
326
+ Convenience wrapper for predict_batch that returns NumPy arrays.
327
+ """
328
+ tensor_results = self.predict_batch(features, classification_threshold)
329
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
330
+ return numpy_results
331
+
332
+ def predict_numpy(self,
333
+ features: Union[np.ndarray, torch.Tensor],
334
+ classification_threshold: float = 0.5
335
+ ) -> Dict[str, np.ndarray]:
336
+ """
337
+ Convenience wrapper for predict that returns NumPy arrays for a single sample.
338
+ Note: For multi-target models, the output is always an array.
339
+ """
340
+ tensor_results = self.predict(features, classification_threshold)
341
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
342
+ return numpy_results
343
+
344
+ def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
345
+ """
346
+ Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
347
+
348
+ `target_ids` must be implemented.
349
+ """
350
+ if self.target_ids is None:
351
+ raise AttributeError(f"'target_id' has not been implemented.")
352
+
353
+ if self.task == "multi_target_regression":
354
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
355
+ else:
356
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
357
+
358
+ return {key: value for key, value in zip(self.target_ids, result)}
161
359
 
162
360
 
163
361
  def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
ml_tools/ML_models.py CHANGED
@@ -89,10 +89,6 @@ class _BaseMLP(nn.Module):
89
89
  class MultilayerPerceptron(_BaseMLP):
90
90
  """
91
91
  Creates a versatile Multilayer Perceptron (MLP) for regression or classification tasks.
92
-
93
- This model generates raw output values (logits) suitable for use with loss
94
- functions like `nn.CrossEntropyLoss` (for classification) or `nn.MSELoss`
95
- (for regression).
96
92
  """
97
93
  def __init__(self, in_features: int, out_targets: int,
98
94
  hidden_layers: List[int] = [256, 128], drop_out: float = 0.2) -> None: