dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py
CHANGED
|
@@ -5,16 +5,15 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Union, Literal, Dict, Any, Optional
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
|
|
8
|
-
from .ML_scaler import
|
|
8
|
+
from .ML_scaler import DragonScaler
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .
|
|
12
|
+
from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"PyTorchInferenceHandlerMulti",
|
|
16
|
+
"DragonInferenceHandler",
|
|
18
17
|
"multi_inference_regression",
|
|
19
18
|
"multi_inference_classification"
|
|
20
19
|
]
|
|
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
|
|
|
31
30
|
model: nn.Module,
|
|
32
31
|
state_dict: Union[str, Path],
|
|
33
32
|
device: str = 'cpu',
|
|
34
|
-
scaler: Optional[Union[
|
|
33
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
35
34
|
"""
|
|
36
35
|
Initializes the handler.
|
|
37
36
|
|
|
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
39
38
|
model (nn.Module): An instantiated PyTorch model.
|
|
40
39
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
41
40
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
42
|
-
scaler (
|
|
41
|
+
scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
|
|
43
42
|
"""
|
|
44
43
|
self.model = model
|
|
45
44
|
self.device = self._validate_device(device)
|
|
45
|
+
self._classification_threshold = 0.5
|
|
46
|
+
self._loaded_threshold: bool = False
|
|
47
|
+
self._loaded_class_map: bool = False
|
|
48
|
+
self._class_map: Optional[dict[str,int]] = None
|
|
49
|
+
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
50
|
+
self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
|
|
46
51
|
|
|
47
52
|
# Load the scaler if a path is provided
|
|
48
53
|
if scaler is not None:
|
|
49
54
|
if isinstance(scaler, (str, Path)):
|
|
50
|
-
self.scaler =
|
|
55
|
+
self.scaler = DragonScaler.load(scaler)
|
|
51
56
|
else:
|
|
52
57
|
self.scaler = scaler
|
|
53
58
|
else:
|
|
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
|
|
|
58
63
|
try:
|
|
59
64
|
# Load whatever is in the file
|
|
60
65
|
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
-
|
|
62
|
-
# Check if it's
|
|
63
|
-
if isinstance(loaded_data, dict)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
|
|
67
|
+
# Check if it's dictionary or a old weights-only file
|
|
68
|
+
if isinstance(loaded_data, dict):
|
|
69
|
+
self._loaded_data_dict = loaded_data # Store the dict
|
|
70
|
+
|
|
71
|
+
if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
72
|
+
# It's a new training checkpoint, extract the weights
|
|
73
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
74
|
+
|
|
75
|
+
# attempt to get a custom classification threshold
|
|
76
|
+
if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
|
|
77
|
+
try:
|
|
78
|
+
self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
|
|
79
|
+
except Exception as e_int:
|
|
80
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
|
|
81
|
+
self._classification_threshold = 0.5
|
|
82
|
+
else:
|
|
83
|
+
_LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
|
|
84
|
+
self._loaded_threshold = True
|
|
85
|
+
|
|
86
|
+
# attempt to get a class map
|
|
87
|
+
if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
|
|
88
|
+
try:
|
|
89
|
+
self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
|
|
90
|
+
except Exception as e_int:
|
|
91
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
|
|
92
|
+
else:
|
|
93
|
+
if isinstance(self._class_map, dict):
|
|
94
|
+
self._loaded_class_map = True
|
|
95
|
+
self.set_class_map(self._class_map)
|
|
96
|
+
else:
|
|
97
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
|
|
98
|
+
self._class_map = None
|
|
99
|
+
else:
|
|
100
|
+
# It's a state_dict, load it directly
|
|
101
|
+
self.model.load_state_dict(loaded_data)
|
|
102
|
+
|
|
66
103
|
else:
|
|
67
|
-
|
|
104
|
+
# It's an old state_dict (just weights), load it directly
|
|
68
105
|
self.model.load_state_dict(loaded_data)
|
|
69
106
|
|
|
70
107
|
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
@@ -85,21 +122,37 @@ class _BaseInferenceHandler(ABC):
|
|
|
85
122
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
86
123
|
device_lower = "cpu"
|
|
87
124
|
return torch.device(device_lower)
|
|
88
|
-
|
|
89
|
-
def
|
|
90
|
-
"""
|
|
91
|
-
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
92
|
-
present, and moves it to the correct device.
|
|
125
|
+
|
|
126
|
+
def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
|
|
93
127
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
128
|
+
Sets the class name mapping to translate predicted integer labels back into string names.
|
|
129
|
+
|
|
130
|
+
If a class_map was previously loaded from a model configuration, this
|
|
131
|
+
method will log a warning and refuse to update the value. This
|
|
132
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
133
|
+
|
|
134
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
101
135
|
|
|
102
|
-
|
|
136
|
+
Args:
|
|
137
|
+
class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
|
|
138
|
+
force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
|
|
139
|
+
"""
|
|
140
|
+
if self._loaded_class_map:
|
|
141
|
+
warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
|
|
142
|
+
if not force_overwrite:
|
|
143
|
+
warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
|
|
144
|
+
_LOGGER.warning(warning_message)
|
|
145
|
+
return
|
|
146
|
+
else:
|
|
147
|
+
warning_message += " Overwriting it for this inference instance."
|
|
148
|
+
_LOGGER.warning(warning_message)
|
|
149
|
+
|
|
150
|
+
# Store the map and invert it for fast lookup
|
|
151
|
+
self._class_map = class_map
|
|
152
|
+
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
153
|
+
# Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
|
|
154
|
+
self._loaded_class_map = True # Protect this newly set map
|
|
155
|
+
_LOGGER.info("Class map set for label-to-name translation.")
|
|
103
156
|
|
|
104
157
|
@abstractmethod
|
|
105
158
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
@@ -112,46 +165,67 @@ class _BaseInferenceHandler(ABC):
|
|
|
112
165
|
pass
|
|
113
166
|
|
|
114
167
|
|
|
115
|
-
class
|
|
168
|
+
class DragonInferenceHandler(_BaseInferenceHandler):
|
|
116
169
|
"""
|
|
117
|
-
Handles loading a PyTorch model's state dictionary and performing inference
|
|
118
|
-
for single-target regression or classification tasks.
|
|
170
|
+
Handles loading a PyTorch model's state dictionary and performing inference.
|
|
119
171
|
"""
|
|
120
172
|
def __init__(self,
|
|
121
173
|
model: nn.Module,
|
|
122
174
|
state_dict: Union[str, Path],
|
|
123
|
-
task: Literal["classification", "regression"],
|
|
175
|
+
task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
|
|
124
176
|
device: str = 'cpu',
|
|
125
|
-
|
|
126
|
-
scaler: Optional[Union[
|
|
177
|
+
target_ids: Optional[list[str]] = None,
|
|
178
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
127
179
|
"""
|
|
128
180
|
Initializes the handler for single-target tasks.
|
|
129
181
|
|
|
130
182
|
Args:
|
|
131
183
|
model (nn.Module): An instantiated PyTorch model architecture.
|
|
132
184
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
133
|
-
task (str): The type of task
|
|
185
|
+
task (str): The type of task.
|
|
134
186
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
135
187
|
target_id (str | None): An optional identifier for the target.
|
|
136
|
-
scaler (
|
|
188
|
+
scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
|
|
189
|
+
|
|
190
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
137
191
|
"""
|
|
138
192
|
# Call the parent constructor to handle model loading, device, and scaler
|
|
139
193
|
super().__init__(model, state_dict, device, scaler)
|
|
140
194
|
|
|
141
|
-
if task not in [
|
|
142
|
-
|
|
195
|
+
if task not in [MLTaskKeys.REGRESSION,
|
|
196
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
197
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
198
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
199
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
200
|
+
_LOGGER.error(f"'task' not recognized: '{task}'.")
|
|
201
|
+
raise ValueError()
|
|
143
202
|
self.task = task
|
|
144
|
-
self.
|
|
203
|
+
self.target_ids = target_ids
|
|
204
|
+
|
|
205
|
+
def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
206
|
+
"""
|
|
207
|
+
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
208
|
+
present, and moves it to the correct device.
|
|
209
|
+
"""
|
|
210
|
+
if isinstance(features, np.ndarray):
|
|
211
|
+
features_tensor = torch.from_numpy(features).float()
|
|
212
|
+
else:
|
|
213
|
+
features_tensor = features.float()
|
|
214
|
+
|
|
215
|
+
if self.scaler:
|
|
216
|
+
features_tensor = self.scaler.transform(features_tensor)
|
|
217
|
+
|
|
218
|
+
return features_tensor.to(self.device)
|
|
145
219
|
|
|
146
220
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
147
221
|
"""
|
|
148
|
-
Core batch prediction method
|
|
222
|
+
Core batch prediction method.
|
|
149
223
|
|
|
150
224
|
Args:
|
|
151
225
|
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
152
226
|
|
|
153
227
|
Returns:
|
|
154
|
-
A dictionary containing the raw output tensors from the model.
|
|
228
|
+
Dict: A dictionary containing the raw output tensors from the model.
|
|
155
229
|
"""
|
|
156
230
|
if features.ndim != 2:
|
|
157
231
|
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
@@ -162,16 +236,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
162
236
|
with torch.no_grad():
|
|
163
237
|
output = self.model(input_tensor)
|
|
164
238
|
|
|
165
|
-
if self.task ==
|
|
239
|
+
if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
|
|
166
240
|
probs = torch.softmax(output, dim=1)
|
|
167
241
|
labels = torch.argmax(probs, dim=1)
|
|
168
242
|
return {
|
|
169
243
|
PyTorchInferenceKeys.LABELS: labels,
|
|
170
244
|
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
171
245
|
}
|
|
172
|
-
|
|
246
|
+
|
|
247
|
+
elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
|
|
248
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
249
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
250
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
251
|
+
output = output.squeeze(1)
|
|
252
|
+
|
|
253
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
254
|
+
labels = (probs >= self._classification_threshold).int()
|
|
255
|
+
return {
|
|
256
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
257
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
elif self.task == MLTaskKeys.REGRESSION:
|
|
173
261
|
# For single-target regression, ensure output is flattened
|
|
174
262
|
return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
|
|
263
|
+
|
|
264
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
265
|
+
probs = torch.sigmoid(output)
|
|
266
|
+
# Get binary predictions based on the threshold
|
|
267
|
+
labels = (probs >= self._classification_threshold).int()
|
|
268
|
+
return {
|
|
269
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
270
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
274
|
+
# The output is already in the correct [batch_size, n_targets] shape
|
|
275
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
276
|
+
|
|
277
|
+
else:
|
|
278
|
+
# should never happen
|
|
279
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
280
|
+
raise ValueError()
|
|
175
281
|
|
|
176
282
|
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
177
283
|
"""
|
|
@@ -181,7 +287,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
181
287
|
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
182
288
|
|
|
183
289
|
Returns:
|
|
184
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
290
|
+
Dict: A dictionary containing the raw output tensors for a single sample.
|
|
185
291
|
"""
|
|
186
292
|
if features.ndim == 1:
|
|
187
293
|
features = features.reshape(1, -1) # Reshape to a batch of one
|
|
@@ -195,193 +301,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
195
301
|
# Extract the first (and only) result from the batch output
|
|
196
302
|
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
197
303
|
return single_results
|
|
198
|
-
|
|
304
|
+
|
|
199
305
|
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
200
306
|
|
|
201
307
|
def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
|
|
202
308
|
"""
|
|
203
|
-
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
309
|
+
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
310
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
204
311
|
"""
|
|
205
312
|
tensor_results = self.predict_batch(features)
|
|
206
313
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
314
|
+
|
|
315
|
+
# Add string names for classification if map exists
|
|
316
|
+
is_classification = self.task in [
|
|
317
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
318
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
319
|
+
]
|
|
320
|
+
|
|
321
|
+
if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
322
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
323
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
|
|
324
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
325
|
+
for label_id in int_labels
|
|
326
|
+
]
|
|
327
|
+
|
|
207
328
|
return numpy_results
|
|
208
329
|
|
|
209
330
|
def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
210
331
|
"""
|
|
211
|
-
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
332
|
+
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
333
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
212
334
|
"""
|
|
213
335
|
tensor_results = self.predict(features)
|
|
214
336
|
|
|
215
|
-
if self.task ==
|
|
337
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
216
338
|
# .item() implicitly moves to CPU and returns a Python scalar
|
|
217
339
|
return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
218
|
-
|
|
340
|
+
|
|
341
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
342
|
+
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
343
|
+
label_name = "Unknown"
|
|
344
|
+
if self._idx_to_class:
|
|
345
|
+
label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
|
|
346
|
+
|
|
219
347
|
return {
|
|
220
|
-
PyTorchInferenceKeys.LABELS:
|
|
348
|
+
PyTorchInferenceKeys.LABELS: int_label,
|
|
349
|
+
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
221
350
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
222
351
|
}
|
|
352
|
+
|
|
353
|
+
elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
354
|
+
# For multi-target models, the output is always an array.
|
|
355
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
356
|
+
return numpy_results
|
|
357
|
+
else:
|
|
358
|
+
# should never happen
|
|
359
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
360
|
+
raise ValueError()
|
|
223
361
|
|
|
224
362
|
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
225
363
|
"""
|
|
226
364
|
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
227
365
|
|
|
228
|
-
`
|
|
366
|
+
`target_ids` must be implemented.
|
|
229
367
|
"""
|
|
230
|
-
if self.
|
|
231
|
-
_LOGGER.error(f"'
|
|
368
|
+
if self.target_ids is None:
|
|
369
|
+
_LOGGER.error(f"'target_ids' has not been implemented.")
|
|
232
370
|
raise AttributeError()
|
|
233
371
|
|
|
234
|
-
if self.task ==
|
|
372
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
235
373
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
|
|
236
|
-
|
|
374
|
+
return {self.target_ids[0]: result}
|
|
375
|
+
|
|
376
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
237
377
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
|
|
378
|
+
return {self.target_ids[0]: result}
|
|
238
379
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
task: Literal["multi_target_regression", "multi_label_classification"],
|
|
251
|
-
device: str = 'cpu',
|
|
252
|
-
target_ids: Optional[list[str]] = None,
|
|
253
|
-
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
254
|
-
"""
|
|
255
|
-
Initializes the handler for multi-target tasks.
|
|
256
|
-
|
|
257
|
-
Args:
|
|
258
|
-
model (nn.Module): An instantiated PyTorch model.
|
|
259
|
-
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
260
|
-
task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
|
|
261
|
-
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
262
|
-
target_ids (list[str] | None): An optional identifier for the targets.
|
|
263
|
-
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
264
|
-
"""
|
|
265
|
-
super().__init__(model, state_dict, device, scaler)
|
|
266
|
-
|
|
267
|
-
if task not in ["multi_target_regression", "multi_label_classification"]:
|
|
268
|
-
_LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
|
|
269
|
-
raise ValueError()
|
|
270
|
-
self.task = task
|
|
271
|
-
self.target_ids = target_ids
|
|
272
|
-
|
|
273
|
-
def predict_batch(self,
|
|
274
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
275
|
-
classification_threshold: float = 0.5
|
|
276
|
-
) -> Dict[str, torch.Tensor]:
|
|
277
|
-
"""
|
|
278
|
-
Core batch prediction method for multi-target models.
|
|
279
|
-
|
|
280
|
-
Args:
|
|
281
|
-
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
282
|
-
classification_threshold (float): The threshold to convert probabilities
|
|
283
|
-
into binary predictions for multi-label classification.
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
A dictionary containing the raw output tensors from the model.
|
|
287
|
-
"""
|
|
288
|
-
if features.ndim != 2:
|
|
289
|
-
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
290
|
-
raise ValueError()
|
|
291
|
-
|
|
292
|
-
input_tensor = self._preprocess_input(features)
|
|
293
|
-
|
|
294
|
-
with torch.no_grad():
|
|
295
|
-
output = self.model(input_tensor)
|
|
296
|
-
|
|
297
|
-
if self.task == "multi_label_classification":
|
|
298
|
-
probs = torch.sigmoid(output)
|
|
299
|
-
# Get binary predictions based on the threshold
|
|
300
|
-
labels = (probs >= classification_threshold).int()
|
|
301
|
-
return {
|
|
302
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
303
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
304
|
-
}
|
|
305
|
-
else: # multi_target_regression
|
|
306
|
-
# The output is already in the correct [batch_size, n_targets] shape
|
|
307
|
-
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
308
|
-
|
|
309
|
-
def predict(self,
|
|
310
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
311
|
-
classification_threshold: float = 0.5
|
|
312
|
-
) -> Dict[str, torch.Tensor]:
|
|
313
|
-
"""
|
|
314
|
-
Core single-sample prediction method for multi-target models.
|
|
315
|
-
|
|
316
|
-
Args:
|
|
317
|
-
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
318
|
-
classification_threshold (float): The threshold for multi-label tasks.
|
|
319
|
-
|
|
320
|
-
Returns:
|
|
321
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
322
|
-
"""
|
|
323
|
-
if features.ndim == 1:
|
|
324
|
-
features = features.reshape(1, -1)
|
|
325
|
-
|
|
326
|
-
if features.shape[0] != 1:
|
|
327
|
-
_LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
|
|
380
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
381
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
|
|
382
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
383
|
+
|
|
384
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
385
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
386
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
387
|
+
|
|
388
|
+
else:
|
|
389
|
+
# should never happen
|
|
390
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
328
391
|
raise ValueError()
|
|
329
|
-
|
|
330
|
-
batch_results = self.predict_batch(features, classification_threshold)
|
|
331
|
-
|
|
332
|
-
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
333
|
-
return single_results
|
|
334
|
-
|
|
335
|
-
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
336
|
-
|
|
337
|
-
def predict_batch_numpy(self,
|
|
338
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
339
|
-
classification_threshold: float = 0.5
|
|
340
|
-
) -> Dict[str, np.ndarray]:
|
|
341
|
-
"""
|
|
342
|
-
Convenience wrapper for predict_batch that returns NumPy arrays.
|
|
343
|
-
"""
|
|
344
|
-
tensor_results = self.predict_batch(features, classification_threshold)
|
|
345
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
346
|
-
return numpy_results
|
|
347
|
-
|
|
348
|
-
def predict_numpy(self,
|
|
349
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
350
|
-
classification_threshold: float = 0.5
|
|
351
|
-
) -> Dict[str, np.ndarray]:
|
|
352
|
-
"""
|
|
353
|
-
Convenience wrapper for predict that returns NumPy arrays for a single sample.
|
|
354
|
-
Note: For multi-target models, the output is always an array.
|
|
355
|
-
"""
|
|
356
|
-
tensor_results = self.predict(features, classification_threshold)
|
|
357
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
358
|
-
return numpy_results
|
|
359
|
-
|
|
360
|
-
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
361
|
-
"""
|
|
362
|
-
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
363
392
|
|
|
364
|
-
|
|
393
|
+
def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
|
|
365
394
|
"""
|
|
366
|
-
|
|
367
|
-
_LOGGER.error(f"'target_id' has not been implemented.")
|
|
368
|
-
raise AttributeError()
|
|
395
|
+
Sets the classification threshold for the current inference instance.
|
|
369
396
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
397
|
+
If a threshold was previously loaded from a model configuration, this
|
|
398
|
+
method will log a warning and refuse to update the value. This
|
|
399
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
374
400
|
|
|
375
|
-
|
|
401
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
376
402
|
|
|
403
|
+
Args:
|
|
404
|
+
threshold (float): The new classification threshold value to set.
|
|
405
|
+
force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
|
|
406
|
+
"""
|
|
407
|
+
if self._loaded_threshold:
|
|
408
|
+
warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
|
|
409
|
+
if not force_overwrite:
|
|
410
|
+
warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
|
|
411
|
+
_LOGGER.warning(warning_message)
|
|
412
|
+
return
|
|
413
|
+
else:
|
|
414
|
+
warning_message += f" Overwriting it to {threshold}."
|
|
415
|
+
_LOGGER.warning(warning_message)
|
|
416
|
+
|
|
417
|
+
self._classification_threshold = threshold
|
|
377
418
|
|
|
378
|
-
|
|
419
|
+
|
|
420
|
+
def multi_inference_regression(handlers: list[DragonInferenceHandler],
|
|
379
421
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
380
422
|
output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
|
|
381
423
|
"""
|
|
382
424
|
Performs regression inference using multiple models on a single feature vector.
|
|
383
425
|
|
|
384
|
-
This function iterates through a list of
|
|
426
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
385
427
|
each configured for a different regression target. It runs a prediction for
|
|
386
428
|
each handler using the same input feature vector and returns the results
|
|
387
429
|
in a dictionary.
|
|
@@ -391,7 +433,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
391
433
|
- 2D input: Returns a dictionary mapping target ID to a list of values.
|
|
392
434
|
|
|
393
435
|
Args:
|
|
394
|
-
handlers (list[
|
|
436
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference
|
|
395
437
|
handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
|
|
396
438
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
|
|
397
439
|
output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
|
|
@@ -421,11 +463,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
421
463
|
results: dict[str,Any] = dict()
|
|
422
464
|
for handler in handlers:
|
|
423
465
|
# validation
|
|
424
|
-
if handler.
|
|
425
|
-
_LOGGER.error("All inference handlers must have a '
|
|
466
|
+
if handler.target_ids is None:
|
|
467
|
+
_LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
|
|
426
468
|
raise AttributeError()
|
|
427
|
-
if handler.task !=
|
|
428
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
469
|
+
if handler.task != MLTaskKeys.REGRESSION:
|
|
470
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
|
|
429
471
|
raise ValueError()
|
|
430
472
|
|
|
431
473
|
# inference
|
|
@@ -434,33 +476,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
434
476
|
numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
435
477
|
if is_single_sample:
|
|
436
478
|
# For a single sample, convert the 1-element array to a Python scalar
|
|
437
|
-
results[handler.
|
|
479
|
+
results[handler.target_ids[0]] = numpy_result.item()
|
|
438
480
|
else:
|
|
439
481
|
# For a batch, return the full NumPy array of predictions
|
|
440
|
-
results[handler.
|
|
482
|
+
results[handler.target_ids[0]] = numpy_result
|
|
441
483
|
|
|
442
484
|
else: # output == "torch"
|
|
443
485
|
# This path returns PyTorch tensors on the model's device
|
|
444
486
|
torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
445
487
|
if is_single_sample:
|
|
446
488
|
# For a single sample, return the 0-dim tensor
|
|
447
|
-
results[handler.
|
|
489
|
+
results[handler.target_ids[0]] = torch_result[0]
|
|
448
490
|
else:
|
|
449
491
|
# For a batch, return the full tensor of predictions
|
|
450
|
-
results[handler.
|
|
492
|
+
results[handler.target_ids[0]] = torch_result
|
|
451
493
|
|
|
452
494
|
return results
|
|
453
495
|
|
|
454
496
|
|
|
455
497
|
def multi_inference_classification(
|
|
456
|
-
handlers: list[
|
|
498
|
+
handlers: list[DragonInferenceHandler],
|
|
457
499
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
458
500
|
output: Literal["numpy","torch"]="numpy"
|
|
459
501
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
460
502
|
"""
|
|
461
503
|
Performs classification inference on a single sample or a batch.
|
|
462
504
|
|
|
463
|
-
This function iterates through a list of
|
|
505
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
464
506
|
each configured for a different classification target. It returns two
|
|
465
507
|
dictionaries: one for the predicted labels and one for the probabilities.
|
|
466
508
|
|
|
@@ -469,7 +511,7 @@ def multi_inference_classification(
|
|
|
469
511
|
- 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
|
|
470
512
|
|
|
471
513
|
Args:
|
|
472
|
-
handlers (list[
|
|
514
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
|
|
473
515
|
with `task="classification"`.
|
|
474
516
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
|
|
475
517
|
or a batch of samples (2D) for prediction.
|
|
@@ -502,11 +544,11 @@ def multi_inference_classification(
|
|
|
502
544
|
|
|
503
545
|
for handler in handlers:
|
|
504
546
|
# Validation
|
|
505
|
-
if handler.
|
|
547
|
+
if handler.target_ids is None:
|
|
506
548
|
_LOGGER.error("All inference handlers must have a 'target_id' attribute.")
|
|
507
549
|
raise AttributeError()
|
|
508
|
-
if handler.task
|
|
509
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
550
|
+
if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
551
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
|
|
510
552
|
raise ValueError()
|
|
511
553
|
|
|
512
554
|
# Inference
|
|
@@ -524,15 +566,15 @@ def multi_inference_classification(
|
|
|
524
566
|
# For "numpy", convert the single label to a Python int scalar.
|
|
525
567
|
# For "torch", get the 0-dim tensor label.
|
|
526
568
|
if output == "numpy":
|
|
527
|
-
labels_results[handler.
|
|
569
|
+
labels_results[handler.target_ids[0]] = labels.item()
|
|
528
570
|
else: # torch
|
|
529
|
-
labels_results[handler.
|
|
571
|
+
labels_results[handler.target_ids[0]] = labels[0]
|
|
530
572
|
|
|
531
573
|
# The probabilities are an array/tensor of values
|
|
532
|
-
probs_results[handler.
|
|
574
|
+
probs_results[handler.target_ids[0]] = probabilities[0]
|
|
533
575
|
else:
|
|
534
|
-
labels_results[handler.
|
|
535
|
-
probs_results[handler.
|
|
576
|
+
labels_results[handler.target_ids[0]] = labels
|
|
577
|
+
probs_results[handler.target_ids[0]] = probabilities
|
|
536
578
|
|
|
537
579
|
return labels_results, probs_results
|
|
538
580
|
|