dragon-ml-toolbox 14.7.0__py3-none-any.whl → 16.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.2.1.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +726 -32
- ml_tools/ML_datasetmaster.py +235 -280
- ml_tools/ML_evaluation.py +160 -42
- ml_tools/ML_evaluation_multi.py +103 -35
- ml_tools/ML_inference.py +290 -208
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1342 -386
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +120 -72
- ml_tools/ML_vision_evaluation.py +30 -6
- ml_tools/ML_vision_inference.py +129 -152
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.7.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.7.0.dist-info → dragon_ml_toolbox-16.2.1.dist-info}/top_level.txt +0 -0
ml_tools/ML_vision_inference.py
CHANGED
|
@@ -8,32 +8,29 @@ from torchvision import transforms
|
|
|
8
8
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
11
|
+
from ._keys import PyTorchInferenceKeys, MLTaskKeys
|
|
12
|
+
from .ML_vision_transformers import _load_recipe_and_build_transform
|
|
13
|
+
from .ML_inference import _BaseInferenceHandler
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
"
|
|
17
|
+
"DragonVisionInferenceHandler"
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class
|
|
21
|
+
class DragonVisionInferenceHandler(_BaseInferenceHandler):
|
|
22
22
|
"""
|
|
23
23
|
Handles loading a PyTorch vision model's state dictionary and performing inference.
|
|
24
24
|
|
|
25
25
|
This class is specifically for vision models, which typically expect
|
|
26
26
|
4D Tensors (B, C, H, W) or Lists of Tensors as input.
|
|
27
|
-
It does NOT use a scaler, as preprocessing (e.g., normalization)
|
|
28
|
-
is assumed to be part of the input transform pipeline.
|
|
29
27
|
"""
|
|
30
28
|
def __init__(self,
|
|
31
29
|
model: nn.Module,
|
|
32
30
|
state_dict: Union[str, Path],
|
|
33
|
-
task: Literal["
|
|
31
|
+
task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
|
|
34
32
|
device: str = 'cpu',
|
|
35
|
-
transform_source: Optional[Union[str, Path, Callable]] = None
|
|
36
|
-
class_map: Optional[Dict[str, int]] = None):
|
|
33
|
+
transform_source: Optional[Union[str, Path, Callable]] = None):
|
|
37
34
|
"""
|
|
38
35
|
Initializes the vision inference handler.
|
|
39
36
|
|
|
@@ -46,19 +43,17 @@ class PyTorchVisionInferenceHandler:
|
|
|
46
43
|
- A path to a .json recipe file (str or Path).
|
|
47
44
|
- A pre-built transformation pipeline (Callable).
|
|
48
45
|
- None, in which case .set_transform() must be called explicitly to set transformations.
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
50
48
|
"""
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
super().__init__(model, state_dict, device, None)
|
|
50
|
+
|
|
53
51
|
self._transform: Optional[Callable] = None
|
|
54
52
|
self._is_transformed: bool = False
|
|
55
|
-
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
56
|
-
if class_map is not None:
|
|
57
|
-
self.set_class_map(class_map)
|
|
58
53
|
|
|
59
|
-
if task not in [
|
|
60
|
-
_LOGGER.error(f"
|
|
61
|
-
raise ValueError(
|
|
54
|
+
if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
|
|
55
|
+
_LOGGER.error(f"Unsupported task: '{task}'.")
|
|
56
|
+
raise ValueError()
|
|
62
57
|
self.task = task
|
|
63
58
|
|
|
64
59
|
self.expected_in_channels: int = 3 # Default to RGB
|
|
@@ -71,39 +66,6 @@ class PyTorchVisionInferenceHandler:
|
|
|
71
66
|
if transform_source:
|
|
72
67
|
self.set_transform(transform_source)
|
|
73
68
|
self._is_transformed = True
|
|
74
|
-
|
|
75
|
-
model_p = make_fullpath(state_dict, enforce="file")
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
# Load whatever is in the file
|
|
79
|
-
loaded_data = torch.load(model_p, map_location=self._device)
|
|
80
|
-
|
|
81
|
-
# Check if it's a new checkpoint dictionary or an old weights-only file
|
|
82
|
-
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
83
|
-
# It's a new training checkpoint, extract the weights
|
|
84
|
-
self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
85
|
-
else:
|
|
86
|
-
# It's an old-style file (or just a state_dict), load it directly
|
|
87
|
-
self._model.load_state_dict(loaded_data)
|
|
88
|
-
|
|
89
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
90
|
-
|
|
91
|
-
self._model.to(self._device)
|
|
92
|
-
self._model.eval() # Set the model to evaluation mode
|
|
93
|
-
except Exception as e:
|
|
94
|
-
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
95
|
-
raise
|
|
96
|
-
|
|
97
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
98
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
99
|
-
device_lower = device.lower()
|
|
100
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
101
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
102
|
-
device_lower = "cpu"
|
|
103
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
104
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
105
|
-
device_lower = "cpu"
|
|
106
|
-
return torch.device(device_lower)
|
|
107
69
|
|
|
108
70
|
def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
109
71
|
"""
|
|
@@ -111,23 +73,23 @@ class PyTorchVisionInferenceHandler:
|
|
|
111
73
|
- For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
|
|
112
74
|
- For Object Detection: Expects List[Tensor(C, H, W)].
|
|
113
75
|
"""
|
|
114
|
-
if self.task ==
|
|
76
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
115
77
|
if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
|
|
116
78
|
_LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
|
|
117
79
|
raise ValueError("Invalid input type for object detection.")
|
|
118
80
|
# Move each tensor in the list to the device
|
|
119
|
-
return [t.float().to(self.
|
|
81
|
+
return [t.float().to(self.device) for t in inputs]
|
|
120
82
|
|
|
121
83
|
else: # Classification or Segmentation
|
|
122
84
|
if not isinstance(inputs, torch.Tensor):
|
|
123
85
|
_LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
|
|
124
86
|
raise ValueError(f"Invalid input type for {self.task}.")
|
|
125
87
|
|
|
126
|
-
if inputs.ndim != 4:
|
|
127
|
-
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
|
|
88
|
+
if inputs.ndim != 4: # type: ignore
|
|
89
|
+
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
|
|
128
90
|
raise ValueError("Input tensor must be 4D.")
|
|
129
91
|
|
|
130
|
-
return inputs.float().to(self.
|
|
92
|
+
return inputs.float().to(self.device)
|
|
131
93
|
|
|
132
94
|
def set_transform(self, transform_source: Union[str, Path, Callable]):
|
|
133
95
|
"""
|
|
@@ -144,8 +106,8 @@ class PyTorchVisionInferenceHandler:
|
|
|
144
106
|
if isinstance(transform_source, (str, Path)):
|
|
145
107
|
_LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
|
|
146
108
|
try:
|
|
147
|
-
# Use the
|
|
148
|
-
self._transform =
|
|
109
|
+
# Use the loader function
|
|
110
|
+
self._transform = _load_recipe_and_build_transform(transform_source)
|
|
149
111
|
except Exception as e:
|
|
150
112
|
_LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
|
|
151
113
|
raise
|
|
@@ -155,31 +117,15 @@ class PyTorchVisionInferenceHandler:
|
|
|
155
117
|
else:
|
|
156
118
|
_LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
|
|
157
119
|
raise TypeError("transform_source must be a file path or a Callable.")
|
|
158
|
-
|
|
159
|
-
def set_class_map(self, class_map: Dict[str, int]):
|
|
160
|
-
"""
|
|
161
|
-
Sets the class name mapping to translate predicted integer labels
|
|
162
|
-
back into string names.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
class_map (Dict[str, int]): The class_to_idx dictionary
|
|
166
|
-
(e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
|
|
167
|
-
"""
|
|
168
|
-
if self._idx_to_class is not None:
|
|
169
|
-
_LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
|
|
170
|
-
# Invert the dictionary for fast lookup
|
|
171
|
-
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
172
|
-
_LOGGER.info("Class map set for label-to-name translation.")
|
|
173
120
|
|
|
174
121
|
def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
175
122
|
"""
|
|
176
123
|
Core batch prediction method for vision models.
|
|
177
|
-
All preprocessing (resizing, normalization) should be done *before*
|
|
178
|
-
calling this method.
|
|
124
|
+
All preprocessing (resizing, normalization) should be done *before* calling this method.
|
|
179
125
|
|
|
180
126
|
Args:
|
|
181
127
|
inputs (torch.Tensor | List[torch.Tensor]):
|
|
182
|
-
- For
|
|
128
|
+
- For binary/multiclass image classification or binary/multiclass image segmentation tasks,
|
|
183
129
|
a 4D torch.Tensor (B, C, H, W).
|
|
184
130
|
- For 'object_detection', a List of 3D torch.Tensors
|
|
185
131
|
[(C, H, W), ...], each with its own size.
|
|
@@ -194,45 +140,55 @@ class PyTorchVisionInferenceHandler:
|
|
|
194
140
|
processed_inputs = self._preprocess_batch(inputs)
|
|
195
141
|
|
|
196
142
|
with torch.no_grad():
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
# 2. Post-process
|
|
143
|
+
# get outputs
|
|
144
|
+
output = self.model(processed_inputs)
|
|
145
|
+
if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
|
|
146
|
+
# process
|
|
203
147
|
probs = torch.softmax(output, dim=1)
|
|
204
148
|
labels = torch.argmax(probs, dim=1)
|
|
205
149
|
return {
|
|
206
150
|
PyTorchInferenceKeys.LABELS: labels, # (B,)
|
|
207
151
|
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
|
|
208
152
|
}
|
|
209
|
-
|
|
210
|
-
elif self.task ==
|
|
211
|
-
#
|
|
212
|
-
# 1
|
|
213
|
-
output
|
|
153
|
+
|
|
154
|
+
elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
|
|
155
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
156
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
157
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
158
|
+
output = output.squeeze(1)
|
|
159
|
+
|
|
160
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
161
|
+
labels = (probs >= self._classification_threshold).int()
|
|
162
|
+
return {
|
|
163
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
164
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
|
|
168
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
169
|
+
probs = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
170
|
+
labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
171
|
+
return {
|
|
172
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
173
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
174
|
+
}
|
|
214
175
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
176
|
+
elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
177
|
+
# output shape [N, C, H, W]
|
|
178
|
+
probs = torch.softmax(output, dim=1)
|
|
179
|
+
labels = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
218
180
|
return {
|
|
219
|
-
PyTorchInferenceKeys.LABELS: labels, # (
|
|
220
|
-
PyTorchInferenceKeys.PROBABILITIES: probs # (
|
|
181
|
+
PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
|
|
182
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
|
|
221
183
|
}
|
|
222
184
|
|
|
223
|
-
elif self.task ==
|
|
224
|
-
# --- Object Detection ---
|
|
225
|
-
# 1. Predict (model is in eval mode, expects only images)
|
|
226
|
-
# Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
|
|
227
|
-
predictions = self._model(processed_inputs)
|
|
228
|
-
|
|
229
|
-
# 2. Post-process: Wrap in our standard key
|
|
185
|
+
elif self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
230
186
|
return {
|
|
231
|
-
PyTorchInferenceKeys.PREDICTIONS:
|
|
187
|
+
PyTorchInferenceKeys.PREDICTIONS: output
|
|
232
188
|
}
|
|
233
189
|
|
|
234
190
|
else:
|
|
235
|
-
# This should be unreachable due to
|
|
191
|
+
# This should be unreachable due to validation
|
|
236
192
|
raise ValueError(f"Unknown task: {self.task}")
|
|
237
193
|
|
|
238
194
|
def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
@@ -248,24 +204,24 @@ class PyTorchVisionInferenceHandler:
|
|
|
248
204
|
Returns:
|
|
249
205
|
A dictionary containing the output tensors for a single sample.
|
|
250
206
|
- Classification: {labels, probabilities} (label is 0-dim)
|
|
251
|
-
- Segmentation: {labels, probabilities} (label is 2D mask)
|
|
207
|
+
- Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
|
|
252
208
|
- Object Detection: {boxes, labels, scores} (single dict)
|
|
253
209
|
"""
|
|
254
210
|
if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
|
|
255
211
|
_LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
|
|
256
|
-
raise ValueError(
|
|
212
|
+
raise ValueError()
|
|
257
213
|
|
|
258
214
|
# --- 1. Batch the input based on task ---
|
|
259
|
-
if self.task ==
|
|
215
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
260
216
|
batched_input = [single_input] # List of one tensor
|
|
261
217
|
else:
|
|
262
|
-
batched_input = single_input.unsqueeze(0)
|
|
218
|
+
batched_input = single_input.unsqueeze(0)
|
|
263
219
|
|
|
264
220
|
# --- 2. Call batch prediction ---
|
|
265
221
|
batch_results = self.predict_batch(batched_input)
|
|
266
222
|
|
|
267
223
|
# --- 3. Un-batch the results ---
|
|
268
|
-
if self.task ==
|
|
224
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
269
225
|
# batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
|
|
270
226
|
return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
|
|
271
227
|
else:
|
|
@@ -283,12 +239,12 @@ class PyTorchVisionInferenceHandler:
|
|
|
283
239
|
Returns:
|
|
284
240
|
Dict: A dictionary containing the outputs as NumPy arrays.
|
|
285
241
|
- Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
|
|
286
|
-
- Classification: {labels:
|
|
242
|
+
- Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
|
|
287
243
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
288
244
|
"""
|
|
289
245
|
tensor_results = self.predict_batch(inputs)
|
|
290
246
|
|
|
291
|
-
if self.task ==
|
|
247
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
292
248
|
# Output is List[Dict[str, Tensor]]
|
|
293
249
|
# Convert each tensor inside each dict to numpy
|
|
294
250
|
numpy_results = []
|
|
@@ -296,21 +252,27 @@ class PyTorchVisionInferenceHandler:
|
|
|
296
252
|
# Convert all tensors to numpy
|
|
297
253
|
np_dict = {key: value.cpu().numpy() for key, value in pred_dict.items()}
|
|
298
254
|
|
|
299
|
-
#
|
|
300
|
-
if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
255
|
+
# 3D pixel to string map unnecessary
|
|
256
|
+
# if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
|
|
257
|
+
# np_dict[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
258
|
+
# self._idx_to_class.get(label_id, "Unknown")
|
|
259
|
+
# for label_id in np_dict[PyTorchInferenceKeys.LABELS]
|
|
260
|
+
# ]
|
|
305
261
|
numpy_results.append(np_dict)
|
|
306
262
|
return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
|
|
263
|
+
|
|
307
264
|
else:
|
|
308
|
-
# Output is Dict[str, Tensor]
|
|
265
|
+
# Output is Dict[str, Tensor] (for Classification or Segmentation)
|
|
309
266
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
310
267
|
|
|
311
268
|
# Add string names for classification if map exists
|
|
312
|
-
|
|
313
|
-
|
|
269
|
+
is_image_classification = self.task in [
|
|
270
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
271
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
275
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
314
276
|
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
315
277
|
self._idx_to_class.get(label_id, "Unknown")
|
|
316
278
|
for label_id in int_labels
|
|
@@ -324,13 +286,13 @@ class PyTorchVisionInferenceHandler:
|
|
|
324
286
|
|
|
325
287
|
Returns:
|
|
326
288
|
Dict: A dictionary containing the outputs as NumPy arrays/scalars.
|
|
327
|
-
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
|
|
289
|
+
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
|
|
328
290
|
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
329
291
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
330
292
|
"""
|
|
331
293
|
tensor_results = self.predict(single_input)
|
|
332
294
|
|
|
333
|
-
if self.task ==
|
|
295
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
334
296
|
# Output is Dict[str, Tensor]
|
|
335
297
|
# Convert each tensor to numpy
|
|
336
298
|
numpy_results = {
|
|
@@ -338,17 +300,17 @@ class PyTorchVisionInferenceHandler:
|
|
|
338
300
|
}
|
|
339
301
|
|
|
340
302
|
# Add string names if map exists
|
|
341
|
-
if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
342
|
-
|
|
303
|
+
# if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
304
|
+
# int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
|
|
343
305
|
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
306
|
+
# numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
307
|
+
# self._idx_to_class.get(label_id, "Unknown")
|
|
308
|
+
# for label_id in int_labels
|
|
309
|
+
# ]
|
|
348
310
|
|
|
349
311
|
return numpy_results
|
|
350
312
|
|
|
351
|
-
elif self.task
|
|
313
|
+
elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
352
314
|
# Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
|
|
353
315
|
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
354
316
|
label_name = "Unknown"
|
|
@@ -360,50 +322,32 @@ class PyTorchVisionInferenceHandler:
|
|
|
360
322
|
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
361
323
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
362
324
|
}
|
|
363
|
-
else: # image_segmentation
|
|
325
|
+
else: # image_segmentation (binary or multiclass)
|
|
364
326
|
# Output is Dict[str, Tensor(2D) or Tensor(3D)]
|
|
365
327
|
return {
|
|
366
328
|
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
|
|
367
329
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
368
330
|
}
|
|
369
331
|
|
|
370
|
-
def
|
|
332
|
+
def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
|
|
371
333
|
"""
|
|
372
|
-
|
|
334
|
+
Applies the stored transform to a single PIL image and returns the prediction.
|
|
373
335
|
|
|
374
336
|
Args:
|
|
375
|
-
|
|
337
|
+
image (PIL.Image.Image): The input PIL image.
|
|
376
338
|
|
|
377
339
|
Returns:
|
|
378
340
|
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
379
341
|
"""
|
|
380
342
|
if self._transform is None:
|
|
381
|
-
_LOGGER.error("Cannot predict from
|
|
343
|
+
_LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
|
|
382
344
|
raise RuntimeError("Inference transform is not set.")
|
|
383
|
-
|
|
384
|
-
try:
|
|
385
|
-
# --- Use expected_in_channels to set PIL mode ---
|
|
386
|
-
pil_mode: str
|
|
387
|
-
if self.expected_in_channels == 1:
|
|
388
|
-
pil_mode = "L" # Grayscale
|
|
389
|
-
elif self.expected_in_channels == 4:
|
|
390
|
-
pil_mode = "RGBA" # RGB + Alpha
|
|
391
|
-
else:
|
|
392
|
-
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
393
|
-
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
394
|
-
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
395
|
-
pil_mode = "RGB"
|
|
396
|
-
|
|
397
|
-
image = Image.open(image_path).convert(pil_mode)
|
|
398
|
-
except Exception as e:
|
|
399
|
-
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
400
|
-
raise
|
|
401
345
|
|
|
402
346
|
# Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
|
|
403
347
|
try:
|
|
404
348
|
transformed_image = self._transform(image)
|
|
405
349
|
except Exception as e:
|
|
406
|
-
_LOGGER.error(f"Error applying transform to image: {e}")
|
|
350
|
+
_LOGGER.error(f"Error applying transform to PIL image: {e}")
|
|
407
351
|
raise
|
|
408
352
|
|
|
409
353
|
# --- Validation ---
|
|
@@ -413,7 +357,7 @@ class PyTorchVisionInferenceHandler:
|
|
|
413
357
|
|
|
414
358
|
if transformed_image.ndim != 3:
|
|
415
359
|
_LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
|
|
416
|
-
# .predict() which expects a 3D tensor
|
|
360
|
+
# .predict_numpy() -> .predict() which expects a 3D tensor
|
|
417
361
|
if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
|
|
418
362
|
transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
|
|
419
363
|
_LOGGER.warning("Removed an extra batch dimension.")
|
|
@@ -423,6 +367,39 @@ class PyTorchVisionInferenceHandler:
|
|
|
423
367
|
# Use the existing single-item predict method
|
|
424
368
|
return self.predict_numpy(transformed_image)
|
|
425
369
|
|
|
370
|
+
def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Loads a single image from a file, applies the stored transform, and returns the prediction.
|
|
373
|
+
|
|
374
|
+
This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
image_path (str | Path): The file path to the input image.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
# --- Use expected_in_channels to set PIL mode ---
|
|
384
|
+
pil_mode: str
|
|
385
|
+
if self.expected_in_channels == 1:
|
|
386
|
+
pil_mode = "L" # Grayscale
|
|
387
|
+
elif self.expected_in_channels == 4:
|
|
388
|
+
pil_mode = "RGBA" # RGB + Alpha
|
|
389
|
+
else:
|
|
390
|
+
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
391
|
+
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
392
|
+
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
393
|
+
pil_mode = "RGB"
|
|
394
|
+
|
|
395
|
+
image = Image.open(image_path).convert(pil_mode)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
# Call the PIL-based prediction method
|
|
401
|
+
return self.predict_from_pil(image)
|
|
402
|
+
|
|
426
403
|
|
|
427
404
|
def info():
|
|
428
405
|
_script_info(__all__)
|
ml_tools/ML_vision_models.py
CHANGED
|
@@ -530,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
530
530
|
This wrapper allows for customizing the model backbone, input channels,
|
|
531
531
|
and the number of output classes for transfer learning.
|
|
532
532
|
|
|
533
|
-
NOTE:
|
|
533
|
+
NOTE: Use an Object Detection compatible trainer.
|
|
534
534
|
"""
|
|
535
535
|
def __init__(self,
|
|
536
536
|
num_classes: int,
|