dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
- dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/ETL_cleaning.py +20 -20
- ml_tools/ETL_engineering.py +23 -25
- ml_tools/GUI_tools.py +20 -20
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_callbacks.py +43 -26
- ml_tools/ML_configuration.py +788 -0
- ml_tools/ML_datasetmaster.py +303 -448
- ml_tools/ML_evaluation.py +351 -93
- ml_tools/ML_evaluation_multi.py +139 -42
- ml_tools/ML_inference.py +290 -209
- ml_tools/ML_models.py +33 -106
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_optimization.py +12 -12
- ml_tools/ML_scaler.py +11 -11
- ml_tools/ML_sequence_datasetmaster.py +341 -0
- ml_tools/ML_sequence_evaluation.py +219 -0
- ml_tools/ML_sequence_inference.py +391 -0
- ml_tools/ML_sequence_models.py +139 -0
- ml_tools/ML_trainer.py +1604 -179
- ml_tools/ML_utilities.py +351 -4
- ml_tools/ML_vision_datasetmaster.py +1540 -0
- ml_tools/ML_vision_evaluation.py +284 -0
- ml_tools/ML_vision_inference.py +405 -0
- ml_tools/ML_vision_models.py +641 -0
- ml_tools/ML_vision_transformers.py +284 -0
- ml_tools/PSO_optimization.py +6 -6
- ml_tools/SQL.py +4 -4
- ml_tools/_keys.py +171 -0
- ml_tools/_schema.py +1 -1
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/ensemble_evaluation.py +54 -11
- ml_tools/ensemble_inference.py +7 -33
- ml_tools/ensemble_learning.py +1 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/optimization_tools.py +2 -2
- ml_tools/path_manager.py +5 -5
- ml_tools/serde.py +2 -2
- ml_tools/utilities.py +192 -4
- dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
- ml_tools/RNN_forecast.py +0 -56
- ml_tools/keys.py +0 -87
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,284 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import seaborn as sns
|
|
5
|
+
import torch
|
|
6
|
+
from sklearn.metrics import (
|
|
7
|
+
accuracy_score,
|
|
8
|
+
f1_score,
|
|
9
|
+
jaccard_score,
|
|
10
|
+
confusion_matrix,
|
|
11
|
+
ConfusionMatrixDisplay
|
|
12
|
+
)
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Union, Optional, List, Dict
|
|
15
|
+
import json
|
|
16
|
+
from torchmetrics.detection import MeanAveragePrecision
|
|
17
|
+
|
|
18
|
+
from .path_manager import make_fullpath
|
|
19
|
+
from ._logger import _LOGGER
|
|
20
|
+
from ._script_info import _script_info
|
|
21
|
+
from ._keys import VisionKeys
|
|
22
|
+
from .ML_configuration import (BinarySegmentationMetricsFormat,
|
|
23
|
+
MultiClassSegmentationMetricsFormat,
|
|
24
|
+
_BaseSegmentationFormat)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
__all__ = [
|
|
28
|
+
"segmentation_metrics",
|
|
29
|
+
"object_detection_metrics"
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
DPI_value = 250
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def segmentation_metrics(
|
|
36
|
+
y_true: np.ndarray,
|
|
37
|
+
y_pred: np.ndarray,
|
|
38
|
+
save_dir: Union[str, Path],
|
|
39
|
+
class_names: Optional[List[str]] = None,
|
|
40
|
+
config: Optional[Union[BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat]] = None
|
|
41
|
+
):
|
|
42
|
+
"""
|
|
43
|
+
Calculates and saves pixel-level metrics for segmentation tasks.
|
|
44
|
+
|
|
45
|
+
Metrics include Pixel Accuracy, Dice (F1-score), and IoU (Jaccard).
|
|
46
|
+
It calculates 'micro', 'macro', and 'weighted' averages and saves
|
|
47
|
+
a pixel-level confusion matrix and a metrics heatmap.
|
|
48
|
+
|
|
49
|
+
Note: This function expects integer-based masks (e.g., shape [N, H, W] or [H, W]),
|
|
50
|
+
not one-hot encoded masks.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
y_true (np.ndarray): Ground truth masks (e.g., shape [N, H, W]).
|
|
54
|
+
y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
|
|
55
|
+
save_dir (str | Path): Directory to save the metrics report and plots.
|
|
56
|
+
class_names (List[str] | None): Names of the classes for the report.
|
|
57
|
+
config (object): Formatting configuration object.
|
|
58
|
+
"""
|
|
59
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
60
|
+
|
|
61
|
+
# --- Parse Config or use defaults ---
|
|
62
|
+
if config is None:
|
|
63
|
+
format_config = _BaseSegmentationFormat()
|
|
64
|
+
else:
|
|
65
|
+
format_config = config
|
|
66
|
+
|
|
67
|
+
# --- Set Matplotlib font size ---
|
|
68
|
+
original_rc_params = plt.rcParams.copy()
|
|
69
|
+
plt.rcParams.update({'font.size': format_config.font_size})
|
|
70
|
+
|
|
71
|
+
# Get all unique class labels present in either true or pred
|
|
72
|
+
labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
|
|
73
|
+
|
|
74
|
+
# --- Setup Class Names ---
|
|
75
|
+
display_names = []
|
|
76
|
+
if class_names is None:
|
|
77
|
+
display_names = [f"Class {i}" for i in labels]
|
|
78
|
+
else:
|
|
79
|
+
if len(class_names) != len(labels):
|
|
80
|
+
_LOGGER.warning(f"Number of class_names ({len(class_names)}) does not match number of unique labels ({len(labels)}). Using default names.")
|
|
81
|
+
display_names = [f"Class {i}" for i in labels]
|
|
82
|
+
else:
|
|
83
|
+
display_names = class_names
|
|
84
|
+
|
|
85
|
+
# Flatten masks for sklearn metrics
|
|
86
|
+
y_true_flat = y_true.ravel()
|
|
87
|
+
y_pred_flat = y_pred.ravel()
|
|
88
|
+
|
|
89
|
+
_LOGGER.info("--- Calculating Segmentation Metrics ---")
|
|
90
|
+
|
|
91
|
+
# --- 1. Calculate Metrics ---
|
|
92
|
+
pix_acc = accuracy_score(y_true_flat, y_pred_flat)
|
|
93
|
+
|
|
94
|
+
# Calculate all average types
|
|
95
|
+
dice_micro = f1_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
|
|
96
|
+
iou_micro = jaccard_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
|
|
97
|
+
|
|
98
|
+
dice_macro = f1_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
|
|
99
|
+
iou_macro = jaccard_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
|
|
100
|
+
|
|
101
|
+
dice_weighted = f1_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
|
|
102
|
+
iou_weighted = jaccard_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
|
|
103
|
+
|
|
104
|
+
# Per-class metrics
|
|
105
|
+
dice_per_class = f1_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
|
|
106
|
+
iou_per_class = jaccard_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
|
|
107
|
+
|
|
108
|
+
# --- 2. Create and Save Report ---
|
|
109
|
+
report_lines = [
|
|
110
|
+
"--- Segmentation Report ---",
|
|
111
|
+
f"\nOverall Pixel Accuracy: {pix_acc:.4f}\n",
|
|
112
|
+
"--- Averaged Metrics ---",
|
|
113
|
+
f"{'Average':<10} | {'Dice (F1)':<12} | {'IoU (Jaccard)':<12}",
|
|
114
|
+
"-"*41,
|
|
115
|
+
f"{'Micro':<10} | {dice_micro:<12.4f} | {iou_micro:<12.4f}",
|
|
116
|
+
f"{'Macro':<10} | {dice_macro:<12.4f} | {iou_macro:<12.4f}",
|
|
117
|
+
f"{'Weighted':<10} | {dice_weighted:<12.4f} | {iou_weighted:<12.4f}",
|
|
118
|
+
"\n--- Per-Class Metrics ---",
|
|
119
|
+
]
|
|
120
|
+
|
|
121
|
+
per_class_data = {
|
|
122
|
+
'Class': display_names,
|
|
123
|
+
'Dice': dice_per_class,
|
|
124
|
+
'IoU': iou_per_class
|
|
125
|
+
}
|
|
126
|
+
per_class_df = pd.DataFrame(per_class_data)
|
|
127
|
+
report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
|
|
128
|
+
|
|
129
|
+
report_string = "\n".join(report_lines)
|
|
130
|
+
# print(report_string) # <-- I removed the print(report_string)
|
|
131
|
+
|
|
132
|
+
# Save text report
|
|
133
|
+
save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
|
|
134
|
+
report_path = save_dir_path / save_filename
|
|
135
|
+
report_path.write_text(report_string, encoding="utf-8")
|
|
136
|
+
_LOGGER.info(f"📝 Segmentation report saved as '{report_path.name}'")
|
|
137
|
+
|
|
138
|
+
# --- 3. Save Per-Class Metrics Heatmap ---
|
|
139
|
+
try:
|
|
140
|
+
plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=DPI_value)
|
|
141
|
+
sns.heatmap(
|
|
142
|
+
per_class_df.set_index('Class').T,
|
|
143
|
+
annot=True,
|
|
144
|
+
cmap=format_config.heatmap_cmap, # Use config cmap
|
|
145
|
+
fmt='.3f',
|
|
146
|
+
linewidths=0.5
|
|
147
|
+
)
|
|
148
|
+
plt.title("Per-Class Segmentation Metrics")
|
|
149
|
+
plt.tight_layout()
|
|
150
|
+
heatmap_filename = VisionKeys.SEGMENTATION_HEATMAP + ".svg"
|
|
151
|
+
heatmap_path = save_dir_path / heatmap_filename
|
|
152
|
+
plt.savefig(heatmap_path)
|
|
153
|
+
_LOGGER.info(f"📊 Metrics heatmap saved as '{heatmap_path.name}'")
|
|
154
|
+
plt.close()
|
|
155
|
+
except Exception as e:
|
|
156
|
+
_LOGGER.error(f"Could not generate segmentation metrics heatmap: {e}")
|
|
157
|
+
|
|
158
|
+
# --- 4. Save Pixel-level Confusion Matrix ---
|
|
159
|
+
try:
|
|
160
|
+
# Calculate CM
|
|
161
|
+
cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
|
|
162
|
+
|
|
163
|
+
# Plot
|
|
164
|
+
fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=100)
|
|
165
|
+
disp = ConfusionMatrixDisplay(
|
|
166
|
+
confusion_matrix=cm,
|
|
167
|
+
display_labels=display_names
|
|
168
|
+
)
|
|
169
|
+
disp.plot(cmap=format_config.cm_cmap, ax=ax_cm, xticks_rotation=45) # Use config cmap
|
|
170
|
+
|
|
171
|
+
# Manually update font size of cell texts
|
|
172
|
+
for text in disp.text_.flatten(): # type: ignore
|
|
173
|
+
text.set_fontsize(format_config.font_size)
|
|
174
|
+
|
|
175
|
+
ax_cm.set_title("Pixel-Level Confusion Matrix")
|
|
176
|
+
plt.tight_layout()
|
|
177
|
+
segmentation_cm_filename = VisionKeys.SEGMENTATION_CONFUSION_MATRIX + ".svg"
|
|
178
|
+
cm_path = save_dir_path / segmentation_cm_filename
|
|
179
|
+
plt.savefig(cm_path)
|
|
180
|
+
_LOGGER.info(f"❇️ Pixel-level confusion matrix saved as '{cm_path.name}'")
|
|
181
|
+
plt.close(fig_cm)
|
|
182
|
+
except Exception as e:
|
|
183
|
+
_LOGGER.error(f"Could not generate confusion matrix: {e}")
|
|
184
|
+
|
|
185
|
+
# --- Restore RC params ---
|
|
186
|
+
plt.rcParams.update(original_rc_params)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def object_detection_metrics(
|
|
190
|
+
preds: List[Dict[str, torch.Tensor]],
|
|
191
|
+
targets: List[Dict[str, torch.Tensor]],
|
|
192
|
+
save_dir: Union[str, Path],
|
|
193
|
+
class_names: Optional[List[str]] = None,
|
|
194
|
+
print_output: bool=False
|
|
195
|
+
):
|
|
196
|
+
"""
|
|
197
|
+
Calculates and saves object detection metrics (mAP) using torchmetrics.
|
|
198
|
+
|
|
199
|
+
This function expects predictions and targets in the standard
|
|
200
|
+
torchvision format (list of dictionaries).
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
preds (List[Dict[str, torch.Tensor]]): A list of predictions.
|
|
204
|
+
Each dict must contain:
|
|
205
|
+
- 'boxes': [N, 4] (xmin, ymin, xmax, ymax)
|
|
206
|
+
- 'scores': [N]
|
|
207
|
+
- 'labels': [N]
|
|
208
|
+
targets (List[Dict[str, torch.Tensor]]): A list of ground truths.
|
|
209
|
+
Each dict must contain:
|
|
210
|
+
- 'boxes': [M, 4]
|
|
211
|
+
- 'labels': [M]
|
|
212
|
+
save_dir (str | Path): Directory to save the metrics report (as JSON).
|
|
213
|
+
class_names (List[str] | None): A list of class names, including 'background'
|
|
214
|
+
at index 0. Used to label per-class metrics in the report.
|
|
215
|
+
print_output (bool): If True, prints the JSON report to the console.
|
|
216
|
+
"""
|
|
217
|
+
save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
|
|
218
|
+
|
|
219
|
+
_LOGGER.info("--- Calculating Object Detection Metrics (mAP) ---")
|
|
220
|
+
|
|
221
|
+
try:
|
|
222
|
+
# Initialize the metric with standard COCO settings
|
|
223
|
+
metric = MeanAveragePrecision(box_format='xyxy')
|
|
224
|
+
|
|
225
|
+
# Move preds and targets to the same device (e.g., CPU for metric calculation)
|
|
226
|
+
# This avoids device mismatches if model was on GPU
|
|
227
|
+
device = torch.device("cpu")
|
|
228
|
+
preds_cpu = [{k: v.to(device) for k, v in p.items()} for p in preds]
|
|
229
|
+
targets_cpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
|
|
230
|
+
|
|
231
|
+
# Update the metric
|
|
232
|
+
metric.update(preds_cpu, targets_cpu)
|
|
233
|
+
|
|
234
|
+
# Compute the final metrics
|
|
235
|
+
results = metric.compute()
|
|
236
|
+
|
|
237
|
+
# --- Handle class names for per-class metrics ---
|
|
238
|
+
report_class_names = None
|
|
239
|
+
if class_names:
|
|
240
|
+
if class_names[0].lower() in ['background', "bg"]:
|
|
241
|
+
report_class_names = class_names[1:] # Skip background (class 0)
|
|
242
|
+
else:
|
|
243
|
+
_LOGGER.warning("class_names provided to object_detection_metrics, but 'background' was not class 0. Using all provided names.")
|
|
244
|
+
report_class_names = class_names
|
|
245
|
+
|
|
246
|
+
# Convert all torch tensors in results to floats/lists for JSON serialization
|
|
247
|
+
serializable_results = {}
|
|
248
|
+
for key, value in results.items():
|
|
249
|
+
if isinstance(value, torch.Tensor):
|
|
250
|
+
if value.numel() == 1:
|
|
251
|
+
serializable_results[key] = value.item()
|
|
252
|
+
# Check if it's a 1D tensor, we have class names, and it's a known per-class key
|
|
253
|
+
elif value.ndim == 1 and report_class_names and key in ('map_per_class', 'mar_100_per_class', 'mar_1_per_class', 'mar_10_per_class'):
|
|
254
|
+
per_class_list = value.cpu().numpy().tolist()
|
|
255
|
+
# Map names to values
|
|
256
|
+
if len(per_class_list) == len(report_class_names):
|
|
257
|
+
serializable_results[key] = {name: val for name, val in zip(report_class_names, per_class_list)}
|
|
258
|
+
else:
|
|
259
|
+
_LOGGER.warning(f"Length mismatch for '{key}': {len(per_class_list)} values vs {len(report_class_names)} class names. Saving as raw list.")
|
|
260
|
+
serializable_results[key] = per_class_list
|
|
261
|
+
else:
|
|
262
|
+
serializable_results[key] = value.cpu().numpy().tolist()
|
|
263
|
+
else:
|
|
264
|
+
serializable_results[key] = value
|
|
265
|
+
|
|
266
|
+
# Pretty print to console
|
|
267
|
+
if print_output:
|
|
268
|
+
print(json.dumps(serializable_results, indent=4))
|
|
269
|
+
|
|
270
|
+
# Save JSON report
|
|
271
|
+
detection_report_filename = VisionKeys.OBJECT_DETECTION_REPORT + ".json"
|
|
272
|
+
report_path = save_dir_path / detection_report_filename
|
|
273
|
+
with open(report_path, 'w') as f:
|
|
274
|
+
json.dump(serializable_results, f, indent=4)
|
|
275
|
+
|
|
276
|
+
_LOGGER.info(f"📊 Object detection (mAP) report saved as '{report_path.name}'")
|
|
277
|
+
|
|
278
|
+
except Exception as e:
|
|
279
|
+
_LOGGER.error(f"Failed to compute mAP: {e}")
|
|
280
|
+
raise
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def info():
|
|
284
|
+
_script_info(__all__)
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch import nn
|
|
3
|
+
import numpy as np #numpy array return value
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Union, Literal, Dict, Any, List, Optional, Callable
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision import transforms
|
|
8
|
+
|
|
9
|
+
from ._script_info import _script_info
|
|
10
|
+
from ._logger import _LOGGER
|
|
11
|
+
from ._keys import PyTorchInferenceKeys, MLTaskKeys
|
|
12
|
+
from .ML_vision_transformers import _load_recipe_and_build_transform
|
|
13
|
+
from .ML_inference import _BaseInferenceHandler
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
__all__ = [
|
|
17
|
+
"DragonVisionInferenceHandler"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DragonVisionInferenceHandler(_BaseInferenceHandler):
|
|
22
|
+
"""
|
|
23
|
+
Handles loading a PyTorch vision model's state dictionary and performing inference.
|
|
24
|
+
|
|
25
|
+
This class is specifically for vision models, which typically expect
|
|
26
|
+
4D Tensors (B, C, H, W) or Lists of Tensors as input.
|
|
27
|
+
"""
|
|
28
|
+
def __init__(self,
|
|
29
|
+
model: nn.Module,
|
|
30
|
+
state_dict: Union[str, Path],
|
|
31
|
+
task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
|
|
32
|
+
device: str = 'cpu',
|
|
33
|
+
transform_source: Optional[Union[str, Path, Callable]] = None):
|
|
34
|
+
"""
|
|
35
|
+
Initializes the vision inference handler.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
model (nn.Module): An instantiated PyTorch model from ML_vision_models.
|
|
39
|
+
state_dict (str | Path): Path to the saved .pth model state_dict file.
|
|
40
|
+
task (str): The type of vision task.
|
|
41
|
+
device (str): The device to run inference on ('cpu', 'cuda', 'mps').
|
|
42
|
+
transform_source (str | Path | Callable | None):
|
|
43
|
+
- A path to a .json recipe file (str or Path).
|
|
44
|
+
- A pre-built transformation pipeline (Callable).
|
|
45
|
+
- None, in which case .set_transform() must be called explicitly to set transformations.
|
|
46
|
+
|
|
47
|
+
Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
|
|
48
|
+
"""
|
|
49
|
+
super().__init__(model, state_dict, device, None)
|
|
50
|
+
|
|
51
|
+
self._transform: Optional[Callable] = None
|
|
52
|
+
self._is_transformed: bool = False
|
|
53
|
+
|
|
54
|
+
if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
|
|
55
|
+
_LOGGER.error(f"Unsupported task: '{task}'.")
|
|
56
|
+
raise ValueError()
|
|
57
|
+
self.task = task
|
|
58
|
+
|
|
59
|
+
self.expected_in_channels: int = 3 # Default to RGB
|
|
60
|
+
if hasattr(model, 'in_channels'):
|
|
61
|
+
self.expected_in_channels = model.in_channels # type: ignore
|
|
62
|
+
_LOGGER.info(f"Model expects {self.expected_in_channels} input channels.")
|
|
63
|
+
else:
|
|
64
|
+
_LOGGER.warning("Could not determine 'in_channels' from model. Defaulting to 3 (RGB). Modify with '.expected_in_channels'.")
|
|
65
|
+
|
|
66
|
+
if transform_source:
|
|
67
|
+
self.set_transform(transform_source)
|
|
68
|
+
self._is_transformed = True
|
|
69
|
+
|
|
70
|
+
def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
|
|
71
|
+
"""
|
|
72
|
+
Validates input and moves it to the correct device.
|
|
73
|
+
- For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
|
|
74
|
+
- For Object Detection: Expects List[Tensor(C, H, W)].
|
|
75
|
+
"""
|
|
76
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
77
|
+
if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
|
|
78
|
+
_LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
|
|
79
|
+
raise ValueError("Invalid input type for object detection.")
|
|
80
|
+
# Move each tensor in the list to the device
|
|
81
|
+
return [t.float().to(self.device) for t in inputs]
|
|
82
|
+
|
|
83
|
+
else: # Classification or Segmentation
|
|
84
|
+
if not isinstance(inputs, torch.Tensor):
|
|
85
|
+
_LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
|
|
86
|
+
raise ValueError(f"Invalid input type for {self.task}.")
|
|
87
|
+
|
|
88
|
+
if inputs.ndim != 4: # type: ignore
|
|
89
|
+
_LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
|
|
90
|
+
raise ValueError("Input tensor must be 4D.")
|
|
91
|
+
|
|
92
|
+
return inputs.float().to(self.device)
|
|
93
|
+
|
|
94
|
+
def set_transform(self, transform_source: Union[str, Path, Callable]):
|
|
95
|
+
"""
|
|
96
|
+
Sets or updates the inference transformation pipeline from a recipe file or a direct Callable.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
transform_source (str, Path, Callable):
|
|
100
|
+
- A path to a .json recipe file (str or Path).
|
|
101
|
+
- A pre-built transformation pipeline (Callable).
|
|
102
|
+
"""
|
|
103
|
+
if self._is_transformed:
|
|
104
|
+
_LOGGER.warning("Transformations were previously applied. Applying new transformations...")
|
|
105
|
+
|
|
106
|
+
if isinstance(transform_source, (str, Path)):
|
|
107
|
+
_LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
|
|
108
|
+
try:
|
|
109
|
+
# Use the loader function
|
|
110
|
+
self._transform = _load_recipe_and_build_transform(transform_source)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
_LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
|
|
113
|
+
raise
|
|
114
|
+
elif isinstance(transform_source, Callable):
|
|
115
|
+
_LOGGER.info("Inference transform has been set from a direct Callable.")
|
|
116
|
+
self._transform = transform_source
|
|
117
|
+
else:
|
|
118
|
+
_LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
|
|
119
|
+
raise TypeError("transform_source must be a file path or a Callable.")
|
|
120
|
+
|
|
121
|
+
def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
122
|
+
"""
|
|
123
|
+
Core batch prediction method for vision models.
|
|
124
|
+
All preprocessing (resizing, normalization) should be done *before* calling this method.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
inputs (torch.Tensor | List[torch.Tensor]):
|
|
128
|
+
- For binary/multiclass image classification or binary/multiclass image segmentation tasks,
|
|
129
|
+
a 4D torch.Tensor (B, C, H, W).
|
|
130
|
+
- For 'object_detection', a List of 3D torch.Tensors
|
|
131
|
+
[(C, H, W), ...], each with its own size.
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
A dictionary containing the output tensors.
|
|
135
|
+
- Classification: {labels, probabilities}
|
|
136
|
+
- Segmentation: {labels, probabilities} (labels is the mask)
|
|
137
|
+
- Object Detection: {predictions} (List of dicts)
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
processed_inputs = self._preprocess_batch(inputs)
|
|
141
|
+
|
|
142
|
+
with torch.no_grad():
|
|
143
|
+
# get outputs
|
|
144
|
+
output = self.model(processed_inputs)
|
|
145
|
+
if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
|
|
146
|
+
# process
|
|
147
|
+
probs = torch.softmax(output, dim=1)
|
|
148
|
+
labels = torch.argmax(probs, dim=1)
|
|
149
|
+
return {
|
|
150
|
+
PyTorchInferenceKeys.LABELS: labels, # (B,)
|
|
151
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
|
|
155
|
+
# Assumes model output is [N, 1] (a single logit)
|
|
156
|
+
# Squeeze output from [N, 1] to [N] if necessary
|
|
157
|
+
if output.ndim == 2 and output.shape[1] == 1:
|
|
158
|
+
output = output.squeeze(1)
|
|
159
|
+
|
|
160
|
+
probs = torch.sigmoid(output) # Probability of positive class
|
|
161
|
+
labels = (probs >= self._classification_threshold).int()
|
|
162
|
+
return {
|
|
163
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
164
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
|
|
168
|
+
# Assumes model output is [N, 1, H, W] (logits for positive class)
|
|
169
|
+
probs = torch.sigmoid(output) # Shape [N, 1, H, W]
|
|
170
|
+
labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
|
|
171
|
+
return {
|
|
172
|
+
PyTorchInferenceKeys.LABELS: labels,
|
|
173
|
+
PyTorchInferenceKeys.PROBABILITIES: probs
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
|
|
177
|
+
# output shape [N, C, H, W]
|
|
178
|
+
probs = torch.softmax(output, dim=1)
|
|
179
|
+
labels = torch.argmax(probs, dim=1) # shape [N, H, W]
|
|
180
|
+
return {
|
|
181
|
+
PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
|
|
182
|
+
PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
elif self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
186
|
+
return {
|
|
187
|
+
PyTorchInferenceKeys.PREDICTIONS: output
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
else:
|
|
191
|
+
# This should be unreachable due to validation
|
|
192
|
+
raise ValueError(f"Unknown task: {self.task}")
|
|
193
|
+
|
|
194
|
+
def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
195
|
+
"""
|
|
196
|
+
Core single-sample prediction method for vision models.
|
|
197
|
+
All preprocessing (resizing, normalization) should be done *before*
|
|
198
|
+
calling this method.
|
|
199
|
+
|
|
200
|
+
Args:
|
|
201
|
+
single_input (torch.Tensor):
|
|
202
|
+
- A 3D torch.Tensor (C, H, W) for any task.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
A dictionary containing the output tensors for a single sample.
|
|
206
|
+
- Classification: {labels, probabilities} (label is 0-dim)
|
|
207
|
+
- Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
|
|
208
|
+
- Object Detection: {boxes, labels, scores} (single dict)
|
|
209
|
+
"""
|
|
210
|
+
if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
|
|
211
|
+
_LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
|
|
212
|
+
raise ValueError()
|
|
213
|
+
|
|
214
|
+
# --- 1. Batch the input based on task ---
|
|
215
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
216
|
+
batched_input = [single_input] # List of one tensor
|
|
217
|
+
else:
|
|
218
|
+
batched_input = single_input.unsqueeze(0)
|
|
219
|
+
|
|
220
|
+
# --- 2. Call batch prediction ---
|
|
221
|
+
batch_results = self.predict_batch(batched_input)
|
|
222
|
+
|
|
223
|
+
# --- 3. Un-batch the results ---
|
|
224
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
225
|
+
# batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
|
|
226
|
+
return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
|
|
227
|
+
else:
|
|
228
|
+
# 'labels' and 'probabilities' are tensors. Get the 0-th element.
|
|
229
|
+
# (B, ...) -> (...)
|
|
230
|
+
single_results = {key: value[0] for key, value in batch_results.items()}
|
|
231
|
+
return single_results
|
|
232
|
+
|
|
233
|
+
# --- NumPy Convenience Wrappers (on CPU) ---
|
|
234
|
+
|
|
235
|
+
def predict_batch_numpy(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
|
|
236
|
+
"""
|
|
237
|
+
Convenience wrapper for predict_batch that returns NumPy arrays. With Labels if set.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
Dict: A dictionary containing the outputs as NumPy arrays.
|
|
241
|
+
- Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
|
|
242
|
+
- Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
|
|
243
|
+
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
244
|
+
"""
|
|
245
|
+
tensor_results = self.predict_batch(inputs)
|
|
246
|
+
|
|
247
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
248
|
+
# Output is List[Dict[str, Tensor]]
|
|
249
|
+
# Convert each tensor inside each dict to numpy
|
|
250
|
+
numpy_results = []
|
|
251
|
+
for pred_dict in tensor_results[PyTorchInferenceKeys.PREDICTIONS]:
|
|
252
|
+
# Convert all tensors to numpy
|
|
253
|
+
np_dict = {key: value.cpu().numpy() for key, value in pred_dict.items()}
|
|
254
|
+
|
|
255
|
+
# 3D pixel to string map unnecessary
|
|
256
|
+
# if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
|
|
257
|
+
# np_dict[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
258
|
+
# self._idx_to_class.get(label_id, "Unknown")
|
|
259
|
+
# for label_id in np_dict[PyTorchInferenceKeys.LABELS]
|
|
260
|
+
# ]
|
|
261
|
+
numpy_results.append(np_dict)
|
|
262
|
+
return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
|
|
263
|
+
|
|
264
|
+
else:
|
|
265
|
+
# Output is Dict[str, Tensor] (for Classification or Segmentation)
|
|
266
|
+
numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
|
|
267
|
+
|
|
268
|
+
# Add string names for classification if map exists
|
|
269
|
+
is_image_classification = self.task in [
|
|
270
|
+
MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
|
|
271
|
+
MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
275
|
+
int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
|
|
276
|
+
numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
277
|
+
self._idx_to_class.get(label_id, "Unknown")
|
|
278
|
+
for label_id in int_labels
|
|
279
|
+
]
|
|
280
|
+
|
|
281
|
+
return numpy_results
|
|
282
|
+
|
|
283
|
+
def predict_numpy(self, single_input: torch.Tensor) -> Dict[str, Any]:
|
|
284
|
+
"""
|
|
285
|
+
Convenience wrapper for predict that returns NumPy arrays/scalars.
|
|
286
|
+
|
|
287
|
+
Returns:
|
|
288
|
+
Dict: A dictionary containing the outputs as NumPy arrays/scalars.
|
|
289
|
+
- Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
|
|
290
|
+
- Classification: {labels: int, label_names: str, probabilities: np.ndarray}
|
|
291
|
+
- Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
|
|
292
|
+
"""
|
|
293
|
+
tensor_results = self.predict(single_input)
|
|
294
|
+
|
|
295
|
+
if self.task == MLTaskKeys.OBJECT_DETECTION:
|
|
296
|
+
# Output is Dict[str, Tensor]
|
|
297
|
+
# Convert each tensor to numpy
|
|
298
|
+
numpy_results = {
|
|
299
|
+
key: value.cpu().numpy() for key, value in tensor_results.items()
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
# Add string names if map exists
|
|
303
|
+
# if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
|
|
304
|
+
# int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
|
|
305
|
+
|
|
306
|
+
# numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
|
|
307
|
+
# self._idx_to_class.get(label_id, "Unknown")
|
|
308
|
+
# for label_id in int_labels
|
|
309
|
+
# ]
|
|
310
|
+
|
|
311
|
+
return numpy_results
|
|
312
|
+
|
|
313
|
+
elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
|
|
314
|
+
# Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
|
|
315
|
+
int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
|
|
316
|
+
label_name = "Unknown"
|
|
317
|
+
if self._idx_to_class:
|
|
318
|
+
label_name = self._idx_to_class.get(int_label, "Unknown")
|
|
319
|
+
|
|
320
|
+
return {
|
|
321
|
+
PyTorchInferenceKeys.LABELS: int_label,
|
|
322
|
+
PyTorchInferenceKeys.LABEL_NAMES: label_name,
|
|
323
|
+
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
324
|
+
}
|
|
325
|
+
else: # image_segmentation (binary or multiclass)
|
|
326
|
+
# Output is Dict[str, Tensor(2D) or Tensor(3D)]
|
|
327
|
+
return {
|
|
328
|
+
PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
|
|
329
|
+
PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
|
|
333
|
+
"""
|
|
334
|
+
Applies the stored transform to a single PIL image and returns the prediction.
|
|
335
|
+
|
|
336
|
+
Args:
|
|
337
|
+
image (PIL.Image.Image): The input PIL image.
|
|
338
|
+
|
|
339
|
+
Returns:
|
|
340
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
341
|
+
"""
|
|
342
|
+
if self._transform is None:
|
|
343
|
+
_LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
|
|
344
|
+
raise RuntimeError("Inference transform is not set.")
|
|
345
|
+
|
|
346
|
+
# Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
|
|
347
|
+
try:
|
|
348
|
+
transformed_image = self._transform(image)
|
|
349
|
+
except Exception as e:
|
|
350
|
+
_LOGGER.error(f"Error applying transform to PIL image: {e}")
|
|
351
|
+
raise
|
|
352
|
+
|
|
353
|
+
# --- Validation ---
|
|
354
|
+
if not isinstance(transformed_image, torch.Tensor):
|
|
355
|
+
_LOGGER.error("The provided transform did not return a torch.Tensor. Does it include transforms.ToTensor()?")
|
|
356
|
+
raise ValueError("Transform pipeline must output a torch.Tensor.")
|
|
357
|
+
|
|
358
|
+
if transformed_image.ndim != 3:
|
|
359
|
+
_LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
|
|
360
|
+
# .predict_numpy() -> .predict() which expects a 3D tensor
|
|
361
|
+
if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
|
|
362
|
+
transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
|
|
363
|
+
_LOGGER.warning("Removed an extra batch dimension.")
|
|
364
|
+
else:
|
|
365
|
+
raise ValueError(f"Transform must output a 3D (C, H, W) tensor, got {transformed_image.shape}.")
|
|
366
|
+
|
|
367
|
+
# Use the existing single-item predict method
|
|
368
|
+
return self.predict_numpy(transformed_image)
|
|
369
|
+
|
|
370
|
+
def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
|
|
371
|
+
"""
|
|
372
|
+
Loads a single image from a file, applies the stored transform, and returns the prediction.
|
|
373
|
+
|
|
374
|
+
This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
|
|
375
|
+
|
|
376
|
+
Args:
|
|
377
|
+
image_path (str | Path): The file path to the input image.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
|
|
381
|
+
"""
|
|
382
|
+
try:
|
|
383
|
+
# --- Use expected_in_channels to set PIL mode ---
|
|
384
|
+
pil_mode: str
|
|
385
|
+
if self.expected_in_channels == 1:
|
|
386
|
+
pil_mode = "L" # Grayscale
|
|
387
|
+
elif self.expected_in_channels == 4:
|
|
388
|
+
pil_mode = "RGBA" # RGB + Alpha
|
|
389
|
+
else:
|
|
390
|
+
if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
|
|
391
|
+
_LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
|
|
392
|
+
# Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
|
|
393
|
+
pil_mode = "RGB"
|
|
394
|
+
|
|
395
|
+
image = Image.open(image_path).convert(pil_mode)
|
|
396
|
+
except Exception as e:
|
|
397
|
+
_LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
|
|
398
|
+
raise
|
|
399
|
+
|
|
400
|
+
# Call the PIL-based prediction method
|
|
401
|
+
return self.predict_from_pil(image)
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def info():
|
|
405
|
+
_script_info(__all__)
|