dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py
CHANGED
|
@@ -5,16 +5,15 @@ from pathlib import Path
|
|
|
5
5
|
from typing import Union, Literal, Dict, Any, Optional
|
|
6
6
|
from abc import ABC, abstractmethod
|
|
7
7
|
|
|
8
|
-
from .ML_scaler import
|
|
8
|
+
from .ML_scaler import DragonScaler
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
11
|
from .path_manager import make_fullpath
|
|
12
|
-
from .
|
|
12
|
+
from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
__all__ = [
|
|
16
|
-
"
|
|
17
|
-
"PyTorchInferenceHandlerMulti",
|
|
16
|
+
"DragonInferenceHandler",
|
|
18
17
|
"multi_inference_regression",
|
|
19
18
|
"multi_inference_classification"
|
|
20
19
|
]
|
|
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
|
|
|
31
30
|
model: nn.Module,
|
|
32
31
|
state_dict: Union[str, Path],
|
|
33
32
|
device: str = 'cpu',
|
|
34
|
-
scaler: Optional[Union[
|
|
33
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
35
34
|
"""
|
|
36
35
|
Initializes the handler.
|
|
37
36
|
|
|
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
|
|
|
39
38
|
model (nn.Module): An instantiated PyTorch model.
|
|
40
39
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
41
40
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
42
|
-
scaler (
|
|
41
|
+
scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
|
|
43
42
|
"""
|
|
44
43
|
self.model = model
|
|
45
44
|
self.device = self._validate_device(device)
|
|
45
|
+
self._classification_threshold = 0.5
|
|
46
|
+
self._loaded_threshold: bool = False
|
|
47
|
+
self._loaded_class_map: bool = False
|
|
48
|
+
self._class_map: Optional[dict[str,int]] = None
|
|
49
|
+
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
50
|
+
self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
|
|
46
51
|
|
|
47
52
|
# Load the scaler if a path is provided
|
|
48
53
|
if scaler is not None:
|
|
49
54
|
if isinstance(scaler, (str, Path)):
|
|
50
|
-
self.scaler =
|
|
55
|
+
self.scaler = DragonScaler.load(scaler)
|
|
51
56
|
else:
|
|
52
57
|
self.scaler = scaler
|
|
53
58
|
else:
|
|
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
|
|
|
58
63
|
try:
|
|
59
64
|
# Load whatever is in the file
|
|
60
65
|
loaded_data = torch.load(model_p, map_location=self.device)
|
|
61
|
-
|
|
62
|
-
# Check if it's
|
|
63
|
-
if isinstance(loaded_data, dict)
|
|
64
|
-
|
|
65
|
-
|
|
66
|
+
|
|
67
|
+
# Check if it's dictionary or a old weights-only file
|
|
68
|
+
if isinstance(loaded_data, dict):
|
|
69
|
+
self._loaded_data_dict = loaded_data # Store the dict
|
|
70
|
+
|
|
71
|
+
if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
72
|
+
# It's a new training checkpoint, extract the weights
|
|
73
|
+
self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
74
|
+
|
|
75
|
+
# attempt to get a custom classification threshold
|
|
76
|
+
if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
|
|
77
|
+
try:
|
|
78
|
+
self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
|
|
79
|
+
except Exception as e_int:
|
|
80
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
|
|
81
|
+
self._classification_threshold = 0.5
|
|
82
|
+
else:
|
|
83
|
+
_LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
|
|
84
|
+
self._loaded_threshold = True
|
|
85
|
+
|
|
86
|
+
# attempt to get a class map
|
|
87
|
+
if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
|
|
88
|
+
try:
|
|
89
|
+
self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
|
|
90
|
+
except Exception as e_int:
|
|
91
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
|
|
92
|
+
else:
|
|
93
|
+
if isinstance(self._class_map, dict):
|
|
94
|
+
self._loaded_class_map = True
|
|
95
|
+
self.set_class_map(self._class_map)
|
|
96
|
+
else:
|
|
97
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
|
|
98
|
+
self._class_map = None
|
|
99
|
+
else:
|
|
100
|
+
# It's a state_dict, load it directly
|
|
101
|
+
self.model.load_state_dict(loaded_data)
|
|
102
|
+
|
|
66
103
|
else:
|
|
67
|
-
|
|
104
|
+
# It's an old state_dict (just weights), load it directly
|
|
68
105
|
self.model.load_state_dict(loaded_data)
|
|
69
106
|
|
|
70
107
|
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
@@ -82,25 +119,40 @@ class _BaseInferenceHandler(ABC):
|
|
|
82
119
|
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
83
120
|
device_lower = "cpu"
|
|
84
121
|
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
85
|
-
# Your M-series Mac will appreciate this check!
|
|
86
122
|
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
87
123
|
device_lower = "cpu"
|
|
88
124
|
return torch.device(device_lower)
|
|
89
|
-
|
|
90
|
-
def
|
|
91
|
-
"""
|
|
92
|
-
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
93
|
-
present, and moves it to the correct device.
|
|
125
|
+
|
|
126
|
+
def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
|
|
94
127
|
"""
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
128
|
+
Sets the class name mapping to translate predicted integer labels back into string names.
|
|
129
|
+
|
|
130
|
+
If a class_map was previously loaded from a model configuration, this
|
|
131
|
+
method will log a warning and refuse to update the value. This
|
|
132
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
133
|
+
|
|
134
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
102
135
|
|
|
103
|
-
|
|
136
|
+
Args:
|
|
137
|
+
class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
|
|
138
|
+
force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
|
|
139
|
+
"""
|
|
140
|
+
if self._loaded_class_map:
|
|
141
|
+
warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
|
|
142
|
+
if not force_overwrite:
|
|
143
|
+
warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
|
|
144
|
+
_LOGGER.warning(warning_message)
|
|
145
|
+
return
|
|
146
|
+
else:
|
|
147
|
+
warning_message += " Overwriting it for this inference instance."
|
|
148
|
+
_LOGGER.warning(warning_message)
|
|
149
|
+
|
|
150
|
+
# Store the map and invert it for fast lookup
|
|
151
|
+
self._class_map = class_map
|
|
152
|
+
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
153
|
+
# Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
|
|
154
|
+
self._loaded_class_map = True # Protect this newly set map
|
|
155
|
+
_LOGGER.info("Class map set for label-to-name translation.")
|
|
104
156
|
|
|
105
157
|
@abstractmethod
|
|
106
158
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
@@ -113,46 +165,107 @@ class _BaseInferenceHandler(ABC):
|
|
|
113
165
|
pass
|
|
114
166
|
|
|
115
167
|
|
|
116
|
-
class
|
|
168
|
+
class DragonInferenceHandler(_BaseInferenceHandler):
|
|
117
169
|
"""
|
|
118
|
-
Handles loading a PyTorch model's state dictionary and performing inference
|
|
119
|
-
for single-target regression or classification tasks.
|
|
170
|
+
Handles loading a PyTorch model's state dictionary and performing inference.
|
|
120
171
|
"""
|
|
121
172
|
def __init__(self,
|
|
122
173
|
model: nn.Module,
|
|
123
174
|
state_dict: Union[str, Path],
|
|
124
|
-
task: Literal["classification", "regression"],
|
|
175
|
+
task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
|
|
125
176
|
device: str = 'cpu',
|
|
126
|
-
|
|
127
|
-
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
177
|
+
scaler: Optional[Union[DragonScaler, str, Path]] = None):
|
|
128
178
|
"""
|
|
129
179
|
Initializes the handler for single-target tasks.
|
|
130
180
|
|
|
131
181
|
Args:
|
|
132
182
|
model (nn.Module): An instantiated PyTorch model architecture.
|
|
133
183
|
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
134
|
-
task (str): The type of task
|
|
184
|
+
task (str): The type of task.
|
|
135
185
|
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
136
|
-
|
|
137
|
-
|
|
186
|
+
scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
|
|
187
|
+
|
|
188
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
138
189
|
"""
|
|
139
190
|
# Call the parent constructor to handle model loading, device, and scaler
|
|
140
191
|
super().__init__(model, state_dict, device, scaler)
|
|
141
192
|
|
|
142
|
-
if task not in [
|
|
143
|
-
|
|
193
|
+
if task not in [MLTaskKeys.REGRESSION,
|
|
194
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
195
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION,
|
|
196
|
+
MLTaskKeys.MULTITARGET_REGRESSION,
|
|
197
|
+
MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
|
|
198
|
+
_LOGGER.error(f"'task' not recognized: '{task}'.")
|
|
199
|
+
raise ValueError()
|
|
144
200
|
self.task = task
|
|
145
|
-
self.
|
|
201
|
+
self.target_ids: Optional[list[str]] = None
|
|
202
|
+
self._target_ids_set: bool = False
|
|
203
|
+
|
|
204
|
+
# attempt to get target name or target names
|
|
205
|
+
if PyTorchCheckpointKeys.TARGET_NAME in self._loaded_data_dict:
|
|
206
|
+
try:
|
|
207
|
+
target_from_dict = [self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAME]]
|
|
208
|
+
except Exception as e_int:
|
|
209
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAME}' but an error occurred when retrieving it:\n{e_int}")
|
|
210
|
+
self.target_ids = None
|
|
211
|
+
else:
|
|
212
|
+
self.set_target_ids([target_from_dict]) # type: ignore
|
|
213
|
+
elif PyTorchCheckpointKeys.TARGET_NAMES in self._loaded_data_dict:
|
|
214
|
+
try:
|
|
215
|
+
targets_from_dict = self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAMES]
|
|
216
|
+
except Exception as e_int:
|
|
217
|
+
_LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAMES}' but an error occurred when retrieving it:\n{e_int}")
|
|
218
|
+
self.target_ids = None
|
|
219
|
+
else:
|
|
220
|
+
self.set_target_ids(targets_from_dict) # type: ignore
|
|
221
|
+
|
|
222
|
+
def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
|
|
223
|
+
"""
|
|
224
|
+
Converts input to a torch.Tensor, applies scaling if a scaler is
|
|
225
|
+
present, and moves it to the correct device.
|
|
226
|
+
"""
|
|
227
|
+
if isinstance(features, np.ndarray):
|
|
228
|
+
features_tensor = torch.from_numpy(features).float()
|
|
229
|
+
else:
|
|
230
|
+
features_tensor = features.float()
|
|
231
|
+
|
|
232
|
+
if self.scaler:
|
|
233
|
+
features_tensor = self.scaler.transform(features_tensor)
|
|
234
|
+
|
|
235
|
+
return features_tensor.to(self.device)
|
|
236
|
+
|
|
237
|
+
def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
|
|
238
|
+
"""
|
|
239
|
+
Assigns the provided list of strings as the target variable names.
|
|
240
|
+
If target IDs have already been set, this method will log a warning.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
target_names (list[str]): A list of target names.
|
|
244
|
+
force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
|
|
245
|
+
"""
|
|
246
|
+
if self._target_ids_set:
|
|
247
|
+
warning_message = "Target IDs was previously set."
|
|
248
|
+
if not force_overwrite:
|
|
249
|
+
warning_message += " Use `force_overwrite=True` to overwrite."
|
|
250
|
+
_LOGGER.warning(warning_message)
|
|
251
|
+
return
|
|
252
|
+
else:
|
|
253
|
+
warning_message += " Overwriting..."
|
|
254
|
+
_LOGGER.warning(warning_message)
|
|
255
|
+
|
|
256
|
+
self.target_ids = target_names
|
|
257
|
+
self._target_ids_set = True
|
|
258
|
+
_LOGGER.info("Target IDs set.")
|
|
146
259
|
|
|
147
260
|
def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
148
261
|
"""
|
|
149
|
-
Core batch prediction method
|
|
262
|
+
Core batch prediction method.
|
|
150
263
|
|
|
151
264
|
Args:
|
|
152
265
|
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
153
266
|
|
|
154
267
|
Returns:
|
|
155
|
-
A dictionary containing the raw output tensors from the model.
|
|
268
|
+
Dict: A dictionary containing the raw output tensors from the model.
|
|
156
269
|
"""
|
|
157
270
|
if features.ndim != 2:
|
|
158
271
|
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
@@ -163,16 +276,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
163
276
|
with torch.no_grad():
|
|
164
277
|
output = self.model(input_tensor)
|
|
165
278
|
|
|
166
|
-
if self.task ==
|
|
279
|
+
if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
|
|
167
280
|
probs = torch.softmax(output, dim=1)
|
|
168
281
|
labels = torch.argmax(probs, dim=1)
|
|
169
282
|
return {
|
|
170
283
|
PyTorchInferenceKeys.LABELS: labels,
|
|
171
284
|
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
172
285
|
}
|
|
173
|
-
|
|
286
|
+
|
|
287
|
+
elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
|
|
288
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
289
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
290
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
291
|
+
output = output.squeeze(1)
|
|
292
|
+
|
|
293
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
294
|
+
labels = (probs >= self._classification_threshold).int()
|
|
295
|
+
return {
|
|
296
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
297
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
elif self.task == MLTaskKeys.REGRESSION:
|
|
174
301
|
# For single-target regression, ensure output is flattened
|
|
175
302
|
return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
|
|
303
|
+
|
|
304
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
305
|
+
probs = torch.sigmoid(output)
|
|
306
|
+
# Get binary predictions based on the threshold
|
|
307
|
+
labels = (probs >= self._classification_threshold).int()
|
|
308
|
+
return {
|
|
309
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
310
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
314
|
+
# The output is already in the correct [batch_size, n_targets] shape
|
|
315
|
+
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
316
|
+
|
|
317
|
+
else:
|
|
318
|
+
# should never happen
|
|
319
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
320
|
+
raise ValueError()
|
|
176
321
|
|
|
177
322
|
def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
178
323
|
"""
|
|
@@ -182,7 +327,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
182
327
|
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
183
328
|
|
|
184
329
|
Returns:
|
|
185
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
330
|
+
Dict: A dictionary containing the raw output tensors for a single sample.
|
|
186
331
|
"""
|
|
187
332
|
if features.ndim == 1:
|
|
188
333
|
features = features.reshape(1, -1) # Reshape to a batch of one
|
|
@@ -196,193 +341,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
|
|
|
196
341
|
# Extract the first (and only) result from the batch output
|
|
197
342
|
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
198
343
|
return single_results
|
|
199
|
-
|
|
344
|
+
|
|
200
345
|
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
201
346
|
|
|
202
347
|
def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
|
|
203
348
|
"""
|
|
204
|
-
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
349
|
+
Convenience wrapper for predict_batch that returns NumPy arrays
|
|
350
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
205
351
|
"""
|
|
206
352
|
tensor_results = self.predict_batch(features)
|
|
207
353
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
354
|
+
|
|
355
|
+
# Add string names for classification if map exists
|
|
356
|
+
is_classification = self.task in [
|
|
357
|
+
MLTaskKeys.BINARY_CLASSIFICATION,
|
|
358
|
+
MLTaskKeys.MULTICLASS_CLASSIFICATION
|
|
359
|
+
]
|
|
360
|
+
|
|
361
|
+
if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
362
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
363
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
|
|
364
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
365
|
+
for label_id in int_labels
|
|
366
|
+
]
|
|
367
|
+
|
|
208
368
|
return numpy_results
|
|
209
369
|
|
|
210
370
|
def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
211
371
|
"""
|
|
212
|
-
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
372
|
+
Convenience wrapper for predict that returns NumPy arrays or scalars
|
|
373
|
+
and adds string labels for classification tasks if a class_map is set.
|
|
213
374
|
"""
|
|
214
375
|
tensor_results = self.predict(features)
|
|
215
376
|
|
|
216
|
-
if self.task ==
|
|
377
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
217
378
|
# .item() implicitly moves to CPU and returns a Python scalar
|
|
218
379
|
return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
|
|
219
|
-
|
|
380
|
+
|
|
381
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
382
|
+
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
383
|
+
label_name = "Unknown"
|
|
384
|
+
if self._idx_to_class:
|
|
385
|
+
label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
|
|
386
|
+
|
|
220
387
|
return {
|
|
221
|
-
PyTorchInferenceKeys.LABELS:
|
|
388
|
+
PyTorchInferenceKeys.LABELS: int_label,
|
|
389
|
+
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
222
390
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
223
391
|
}
|
|
392
|
+
|
|
393
|
+
elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
|
|
394
|
+
# For multi-target models, the output is always an array.
|
|
395
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
396
|
+
return numpy_results
|
|
397
|
+
else:
|
|
398
|
+
# should never happen
|
|
399
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
400
|
+
raise ValueError()
|
|
224
401
|
|
|
225
402
|
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
226
403
|
"""
|
|
227
404
|
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
228
405
|
|
|
229
|
-
`
|
|
406
|
+
`target_ids` must be implemented.
|
|
230
407
|
"""
|
|
231
|
-
if self.
|
|
232
|
-
_LOGGER.error(f"'
|
|
408
|
+
if self.target_ids is None:
|
|
409
|
+
_LOGGER.error(f"'target_ids' has not been implemented.")
|
|
233
410
|
raise AttributeError()
|
|
234
411
|
|
|
235
|
-
if self.task ==
|
|
412
|
+
if self.task == MLTaskKeys.REGRESSION:
|
|
236
413
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
|
|
237
|
-
|
|
414
|
+
return {self.target_ids[0]: result}
|
|
415
|
+
|
|
416
|
+
elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
238
417
|
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
|
|
418
|
+
return {self.target_ids[0]: result}
|
|
239
419
|
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
task: Literal["multi_target_regression", "multi_label_classification"],
|
|
252
|
-
device: str = 'cpu',
|
|
253
|
-
target_ids: Optional[list[str]] = None,
|
|
254
|
-
scaler: Optional[Union[PytorchScaler, str, Path]] = None):
|
|
255
|
-
"""
|
|
256
|
-
Initializes the handler for multi-target tasks.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
model (nn.Module): An instantiated PyTorch model.
|
|
260
|
-
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
261
|
-
task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
|
|
262
|
-
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
263
|
-
target_ids (list[str] | None): An optional identifier for the targets.
|
|
264
|
-
scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
|
|
265
|
-
"""
|
|
266
|
-
super().__init__(model, state_dict, device, scaler)
|
|
267
|
-
|
|
268
|
-
if task not in ["multi_target_regression", "multi_label_classification"]:
|
|
269
|
-
_LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
|
|
270
|
-
raise ValueError()
|
|
271
|
-
self.task = task
|
|
272
|
-
self.target_ids = target_ids
|
|
273
|
-
|
|
274
|
-
def predict_batch(self,
|
|
275
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
276
|
-
classification_threshold: float = 0.5
|
|
277
|
-
) -> Dict[str, torch.Tensor]:
|
|
278
|
-
"""
|
|
279
|
-
Core batch prediction method for multi-target models.
|
|
280
|
-
|
|
281
|
-
Args:
|
|
282
|
-
features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
|
|
283
|
-
classification_threshold (float): The threshold to convert probabilities
|
|
284
|
-
into binary predictions for multi-label classification.
|
|
285
|
-
|
|
286
|
-
Returns:
|
|
287
|
-
A dictionary containing the raw output tensors from the model.
|
|
288
|
-
"""
|
|
289
|
-
if features.ndim != 2:
|
|
290
|
-
_LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
|
|
291
|
-
raise ValueError()
|
|
292
|
-
|
|
293
|
-
input_tensor = self._preprocess_input(features)
|
|
294
|
-
|
|
295
|
-
with torch.no_grad():
|
|
296
|
-
output = self.model(input_tensor)
|
|
297
|
-
|
|
298
|
-
if self.task == "multi_label_classification":
|
|
299
|
-
probs = torch.sigmoid(output)
|
|
300
|
-
# Get binary predictions based on the threshold
|
|
301
|
-
labels = (probs >= classification_threshold).int()
|
|
302
|
-
return {
|
|
303
|
-
PyTorchInferenceKeys.LABELS: labels,
|
|
304
|
-
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
305
|
-
}
|
|
306
|
-
else: # multi_target_regression
|
|
307
|
-
# The output is already in the correct [batch_size, n_targets] shape
|
|
308
|
-
return {PyTorchInferenceKeys.PREDICTIONS: output}
|
|
309
|
-
|
|
310
|
-
def predict(self,
|
|
311
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
312
|
-
classification_threshold: float = 0.5
|
|
313
|
-
) -> Dict[str, torch.Tensor]:
|
|
314
|
-
"""
|
|
315
|
-
Core single-sample prediction method for multi-target models.
|
|
316
|
-
|
|
317
|
-
Args:
|
|
318
|
-
features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
|
|
319
|
-
classification_threshold (float): The threshold for multi-label tasks.
|
|
320
|
-
|
|
321
|
-
Returns:
|
|
322
|
-
A dictionary containing the raw output tensors for a single sample.
|
|
323
|
-
"""
|
|
324
|
-
if features.ndim == 1:
|
|
325
|
-
features = features.reshape(1, -1)
|
|
326
|
-
|
|
327
|
-
if features.shape[0] != 1:
|
|
328
|
-
_LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
|
|
420
|
+
elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
|
|
421
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
|
|
422
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
423
|
+
|
|
424
|
+
elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
|
|
425
|
+
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
426
|
+
return {key: value for key, value in zip(self.target_ids, result)}
|
|
427
|
+
|
|
428
|
+
else:
|
|
429
|
+
# should never happen
|
|
430
|
+
_LOGGER.error(f"Unrecognized task '{self.task}'.")
|
|
329
431
|
raise ValueError()
|
|
330
|
-
|
|
331
|
-
batch_results = self.predict_batch(features, classification_threshold)
|
|
332
|
-
|
|
333
|
-
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
334
|
-
return single_results
|
|
335
|
-
|
|
336
|
-
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
337
|
-
|
|
338
|
-
def predict_batch_numpy(self,
|
|
339
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
340
|
-
classification_threshold: float = 0.5
|
|
341
|
-
) -> Dict[str, np.ndarray]:
|
|
342
|
-
"""
|
|
343
|
-
Convenience wrapper for predict_batch that returns NumPy arrays.
|
|
344
|
-
"""
|
|
345
|
-
tensor_results = self.predict_batch(features, classification_threshold)
|
|
346
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
347
|
-
return numpy_results
|
|
348
|
-
|
|
349
|
-
def predict_numpy(self,
|
|
350
|
-
features: Union[np.ndarray, torch.Tensor],
|
|
351
|
-
classification_threshold: float = 0.5
|
|
352
|
-
) -> Dict[str, np.ndarray]:
|
|
353
|
-
"""
|
|
354
|
-
Convenience wrapper for predict that returns NumPy arrays for a single sample.
|
|
355
|
-
Note: For multi-target models, the output is always an array.
|
|
356
|
-
"""
|
|
357
|
-
tensor_results = self.predict(features, classification_threshold)
|
|
358
|
-
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
359
|
-
return numpy_results
|
|
360
|
-
|
|
361
|
-
def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
|
|
362
|
-
"""
|
|
363
|
-
Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
|
|
364
432
|
|
|
365
|
-
|
|
433
|
+
def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
|
|
366
434
|
"""
|
|
367
|
-
|
|
368
|
-
_LOGGER.error(f"'target_id' has not been implemented.")
|
|
369
|
-
raise AttributeError()
|
|
435
|
+
Sets the classification threshold for the current inference instance.
|
|
370
436
|
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
|
|
437
|
+
If a threshold was previously loaded from a model configuration, this
|
|
438
|
+
method will log a warning and refuse to update the value. This
|
|
439
|
+
prevents accidentally overriding a setting from a loaded checkpoint.
|
|
375
440
|
|
|
376
|
-
|
|
441
|
+
To bypass this safety check set `force_overwrite` to `True`.
|
|
377
442
|
|
|
443
|
+
Args:
|
|
444
|
+
threshold (float): The new classification threshold value to set.
|
|
445
|
+
force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
|
|
446
|
+
"""
|
|
447
|
+
if self._loaded_threshold:
|
|
448
|
+
warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
|
|
449
|
+
if not force_overwrite:
|
|
450
|
+
warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
|
|
451
|
+
_LOGGER.warning(warning_message)
|
|
452
|
+
return
|
|
453
|
+
else:
|
|
454
|
+
warning_message += f" Overwriting it to {threshold}."
|
|
455
|
+
_LOGGER.warning(warning_message)
|
|
456
|
+
|
|
457
|
+
self._classification_threshold = threshold
|
|
378
458
|
|
|
379
|
-
|
|
459
|
+
|
|
460
|
+
def multi_inference_regression(handlers: list[DragonInferenceHandler],
|
|
380
461
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
381
462
|
output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
|
|
382
463
|
"""
|
|
383
464
|
Performs regression inference using multiple models on a single feature vector.
|
|
384
465
|
|
|
385
|
-
This function iterates through a list of
|
|
466
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
386
467
|
each configured for a different regression target. It runs a prediction for
|
|
387
468
|
each handler using the same input feature vector and returns the results
|
|
388
469
|
in a dictionary.
|
|
@@ -392,7 +473,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
392
473
|
- 2D input: Returns a dictionary mapping target ID to a list of values.
|
|
393
474
|
|
|
394
475
|
Args:
|
|
395
|
-
handlers (list[
|
|
476
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference
|
|
396
477
|
handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
|
|
397
478
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
|
|
398
479
|
output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
|
|
@@ -422,11 +503,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
422
503
|
results: dict[str,Any] = dict()
|
|
423
504
|
for handler in handlers:
|
|
424
505
|
# validation
|
|
425
|
-
if handler.
|
|
426
|
-
_LOGGER.error("All inference handlers must have a '
|
|
506
|
+
if handler.target_ids is None:
|
|
507
|
+
_LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
|
|
427
508
|
raise AttributeError()
|
|
428
|
-
if handler.task !=
|
|
429
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
509
|
+
if handler.task != MLTaskKeys.REGRESSION:
|
|
510
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
|
|
430
511
|
raise ValueError()
|
|
431
512
|
|
|
432
513
|
# inference
|
|
@@ -435,33 +516,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
|
|
|
435
516
|
numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
436
517
|
if is_single_sample:
|
|
437
518
|
# For a single sample, convert the 1-element array to a Python scalar
|
|
438
|
-
results[handler.
|
|
519
|
+
results[handler.target_ids[0]] = numpy_result.item()
|
|
439
520
|
else:
|
|
440
521
|
# For a batch, return the full NumPy array of predictions
|
|
441
|
-
results[handler.
|
|
522
|
+
results[handler.target_ids[0]] = numpy_result
|
|
442
523
|
|
|
443
524
|
else: # output == "torch"
|
|
444
525
|
# This path returns PyTorch tensors on the model's device
|
|
445
526
|
torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
|
|
446
527
|
if is_single_sample:
|
|
447
528
|
# For a single sample, return the 0-dim tensor
|
|
448
|
-
results[handler.
|
|
529
|
+
results[handler.target_ids[0]] = torch_result[0]
|
|
449
530
|
else:
|
|
450
531
|
# For a batch, return the full tensor of predictions
|
|
451
|
-
results[handler.
|
|
532
|
+
results[handler.target_ids[0]] = torch_result
|
|
452
533
|
|
|
453
534
|
return results
|
|
454
535
|
|
|
455
536
|
|
|
456
537
|
def multi_inference_classification(
|
|
457
|
-
handlers: list[
|
|
538
|
+
handlers: list[DragonInferenceHandler],
|
|
458
539
|
feature_vector: Union[np.ndarray, torch.Tensor],
|
|
459
540
|
output: Literal["numpy","torch"]="numpy"
|
|
460
541
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
461
542
|
"""
|
|
462
543
|
Performs classification inference on a single sample or a batch.
|
|
463
544
|
|
|
464
|
-
This function iterates through a list of
|
|
545
|
+
This function iterates through a list of DragonInferenceHandler objects,
|
|
465
546
|
each configured for a different classification target. It returns two
|
|
466
547
|
dictionaries: one for the predicted labels and one for the probabilities.
|
|
467
548
|
|
|
@@ -470,7 +551,7 @@ def multi_inference_classification(
|
|
|
470
551
|
- 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
|
|
471
552
|
|
|
472
553
|
Args:
|
|
473
|
-
handlers (list[
|
|
554
|
+
handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
|
|
474
555
|
with `task="classification"`.
|
|
475
556
|
feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
|
|
476
557
|
or a batch of samples (2D) for prediction.
|
|
@@ -503,11 +584,11 @@ def multi_inference_classification(
|
|
|
503
584
|
|
|
504
585
|
for handler in handlers:
|
|
505
586
|
# Validation
|
|
506
|
-
if handler.
|
|
587
|
+
if handler.target_ids is None:
|
|
507
588
|
_LOGGER.error("All inference handlers must have a 'target_id' attribute.")
|
|
508
589
|
raise AttributeError()
|
|
509
|
-
if handler.task
|
|
510
|
-
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.
|
|
590
|
+
if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
|
|
591
|
+
_LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
|
|
511
592
|
raise ValueError()
|
|
512
593
|
|
|
513
594
|
# Inference
|
|
@@ -525,15 +606,15 @@ def multi_inference_classification(
|
|
|
525
606
|
# For "numpy", convert the single label to a Python int scalar.
|
|
526
607
|
# For "torch", get the 0-dim tensor label.
|
|
527
608
|
if output == "numpy":
|
|
528
|
-
labels_results[handler.
|
|
609
|
+
labels_results[handler.target_ids[0]] = labels.item()
|
|
529
610
|
else: # torch
|
|
530
|
-
labels_results[handler.
|
|
611
|
+
labels_results[handler.target_ids[0]] = labels[0]
|
|
531
612
|
|
|
532
613
|
# The probabilities are an array/tensor of values
|
|
533
|
-
probs_results[handler.
|
|
614
|
+
probs_results[handler.target_ids[0]] = probabilities[0]
|
|
534
615
|
else:
|
|
535
|
-
labels_results[handler.
|
|
536
|
-
probs_results[handler.
|
|
616
|
+
labels_results[handler.target_ids[0]] = labels
|
|
617
|
+
probs_results[handler.target_ids[0]] = probabilities
|
|
537
618
|
|
|
538
619
|
return labels_results, probs_results
|
|
539
620
|
|