dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +309 -0
  9. ml_tools/ML_datasetmaster.py +220 -260
  10. ml_tools/ML_evaluation.py +317 -81
  11. ml_tools/ML_evaluation_multi.py +127 -36
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1247 -338
  22. ml_tools/ML_utilities.py +51 -2
  23. ml_tools/ML_vision_datasetmaster.py +262 -118
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +15 -1
  27. ml_tools/ML_vision_transformers.py +233 -7
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -1
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +54 -11
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_inference.py CHANGED
@@ -5,16 +5,15 @@ from pathlib import Path
5
5
  from typing import Union, Literal, Dict, Any, Optional
6
6
  from abc import ABC, abstractmethod
7
7
 
8
- from .ML_scaler import PytorchScaler
8
+ from .ML_scaler import DragonScaler
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
11
  from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
12
+ from ._keys import PyTorchInferenceKeys, PyTorchCheckpointKeys, MLTaskKeys
13
13
 
14
14
 
15
15
  __all__ = [
16
- "PyTorchInferenceHandler",
17
- "PyTorchInferenceHandlerMulti",
16
+ "DragonInferenceHandler",
18
17
  "multi_inference_regression",
19
18
  "multi_inference_classification"
20
19
  ]
@@ -31,7 +30,7 @@ class _BaseInferenceHandler(ABC):
31
30
  model: nn.Module,
32
31
  state_dict: Union[str, Path],
33
32
  device: str = 'cpu',
34
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
33
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
35
34
  """
36
35
  Initializes the handler.
37
36
 
@@ -39,15 +38,21 @@ class _BaseInferenceHandler(ABC):
39
38
  model (nn.Module): An instantiated PyTorch model.
40
39
  state_dict (str | Path): Path to the saved .pth model state_dict file.
41
40
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
42
- scaler (PytorchScaler | str | Path | None): An optional scaler or path to a saved scaler state.
41
+ scaler (DragonScaler | str | Path | None): An optional scaler or path to a saved scaler state.
43
42
  """
44
43
  self.model = model
45
44
  self.device = self._validate_device(device)
45
+ self._classification_threshold = 0.5
46
+ self._loaded_threshold: bool = False
47
+ self._loaded_class_map: bool = False
48
+ self._class_map: Optional[dict[str,int]] = None
49
+ self._idx_to_class: Optional[Dict[int, str]] = None
50
+ self._loaded_data_dict: Dict[str, Any] = {} #Store whatever is in the finalized file
46
51
 
47
52
  # Load the scaler if a path is provided
48
53
  if scaler is not None:
49
54
  if isinstance(scaler, (str, Path)):
50
- self.scaler = PytorchScaler.load(scaler)
55
+ self.scaler = DragonScaler.load(scaler)
51
56
  else:
52
57
  self.scaler = scaler
53
58
  else:
@@ -58,13 +63,45 @@ class _BaseInferenceHandler(ABC):
58
63
  try:
59
64
  # Load whatever is in the file
60
65
  loaded_data = torch.load(model_p, map_location=self.device)
61
-
62
- # Check if it's the new checkpoint dictionary or an old weights-only file
63
- if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
64
- # It's a new training checkpoint, extract the weights
65
- self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
66
+
67
+ # Check if it's dictionary or a old weights-only file
68
+ if isinstance(loaded_data, dict):
69
+ self._loaded_data_dict = loaded_data # Store the dict
70
+
71
+ if PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
72
+ # It's a new training checkpoint, extract the weights
73
+ self.model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
74
+
75
+ # attempt to get a custom classification threshold
76
+ if PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD in loaded_data:
77
+ try:
78
+ self._classification_threshold = float(loaded_data[PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD])
79
+ except Exception as e_int:
80
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' but an error occurred when retrieving it:\n{e_int}")
81
+ self._classification_threshold = 0.5
82
+ else:
83
+ _LOGGER.info(f"'{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}' found and set to {self._classification_threshold}")
84
+ self._loaded_threshold = True
85
+
86
+ # attempt to get a class map
87
+ if PyTorchCheckpointKeys.CLASS_MAP in loaded_data:
88
+ try:
89
+ self._class_map = loaded_data[PyTorchCheckpointKeys.CLASS_MAP]
90
+ except Exception as e_int:
91
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but an error occurred when retrieving it:\n{e_int}")
92
+ else:
93
+ if isinstance(self._class_map, dict):
94
+ self._loaded_class_map = True
95
+ self.set_class_map(self._class_map)
96
+ else:
97
+ _LOGGER.warning(f"State Dictionary has the key '{PyTorchCheckpointKeys.CLASS_MAP}' but it is not a dict: '{type(self._class_map)}'.")
98
+ self._class_map = None
99
+ else:
100
+ # It's a state_dict, load it directly
101
+ self.model.load_state_dict(loaded_data)
102
+
66
103
  else:
67
- # It's an old-style file (or just a state_dict), load it directly
104
+ # It's an old state_dict (just weights), load it directly
68
105
  self.model.load_state_dict(loaded_data)
69
106
 
70
107
  _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
@@ -85,21 +122,37 @@ class _BaseInferenceHandler(ABC):
85
122
  _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
86
123
  device_lower = "cpu"
87
124
  return torch.device(device_lower)
88
-
89
- def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
90
- """
91
- Converts input to a torch.Tensor, applies scaling if a scaler is
92
- present, and moves it to the correct device.
125
+
126
+ def set_class_map(self, class_map: Dict[str, int], force_overwrite: bool = False):
93
127
  """
94
- if isinstance(features, np.ndarray):
95
- features_tensor = torch.from_numpy(features).float()
96
- else:
97
- features_tensor = features.float()
98
-
99
- if self.scaler:
100
- features_tensor = self.scaler.transform(features_tensor)
128
+ Sets the class name mapping to translate predicted integer labels back into string names.
129
+
130
+ If a class_map was previously loaded from a model configuration, this
131
+ method will log a warning and refuse to update the value. This
132
+ prevents accidentally overriding a setting from a loaded checkpoint.
133
+
134
+ To bypass this safety check set `force_overwrite` to `True`.
101
135
 
102
- return features_tensor.to(self.device)
136
+ Args:
137
+ class_map (Dict[str, int]): The class_to_idx dictionary (e.g., {'cat': 0, 'dog': 1}).
138
+ force_overwrite (bool): If True, allows overwriting a map that was loaded from a configuration file.
139
+ """
140
+ if self._loaded_class_map:
141
+ warning_message = f"A '{PyTorchCheckpointKeys.CLASS_MAP}' was loaded from the model configuration file."
142
+ if not force_overwrite:
143
+ warning_message += " Use 'force_overwrite=True' if you are sure you want to modify it. This will not affect the value from the file."
144
+ _LOGGER.warning(warning_message)
145
+ return
146
+ else:
147
+ warning_message += " Overwriting it for this inference instance."
148
+ _LOGGER.warning(warning_message)
149
+
150
+ # Store the map and invert it for fast lookup
151
+ self._class_map = class_map
152
+ self._idx_to_class = {v: k for k, v in class_map.items()}
153
+ # Mark as 'loaded' by the user, even if it wasn't from a file, to prevent accidental changes later if _loaded_class_map was false.
154
+ self._loaded_class_map = True # Protect this newly set map
155
+ _LOGGER.info("Class map set for label-to-name translation.")
103
156
 
104
157
  @abstractmethod
105
158
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
@@ -112,46 +165,67 @@ class _BaseInferenceHandler(ABC):
112
165
  pass
113
166
 
114
167
 
115
- class PyTorchInferenceHandler(_BaseInferenceHandler):
168
+ class DragonInferenceHandler(_BaseInferenceHandler):
116
169
  """
117
- Handles loading a PyTorch model's state dictionary and performing inference
118
- for single-target regression or classification tasks.
170
+ Handles loading a PyTorch model's state dictionary and performing inference.
119
171
  """
120
172
  def __init__(self,
121
173
  model: nn.Module,
122
174
  state_dict: Union[str, Path],
123
- task: Literal["classification", "regression"],
175
+ task: Literal["regression", "binary classification", "multiclass classification", "multitarget regression", "multilabel binary classification"],
124
176
  device: str = 'cpu',
125
- target_id: Optional[str] = None,
126
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
177
+ target_ids: Optional[list[str]] = None,
178
+ scaler: Optional[Union[DragonScaler, str, Path]] = None):
127
179
  """
128
180
  Initializes the handler for single-target tasks.
129
181
 
130
182
  Args:
131
183
  model (nn.Module): An instantiated PyTorch model architecture.
132
184
  state_dict (str | Path): Path to the saved .pth model state_dict file.
133
- task (str): The type of task, 'regression' or 'classification'.
185
+ task (str): The type of task.
134
186
  device (str): The device to run inference on ('cpu', 'cuda', 'mps').
135
187
  target_id (str | None): An optional identifier for the target.
136
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
188
+ scaler (DragonScaler | str | Path | None): A DragonScaler instance or the file path to a saved DragonScaler state.
189
+
190
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
137
191
  """
138
192
  # Call the parent constructor to handle model loading, device, and scaler
139
193
  super().__init__(model, state_dict, device, scaler)
140
194
 
141
- if task not in ["classification", "regression"]:
142
- raise ValueError("`task` must be 'classification' or 'regression'.")
195
+ if task not in [MLTaskKeys.REGRESSION,
196
+ MLTaskKeys.BINARY_CLASSIFICATION,
197
+ MLTaskKeys.MULTICLASS_CLASSIFICATION,
198
+ MLTaskKeys.MULTITARGET_REGRESSION,
199
+ MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION]:
200
+ _LOGGER.error(f"'task' not recognized: '{task}'.")
201
+ raise ValueError()
143
202
  self.task = task
144
- self.target_id = target_id
203
+ self.target_ids = target_ids
204
+
205
+ def _preprocess_input(self, features: Union[np.ndarray, torch.Tensor]) -> torch.Tensor:
206
+ """
207
+ Converts input to a torch.Tensor, applies scaling if a scaler is
208
+ present, and moves it to the correct device.
209
+ """
210
+ if isinstance(features, np.ndarray):
211
+ features_tensor = torch.from_numpy(features).float()
212
+ else:
213
+ features_tensor = features.float()
214
+
215
+ if self.scaler:
216
+ features_tensor = self.scaler.transform(features_tensor)
217
+
218
+ return features_tensor.to(self.device)
145
219
 
146
220
  def predict_batch(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
147
221
  """
148
- Core batch prediction method for single-target models.
222
+ Core batch prediction method.
149
223
 
150
224
  Args:
151
225
  features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
152
226
 
153
227
  Returns:
154
- A dictionary containing the raw output tensors from the model.
228
+ Dict: A dictionary containing the raw output tensors from the model.
155
229
  """
156
230
  if features.ndim != 2:
157
231
  _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
@@ -162,16 +236,48 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
162
236
  with torch.no_grad():
163
237
  output = self.model(input_tensor)
164
238
 
165
- if self.task == "classification":
239
+ if self.task == MLTaskKeys.MULTICLASS_CLASSIFICATION:
166
240
  probs = torch.softmax(output, dim=1)
167
241
  labels = torch.argmax(probs, dim=1)
168
242
  return {
169
243
  PyTorchInferenceKeys.LABELS: labels,
170
244
  PyTorchInferenceKeys.PROBABILITIES: probs
171
245
  }
172
- else: # regression
246
+
247
+ elif self.task == MLTaskKeys.BINARY_CLASSIFICATION:
248
+ # Assumes model output is [N, 1] (a single logit)
249
+ # Squeeze output from [N, 1] to [N] if necessary
250
+ if output.ndim == 2 and output.shape[1] == 1:
251
+ output = output.squeeze(1)
252
+
253
+ probs = torch.sigmoid(output) # Probability of positive class
254
+ labels = (probs >= self._classification_threshold).int()
255
+ return {
256
+ PyTorchInferenceKeys.LABELS: labels,
257
+ PyTorchInferenceKeys.PROBABILITIES: probs
258
+ }
259
+
260
+ elif self.task == MLTaskKeys.REGRESSION:
173
261
  # For single-target regression, ensure output is flattened
174
262
  return {PyTorchInferenceKeys.PREDICTIONS: output.flatten()}
263
+
264
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
265
+ probs = torch.sigmoid(output)
266
+ # Get binary predictions based on the threshold
267
+ labels = (probs >= self._classification_threshold).int()
268
+ return {
269
+ PyTorchInferenceKeys.LABELS: labels,
270
+ PyTorchInferenceKeys.PROBABILITIES: probs
271
+ }
272
+
273
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
274
+ # The output is already in the correct [batch_size, n_targets] shape
275
+ return {PyTorchInferenceKeys.PREDICTIONS: output}
276
+
277
+ else:
278
+ # should never happen
279
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
280
+ raise ValueError()
175
281
 
176
282
  def predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, torch.Tensor]:
177
283
  """
@@ -181,7 +287,7 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
181
287
  features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
182
288
 
183
289
  Returns:
184
- A dictionary containing the raw output tensors for a single sample.
290
+ Dict: A dictionary containing the raw output tensors for a single sample.
185
291
  """
186
292
  if features.ndim == 1:
187
293
  features = features.reshape(1, -1) # Reshape to a batch of one
@@ -195,193 +301,129 @@ class PyTorchInferenceHandler(_BaseInferenceHandler):
195
301
  # Extract the first (and only) result from the batch output
196
302
  single_results = {key: value[0] for key, value in batch_results.items()}
197
303
  return single_results
198
-
304
+
199
305
  # --- NumPy Convenience Wrappers (on CPU) ---
200
306
 
201
307
  def predict_batch_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, np.ndarray]:
202
308
  """
203
- Convenience wrapper for predict_batch that returns NumPy arrays.
309
+ Convenience wrapper for predict_batch that returns NumPy arrays
310
+ and adds string labels for classification tasks if a class_map is set.
204
311
  """
205
312
  tensor_results = self.predict_batch(features)
206
313
  numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
314
+
315
+ # Add string names for classification if map exists
316
+ is_classification = self.task in [
317
+ MLTaskKeys.BINARY_CLASSIFICATION,
318
+ MLTaskKeys.MULTICLASS_CLASSIFICATION
319
+ ]
320
+
321
+ if is_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
322
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
323
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [ # type: ignore
324
+ self._idx_to_class.get(label_id, "Unknown")
325
+ for label_id in int_labels
326
+ ]
327
+
207
328
  return numpy_results
208
329
 
209
330
  def predict_numpy(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
210
331
  """
211
- Convenience wrapper for predict that returns NumPy arrays or scalars.
332
+ Convenience wrapper for predict that returns NumPy arrays or scalars
333
+ and adds string labels for classification tasks if a class_map is set.
212
334
  """
213
335
  tensor_results = self.predict(features)
214
336
 
215
- if self.task == "regression":
337
+ if self.task == MLTaskKeys.REGRESSION:
216
338
  # .item() implicitly moves to CPU and returns a Python scalar
217
339
  return {PyTorchInferenceKeys.PREDICTIONS: tensor_results[PyTorchInferenceKeys.PREDICTIONS].item()}
218
- else: # classification
340
+
341
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
342
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
343
+ label_name = "Unknown"
344
+ if self._idx_to_class:
345
+ label_name = self._idx_to_class.get(int_label, "Unknown") # type: ignore
346
+
219
347
  return {
220
- PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].item(),
348
+ PyTorchInferenceKeys.LABELS: int_label,
349
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
221
350
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
222
351
  }
352
+
353
+ elif self.task in [MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION, MLTaskKeys.MULTITARGET_REGRESSION]:
354
+ # For multi-target models, the output is always an array.
355
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
356
+ return numpy_results
357
+ else:
358
+ # should never happen
359
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
360
+ raise ValueError()
223
361
 
224
362
  def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
225
363
  """
226
364
  Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
227
365
 
228
- `target_id` must be implemented.
366
+ `target_ids` must be implemented.
229
367
  """
230
- if self.target_id is None:
231
- _LOGGER.error(f"'target_id' has not been implemented.")
368
+ if self.target_ids is None:
369
+ _LOGGER.error(f"'target_ids' has not been implemented.")
232
370
  raise AttributeError()
233
371
 
234
- if self.task == "regression":
372
+ if self.task == MLTaskKeys.REGRESSION:
235
373
  result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS]
236
- else:
374
+ return {self.target_ids[0]: result}
375
+
376
+ elif self.task in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
237
377
  result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS]
378
+ return {self.target_ids[0]: result}
238
379
 
239
- return {self.target_id: result}
240
-
241
-
242
- class PyTorchInferenceHandlerMulti(_BaseInferenceHandler):
243
- """
244
- Handles loading a PyTorch model's state dictionary and performing inference
245
- for multi-target regression or multi-label classification tasks.
246
- """
247
- def __init__(self,
248
- model: nn.Module,
249
- state_dict: Union[str, Path],
250
- task: Literal["multi_target_regression", "multi_label_classification"],
251
- device: str = 'cpu',
252
- target_ids: Optional[list[str]] = None,
253
- scaler: Optional[Union[PytorchScaler, str, Path]] = None):
254
- """
255
- Initializes the handler for multi-target tasks.
256
-
257
- Args:
258
- model (nn.Module): An instantiated PyTorch model.
259
- state_dict (str | Path): Path to the saved .pth model state_dict file.
260
- task (str): The type of task, 'multi_target_regression' or 'multi_label_classification'.
261
- device (str): The device to run inference on ('cpu', 'cuda', 'mps').
262
- target_ids (list[str] | None): An optional identifier for the targets.
263
- scaler (PytorchScaler | str | Path | None): A PytorchScaler instance or the file path to a saved PytorchScaler state.
264
- """
265
- super().__init__(model, state_dict, device, scaler)
266
-
267
- if task not in ["multi_target_regression", "multi_label_classification"]:
268
- _LOGGER.error("`task` must be 'multi_target_regression' or 'multi_label_classification'.")
269
- raise ValueError()
270
- self.task = task
271
- self.target_ids = target_ids
272
-
273
- def predict_batch(self,
274
- features: Union[np.ndarray, torch.Tensor],
275
- classification_threshold: float = 0.5
276
- ) -> Dict[str, torch.Tensor]:
277
- """
278
- Core batch prediction method for multi-target models.
279
-
280
- Args:
281
- features (np.ndarray | torch.Tensor): A 2D array/tensor of input features.
282
- classification_threshold (float): The threshold to convert probabilities
283
- into binary predictions for multi-label classification.
284
-
285
- Returns:
286
- A dictionary containing the raw output tensors from the model.
287
- """
288
- if features.ndim != 2:
289
- _LOGGER.error("Input for batch prediction must be a 2D array or tensor.")
290
- raise ValueError()
291
-
292
- input_tensor = self._preprocess_input(features)
293
-
294
- with torch.no_grad():
295
- output = self.model(input_tensor)
296
-
297
- if self.task == "multi_label_classification":
298
- probs = torch.sigmoid(output)
299
- # Get binary predictions based on the threshold
300
- labels = (probs >= classification_threshold).int()
301
- return {
302
- PyTorchInferenceKeys.LABELS: labels,
303
- PyTorchInferenceKeys.PROBABILITIES: probs
304
- }
305
- else: # multi_target_regression
306
- # The output is already in the correct [batch_size, n_targets] shape
307
- return {PyTorchInferenceKeys.PREDICTIONS: output}
308
-
309
- def predict(self,
310
- features: Union[np.ndarray, torch.Tensor],
311
- classification_threshold: float = 0.5
312
- ) -> Dict[str, torch.Tensor]:
313
- """
314
- Core single-sample prediction method for multi-target models.
315
-
316
- Args:
317
- features (np.ndarray | torch.Tensor): A 1D array/tensor of input features.
318
- classification_threshold (float): The threshold for multi-label tasks.
319
-
320
- Returns:
321
- A dictionary containing the raw output tensors for a single sample.
322
- """
323
- if features.ndim == 1:
324
- features = features.reshape(1, -1)
325
-
326
- if features.shape[0] != 1:
327
- _LOGGER.error("The 'predict()' method is for a single sample. 'Use predict_batch()' for multiple samples.")
380
+ elif self.task == MLTaskKeys.MULTITARGET_REGRESSION:
381
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
382
+ return {key: value for key, value in zip(self.target_ids, result)}
383
+
384
+ elif self.task == MLTaskKeys.MULTILABEL_BINARY_CLASSIFICATION:
385
+ result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
386
+ return {key: value for key, value in zip(self.target_ids, result)}
387
+
388
+ else:
389
+ # should never happen
390
+ _LOGGER.error(f"Unrecognized task '{self.task}'.")
328
391
  raise ValueError()
329
-
330
- batch_results = self.predict_batch(features, classification_threshold)
331
-
332
- single_results = {key: value[0] for key, value in batch_results.items()}
333
- return single_results
334
-
335
- # --- NumPy Convenience Wrappers (on CPU) ---
336
-
337
- def predict_batch_numpy(self,
338
- features: Union[np.ndarray, torch.Tensor],
339
- classification_threshold: float = 0.5
340
- ) -> Dict[str, np.ndarray]:
341
- """
342
- Convenience wrapper for predict_batch that returns NumPy arrays.
343
- """
344
- tensor_results = self.predict_batch(features, classification_threshold)
345
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
346
- return numpy_results
347
-
348
- def predict_numpy(self,
349
- features: Union[np.ndarray, torch.Tensor],
350
- classification_threshold: float = 0.5
351
- ) -> Dict[str, np.ndarray]:
352
- """
353
- Convenience wrapper for predict that returns NumPy arrays for a single sample.
354
- Note: For multi-target models, the output is always an array.
355
- """
356
- tensor_results = self.predict(features, classification_threshold)
357
- numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
358
- return numpy_results
359
-
360
- def quick_predict(self, features: Union[np.ndarray, torch.Tensor]) -> Dict[str, Any]:
361
- """
362
- Convenience wrapper to get the mapping {target_name: prediction} or {target_name: label}
363
392
 
364
- `target_ids` must be implemented.
393
+ def set_classification_threshold(self, threshold: float, force_overwrite: bool=False):
365
394
  """
366
- if self.target_ids is None:
367
- _LOGGER.error(f"'target_id' has not been implemented.")
368
- raise AttributeError()
395
+ Sets the classification threshold for the current inference instance.
369
396
 
370
- if self.task == "multi_target_regression":
371
- result = self.predict_numpy(features)[PyTorchInferenceKeys.PREDICTIONS].flatten().tolist()
372
- else:
373
- result = self.predict_numpy(features)[PyTorchInferenceKeys.LABELS].flatten().tolist()
397
+ If a threshold was previously loaded from a model configuration, this
398
+ method will log a warning and refuse to update the value. This
399
+ prevents accidentally overriding a setting from a loaded checkpoint.
374
400
 
375
- return {key: value for key, value in zip(self.target_ids, result)}
401
+ To bypass this safety check set `force_overwrite` to `True`.
376
402
 
403
+ Args:
404
+ threshold (float): The new classification threshold value to set.
405
+ force_overwrite (bool): If True, allows overwriting a threshold that was loaded from a configuration file.
406
+ """
407
+ if self._loaded_threshold:
408
+ warning_message = f"The current '{PyTorchCheckpointKeys.CLASSIFICATION_THRESHOLD}={self._classification_threshold}' was loaded and set from a model configuration file."
409
+ if not force_overwrite:
410
+ warning_message += " Use 'force_overwrite' if you are sure you want to modify it. This will not affect the value from the file."
411
+ _LOGGER.warning(warning_message)
412
+ return
413
+ else:
414
+ warning_message += f" Overwriting it to {threshold}."
415
+ _LOGGER.warning(warning_message)
416
+
417
+ self._classification_threshold = threshold
377
418
 
378
- def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
419
+
420
+ def multi_inference_regression(handlers: list[DragonInferenceHandler],
379
421
  feature_vector: Union[np.ndarray, torch.Tensor],
380
422
  output: Literal["numpy","torch"]="numpy") -> dict[str,Any]:
381
423
  """
382
424
  Performs regression inference using multiple models on a single feature vector.
383
425
 
384
- This function iterates through a list of PyTorchInferenceHandler objects,
426
+ This function iterates through a list of DragonInferenceHandler objects,
385
427
  each configured for a different regression target. It runs a prediction for
386
428
  each handler using the same input feature vector and returns the results
387
429
  in a dictionary.
@@ -391,7 +433,7 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
391
433
  - 2D input: Returns a dictionary mapping target ID to a list of values.
392
434
 
393
435
  Args:
394
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference
436
+ handlers (list[DragonInferenceHandler]): A list of initialized inference
395
437
  handlers. Each handler must have a unique `target_id` and be configured with `task="regression"`.
396
438
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D) or a batch of samples (2D) to be fed into each regression model.
397
439
  output (Literal["numpy", "torch"], optional): The desired format for the output predictions.
@@ -421,11 +463,11 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
421
463
  results: dict[str,Any] = dict()
422
464
  for handler in handlers:
423
465
  # validation
424
- if handler.target_id is None:
425
- _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
466
+ if handler.target_ids is None:
467
+ _LOGGER.error("All inference handlers must have a 'target_ids' attribute.")
426
468
  raise AttributeError()
427
- if handler.task != "regression":
428
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but only 'regression' tasks are supported.")
469
+ if handler.task != MLTaskKeys.REGRESSION:
470
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', only single target regression tasks are supported.")
429
471
  raise ValueError()
430
472
 
431
473
  # inference
@@ -434,33 +476,33 @@ def multi_inference_regression(handlers: list[PyTorchInferenceHandler],
434
476
  numpy_result = handler.predict_batch_numpy(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
435
477
  if is_single_sample:
436
478
  # For a single sample, convert the 1-element array to a Python scalar
437
- results[handler.target_id] = numpy_result.item()
479
+ results[handler.target_ids[0]] = numpy_result.item()
438
480
  else:
439
481
  # For a batch, return the full NumPy array of predictions
440
- results[handler.target_id] = numpy_result
482
+ results[handler.target_ids[0]] = numpy_result
441
483
 
442
484
  else: # output == "torch"
443
485
  # This path returns PyTorch tensors on the model's device
444
486
  torch_result = handler.predict_batch(feature_vector)[PyTorchInferenceKeys.PREDICTIONS]
445
487
  if is_single_sample:
446
488
  # For a single sample, return the 0-dim tensor
447
- results[handler.target_id] = torch_result[0]
489
+ results[handler.target_ids[0]] = torch_result[0]
448
490
  else:
449
491
  # For a batch, return the full tensor of predictions
450
- results[handler.target_id] = torch_result
492
+ results[handler.target_ids[0]] = torch_result
451
493
 
452
494
  return results
453
495
 
454
496
 
455
497
  def multi_inference_classification(
456
- handlers: list[PyTorchInferenceHandler],
498
+ handlers: list[DragonInferenceHandler],
457
499
  feature_vector: Union[np.ndarray, torch.Tensor],
458
500
  output: Literal["numpy","torch"]="numpy"
459
501
  ) -> tuple[dict[str, Any], dict[str, Any]]:
460
502
  """
461
503
  Performs classification inference on a single sample or a batch.
462
504
 
463
- This function iterates through a list of PyTorchInferenceHandler objects,
505
+ This function iterates through a list of DragonInferenceHandler objects,
464
506
  each configured for a different classification target. It returns two
465
507
  dictionaries: one for the predicted labels and one for the probabilities.
466
508
 
@@ -469,7 +511,7 @@ def multi_inference_classification(
469
511
  - 2D input: The dictionaries map target ID to an array of labels and an array of probability arrays.
470
512
 
471
513
  Args:
472
- handlers (list[PyTorchInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
514
+ handlers (list[DragonInferenceHandler]): A list of initialized inference handlers. Each must have a unique `target_id` and be configured
473
515
  with `task="classification"`.
474
516
  feature_vector (Union[np.ndarray, torch.Tensor]): An input sample (1D)
475
517
  or a batch of samples (2D) for prediction.
@@ -502,11 +544,11 @@ def multi_inference_classification(
502
544
 
503
545
  for handler in handlers:
504
546
  # Validation
505
- if handler.target_id is None:
547
+ if handler.target_ids is None:
506
548
  _LOGGER.error("All inference handlers must have a 'target_id' attribute.")
507
549
  raise AttributeError()
508
- if handler.task != "classification":
509
- _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_id}' is for '{handler.task}', but this function only supports 'classification'.")
550
+ if handler.task not in [MLTaskKeys.BINARY_CLASSIFICATION, MLTaskKeys.MULTICLASS_CLASSIFICATION]:
551
+ _LOGGER.error(f"Invalid task type: The handler for target_id '{handler.target_ids[0]}' is for '{handler.task}', but this function only supports binary and multiclass classification.")
510
552
  raise ValueError()
511
553
 
512
554
  # Inference
@@ -524,15 +566,15 @@ def multi_inference_classification(
524
566
  # For "numpy", convert the single label to a Python int scalar.
525
567
  # For "torch", get the 0-dim tensor label.
526
568
  if output == "numpy":
527
- labels_results[handler.target_id] = labels.item()
569
+ labels_results[handler.target_ids[0]] = labels.item()
528
570
  else: # torch
529
- labels_results[handler.target_id] = labels[0]
571
+ labels_results[handler.target_ids[0]] = labels[0]
530
572
 
531
573
  # The probabilities are an array/tensor of values
532
- probs_results[handler.target_id] = probabilities[0]
574
+ probs_results[handler.target_ids[0]] = probabilities[0]
533
575
  else:
534
- labels_results[handler.target_id] = labels
535
- probs_results[handler.target_id] = probabilities
576
+ labels_results[handler.target_ids[0]] = labels
577
+ probs_results[handler.target_ids[0]] = probabilities
536
578
 
537
579
  return labels_results, probs_results
538
580