dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
- dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +3 -3
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +309 -0
- ml_tools/ML_datasetmaster.py +220 -260
- ml_tools/ML_evaluation.py +317 -81
- ml_tools/ML_evaluation_multi.py +127 -36
- ml_tools/ML_inference.py +249 -207
- ml_tools/ML_models.py +13 -102
- ml_tools/ML_models_advanced.py +1 -1
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +215 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1247 -338
- ml_tools/ML_utilities.py +51 -2
- ml_tools/ML_vision_datasetmaster.py +262 -118
- ml_tools/ML_vision_evaluation.py +26 -6
- ml_tools/ML_vision_inference.py +117 -140
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +233 -7
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/{keys.py → _keys.py} +45 -1
- ml_tools/_schema.py +1 -1
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/utilities.py +1 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/_ML_vision_recipe.py +0 -88
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
ml_tools/ML_vision_evaluation.py
CHANGED
|
@@ -18,7 +18,8 @@ from torchmetrics.detection import MeanAveragePrecision
|
|
|
18
18
|
from .path_manager import make_fullpath
|
|
19
19
|
from ._logger import _LOGGER
|
|
20
20
|
from ._script_info import _script_info
|
|
21
|
-
from .
|
|
21
|
+
from ._keys import VisionKeys
|
|
22
|
+
from .ML_configuration import SegmentationMetricsFormat
|
|
22
23
|
|
|
23
24
|
|
|
24
25
|
__all__ = [
|
|
@@ -26,12 +27,15 @@ __all__ = [
|
|
|
26
27
|
"object_detection_metrics"
|
|
27
28
|
]
|
|
28
29
|
|
|
30
|
+
DPI_value = 250
|
|
31
|
+
|
|
29
32
|
|
|
30
33
|
def segmentation_metrics(
|
|
31
34
|
y_true: np.ndarray,
|
|
32
35
|
y_pred: np.ndarray,
|
|
33
36
|
save_dir: Union[str, Path],
|
|
34
|
-
class_names: Optional[List[str]] = None
|
|
37
|
+
class_names: Optional[List[str]] = None,
|
|
38
|
+
config: Optional[SegmentationMetricsFormat] = None # Add config object
|
|
35
39
|
):
|
|
36
40
|
"""
|
|
37
41
|
Calculates and saves pixel-level metrics for segmentation tasks.
|
|
@@ -48,9 +52,18 @@ def segmentation_metrics(
|
|
|
48
52
|
y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
|
|
49
53
|
save_dir (str | Path): Directory to save the metrics report and plots.
|
|
50
54
|
class_names (List[str] | None): Names of the classes for the report.
|
|
55
|
+
config (SegmentationMetricsFormat, optional): Formatting configuration object.
|
|
51
56
|
"""
|
|
52
57
|
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
53
58
|
|
|
59
|
+
# --- Parse Config or use defaults ---
|
|
60
|
+
if config is None:
|
|
61
|
+
config = SegmentationMetricsFormat()
|
|
62
|
+
|
|
63
|
+
# --- Set Matplotlib font size ---
|
|
64
|
+
original_rc_params = plt.rcParams.copy()
|
|
65
|
+
plt.rcParams.update({'font.size': config.font_size})
|
|
66
|
+
|
|
54
67
|
# Get all unique class labels present in either true or pred
|
|
55
68
|
labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
|
|
56
69
|
|
|
@@ -110,7 +123,7 @@ def segmentation_metrics(
|
|
|
110
123
|
report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
|
|
111
124
|
|
|
112
125
|
report_string = "\n".join(report_lines)
|
|
113
|
-
print(report_string)
|
|
126
|
+
# print(report_string) # <-- I removed the print(report_string)
|
|
114
127
|
|
|
115
128
|
# Save text report
|
|
116
129
|
save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
|
|
@@ -120,11 +133,11 @@ def segmentation_metrics(
|
|
|
120
133
|
|
|
121
134
|
# --- 3. Save Per-Class Metrics Heatmap ---
|
|
122
135
|
try:
|
|
123
|
-
plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=
|
|
136
|
+
plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=DPI_value)
|
|
124
137
|
sns.heatmap(
|
|
125
138
|
per_class_df.set_index('Class').T,
|
|
126
139
|
annot=True,
|
|
127
|
-
cmap=
|
|
140
|
+
cmap=config.heatmap_cmap, # Use config cmap
|
|
128
141
|
fmt='.3f',
|
|
129
142
|
linewidths=0.5
|
|
130
143
|
)
|
|
@@ -149,7 +162,11 @@ def segmentation_metrics(
|
|
|
149
162
|
confusion_matrix=cm,
|
|
150
163
|
display_labels=display_names
|
|
151
164
|
)
|
|
152
|
-
disp.plot(cmap=
|
|
165
|
+
disp.plot(cmap=config.cm_cmap, ax=ax_cm, xticks_rotation=45) # Use config cmap
|
|
166
|
+
|
|
167
|
+
# Manually update font size of cell texts
|
|
168
|
+
for text in disp.text_.flatten(): # type: ignore
|
|
169
|
+
text.set_fontsize(config.font_size)
|
|
153
170
|
|
|
154
171
|
ax_cm.set_title("Pixel-Level Confusion Matrix")
|
|
155
172
|
plt.tight_layout()
|
|
@@ -160,6 +177,9 @@ def segmentation_metrics(
|
|
|
160
177
|
plt.close(fig_cm)
|
|
161
178
|
except Exception as e:
|
|
162
179
|
_LOGGER.error(f"Could not generate confusion matrix: {e}")
|
|
180
|
+
|
|
181
|
+
# --- Restore RC params ---
|
|
182
|
+
plt.rcParams.update(original_rc_params)
|
|
163
183
|
|
|
164
184
|
|
|
165
185
|
def object_detection_metrics(
|
ml_tools/ML_vision_inference.py
CHANGED
|
@@ -8,32 +8,29 @@ from torchvision import transforms
|
|
|
8
8
|
|
|
9
9
|
from ._script_info import _script_info
|
|
10
10
|
from ._logger import _LOGGER
|
|
11
|
-
from .
|
|
12
|
-
from .
|
|
13
|
-
from .
|
|
11
|
+
from ._keys import PyTorchInferenceKeys, MLTaskKeys
|
|
12
|
+
from .ML_vision_transformers import _load_recipe_and_build_transform
|
|
13
|
+
from .ML_inference import _BaseInferenceHandler
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
__all__ = [
|
|
17
|
-
"
|
|
17
|
+
"DragonVisionInferenceHandler"
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
|
|
21
|
-
class
|
|
21
|
+
class DragonVisionInferenceHandler(_BaseInferenceHandler):
|
|
22
22
|
"""
|
|
23
23
|
Handles loading a PyTorch vision model's state dictionary and performing inference.
|
|
24
24
|
|
|
25
25
|
This class is specifically for vision models, which typically expect
|
|
26
26
|
4D Tensors (B, C, H, W) or Lists of Tensors as input.
|
|
27
|
-
It does NOT use a scaler, as preprocessing (e.g., normalization)
|
|
28
|
-
is assumed to be part of the input transform pipeline.
|
|
29
27
|
"""
|
|
30
28
|
def __init__(self,
|
|
31
29
|
model: nn.Module,
|
|
32
30
|
state_dict: Union[str, Path],
|
|
33
|
-
task: Literal["
|
|
31
|
+
task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
|
|
34
32
|
device: str = 'cpu',
|
|
35
|
-
transform_source: Optional[Union[str, Path, Callable]] = None
|
|
36
|
-
class_map: Optional[Dict[str, int]] = None):
|
|
33
|
+
transform_source: Optional[Union[str, Path, Callable]] = None):
|
|
37
34
|
"""
|
|
38
35
|
Initializes the vision inference handler.
|
|
39
36
|
|
|
@@ -46,19 +43,17 @@ class PyTorchVisionInferenceHandler:
|
|
|
46
43
|
- A path to a .json recipe file (str or Path).
|
|
47
44
|
- A pre-built transformation pipeline (Callable).
|
|
48
45
|
- None, in which case .set_transform() must be called explicitly to set transformations.
|
|
49
|
-
|
|
46
|
+
|
|
47
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
50
48
|
"""
|
|
51
|
-
|
|
52
|
-
|
|
49
|
+
super().__init__(model, state_dict, device, None)
|
|
50
|
+
|
|
53
51
|
self._transform: Optional[Callable] = None
|
|
54
52
|
self._is_transformed: bool = False
|
|
55
|
-
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
56
|
-
if class_map is not None:
|
|
57
|
-
self.set_class_map(class_map)
|
|
58
53
|
|
|
59
|
-
if task not in [
|
|
60
|
-
_LOGGER.error(f"
|
|
61
|
-
raise ValueError(
|
|
54
|
+
if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
|
|
55
|
+
_LOGGER.error(f"Unsupported task: '{task}'.")
|
|
56
|
+
raise ValueError()
|
|
62
57
|
self.task = task
|
|
63
58
|
|
|
64
59
|
self.expected_in_channels: int = 3 # Default to RGB
|
|
@@ -71,39 +66,6 @@ class PyTorchVisionInferenceHandler:
|
|
|
71
66
|
if transform_source:
|
|
72
67
|
self.set_transform(transform_source)
|
|
73
68
|
self._is_transformed = True
|
|
74
|
-
|
|
75
|
-
model_p = make_fullpath(state_dict, enforce="file")
|
|
76
|
-
|
|
77
|
-
try:
|
|
78
|
-
# Load whatever is in the file
|
|
79
|
-
loaded_data = torch.load(model_p, map_location=self._device)
|
|
80
|
-
|
|
81
|
-
# Check if it's a new checkpoint dictionary or an old weights-only file
|
|
82
|
-
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
83
|
-
# It's a new training checkpoint, extract the weights
|
|
84
|
-
self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
85
|
-
else:
|
|
86
|
-
# It's an old-style file (or just a state_dict), load it directly
|
|
87
|
-
self._model.load_state_dict(loaded_data)
|
|
88
|
-
|
|
89
|
-
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
90
|
-
|
|
91
|
-
self._model.to(self._device)
|
|
92
|
-
self._model.eval() # Set the model to evaluation mode
|
|
93
|
-
except Exception as e:
|
|
94
|
-
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
95
|
-
raise
|
|
96
|
-
|
|
97
|
-
def _validate_device(self, device: str) -> torch.device:
|
|
98
|
-
"""Validates the selected device and returns a torch.device object."""
|
|
99
|
-
device_lower = device.lower()
|
|
100
|
-
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
101
|
-
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
102
|
-
device_lower = "cpu"
|
|
103
|
-
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
104
|
-
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
105
|
-
device_lower = "cpu"
|
|
106
|
-
return torch.device(device_lower)
|
|
107
69
|
|
|
108
70
|
def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
109
71
|
"""
|
|
@@ -111,23 +73,23 @@ class PyTorchVisionInferenceHandler:
|
|
|
111
73
|
- For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
|
|
112
74
|
- For Object Detection: Expects List[Tensor(C, H, W)].
|
|
113
75
|
"""
|
|
114
|
-
if self.task ==
|
|
76
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
115
77
|
if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
|
|
116
78
|
_LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
|
|
117
79
|
raise ValueError("Invalid input type for object detection.")
|
|
118
80
|
# Move each tensor in the list to the device
|
|
119
|
-
return [t.float().to(self.
|
|
81
|
+
return [t.float().to(self.device) for t in inputs]
|
|
120
82
|
|
|
121
83
|
else: # Classification or Segmentation
|
|
122
84
|
if not isinstance(inputs, torch.Tensor):
|
|
123
85
|
_LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
|
|
124
86
|
raise ValueError(f"Invalid input type for {self.task}.")
|
|
125
87
|
|
|
126
|
-
if inputs.ndim != 4:
|
|
127
|
-
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
|
|
88
|
+
if inputs.ndim != 4: # type: ignore
|
|
89
|
+
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
|
|
128
90
|
raise ValueError("Input tensor must be 4D.")
|
|
129
91
|
|
|
130
|
-
return inputs.float().to(self.
|
|
92
|
+
return inputs.float().to(self.device)
|
|
131
93
|
|
|
132
94
|
def set_transform(self, transform_source: Union[str, Path, Callable]):
|
|
133
95
|
"""
|
|
@@ -144,8 +106,8 @@ class PyTorchVisionInferenceHandler:
|
|
|
144
106
|
if isinstance(transform_source, (str, Path)):
|
|
145
107
|
_LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
|
|
146
108
|
try:
|
|
147
|
-
# Use the
|
|
148
|
-
self._transform =
|
|
109
|
+
# Use the loader function
|
|
110
|
+
self._transform = _load_recipe_and_build_transform(transform_source)
|
|
149
111
|
except Exception as e:
|
|
150
112
|
_LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
|
|
151
113
|
raise
|
|
@@ -155,31 +117,15 @@ class PyTorchVisionInferenceHandler:
|
|
|
155
117
|
else:
|
|
156
118
|
_LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
|
|
157
119
|
raise TypeError("transform_source must be a file path or a Callable.")
|
|
158
|
-
|
|
159
|
-
def set_class_map(self, class_map: Dict[str, int]):
|
|
160
|
-
"""
|
|
161
|
-
Sets the class name mapping to translate predicted integer labels
|
|
162
|
-
back into string names.
|
|
163
|
-
|
|
164
|
-
Args:
|
|
165
|
-
class_map (Dict[str, int]): The class_to_idx dictionary
|
|
166
|
-
(e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
|
|
167
|
-
"""
|
|
168
|
-
if self._idx_to_class is not None:
|
|
169
|
-
_LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
|
|
170
|
-
# Invert the dictionary for fast lookup
|
|
171
|
-
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
172
|
-
_LOGGER.info("Class map set for label-to-name translation.")
|
|
173
120
|
|
|
174
121
|
def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
175
122
|
"""
|
|
176
123
|
Core batch prediction method for vision models.
|
|
177
|
-
All preprocessing (resizing, normalization) should be done *before*
|
|
178
|
-
calling this method.
|
|
124
|
+
All preprocessing (resizing, normalization) should be done *before* calling this method.
|
|
179
125
|
|
|
180
126
|
Args:
|
|
181
127
|
inputs (torch.Tensor | List[torch.Tensor]):
|
|
182
|
-
- For
|
|
128
|
+
- For binary/multiclass image classification or binary/multiclass image segmentation tasks,
|
|
183
129
|
a 4D torch.Tensor (B, C, H, W).
|
|
184
130
|
- For 'object_detection', a List of 3D torch.Tensors
|
|
185
131
|
[(C, H, W), ...], each with its own size.
|
|
@@ -194,45 +140,55 @@ class PyTorchVisionInferenceHandler:
|
|
|
194
140
|
processed_inputs = self._preprocess_batch(inputs)
|
|
195
141
|
|
|
196
142
|
with torch.no_grad():
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
# 2. Post-process
|
|
143
|
+
# get outputs
|
|
144
|
+
output = self.model(processed_inputs)
|
|
145
|
+
if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
|
|
146
|
+
# process
|
|
203
147
|
probs = torch.softmax(output, dim=1)
|
|
204
148
|
labels = torch.argmax(probs, dim=1)
|
|
205
149
|
return {
|
|
206
150
|
PyTorchInferenceKeys.LABELS: labels, # (B,)
|
|
207
151
|
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
|
|
208
152
|
}
|
|
209
|
-
|
|
210
|
-
elif self.task ==
|
|
211
|
-
#
|
|
212
|
-
# 1
|
|
213
|
-
output
|
|
153
|
+
|
|
154
|
+
elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
|
|
155
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
156
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
157
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
158
|
+
output = output.squeeze(1)
|
|
159
|
+
|
|
160
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
161
|
+
labels = (probs >= self._classification_threshold).int()
|
|
162
|
+
return {
|
|
163
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
164
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
|
|
168
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
169
|
+
probs = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
170
|
+
labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
171
|
+
return {
|
|
172
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
173
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
174
|
+
}
|
|
214
175
|
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
176
|
+
elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
177
|
+
# output shape [N, C, H, W]
|
|
178
|
+
probs = torch.softmax(output, dim=1)
|
|
179
|
+
labels = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
218
180
|
return {
|
|
219
|
-
PyTorchInferenceKeys.LABELS: labels, # (
|
|
220
|
-
PyTorchInferenceKeys.PROBABILITIES: probs # (
|
|
181
|
+
PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
|
|
182
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
|
|
221
183
|
}
|
|
222
184
|
|
|
223
|
-
elif self.task ==
|
|
224
|
-
# --- Object Detection ---
|
|
225
|
-
# 1. Predict (model is in eval mode, expects only images)
|
|
226
|
-
# Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
|
|
227
|
-
predictions = self._model(processed_inputs)
|
|
228
|
-
|
|
229
|
-
# 2. Post-process: Wrap in our standard key
|
|
185
|
+
elif self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
230
186
|
return {
|
|
231
|
-
PyTorchInferenceKeys.PREDICTIONS:
|
|
187
|
+
PyTorchInferenceKeys.PREDICTIONS: output
|
|
232
188
|
}
|
|
233
189
|
|
|
234
190
|
else:
|
|
235
|
-
# This should be unreachable due to
|
|
191
|
+
# This should be unreachable due to validation
|
|
236
192
|
raise ValueError(f"Unknown task: {self.task}")
|
|
237
193
|
|
|
238
194
|
def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
@@ -248,24 +204,24 @@ class PyTorchVisionInferenceHandler:
|
|
|
248
204
|
Returns:
|
|
249
205
|
A dictionary containing the output tensors for a single sample.
|
|
250
206
|
- Classification: {labels, probabilities} (label is 0-dim)
|
|
251
|
-
- Segmentation: {labels, probabilities} (label is 2D mask)
|
|
207
|
+
- Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
|
|
252
208
|
- Object Detection: {boxes, labels, scores} (single dict)
|
|
253
209
|
"""
|
|
254
210
|
if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
|
|
255
211
|
_LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
|
|
256
|
-
raise ValueError(
|
|
212
|
+
raise ValueError()
|
|
257
213
|
|
|
258
214
|
# --- 1. Batch the input based on task ---
|
|
259
|
-
if self.task ==
|
|
215
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
260
216
|
batched_input = [single_input] # List of one tensor
|
|
261
217
|
else:
|
|
262
|
-
batched_input = single_input.unsqueeze(0)
|
|
218
|
+
batched_input = single_input.unsqueeze(0)
|
|
263
219
|
|
|
264
220
|
# --- 2. Call batch prediction ---
|
|
265
221
|
batch_results = self.predict_batch(batched_input)
|
|
266
222
|
|
|
267
223
|
# --- 3. Un-batch the results ---
|
|
268
|
-
if self.task ==
|
|
224
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
269
225
|
# batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
|
|
270
226
|
return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
|
|
271
227
|
else:
|
|
@@ -283,12 +239,12 @@ class PyTorchVisionInferenceHandler:
|
|
|
283
239
|
Returns:
|
|
284
240
|
Dict: A dictionary containing the outputs as NumPy arrays.
|
|
285
241
|
- Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
|
|
286
|
-
- Classification: {labels:
|
|
242
|
+
- Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
|
|
287
243
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
288
244
|
"""
|
|
289
245
|
tensor_results = self.predict_batch(inputs)
|
|
290
246
|
|
|
291
|
-
if self.task ==
|
|
247
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
292
248
|
# Output is List[Dict[str, Tensor]]
|
|
293
249
|
# Convert each tensor inside each dict to numpy
|
|
294
250
|
numpy_results = []
|
|
@@ -304,13 +260,19 @@ class PyTorchVisionInferenceHandler:
|
|
|
304
260
|
]
|
|
305
261
|
numpy_results.append(np_dict)
|
|
306
262
|
return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
|
|
263
|
+
|
|
307
264
|
else:
|
|
308
|
-
# Output is Dict[str, Tensor]
|
|
265
|
+
# Output is Dict[str, Tensor] (for Classification or Segmentation)
|
|
309
266
|
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
310
267
|
|
|
311
268
|
# Add string names for classification if map exists
|
|
312
|
-
|
|
313
|
-
|
|
269
|
+
is_image_classification = self.task in [
|
|
270
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
271
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
275
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
314
276
|
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
315
277
|
self._idx_to_class.get(label_id, "Unknown")
|
|
316
278
|
for label_id in int_labels
|
|
@@ -324,13 +286,13 @@ class PyTorchVisionInferenceHandler:
|
|
|
324
286
|
|
|
325
287
|
Returns:
|
|
326
288
|
Dict: A dictionary containing the outputs as NumPy arrays/scalars.
|
|
327
|
-
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
|
|
289
|
+
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
|
|
328
290
|
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
329
291
|
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
330
292
|
"""
|
|
331
293
|
tensor_results = self.predict(single_input)
|
|
332
294
|
|
|
333
|
-
if self.task ==
|
|
295
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
334
296
|
# Output is Dict[str, Tensor]
|
|
335
297
|
# Convert each tensor to numpy
|
|
336
298
|
numpy_results = {
|
|
@@ -348,7 +310,7 @@ class PyTorchVisionInferenceHandler:
|
|
|
348
310
|
|
|
349
311
|
return numpy_results
|
|
350
312
|
|
|
351
|
-
elif self.task
|
|
313
|
+
elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
352
314
|
# Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
|
|
353
315
|
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
354
316
|
label_name = "Unknown"
|
|
@@ -360,50 +322,32 @@ class PyTorchVisionInferenceHandler:
|
|
|
360
322
|
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
361
323
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
362
324
|
}
|
|
363
|
-
else: # image_segmentation
|
|
325
|
+
else: # image_segmentation (binary or multiclass)
|
|
364
326
|
# Output is Dict[str, Tensor(2D) or Tensor(3D)]
|
|
365
327
|
return {
|
|
366
328
|
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
|
|
367
329
|
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
368
330
|
}
|
|
369
331
|
|
|
370
|
-
def
|
|
332
|
+
def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
|
|
371
333
|
"""
|
|
372
|
-
|
|
334
|
+
Applies the stored transform to a single PIL image and returns the prediction.
|
|
373
335
|
|
|
374
336
|
Args:
|
|
375
|
-
|
|
337
|
+
image (PIL.Image.Image): The input PIL image.
|
|
376
338
|
|
|
377
339
|
Returns:
|
|
378
340
|
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
379
341
|
"""
|
|
380
342
|
if self._transform is None:
|
|
381
|
-
_LOGGER.error("Cannot predict from
|
|
343
|
+
_LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
|
|
382
344
|
raise RuntimeError("Inference transform is not set.")
|
|
383
|
-
|
|
384
|
-
try:
|
|
385
|
-
# --- Use expected_in_channels to set PIL mode ---
|
|
386
|
-
pil_mode: str
|
|
387
|
-
if self.expected_in_channels == 1:
|
|
388
|
-
pil_mode = "L" # Grayscale
|
|
389
|
-
elif self.expected_in_channels == 4:
|
|
390
|
-
pil_mode = "RGBA" # RGB + Alpha
|
|
391
|
-
else:
|
|
392
|
-
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
393
|
-
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
394
|
-
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
395
|
-
pil_mode = "RGB"
|
|
396
|
-
|
|
397
|
-
image = Image.open(image_path).convert(pil_mode)
|
|
398
|
-
except Exception as e:
|
|
399
|
-
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
400
|
-
raise
|
|
401
345
|
|
|
402
346
|
# Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
|
|
403
347
|
try:
|
|
404
348
|
transformed_image = self._transform(image)
|
|
405
349
|
except Exception as e:
|
|
406
|
-
_LOGGER.error(f"Error applying transform to image: {e}")
|
|
350
|
+
_LOGGER.error(f"Error applying transform to PIL image: {e}")
|
|
407
351
|
raise
|
|
408
352
|
|
|
409
353
|
# --- Validation ---
|
|
@@ -413,7 +357,7 @@ class PyTorchVisionInferenceHandler:
|
|
|
413
357
|
|
|
414
358
|
if transformed_image.ndim != 3:
|
|
415
359
|
_LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
|
|
416
|
-
# .predict() which expects a 3D tensor
|
|
360
|
+
# .predict_numpy() -> .predict() which expects a 3D tensor
|
|
417
361
|
if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
|
|
418
362
|
transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
|
|
419
363
|
_LOGGER.warning("Removed an extra batch dimension.")
|
|
@@ -423,6 +367,39 @@ class PyTorchVisionInferenceHandler:
|
|
|
423
367
|
# Use the existing single-item predict method
|
|
424
368
|
return self.predict_numpy(transformed_image)
|
|
425
369
|
|
|
370
|
+
def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Loads a single image from a file, applies the stored transform, and returns the prediction.
|
|
373
|
+
|
|
374
|
+
This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
image_path (str | Path): The file path to the input image.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
# --- Use expected_in_channels to set PIL mode ---
|
|
384
|
+
pil_mode: str
|
|
385
|
+
if self.expected_in_channels == 1:
|
|
386
|
+
pil_mode = "L" # Grayscale
|
|
387
|
+
elif self.expected_in_channels == 4:
|
|
388
|
+
pil_mode = "RGBA" # RGB + Alpha
|
|
389
|
+
else:
|
|
390
|
+
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
391
|
+
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
392
|
+
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
393
|
+
pil_mode = "RGB"
|
|
394
|
+
|
|
395
|
+
image = Image.open(image_path).convert(pil_mode)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
# Call the PIL-based prediction method
|
|
401
|
+
return self.predict_from_pil(image)
|
|
402
|
+
|
|
426
403
|
|
|
427
404
|
def info():
|
|
428
405
|
_script_info(__all__)
|
ml_tools/ML_vision_models.py
CHANGED
|
@@ -47,12 +47,17 @@ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
47
47
|
self.num_classes = num_classes
|
|
48
48
|
self.in_channels = in_channels
|
|
49
49
|
self.model_name = model_name
|
|
50
|
+
self._pretrained_default_transforms = None
|
|
50
51
|
|
|
51
52
|
# --- 2. Instantiate the base model ---
|
|
52
53
|
if init_with_pretrained:
|
|
53
54
|
weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
|
|
54
55
|
weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
|
|
55
56
|
|
|
57
|
+
# Save transformations for pretrained models
|
|
58
|
+
if weights:
|
|
59
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
60
|
+
|
|
56
61
|
if weights is None and init_with_pretrained:
|
|
57
62
|
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
58
63
|
self.model = getattr(vision_models, model_name)(pretrained=True)
|
|
@@ -331,6 +336,7 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
331
336
|
self.num_classes = num_classes
|
|
332
337
|
self.in_channels = in_channels
|
|
333
338
|
self.model_name = model_name
|
|
339
|
+
self._pretrained_default_transforms = None
|
|
334
340
|
|
|
335
341
|
# --- 2. Instantiate the base model ---
|
|
336
342
|
model_kwargs = {
|
|
@@ -343,6 +349,10 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
343
349
|
weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
|
|
344
350
|
weights = weights_enum.DEFAULT if weights_enum else None
|
|
345
351
|
|
|
352
|
+
# save pretrained model transformations
|
|
353
|
+
if weights:
|
|
354
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
355
|
+
|
|
346
356
|
if weights is None:
|
|
347
357
|
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
348
358
|
# Legacy models used 'pretrained=True' and num_classes was separate
|
|
@@ -520,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
520
530
|
This wrapper allows for customizing the model backbone, input channels,
|
|
521
531
|
and the number of output classes for transfer learning.
|
|
522
532
|
|
|
523
|
-
NOTE:
|
|
533
|
+
NOTE: Use an Object Detection compatible trainer.
|
|
524
534
|
"""
|
|
525
535
|
def __init__(self,
|
|
526
536
|
num_classes: int,
|
|
@@ -550,6 +560,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
550
560
|
self.num_classes = num_classes
|
|
551
561
|
self.in_channels = in_channels
|
|
552
562
|
self.model_name = model_name
|
|
563
|
+
self._pretrained_default_transforms = None
|
|
553
564
|
|
|
554
565
|
# --- 2. Instantiate the base model ---
|
|
555
566
|
model_constructor = getattr(detection_models, model_name)
|
|
@@ -560,6 +571,9 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
560
571
|
|
|
561
572
|
weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
|
|
562
573
|
weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
|
|
574
|
+
|
|
575
|
+
if weights:
|
|
576
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
563
577
|
|
|
564
578
|
self.model = model_constructor(weights=weights, weights_backbone=weights)
|
|
565
579
|
|