dragon-ml-toolbox 13.1.0__py3-none-any.whl → 14.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/METADATA +11 -2
- dragon_ml_toolbox-14.3.0.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_datasetmaster.py +63 -205
- ml_tools/ML_evaluation.py +23 -15
- ml_tools/ML_evaluation_multi.py +5 -6
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +22 -6
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_trainer.py +463 -20
- ml_tools/ML_utilities.py +302 -4
- ml_tools/ML_vision_datasetmaster.py +1352 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +79 -2
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/keys.py +42 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/serde.py +77 -15
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.1.0.dist-info/RECORD +0 -41
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import torch
|
|
6
|
+
from sklearn.metrics import (
|
|
7
|
+
accuracy_score,
|
|
8
|
+
f1_score,
|
|
9
|
+
jaccard_score,
|
|
10
|
+
confusion_matrix,
|
|
11
|
+
ConfusionMatrixDisplay
|
|
12
|
+
)
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Union, Optional, List, Dict
|
|
15
|
+
import json
|
|
16
|
+
from torchmetrics.detection import MeanAveragePrecision
|
|
17
|
+
|
|
18
|
+
from .path_manager import make_fullpath
|
|
19
|
+
from ._logger import _LOGGER
|
|
20
|
+
from ._script_info import _script_info
|
|
21
|
+
from .keys import VisionKeys
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
__all__ = [
|
|
25
|
+
"segmentation_metrics",
|
|
26
|
+
"object_detection_metrics"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def segmentation_metrics(
|
|
31
|
+
y_true: np.ndarray,
|
|
32
|
+
y_pred: np.ndarray,
|
|
33
|
+
save_dir: Union[str, Path],
|
|
34
|
+
class_names: Optional[List[str]] = None
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Calculates and saves pixel-level metrics for segmentation tasks.
|
|
38
|
+
|
|
39
|
+
Metrics include Pixel Accuracy, Dice (F1-score), and IoU (Jaccard).
|
|
40
|
+
It calculates 'micro', 'macro', and 'weighted' averages and saves
|
|
41
|
+
a pixel-level confusion matrix and a metrics heatmap.
|
|
42
|
+
|
|
43
|
+
Note: This function expects integer-based masks (e.g., shape [N, H, W] or [H, W]),
|
|
44
|
+
not one-hot encoded masks.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
y_true (np.ndarray): Ground truth masks (e.g., shape [N, H, W]).
|
|
48
|
+
y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
|
|
49
|
+
save_dir (str | Path): Directory to save the metrics report and plots.
|
|
50
|
+
class_names (List[str] | None): Names of the classes for the report.
|
|
51
|
+
"""
|
|
52
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
53
|
+
|
|
54
|
+
# Get all unique class labels present in either true or pred
|
|
55
|
+
labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
|
|
56
|
+
|
|
57
|
+
# --- Setup Class Names ---
|
|
58
|
+
display_names = []
|
|
59
|
+
if class_names is None:
|
|
60
|
+
display_names = [f"Class {i}" for i in labels]
|
|
61
|
+
else:
|
|
62
|
+
if len(class_names) != len(labels):
|
|
63
|
+
_LOGGER.warning(f"Number of class_names ({len(class_names)}) does not match number of unique labels ({len(labels)}). Using default names.")
|
|
64
|
+
display_names = [f"Class {i}" for i in labels]
|
|
65
|
+
else:
|
|
66
|
+
display_names = class_names
|
|
67
|
+
|
|
68
|
+
# Flatten masks for sklearn metrics
|
|
69
|
+
y_true_flat = y_true.ravel()
|
|
70
|
+
y_pred_flat = y_pred.ravel()
|
|
71
|
+
|
|
72
|
+
_LOGGER.info("--- Calculating Segmentation Metrics ---")
|
|
73
|
+
|
|
74
|
+
# --- 1. Calculate Metrics ---
|
|
75
|
+
pix_acc = accuracy_score(y_true_flat, y_pred_flat)
|
|
76
|
+
|
|
77
|
+
# Calculate all average types
|
|
78
|
+
dice_micro = f1_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
|
|
79
|
+
iou_micro = jaccard_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
|
|
80
|
+
|
|
81
|
+
dice_macro = f1_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
|
|
82
|
+
iou_macro = jaccard_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
|
|
83
|
+
|
|
84
|
+
dice_weighted = f1_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
|
|
85
|
+
iou_weighted = jaccard_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
|
|
86
|
+
|
|
87
|
+
# Per-class metrics
|
|
88
|
+
dice_per_class = f1_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
|
|
89
|
+
iou_per_class = jaccard_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
|
|
90
|
+
|
|
91
|
+
# --- 2. Create and Save Report ---
|
|
92
|
+
report_lines = [
|
|
93
|
+
"--- Segmentation Report ---",
|
|
94
|
+
f"\nOverall Pixel Accuracy: {pix_acc:.4f}\n",
|
|
95
|
+
"--- Averaged Metrics ---",
|
|
96
|
+
f"{'Average':<10} | {'Dice (F1)':<12} | {'IoU (Jaccard)':<12}",
|
|
97
|
+
"-"*41,
|
|
98
|
+
f"{'Micro':<10} | {dice_micro:<12.4f} | {iou_micro:<12.4f}",
|
|
99
|
+
f"{'Macro':<10} | {dice_macro:<12.4f} | {iou_macro:<12.4f}",
|
|
100
|
+
f"{'Weighted':<10} | {dice_weighted:<12.4f} | {iou_weighted:<12.4f}",
|
|
101
|
+
"\n--- Per-Class Metrics ---",
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
per_class_data = {
|
|
105
|
+
'Class': display_names,
|
|
106
|
+
'Dice': dice_per_class,
|
|
107
|
+
'IoU': iou_per_class
|
|
108
|
+
}
|
|
109
|
+
per_class_df = pd.DataFrame(per_class_data)
|
|
110
|
+
report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
|
|
111
|
+
|
|
112
|
+
report_string = "\n".join(report_lines)
|
|
113
|
+
print(report_string)
|
|
114
|
+
|
|
115
|
+
# Save text report
|
|
116
|
+
save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
|
|
117
|
+
report_path = save_dir_path / save_filename
|
|
118
|
+
report_path.write_text(report_string, encoding="utf-8")
|
|
119
|
+
_LOGGER.info(f"📝 Segmentation report saved as '{report_path.name}'")
|
|
120
|
+
|
|
121
|
+
# --- 3. Save Per-Class Metrics Heatmap ---
|
|
122
|
+
try:
|
|
123
|
+
plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=100)
|
|
124
|
+
sns.heatmap(
|
|
125
|
+
per_class_df.set_index('Class').T,
|
|
126
|
+
annot=True,
|
|
127
|
+
cmap='viridis',
|
|
128
|
+
fmt='.3f',
|
|
129
|
+
linewidths=0.5
|
|
130
|
+
)
|
|
131
|
+
plt.title("Per-Class Segmentation Metrics")
|
|
132
|
+
plt.tight_layout()
|
|
133
|
+
heatmap_filename = VisionKeys.SEGMENTATION_HEATMAP + ".svg"
|
|
134
|
+
heatmap_path = save_dir_path / heatmap_filename
|
|
135
|
+
plt.savefig(heatmap_path)
|
|
136
|
+
_LOGGER.info(f"📊 Metrics heatmap saved as '{heatmap_path.name}'")
|
|
137
|
+
plt.close()
|
|
138
|
+
except Exception as e:
|
|
139
|
+
_LOGGER.error(f"Could not generate segmentation metrics heatmap: {e}")
|
|
140
|
+
|
|
141
|
+
# --- 4. Save Pixel-level Confusion Matrix ---
|
|
142
|
+
try:
|
|
143
|
+
# Calculate CM
|
|
144
|
+
cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
|
|
145
|
+
|
|
146
|
+
# Plot
|
|
147
|
+
fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=100)
|
|
148
|
+
disp = ConfusionMatrixDisplay(
|
|
149
|
+
confusion_matrix=cm,
|
|
150
|
+
display_labels=display_names
|
|
151
|
+
)
|
|
152
|
+
disp.plot(cmap='Blues', ax=ax_cm, xticks_rotation=45)
|
|
153
|
+
|
|
154
|
+
ax_cm.set_title("Pixel-Level Confusion Matrix")
|
|
155
|
+
plt.tight_layout()
|
|
156
|
+
segmentation_cm_filename = VisionKeys.SEGMENTATION_CONFUSION_MATRIX + ".svg"
|
|
157
|
+
cm_path = save_dir_path / segmentation_cm_filename
|
|
158
|
+
plt.savefig(cm_path)
|
|
159
|
+
_LOGGER.info(f"❇️ Pixel-level confusion matrix saved as '{cm_path.name}'")
|
|
160
|
+
plt.close(fig_cm)
|
|
161
|
+
except Exception as e:
|
|
162
|
+
_LOGGER.error(f"Could not generate confusion matrix: {e}")
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def object_detection_metrics(
|
|
166
|
+
preds: List[Dict[str, torch.Tensor]],
|
|
167
|
+
targets: List[Dict[str, torch.Tensor]],
|
|
168
|
+
save_dir: Union[str, Path],
|
|
169
|
+
class_names: Optional[List[str]] = None,
|
|
170
|
+
print_output: bool=False
|
|
171
|
+
):
|
|
172
|
+
"""
|
|
173
|
+
Calculates and saves object detection metrics (mAP) using torchmetrics.
|
|
174
|
+
|
|
175
|
+
This function expects predictions and targets in the standard
|
|
176
|
+
torchvision format (list of dictionaries).
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
preds (List[Dict[str, torch.Tensor]]): A list of predictions.
|
|
180
|
+
Each dict must contain:
|
|
181
|
+
- 'boxes': [N, 4] (xmin, ymin, xmax, ymax)
|
|
182
|
+
- 'scores': [N]
|
|
183
|
+
- 'labels': [N]
|
|
184
|
+
targets (List[Dict[str, torch.Tensor]]): A list of ground truths.
|
|
185
|
+
Each dict must contain:
|
|
186
|
+
- 'boxes': [M, 4]
|
|
187
|
+
- 'labels': [M]
|
|
188
|
+
save_dir (str | Path): Directory to save the metrics report (as JSON).
|
|
189
|
+
class_names (List[str] | None): A list of class names, including 'background'
|
|
190
|
+
at index 0. Used to label per-class metrics in the report.
|
|
191
|
+
print_output (bool): If True, prints the JSON report to the console.
|
|
192
|
+
"""
|
|
193
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
194
|
+
|
|
195
|
+
_LOGGER.info("--- Calculating Object Detection Metrics (mAP) ---")
|
|
196
|
+
|
|
197
|
+
try:
|
|
198
|
+
# Initialize the metric with standard COCO settings
|
|
199
|
+
metric = MeanAveragePrecision(box_format='xyxy')
|
|
200
|
+
|
|
201
|
+
# Move preds and targets to the same device (e.g., CPU for metric calculation)
|
|
202
|
+
# This avoids device mismatches if model was on GPU
|
|
203
|
+
device = torch.device("cpu")
|
|
204
|
+
preds_cpu = [{k: v.to(device) for k, v in p.items()} for p in preds]
|
|
205
|
+
targets_cpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
206
|
+
|
|
207
|
+
# Update the metric
|
|
208
|
+
metric.update(preds_cpu, targets_cpu)
|
|
209
|
+
|
|
210
|
+
# Compute the final metrics
|
|
211
|
+
results = metric.compute()
|
|
212
|
+
|
|
213
|
+
# --- Handle class names for per-class metrics ---
|
|
214
|
+
report_class_names = None
|
|
215
|
+
if class_names:
|
|
216
|
+
if class_names[0].lower() in ['background', "bg"]:
|
|
217
|
+
report_class_names = class_names[1:] # Skip background (class 0)
|
|
218
|
+
else:
|
|
219
|
+
_LOGGER.warning("class_names provided to object_detection_metrics, but 'background' was not class 0. Using all provided names.")
|
|
220
|
+
report_class_names = class_names
|
|
221
|
+
|
|
222
|
+
# Convert all torch tensors in results to floats/lists for JSON serialization
|
|
223
|
+
serializable_results = {}
|
|
224
|
+
for key, value in results.items():
|
|
225
|
+
if isinstance(value, torch.Tensor):
|
|
226
|
+
if value.numel() == 1:
|
|
227
|
+
serializable_results[key] = value.item()
|
|
228
|
+
# Check if it's a 1D tensor, we have class names, and it's a known per-class key
|
|
229
|
+
elif value.ndim == 1 and report_class_names and key in ('map_per_class', 'mar_100_per_class', 'mar_1_per_class', 'mar_10_per_class'):
|
|
230
|
+
per_class_list = value.cpu().numpy().tolist()
|
|
231
|
+
# Map names to values
|
|
232
|
+
if len(per_class_list) == len(report_class_names):
|
|
233
|
+
serializable_results[key] = {name: val for name, val in zip(report_class_names, per_class_list)}
|
|
234
|
+
else:
|
|
235
|
+
_LOGGER.warning(f"Length mismatch for '{key}': {len(per_class_list)} values vs {len(report_class_names)} class names. Saving as raw list.")
|
|
236
|
+
serializable_results[key] = per_class_list
|
|
237
|
+
else:
|
|
238
|
+
serializable_results[key] = value.cpu().numpy().tolist()
|
|
239
|
+
else:
|
|
240
|
+
serializable_results[key] = value
|
|
241
|
+
|
|
242
|
+
# Pretty print to console
|
|
243
|
+
if print_output:
|
|
244
|
+
print(json.dumps(serializable_results, indent=4))
|
|
245
|
+
|
|
246
|
+
# Save JSON report
|
|
247
|
+
detection_report_filename = VisionKeys.OBJECT_DETECTION_REPORT + ".json"
|
|
248
|
+
report_path = save_dir_path / detection_report_filename
|
|
249
|
+
with open(report_path, 'w') as f:
|
|
250
|
+
json.dump(serializable_results, f, indent=4)
|
|
251
|
+
|
|
252
|
+
_LOGGER.info(f"📊 Object detection (mAP) report saved as '{report_path.name}'")
|
|
253
|
+
|
|
254
|
+
except Exception as e:
|
|
255
|
+
_LOGGER.error(f"Failed to compute mAP: {e}")
|
|
256
|
+
raise
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def info():
|
|
260
|
+
_script_info(__all__)
|
|
@@ -0,0 +1,428 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import numpy as np #numpy array return value
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union, Literal, Dict, Any, List, Optional, Callable
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision import transforms
|
|
8
|
+
|
|
9
|
+
from ._script_info import _script_info
|
|
10
|
+
from ._logger import _LOGGER
|
|
11
|
+
from .path_manager import make_fullpath
|
|
12
|
+
from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
|
|
13
|
+
from ._ML_vision_recipe import load_recipe_and_build_transform
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"PyTorchVisionInferenceHandler"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class PyTorchVisionInferenceHandler:
|
|
22
|
+
"""
|
|
23
|
+
Handles loading a PyTorch vision model's state dictionary and performing inference.
|
|
24
|
+
|
|
25
|
+
This class is specifically for vision models, which typically expect
|
|
26
|
+
4D Tensors (B, C, H, W) or Lists of Tensors as input.
|
|
27
|
+
It does NOT use a scaler, as preprocessing (e.g., normalization)
|
|
28
|
+
is assumed to be part of the input transform pipeline.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self,
|
|
31
|
+
model: nn.Module,
|
|
32
|
+
state_dict: Union[str, Path],
|
|
33
|
+
task: Literal["image_classification", "image_segmentation", "object_detection"],
|
|
34
|
+
device: str = 'cpu',
|
|
35
|
+
transform_source: Optional[Union[str, Path, Callable]] = None,
|
|
36
|
+
class_map: Optional[Dict[str, int]] = None):
|
|
37
|
+
"""
|
|
38
|
+
Initializes the vision inference handler.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
model (nn.Module): An instantiated PyTorch model from ML_vision_models.
|
|
42
|
+
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
43
|
+
task (str): The type of vision task.
|
|
44
|
+
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
45
|
+
transform_source (str | Path | Callable | None):
|
|
46
|
+
- A path to a .json recipe file (str or Path).
|
|
47
|
+
- A pre-built transformation pipeline (Callable).
|
|
48
|
+
- None, in which case .set_transform() must be called explicitly to set transformations.
|
|
49
|
+
idx_to_class (Dict[int, str] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For image classification and object detection)
|
|
50
|
+
"""
|
|
51
|
+
self._model = model
|
|
52
|
+
self._device = self._validate_device(device)
|
|
53
|
+
self._transform: Optional[Callable] = None
|
|
54
|
+
self._is_transformed: bool = False
|
|
55
|
+
self._idx_to_class: Optional[Dict[int, str]] = None
|
|
56
|
+
if class_map is not None:
|
|
57
|
+
self.set_class_map(class_map)
|
|
58
|
+
|
|
59
|
+
if task not in ["image_classification", "image_segmentation", "object_detection"]:
|
|
60
|
+
_LOGGER.error(f"`task` must be 'image_classification', 'image_segmentation', or 'object_detection'. Got '{task}'.")
|
|
61
|
+
raise ValueError("Invalid task type.")
|
|
62
|
+
self.task = task
|
|
63
|
+
|
|
64
|
+
self.expected_in_channels: int = 3 # Default to RGB
|
|
65
|
+
if hasattr(model, 'in_channels'):
|
|
66
|
+
self.expected_in_channels = model.in_channels # type: ignore
|
|
67
|
+
_LOGGER.info(f"Model expects {self.expected_in_channels} input channels.")
|
|
68
|
+
else:
|
|
69
|
+
_LOGGER.warning("Could not determine 'in_channels' from model. Defaulting to 3 (RGB). Modify with '.expected_in_channels'.")
|
|
70
|
+
|
|
71
|
+
if transform_source:
|
|
72
|
+
self.set_transform(transform_source)
|
|
73
|
+
self._is_transformed = True
|
|
74
|
+
|
|
75
|
+
model_p = make_fullpath(state_dict, enforce="file")
|
|
76
|
+
|
|
77
|
+
try:
|
|
78
|
+
# Load whatever is in the file
|
|
79
|
+
loaded_data = torch.load(model_p, map_location=self._device)
|
|
80
|
+
|
|
81
|
+
# Check if it's a new checkpoint dictionary or an old weights-only file
|
|
82
|
+
if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
|
|
83
|
+
# It's a new training checkpoint, extract the weights
|
|
84
|
+
self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
|
|
85
|
+
else:
|
|
86
|
+
# It's an old-style file (or just a state_dict), load it directly
|
|
87
|
+
self._model.load_state_dict(loaded_data)
|
|
88
|
+
|
|
89
|
+
_LOGGER.info(f"Model state loaded from '{model_p.name}'.")
|
|
90
|
+
|
|
91
|
+
self._model.to(self._device)
|
|
92
|
+
self._model.eval() # Set the model to evaluation mode
|
|
93
|
+
except Exception as e:
|
|
94
|
+
_LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
|
|
95
|
+
raise
|
|
96
|
+
|
|
97
|
+
def _validate_device(self, device: str) -> torch.device:
|
|
98
|
+
"""Validates the selected device and returns a torch.device object."""
|
|
99
|
+
device_lower = device.lower()
|
|
100
|
+
if "cuda" in device_lower and not torch.cuda.is_available():
|
|
101
|
+
_LOGGER.warning("CUDA not available, switching to CPU.")
|
|
102
|
+
device_lower = "cpu"
|
|
103
|
+
elif device_lower == "mps" and not torch.backends.mps.is_available():
|
|
104
|
+
_LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
|
|
105
|
+
device_lower = "cpu"
|
|
106
|
+
return torch.device(device_lower)
|
|
107
|
+
|
|
108
|
+
def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
109
|
+
"""
|
|
110
|
+
Validates input and moves it to the correct device.
|
|
111
|
+
- For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
|
|
112
|
+
- For Object Detection: Expects List[Tensor(C, H, W)].
|
|
113
|
+
"""
|
|
114
|
+
if self.task == "object_detection":
|
|
115
|
+
if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
|
|
116
|
+
_LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
|
|
117
|
+
raise ValueError("Invalid input type for object detection.")
|
|
118
|
+
# Move each tensor in the list to the device
|
|
119
|
+
return [t.float().to(self._device) for t in inputs]
|
|
120
|
+
|
|
121
|
+
else: # Classification or Segmentation
|
|
122
|
+
if not isinstance(inputs, torch.Tensor):
|
|
123
|
+
_LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
|
|
124
|
+
raise ValueError(f"Invalid input type for {self.task}.")
|
|
125
|
+
|
|
126
|
+
if inputs.ndim != 4:
|
|
127
|
+
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
|
|
128
|
+
raise ValueError("Input tensor must be 4D.")
|
|
129
|
+
|
|
130
|
+
return inputs.float().to(self._device)
|
|
131
|
+
|
|
132
|
+
def set_transform(self, transform_source: Union[str, Path, Callable]):
|
|
133
|
+
"""
|
|
134
|
+
Sets or updates the inference transformation pipeline from a recipe file or a direct Callable.
|
|
135
|
+
|
|
136
|
+
Args:
|
|
137
|
+
transform_source (str, Path, Callable):
|
|
138
|
+
- A path to a .json recipe file (str or Path).
|
|
139
|
+
- A pre-built transformation pipeline (Callable).
|
|
140
|
+
"""
|
|
141
|
+
if self._is_transformed:
|
|
142
|
+
_LOGGER.warning("Transformations were previously applied. Applying new transformations...")
|
|
143
|
+
|
|
144
|
+
if isinstance(transform_source, (str, Path)):
|
|
145
|
+
_LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
|
|
146
|
+
try:
|
|
147
|
+
# Use the new loader function
|
|
148
|
+
self._transform = load_recipe_and_build_transform(transform_source)
|
|
149
|
+
except Exception as e:
|
|
150
|
+
_LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
|
|
151
|
+
raise
|
|
152
|
+
elif isinstance(transform_source, Callable):
|
|
153
|
+
_LOGGER.info("Inference transform has been set from a direct Callable.")
|
|
154
|
+
self._transform = transform_source
|
|
155
|
+
else:
|
|
156
|
+
_LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
|
|
157
|
+
raise TypeError("transform_source must be a file path or a Callable.")
|
|
158
|
+
|
|
159
|
+
def set_class_map(self, class_map: Dict[str, int]):
|
|
160
|
+
"""
|
|
161
|
+
Sets the class name mapping to translate predicted integer labels
|
|
162
|
+
back into string names.
|
|
163
|
+
|
|
164
|
+
Args:
|
|
165
|
+
class_map (Dict[str, int]): The class_to_idx dictionary
|
|
166
|
+
(e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
|
|
167
|
+
"""
|
|
168
|
+
if self._idx_to_class is not None:
|
|
169
|
+
_LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
|
|
170
|
+
# Invert the dictionary for fast lookup
|
|
171
|
+
self._idx_to_class = {v: k for k, v in class_map.items()}
|
|
172
|
+
_LOGGER.info("Class map set for label-to-name translation.")
|
|
173
|
+
|
|
174
|
+
def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
175
|
+
"""
|
|
176
|
+
Core batch prediction method for vision models.
|
|
177
|
+
All preprocessing (resizing, normalization) should be done *before*
|
|
178
|
+
calling this method.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
inputs (torch.Tensor | List[torch.Tensor]):
|
|
182
|
+
- For 'image_classification' or 'image_segmentation',
|
|
183
|
+
a 4D torch.Tensor (B, C, H, W).
|
|
184
|
+
- For 'object_detection', a List of 3D torch.Tensors
|
|
185
|
+
[(C, H, W), ...], each with its own size.
|
|
186
|
+
|
|
187
|
+
Returns:
|
|
188
|
+
A dictionary containing the output tensors.
|
|
189
|
+
- Classification: {labels, probabilities}
|
|
190
|
+
- Segmentation: {labels, probabilities} (labels is the mask)
|
|
191
|
+
- Object Detection: {predictions} (List of dicts)
|
|
192
|
+
"""
|
|
193
|
+
|
|
194
|
+
processed_inputs = self._preprocess_batch(inputs)
|
|
195
|
+
|
|
196
|
+
with torch.no_grad():
|
|
197
|
+
if self.task == "image_classification":
|
|
198
|
+
# --- Image Classification ---
|
|
199
|
+
# 1. Predict
|
|
200
|
+
output = self._model(processed_inputs) # (B, num_classes)
|
|
201
|
+
|
|
202
|
+
# 2. Post-process
|
|
203
|
+
probs = torch.softmax(output, dim=1)
|
|
204
|
+
labels = torch.argmax(probs, dim=1)
|
|
205
|
+
return {
|
|
206
|
+
PyTorchInferenceKeys.LABELS: labels, # (B,)
|
|
207
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
elif self.task == "image_segmentation":
|
|
211
|
+
# --- Image Segmentation ---
|
|
212
|
+
# 1. Predict
|
|
213
|
+
output = self._model(processed_inputs) # (B, num_classes, H, W)
|
|
214
|
+
|
|
215
|
+
# 2. Post-process
|
|
216
|
+
probs = torch.softmax(output, dim=1) # Probs across class channel
|
|
217
|
+
labels = torch.argmax(probs, dim=1) # (B, H, W) segmentation map
|
|
218
|
+
return {
|
|
219
|
+
PyTorchInferenceKeys.LABELS: labels, # (B, H, W)
|
|
220
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes, H, W)
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
elif self.task == "object_detection":
|
|
224
|
+
# --- Object Detection ---
|
|
225
|
+
# 1. Predict (model is in eval mode, expects only images)
|
|
226
|
+
# Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
|
|
227
|
+
predictions = self._model(processed_inputs)
|
|
228
|
+
|
|
229
|
+
# 2. Post-process: Wrap in our standard key
|
|
230
|
+
return {
|
|
231
|
+
PyTorchInferenceKeys.PREDICTIONS: predictions
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
else:
|
|
235
|
+
# This should be unreachable due to __init__ check
|
|
236
|
+
raise ValueError(f"Unknown task: {self.task}")
|
|
237
|
+
|
|
238
|
+
def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
239
|
+
"""
|
|
240
|
+
Core single-sample prediction method for vision models.
|
|
241
|
+
All preprocessing (resizing, normalization) should be done *before*
|
|
242
|
+
calling this method.
|
|
243
|
+
|
|
244
|
+
Args:
|
|
245
|
+
single_input (torch.Tensor):
|
|
246
|
+
- A 3D torch.Tensor (C, H, W) for any task.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
A dictionary containing the output tensors for a single sample.
|
|
250
|
+
- Classification: {labels, probabilities} (label is 0-dim)
|
|
251
|
+
- Segmentation: {labels, probabilities} (label is 2D mask)
|
|
252
|
+
- Object Detection: {boxes, labels, scores} (single dict)
|
|
253
|
+
"""
|
|
254
|
+
if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
|
|
255
|
+
_LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
|
|
256
|
+
raise ValueError("Input must be a 3D tensor.")
|
|
257
|
+
|
|
258
|
+
# --- 1. Batch the input based on task ---
|
|
259
|
+
if self.task == "object_detection":
|
|
260
|
+
batched_input = [single_input] # List of one tensor
|
|
261
|
+
else:
|
|
262
|
+
batched_input = single_input.unsqueeze(0) # (1, C, H, W)
|
|
263
|
+
|
|
264
|
+
# --- 2. Call batch prediction ---
|
|
265
|
+
batch_results = self.predict_batch(batched_input)
|
|
266
|
+
|
|
267
|
+
# --- 3. Un-batch the results ---
|
|
268
|
+
if self.task == "object_detection":
|
|
269
|
+
# batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
|
|
270
|
+
return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
|
|
271
|
+
else:
|
|
272
|
+
# 'labels' and 'probabilities' are tensors. Get the 0-th element.
|
|
273
|
+
# (B, ...) -> (...)
|
|
274
|
+
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
275
|
+
return single_results
|
|
276
|
+
|
|
277
|
+
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
278
|
+
|
|
279
|
+
def predict_batch_numpy(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
280
|
+
"""
|
|
281
|
+
Convenience wrapper for predict_batch that returns NumPy arrays. With Labels if set.
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
Dict: A dictionary containing the outputs as NumPy arrays.
|
|
285
|
+
- Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
|
|
286
|
+
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
287
|
+
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
288
|
+
"""
|
|
289
|
+
tensor_results = self.predict_batch(inputs)
|
|
290
|
+
|
|
291
|
+
if self.task == "object_detection":
|
|
292
|
+
# Output is List[Dict[str, Tensor]]
|
|
293
|
+
# Convert each tensor inside each dict to numpy
|
|
294
|
+
numpy_results = []
|
|
295
|
+
for pred_dict in tensor_results[PyTorchInferenceKeys.PREDICTIONS]:
|
|
296
|
+
# Convert all tensors to numpy
|
|
297
|
+
np_dict = {key: value.cpu().numpy() for key, value in pred_dict.items()}
|
|
298
|
+
|
|
299
|
+
# Add string names if map exists
|
|
300
|
+
if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
|
|
301
|
+
np_dict[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
302
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
303
|
+
for label_id in np_dict[PyTorchInferenceKeys.LABELS]
|
|
304
|
+
]
|
|
305
|
+
numpy_results.append(np_dict)
|
|
306
|
+
return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
|
|
307
|
+
else:
|
|
308
|
+
# Output is Dict[str, Tensor]
|
|
309
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
310
|
+
|
|
311
|
+
# Add string names for classification if map exists
|
|
312
|
+
if self.task == "image_classification" and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
313
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
|
|
314
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
315
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
316
|
+
for label_id in int_labels
|
|
317
|
+
]
|
|
318
|
+
|
|
319
|
+
return numpy_results
|
|
320
|
+
|
|
321
|
+
def predict_numpy(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
322
|
+
"""
|
|
323
|
+
Convenience wrapper for predict that returns NumPy arrays/scalars.
|
|
324
|
+
|
|
325
|
+
Returns:
|
|
326
|
+
Dict: A dictionary containing the outputs as NumPy arrays/scalars.
|
|
327
|
+
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
|
|
328
|
+
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
329
|
+
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
330
|
+
"""
|
|
331
|
+
tensor_results = self.predict(single_input)
|
|
332
|
+
|
|
333
|
+
if self.task == "object_detection":
|
|
334
|
+
# Output is Dict[str, Tensor]
|
|
335
|
+
# Convert each tensor to numpy
|
|
336
|
+
numpy_results = {
|
|
337
|
+
key: value.cpu().numpy() for key, value in tensor_results.items()
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
# Add string names if map exists
|
|
341
|
+
if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
342
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
|
|
343
|
+
|
|
344
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
345
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
346
|
+
for label_id in int_labels
|
|
347
|
+
]
|
|
348
|
+
|
|
349
|
+
return numpy_results
|
|
350
|
+
|
|
351
|
+
elif self.task == "image_classification":
|
|
352
|
+
# Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
|
|
353
|
+
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
354
|
+
label_name = "Unknown"
|
|
355
|
+
if self._idx_to_class:
|
|
356
|
+
label_name = self._idx_to_class.get(int_label, "Unknown")
|
|
357
|
+
|
|
358
|
+
return {
|
|
359
|
+
PyTorchInferenceKeys.LABELS: int_label,
|
|
360
|
+
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
361
|
+
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
362
|
+
}
|
|
363
|
+
else: # image_segmentation
|
|
364
|
+
# Output is Dict[str, Tensor(2D) or Tensor(3D)]
|
|
365
|
+
return {
|
|
366
|
+
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
|
|
367
|
+
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Loads a single image from a file, applies the stored transform, and returns the prediction.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
image_path (str | Path): The file path to the input image.
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
379
|
+
"""
|
|
380
|
+
if self._transform is None:
|
|
381
|
+
_LOGGER.error("Cannot predict from file: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
|
|
382
|
+
raise RuntimeError("Inference transform is not set.")
|
|
383
|
+
|
|
384
|
+
try:
|
|
385
|
+
# --- Use expected_in_channels to set PIL mode ---
|
|
386
|
+
pil_mode: str
|
|
387
|
+
if self.expected_in_channels == 1:
|
|
388
|
+
pil_mode = "L" # Grayscale
|
|
389
|
+
elif self.expected_in_channels == 4:
|
|
390
|
+
pil_mode = "RGBA" # RGB + Alpha
|
|
391
|
+
else:
|
|
392
|
+
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
393
|
+
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
394
|
+
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
395
|
+
pil_mode = "RGB"
|
|
396
|
+
|
|
397
|
+
image = Image.open(image_path).convert(pil_mode)
|
|
398
|
+
except Exception as e:
|
|
399
|
+
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
400
|
+
raise
|
|
401
|
+
|
|
402
|
+
# Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
|
|
403
|
+
try:
|
|
404
|
+
transformed_image = self._transform(image)
|
|
405
|
+
except Exception as e:
|
|
406
|
+
_LOGGER.error(f"Error applying transform to image: {e}")
|
|
407
|
+
raise
|
|
408
|
+
|
|
409
|
+
# --- Validation ---
|
|
410
|
+
if not isinstance(transformed_image, torch.Tensor):
|
|
411
|
+
_LOGGER.error("The provided transform did not return a torch.Tensor. Does it include transforms.ToTensor()?")
|
|
412
|
+
raise ValueError("Transform pipeline must output a torch.Tensor.")
|
|
413
|
+
|
|
414
|
+
if transformed_image.ndim != 3:
|
|
415
|
+
_LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
|
|
416
|
+
# .predict() which expects a 3D tensor
|
|
417
|
+
if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
|
|
418
|
+
transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
|
|
419
|
+
_LOGGER.warning("Removed an extra batch dimension.")
|
|
420
|
+
else:
|
|
421
|
+
raise ValueError(f"Transform must output a 3D (C, H, W) tensor, got {transformed_image.shape}.")
|
|
422
|
+
|
|
423
|
+
# Use the existing single-item predict method
|
|
424
|
+
return self.predict_numpy(transformed_image)
|
|
425
|
+
|
|
426
|
+
|
|
427
|
+
def info():
|
|
428
|
+
_script_info(__all__)
|