dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,284 @@
1
+ import numpy as np
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import torch
6
+ from sklearn.metrics import (
7
+ accuracy_score,
8
+ f1_score,
9
+ jaccard_score,
10
+ confusion_matrix,
11
+ ConfusionMatrixDisplay
12
+ )
13
+ from pathlib import Path
14
+ from typing import Union, Optional, List, Dict
15
+ import json
16
+ from torchmetrics.detection import MeanAveragePrecision
17
+
18
+ from .path_manager import make_fullpath
19
+ from ._logger import _LOGGER
20
+ from ._script_info import _script_info
21
+ from ._keys import VisionKeys
22
+ from .ML_configuration import (BinarySegmentationMetricsFormat,
23
+ MultiClassSegmentationMetricsFormat,
24
+ _BaseSegmentationFormat)
25
+
26
+
27
+ __all__ = [
28
+ "segmentation_metrics",
29
+ "object_detection_metrics"
30
+ ]
31
+
32
+ DPI_value = 250
33
+
34
+
35
+ def segmentation_metrics(
36
+ y_true: np.ndarray,
37
+ y_pred: np.ndarray,
38
+ save_dir: Union[str, Path],
39
+ class_names: Optional[List[str]] = None,
40
+ config: Optional[Union[BinarySegmentationMetricsFormat, MultiClassSegmentationMetricsFormat]] = None
41
+ ):
42
+ """
43
+ Calculates and saves pixel-level metrics for segmentation tasks.
44
+
45
+ Metrics include Pixel Accuracy, Dice (F1-score), and IoU (Jaccard).
46
+ It calculates 'micro', 'macro', and 'weighted' averages and saves
47
+ a pixel-level confusion matrix and a metrics heatmap.
48
+
49
+ Note: This function expects integer-based masks (e.g., shape [N, H, W] or [H, W]),
50
+ not one-hot encoded masks.
51
+
52
+ Args:
53
+ y_true (np.ndarray): Ground truth masks (e.g., shape [N, H, W]).
54
+ y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
55
+ save_dir (str | Path): Directory to save the metrics report and plots.
56
+ class_names (List[str] | None): Names of the classes for the report.
57
+ config (object): Formatting configuration object.
58
+ """
59
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
60
+
61
+ # --- Parse Config or use defaults ---
62
+ if config is None:
63
+ format_config = _BaseSegmentationFormat()
64
+ else:
65
+ format_config = config
66
+
67
+ # --- Set Matplotlib font size ---
68
+ original_rc_params = plt.rcParams.copy()
69
+ plt.rcParams.update({'font.size': format_config.font_size})
70
+
71
+ # Get all unique class labels present in either true or pred
72
+ labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
73
+
74
+ # --- Setup Class Names ---
75
+ display_names = []
76
+ if class_names is None:
77
+ display_names = [f"Class {i}" for i in labels]
78
+ else:
79
+ if len(class_names) != len(labels):
80
+ _LOGGER.warning(f"Number of class_names ({len(class_names)}) does not match number of unique labels ({len(labels)}). Using default names.")
81
+ display_names = [f"Class {i}" for i in labels]
82
+ else:
83
+ display_names = class_names
84
+
85
+ # Flatten masks for sklearn metrics
86
+ y_true_flat = y_true.ravel()
87
+ y_pred_flat = y_pred.ravel()
88
+
89
+ _LOGGER.info("--- Calculating Segmentation Metrics ---")
90
+
91
+ # --- 1. Calculate Metrics ---
92
+ pix_acc = accuracy_score(y_true_flat, y_pred_flat)
93
+
94
+ # Calculate all average types
95
+ dice_micro = f1_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
96
+ iou_micro = jaccard_score(y_true_flat, y_pred_flat, average='micro', labels=labels)
97
+
98
+ dice_macro = f1_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
99
+ iou_macro = jaccard_score(y_true_flat, y_pred_flat, average='macro', labels=labels, zero_division=0)
100
+
101
+ dice_weighted = f1_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
102
+ iou_weighted = jaccard_score(y_true_flat, y_pred_flat, average='weighted', labels=labels, zero_division=0)
103
+
104
+ # Per-class metrics
105
+ dice_per_class = f1_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
106
+ iou_per_class = jaccard_score(y_true_flat, y_pred_flat, average=None, labels=labels, zero_division=0)
107
+
108
+ # --- 2. Create and Save Report ---
109
+ report_lines = [
110
+ "--- Segmentation Report ---",
111
+ f"\nOverall Pixel Accuracy: {pix_acc:.4f}\n",
112
+ "--- Averaged Metrics ---",
113
+ f"{'Average':<10} | {'Dice (F1)':<12} | {'IoU (Jaccard)':<12}",
114
+ "-"*41,
115
+ f"{'Micro':<10} | {dice_micro:<12.4f} | {iou_micro:<12.4f}",
116
+ f"{'Macro':<10} | {dice_macro:<12.4f} | {iou_macro:<12.4f}",
117
+ f"{'Weighted':<10} | {dice_weighted:<12.4f} | {iou_weighted:<12.4f}",
118
+ "\n--- Per-Class Metrics ---",
119
+ ]
120
+
121
+ per_class_data = {
122
+ 'Class': display_names,
123
+ 'Dice': dice_per_class,
124
+ 'IoU': iou_per_class
125
+ }
126
+ per_class_df = pd.DataFrame(per_class_data)
127
+ report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
128
+
129
+ report_string = "\n".join(report_lines)
130
+ # print(report_string) # <-- I removed the print(report_string)
131
+
132
+ # Save text report
133
+ save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
134
+ report_path = save_dir_path / save_filename
135
+ report_path.write_text(report_string, encoding="utf-8")
136
+ _LOGGER.info(f"📝 Segmentation report saved as '{report_path.name}'")
137
+
138
+ # --- 3. Save Per-Class Metrics Heatmap ---
139
+ try:
140
+ plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=DPI_value)
141
+ sns.heatmap(
142
+ per_class_df.set_index('Class').T,
143
+ annot=True,
144
+ cmap=format_config.heatmap_cmap, # Use config cmap
145
+ fmt='.3f',
146
+ linewidths=0.5
147
+ )
148
+ plt.title("Per-Class Segmentation Metrics")
149
+ plt.tight_layout()
150
+ heatmap_filename = VisionKeys.SEGMENTATION_HEATMAP + ".svg"
151
+ heatmap_path = save_dir_path / heatmap_filename
152
+ plt.savefig(heatmap_path)
153
+ _LOGGER.info(f"📊 Metrics heatmap saved as '{heatmap_path.name}'")
154
+ plt.close()
155
+ except Exception as e:
156
+ _LOGGER.error(f"Could not generate segmentation metrics heatmap: {e}")
157
+
158
+ # --- 4. Save Pixel-level Confusion Matrix ---
159
+ try:
160
+ # Calculate CM
161
+ cm = confusion_matrix(y_true_flat, y_pred_flat, labels=labels)
162
+
163
+ # Plot
164
+ fig_cm, ax_cm = plt.subplots(figsize=(max(8, len(labels) * 0.8), max(8, len(labels) * 0.8)), dpi=100)
165
+ disp = ConfusionMatrixDisplay(
166
+ confusion_matrix=cm,
167
+ display_labels=display_names
168
+ )
169
+ disp.plot(cmap=format_config.cm_cmap, ax=ax_cm, xticks_rotation=45) # Use config cmap
170
+
171
+ # Manually update font size of cell texts
172
+ for text in disp.text_.flatten(): # type: ignore
173
+ text.set_fontsize(format_config.font_size)
174
+
175
+ ax_cm.set_title("Pixel-Level Confusion Matrix")
176
+ plt.tight_layout()
177
+ segmentation_cm_filename = VisionKeys.SEGMENTATION_CONFUSION_MATRIX + ".svg"
178
+ cm_path = save_dir_path / segmentation_cm_filename
179
+ plt.savefig(cm_path)
180
+ _LOGGER.info(f"❇️ Pixel-level confusion matrix saved as '{cm_path.name}'")
181
+ plt.close(fig_cm)
182
+ except Exception as e:
183
+ _LOGGER.error(f"Could not generate confusion matrix: {e}")
184
+
185
+ # --- Restore RC params ---
186
+ plt.rcParams.update(original_rc_params)
187
+
188
+
189
+ def object_detection_metrics(
190
+ preds: List[Dict[str, torch.Tensor]],
191
+ targets: List[Dict[str, torch.Tensor]],
192
+ save_dir: Union[str, Path],
193
+ class_names: Optional[List[str]] = None,
194
+ print_output: bool=False
195
+ ):
196
+ """
197
+ Calculates and saves object detection metrics (mAP) using torchmetrics.
198
+
199
+ This function expects predictions and targets in the standard
200
+ torchvision format (list of dictionaries).
201
+
202
+ Args:
203
+ preds (List[Dict[str, torch.Tensor]]): A list of predictions.
204
+ Each dict must contain:
205
+ - 'boxes': [N, 4] (xmin, ymin, xmax, ymax)
206
+ - 'scores': [N]
207
+ - 'labels': [N]
208
+ targets (List[Dict[str, torch.Tensor]]): A list of ground truths.
209
+ Each dict must contain:
210
+ - 'boxes': [M, 4]
211
+ - 'labels': [M]
212
+ save_dir (str | Path): Directory to save the metrics report (as JSON).
213
+ class_names (List[str] | None): A list of class names, including 'background'
214
+ at index 0. Used to label per-class metrics in the report.
215
+ print_output (bool): If True, prints the JSON report to the console.
216
+ """
217
+ save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
218
+
219
+ _LOGGER.info("--- Calculating Object Detection Metrics (mAP) ---")
220
+
221
+ try:
222
+ # Initialize the metric with standard COCO settings
223
+ metric = MeanAveragePrecision(box_format='xyxy')
224
+
225
+ # Move preds and targets to the same device (e.g., CPU for metric calculation)
226
+ # This avoids device mismatches if model was on GPU
227
+ device = torch.device("cpu")
228
+ preds_cpu = [{k: v.to(device) for k, v in p.items()} for p in preds]
229
+ targets_cpu = [{k: v.to(device) for k, v in t.items()} for t in targets]
230
+
231
+ # Update the metric
232
+ metric.update(preds_cpu, targets_cpu)
233
+
234
+ # Compute the final metrics
235
+ results = metric.compute()
236
+
237
+ # --- Handle class names for per-class metrics ---
238
+ report_class_names = None
239
+ if class_names:
240
+ if class_names[0].lower() in ['background', "bg"]:
241
+ report_class_names = class_names[1:] # Skip background (class 0)
242
+ else:
243
+ _LOGGER.warning("class_names provided to object_detection_metrics, but 'background' was not class 0. Using all provided names.")
244
+ report_class_names = class_names
245
+
246
+ # Convert all torch tensors in results to floats/lists for JSON serialization
247
+ serializable_results = {}
248
+ for key, value in results.items():
249
+ if isinstance(value, torch.Tensor):
250
+ if value.numel() == 1:
251
+ serializable_results[key] = value.item()
252
+ # Check if it's a 1D tensor, we have class names, and it's a known per-class key
253
+ elif value.ndim == 1 and report_class_names and key in ('map_per_class', 'mar_100_per_class', 'mar_1_per_class', 'mar_10_per_class'):
254
+ per_class_list = value.cpu().numpy().tolist()
255
+ # Map names to values
256
+ if len(per_class_list) == len(report_class_names):
257
+ serializable_results[key] = {name: val for name, val in zip(report_class_names, per_class_list)}
258
+ else:
259
+ _LOGGER.warning(f"Length mismatch for '{key}': {len(per_class_list)} values vs {len(report_class_names)} class names. Saving as raw list.")
260
+ serializable_results[key] = per_class_list
261
+ else:
262
+ serializable_results[key] = value.cpu().numpy().tolist()
263
+ else:
264
+ serializable_results[key] = value
265
+
266
+ # Pretty print to console
267
+ if print_output:
268
+ print(json.dumps(serializable_results, indent=4))
269
+
270
+ # Save JSON report
271
+ detection_report_filename = VisionKeys.OBJECT_DETECTION_REPORT + ".json"
272
+ report_path = save_dir_path / detection_report_filename
273
+ with open(report_path, 'w') as f:
274
+ json.dump(serializable_results, f, indent=4)
275
+
276
+ _LOGGER.info(f"📊 Object detection (mAP) report saved as '{report_path.name}'")
277
+
278
+ except Exception as e:
279
+ _LOGGER.error(f"Failed to compute mAP: {e}")
280
+ raise
281
+
282
+
283
+ def info():
284
+ _script_info(__all__)
@@ -0,0 +1,405 @@
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np #numpy array return value
4
+ from pathlib import Path
5
+ from typing import Union, Literal, Dict, Any, List, Optional, Callable
6
+ from PIL import Image
7
+ from torchvision import transforms
8
+
9
+ from ._script_info import _script_info
10
+ from ._logger import _LOGGER
11
+ from ._keys import PyTorchInferenceKeys, MLTaskKeys
12
+ from .ML_vision_transformers import _load_recipe_and_build_transform
13
+ from .ML_inference import _BaseInferenceHandler
14
+
15
+
16
+ __all__ = [
17
+ "DragonVisionInferenceHandler"
18
+ ]
19
+
20
+
21
+ class DragonVisionInferenceHandler(_BaseInferenceHandler):
22
+ """
23
+ Handles loading a PyTorch vision model's state dictionary and performing inference.
24
+
25
+ This class is specifically for vision models, which typically expect
26
+ 4D Tensors (B, C, H, W) or Lists of Tensors as input.
27
+ """
28
+ def __init__(self,
29
+ model: nn.Module,
30
+ state_dict: Union[str, Path],
31
+ task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
32
+ device: str = 'cpu',
33
+ transform_source: Optional[Union[str, Path, Callable]] = None):
34
+ """
35
+ Initializes the vision inference handler.
36
+
37
+ Args:
38
+ model (nn.Module): An instantiated PyTorch model from ML_vision_models.
39
+ state_dict (str | Path): Path to the saved .pth model state_dict file.
40
+ task (str): The type of vision task.
41
+ device (str): The device to run inference on ('cpu', 'cuda', 'mps').
42
+ transform_source (str | Path | Callable | None):
43
+ - A path to a .json recipe file (str or Path).
44
+ - A pre-built transformation pipeline (Callable).
45
+ - None, in which case .set_transform() must be called explicitly to set transformations.
46
+
47
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
48
+ """
49
+ super().__init__(model, state_dict, device, None)
50
+
51
+ self._transform: Optional[Callable] = None
52
+ self._is_transformed: bool = False
53
+
54
+ if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
55
+ _LOGGER.error(f"Unsupported task: '{task}'.")
56
+ raise ValueError()
57
+ self.task = task
58
+
59
+ self.expected_in_channels: int = 3 # Default to RGB
60
+ if hasattr(model, 'in_channels'):
61
+ self.expected_in_channels = model.in_channels # type: ignore
62
+ _LOGGER.info(f"Model expects {self.expected_in_channels} input channels.")
63
+ else:
64
+ _LOGGER.warning("Could not determine 'in_channels' from model. Defaulting to 3 (RGB). Modify with '.expected_in_channels'.")
65
+
66
+ if transform_source:
67
+ self.set_transform(transform_source)
68
+ self._is_transformed = True
69
+
70
+ def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
71
+ """
72
+ Validates input and moves it to the correct device.
73
+ - For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
74
+ - For Object Detection: Expects List[Tensor(C, H, W)].
75
+ """
76
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
77
+ if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
78
+ _LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
79
+ raise ValueError("Invalid input type for object detection.")
80
+ # Move each tensor in the list to the device
81
+ return [t.float().to(self.device) for t in inputs]
82
+
83
+ else: # Classification or Segmentation
84
+ if not isinstance(inputs, torch.Tensor):
85
+ _LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
86
+ raise ValueError(f"Invalid input type for {self.task}.")
87
+
88
+ if inputs.ndim != 4: # type: ignore
89
+ _LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
90
+ raise ValueError("Input tensor must be 4D.")
91
+
92
+ return inputs.float().to(self.device)
93
+
94
+ def set_transform(self, transform_source: Union[str, Path, Callable]):
95
+ """
96
+ Sets or updates the inference transformation pipeline from a recipe file or a direct Callable.
97
+
98
+ Args:
99
+ transform_source (str, Path, Callable):
100
+ - A path to a .json recipe file (str or Path).
101
+ - A pre-built transformation pipeline (Callable).
102
+ """
103
+ if self._is_transformed:
104
+ _LOGGER.warning("Transformations were previously applied. Applying new transformations...")
105
+
106
+ if isinstance(transform_source, (str, Path)):
107
+ _LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
108
+ try:
109
+ # Use the loader function
110
+ self._transform = _load_recipe_and_build_transform(transform_source)
111
+ except Exception as e:
112
+ _LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
113
+ raise
114
+ elif isinstance(transform_source, Callable):
115
+ _LOGGER.info("Inference transform has been set from a direct Callable.")
116
+ self._transform = transform_source
117
+ else:
118
+ _LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
119
+ raise TypeError("transform_source must be a file path or a Callable.")
120
+
121
+ def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
122
+ """
123
+ Core batch prediction method for vision models.
124
+ All preprocessing (resizing, normalization) should be done *before* calling this method.
125
+
126
+ Args:
127
+ inputs (torch.Tensor | List[torch.Tensor]):
128
+ - For binary/multiclass image classification or binary/multiclass image segmentation tasks,
129
+ a 4D torch.Tensor (B, C, H, W).
130
+ - For 'object_detection', a List of 3D torch.Tensors
131
+ [(C, H, W), ...], each with its own size.
132
+
133
+ Returns:
134
+ A dictionary containing the output tensors.
135
+ - Classification: {labels, probabilities}
136
+ - Segmentation: {labels, probabilities} (labels is the mask)
137
+ - Object Detection: {predictions} (List of dicts)
138
+ """
139
+
140
+ processed_inputs = self._preprocess_batch(inputs)
141
+
142
+ with torch.no_grad():
143
+ # get outputs
144
+ output = self.model(processed_inputs)
145
+ if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
146
+ # process
147
+ probs = torch.softmax(output, dim=1)
148
+ labels = torch.argmax(probs, dim=1)
149
+ return {
150
+ PyTorchInferenceKeys.LABELS: labels, # (B,)
151
+ PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
152
+ }
153
+
154
+ elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
155
+ # Assumes model output is [N, 1] (a single logit)
156
+ # Squeeze output from [N, 1] to [N] if necessary
157
+ if output.ndim == 2 and output.shape[1] == 1:
158
+ output = output.squeeze(1)
159
+
160
+ probs = torch.sigmoid(output) # Probability of positive class
161
+ labels = (probs >= self._classification_threshold).int()
162
+ return {
163
+ PyTorchInferenceKeys.LABELS: labels,
164
+ PyTorchInferenceKeys.PROBABILITIES: probs
165
+ }
166
+
167
+ elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
168
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
169
+ probs = torch.sigmoid(output) # Shape [N, 1, H, W]
170
+ labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
171
+ return {
172
+ PyTorchInferenceKeys.LABELS: labels,
173
+ PyTorchInferenceKeys.PROBABILITIES: probs
174
+ }
175
+
176
+ elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
177
+ # output shape [N, C, H, W]
178
+ probs = torch.softmax(output, dim=1)
179
+ labels = torch.argmax(probs, dim=1) # shape [N, H, W]
180
+ return {
181
+ PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
182
+ PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
183
+ }
184
+
185
+ elif self.task == MLTaskKeys.OBJECT_DETECTION:
186
+ return {
187
+ PyTorchInferenceKeys.PREDICTIONS: output
188
+ }
189
+
190
+ else:
191
+ # This should be unreachable due to validation
192
+ raise ValueError(f"Unknown task: {self.task}")
193
+
194
+ def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
195
+ """
196
+ Core single-sample prediction method for vision models.
197
+ All preprocessing (resizing, normalization) should be done *before*
198
+ calling this method.
199
+
200
+ Args:
201
+ single_input (torch.Tensor):
202
+ - A 3D torch.Tensor (C, H, W) for any task.
203
+
204
+ Returns:
205
+ A dictionary containing the output tensors for a single sample.
206
+ - Classification: {labels, probabilities} (label is 0-dim)
207
+ - Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
208
+ - Object Detection: {boxes, labels, scores} (single dict)
209
+ """
210
+ if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
211
+ _LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
212
+ raise ValueError()
213
+
214
+ # --- 1. Batch the input based on task ---
215
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
216
+ batched_input = [single_input] # List of one tensor
217
+ else:
218
+ batched_input = single_input.unsqueeze(0)
219
+
220
+ # --- 2. Call batch prediction ---
221
+ batch_results = self.predict_batch(batched_input)
222
+
223
+ # --- 3. Un-batch the results ---
224
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
225
+ # batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
226
+ return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
227
+ else:
228
+ # 'labels' and 'probabilities' are tensors. Get the 0-th element.
229
+ # (B, ...) -> (...)
230
+ single_results = {key: value[0] for key, value in batch_results.items()}
231
+ return single_results
232
+
233
+ # --- NumPy Convenience Wrappers (on CPU) ---
234
+
235
+ def predict_batch_numpy(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
236
+ """
237
+ Convenience wrapper for predict_batch that returns NumPy arrays. With Labels if set.
238
+
239
+ Returns:
240
+ Dict: A dictionary containing the outputs as NumPy arrays.
241
+ - Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
242
+ - Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
243
+ - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
244
+ """
245
+ tensor_results = self.predict_batch(inputs)
246
+
247
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
248
+ # Output is List[Dict[str, Tensor]]
249
+ # Convert each tensor inside each dict to numpy
250
+ numpy_results = []
251
+ for pred_dict in tensor_results[PyTorchInferenceKeys.PREDICTIONS]:
252
+ # Convert all tensors to numpy
253
+ np_dict = {key: value.cpu().numpy() for key, value in pred_dict.items()}
254
+
255
+ # 3D pixel to string map unnecessary
256
+ # if self._idx_to_class and PyTorchInferenceKeys.LABELS in np_dict:
257
+ # np_dict[PyTorchInferenceKeys.LABEL_NAMES] = [
258
+ # self._idx_to_class.get(label_id, "Unknown")
259
+ # for label_id in np_dict[PyTorchInferenceKeys.LABELS]
260
+ # ]
261
+ numpy_results.append(np_dict)
262
+ return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
263
+
264
+ else:
265
+ # Output is Dict[str, Tensor] (for Classification or Segmentation)
266
+ numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
267
+
268
+ # Add string names for classification if map exists
269
+ is_image_classification = self.task in [
270
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
271
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
272
+ ]
273
+
274
+ if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
275
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
276
+ numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
277
+ self._idx_to_class.get(label_id, "Unknown")
278
+ for label_id in int_labels
279
+ ]
280
+
281
+ return numpy_results
282
+
283
+ def predict_numpy(self, single_input: torch.Tensor) -> Dict[str, Any]:
284
+ """
285
+ Convenience wrapper for predict that returns NumPy arrays/scalars.
286
+
287
+ Returns:
288
+ Dict: A dictionary containing the outputs as NumPy arrays/scalars.
289
+ - Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
290
+ - Classification: {labels: int, label_names: str, probabilities: np.ndarray}
291
+ - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
292
+ """
293
+ tensor_results = self.predict(single_input)
294
+
295
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
296
+ # Output is Dict[str, Tensor]
297
+ # Convert each tensor to numpy
298
+ numpy_results = {
299
+ key: value.cpu().numpy() for key, value in tensor_results.items()
300
+ }
301
+
302
+ # Add string names if map exists
303
+ # if self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
304
+ # int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
305
+
306
+ # numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
307
+ # self._idx_to_class.get(label_id, "Unknown")
308
+ # for label_id in int_labels
309
+ # ]
310
+
311
+ return numpy_results
312
+
313
+ elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
314
+ # Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
315
+ int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
316
+ label_name = "Unknown"
317
+ if self._idx_to_class:
318
+ label_name = self._idx_to_class.get(int_label, "Unknown")
319
+
320
+ return {
321
+ PyTorchInferenceKeys.LABELS: int_label,
322
+ PyTorchInferenceKeys.LABEL_NAMES: label_name,
323
+ PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
324
+ }
325
+ else: # image_segmentation (binary or multiclass)
326
+ # Output is Dict[str, Tensor(2D) or Tensor(3D)]
327
+ return {
328
+ PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
329
+ PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
330
+ }
331
+
332
+ def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
333
+ """
334
+ Applies the stored transform to a single PIL image and returns the prediction.
335
+
336
+ Args:
337
+ image (PIL.Image.Image): The input PIL image.
338
+
339
+ Returns:
340
+ Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
341
+ """
342
+ if self._transform is None:
343
+ _LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
344
+ raise RuntimeError("Inference transform is not set.")
345
+
346
+ # Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
347
+ try:
348
+ transformed_image = self._transform(image)
349
+ except Exception as e:
350
+ _LOGGER.error(f"Error applying transform to PIL image: {e}")
351
+ raise
352
+
353
+ # --- Validation ---
354
+ if not isinstance(transformed_image, torch.Tensor):
355
+ _LOGGER.error("The provided transform did not return a torch.Tensor. Does it include transforms.ToTensor()?")
356
+ raise ValueError("Transform pipeline must output a torch.Tensor.")
357
+
358
+ if transformed_image.ndim != 3:
359
+ _LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
360
+ # .predict_numpy() -> .predict() which expects a 3D tensor
361
+ if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
362
+ transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
363
+ _LOGGER.warning("Removed an extra batch dimension.")
364
+ else:
365
+ raise ValueError(f"Transform must output a 3D (C, H, W) tensor, got {transformed_image.shape}.")
366
+
367
+ # Use the existing single-item predict method
368
+ return self.predict_numpy(transformed_image)
369
+
370
+ def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
371
+ """
372
+ Loads a single image from a file, applies the stored transform, and returns the prediction.
373
+
374
+ This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
375
+
376
+ Args:
377
+ image_path (str | Path): The file path to the input image.
378
+
379
+ Returns:
380
+ Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
381
+ """
382
+ try:
383
+ # --- Use expected_in_channels to set PIL mode ---
384
+ pil_mode: str
385
+ if self.expected_in_channels == 1:
386
+ pil_mode = "L" # Grayscale
387
+ elif self.expected_in_channels == 4:
388
+ pil_mode = "RGBA" # RGB + Alpha
389
+ else:
390
+ if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
391
+ _LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
392
+ # Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
393
+ pil_mode = "RGB"
394
+
395
+ image = Image.open(image_path).convert(pil_mode)
396
+ except Exception as e:
397
+ _LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
398
+ raise
399
+
400
+ # Call the PIL-based prediction method
401
+ return self.predict_from_pil(image)
402
+
403
+
404
+ def info():
405
+ _script_info(__all__)