dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py CHANGED
@@ -5,16 +5,15 @@ from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
6
  from abc import ABC, abstractmethod
7
7
 
8
- from .ML_scaler import PytorchScaler
8
+ from .ML_scaler import DragonScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
12
+ from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
13
13
 
14
14
 
15
15
  __all__ = [
16
- "PyTorchInferenceHandler",
17
- "PyTorchInferenceHandlerMulti",
16
+ "DragonInferenceHandler",
18
17
  "multi_inference_regression",
19
18
  "multi_inference_classification"
20
19
  ]
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
31
30
  model: nn.Module,
32
31
  state_dict: Union[str, Path],
33
32
  device: str = 'cpu',
34
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
33
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
35
34
  """
36
35
  Initializes the handler.
37
36
 
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
39
38
  model (nn.Module): An instantiated PyTorch model.
40
39
  state_dict (str | Path): Path to the saved .pth model state_dict file.
41
40
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
42
- scaler (PytorchScaler | str | Path | None): An optional scaler or path to a saved scaler state.
41
+ scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
43
42
  """
44
43
  self.model = model
45
44
  self.device = self._validate_device(device)
45
+ self._classification_threshold = 0.5
46
+ self._loaded_threshold: bool = False
47
+ self._loaded_class_map: bool = False
48
+ self._class_map: Optional[dict[str,int]] = None
49
+ self._idx_to_class: Optional[Dict[int, str]] = None
50
+ self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
46
51
 
47
52
  # Load the scaler if a path is provided
48
53
  if scaler is not None:
49
54
  if isinstance(scaler, (str, Path)):
50
- self.scaler = PytorchScaler.load(scaler)
55
+ self.scaler = DragonScaler.load(scaler)
51
56
  else:
52
57
  self.scaler = scaler
53
58
  else:
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
58
63
  try:
59
64
  # Load whatever is in the file
60
65
  loaded_data = torch.load(model_p, map_location=self.device)
61
-
62
- # Check if it's the new checkpoint dictionary or an old weights-only file
63
- if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
- # It's a new training checkpoint, extract the weights
65
- self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+
67
+ # Check if it's dictionary or a old weights-only file
68
+ if isinstance(loaded_data, dict):
69
+ self._loaded_data_dict = loaded_data # Store the dict
70
+
71
+ if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
72
+ # It's a new training checkpoint, extract the weights
73
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
74
+
75
+ # attempt to get a custom classification threshold
76
+ if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
77
+ try:
78
+ self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
79
+ except Exception as e_int:
80
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
81
+ self._classification_threshold = 0.5
82
+ else:
83
+ _LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
84
+ self._loaded_threshold = True
85
+
86
+ # attempt to get a class map
87
+ if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
88
+ try:
89
+ self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
90
+ except Exception as e_int:
91
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
92
+ else:
93
+ if isinstance(self._class_map, dict):
94
+ self._loaded_class_map = True
95
+ self.set_class_map(self._class_map)
96
+ else:
97
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
98
+ self._class_map = None
99
+ else:
100
+ # It's a state_dict, load it directly
101
+ self.model.load_state_dict(loaded_data)
102
+
66
103
  else:
67
- # It's an old-style file (or just a state_dict), load it directly
104
+ # It's an old state_dict (just weights), load it directly
68
105
  self.model.load_state_dict(loaded_data)
69
106
 
70
107
  _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
@@ -82,25 +119,40 @@ class _BaseInferenceHandler(ABC):
82
119
  _LOGGER.warning("CUDA not available, switching to CPU.")
83
120
  device_lower = "cpu"
84
121
  elif device_lower == "mps" and not torch.backends.mps.is_available():
85
- # Your M-series Mac will appreciate this check!
86
122
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
87
123
  device_lower = "cpu"
88
124
  return torch.device(device_lower)
89
-
90
- def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
91
- """
92
- Converts input to a torch.Tensor, applies scaling if a scaler is
93
- present, and moves it to the correct device.
125
+
126
+ def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
94
127
  """
95
- if isinstance(features, np.ndarray):
96
- features_tensor = torch.from_numpy(features).float()
97
- else:
98
- features_tensor = features.float()
99
-
100
- if self.scaler:
101
- features_tensor = self.scaler.transform(features_tensor)
128
+ Sets the class name mapping to translate predicted integer labels back into string names.
129
+
130
+ If a class_map was previously loaded from a model configuration, this
131
+ method will log a warning and refuse to update the value. This
132
+ prevents accidentally overriding a setting from a loaded checkpoint.
133
+
134
+ To bypass this safety check set `force_overwrite` to `True`.
102
135
 
103
- return features_tensor.to(self.device)
136
+ Args:
137
+ class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
138
+ force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
139
+ """
140
+ if self._loaded_class_map:
141
+ warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
142
+ if not force_overwrite:
143
+ warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
144
+ _LOGGER.warning(warning_message)
145
+ return
146
+ else:
147
+ warning_message += " Overwriting it for this inference instance."
148
+ _LOGGER.warning(warning_message)
149
+
150
+ # Store the map and invert it for fast lookup
151
+ self._class_map = class_map
152
+ self._idx_to_class = {v: k for k, v in class_map.items()}
153
+ # Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
154
+ self._loaded_class_map = True # Protect this newly set map
155
+ _LOGGER.info("Class map set for label-to-name translation.")
104
156
 
105
157
  @abstractmethod
106
158
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
@@ -113,46 +165,107 @@ class _BaseInferenceHandler(ABC):
113
165
  pass
114
166
 
115
167
 
116
- class PyTorchInferenceHandler(_BaseInferenceHandler):
168
+ class DragonInferenceHandler(_BaseInferenceHandler):
117
169
  """
118
- Handles loading a PyTorch model's state dictionary and performing inference
119
- for single-target regression or classification tasks.
170
+ Handles loading a PyTorch model's state dictionary and performing inference.
120
171
  """
121
172
  def __init__(self,
122
173
  model: nn.Module,
123
174
  state_dict: Union[str, Path],
124
- task: Literal["classification", "regression"],
175
+ task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
125
176
  device: str = 'cpu',
126
- target_id: Optional[str] = None,
127
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
177
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
128
178
  """
129
179
  Initializes the handler for single-target tasks.
130
180
 
131
181
  Args:
132
182
  model (nn.Module): An instantiated PyTorch model architecture.
133
183
  state_dict (str | Path): Path to the saved .pth model state_dict file.
134
- task (str): The type of task, 'regression' or 'classification'.
184
+ task (str): The type of task.
135
185
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
136
- target_id (str | None): An optional identifier for the target.
137
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
186
+ scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
187
+
188
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
138
189
  """
139
190
  # Call the parent constructor to handle model loading, device, and scaler
140
191
  super().__init__(model, state_dict, device, scaler)
141
192
 
142
- if task not in ["classification", "regression"]:
143
- raise ValueError("`task` must be 'classification' or 'regression'.")
193
+ if task not in [MLTaskKeys.REGRESSION,
194
+ MLTaskKeys.BINARY_CLASSIFICATION,
195
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
196
+ MLTaskKeys.MULTITARGET_REGRESSION,
197
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
198
+ _LOGGER.error(f"'task' not recognized: '{task}'.")
199
+ raise ValueError()
144
200
  self.task = task
145
- self.target_id = target_id
201
+ self.target_ids: Optional[list[str]] = None
202
+ self._target_ids_set: bool = False
203
+
204
+ # attempt to get target name or target names
205
+ if PyTorchCheckpointKeys.TARGET_NAME in self._loaded_data_dict:
206
+ try:
207
+ target_from_dict = [self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAME]]
208
+ except Exception as e_int:
209
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAME}' but an error occurred when retrieving it:\n{e_int}")
210
+ self.target_ids = None
211
+ else:
212
+ self.set_target_ids([target_from_dict]) # type: ignore
213
+ elif PyTorchCheckpointKeys.TARGET_NAMES in self._loaded_data_dict:
214
+ try:
215
+ targets_from_dict = self._loaded_data_dict[PyTorchCheckpointKeys.TARGET_NAMES]
216
+ except Exception as e_int:
217
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.TARGET_NAMES}' but an error occurred when retrieving it:\n{e_int}")
218
+ self.target_ids = None
219
+ else:
220
+ self.set_target_ids(targets_from_dict) # type: ignore
221
+
222
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
223
+ """
224
+ Converts input to a torch.Tensor, applies scaling if a scaler is
225
+ present, and moves it to the correct device.
226
+ """
227
+ if isinstance(features, np.ndarray):
228
+ features_tensor = torch.from_numpy(features).float()
229
+ else:
230
+ features_tensor = features.float()
231
+
232
+ if self.scaler:
233
+ features_tensor = self.scaler.transform(features_tensor)
234
+
235
+ return features_tensor.to(self.device)
236
+
237
+ def set_target_ids(self, target_names: list[str], force_overwrite: bool=False):
238
+ """
239
+ Assigns the provided list of strings as the target variable names.
240
+ If target IDs have already been set, this method will log a warning.
241
+
242
+ Args:
243
+ target_names (list[str]): A list of target names.
244
+ force_overwrite (bool): If True, allows the method to overwrite previously set target IDs.
245
+ """
246
+ if self._target_ids_set:
247
+ warning_message = "Target IDs was previously set."
248
+ if not force_overwrite:
249
+ warning_message += " Use `force_overwrite=True` to overwrite."
250
+ _LOGGER.warning(warning_message)
251
+ return
252
+ else:
253
+ warning_message += " Overwriting..."
254
+ _LOGGER.warning(warning_message)
255
+
256
+ self.target_ids = target_names
257
+ self._target_ids_set = True
258
+ _LOGGER.info("Target IDs set.")
146
259
 
147
260
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
148
261
  """
149
- Core batch prediction method for single-target models.
262
+ Core batch prediction method.
150
263
 
151
264
  Args:
152
265
  features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
153
266
 
154
267
  Returns:
155
- A dictionary containing the raw output tensors from the model.
268
+ Dict: A dictionary containing the raw output tensors from the model.
156
269
  """
157
270
  if features.ndim != 2:
158
271
  _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
@@ -163,16 +276,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
163
276
  with torch.no_grad():
164
277
  output = self.model(input_tensor)
165
278
 
166
- if self.task == "classification":
279
+ if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
167
280
  probs = torch.softmax(output, dim=1)
168
281
  labels = torch.argmax(probs, dim=1)
169
282
  return {
170
283
  PyTorchInferenceKeys.LABELS: labels,
171
284
  PyTorchInferenceKeys.PROBABILITIES: probs
172
285
  }
173
- else: # regression
286
+
287
+ elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
288
+ # Assumes model output is [N, 1] (a single logit)
289
+ # Squeeze output from [N, 1] to [N] if necessary
290
+ if output.ndim == 2 and output.shape[1] == 1:
291
+ output = output.squeeze(1)
292
+
293
+ probs = torch.sigmoid(output) # Probability of positive class
294
+ labels = (probs >= self._classification_threshold).int()
295
+ return {
296
+ PyTorchInferenceKeys.LABELS: labels,
297
+ PyTorchInferenceKeys.PROBABILITIES: probs
298
+ }
299
+
300
+ elif self.task == MLTaskKeys.REGRESSION:
174
301
  # For single-target regression, ensure output is flattened
175
302
  return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
303
+
304
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
305
+ probs = torch.sigmoid(output)
306
+ # Get binary predictions based on the threshold
307
+ labels = (probs >= self._classification_threshold).int()
308
+ return {
309
+ PyTorchInferenceKeys.LABELS: labels,
310
+ PyTorchInferenceKeys.PROBABILITIES: probs
311
+ }
312
+
313
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
314
+ # The output is already in the correct [batch_size, n_targets] shape
315
+ return {PyTorchInferenceKeys.PREDICTIONS: output}
316
+
317
+ else:
318
+ # should never happen
319
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
320
+ raise ValueError()
176
321
 
177
322
  def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
178
323
  """
@@ -182,7 +327,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
182
327
  features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
183
328
 
184
329
  Returns:
185
- A dictionary containing the raw output tensors for a single sample.
330
+ Dict: A dictionary containing the raw output tensors for a single sample.
186
331
  """
187
332
  if features.ndim == 1:
188
333
  features = features.reshape(1, -1) # Reshape to a batch of one
@@ -196,193 +341,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
196
341
  # Extract the first (and only) result from the batch output
197
342
  single_results = {key: value[0] for key, value in batch_results.items()}
198
343
  return single_results
199
-
344
+
200
345
  # --- NumPy Convenience Wrappers (on CPU) ---
201
346
 
202
347
  def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
203
348
  """
204
- Convenience wrapper for predict_batch that returns NumPy arrays.
349
+ Convenience wrapper for predict_batch that returns NumPy arrays
350
+ and adds string labels for classification tasks if a class_map is set.
205
351
  """
206
352
  tensor_results = self.predict_batch(features)
207
353
  numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
354
+
355
+ # Add string names for classification if map exists
356
+ is_classification = self.task in [
357
+ MLTaskKeys.BINARY_CLASSIFICATION,
358
+ MLTaskKeys.MULTICLASS_CLASSIFICATION
359
+ ]
360
+
361
+ if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
362
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
363
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
364
+ self._idx_to_class.get(label_id, "Unknown")
365
+ for label_id in int_labels
366
+ ]
367
+
208
368
  return numpy_results
209
369
 
210
370
  def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
211
371
  """
212
- Convenience wrapper for predict that returns NumPy arrays or scalars.
372
+ Convenience wrapper for predict that returns NumPy arrays or scalars
373
+ and adds string labels for classification tasks if a class_map is set.
213
374
  """
214
375
  tensor_results = self.predict(features)
215
376
 
216
- if self.task == "regression":
377
+ if self.task == MLTaskKeys.REGRESSION:
217
378
  # .item() implicitly moves to CPU and returns a Python scalar
218
379
  return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
219
- else: # classification
380
+
381
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
382
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
383
+ label_name = "Unknown"
384
+ if self._idx_to_class:
385
+ label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
386
+
220
387
  return {
221
- PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
388
+ PyTorchInferenceKeys.LABELS: int_label,
389
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
222
390
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
223
391
  }
392
+
393
+ elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
394
+ # For multi-target models, the output is always an array.
395
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
396
+ return numpy_results
397
+ else:
398
+ # should never happen
399
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
400
+ raise ValueError()
224
401
 
225
402
  def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
226
403
  """
227
404
  Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
228
405
 
229
- `target_id` must be implemented.
406
+ `target_ids` must be implemented.
230
407
  """
231
- if self.target_id is None:
232
- _LOGGER.error(f"'target_id' has not been implemented.")
408
+ if self.target_ids is None:
409
+ _LOGGER.error(f"'target_ids' has not been implemented.")
233
410
  raise AttributeError()
234
411
 
235
- if self.task == "regression":
412
+ if self.task == MLTaskKeys.REGRESSION:
236
413
  result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
237
- else:
414
+ return {self.target_ids[0]: result}
415
+
416
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
238
417
  result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
418
+ return {self.target_ids[0]: result}
239
419
 
240
- return {self.target_id: result}
241
-
242
-
243
- class PyTorchInferenceHandlerMulti(_BaseInferenceHandler):
244
- """
245
- Handles loading a PyTorch model's state dictionary and performing inference
246
- for multi-target regression or multi-label classification tasks.
247
- """
248
- def __init__(self,
249
- model: nn.Module,
250
- state_dict: Union[str, Path],
251
- task: Literal["multi_target_regression", "multi_label_classification"],
252
- device: str = 'cpu',
253
- target_ids: Optional[list[str]] = None,
254
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
255
- """
256
- Initializes the handler for multi-target tasks.
257
-
258
- Args:
259
- model (nn.Module): An instantiated PyTorch model.
260
- state_dict (str | Path): Path to the saved .pth model state_dict file.
261
- task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
262
- device (str): The device to run inference on ('cpu', 'cuda', 'mps').
263
- target_ids (list[str] | None): An optional identifier for the targets.
264
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
265
- """
266
- super().__init__(model, state_dict, device, scaler)
267
-
268
- if task not in ["multi_target_regression", "multi_label_classification"]:
269
- _LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
270
- raise ValueError()
271
- self.task = task
272
- self.target_ids = target_ids
273
-
274
- def predict_batch(self,
275
- features: Union[np.ndarray, torch.Tensor],
276
- classification_threshold: float = 0.5
277
- ) -> Dict[str, torch.Tensor]:
278
- """
279
- Core batch prediction method for multi-target models.
280
-
281
- Args:
282
- features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
283
- classification_threshold (float): The threshold to convert probabilities
284
- into binary predictions for multi-label classification.
285
-
286
- Returns:
287
- A dictionary containing the raw output tensors from the model.
288
- """
289
- if features.ndim != 2:
290
- _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
291
- raise ValueError()
292
-
293
- input_tensor = self._preprocess_input(features)
294
-
295
- with torch.no_grad():
296
- output = self.model(input_tensor)
297
-
298
- if self.task == "multi_label_classification":
299
- probs = torch.sigmoid(output)
300
- # Get binary predictions based on the threshold
301
- labels = (probs >= classification_threshold).int()
302
- return {
303
- PyTorchInferenceKeys.LABELS: labels,
304
- PyTorchInferenceKeys.PROBABILITIES: probs
305
- }
306
- else: # multi_target_regression
307
- # The output is already in the correct [batch_size, n_targets] shape
308
- return {PyTorchInferenceKeys.PREDICTIONS: output}
309
-
310
- def predict(self,
311
- features: Union[np.ndarray, torch.Tensor],
312
- classification_threshold: float = 0.5
313
- ) -> Dict[str, torch.Tensor]:
314
- """
315
- Core single-sample prediction method for multi-target models.
316
-
317
- Args:
318
- features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
319
- classification_threshold (float): The threshold for multi-label tasks.
320
-
321
- Returns:
322
- A dictionary containing the raw output tensors for a single sample.
323
- """
324
- if features.ndim == 1:
325
- features = features.reshape(1, -1)
326
-
327
- if features.shape[0] != 1:
328
- _LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
420
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
421
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
422
+ return {key: value for key, value in zip(self.target_ids, result)}
423
+
424
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
425
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
426
+ return {key: value for key, value in zip(self.target_ids, result)}
427
+
428
+ else:
429
+ # should never happen
430
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
329
431
  raise ValueError()
330
-
331
- batch_results = self.predict_batch(features, classification_threshold)
332
-
333
- single_results = {key: value[0] for key, value in batch_results.items()}
334
- return single_results
335
-
336
- # --- NumPy Convenience Wrappers (on CPU) ---
337
-
338
- def predict_batch_numpy(self,
339
- features: Union[np.ndarray, torch.Tensor],
340
- classification_threshold: float = 0.5
341
- ) -> Dict[str, np.ndarray]:
342
- """
343
- Convenience wrapper for predict_batch that returns NumPy arrays.
344
- """
345
- tensor_results = self.predict_batch(features, classification_threshold)
346
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
347
- return numpy_results
348
-
349
- def predict_numpy(self,
350
- features: Union[np.ndarray, torch.Tensor],
351
- classification_threshold: float = 0.5
352
- ) -> Dict[str, np.ndarray]:
353
- """
354
- Convenience wrapper for predict that returns NumPy arrays for a single sample.
355
- Note: For multi-target models, the output is always an array.
356
- """
357
- tensor_results = self.predict(features, classification_threshold)
358
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
359
- return numpy_results
360
-
361
- def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
362
- """
363
- Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
364
432
 
365
- `target_ids` must be implemented.
433
+ def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
366
434
  """
367
- if self.target_ids is None:
368
- _LOGGER.error(f"'target_id' has not been implemented.")
369
- raise AttributeError()
435
+ Sets the classification threshold for the current inference instance.
370
436
 
371
- if self.task == "multi_target_regression":
372
- result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
373
- else:
374
- result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
437
+ If a threshold was previously loaded from a model configuration, this
438
+ method will log a warning and refuse to update the value. This
439
+ prevents accidentally overriding a setting from a loaded checkpoint.
375
440
 
376
- return {key: value for key, value in zip(self.target_ids, result)}
441
+ To bypass this safety check set `force_overwrite` to `True`.
377
442
 
443
+ Args:
444
+ threshold (float): The new classification threshold value to set.
445
+ force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
446
+ """
447
+ if self._loaded_threshold:
448
+ warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
449
+ if not force_overwrite:
450
+ warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
451
+ _LOGGER.warning(warning_message)
452
+ return
453
+ else:
454
+ warning_message += f" Overwriting it to {threshold}."
455
+ _LOGGER.warning(warning_message)
456
+
457
+ self._classification_threshold = threshold
378
458
 
379
- def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
459
+
460
+ def multi_inference_regression(handlers: list[DragonInferenceHandler],
380
461
  feature_vector: Union[np.ndarray, torch.Tensor],
381
462
  output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
382
463
  """
383
464
  Performs regression inference using multiple models on a single feature vector.
384
465
 
385
- This function iterates through a list of PyTorchInferenceHandler objects,
466
+ This function iterates through a list of DragonInferenceHandler objects,
386
467
  each configured for a different regression target. It runs a prediction for
387
468
  each handler using the same input feature vector and returns the results
388
469
  in a dictionary.
@@ -392,7 +473,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
392
473
  - 2D input: Returns a dictionary mapping target ID to a list of values.
393
474
 
394
475
  Args:
395
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference
476
+ handlers (list[DragonInferenceHandler]): A list of initialized inference
396
477
  handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
397
478
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
398
479
  output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
@@ -422,11 +503,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
422
503
  results: dict[str,Any] = dict()
423
504
  for handler in handlers:
424
505
  # validation
425
- if handler.target_id is None:
426
- _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
506
+ if handler.target_ids is None:
507
+ _LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
427
508
  raise AttributeError()
428
- if handler.task != "regression":
429
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but only 'regression' tasks are supported.")
509
+ if handler.task != MLTaskKeys.REGRESSION:
510
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
430
511
  raise ValueError()
431
512
 
432
513
  # inference
@@ -435,33 +516,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
435
516
  numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
436
517
  if is_single_sample:
437
518
  # For a single sample, convert the 1-element array to a Python scalar
438
- results[handler.target_id] = numpy_result.item()
519
+ results[handler.target_ids[0]] = numpy_result.item()
439
520
  else:
440
521
  # For a batch, return the full NumPy array of predictions
441
- results[handler.target_id] = numpy_result
522
+ results[handler.target_ids[0]] = numpy_result
442
523
 
443
524
  else: # output == "torch"
444
525
  # This path returns PyTorch tensors on the model's device
445
526
  torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
446
527
  if is_single_sample:
447
528
  # For a single sample, return the 0-dim tensor
448
- results[handler.target_id] = torch_result[0]
529
+ results[handler.target_ids[0]] = torch_result[0]
449
530
  else:
450
531
  # For a batch, return the full tensor of predictions
451
- results[handler.target_id] = torch_result
532
+ results[handler.target_ids[0]] = torch_result
452
533
 
453
534
  return results
454
535
 
455
536
 
456
537
  def multi_inference_classification(
457
- handlers: list[PyTorchInferenceHandler],
538
+ handlers: list[DragonInferenceHandler],
458
539
  feature_vector: Union[np.ndarray, torch.Tensor],
459
540
  output: Literal["numpy","torch"]="numpy"
460
541
  ) -> tuple[dict[str, Any], dict[str, Any]]:
461
542
  """
462
543
  Performs classification inference on a single sample or a batch.
463
544
 
464
- This function iterates through a list of PyTorchInferenceHandler objects,
545
+ This function iterates through a list of DragonInferenceHandler objects,
465
546
  each configured for a different classification target. It returns two
466
547
  dictionaries: one for the predicted labels and one for the probabilities.
467
548
 
@@ -470,7 +551,7 @@ def multi_inference_classification(
470
551
  - 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
471
552
 
472
553
  Args:
473
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
554
+ handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
474
555
  with `task="classification"`.
475
556
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
476
557
  or a batch of samples (2D) for prediction.
@@ -503,11 +584,11 @@ def multi_inference_classification(
503
584
 
504
585
  for handler in handlers:
505
586
  # Validation
506
- if handler.target_id is None:
587
+ if handler.target_ids is None:
507
588
  _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
508
589
  raise AttributeError()
509
- if handler.task != "classification":
510
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but this function only supports 'classification'.")
590
+ if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
591
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
511
592
  raise ValueError()
512
593
 
513
594
  # Inference
@@ -525,15 +606,15 @@ def multi_inference_classification(
525
606
  # For "numpy", convert the single label to a Python int scalar.
526
607
  # For "torch", get the 0-dim tensor label.
527
608
  if output == "numpy":
528
- labels_results[handler.target_id] = labels.item()
609
+ labels_results[handler.target_ids[0]] = labels.item()
529
610
  else: # torch
530
- labels_results[handler.target_id] = labels[0]
611
+ labels_results[handler.target_ids[0]] = labels[0]
531
612
 
532
613
  # The probabilities are an array/tensor of values
533
- probs_results[handler.target_id] = probabilities[0]
614
+ probs_results[handler.target_ids[0]] = probabilities[0]
534
615
  else:
535
- labels_results[handler.target_id] = labels
536
- probs_results[handler.target_id] = probabilities
616
+ labels_results[handler.target_ids[0]] = labels
617
+ probs_results[handler.target_ids[0]] = probabilities
537
618
 
538
619
  return labels_results, probs_results
539
620