dragon-ml-toolbox 14.8.0__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +9 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +204 -11
- ml_tools/ML_datasetmaster.py +198 -280
- ml_tools/ML_evaluation.py +132 -41
- ml_tools/ML_evaluation_multi.py +96 -35
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1237 -354
- ml_tools/ML_utilities.py +1 -1
- ml_tools/ML_vision_datasetmaster.py +73 -67
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +1 -1
- ml_tools/ML_vision_transformers.py +121 -40
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +43 -0
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +1 -1
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.8.0.dist-info/RECORD +0 -49
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.8.0.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_vision_inference.py
CHANGED
|
@@ -8,32 +8,29 @@ from torchvision import transforms
|
|
|
8
8
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
11
|
+
from ._keys import PyTorchInferenceKeys, MLTaskKeys
|
|
12
|
+
from .ML_vision_transformers import _load_recipe_and_build_transform
|
|
13
|
+
from .ML_inference import _BaseInferenceHandler
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
"
|
|
17
|
+
"DragonVisionInferenceHandler"
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class
|
|
21
|
+
class DragonVisionInferenceHandler(_BaseInferenceHandler):
|
|
22
22
|
"""
|
|
23
23
|
Handles loading a PyTorch vision model's state dictionary and performing inference.
|
|
24
24
|
|
|
25
25
|
This class is specifically for vision models, which typically expect
|
|
26
26
|
4D Tensors (B, C, H, W) or Lists of Tensors as input.
|
|
27
|
-
It does NOT use a scaler, as preprocessing (e.g., normalization)
|
|
28
|
-
is assumed to be part of the input transform pipeline.
|
|
29
27
|
"""
|
|
30
28
|
def __init__(self,
|
|
31
29
|
model: nn.Module,
|
|
32
30
|
state_dict: Union[str, Path],
|
|
33
|
-
task: Literal["
|
|
31
|
+
task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
|
|
34
32
|
device: str = 'cpu',
|
|
35
|
-
transform_source: Optional[Union[str, Path, Callable]] = None
|
|
36
|
-
class_map: Optional[Dict[str, int]] = None):
|
|
33
|
+
transform_source: Optional[Union[str, Path, Callable]] = None):
|
|
37
34
|
"""
|
|
38
35
|
Initializes the vision inference handler.
|
|
39
36
|
|
|
@@ -46,19 +43,17 @@ class PyTorchVisionInferenceHandler:
|
|
|
46
43
|
- A path to a .json recipe file (str or Path).
|
|
47
44
|
- A pre-built transformation pipeline (Callable).
|
|
48
45
|
- None, in which case .set_transform() must be called explicitly to set transformations.
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
50
48
|
"""
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
super().__init__(model, state_dict, device, None)
|
|
50
|
+
|
|
53
51
|
self._transform: Optional[Callable] = None
|
|
54
52
|
self._is_transformed: bool = False
|
|
55
|
-
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
56
|
-
if class_map is not None:
|
|
57
|
-
self.set_class_map(class_map)
|
|
58
53
|
|
|
59
|
-
if task not in [
|
|
60
|
-
_LOGGER.error(f"
|
|
61
|
-
raise ValueError(
|
|
54
|
+
if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
|
|
55
|
+
_LOGGER.error(f"Unsupported task: '{task}'.")
|
|
56
|
+
raise ValueError()
|
|
62
57
|
self.task = task
|
|
63
58
|
|
|
64
59
|
self.expected_in_channels: int = 3 # Default to RGB
|
|
@@ -71,39 +66,6 @@ class PyTorchVisionInferenceHandler:
|
|
|
71
66
|
if transform_source:
|
|
72
67
|
self.set_transform(transform_source)
|
|
73
68
|
self._is_transformed = True
|
|
74
|
-
|
|
75
|
-
model_p = make_fullpath(state_dict, enforce="file")
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
# Load whatever is in the file
|
|
79
|
-
loaded_data = torch.load(model_p, map_location=self._device)
|
|
80
|
-
|
|
81
|
-
# Check if it's a new checkpoint dictionary or an old weights-only file
|
|
82
|
-
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
83
|
-
# It's a new training checkpoint, extract the weights
|
|
84
|
-
self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
85
|
-
else:
|
|
86
|
-
# It's an old-style file (or just a state_dict), load it directly
|
|
87
|
-
self._model.load_state_dict(loaded_data)
|
|
88
|
-
|
|
89
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
90
|
-
|
|
91
|
-
self._model.to(self._device)
|
|
92
|
-
self._model.eval() # Set the model to evaluation mode
|
|
93
|
-
except Exception as e:
|
|
94
|
-
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
95
|
-
raise
|
|
96
|
-
|
|
97
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
98
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
99
|
-
device_lower = device.lower()
|
|
100
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
101
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
102
|
-
device_lower = "cpu"
|
|
103
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
104
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
105
|
-
device_lower = "cpu"
|
|
106
|
-
return torch.device(device_lower)
|
|
107
69
|
|
|
108
70
|
def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
109
71
|
"""
|
|
@@ -111,23 +73,23 @@ class PyTorchVisionInferenceHandler:
|
|
|
111
73
|
- For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
|
|
112
74
|
- For Object Detection: Expects List[Tensor(C, H, W)].
|
|
113
75
|
"""
|
|
114
|
-
if self.task ==
|
|
76
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
115
77
|
if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
|
|
116
78
|
_LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
|
|
117
79
|
raise ValueError("Invalid input type for object detection.")
|
|
118
80
|
# Move each tensor in the list to the device
|
|
119
|
-
return [t.float().to(self.
|
|
81
|
+
return [t.float().to(self.device) for t in inputs]
|
|
120
82
|
|
|
121
83
|
else: # Classification or Segmentation
|
|
122
84
|
if not isinstance(inputs, torch.Tensor):
|
|
123
85
|
_LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
|
|
124
86
|
raise ValueError(f"Invalid input type for {self.task}.")
|
|
125
87
|
|
|
126
|
-
if inputs.ndim != 4:
|
|
127
|
-
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
|
|
88
|
+
if inputs.ndim != 4: # type: ignore
|
|
89
|
+
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
|
|
128
90
|
raise ValueError("Input tensor must be 4D.")
|
|
129
91
|
|
|
130
|
-
return inputs.float().to(self.
|
|
92
|
+
return inputs.float().to(self.device)
|
|
131
93
|
|
|
132
94
|
def set_transform(self, transform_source: Union[str, Path, Callable]):
|
|
133
95
|
"""
|
|
@@ -144,8 +106,8 @@ class PyTorchVisionInferenceHandler:
|
|
|
144
106
|
if isinstance(transform_source, (str, Path)):
|
|
145
107
|
_LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
|
|
146
108
|
try:
|
|
147
|
-
# Use the
|
|
148
|
-
self._transform =
|
|
109
|
+
# Use the loader function
|
|
110
|
+
self._transform = _load_recipe_and_build_transform(transform_source)
|
|
149
111
|
except Exception as e:
|
|
150
112
|
_LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
|
|
151
113
|
raise
|
|
@@ -155,31 +117,15 @@ class PyTorchVisionInferenceHandler:
|
|
|
155
117
|
else:
|
|
156
118
|
_LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
|
|
157
119
|
raise TypeError("transform_source must be a file path or a Callable.")
|
|
158
|
-
|
|
159
|
-
def set_class_map(self, class_map: Dict[str, int]):
|
|
160
|
-
"""
|
|
161
|
-
Sets the class name mapping to translate predicted integer labels
|
|
162
|
-
back into string names.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
class_map (Dict[str, int]): The class_to_idx dictionary
|
|
166
|
-
(e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
|
|
167
|
-
"""
|
|
168
|
-
if self._idx_to_class is not None:
|
|
169
|
-
_LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
|
|
170
|
-
# Invert the dictionary for fast lookup
|
|
171
|
-
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
172
|
-
_LOGGER.info("Class map set for label-to-name translation.")
|
|
173
120
|
|
|
174
121
|
def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
175
122
|
"""
|
|
176
123
|
Core batch prediction method for vision models.
|
|
177
|
-
All preprocessing (resizing, normalization) should be done *before*
|
|
178
|
-
calling this method.
|
|
124
|
+
All preprocessing (resizing, normalization) should be done *before* calling this method.
|
|
179
125
|
|
|
180
126
|
Args:
|
|
181
127
|
inputs (torch.Tensor | List[torch.Tensor]):
|
|
182
|
-
- For
|
|
128
|
+
- For binary/multiclass image classification or binary/multiclass image segmentation tasks,
|
|
183
129
|
a 4D torch.Tensor (B, C, H, W).
|
|
184
130
|
- For 'object_detection', a List of 3D torch.Tensors
|
|
185
131
|
[(C, H, W), ...], each with its own size.
|
|
@@ -194,45 +140,55 @@ class PyTorchVisionInferenceHandler:
|
|
|
194
140
|
processed_inputs = self._preprocess_batch(inputs)
|
|
195
141
|
|
|
196
142
|
with torch.no_grad():
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
# 2. Post-process
|
|
143
|
+
# get outputs
|
|
144
|
+
output = self.model(processed_inputs)
|
|
145
|
+
if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
|
|
146
|
+
# process
|
|
203
147
|
probs = torch.softmax(output, dim=1)
|
|
204
148
|
labels = torch.argmax(probs, dim=1)
|
|
205
149
|
return {
|
|
206
150
|
PyTorchInferenceKeys.LABELS: labels, # (B,)
|
|
207
151
|
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
|
|
208
152
|
}
|
|
209
|
-
|
|
210
|
-
elif self.task ==
|
|
211
|
-
#
|
|
212
|
-
# 1
|
|
213
|
-
output
|
|
153
|
+
|
|
154
|
+
elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
|
|
155
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
156
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
157
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
158
|
+
output = output.squeeze(1)
|
|
159
|
+
|
|
160
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
161
|
+
labels = (probs >= self._classification_threshold).int()
|
|
162
|
+
return {
|
|
163
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
164
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
|
|
168
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
169
|
+
probs = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
170
|
+
labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
171
|
+
return {
|
|
172
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
173
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
174
|
+
}
|
|
214
175
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
176
|
+
elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
177
|
+
# output shape [N, C, H, W]
|
|
178
|
+
probs = torch.softmax(output, dim=1)
|
|
179
|
+
labels = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
218
180
|
return {
|
|
219
|
-
PyTorchInferenceKeys.LABELS: labels, # (
|
|
220
|
-
PyTorchInferenceKeys.PROBABILITIES: probs # (
|
|
181
|
+
PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
|
|
182
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
|
|
221
183
|
}
|
|
222
184
|
|
|
223
|
-
elif self.task ==
|
|
224
|
-
# --- Object Detection ---
|
|
225
|
-
# 1. Predict (model is in eval mode, expects only images)
|
|
226
|
-
# Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
|
|
227
|
-
predictions = self._model(processed_inputs)
|
|
228
|
-
|
|
229
|
-
# 2. Post-process: Wrap in our standard key
|
|
185
|
+
elif self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
230
186
|
return {
|
|
231
|
-
PyTorchInferenceKeys.PREDICTIONS:
|
|
187
|
+
PyTorchInferenceKeys.PREDICTIONS: output
|
|
232
188
|
}
|
|
233
189
|
|
|
234
190
|
else:
|
|
235
|
-
# This should be unreachable due to
|
|
191
|
+
# This should be unreachable due to validation
|
|
236
192
|
raise ValueError(f"Unknown task: {self.task}")
|
|
237
193
|
|
|
238
194
|
def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
@@ -248,24 +204,24 @@ class PyTorchVisionInferenceHandler:
|
|
|
248
204
|
Returns:
|
|
249
205
|
A dictionary containing the output tensors for a single sample.
|
|
250
206
|
- Classification: {labels, probabilities} (label is 0-dim)
|
|
251
|
-
- Segmentation: {labels, probabilities} (label is 2D mask)
|
|
207
|
+
- Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
|
|
252
208
|
- Object Detection: {boxes, labels, scores} (single dict)
|
|
253
209
|
"""
|
|
254
210
|
if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
|
|
255
211
|
_LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
|
|
256
|
-
raise ValueError(
|
|
212
|
+
raise ValueError()
|
|
257
213
|
|
|
258
214
|
# --- 1. Batch the input based on task ---
|
|
259
|
-
if self.task ==
|
|
215
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
260
216
|
batched_input = [single_input] # List of one tensor
|
|
261
217
|
else:
|
|
262
|
-
batched_input = single_input.unsqueeze(0)
|
|
218
|
+
batched_input = single_input.unsqueeze(0)
|
|
263
219
|
|
|
264
220
|
# --- 2. Call batch prediction ---
|
|
265
221
|
batch_results = self.predict_batch(batched_input)
|
|
266
222
|
|
|
267
223
|
# --- 3. Un-batch the results ---
|
|
268
|
-
if self.task ==
|
|
224
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
269
225
|
# batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
|
|
270
226
|
return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
|
|
271
227
|
else:
|
|
@@ -283,12 +239,12 @@ class PyTorchVisionInferenceHandler:
|
|
|
283
239
|
Returns:
|
|
284
240
|
Dict: A dictionary containing the outputs as NumPy arrays.
|
|
285
241
|
- Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
|
|
286
|
-
- Classification: {labels:
|
|
242
|
+
- Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
|
|
287
243
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
288
244
|
"""
|
|
289
245
|
tensor_results = self.predict_batch(inputs)
|
|
290
246
|
|
|
291
|
-
if self.task ==
|
|
247
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
292
248
|
# Output is List[Dict[str, Tensor]]
|
|
293
249
|
# Convert each tensor inside each dict to numpy
|
|
294
250
|
numpy_results = []
|
|
@@ -304,13 +260,19 @@ class PyTorchVisionInferenceHandler:
|
|
|
304
260
|
]
|
|
305
261
|
numpy_results.append(np_dict)
|
|
306
262
|
return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
|
|
263
|
+
|
|
307
264
|
else:
|
|
308
|
-
# Output is Dict[str, Tensor]
|
|
265
|
+
# Output is Dict[str, Tensor] (for Classification or Segmentation)
|
|
309
266
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
310
267
|
|
|
311
268
|
# Add string names for classification if map exists
|
|
312
|
-
|
|
313
|
-
|
|
269
|
+
is_image_classification = self.task in [
|
|
270
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
271
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
275
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
314
276
|
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
315
277
|
self._idx_to_class.get(label_id, "Unknown")
|
|
316
278
|
for label_id in int_labels
|
|
@@ -324,13 +286,13 @@ class PyTorchVisionInferenceHandler:
|
|
|
324
286
|
|
|
325
287
|
Returns:
|
|
326
288
|
Dict: A dictionary containing the outputs as NumPy arrays/scalars.
|
|
327
|
-
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
|
|
289
|
+
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
|
|
328
290
|
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
329
291
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
330
292
|
"""
|
|
331
293
|
tensor_results = self.predict(single_input)
|
|
332
294
|
|
|
333
|
-
if self.task ==
|
|
295
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
334
296
|
# Output is Dict[str, Tensor]
|
|
335
297
|
# Convert each tensor to numpy
|
|
336
298
|
numpy_results = {
|
|
@@ -348,7 +310,7 @@ class PyTorchVisionInferenceHandler:
|
|
|
348
310
|
|
|
349
311
|
return numpy_results
|
|
350
312
|
|
|
351
|
-
elif self.task
|
|
313
|
+
elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
352
314
|
# Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
|
|
353
315
|
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
354
316
|
label_name = "Unknown"
|
|
@@ -360,50 +322,32 @@ class PyTorchVisionInferenceHandler:
|
|
|
360
322
|
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
361
323
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
362
324
|
}
|
|
363
|
-
else: # image_segmentation
|
|
325
|
+
else: # image_segmentation (binary or multiclass)
|
|
364
326
|
# Output is Dict[str, Tensor(2D) or Tensor(3D)]
|
|
365
327
|
return {
|
|
366
328
|
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
|
|
367
329
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
368
330
|
}
|
|
369
331
|
|
|
370
|
-
def
|
|
332
|
+
def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
|
|
371
333
|
"""
|
|
372
|
-
|
|
334
|
+
Applies the stored transform to a single PIL image and returns the prediction.
|
|
373
335
|
|
|
374
336
|
Args:
|
|
375
|
-
|
|
337
|
+
image (PIL.Image.Image): The input PIL image.
|
|
376
338
|
|
|
377
339
|
Returns:
|
|
378
340
|
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
379
341
|
"""
|
|
380
342
|
if self._transform is None:
|
|
381
|
-
_LOGGER.error("Cannot predict from
|
|
343
|
+
_LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
|
|
382
344
|
raise RuntimeError("Inference transform is not set.")
|
|
383
|
-
|
|
384
|
-
try:
|
|
385
|
-
# --- Use expected_in_channels to set PIL mode ---
|
|
386
|
-
pil_mode: str
|
|
387
|
-
if self.expected_in_channels == 1:
|
|
388
|
-
pil_mode = "L" # Grayscale
|
|
389
|
-
elif self.expected_in_channels == 4:
|
|
390
|
-
pil_mode = "RGBA" # RGB + Alpha
|
|
391
|
-
else:
|
|
392
|
-
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
393
|
-
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
394
|
-
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
395
|
-
pil_mode = "RGB"
|
|
396
|
-
|
|
397
|
-
image = Image.open(image_path).convert(pil_mode)
|
|
398
|
-
except Exception as e:
|
|
399
|
-
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
400
|
-
raise
|
|
401
345
|
|
|
402
346
|
# Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
|
|
403
347
|
try:
|
|
404
348
|
transformed_image = self._transform(image)
|
|
405
349
|
except Exception as e:
|
|
406
|
-
_LOGGER.error(f"Error applying transform to image: {e}")
|
|
350
|
+
_LOGGER.error(f"Error applying transform to PIL image: {e}")
|
|
407
351
|
raise
|
|
408
352
|
|
|
409
353
|
# --- Validation ---
|
|
@@ -413,7 +357,7 @@ class PyTorchVisionInferenceHandler:
|
|
|
413
357
|
|
|
414
358
|
if transformed_image.ndim != 3:
|
|
415
359
|
_LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
|
|
416
|
-
# .predict() which expects a 3D tensor
|
|
360
|
+
# .predict_numpy() -> .predict() which expects a 3D tensor
|
|
417
361
|
if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
|
|
418
362
|
transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
|
|
419
363
|
_LOGGER.warning("Removed an extra batch dimension.")
|
|
@@ -423,6 +367,39 @@ class PyTorchVisionInferenceHandler:
|
|
|
423
367
|
# Use the existing single-item predict method
|
|
424
368
|
return self.predict_numpy(transformed_image)
|
|
425
369
|
|
|
370
|
+
def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Loads a single image from a file, applies the stored transform, and returns the prediction.
|
|
373
|
+
|
|
374
|
+
This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
image_path (str | Path): The file path to the input image.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
# --- Use expected_in_channels to set PIL mode ---
|
|
384
|
+
pil_mode: str
|
|
385
|
+
if self.expected_in_channels == 1:
|
|
386
|
+
pil_mode = "L" # Grayscale
|
|
387
|
+
elif self.expected_in_channels == 4:
|
|
388
|
+
pil_mode = "RGBA" # RGB + Alpha
|
|
389
|
+
else:
|
|
390
|
+
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
391
|
+
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
392
|
+
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
393
|
+
pil_mode = "RGB"
|
|
394
|
+
|
|
395
|
+
image = Image.open(image_path).convert(pil_mode)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
# Call the PIL-based prediction method
|
|
401
|
+
return self.predict_from_pil(image)
|
|
402
|
+
|
|
426
403
|
|
|
427
404
|
def info():
|
|
428
405
|
_script_info(__all__)
|
ml_tools/ML_vision_models.py
CHANGED
|
@@ -530,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
530
530
|
This wrapper allows for customizing the model backbone, input channels,
|
|
531
531
|
and the number of output classes for transfer learning.
|
|
532
532
|
|
|
533
|
-
NOTE:
|
|
533
|
+
NOTE: Use an Object Detection compatible trainer.
|
|
534
534
|
"""
|
|
535
535
|
def __init__(self,
|
|
536
536
|
num_classes: int,
|
|
@@ -2,10 +2,11 @@ from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
|
|
|
2
2
|
from PIL import ImageOps, Image
|
|
3
3
|
from torchvision import transforms
|
|
4
4
|
from pathlib import Path
|
|
5
|
+
import json
|
|
5
6
|
|
|
6
7
|
from ._logger import _LOGGER
|
|
7
8
|
from ._script_info import _script_info
|
|
8
|
-
from .
|
|
9
|
+
from ._keys import VisionTransformRecipeKeys
|
|
9
10
|
from .path_manager import make_fullpath
|
|
10
11
|
|
|
11
12
|
|
|
@@ -52,49 +53,12 @@ class ResizeAspectFill:
|
|
|
52
53
|
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
53
54
|
|
|
54
55
|
|
|
56
|
+
#############################################################
|
|
55
57
|
#NOTE: Add custom transforms.
|
|
56
58
|
TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
|
|
57
59
|
"ResizeAspectFill": ResizeAspectFill,
|
|
58
60
|
}
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
|
|
62
|
-
"""Internal helper to build a transform pipeline from a recipe dict."""
|
|
63
|
-
pipeline_steps: List[Callable] = []
|
|
64
|
-
|
|
65
|
-
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
66
|
-
_LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
|
|
67
|
-
raise ValueError("Invalid recipe format.")
|
|
68
|
-
|
|
69
|
-
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
70
|
-
t_name = step.get(VisionTransformRecipeKeys.NAME)
|
|
71
|
-
t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
|
|
72
|
-
|
|
73
|
-
if not t_name:
|
|
74
|
-
_LOGGER.error(f"Invalid transform step, missing 'name': {step}")
|
|
75
|
-
continue
|
|
76
|
-
|
|
77
|
-
transform_class: Any = None
|
|
78
|
-
|
|
79
|
-
# 1. Check standard torchvision transforms
|
|
80
|
-
if hasattr(transforms, t_name):
|
|
81
|
-
transform_class = getattr(transforms, t_name)
|
|
82
|
-
# 2. Check custom transforms
|
|
83
|
-
elif t_name in TRANSFORM_REGISTRY:
|
|
84
|
-
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
85
|
-
# 3. Not found
|
|
86
|
-
else:
|
|
87
|
-
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
88
|
-
raise ValueError(f"Unknown transform name: {t_name}")
|
|
89
|
-
|
|
90
|
-
# Instantiate the transform
|
|
91
|
-
try:
|
|
92
|
-
pipeline_steps.append(transform_class(**t_kwargs))
|
|
93
|
-
except Exception as e:
|
|
94
|
-
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
95
|
-
raise
|
|
96
|
-
|
|
97
|
-
return transforms.Compose(pipeline_steps)
|
|
61
|
+
#############################################################
|
|
98
62
|
|
|
99
63
|
|
|
100
64
|
def create_offline_augmentations(
|
|
@@ -199,5 +163,122 @@ def create_offline_augmentations(
|
|
|
199
163
|
_LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
|
|
200
164
|
|
|
201
165
|
|
|
166
|
+
def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
|
|
167
|
+
"""Internal helper to build a transform pipeline from a recipe dict."""
|
|
168
|
+
pipeline_steps: List[Callable] = []
|
|
169
|
+
|
|
170
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
171
|
+
_LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
|
|
172
|
+
raise ValueError("Invalid recipe format.")
|
|
173
|
+
|
|
174
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
175
|
+
t_name = step.get(VisionTransformRecipeKeys.NAME)
|
|
176
|
+
t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
|
|
177
|
+
|
|
178
|
+
if not t_name:
|
|
179
|
+
_LOGGER.error(f"Invalid transform step, missing 'name': {step}")
|
|
180
|
+
continue
|
|
181
|
+
|
|
182
|
+
transform_class: Any = None
|
|
183
|
+
|
|
184
|
+
# 1. Check standard torchvision transforms
|
|
185
|
+
if hasattr(transforms, t_name):
|
|
186
|
+
transform_class = getattr(transforms, t_name)
|
|
187
|
+
# 2. Check custom transforms
|
|
188
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
189
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
190
|
+
# 3. Not found
|
|
191
|
+
else:
|
|
192
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
193
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
194
|
+
|
|
195
|
+
# Instantiate the transform
|
|
196
|
+
try:
|
|
197
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
198
|
+
except Exception as e:
|
|
199
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
200
|
+
raise
|
|
201
|
+
|
|
202
|
+
return transforms.Compose(pipeline_steps)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _save_recipe(recipe: Dict[str, Any], filepath: Path) -> None:
|
|
206
|
+
"""
|
|
207
|
+
Saves a transform recipe dictionary to a JSON file.
|
|
208
|
+
|
|
209
|
+
Args:
|
|
210
|
+
recipe (Dict[str, Any]): The recipe dictionary to save.
|
|
211
|
+
filepath (str): The path to the output .json file.
|
|
212
|
+
"""
|
|
213
|
+
final_filepath = filepath.with_suffix(".json")
|
|
214
|
+
|
|
215
|
+
try:
|
|
216
|
+
with open(final_filepath, 'w') as f:
|
|
217
|
+
json.dump(recipe, f, indent=4)
|
|
218
|
+
_LOGGER.info(f"Transform recipe saved as '{final_filepath.name}'.")
|
|
219
|
+
except Exception as e:
|
|
220
|
+
_LOGGER.error(f"Failed to save recipe to '{final_filepath}': {e}")
|
|
221
|
+
raise
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def _load_recipe_and_build_transform(filepath: Union[str,Path]) -> transforms.Compose:
|
|
225
|
+
"""
|
|
226
|
+
Loads a transform recipe from a .json file and reconstructs the
|
|
227
|
+
torchvision.transforms.Compose pipeline.
|
|
228
|
+
|
|
229
|
+
Args:
|
|
230
|
+
filepath (str): Path to the saved transform recipe .json file.
|
|
231
|
+
|
|
232
|
+
Returns:
|
|
233
|
+
transforms.Compose: The reconstructed transformation pipeline.
|
|
234
|
+
|
|
235
|
+
Raises:
|
|
236
|
+
ValueError: If a transform name in the recipe is not found in
|
|
237
|
+
torchvision.transforms or the custom TRANSFORM_REGISTRY.
|
|
238
|
+
"""
|
|
239
|
+
# validate filepath
|
|
240
|
+
final_filepath = make_fullpath(filepath, enforce="file")
|
|
241
|
+
|
|
242
|
+
try:
|
|
243
|
+
with open(final_filepath, 'r') as f:
|
|
244
|
+
recipe = json.load(f)
|
|
245
|
+
except Exception as e:
|
|
246
|
+
_LOGGER.error(f"Failed to load recipe from '{final_filepath}': {e}")
|
|
247
|
+
raise
|
|
248
|
+
|
|
249
|
+
pipeline_steps: List[Callable] = []
|
|
250
|
+
|
|
251
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
252
|
+
_LOGGER.error("Recipe file is invalid: missing 'pipeline' key.")
|
|
253
|
+
raise ValueError("Invalid recipe format.")
|
|
254
|
+
|
|
255
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
256
|
+
t_name = step[VisionTransformRecipeKeys.NAME]
|
|
257
|
+
t_kwargs = step[VisionTransformRecipeKeys.KWARGS]
|
|
258
|
+
|
|
259
|
+
transform_class: Any = None
|
|
260
|
+
|
|
261
|
+
# 1. Check standard torchvision transforms
|
|
262
|
+
if hasattr(transforms, t_name):
|
|
263
|
+
transform_class = getattr(transforms, t_name)
|
|
264
|
+
# 2. Check custom transforms
|
|
265
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
266
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
267
|
+
# 3. Not found
|
|
268
|
+
else:
|
|
269
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
270
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
271
|
+
|
|
272
|
+
# Instantiate the transform
|
|
273
|
+
try:
|
|
274
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
275
|
+
except Exception as e:
|
|
276
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
277
|
+
raise
|
|
278
|
+
|
|
279
|
+
_LOGGER.info(f"Successfully loaded and built transform pipeline from '{final_filepath.name}'.")
|
|
280
|
+
return transforms.Compose(pipeline_steps)
|
|
281
|
+
|
|
282
|
+
|
|
202
283
|
def info():
|
|
203
284
|
_script_info(__all__)
|