dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +726 -32
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +160 -42
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py
CHANGED
|
@@ -5,16 +5,15 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Union, Literal, Dict, Any, Optional
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
|
|
8
|
-
from .ML_scaler import
|
|
8
|
+
from .ML_scaler import DragonScaler
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .
|
|
12
|
+
from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"PyTorchInferenceHandlerMulti",
|
|
16
|
+
"DragonInferenceHandler",
|
|
18
17
|
"multi_inference_regression",
|
|
19
18
|
"multi_inference_classification"
|
|
20
19
|
]
|
|
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
|
|
|
31
30
|
model: nn.Module,
|
|
32
31
|
state_dict: Union[str, Path],
|
|
33
32
|
device: str = 'cpu',
|
|
34
|
-
scaler: Optional[Union[
|
|
33
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
35
34
|
"""
|
|
36
35
|
Initializes the handler.
|
|
37
36
|
|
|
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
39
38
|
model (nn.Module): An instantiated PyTorch model.
|
|
40
39
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
41
40
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
42
|
-
scaler (
|
|
41
|
+
scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
|
|
43
42
|
"""
|
|
44
43
|
self.model = model
|
|
45
44
|
self.device = self._validate_device(device)
|
|
45
|
+
self._classification_threshold = 0.5
|
|
46
|
+
self._loaded_threshold: bool = False
|
|
47
|
+
self._loaded_class_map: bool = False
|
|
48
|
+
self._class_map: Optional[dict[str,int]] = None
|
|
49
|
+
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
50
|
+
self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
|
|
46
51
|
|
|
47
52
|
# Load the scaler if a path is provided
|
|
48
53
|
if scaler is not None:
|
|
49
54
|
if isinstance(scaler, (str, Path)):
|
|
50
|
-
self.scaler =
|
|
55
|
+
self.scaler = DragonScaler.load(scaler)
|
|
51
56
|
else:
|
|
52
57
|
self.scaler = scaler
|
|
53
58
|
else:
|
|
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
|
|
|
58
63
|
try:
|
|
59
64
|
# Load whatever is in the file
|
|
60
65
|
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
-
|
|
62
|
-
# Check if it's
|
|
63
|
-
if isinstance(loaded_data, dict)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
|
|
67
|
+
# Check if it's dictionary or a old weights-only file
|
|
68
|
+
if isinstance(loaded_data, dict):
|
|
69
|
+
self._loaded_data_dict = loaded_data # Store the dict
|
|
70
|
+
|
|
71
|
+
if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
72
|
+
# It's a new training checkpoint, extract the weights
|
|
73
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
74
|
+
|
|
75
|
+
# attempt to get a custom classification threshold
|
|
76
|
+
if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
|
|
77
|
+
try:
|
|
78
|
+
self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
|
|
79
|
+
except Exception as e_int:
|
|
80
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
|
|
81
|
+
self._classification_threshold = 0.5
|
|
82
|
+
else:
|
|
83
|
+
_LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
|
|
84
|
+
self._loaded_threshold = True
|
|
85
|
+
|
|
86
|
+
# attempt to get a class map
|
|
87
|
+
if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
|
|
88
|
+
try:
|
|
89
|
+
self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
|
|
90
|
+
except Exception as e_int:
|
|
91
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
|
|
92
|
+
else:
|
|
93
|
+
if isinstance(self._class_map, dict):
|
|
94
|
+
self._loaded_class_map = True
|
|
95
|
+
self.set_class_map(self._class_map)
|
|
96
|
+
else:
|
|
97
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
|
|
98
|
+
self._class_map = None
|
|
99
|
+
else:
|
|
100
|
+
# It's a state_dict, load it directly
|
|
101
|
+
self.model.load_state_dict(loaded_data)
|
|
102
|
+
|
|
66
103
|
else:
|
|
67
|
-
|
|
104
|
+
# It's an old state_dict (just weights), load it directly
|
|
68
105
|
self.model.load_state_dict(loaded_data)
|
|
69
106
|
|
|
70
107
|
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
@@ -85,21 +122,37 @@ class _BaseInferenceHandler(ABC):
|
|
|
85
122
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
86
123
|
device_lower = "cpu"
|
|
87
124
|
return torch.device(device_lower)
|
|
88
|
-
|
|
89
|
-
def
|
|
90
|
-
"""
|
|
91
|
-
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
92
|
-
present, and moves it to the correct device.
|
|
125
|
+
|
|
126
|
+
def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
|
|
93
127
|
"""
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
128
|
+
Sets the class name mapping to translate predicted integer labels back into string names.
|
|
129
|
+
|
|
130
|
+
If a class_map was previously loaded from a model configuration, this
|
|
131
|
+
method will log a warning and refuse to update the value. This
|
|
132
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
133
|
+
|
|
134
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
101
135
|
|
|
102
|
-
|
|
136
|
+
Args:
|
|
137
|
+
class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
|
|
138
|
+
force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
|
|
139
|
+
"""
|
|
140
|
+
if self._loaded_class_map:
|
|
141
|
+
warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
|
|
142
|
+
if not force_overwrite:
|
|
143
|
+
warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
|
|
144
|
+
_LOGGER.warning(warning_message)
|
|
145
|
+
return
|
|
146
|
+
else:
|
|
147
|
+
warning_message += " Overwriting it for this inference instance."
|
|
148
|
+
_LOGGER.warning(warning_message)
|
|
149
|
+
|
|
150
|
+
# Store the map and invert it for fast lookup
|
|
151
|
+
self._class_map = class_map
|
|
152
|
+
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
153
|
+
# Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
|
|
154
|
+
self._loaded_class_map = True # Protect this newly set map
|
|
155
|
+
_LOGGER.info("Class map set for label-to-name translation.")
|
|
103
156
|
|
|
104
157
|
@abstractmethod
|
|
105
158
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
@@ -112,46 +165,107 @@ class _BaseInferenceHandler(ABC):
|
|
|
112
165
|
pass
|
|
113
166
|
|
|
114
167
|
|
|
115
|
-
class
|
|
168
|
+
class DragonInferenceHandler(_BaseInferenceHandler):
|
|
116
169
|
"""
|
|
117
|
-
Handles loading a PyTorch model's state dictionary and performing inference
|
|
118
|
-
for single-target regression or classification tasks.
|
|
170
|
+
Handles loading a PyTorch model's state dictionary and performing inference.
|
|
119
171
|
"""
|
|
120
172
|
def __init__(self,
|
|
121
173
|
model: nn.Module,
|
|
122
174
|
state_dict: Union[str, Path],
|
|
123
|
-
task: Literal["classification", "regression"],
|
|
175
|
+
task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
|
|
124
176
|
device: str = 'cpu',
|
|
125
|
-
|
|
126
|
-
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
177
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
127
178
|
"""
|
|
128
179
|
Initializes the handler for single-target tasks.
|
|
129
180
|
|
|
130
181
|
Args:
|
|
131
182
|
model (nn.Module): An instantiated PyTorch model architecture.
|
|
132
183
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
133
|
-
task (str): The type of task
|
|
184
|
+
task (str): The type of task.
|
|
134
185
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
135
|
-
|
|
136
|
-
|
|
186
|
+
scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
|
|
187
|
+
|
|
188
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
137
189
|
"""
|
|
138
190
|
# Call the parent constructor to handle model loading, device, and scaler
|
|
139
191
|
super().__init__(model, state_dict, device, scaler)
|
|
140
192
|
|
|
141
|
-
if task not in [
|
|
142
|
-
|
|
193
|
+
if task not in [MLTaskKeys.REGRESSION,
|
|
194
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
195
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
196
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
197
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
198
|
+
_LOGGER.error(f"'task' not recognized: '{task}'.")
|
|
199
|
+
raise ValueError()
|
|
143
200
|
self.task = task
|
|
144
|
-
self.
|
|
201
|
+
self.target_ids: Optional[list[str]] = None
|
|
202
|
+
self._target_ids_set: bool = False
|
|
203
|
+
|
|
204
|
+
# attempt to get target name or target names
|
|
205
|
+
if PyTorchCheckpointKeys.TARGET_NAME in self._loaded_data_dict:
|
|
206
|
+
try:
|
|
207
|
+
target_from_dict = [self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAME]]
|
|
208
|
+
except Exception as e_int:
|
|
209
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAME}' but an error occurred when retrieving it:\n{e_int}")
|
|
210
|
+
self.target_ids = None
|
|
211
|
+
else:
|
|
212
|
+
self.set_target_ids([target_from_dict]) # type: ignore
|
|
213
|
+
elif PyTorchCheckpointKeys.TARGET_NAMES in self._loaded_data_dict:
|
|
214
|
+
try:
|
|
215
|
+
targets_from_dict = self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAMES]
|
|
216
|
+
except Exception as e_int:
|
|
217
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAMES}' but an error occurred when retrieving it:\n{e_int}")
|
|
218
|
+
self.target_ids = None
|
|
219
|
+
else:
|
|
220
|
+
self.set_target_ids(targets_from_dict) # type: ignore
|
|
221
|
+
|
|
222
|
+
def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
223
|
+
"""
|
|
224
|
+
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
225
|
+
present, and moves it to the correct device.
|
|
226
|
+
"""
|
|
227
|
+
if isinstance(features, np.ndarray):
|
|
228
|
+
features_tensor = torch.from_numpy(features).float()
|
|
229
|
+
else:
|
|
230
|
+
features_tensor = features.float()
|
|
231
|
+
|
|
232
|
+
if self.scaler:
|
|
233
|
+
features_tensor = self.scaler.transform(features_tensor)
|
|
234
|
+
|
|
235
|
+
return features_tensor.to(self.device)
|
|
236
|
+
|
|
237
|
+
def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
|
|
238
|
+
"""
|
|
239
|
+
Assigns the provided list of strings as the target variable names.
|
|
240
|
+
If target IDs have already been set, this method will log a warning.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
target_names (list[str]): A list of target names.
|
|
244
|
+
force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
|
|
245
|
+
"""
|
|
246
|
+
if self._target_ids_set:
|
|
247
|
+
warning_message = "Target IDs was previously set."
|
|
248
|
+
if not force_overwrite:
|
|
249
|
+
warning_message += " Use `force_overwrite=True` to overwrite."
|
|
250
|
+
_LOGGER.warning(warning_message)
|
|
251
|
+
return
|
|
252
|
+
else:
|
|
253
|
+
warning_message += " Overwriting..."
|
|
254
|
+
_LOGGER.warning(warning_message)
|
|
255
|
+
|
|
256
|
+
self.target_ids = target_names
|
|
257
|
+
self._target_ids_set = True
|
|
258
|
+
_LOGGER.info("Target IDs set.")
|
|
145
259
|
|
|
146
260
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
147
261
|
"""
|
|
148
|
-
Core batch prediction method
|
|
262
|
+
Core batch prediction method.
|
|
149
263
|
|
|
150
264
|
Args:
|
|
151
265
|
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
152
266
|
|
|
153
267
|
Returns:
|
|
154
|
-
A dictionary containing the raw output tensors from the model.
|
|
268
|
+
Dict: A dictionary containing the raw output tensors from the model.
|
|
155
269
|
"""
|
|
156
270
|
if features.ndim != 2:
|
|
157
271
|
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
@@ -162,16 +276,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
162
276
|
with torch.no_grad():
|
|
163
277
|
output = self.model(input_tensor)
|
|
164
278
|
|
|
165
|
-
if self.task ==
|
|
279
|
+
if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
|
|
166
280
|
probs = torch.softmax(output, dim=1)
|
|
167
281
|
labels = torch.argmax(probs, dim=1)
|
|
168
282
|
return {
|
|
169
283
|
PyTorchInferenceKeys.LABELS: labels,
|
|
170
284
|
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
171
285
|
}
|
|
172
|
-
|
|
286
|
+
|
|
287
|
+
elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
|
|
288
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
289
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
290
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
291
|
+
output = output.squeeze(1)
|
|
292
|
+
|
|
293
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
294
|
+
labels = (probs >= self._classification_threshold).int()
|
|
295
|
+
return {
|
|
296
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
297
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
elif self.task == MLTaskKeys.REGRESSION:
|
|
173
301
|
# For single-target regression, ensure output is flattened
|
|
174
302
|
return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
|
|
303
|
+
|
|
304
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
305
|
+
probs = torch.sigmoid(output)
|
|
306
|
+
# Get binary predictions based on the threshold
|
|
307
|
+
labels = (probs >= self._classification_threshold).int()
|
|
308
|
+
return {
|
|
309
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
310
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
314
|
+
# The output is already in the correct [batch_size, n_targets] shape
|
|
315
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
# should never happen
|
|
319
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
320
|
+
raise ValueError()
|
|
175
321
|
|
|
176
322
|
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
177
323
|
"""
|
|
@@ -181,7 +327,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
181
327
|
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
182
328
|
|
|
183
329
|
Returns:
|
|
184
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
330
|
+
Dict: A dictionary containing the raw output tensors for a single sample.
|
|
185
331
|
"""
|
|
186
332
|
if features.ndim == 1:
|
|
187
333
|
features = features.reshape(1, -1) # Reshape to a batch of one
|
|
@@ -195,193 +341,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
195
341
|
# Extract the first (and only) result from the batch output
|
|
196
342
|
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
197
343
|
return single_results
|
|
198
|
-
|
|
344
|
+
|
|
199
345
|
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
200
346
|
|
|
201
347
|
def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
|
|
202
348
|
"""
|
|
203
|
-
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
349
|
+
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
350
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
204
351
|
"""
|
|
205
352
|
tensor_results = self.predict_batch(features)
|
|
206
353
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
354
|
+
|
|
355
|
+
# Add string names for classification if map exists
|
|
356
|
+
is_classification = self.task in [
|
|
357
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
358
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
362
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
363
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
|
|
364
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
365
|
+
for label_id in int_labels
|
|
366
|
+
]
|
|
367
|
+
|
|
207
368
|
return numpy_results
|
|
208
369
|
|
|
209
370
|
def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
210
371
|
"""
|
|
211
|
-
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
372
|
+
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
373
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
212
374
|
"""
|
|
213
375
|
tensor_results = self.predict(features)
|
|
214
376
|
|
|
215
|
-
if self.task ==
|
|
377
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
216
378
|
# .item() implicitly moves to CPU and returns a Python scalar
|
|
217
379
|
return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
218
|
-
|
|
380
|
+
|
|
381
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
382
|
+
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
383
|
+
label_name = "Unknown"
|
|
384
|
+
if self._idx_to_class:
|
|
385
|
+
label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
|
|
386
|
+
|
|
219
387
|
return {
|
|
220
|
-
PyTorchInferenceKeys.LABELS:
|
|
388
|
+
PyTorchInferenceKeys.LABELS: int_label,
|
|
389
|
+
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
221
390
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
222
391
|
}
|
|
392
|
+
|
|
393
|
+
elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
394
|
+
# For multi-target models, the output is always an array.
|
|
395
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
396
|
+
return numpy_results
|
|
397
|
+
else:
|
|
398
|
+
# should never happen
|
|
399
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
400
|
+
raise ValueError()
|
|
223
401
|
|
|
224
402
|
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
225
403
|
"""
|
|
226
404
|
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
227
405
|
|
|
228
|
-
`
|
|
406
|
+
`target_ids` must be implemented.
|
|
229
407
|
"""
|
|
230
|
-
if self.
|
|
231
|
-
_LOGGER.error(f"'
|
|
408
|
+
if self.target_ids is None:
|
|
409
|
+
_LOGGER.error(f"'target_ids' has not been implemented.")
|
|
232
410
|
raise AttributeError()
|
|
233
411
|
|
|
234
|
-
if self.task ==
|
|
412
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
235
413
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
|
|
236
|
-
|
|
414
|
+
return {self.target_ids[0]: result}
|
|
415
|
+
|
|
416
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
237
417
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
|
|
418
|
+
return {self.target_ids[0]: result}
|
|
238
419
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
task: Literal["multi_target_regression", "multi_label_classification"],
|
|
251
|
-
device: str = 'cpu',
|
|
252
|
-
target_ids: Optional[list[str]] = None,
|
|
253
|
-
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
254
|
-
"""
|
|
255
|
-
Initializes the handler for multi-target tasks.
|
|
256
|
-
|
|
257
|
-
Args:
|
|
258
|
-
model (nn.Module): An instantiated PyTorch model.
|
|
259
|
-
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
260
|
-
task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
|
|
261
|
-
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
262
|
-
target_ids (list[str] | None): An optional identifier for the targets.
|
|
263
|
-
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
264
|
-
"""
|
|
265
|
-
super().__init__(model, state_dict, device, scaler)
|
|
266
|
-
|
|
267
|
-
if task not in ["multi_target_regression", "multi_label_classification"]:
|
|
268
|
-
_LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
|
|
269
|
-
raise ValueError()
|
|
270
|
-
self.task = task
|
|
271
|
-
self.target_ids = target_ids
|
|
272
|
-
|
|
273
|
-
def predict_batch(self,
|
|
274
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
275
|
-
classification_threshold: float = 0.5
|
|
276
|
-
) -> Dict[str, torch.Tensor]:
|
|
277
|
-
"""
|
|
278
|
-
Core batch prediction method for multi-target models.
|
|
279
|
-
|
|
280
|
-
Args:
|
|
281
|
-
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
282
|
-
classification_threshold (float): The threshold to convert probabilities
|
|
283
|
-
into binary predictions for multi-label classification.
|
|
284
|
-
|
|
285
|
-
Returns:
|
|
286
|
-
A dictionary containing the raw output tensors from the model.
|
|
287
|
-
"""
|
|
288
|
-
if features.ndim != 2:
|
|
289
|
-
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
290
|
-
raise ValueError()
|
|
291
|
-
|
|
292
|
-
input_tensor = self._preprocess_input(features)
|
|
293
|
-
|
|
294
|
-
with torch.no_grad():
|
|
295
|
-
output = self.model(input_tensor)
|
|
296
|
-
|
|
297
|
-
if self.task == "multi_label_classification":
|
|
298
|
-
probs = torch.sigmoid(output)
|
|
299
|
-
# Get binary predictions based on the threshold
|
|
300
|
-
labels = (probs >= classification_threshold).int()
|
|
301
|
-
return {
|
|
302
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
303
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
304
|
-
}
|
|
305
|
-
else: # multi_target_regression
|
|
306
|
-
# The output is already in the correct [batch_size, n_targets] shape
|
|
307
|
-
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
308
|
-
|
|
309
|
-
def predict(self,
|
|
310
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
311
|
-
classification_threshold: float = 0.5
|
|
312
|
-
) -> Dict[str, torch.Tensor]:
|
|
313
|
-
"""
|
|
314
|
-
Core single-sample prediction method for multi-target models.
|
|
315
|
-
|
|
316
|
-
Args:
|
|
317
|
-
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
318
|
-
classification_threshold (float): The threshold for multi-label tasks.
|
|
319
|
-
|
|
320
|
-
Returns:
|
|
321
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
322
|
-
"""
|
|
323
|
-
if features.ndim == 1:
|
|
324
|
-
features = features.reshape(1, -1)
|
|
325
|
-
|
|
326
|
-
if features.shape[0] != 1:
|
|
327
|
-
_LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
|
|
420
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
421
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
|
|
422
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
423
|
+
|
|
424
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
425
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
426
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
427
|
+
|
|
428
|
+
else:
|
|
429
|
+
# should never happen
|
|
430
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
328
431
|
raise ValueError()
|
|
329
|
-
|
|
330
|
-
batch_results = self.predict_batch(features, classification_threshold)
|
|
331
|
-
|
|
332
|
-
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
333
|
-
return single_results
|
|
334
|
-
|
|
335
|
-
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
336
|
-
|
|
337
|
-
def predict_batch_numpy(self,
|
|
338
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
339
|
-
classification_threshold: float = 0.5
|
|
340
|
-
) -> Dict[str, np.ndarray]:
|
|
341
|
-
"""
|
|
342
|
-
Convenience wrapper for predict_batch that returns NumPy arrays.
|
|
343
|
-
"""
|
|
344
|
-
tensor_results = self.predict_batch(features, classification_threshold)
|
|
345
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
346
|
-
return numpy_results
|
|
347
|
-
|
|
348
|
-
def predict_numpy(self,
|
|
349
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
350
|
-
classification_threshold: float = 0.5
|
|
351
|
-
) -> Dict[str, np.ndarray]:
|
|
352
|
-
"""
|
|
353
|
-
Convenience wrapper for predict that returns NumPy arrays for a single sample.
|
|
354
|
-
Note: For multi-target models, the output is always an array.
|
|
355
|
-
"""
|
|
356
|
-
tensor_results = self.predict(features, classification_threshold)
|
|
357
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
358
|
-
return numpy_results
|
|
359
|
-
|
|
360
|
-
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
361
|
-
"""
|
|
362
|
-
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
363
432
|
|
|
364
|
-
|
|
433
|
+
def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
|
|
365
434
|
"""
|
|
366
|
-
|
|
367
|
-
_LOGGER.error(f"'target_id' has not been implemented.")
|
|
368
|
-
raise AttributeError()
|
|
435
|
+
Sets the classification threshold for the current inference instance.
|
|
369
436
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
437
|
+
If a threshold was previously loaded from a model configuration, this
|
|
438
|
+
method will log a warning and refuse to update the value. This
|
|
439
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
374
440
|
|
|
375
|
-
|
|
441
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
376
442
|
|
|
443
|
+
Args:
|
|
444
|
+
threshold (float): The new classification threshold value to set.
|
|
445
|
+
force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
|
|
446
|
+
"""
|
|
447
|
+
if self._loaded_threshold:
|
|
448
|
+
warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
|
|
449
|
+
if not force_overwrite:
|
|
450
|
+
warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
|
|
451
|
+
_LOGGER.warning(warning_message)
|
|
452
|
+
return
|
|
453
|
+
else:
|
|
454
|
+
warning_message += f" Overwriting it to {threshold}."
|
|
455
|
+
_LOGGER.warning(warning_message)
|
|
456
|
+
|
|
457
|
+
self._classification_threshold = threshold
|
|
377
458
|
|
|
378
|
-
|
|
459
|
+
|
|
460
|
+
def multi_inference_regression(handlers: list[DragonInferenceHandler],
|
|
379
461
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
380
462
|
output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
|
|
381
463
|
"""
|
|
382
464
|
Performs regression inference using multiple models on a single feature vector.
|
|
383
465
|
|
|
384
|
-
This function iterates through a list of
|
|
466
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
385
467
|
each configured for a different regression target. It runs a prediction for
|
|
386
468
|
each handler using the same input feature vector and returns the results
|
|
387
469
|
in a dictionary.
|
|
@@ -391,7 +473,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
391
473
|
- 2D input: Returns a dictionary mapping target ID to a list of values.
|
|
392
474
|
|
|
393
475
|
Args:
|
|
394
|
-
handlers (list[
|
|
476
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference
|
|
395
477
|
handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
|
|
396
478
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
|
|
397
479
|
output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
|
|
@@ -421,11 +503,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
421
503
|
results: dict[str,Any] = dict()
|
|
422
504
|
for handler in handlers:
|
|
423
505
|
# validation
|
|
424
|
-
if handler.
|
|
425
|
-
_LOGGER.error("All inference handlers must have a '
|
|
506
|
+
if handler.target_ids is None:
|
|
507
|
+
_LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
|
|
426
508
|
raise AttributeError()
|
|
427
|
-
if handler.task !=
|
|
428
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
509
|
+
if handler.task != MLTaskKeys.REGRESSION:
|
|
510
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
|
|
429
511
|
raise ValueError()
|
|
430
512
|
|
|
431
513
|
# inference
|
|
@@ -434,33 +516,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
434
516
|
numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
435
517
|
if is_single_sample:
|
|
436
518
|
# For a single sample, convert the 1-element array to a Python scalar
|
|
437
|
-
results[handler.
|
|
519
|
+
results[handler.target_ids[0]] = numpy_result.item()
|
|
438
520
|
else:
|
|
439
521
|
# For a batch, return the full NumPy array of predictions
|
|
440
|
-
results[handler.
|
|
522
|
+
results[handler.target_ids[0]] = numpy_result
|
|
441
523
|
|
|
442
524
|
else: # output == "torch"
|
|
443
525
|
# This path returns PyTorch tensors on the model's device
|
|
444
526
|
torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
445
527
|
if is_single_sample:
|
|
446
528
|
# For a single sample, return the 0-dim tensor
|
|
447
|
-
results[handler.
|
|
529
|
+
results[handler.target_ids[0]] = torch_result[0]
|
|
448
530
|
else:
|
|
449
531
|
# For a batch, return the full tensor of predictions
|
|
450
|
-
results[handler.
|
|
532
|
+
results[handler.target_ids[0]] = torch_result
|
|
451
533
|
|
|
452
534
|
return results
|
|
453
535
|
|
|
454
536
|
|
|
455
537
|
def multi_inference_classification(
|
|
456
|
-
handlers: list[
|
|
538
|
+
handlers: list[DragonInferenceHandler],
|
|
457
539
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
458
540
|
output: Literal["numpy","torch"]="numpy"
|
|
459
541
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
460
542
|
"""
|
|
461
543
|
Performs classification inference on a single sample or a batch.
|
|
462
544
|
|
|
463
|
-
This function iterates through a list of
|
|
545
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
464
546
|
each configured for a different classification target. It returns two
|
|
465
547
|
dictionaries: one for the predicted labels and one for the probabilities.
|
|
466
548
|
|
|
@@ -469,7 +551,7 @@ def multi_inference_classification(
|
|
|
469
551
|
- 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
|
|
470
552
|
|
|
471
553
|
Args:
|
|
472
|
-
handlers (list[
|
|
554
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
|
|
473
555
|
with `task="classification"`.
|
|
474
556
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
|
|
475
557
|
or a batch of samples (2D) for prediction.
|
|
@@ -502,11 +584,11 @@ def multi_inference_classification(
|
|
|
502
584
|
|
|
503
585
|
for handler in handlers:
|
|
504
586
|
# Validation
|
|
505
|
-
if handler.
|
|
587
|
+
if handler.target_ids is None:
|
|
506
588
|
_LOGGER.error("All inference handlers must have a 'target_id' attribute.")
|
|
507
589
|
raise AttributeError()
|
|
508
|
-
if handler.task
|
|
509
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
590
|
+
if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
591
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
|
|
510
592
|
raise ValueError()
|
|
511
593
|
|
|
512
594
|
# Inference
|
|
@@ -524,15 +606,15 @@ def multi_inference_classification(
|
|
|
524
606
|
# For "numpy", convert the single label to a Python int scalar.
|
|
525
607
|
# For "torch", get the 0-dim tensor label.
|
|
526
608
|
if output == "numpy":
|
|
527
|
-
labels_results[handler.
|
|
609
|
+
labels_results[handler.target_ids[0]] = labels.item()
|
|
528
610
|
else: # torch
|
|
529
|
-
labels_results[handler.
|
|
611
|
+
labels_results[handler.target_ids[0]] = labels[0]
|
|
530
612
|
|
|
531
613
|
# The probabilities are an array/tensor of values
|
|
532
|
-
probs_results[handler.
|
|
614
|
+
probs_results[handler.target_ids[0]] = probabilities[0]
|
|
533
615
|
else:
|
|
534
|
-
labels_results[handler.
|
|
535
|
-
probs_results[handler.
|
|
616
|
+
labels_results[handler.target_ids[0]] = labels
|
|
617
|
+
probs_results[handler.target_ids[0]] = probabilities
|
|
536
618
|
|
|
537
619
|
return labels_results, probs_results
|
|
538
620
|
|