dragon-ml-toolbox 13.0.0__py3-none-any.whl → 14.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/METADATA +12 -2
  2. dragon_ml_toolbox-14.7.0.dist-info/RECORD +49 -0
  3. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/MICE_imputation.py +207 -5
  5. ml_tools/ML_configuration.py +108 -0
  6. ml_tools/ML_datasetmaster.py +241 -260
  7. ml_tools/ML_evaluation.py +229 -76
  8. ml_tools/ML_evaluation_multi.py +45 -16
  9. ml_tools/ML_inference.py +0 -1
  10. ml_tools/ML_models.py +135 -55
  11. ml_tools/ML_models_advanced.py +323 -0
  12. ml_tools/ML_optimization.py +49 -36
  13. ml_tools/ML_trainer.py +498 -29
  14. ml_tools/ML_utilities.py +351 -4
  15. ml_tools/ML_vision_datasetmaster.py +1492 -0
  16. ml_tools/ML_vision_evaluation.py +260 -0
  17. ml_tools/ML_vision_inference.py +428 -0
  18. ml_tools/ML_vision_models.py +641 -0
  19. ml_tools/ML_vision_transformers.py +203 -0
  20. ml_tools/PSO_optimization.py +5 -1
  21. ml_tools/_ML_vision_recipe.py +88 -0
  22. ml_tools/__init__.py +1 -0
  23. ml_tools/_schema.py +96 -0
  24. ml_tools/custom_logger.py +37 -14
  25. ml_tools/data_exploration.py +576 -138
  26. ml_tools/ensemble_evaluation.py +53 -10
  27. ml_tools/keys.py +43 -1
  28. ml_tools/math_utilities.py +1 -1
  29. ml_tools/optimization_tools.py +65 -86
  30. ml_tools/serde.py +78 -17
  31. ml_tools/utilities.py +192 -3
  32. dragon_ml_toolbox-13.0.0.dist-info/RECORD +0 -41
  33. ml_tools/ML_simple_optimization.py +0 -413
  34. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/WHEEL +0 -0
  35. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/licenses/LICENSE +0 -0
  36. {dragon_ml_toolbox-13.0.0.dist-info → dragon_ml_toolbox-14.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,260 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import torch
6
+ from sklearn.metrics import (
7
+ accuracy_score,
8
+ f1_score,
9
+ jaccard_score,
10
+ confusion_matrix,
11
+ ConfusionMatrixDisplay
12
+ )
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Dict
15
+ import json
16
+ from torchmetrics.detection import MeanAveragePrecision
17
+
18
+ from .path_manager import make_fullpath
19
+ from ._logger import _LOGGER
20
+ from ._script_info import _script_info
21
+ from .keys import VisionKeys
22
+
23
+
24
+ __all__ = [
25
+ "segmentation_metrics",
26
+ "object_detection_metrics"
27
+ ]
28
+
29
+
30
+ def segmentation_metrics(
31
+ y_true: np.ndarray,
32
+ y_pred: np.ndarray,
33
+ save_dir: Union[str, Path],
34
+ class_names: Optional[List[str]] = None
35
+ ):
36
+ """
37
+ Calculates and saves pixel-level metrics for segmentation tasks.
38
+
39
+ Metrics include Pixel Accuracy, Dice (F1-score), and IoU (Jaccard).
40
+ It calculates 'micro', 'macro', and 'weighted' averages and saves
41
+ a pixel-level confusion matrix and a metrics heatmap.
42
+
43
+ Note: This function expects integer-based masks (e.g., shape [N, H, W] or [H, W]),
44
+ not one-hot encoded masks.
45
+
46
+ Args:
47
+ y_true (np.ndarray): Ground truth masks (e.g., shape [N, H, W]).
48
+ y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
49
+ save_dir (str | Path): Directory to save the metrics report and plots.
50
+ class_names (List[str] | None): Names of the classes for the report.
51
+ """
52
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
53
+
54
+ # Get all unique class labels present in either true or pred
55
+ labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
56
+
57
+ # --- Setup Class Names ---
58
+ display_names = []
59
+ if class_names is None:
60
+ display_names = [f"Class {i}" for i in labels]
61
+ else:
62
+ if len(class_names) != len(labels):
63
+ _LOGGER.warning(f"Number of class_names ({len(class_names)}) does not match number of unique labels ({len(labels)}). Using default names.")
64
+ display_names = [f"Class {i}" for i in labels]
65
+ else:
66
+ display_names = class_names
67
+
68
+ # Flatten masks for sklearn metrics
69
+ y_true_flat = y_true.ravel()
70
+ y_pred_flat = y_pred.ravel()
71
+
72
+ _LOGGER.info("--- Calculating Segmentation Metrics ---")
73
+
74
+ # --- 1. Calculate Metrics ---
75
+ pix_acc = accuracy_score(y_true_flat, y_pred_flat)
76
+
77
+ # Calculate all average types
78
+ dice_micro = f1_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
79
+ iou_micro = jaccard_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
80
+
81
+ dice_macro = f1_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
82
+ iou_macro = jaccard_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
83
+
84
+ dice_weighted = f1_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
85
+ iou_weighted = jaccard_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
86
+
87
+ # Per-class metrics
88
+ dice_per_class = f1_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
89
+ iou_per_class = jaccard_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
90
+
91
+ # --- 2. Create and Save Report ---
92
+ report_lines = [
93
+ "--- Segmentation Report ---",
94
+ f"\nOverall Pixel Accuracy: {pix_acc:.4f}\n",
95
+ "--- Averaged Metrics ---",
96
+ f"{'Average':<10} | {'Dice (F1)':<12} | {'IoU (Jaccard)':<12}",
97
+ "-"*41,
98
+ f"{'Micro':<10} | {dice_micro:<12.4f} | {iou_micro:<12.4f}",
99
+ f"{'Macro':<10} | {dice_macro:<12.4f} | {iou_macro:<12.4f}",
100
+ f"{'Weighted':<10} | {dice_weighted:<12.4f} | {iou_weighted:<12.4f}",
101
+ "\n--- Per-Class Metrics ---",
102
+ ]
103
+
104
+ per_class_data = {
105
+ 'Class': display_names,
106
+ 'Dice': dice_per_class,
107
+ 'IoU': iou_per_class
108
+ }
109
+ per_class_df = pd.DataFrame(per_class_data)
110
+ report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
111
+
112
+ report_string = "\n".join(report_lines)
113
+ print(report_string)
114
+
115
+ # Save text report
116
+ save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
117
+ report_path = save_dir_path / save_filename
118
+ report_path.write_text(report_string, encoding="utf-8")
119
+ _LOGGER.info(f"📝 Segmentation report saved as '{report_path.name}'")
120
+
121
+ # --- 3. Save Per-Class Metrics Heatmap ---
122
+ try:
123
+ plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=100)
124
+ sns.heatmap(
125
+ per_class_df.set_index('Class').T,
126
+ annot=True,
127
+ cmap='viridis',
128
+ fmt='.3f',
129
+ linewidths=0.5
130
+ )
131
+ plt.title("Per-Class Segmentation Metrics")
132
+ plt.tight_layout()
133
+ heatmap_filename = VisionKeys.SEGMENTATION_HEATMAP + ".svg"
134
+ heatmap_path = save_dir_path / heatmap_filename
135
+ plt.savefig(heatmap_path)
136
+ _LOGGER.info(f"📊 Metrics heatmap saved as '{heatmap_path.name}'")
137
+ plt.close()
138
+ except Exception as e:
139
+ _LOGGER.error(f"Could not generate segmentation metrics heatmap: {e}")
140
+
141
+ # --- 4. Save Pixel-level Confusion Matrix ---
142
+ try:
143
+ # Calculate CM
144
+ cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
145
+
146
+ # Plot
147
+ fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=100)
148
+ disp = ConfusionMatrixDisplay(
149
+ confusion_matrix=cm,
150
+ display_labels=display_names
151
+ )
152
+ disp.plot(cmap='Blues', ax=ax_cm, xticks_rotation=45)
153
+
154
+ ax_cm.set_title("Pixel-Level Confusion Matrix")
155
+ plt.tight_layout()
156
+ segmentation_cm_filename = VisionKeys.SEGMENTATION_CONFUSION_MATRIX + ".svg"
157
+ cm_path = save_dir_path / segmentation_cm_filename
158
+ plt.savefig(cm_path)
159
+ _LOGGER.info(f"❇️ Pixel-level confusion matrix saved as '{cm_path.name}'")
160
+ plt.close(fig_cm)
161
+ except Exception as e:
162
+ _LOGGER.error(f"Could not generate confusion matrix: {e}")
163
+
164
+
165
+ def object_detection_metrics(
166
+ preds: List[Dict[str, torch.Tensor]],
167
+ targets: List[Dict[str, torch.Tensor]],
168
+ save_dir: Union[str, Path],
169
+ class_names: Optional[List[str]] = None,
170
+ print_output: bool=False
171
+ ):
172
+ """
173
+ Calculates and saves object detection metrics (mAP) using torchmetrics.
174
+
175
+ This function expects predictions and targets in the standard
176
+ torchvision format (list of dictionaries).
177
+
178
+ Args:
179
+ preds (List[Dict[str, torch.Tensor]]): A list of predictions.
180
+ Each dict must contain:
181
+ - 'boxes': [N, 4] (xmin, ymin, xmax, ymax)
182
+ - 'scores': [N]
183
+ - 'labels': [N]
184
+ targets (List[Dict[str, torch.Tensor]]): A list of ground truths.
185
+ Each dict must contain:
186
+ - 'boxes': [M, 4]
187
+ - 'labels': [M]
188
+ save_dir (str | Path): Directory to save the metrics report (as JSON).
189
+ class_names (List[str] | None): A list of class names, including 'background'
190
+ at index 0. Used to label per-class metrics in the report.
191
+ print_output (bool): If True, prints the JSON report to the console.
192
+ """
193
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
194
+
195
+ _LOGGER.info("--- Calculating Object Detection Metrics (mAP) ---")
196
+
197
+ try:
198
+ # Initialize the metric with standard COCO settings
199
+ metric = MeanAveragePrecision(box_format='xyxy')
200
+
201
+ # Move preds and targets to the same device (e.g., CPU for metric calculation)
202
+ # This avoids device mismatches if model was on GPU
203
+ device = torch.device("cpu")
204
+ preds_cpu = [{k: v.to(device) for k, v in p.items()} for p in preds]
205
+ targets_cpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
206
+
207
+ # Update the metric
208
+ metric.update(preds_cpu, targets_cpu)
209
+
210
+ # Compute the final metrics
211
+ results = metric.compute()
212
+
213
+ # --- Handle class names for per-class metrics ---
214
+ report_class_names = None
215
+ if class_names:
216
+ if class_names[0].lower() in ['background', "bg"]:
217
+ report_class_names = class_names[1:] # Skip background (class 0)
218
+ else:
219
+ _LOGGER.warning("class_names provided to object_detection_metrics, but 'background' was not class 0. Using all provided names.")
220
+ report_class_names = class_names
221
+
222
+ # Convert all torch tensors in results to floats/lists for JSON serialization
223
+ serializable_results = {}
224
+ for key, value in results.items():
225
+ if isinstance(value, torch.Tensor):
226
+ if value.numel() == 1:
227
+ serializable_results[key] = value.item()
228
+ # Check if it's a 1D tensor, we have class names, and it's a known per-class key
229
+ elif value.ndim == 1 and report_class_names and key in ('map_per_class', 'mar_100_per_class', 'mar_1_per_class', 'mar_10_per_class'):
230
+ per_class_list = value.cpu().numpy().tolist()
231
+ # Map names to values
232
+ if len(per_class_list) == len(report_class_names):
233
+ serializable_results[key] = {name: val for name, val in zip(report_class_names, per_class_list)}
234
+ else:
235
+ _LOGGER.warning(f"Length mismatch for '{key}': {len(per_class_list)} values vs {len(report_class_names)} class names. Saving as raw list.")
236
+ serializable_results[key] = per_class_list
237
+ else:
238
+ serializable_results[key] = value.cpu().numpy().tolist()
239
+ else:
240
+ serializable_results[key] = value
241
+
242
+ # Pretty print to console
243
+ if print_output:
244
+ print(json.dumps(serializable_results, indent=4))
245
+
246
+ # Save JSON report
247
+ detection_report_filename = VisionKeys.OBJECT_DETECTION_REPORT + ".json"
248
+ report_path = save_dir_path / detection_report_filename
249
+ with open(report_path, 'w') as f:
250
+ json.dump(serializable_results, f, indent=4)
251
+
252
+ _LOGGER.info(f"📊 Object detection (mAP) report saved as '{report_path.name}'")
253
+
254
+ except Exception as e:
255
+ _LOGGER.error(f"Failed to compute mAP: {e}")
256
+ raise
257
+
258
+
259
+ def info():
260
+ _script_info(__all__)
@@ -0,0 +1,428 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np #numpy array return value
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Dict, Any, List, Optional, Callable
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from ._script_info import _script_info
10
+ from ._logger import _LOGGER
11
+ from .path_manager import make_fullpath
12
+ from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
+ from ._ML_vision_recipe import load_recipe_and_build_transform
14
+
15
+
16
+ __all__ = [
17
+ "PyTorchVisionInferenceHandler"
18
+ ]
19
+
20
+
21
+ class PyTorchVisionInferenceHandler:
22
+ """
23
+ Handles loading a PyTorch vision model's state dictionary and performing inference.
24
+
25
+ This class is specifically for vision models, which typically expect
26
+ 4D Tensors (B, C, H, W) or Lists of Tensors as input.
27
+ It does NOT use a scaler, as preprocessing (e.g., normalization)
28
+ is assumed to be part of the input transform pipeline.
29
+ """
30
+ def __init__(self,
31
+ model: nn.Module,
32
+ state_dict: Union[str, Path],
33
+ task: Literal["image_classification", "image_segmentation", "object_detection"],
34
+ device: str = 'cpu',
35
+ transform_source: Optional[Union[str, Path, Callable]] = None,
36
+ class_map: Optional[Dict[str, int]] = None):
37
+ """
38
+ Initializes the vision inference handler.
39
+
40
+ Args:
41
+ model (nn.Module): An instantiated PyTorch model from ML_vision_models.
42
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
43
+ task (str): The type of vision task.
44
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
45
+ transform_source (str | Path | Callable | None):
46
+ - A path to a .json recipe file (str or Path).
47
+ - A pre-built transformation pipeline (Callable).
48
+ - None, in which case .set_transform() must be called explicitly to set transformations.
49
+ idx_to_class (Dict[int, str] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For image classification and object detection)
50
+ """
51
+ self._model = model
52
+ self._device = self._validate_device(device)
53
+ self._transform: Optional[Callable] = None
54
+ self._is_transformed: bool = False
55
+ self._idx_to_class: Optional[Dict[int, str]] = None
56
+ if class_map is not None:
57
+ self.set_class_map(class_map)
58
+
59
+ if task not in ["image_classification", "image_segmentation", "object_detection"]:
60
+ _LOGGER.error(f"`task` must be 'image_classification', 'image_segmentation', or 'object_detection'. Got '{task}'.")
61
+ raise ValueError("Invalid task type.")
62
+ self.task = task
63
+
64
+ self.expected_in_channels: int = 3 # Default to RGB
65
+ if hasattr(model, 'in_channels'):
66
+ self.expected_in_channels = model.in_channels # type: ignore
67
+ _LOGGER.info(f"Model expects {self.expected_in_channels} input channels.")
68
+ else:
69
+ _LOGGER.warning("Could not determine 'in_channels' from model. Defaulting to 3 (RGB). Modify with '.expected_in_channels'.")
70
+
71
+ if transform_source:
72
+ self.set_transform(transform_source)
73
+ self._is_transformed = True
74
+
75
+ model_p = make_fullpath(state_dict, enforce="file")
76
+
77
+ try:
78
+ # Load whatever is in the file
79
+ loaded_data = torch.load(model_p, map_location=self._device)
80
+
81
+ # Check if it's a new checkpoint dictionary or an old weights-only file
82
+ if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
83
+ # It's a new training checkpoint, extract the weights
84
+ self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
85
+ else:
86
+ # It's an old-style file (or just a state_dict), load it directly
87
+ self._model.load_state_dict(loaded_data)
88
+
89
+ _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
90
+
91
+ self._model.to(self._device)
92
+ self._model.eval() # Set the model to evaluation mode
93
+ except Exception as e:
94
+ _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
95
+ raise
96
+
97
+ def _validate_device(self, device: str) -> torch.device:
98
+ """Validates the selected device and returns a torch.device object."""
99
+ device_lower = device.lower()
100
+ if "cuda" in device_lower and not torch.cuda.is_available():
101
+ _LOGGER.warning("CUDA not available, switching to CPU.")
102
+ device_lower = "cpu"
103
+ elif device_lower == "mps" and not torch.backends.mps.is_available():
104
+ _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
105
+ device_lower = "cpu"
106
+ return torch.device(device_lower)
107
+
108
+ def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
109
+ """
110
+ Validates input and moves it to the correct device.
111
+ - For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
112
+ - For Object Detection: Expects List[Tensor(C, H, W)].
113
+ """
114
+ if self.task == "object_detection":
115
+ if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
116
+ _LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
117
+ raise ValueError("Invalid input type for object detection.")
118
+ # Move each tensor in the list to the device
119
+ return [t.float().to(self._device) for t in inputs]
120
+
121
+ else: # Classification or Segmentation
122
+ if not isinstance(inputs, torch.Tensor):
123
+ _LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
124
+ raise ValueError(f"Invalid input type for {self.task}.")
125
+
126
+ if inputs.ndim != 4:
127
+ _LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
128
+ raise ValueError("Input tensor must be 4D.")
129
+
130
+ return inputs.float().to(self._device)
131
+
132
+ def set_transform(self, transform_source: Union[str, Path, Callable]):
133
+ """
134
+ Sets or updates the inference transformation pipeline from a recipe file or a direct Callable.
135
+
136
+ Args:
137
+ transform_source (str, Path, Callable):
138
+ - A path to a .json recipe file (str or Path).
139
+ - A pre-built transformation pipeline (Callable).
140
+ """
141
+ if self._is_transformed:
142
+ _LOGGER.warning("Transformations were previously applied. Applying new transformations...")
143
+
144
+ if isinstance(transform_source, (str, Path)):
145
+ _LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
146
+ try:
147
+ # Use the new loader function
148
+ self._transform = load_recipe_and_build_transform(transform_source)
149
+ except Exception as e:
150
+ _LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
151
+ raise
152
+ elif isinstance(transform_source, Callable):
153
+ _LOGGER.info("Inference transform has been set from a direct Callable.")
154
+ self._transform = transform_source
155
+ else:
156
+ _LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
157
+ raise TypeError("transform_source must be a file path or a Callable.")
158
+
159
+ def set_class_map(self, class_map: Dict[str, int]):
160
+ """
161
+ Sets the class name mapping to translate predicted integer labels
162
+ back into string names.
163
+
164
+ Args:
165
+ class_map (Dict[str, int]): The class_to_idx dictionary
166
+ (e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
167
+ """
168
+ if self._idx_to_class is not None:
169
+ _LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
170
+ # Invert the dictionary for fast lookup
171
+ self._idx_to_class = {v: k for k, v in class_map.items()}
172
+ _LOGGER.info("Class map set for label-to-name translation.")
173
+
174
+ def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
175
+ """
176
+ Core batch prediction method for vision models.
177
+ All preprocessing (resizing, normalization) should be done *before*
178
+ calling this method.
179
+
180
+ Args:
181
+ inputs (torch.Tensor | List[torch.Tensor]):
182
+ - For 'image_classification' or 'image_segmentation',
183
+ a 4D torch.Tensor (B, C, H, W).
184
+ - For 'object_detection', a List of 3D torch.Tensors
185
+ [(C, H, W), ...], each with its own size.
186
+
187
+ Returns:
188
+ A dictionary containing the output tensors.
189
+ - Classification: {labels, probabilities}
190
+ - Segmentation: {labels, probabilities} (labels is the mask)
191
+ - Object Detection: {predictions} (List of dicts)
192
+ """
193
+
194
+ processed_inputs = self._preprocess_batch(inputs)
195
+
196
+ with torch.no_grad():
197
+ if self.task == "image_classification":
198
+ # --- Image Classification ---
199
+ # 1. Predict
200
+ output = self._model(processed_inputs) # (B, num_classes)
201
+
202
+ # 2. Post-process
203
+ probs = torch.softmax(output, dim=1)
204
+ labels = torch.argmax(probs, dim=1)
205
+ return {
206
+ PyTorchInferenceKeys.LABELS: labels, # (B,)
207
+ PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
208
+ }
209
+
210
+ elif self.task == "image_segmentation":
211
+ # --- Image Segmentation ---
212
+ # 1. Predict
213
+ output = self._model(processed_inputs) # (B, num_classes, H, W)
214
+
215
+ # 2. Post-process
216
+ probs = torch.softmax(output, dim=1) # Probs across class channel
217
+ labels = torch.argmax(probs, dim=1) # (B, H, W) segmentation map
218
+ return {
219
+ PyTorchInferenceKeys.LABELS: labels, # (B, H, W)
220
+ PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes, H, W)
221
+ }
222
+
223
+ elif self.task == "object_detection":
224
+ # --- Object Detection ---
225
+ # 1. Predict (model is in eval mode, expects only images)
226
+ # Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
227
+ predictions = self._model(processed_inputs)
228
+
229
+ # 2. Post-process: Wrap in our standard key
230
+ return {
231
+ PyTorchInferenceKeys.PREDICTIONS: predictions
232
+ }
233
+
234
+ else:
235
+ # This should be unreachable due to __init__ check
236
+ raise ValueError(f"Unknown task: {self.task}")
237
+
238
+ def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
239
+ """
240
+ Core single-sample prediction method for vision models.
241
+ All preprocessing (resizing, normalization) should be done *before*
242
+ calling this method.
243
+
244
+ Args:
245
+ single_input (torch.Tensor):
246
+ - A 3D torch.Tensor (C, H, W) for any task.
247
+
248
+ Returns:
249
+ A dictionary containing the output tensors for a single sample.
250
+ - Classification: {labels, probabilities} (label is 0-dim)
251
+ - Segmentation: {labels, probabilities} (label is 2D mask)
252
+ - Object Detection: {boxes, labels, scores} (single dict)
253
+ """
254
+ if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
255
+ _LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
256
+ raise ValueError("Input must be a 3D tensor.")
257
+
258
+ # --- 1. Batch the input based on task ---
259
+ if self.task == "object_detection":
260
+ batched_input = [single_input] # List of one tensor
261
+ else:
262
+ batched_input = single_input.unsqueeze(0) # (1, C, H, W)
263
+
264
+ # --- 2. Call batch prediction ---
265
+ batch_results = self.predict_batch(batched_input)
266
+
267
+ # --- 3. Un-batch the results ---
268
+ if self.task == "object_detection":
269
+ # batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
270
+ return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
271
+ else:
272
+ # 'labels' and 'probabilities' are tensors. Get the 0-th element.
273
+ # (B, ...) -> (...)
274
+ single_results = {key: value[0] for key, value in batch_results.items()}
275
+ return single_results
276
+
277
+ # --- NumPy Convenience Wrappers (on CPU) ---
278
+
279
+ def predict_batch_numpy(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
280
+ """
281
+ Convenience wrapper for predict_batch that returns NumPy arrays. With Labels if set.
282
+
283
+ Returns:
284
+ Dict: A dictionary containing the outputs as NumPy arrays.
285
+ - Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
286
+ - Classification: {labels: int, label_names: str, probabilities: np.ndarray}
287
+ - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
288
+ """
289
+ tensor_results = self.predict_batch(inputs)
290
+
291
+ if self.task == "object_detection":
292
+ # Output is List[Dict[str, Tensor]]
293
+ # Convert each tensor inside each dict to numpy
294
+ numpy_results = []
295
+ for pred_dict in tensor_results[PyTorchInferenceKeys.PREDICTIONS]:
296
+ # Convert all tensors to numpy
297
+ np_dict = {key: value.cpu().numpy() for key, value in pred_dict.items()}
298
+
299
+ # Add string names if map exists
300
+ if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
301
+ np_dict[PyTorchInferenceKeys.LABEL_NAMES] = [
302
+ self._idx_to_class.get(label_id, "Unknown")
303
+ for label_id in np_dict[PyTorchInferenceKeys.LABELS]
304
+ ]
305
+ numpy_results.append(np_dict)
306
+ return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
307
+ else:
308
+ # Output is Dict[str, Tensor]
309
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
310
+
311
+ # Add string names for classification if map exists
312
+ if self.task == "image_classification" and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
313
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
314
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
315
+ self._idx_to_class.get(label_id, "Unknown")
316
+ for label_id in int_labels
317
+ ]
318
+
319
+ return numpy_results
320
+
321
+ def predict_numpy(self, single_input: torch.Tensor) -> Dict[str, Any]:
322
+ """
323
+ Convenience wrapper for predict that returns NumPy arrays/scalars.
324
+
325
+ Returns:
326
+ Dict: A dictionary containing the outputs as NumPy arrays/scalars.
327
+ - Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
328
+ - Classification: {labels: int, label_names: str, probabilities: np.ndarray}
329
+ - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
330
+ """
331
+ tensor_results = self.predict(single_input)
332
+
333
+ if self.task == "object_detection":
334
+ # Output is Dict[str, Tensor]
335
+ # Convert each tensor to numpy
336
+ numpy_results = {
337
+ key: value.cpu().numpy() for key, value in tensor_results.items()
338
+ }
339
+
340
+ # Add string names if map exists
341
+ if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
342
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
343
+
344
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
345
+ self._idx_to_class.get(label_id, "Unknown")
346
+ for label_id in int_labels
347
+ ]
348
+
349
+ return numpy_results
350
+
351
+ elif self.task == "image_classification":
352
+ # Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
353
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
354
+ label_name = "Unknown"
355
+ if self._idx_to_class:
356
+ label_name = self._idx_to_class.get(int_label, "Unknown")
357
+
358
+ return {
359
+ PyTorchInferenceKeys.LABELS: int_label,
360
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
361
+ PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
362
+ }
363
+ else: # image_segmentation
364
+ # Output is Dict[str, Tensor(2D) or Tensor(3D)]
365
+ return {
366
+ PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
367
+ PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
368
+ }
369
+
370
+ def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
371
+ """
372
+ Loads a single image from a file, applies the stored transform, and returns the prediction.
373
+
374
+ Args:
375
+ image_path (str | Path): The file path to the input image.
376
+
377
+ Returns:
378
+ Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
379
+ """
380
+ if self._transform is None:
381
+ _LOGGER.error("Cannot predict from file: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
382
+ raise RuntimeError("Inference transform is not set.")
383
+
384
+ try:
385
+ # --- Use expected_in_channels to set PIL mode ---
386
+ pil_mode: str
387
+ if self.expected_in_channels == 1:
388
+ pil_mode = "L" # Grayscale
389
+ elif self.expected_in_channels == 4:
390
+ pil_mode = "RGBA" # RGB + Alpha
391
+ else:
392
+ if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
393
+ _LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
394
+ # Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
395
+ pil_mode = "RGB"
396
+
397
+ image = Image.open(image_path).convert(pil_mode)
398
+ except Exception as e:
399
+ _LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
400
+ raise
401
+
402
+ # Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
403
+ try:
404
+ transformed_image = self._transform(image)
405
+ except Exception as e:
406
+ _LOGGER.error(f"Error applying transform to image: {e}")
407
+ raise
408
+
409
+ # --- Validation ---
410
+ if not isinstance(transformed_image, torch.Tensor):
411
+ _LOGGER.error("The provided transform did not return a torch.Tensor. Does it include transforms.ToTensor()?")
412
+ raise ValueError("Transform pipeline must output a torch.Tensor.")
413
+
414
+ if transformed_image.ndim != 3:
415
+ _LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
416
+ # .predict() which expects a 3D tensor
417
+ if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
418
+ transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
419
+ _LOGGER.warning("Removed an extra batch dimension.")
420
+ else:
421
+ raise ValueError(f"Transform must output a 3D (C, H, W) tensor, got {transformed_image.shape}.")
422
+
423
+ # Use the existing single-item predict method
424
+ return self.predict_numpy(transformed_image)
425
+
426
+
427
+ def info():
428
+ _script_info(__all__)