dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
  2. dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +726 -32
  9. ml_tools/ML_datasetmaster.py +235 -280
  10. ml_tools/ML_evaluation.py +160 -42
  11. ml_tools/ML_evaluation_multi.py +103 -35
  12. ml_tools/ML_inference.py +290 -208
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +219 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1342 -386
  22. ml_tools/ML_utilities.py +1 -1
  23. ml_tools/ML_vision_datasetmaster.py +120 -72
  24. ml_tools/ML_vision_evaluation.py +30 -6
  25. ml_tools/ML_vision_inference.py +129 -152
  26. ml_tools/ML_vision_models.py +1 -1
  27. ml_tools/ML_vision_transformers.py +121 -40
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -0
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +1 -1
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py CHANGED
@@ -5,16 +5,15 @@ from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
6
  from abc import ABC, abstractmethod
7
7
 
8
- from .ML_scaler import PytorchScaler
8
+ from .ML_scaler import DragonScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
12
+ from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
13
13
 
14
14
 
15
15
  __all__ = [
16
- "PyTorchInferenceHandler",
17
- "PyTorchInferenceHandlerMulti",
16
+ "DragonInferenceHandler",
18
17
  "multi_inference_regression",
19
18
  "multi_inference_classification"
20
19
  ]
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
31
30
  model: nn.Module,
32
31
  state_dict: Union[str, Path],
33
32
  device: str = 'cpu',
34
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
33
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
35
34
  """
36
35
  Initializes the handler.
37
36
 
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
39
38
  model (nn.Module): An instantiated PyTorch model.
40
39
  state_dict (str | Path): Path to the saved .pth model state_dict file.
41
40
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
42
- scaler (PytorchScaler | str | Path | None): An optional scaler or path to a saved scaler state.
41
+ scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
43
42
  """
44
43
  self.model = model
45
44
  self.device = self._validate_device(device)
45
+ self._classification_threshold = 0.5
46
+ self._loaded_threshold: bool = False
47
+ self._loaded_class_map: bool = False
48
+ self._class_map: Optional[dict[str,int]] = None
49
+ self._idx_to_class: Optional[Dict[int, str]] = None
50
+ self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
46
51
 
47
52
  # Load the scaler if a path is provided
48
53
  if scaler is not None:
49
54
  if isinstance(scaler, (str, Path)):
50
- self.scaler = PytorchScaler.load(scaler)
55
+ self.scaler = DragonScaler.load(scaler)
51
56
  else:
52
57
  self.scaler = scaler
53
58
  else:
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
58
63
  try:
59
64
  # Load whatever is in the file
60
65
  loaded_data = torch.load(model_p, map_location=self.device)
61
-
62
- # Check if it's the new checkpoint dictionary or an old weights-only file
63
- if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
- # It's a new training checkpoint, extract the weights
65
- self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+
67
+ # Check if it's dictionary or a old weights-only file
68
+ if isinstance(loaded_data, dict):
69
+ self._loaded_data_dict = loaded_data # Store the dict
70
+
71
+ if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
72
+ # It's a new training checkpoint, extract the weights
73
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
74
+
75
+ # attempt to get a custom classification threshold
76
+ if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
77
+ try:
78
+ self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
79
+ except Exception as e_int:
80
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
81
+ self._classification_threshold = 0.5
82
+ else:
83
+ _LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
84
+ self._loaded_threshold = True
85
+
86
+ # attempt to get a class map
87
+ if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
88
+ try:
89
+ self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
90
+ except Exception as e_int:
91
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
92
+ else:
93
+ if isinstance(self._class_map, dict):
94
+ self._loaded_class_map = True
95
+ self.set_class_map(self._class_map)
96
+ else:
97
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
98
+ self._class_map = None
99
+ else:
100
+ # It's a state_dict, load it directly
101
+ self.model.load_state_dict(loaded_data)
102
+
66
103
  else:
67
- # It's an old-style file (or just a state_dict), load it directly
104
+ # It's an old state_dict (just weights), load it directly
68
105
  self.model.load_state_dict(loaded_data)
69
106
 
70
107
  _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
@@ -85,21 +122,37 @@ class _BaseInferenceHandler(ABC):
85
122
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
86
123
  device_lower = "cpu"
87
124
  return torch.device(device_lower)
88
-
89
- def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
90
- """
91
- Converts input to a torch.Tensor, applies scaling if a scaler is
92
- present, and moves it to the correct device.
125
+
126
+ def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
93
127
  """
94
- if isinstance(features, np.ndarray):
95
- features_tensor = torch.from_numpy(features).float()
96
- else:
97
- features_tensor = features.float()
98
-
99
- if self.scaler:
100
- features_tensor = self.scaler.transform(features_tensor)
128
+ Sets the class name mapping to translate predicted integer labels back into string names.
129
+
130
+ If a class_map was previously loaded from a model configuration, this
131
+ method will log a warning and refuse to update the value. This
132
+ prevents accidentally overriding a setting from a loaded checkpoint.
133
+
134
+ To bypass this safety check set `force_overwrite` to `True`.
101
135
 
102
- return features_tensor.to(self.device)
136
+ Args:
137
+ class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
138
+ force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
139
+ """
140
+ if self._loaded_class_map:
141
+ warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
142
+ if not force_overwrite:
143
+ warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
144
+ _LOGGER.warning(warning_message)
145
+ return
146
+ else:
147
+ warning_message += " Overwriting it for this inference instance."
148
+ _LOGGER.warning(warning_message)
149
+
150
+ # Store the map and invert it for fast lookup
151
+ self._class_map = class_map
152
+ self._idx_to_class = {v: k for k, v in class_map.items()}
153
+ # Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
154
+ self._loaded_class_map = True # Protect this newly set map
155
+ _LOGGER.info("Class map set for label-to-name translation.")
103
156
 
104
157
  @abstractmethod
105
158
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
@@ -112,46 +165,107 @@ class _BaseInferenceHandler(ABC):
112
165
  pass
113
166
 
114
167
 
115
- class PyTorchInferenceHandler(_BaseInferenceHandler):
168
+ class DragonInferenceHandler(_BaseInferenceHandler):
116
169
  """
117
- Handles loading a PyTorch model's state dictionary and performing inference
118
- for single-target regression or classification tasks.
170
+ Handles loading a PyTorch model's state dictionary and performing inference.
119
171
  """
120
172
  def __init__(self,
121
173
  model: nn.Module,
122
174
  state_dict: Union[str, Path],
123
- task: Literal["classification", "regression"],
175
+ task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
124
176
  device: str = 'cpu',
125
- target_id: Optional[str] = None,
126
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
177
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
127
178
  """
128
179
  Initializes the handler for single-target tasks.
129
180
 
130
181
  Args:
131
182
  model (nn.Module): An instantiated PyTorch model architecture.
132
183
  state_dict (str | Path): Path to the saved .pth model state_dict file.
133
- task (str): The type of task, 'regression' or 'classification'.
184
+ task (str): The type of task.
134
185
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
135
- target_id (str | None): An optional identifier for the target.
136
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
186
+ scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
187
+
188
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
137
189
  """
138
190
  # Call the parent constructor to handle model loading, device, and scaler
139
191
  super().__init__(model, state_dict, device, scaler)
140
192
 
141
- if task not in ["classification", "regression"]:
142
- raise ValueError("`task` must be 'classification' or 'regression'.")
193
+ if task not in [MLTaskKeys.REGRESSION,
194
+ MLTaskKeys.BINARY_CLASSIFICATION,
195
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
196
+ MLTaskKeys.MULTITARGET_REGRESSION,
197
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
198
+ _LOGGER.error(f"'task' not recognized: '{task}'.")
199
+ raise ValueError()
143
200
  self.task = task
144
- self.target_id = target_id
201
+ self.target_ids: Optional[list[str]] = None
202
+ self._target_ids_set: bool = False
203
+
204
+ # attempt to get target name or target names
205
+ if PyTorchCheckpointKeys.TARGET_NAME in self._loaded_data_dict:
206
+ try:
207
+ target_from_dict = [self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAME]]
208
+ except Exception as e_int:
209
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAME}' but an error occurred when retrieving it:\n{e_int}")
210
+ self.target_ids = None
211
+ else:
212
+ self.set_target_ids([target_from_dict]) # type: ignore
213
+ elif PyTorchCheckpointKeys.TARGET_NAMES in self._loaded_data_dict:
214
+ try:
215
+ targets_from_dict = self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAMES]
216
+ except Exception as e_int:
217
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAMES}' but an error occurred when retrieving it:\n{e_int}")
218
+ self.target_ids = None
219
+ else:
220
+ self.set_target_ids(targets_from_dict) # type: ignore
221
+
222
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
223
+ """
224
+ Converts input to a torch.Tensor, applies scaling if a scaler is
225
+ present, and moves it to the correct device.
226
+ """
227
+ if isinstance(features, np.ndarray):
228
+ features_tensor = torch.from_numpy(features).float()
229
+ else:
230
+ features_tensor = features.float()
231
+
232
+ if self.scaler:
233
+ features_tensor = self.scaler.transform(features_tensor)
234
+
235
+ return features_tensor.to(self.device)
236
+
237
+ def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
238
+ """
239
+ Assigns the provided list of strings as the target variable names.
240
+ If target IDs have already been set, this method will log a warning.
241
+
242
+ Args:
243
+ target_names (list[str]): A list of target names.
244
+ force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
245
+ """
246
+ if self._target_ids_set:
247
+ warning_message = "Target IDs was previously set."
248
+ if not force_overwrite:
249
+ warning_message += " Use `force_overwrite=True` to overwrite."
250
+ _LOGGER.warning(warning_message)
251
+ return
252
+ else:
253
+ warning_message += " Overwriting..."
254
+ _LOGGER.warning(warning_message)
255
+
256
+ self.target_ids = target_names
257
+ self._target_ids_set = True
258
+ _LOGGER.info("Target IDs set.")
145
259
 
146
260
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
147
261
  """
148
- Core batch prediction method for single-target models.
262
+ Core batch prediction method.
149
263
 
150
264
  Args:
151
265
  features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
152
266
 
153
267
  Returns:
154
- A dictionary containing the raw output tensors from the model.
268
+ Dict: A dictionary containing the raw output tensors from the model.
155
269
  """
156
270
  if features.ndim != 2:
157
271
  _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
@@ -162,16 +276,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
162
276
  with torch.no_grad():
163
277
  output = self.model(input_tensor)
164
278
 
165
- if self.task == "classification":
279
+ if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
166
280
  probs = torch.softmax(output, dim=1)
167
281
  labels = torch.argmax(probs, dim=1)
168
282
  return {
169
283
  PyTorchInferenceKeys.LABELS: labels,
170
284
  PyTorchInferenceKeys.PROBABILITIES: probs
171
285
  }
172
- else: # regression
286
+
287
+ elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
288
+ # Assumes model output is [N, 1] (a single logit)
289
+ # Squeeze output from [N, 1] to [N] if necessary
290
+ if output.ndim == 2 and output.shape[1] == 1:
291
+ output = output.squeeze(1)
292
+
293
+ probs = torch.sigmoid(output) # Probability of positive class
294
+ labels = (probs >= self._classification_threshold).int()
295
+ return {
296
+ PyTorchInferenceKeys.LABELS: labels,
297
+ PyTorchInferenceKeys.PROBABILITIES: probs
298
+ }
299
+
300
+ elif self.task == MLTaskKeys.REGRESSION:
173
301
  # For single-target regression, ensure output is flattened
174
302
  return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
303
+
304
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
305
+ probs = torch.sigmoid(output)
306
+ # Get binary predictions based on the threshold
307
+ labels = (probs >= self._classification_threshold).int()
308
+ return {
309
+ PyTorchInferenceKeys.LABELS: labels,
310
+ PyTorchInferenceKeys.PROBABILITIES: probs
311
+ }
312
+
313
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
314
+ # The output is already in the correct [batch_size, n_targets] shape
315
+ return {PyTorchInferenceKeys.PREDICTIONS: output}
316
+
317
+ else:
318
+ # should never happen
319
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
320
+ raise ValueError()
175
321
 
176
322
  def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
177
323
  """
@@ -181,7 +327,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
181
327
  features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
182
328
 
183
329
  Returns:
184
- A dictionary containing the raw output tensors for a single sample.
330
+ Dict: A dictionary containing the raw output tensors for a single sample.
185
331
  """
186
332
  if features.ndim == 1:
187
333
  features = features.reshape(1, -1) # Reshape to a batch of one
@@ -195,193 +341,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
195
341
  # Extract the first (and only) result from the batch output
196
342
  single_results = {key: value[0] for key, value in batch_results.items()}
197
343
  return single_results
198
-
344
+
199
345
  # --- NumPy Convenience Wrappers (on CPU) ---
200
346
 
201
347
  def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
202
348
  """
203
- Convenience wrapper for predict_batch that returns NumPy arrays.
349
+ Convenience wrapper for predict_batch that returns NumPy arrays
350
+ and adds string labels for classification tasks if a class_map is set.
204
351
  """
205
352
  tensor_results = self.predict_batch(features)
206
353
  numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
354
+
355
+ # Add string names for classification if map exists
356
+ is_classification = self.task in [
357
+ MLTaskKeys.BINARY_CLASSIFICATION,
358
+ MLTaskKeys.MULTICLASS_CLASSIFICATION
359
+ ]
360
+
361
+ if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
362
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
363
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
364
+ self._idx_to_class.get(label_id, "Unknown")
365
+ for label_id in int_labels
366
+ ]
367
+
207
368
  return numpy_results
208
369
 
209
370
  def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
210
371
  """
211
- Convenience wrapper for predict that returns NumPy arrays or scalars.
372
+ Convenience wrapper for predict that returns NumPy arrays or scalars
373
+ and adds string labels for classification tasks if a class_map is set.
212
374
  """
213
375
  tensor_results = self.predict(features)
214
376
 
215
- if self.task == "regression":
377
+ if self.task == MLTaskKeys.REGRESSION:
216
378
  # .item() implicitly moves to CPU and returns a Python scalar
217
379
  return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
218
- else: # classification
380
+
381
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
382
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
383
+ label_name = "Unknown"
384
+ if self._idx_to_class:
385
+ label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
386
+
219
387
  return {
220
- PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
388
+ PyTorchInferenceKeys.LABELS: int_label,
389
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
221
390
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
222
391
  }
392
+
393
+ elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
394
+ # For multi-target models, the output is always an array.
395
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
396
+ return numpy_results
397
+ else:
398
+ # should never happen
399
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
400
+ raise ValueError()
223
401
 
224
402
  def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
225
403
  """
226
404
  Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
227
405
 
228
- `target_id` must be implemented.
406
+ `target_ids` must be implemented.
229
407
  """
230
- if self.target_id is None:
231
- _LOGGER.error(f"'target_id' has not been implemented.")
408
+ if self.target_ids is None:
409
+ _LOGGER.error(f"'target_ids' has not been implemented.")
232
410
  raise AttributeError()
233
411
 
234
- if self.task == "regression":
412
+ if self.task == MLTaskKeys.REGRESSION:
235
413
  result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
236
- else:
414
+ return {self.target_ids[0]: result}
415
+
416
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
237
417
  result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
418
+ return {self.target_ids[0]: result}
238
419
 
239
- return {self.target_id: result}
240
-
241
-
242
- class PyTorchInferenceHandlerMulti(_BaseInferenceHandler):
243
- """
244
- Handles loading a PyTorch model's state dictionary and performing inference
245
- for multi-target regression or multi-label classification tasks.
246
- """
247
- def __init__(self,
248
- model: nn.Module,
249
- state_dict: Union[str, Path],
250
- task: Literal["multi_target_regression", "multi_label_classification"],
251
- device: str = 'cpu',
252
- target_ids: Optional[list[str]] = None,
253
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
254
- """
255
- Initializes the handler for multi-target tasks.
256
-
257
- Args:
258
- model (nn.Module): An instantiated PyTorch model.
259
- state_dict (str | Path): Path to the saved .pth model state_dict file.
260
- task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
261
- device (str): The device to run inference on ('cpu', 'cuda', 'mps').
262
- target_ids (list[str] | None): An optional identifier for the targets.
263
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
264
- """
265
- super().__init__(model, state_dict, device, scaler)
266
-
267
- if task not in ["multi_target_regression", "multi_label_classification"]:
268
- _LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
269
- raise ValueError()
270
- self.task = task
271
- self.target_ids = target_ids
272
-
273
- def predict_batch(self,
274
- features: Union[np.ndarray, torch.Tensor],
275
- classification_threshold: float = 0.5
276
- ) -> Dict[str, torch.Tensor]:
277
- """
278
- Core batch prediction method for multi-target models.
279
-
280
- Args:
281
- features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
282
- classification_threshold (float): The threshold to convert probabilities
283
- into binary predictions for multi-label classification.
284
-
285
- Returns:
286
- A dictionary containing the raw output tensors from the model.
287
- """
288
- if features.ndim != 2:
289
- _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
290
- raise ValueError()
291
-
292
- input_tensor = self._preprocess_input(features)
293
-
294
- with torch.no_grad():
295
- output = self.model(input_tensor)
296
-
297
- if self.task == "multi_label_classification":
298
- probs = torch.sigmoid(output)
299
- # Get binary predictions based on the threshold
300
- labels = (probs >= classification_threshold).int()
301
- return {
302
- PyTorchInferenceKeys.LABELS: labels,
303
- PyTorchInferenceKeys.PROBABILITIES: probs
304
- }
305
- else: # multi_target_regression
306
- # The output is already in the correct [batch_size, n_targets] shape
307
- return {PyTorchInferenceKeys.PREDICTIONS: output}
308
-
309
- def predict(self,
310
- features: Union[np.ndarray, torch.Tensor],
311
- classification_threshold: float = 0.5
312
- ) -> Dict[str, torch.Tensor]:
313
- """
314
- Core single-sample prediction method for multi-target models.
315
-
316
- Args:
317
- features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
318
- classification_threshold (float): The threshold for multi-label tasks.
319
-
320
- Returns:
321
- A dictionary containing the raw output tensors for a single sample.
322
- """
323
- if features.ndim == 1:
324
- features = features.reshape(1, -1)
325
-
326
- if features.shape[0] != 1:
327
- _LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
420
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
421
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
422
+ return {key: value for key, value in zip(self.target_ids, result)}
423
+
424
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
425
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
426
+ return {key: value for key, value in zip(self.target_ids, result)}
427
+
428
+ else:
429
+ # should never happen
430
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
328
431
  raise ValueError()
329
-
330
- batch_results = self.predict_batch(features, classification_threshold)
331
-
332
- single_results = {key: value[0] for key, value in batch_results.items()}
333
- return single_results
334
-
335
- # --- NumPy Convenience Wrappers (on CPU) ---
336
-
337
- def predict_batch_numpy(self,
338
- features: Union[np.ndarray, torch.Tensor],
339
- classification_threshold: float = 0.5
340
- ) -> Dict[str, np.ndarray]:
341
- """
342
- Convenience wrapper for predict_batch that returns NumPy arrays.
343
- """
344
- tensor_results = self.predict_batch(features, classification_threshold)
345
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
346
- return numpy_results
347
-
348
- def predict_numpy(self,
349
- features: Union[np.ndarray, torch.Tensor],
350
- classification_threshold: float = 0.5
351
- ) -> Dict[str, np.ndarray]:
352
- """
353
- Convenience wrapper for predict that returns NumPy arrays for a single sample.
354
- Note: For multi-target models, the output is always an array.
355
- """
356
- tensor_results = self.predict(features, classification_threshold)
357
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
358
- return numpy_results
359
-
360
- def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
361
- """
362
- Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
363
432
 
364
- `target_ids` must be implemented.
433
+ def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
365
434
  """
366
- if self.target_ids is None:
367
- _LOGGER.error(f"'target_id' has not been implemented.")
368
- raise AttributeError()
435
+ Sets the classification threshold for the current inference instance.
369
436
 
370
- if self.task == "multi_target_regression":
371
- result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
372
- else:
373
- result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
437
+ If a threshold was previously loaded from a model configuration, this
438
+ method will log a warning and refuse to update the value. This
439
+ prevents accidentally overriding a setting from a loaded checkpoint.
374
440
 
375
- return {key: value for key, value in zip(self.target_ids, result)}
441
+ To bypass this safety check set `force_overwrite` to `True`.
376
442
 
443
+ Args:
444
+ threshold (float): The new classification threshold value to set.
445
+ force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
446
+ """
447
+ if self._loaded_threshold:
448
+ warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
449
+ if not force_overwrite:
450
+ warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
451
+ _LOGGER.warning(warning_message)
452
+ return
453
+ else:
454
+ warning_message += f" Overwriting it to {threshold}."
455
+ _LOGGER.warning(warning_message)
456
+
457
+ self._classification_threshold = threshold
377
458
 
378
- def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
459
+
460
+ def multi_inference_regression(handlers: list[DragonInferenceHandler],
379
461
  feature_vector: Union[np.ndarray, torch.Tensor],
380
462
  output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
381
463
  """
382
464
  Performs regression inference using multiple models on a single feature vector.
383
465
 
384
- This function iterates through a list of PyTorchInferenceHandler objects,
466
+ This function iterates through a list of DragonInferenceHandler objects,
385
467
  each configured for a different regression target. It runs a prediction for
386
468
  each handler using the same input feature vector and returns the results
387
469
  in a dictionary.
@@ -391,7 +473,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
391
473
  - 2D input: Returns a dictionary mapping target ID to a list of values.
392
474
 
393
475
  Args:
394
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference
476
+ handlers (list[DragonInferenceHandler]): A list of initialized inference
395
477
  handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
396
478
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
397
479
  output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
@@ -421,11 +503,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
421
503
  results: dict[str,Any] = dict()
422
504
  for handler in handlers:
423
505
  # validation
424
- if handler.target_id is None:
425
- _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
506
+ if handler.target_ids is None:
507
+ _LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
426
508
  raise AttributeError()
427
- if handler.task != "regression":
428
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but only 'regression' tasks are supported.")
509
+ if handler.task != MLTaskKeys.REGRESSION:
510
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
429
511
  raise ValueError()
430
512
 
431
513
  # inference
@@ -434,33 +516,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
434
516
  numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
435
517
  if is_single_sample:
436
518
  # For a single sample, convert the 1-element array to a Python scalar
437
- results[handler.target_id] = numpy_result.item()
519
+ results[handler.target_ids[0]] = numpy_result.item()
438
520
  else:
439
521
  # For a batch, return the full NumPy array of predictions
440
- results[handler.target_id] = numpy_result
522
+ results[handler.target_ids[0]] = numpy_result
441
523
 
442
524
  else: # output == "torch"
443
525
  # This path returns PyTorch tensors on the model's device
444
526
  torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
445
527
  if is_single_sample:
446
528
  # For a single sample, return the 0-dim tensor
447
- results[handler.target_id] = torch_result[0]
529
+ results[handler.target_ids[0]] = torch_result[0]
448
530
  else:
449
531
  # For a batch, return the full tensor of predictions
450
- results[handler.target_id] = torch_result
532
+ results[handler.target_ids[0]] = torch_result
451
533
 
452
534
  return results
453
535
 
454
536
 
455
537
  def multi_inference_classification(
456
- handlers: list[PyTorchInferenceHandler],
538
+ handlers: list[DragonInferenceHandler],
457
539
  feature_vector: Union[np.ndarray, torch.Tensor],
458
540
  output: Literal["numpy","torch"]="numpy"
459
541
  ) -> tuple[dict[str, Any], dict[str, Any]]:
460
542
  """
461
543
  Performs classification inference on a single sample or a batch.
462
544
 
463
- This function iterates through a list of PyTorchInferenceHandler objects,
545
+ This function iterates through a list of DragonInferenceHandler objects,
464
546
  each configured for a different classification target. It returns two
465
547
  dictionaries: one for the predicted labels and one for the probabilities.
466
548
 
@@ -469,7 +551,7 @@ def multi_inference_classification(
469
551
  - 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
470
552
 
471
553
  Args:
472
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
554
+ handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
473
555
  with `task="classification"`.
474
556
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
475
557
  or a batch of samples (2D) for prediction.
@@ -502,11 +584,11 @@ def multi_inference_classification(
502
584
 
503
585
  for handler in handlers:
504
586
  # Validation
505
- if handler.target_id is None:
587
+ if handler.target_ids is None:
506
588
  _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
507
589
  raise AttributeError()
508
- if handler.task != "classification":
509
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but this function only supports 'classification'.")
590
+ if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
591
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
510
592
  raise ValueError()
511
593
 
512
594
  # Inference
@@ -524,15 +606,15 @@ def multi_inference_classification(
524
606
  # For "numpy", convert the single label to a Python int scalar.
525
607
  # For "torch", get the 0-dim tensor label.
526
608
  if output == "numpy":
527
- labels_results[handler.target_id] = labels.item()
609
+ labels_results[handler.target_ids[0]] = labels.item()
528
610
  else: # torch
529
- labels_results[handler.target_id] = labels[0]
611
+ labels_results[handler.target_ids[0]] = labels[0]
530
612
 
531
613
  # The probabilities are an array/tensor of values
532
- probs_results[handler.target_id] = probabilities[0]
614
+ probs_results[handler.target_ids[0]] = probabilities[0]
533
615
  else:
534
- labels_results[handler.target_id] = labels
535
- probs_results[handler.target_id] = probabilities
616
+ labels_results[handler.target_ids[0]] = labels
617
+ probs_results[handler.target_ids[0]] = probabilities
536
618
 
537
619
  return labels_results, probs_results
538
620