dragon-ml-toolbox 14.3.1__py3-none-any.whl → 16.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

Files changed (44) hide show
  1. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/METADATA +10 -5
  2. dragon_ml_toolbox-16.0.0.dist-info/RECORD +51 -0
  3. ml_tools/ETL_cleaning.py +20 -20
  4. ml_tools/ETL_engineering.py +23 -25
  5. ml_tools/GUI_tools.py +20 -20
  6. ml_tools/MICE_imputation.py +3 -3
  7. ml_tools/ML_callbacks.py +43 -26
  8. ml_tools/ML_configuration.py +309 -0
  9. ml_tools/ML_datasetmaster.py +220 -260
  10. ml_tools/ML_evaluation.py +317 -81
  11. ml_tools/ML_evaluation_multi.py +127 -36
  12. ml_tools/ML_inference.py +249 -207
  13. ml_tools/ML_models.py +13 -102
  14. ml_tools/ML_models_advanced.py +1 -1
  15. ml_tools/ML_optimization.py +12 -12
  16. ml_tools/ML_scaler.py +11 -11
  17. ml_tools/ML_sequence_datasetmaster.py +341 -0
  18. ml_tools/ML_sequence_evaluation.py +215 -0
  19. ml_tools/ML_sequence_inference.py +391 -0
  20. ml_tools/ML_sequence_models.py +139 -0
  21. ml_tools/ML_trainer.py +1247 -338
  22. ml_tools/ML_utilities.py +51 -2
  23. ml_tools/ML_vision_datasetmaster.py +262 -118
  24. ml_tools/ML_vision_evaluation.py +26 -6
  25. ml_tools/ML_vision_inference.py +117 -140
  26. ml_tools/ML_vision_models.py +15 -1
  27. ml_tools/ML_vision_transformers.py +233 -7
  28. ml_tools/PSO_optimization.py +6 -6
  29. ml_tools/SQL.py +4 -4
  30. ml_tools/{keys.py → _keys.py} +45 -1
  31. ml_tools/_schema.py +1 -1
  32. ml_tools/ensemble_evaluation.py +54 -11
  33. ml_tools/ensemble_inference.py +7 -33
  34. ml_tools/ensemble_learning.py +1 -1
  35. ml_tools/optimization_tools.py +2 -2
  36. ml_tools/path_manager.py +5 -5
  37. ml_tools/utilities.py +1 -2
  38. dragon_ml_toolbox-14.3.1.dist-info/RECORD +0 -48
  39. ml_tools/RNN_forecast.py +0 -56
  40. ml_tools/_ML_vision_recipe.py +0 -88
  41. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/WHEEL +0 -0
  42. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE +0 -0
  43. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
  44. {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-16.0.0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,8 @@ from torchmetrics.detection import MeanAveragePrecision
18
18
  from .path_manager import make_fullpath
19
19
  from ._logger import _LOGGER
20
20
  from ._script_info import _script_info
21
- from .keys import VisionKeys
21
+ from ._keys import VisionKeys
22
+ from .ML_configuration import SegmentationMetricsFormat
22
23
 
23
24
 
24
25
  __all__ = [
@@ -26,12 +27,15 @@ __all__ = [
26
27
  "object_detection_metrics"
27
28
  ]
28
29
 
30
+ DPI_value = 250
31
+
29
32
 
30
33
  def segmentation_metrics(
31
34
  y_true: np.ndarray,
32
35
  y_pred: np.ndarray,
33
36
  save_dir: Union[str, Path],
34
- class_names: Optional[List[str]] = None
37
+ class_names: Optional[List[str]] = None,
38
+ config: Optional[SegmentationMetricsFormat] = None # Add config object
35
39
  ):
36
40
  """
37
41
  Calculates and saves pixel-level metrics for segmentation tasks.
@@ -48,9 +52,18 @@ def segmentation_metrics(
48
52
  y_pred (np.ndarray): Predicted masks (e.g., shape [N, H, W]).
49
53
  save_dir (str | Path): Directory to save the metrics report and plots.
50
54
  class_names (List[str] | None): Names of the classes for the report.
55
+ config (SegmentationMetricsFormat, optional): Formatting configuration object.
51
56
  """
52
57
  save_dir_path = make_fullpath(save_dir, make=True, enforce="directory")
53
58
 
59
+ # --- Parse Config or use defaults ---
60
+ if config is None:
61
+ config = SegmentationMetricsFormat()
62
+
63
+ # --- Set Matplotlib font size ---
64
+ original_rc_params = plt.rcParams.copy()
65
+ plt.rcParams.update({'font.size': config.font_size})
66
+
54
67
  # Get all unique class labels present in either true or pred
55
68
  labels = np.unique(np.concatenate((np.unique(y_true), np.unique(y_pred)))).astype(int)
56
69
 
@@ -110,7 +123,7 @@ def segmentation_metrics(
110
123
  report_lines.append(per_class_df.to_string(index=False, float_format="%.4f"))
111
124
 
112
125
  report_string = "\n".join(report_lines)
113
- print(report_string)
126
+ # print(report_string) # <-- I removed the print(report_string)
114
127
 
115
128
  # Save text report
116
129
  save_filename = VisionKeys.SEGMENTATION_REPORT + ".txt"
@@ -120,11 +133,11 @@ def segmentation_metrics(
120
133
 
121
134
  # --- 3. Save Per-Class Metrics Heatmap ---
122
135
  try:
123
- plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=100)
136
+ plt.figure(figsize=(max(8, len(labels) * 0.5), 6), dpi=DPI_value)
124
137
  sns.heatmap(
125
138
  per_class_df.set_index('Class').T,
126
139
  annot=True,
127
- cmap='viridis',
140
+ cmap=config.heatmap_cmap, # Use config cmap
128
141
  fmt='.3f',
129
142
  linewidths=0.5
130
143
  )
@@ -149,7 +162,11 @@ def segmentation_metrics(
149
162
  confusion_matrix=cm,
150
163
  display_labels=display_names
151
164
  )
152
- disp.plot(cmap='Blues', ax=ax_cm, xticks_rotation=45)
165
+ disp.plot(cmap=config.cm_cmap, ax=ax_cm, xticks_rotation=45) # Use config cmap
166
+
167
+ # Manually update font size of cell texts
168
+ for text in disp.text_.flatten(): # type: ignore
169
+ text.set_fontsize(config.font_size)
153
170
 
154
171
  ax_cm.set_title("Pixel-Level Confusion Matrix")
155
172
  plt.tight_layout()
@@ -160,6 +177,9 @@ def segmentation_metrics(
160
177
  plt.close(fig_cm)
161
178
  except Exception as e:
162
179
  _LOGGER.error(f"Could not generate confusion matrix: {e}")
180
+
181
+ # --- Restore RC params ---
182
+ plt.rcParams.update(original_rc_params)
163
183
 
164
184
 
165
185
  def object_detection_metrics(
@@ -8,32 +8,29 @@ from torchvision import transforms
8
8
 
9
9
  from ._script_info import _script_info
10
10
  from ._logger import _LOGGER
11
- from .path_manager import make_fullpath
12
- from .keys import PyTorchInferenceKeys, PyTorchCheckpointKeys
13
- from ._ML_vision_recipe import load_recipe_and_build_transform
11
+ from ._keys import PyTorchInferenceKeys, MLTaskKeys
12
+ from .ML_vision_transformers import _load_recipe_and_build_transform
13
+ from .ML_inference import _BaseInferenceHandler
14
14
 
15
15
 
16
16
  __all__ = [
17
- "PyTorchVisionInferenceHandler"
17
+ "DragonVisionInferenceHandler"
18
18
  ]
19
19
 
20
20
 
21
- class PyTorchVisionInferenceHandler:
21
+ class DragonVisionInferenceHandler(_BaseInferenceHandler):
22
22
  """
23
23
  Handles loading a PyTorch vision model's state dictionary and performing inference.
24
24
 
25
25
  This class is specifically for vision models, which typically expect
26
26
  4D Tensors (B, C, H, W) or Lists of Tensors as input.
27
- It does NOT use a scaler, as preprocessing (e.g., normalization)
28
- is assumed to be part of the input transform pipeline.
29
27
  """
30
28
  def __init__(self,
31
29
  model: nn.Module,
32
30
  state_dict: Union[str, Path],
33
- task: Literal["image_classification", "image_segmentation", "object_detection"],
31
+ task: Literal["binary image classification", "multiclass image classification", "binary segmentation", "multiclass segmentation", "object detection"],
34
32
  device: str = 'cpu',
35
- transform_source: Optional[Union[str, Path, Callable]] = None,
36
- class_map: Optional[Dict[str, int]] = None):
33
+ transform_source: Optional[Union[str, Path, Callable]] = None):
37
34
  """
38
35
  Initializes the vision inference handler.
39
36
 
@@ -46,19 +43,17 @@ class PyTorchVisionInferenceHandler:
46
43
  - A path to a .json recipe file (str or Path).
47
44
  - A pre-built transformation pipeline (Callable).
48
45
  - None, in which case .set_transform() must be called explicitly to set transformations.
49
- idx_to_class (Dict[int, str] | None): Sets the class name mapping to translate predicted integer labels back into string names. (For image classification and object detection)
46
+
47
+ Note: class_map (Dict[int, str]) will be loaded from the model file, to set or override it use `.set_class_map()`.
50
48
  """
51
- self._model = model
52
- self._device = self._validate_device(device)
49
+ super().__init__(model, state_dict, device, None)
50
+
53
51
  self._transform: Optional[Callable] = None
54
52
  self._is_transformed: bool = False
55
- self._idx_to_class: Optional[Dict[int, str]] = None
56
- if class_map is not None:
57
- self.set_class_map(class_map)
58
53
 
59
- if task not in ["image_classification", "image_segmentation", "object_detection"]:
60
- _LOGGER.error(f"`task` must be 'image_classification', 'image_segmentation', or 'object_detection'. Got '{task}'.")
61
- raise ValueError("Invalid task type.")
54
+ if task not in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION, MLTaskKeys.BINARY_SEGMENTATION, MLTaskKeys.MULTICLASS_SEGMENTATION, MLTaskKeys.OBJECT_DETECTION]:
55
+ _LOGGER.error(f"Unsupported task: '{task}'.")
56
+ raise ValueError()
62
57
  self.task = task
63
58
 
64
59
  self.expected_in_channels: int = 3 # Default to RGB
@@ -71,39 +66,6 @@ class PyTorchVisionInferenceHandler:
71
66
  if transform_source:
72
67
  self.set_transform(transform_source)
73
68
  self._is_transformed = True
74
-
75
- model_p = make_fullpath(state_dict, enforce="file")
76
-
77
- try:
78
- # Load whatever is in the file
79
- loaded_data = torch.load(model_p, map_location=self._device)
80
-
81
- # Check if it's a new checkpoint dictionary or an old weights-only file
82
- if isinstance(loaded_data, dict) and PyTorchCheckpointKeys.MODEL_STATE in loaded_data:
83
- # It's a new training checkpoint, extract the weights
84
- self._model.load_state_dict(loaded_data[PyTorchCheckpointKeys.MODEL_STATE])
85
- else:
86
- # It's an old-style file (or just a state_dict), load it directly
87
- self._model.load_state_dict(loaded_data)
88
-
89
- _LOGGER.info(f"Model state loaded from '{model_p.name}'.")
90
-
91
- self._model.to(self._device)
92
- self._model.eval() # Set the model to evaluation mode
93
- except Exception as e:
94
- _LOGGER.error(f"Failed to load model state from '{model_p}': {e}")
95
- raise
96
-
97
- def _validate_device(self, device: str) -> torch.device:
98
- """Validates the selected device and returns a torch.device object."""
99
- device_lower = device.lower()
100
- if "cuda" in device_lower and not torch.cuda.is_available():
101
- _LOGGER.warning("CUDA not available, switching to CPU.")
102
- device_lower = "cpu"
103
- elif device_lower == "mps" and not torch.backends.mps.is_available():
104
- _LOGGER.warning("Apple Metal Performance Shaders (MPS) not available, switching to CPU.")
105
- device_lower = "cpu"
106
- return torch.device(device_lower)
107
69
 
108
70
  def _preprocess_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Union[torch.Tensor, List[torch.Tensor]]:
109
71
  """
@@ -111,23 +73,23 @@ class PyTorchVisionInferenceHandler:
111
73
  - For Classification/Segmentation: Expects 4D Tensor (B, C, H, W).
112
74
  - For Object Detection: Expects List[Tensor(C, H, W)].
113
75
  """
114
- if self.task == "object_detection":
76
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
115
77
  if not isinstance(inputs, list) or not all(isinstance(t, torch.Tensor) for t in inputs):
116
78
  _LOGGER.error("Input for object_detection must be a List[torch.Tensor].")
117
79
  raise ValueError("Invalid input type for object detection.")
118
80
  # Move each tensor in the list to the device
119
- return [t.float().to(self._device) for t in inputs]
81
+ return [t.float().to(self.device) for t in inputs]
120
82
 
121
83
  else: # Classification or Segmentation
122
84
  if not isinstance(inputs, torch.Tensor):
123
85
  _LOGGER.error(f"Input for {self.task} must be a torch.Tensor.")
124
86
  raise ValueError(f"Invalid input type for {self.task}.")
125
87
 
126
- if inputs.ndim != 4:
127
- _LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.")
88
+ if inputs.ndim != 4: # type: ignore
89
+ _LOGGER.error(f"Input tensor for {self.task} must be 4D (B, C, H, W). Got {inputs.ndim}D.") # type: ignore
128
90
  raise ValueError("Input tensor must be 4D.")
129
91
 
130
- return inputs.float().to(self._device)
92
+ return inputs.float().to(self.device)
131
93
 
132
94
  def set_transform(self, transform_source: Union[str, Path, Callable]):
133
95
  """
@@ -144,8 +106,8 @@ class PyTorchVisionInferenceHandler:
144
106
  if isinstance(transform_source, (str, Path)):
145
107
  _LOGGER.info(f"Loading transform from recipe file: '{transform_source}'")
146
108
  try:
147
- # Use the new loader function
148
- self._transform = load_recipe_and_build_transform(transform_source)
109
+ # Use the loader function
110
+ self._transform = _load_recipe_and_build_transform(transform_source)
149
111
  except Exception as e:
150
112
  _LOGGER.error(f"Failed to load transform from recipe '{transform_source}': {e}")
151
113
  raise
@@ -155,31 +117,15 @@ class PyTorchVisionInferenceHandler:
155
117
  else:
156
118
  _LOGGER.error(f"Invalid transform_source type: {type(transform_source)}. Must be str, Path, or Callable.")
157
119
  raise TypeError("transform_source must be a file path or a Callable.")
158
-
159
- def set_class_map(self, class_map: Dict[str, int]):
160
- """
161
- Sets the class name mapping to translate predicted integer labels
162
- back into string names.
163
-
164
- Args:
165
- class_map (Dict[str, int]): The class_to_idx dictionary
166
- (e.g., {'cat': 0, 'dog': 1}) from the VisionDatasetMaker.
167
- """
168
- if self._idx_to_class is not None:
169
- _LOGGER.warning("Class to index mapping was previously given. Setting new mapping...")
170
- # Invert the dictionary for fast lookup
171
- self._idx_to_class = {v: k for k, v in class_map.items()}
172
- _LOGGER.info("Class map set for label-to-name translation.")
173
120
 
174
121
  def predict_batch(self, inputs: Union[torch.Tensor, List[torch.Tensor]]) -> Dict[str, Any]:
175
122
  """
176
123
  Core batch prediction method for vision models.
177
- All preprocessing (resizing, normalization) should be done *before*
178
- calling this method.
124
+ All preprocessing (resizing, normalization) should be done *before* calling this method.
179
125
 
180
126
  Args:
181
127
  inputs (torch.Tensor | List[torch.Tensor]):
182
- - For 'image_classification' or 'image_segmentation',
128
+ - For binary/multiclass image classification or binary/multiclass image segmentation tasks,
183
129
  a 4D torch.Tensor (B, C, H, W).
184
130
  - For 'object_detection', a List of 3D torch.Tensors
185
131
  [(C, H, W), ...], each with its own size.
@@ -194,45 +140,55 @@ class PyTorchVisionInferenceHandler:
194
140
  processed_inputs = self._preprocess_batch(inputs)
195
141
 
196
142
  with torch.no_grad():
197
- if self.task == "image_classification":
198
- # --- Image Classification ---
199
- # 1. Predict
200
- output = self._model(processed_inputs) # (B, num_classes)
201
-
202
- # 2. Post-process
143
+ # get outputs
144
+ output = self.model(processed_inputs)
145
+ if self.task == MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION:
146
+ # process
203
147
  probs = torch.softmax(output, dim=1)
204
148
  labels = torch.argmax(probs, dim=1)
205
149
  return {
206
150
  PyTorchInferenceKeys.LABELS: labels, # (B,)
207
151
  PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes)
208
152
  }
209
-
210
- elif self.task == "image_segmentation":
211
- # --- Image Segmentation ---
212
- # 1. Predict
213
- output = self._model(processed_inputs) # (B, num_classes, H, W)
153
+
154
+ elif self.task == MLTaskKeys.BINARY_IMAGE_CLASSIFICATION:
155
+ # Assumes model output is [N, 1] (a single logit)
156
+ # Squeeze output from [N, 1] to [N] if necessary
157
+ if output.ndim == 2 and output.shape[1] == 1:
158
+ output = output.squeeze(1)
159
+
160
+ probs = torch.sigmoid(output) # Probability of positive class
161
+ labels = (probs >= self._classification_threshold).int()
162
+ return {
163
+ PyTorchInferenceKeys.LABELS: labels,
164
+ PyTorchInferenceKeys.PROBABILITIES: probs
165
+ }
166
+
167
+ elif self.task == MLTaskKeys.BINARY_SEGMENTATION:
168
+ # Assumes model output is [N, 1, H, W] (logits for positive class)
169
+ probs = torch.sigmoid(output) # Shape [N, 1, H, W]
170
+ labels = (probs >= self._classification_threshold).int() # Shape [N, 1, H, W]
171
+ return {
172
+ PyTorchInferenceKeys.LABELS: labels,
173
+ PyTorchInferenceKeys.PROBABILITIES: probs
174
+ }
214
175
 
215
- # 2. Post-process
216
- probs = torch.softmax(output, dim=1) # Probs across class channel
217
- labels = torch.argmax(probs, dim=1) # (B, H, W) segmentation map
176
+ elif self.task == MLTaskKeys.MULTICLASS_SEGMENTATION:
177
+ # output shape [N, C, H, W]
178
+ probs = torch.softmax(output, dim=1)
179
+ labels = torch.argmax(probs, dim=1) # shape [N, H, W]
218
180
  return {
219
- PyTorchInferenceKeys.LABELS: labels, # (B, H, W)
220
- PyTorchInferenceKeys.PROBABILITIES: probs # (B, num_classes, H, W)
181
+ PyTorchInferenceKeys.LABELS: labels, # (N, H, W)
182
+ PyTorchInferenceKeys.PROBABILITIES: probs # (N, num_classes, H, W)
221
183
  }
222
184
 
223
- elif self.task == "object_detection":
224
- # --- Object Detection ---
225
- # 1. Predict (model is in eval mode, expects only images)
226
- # Output is List[Dict[str, Tensor('boxes', 'labels', 'scores')]]
227
- predictions = self._model(processed_inputs)
228
-
229
- # 2. Post-process: Wrap in our standard key
185
+ elif self.task == MLTaskKeys.OBJECT_DETECTION:
230
186
  return {
231
- PyTorchInferenceKeys.PREDICTIONS: predictions
187
+ PyTorchInferenceKeys.PREDICTIONS: output
232
188
  }
233
189
 
234
190
  else:
235
- # This should be unreachable due to __init__ check
191
+ # This should be unreachable due to validation
236
192
  raise ValueError(f"Unknown task: {self.task}")
237
193
 
238
194
  def predict(self, single_input: torch.Tensor) -> Dict[str, Any]:
@@ -248,24 +204,24 @@ class PyTorchVisionInferenceHandler:
248
204
  Returns:
249
205
  A dictionary containing the output tensors for a single sample.
250
206
  - Classification: {labels, probabilities} (label is 0-dim)
251
- - Segmentation: {labels, probabilities} (label is 2D mask)
207
+ - Segmentation: {labels, probabilities} (label is a 2D (multiclass) or 3D (binary) mask)
252
208
  - Object Detection: {boxes, labels, scores} (single dict)
253
209
  """
254
210
  if not isinstance(single_input, torch.Tensor) or single_input.ndim != 3:
255
211
  _LOGGER.error(f"Input for predict() must be a 3D tensor (C, H, W). Got {single_input.ndim}D.")
256
- raise ValueError("Input must be a 3D tensor.")
212
+ raise ValueError()
257
213
 
258
214
  # --- 1. Batch the input based on task ---
259
- if self.task == "object_detection":
215
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
260
216
  batched_input = [single_input] # List of one tensor
261
217
  else:
262
- batched_input = single_input.unsqueeze(0) # (1, C, H, W)
218
+ batched_input = single_input.unsqueeze(0)
263
219
 
264
220
  # --- 2. Call batch prediction ---
265
221
  batch_results = self.predict_batch(batched_input)
266
222
 
267
223
  # --- 3. Un-batch the results ---
268
- if self.task == "object_detection":
224
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
269
225
  # batch_results['predictions'] is a List[Dict]. We want the first (and only) Dict.
270
226
  return batch_results[PyTorchInferenceKeys.PREDICTIONS][0]
271
227
  else:
@@ -283,12 +239,12 @@ class PyTorchVisionInferenceHandler:
283
239
  Returns:
284
240
  Dict: A dictionary containing the outputs as NumPy arrays.
285
241
  - Obj. Detection: {predictions: List[Dict[str, np.ndarray]]}
286
- - Classification: {labels: int, label_names: str, probabilities: np.ndarray}
242
+ - Classification: {labels: np.ndarray, label_names: List[str], probabilities: np.ndarray}
287
243
  - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
288
244
  """
289
245
  tensor_results = self.predict_batch(inputs)
290
246
 
291
- if self.task == "object_detection":
247
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
292
248
  # Output is List[Dict[str, Tensor]]
293
249
  # Convert each tensor inside each dict to numpy
294
250
  numpy_results = []
@@ -304,13 +260,19 @@ class PyTorchVisionInferenceHandler:
304
260
  ]
305
261
  numpy_results.append(np_dict)
306
262
  return {PyTorchInferenceKeys.PREDICTIONS: numpy_results}
263
+
307
264
  else:
308
- # Output is Dict[str, Tensor]
265
+ # Output is Dict[str, Tensor] (for Classification or Segmentation)
309
266
  numpy_results = {key: value.cpu().numpy() for key, value in tensor_results.items()}
310
267
 
311
268
  # Add string names for classification if map exists
312
- if self.task == "image_classification" and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
313
- int_labels = numpy_results[PyTorchInferenceKeys.LABELS]
269
+ is_image_classification = self.task in [
270
+ MLTaskKeys.BINARY_IMAGE_CLASSIFICATION,
271
+ MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION
272
+ ]
273
+
274
+ if is_image_classification and self._idx_to_class and PyTorchInferenceKeys.LABELS in numpy_results:
275
+ int_labels = numpy_results[PyTorchInferenceKeys.LABELS] # This is a (B,) array
314
276
  numpy_results[PyTorchInferenceKeys.LABEL_NAMES] = [
315
277
  self._idx_to_class.get(label_id, "Unknown")
316
278
  for label_id in int_labels
@@ -324,13 +286,13 @@ class PyTorchVisionInferenceHandler:
324
286
 
325
287
  Returns:
326
288
  Dict: A dictionary containing the outputs as NumPy arrays/scalars.
327
- - Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray}
289
+ - Obj. Detection: {boxes: np.ndarray, labels: np.ndarray, scores: np.ndarray, label_names: List[str]}
328
290
  - Classification: {labels: int, label_names: str, probabilities: np.ndarray}
329
291
  - Segmentation: {labels: np.ndarray, probabilities: np.ndarray}
330
292
  """
331
293
  tensor_results = self.predict(single_input)
332
294
 
333
- if self.task == "object_detection":
295
+ if self.task == MLTaskKeys.OBJECT_DETECTION:
334
296
  # Output is Dict[str, Tensor]
335
297
  # Convert each tensor to numpy
336
298
  numpy_results = {
@@ -348,7 +310,7 @@ class PyTorchVisionInferenceHandler:
348
310
 
349
311
  return numpy_results
350
312
 
351
- elif self.task == "image_classification":
313
+ elif self.task in [MLTaskKeys.BINARY_IMAGE_CLASSIFICATION, MLTaskKeys.MULTICLASS_IMAGE_CLASSIFICATION]:
352
314
  # Output is Dict[str, Tensor(0-dim) or Tensor(1-dim)]
353
315
  int_label = tensor_results[PyTorchInferenceKeys.LABELS].item()
354
316
  label_name = "Unknown"
@@ -360,50 +322,32 @@ class PyTorchVisionInferenceHandler:
360
322
  PyTorchInferenceKeys.LABEL_NAMES: label_name,
361
323
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
362
324
  }
363
- else: # image_segmentation
325
+ else: # image_segmentation (binary or multiclass)
364
326
  # Output is Dict[str, Tensor(2D) or Tensor(3D)]
365
327
  return {
366
328
  PyTorchInferenceKeys.LABELS: tensor_results[PyTorchInferenceKeys.LABELS].cpu().numpy(),
367
329
  PyTorchInferenceKeys.PROBABILITIES: tensor_results[PyTorchInferenceKeys.PROBABILITIES].cpu().numpy()
368
330
  }
369
331
 
370
- def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
332
+ def predict_from_pil(self, image: Image.Image) -> Dict[str, Any]:
371
333
  """
372
- Loads a single image from a file, applies the stored transform, and returns the prediction.
334
+ Applies the stored transform to a single PIL image and returns the prediction.
373
335
 
374
336
  Args:
375
- image_path (str | Path): The file path to the input image.
337
+ image (PIL.Image.Image): The input PIL image.
376
338
 
377
339
  Returns:
378
340
  Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
379
341
  """
380
342
  if self._transform is None:
381
- _LOGGER.error("Cannot predict from file: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
343
+ _LOGGER.error("Cannot predict from PIL image: No transform has been set. Call .set_transform() or provide transform_source in __init__.")
382
344
  raise RuntimeError("Inference transform is not set.")
383
-
384
- try:
385
- # --- Use expected_in_channels to set PIL mode ---
386
- pil_mode: str
387
- if self.expected_in_channels == 1:
388
- pil_mode = "L" # Grayscale
389
- elif self.expected_in_channels == 4:
390
- pil_mode = "RGBA" # RGB + Alpha
391
- else:
392
- if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
393
- _LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
394
- # Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
395
- pil_mode = "RGB"
396
-
397
- image = Image.open(image_path).convert(pil_mode)
398
- except Exception as e:
399
- _LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
400
- raise
401
345
 
402
346
  # Apply the transformation pipeline (e.g., resize, crop, ToTensor, normalize)
403
347
  try:
404
348
  transformed_image = self._transform(image)
405
349
  except Exception as e:
406
- _LOGGER.error(f"Error applying transform to image: {e}")
350
+ _LOGGER.error(f"Error applying transform to PIL image: {e}")
407
351
  raise
408
352
 
409
353
  # --- Validation ---
@@ -413,7 +357,7 @@ class PyTorchVisionInferenceHandler:
413
357
 
414
358
  if transformed_image.ndim != 3:
415
359
  _LOGGER.warning(f"Expected transform to output a 3D (C, H, W) tensor, but got {transformed_image.ndim}D. Attempting to proceed.")
416
- # .predict() which expects a 3D tensor
360
+ # .predict_numpy() -> .predict() which expects a 3D tensor
417
361
  if transformed_image.ndim == 4 and transformed_image.shape[0] == 1:
418
362
  transformed_image = transformed_image.squeeze(0) # Fix if user's transform adds a batch dim
419
363
  _LOGGER.warning("Removed an extra batch dimension.")
@@ -423,6 +367,39 @@ class PyTorchVisionInferenceHandler:
423
367
  # Use the existing single-item predict method
424
368
  return self.predict_numpy(transformed_image)
425
369
 
370
+ def predict_from_file(self, image_path: Union[str, Path]) -> Dict[str, Any]:
371
+ """
372
+ Loads a single image from a file, applies the stored transform, and returns the prediction.
373
+
374
+ This is a convenience wrapper that loads the image and calls `predict_from_pil()`.
375
+
376
+ Args:
377
+ image_path (str | Path): The file path to the input image.
378
+
379
+ Returns:
380
+ Dict: A dictionary containing the prediction results. See `predict_numpy()` for task-specific output structures.
381
+ """
382
+ try:
383
+ # --- Use expected_in_channels to set PIL mode ---
384
+ pil_mode: str
385
+ if self.expected_in_channels == 1:
386
+ pil_mode = "L" # Grayscale
387
+ elif self.expected_in_channels == 4:
388
+ pil_mode = "RGBA" # RGB + Alpha
389
+ else:
390
+ if self.expected_in_channels != 3: # 2, 5+ channels not supported by PIL convert
391
+ _LOGGER.warning(f"Model expects {self.expected_in_channels} channels. PIL conversion is limited, defaulting to 3 channels (RGB). The transformations must convert it to the desired channel dimensions.")
392
+ # Default to RGB. If 2-channels are needed, the transform recipe *must* be responsible for handling the conversion from a 3-channel PIL image.
393
+ pil_mode = "RGB"
394
+
395
+ image = Image.open(image_path).convert(pil_mode)
396
+ except Exception as e:
397
+ _LOGGER.error(f"Failed to load and convert image from '{image_path}': {e}")
398
+ raise
399
+
400
+ # Call the PIL-based prediction method
401
+ return self.predict_from_pil(image)
402
+
426
403
 
427
404
  def info():
428
405
  _script_info(__all__)
@@ -47,12 +47,17 @@ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
47
47
  self.num_classes = num_classes
48
48
  self.in_channels = in_channels
49
49
  self.model_name = model_name
50
+ self._pretrained_default_transforms = None
50
51
 
51
52
  # --- 2. Instantiate the base model ---
52
53
  if init_with_pretrained:
53
54
  weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
54
55
  weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
55
56
 
57
+ # Save transformations for pretrained models
58
+ if weights:
59
+ self._pretrained_default_transforms = weights.transforms()
60
+
56
61
  if weights is None and init_with_pretrained:
57
62
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
58
63
  self.model = getattr(vision_models, model_name)(pretrained=True)
@@ -331,6 +336,7 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
331
336
  self.num_classes = num_classes
332
337
  self.in_channels = in_channels
333
338
  self.model_name = model_name
339
+ self._pretrained_default_transforms = None
334
340
 
335
341
  # --- 2. Instantiate the base model ---
336
342
  model_kwargs = {
@@ -343,6 +349,10 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
343
349
  weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
344
350
  weights = weights_enum.DEFAULT if weights_enum else None
345
351
 
352
+ # save pretrained model transformations
353
+ if weights:
354
+ self._pretrained_default_transforms = weights.transforms()
355
+
346
356
  if weights is None:
347
357
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
348
358
  # Legacy models used 'pretrained=True' and num_classes was separate
@@ -520,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
520
530
  This wrapper allows for customizing the model backbone, input channels,
521
531
  and the number of output classes for transfer learning.
522
532
 
523
- NOTE: This model is NOT compatible with the MLTrainer class.
533
+ NOTE: Use an Object Detection compatible trainer.
524
534
  """
525
535
  def __init__(self,
526
536
  num_classes: int,
@@ -550,6 +560,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
550
560
  self.num_classes = num_classes
551
561
  self.in_channels = in_channels
552
562
  self.model_name = model_name
563
+ self._pretrained_default_transforms = None
553
564
 
554
565
  # --- 2. Instantiate the base model ---
555
566
  model_constructor = getattr(detection_models, model_name)
@@ -560,6 +571,9 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
560
571
 
561
572
  weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
562
573
  weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
574
+
575
+ if weights:
576
+ self._pretrained_default_transforms = weights.transforms()
563
577
 
564
578
  self.model = model_constructor(weights=weights, weights_backbone=weights)
565
579