dragon-ml-toolbox 13.1.0__py3-none-any.whl → 14.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/METADATA +11 -2
- dragon_ml_toolbox-14.3.1.dist-info/RECORD +48 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
- ml_tools/MICE_imputation.py +207 -5
- ml_tools/ML_datasetmaster.py +63 -205
- ml_tools/ML_evaluation.py +23 -15
- ml_tools/ML_evaluation_multi.py +5 -6
- ml_tools/ML_inference.py +0 -1
- ml_tools/ML_models.py +22 -6
- ml_tools/ML_models_advanced.py +323 -0
- ml_tools/ML_trainer.py +463 -20
- ml_tools/ML_utilities.py +302 -4
- ml_tools/ML_vision_datasetmaster.py +1395 -0
- ml_tools/ML_vision_evaluation.py +260 -0
- ml_tools/ML_vision_inference.py +428 -0
- ml_tools/ML_vision_models.py +627 -0
- ml_tools/ML_vision_transformers.py +58 -0
- ml_tools/_ML_vision_recipe.py +88 -0
- ml_tools/__init__.py +1 -0
- ml_tools/_schema.py +79 -2
- ml_tools/custom_logger.py +37 -14
- ml_tools/data_exploration.py +502 -93
- ml_tools/keys.py +42 -1
- ml_tools/math_utilities.py +1 -1
- ml_tools/serde.py +77 -15
- ml_tools/utilities.py +192 -3
- dragon_ml_toolbox-13.1.0.dist-info/RECORD +0 -41
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-13.1.0.dist-info → dragon_ml_toolbox-14.3.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1395 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from torch.utils.data import Dataset, Subset
|
|
3
|
+
import numpy
|
|
4
|
+
from sklearn.model_selection import train_test_split
|
|
5
|
+
from typing import Union, Tuple, List, Optional, Callable, Dict, Any
|
|
6
|
+
from PIL import Image
|
|
7
|
+
from torchvision.datasets import ImageFolder
|
|
8
|
+
from torchvision import transforms
|
|
9
|
+
import torchvision.transforms.functional as TF
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
import random
|
|
12
|
+
import json
|
|
13
|
+
import inspect
|
|
14
|
+
|
|
15
|
+
from .ML_datasetmaster import _BaseMaker
|
|
16
|
+
from .path_manager import make_fullpath
|
|
17
|
+
from ._logger import _LOGGER
|
|
18
|
+
from ._script_info import _script_info
|
|
19
|
+
from .keys import VisionTransformRecipeKeys, ObjectDetectionKeys
|
|
20
|
+
from ._ML_vision_recipe import save_recipe
|
|
21
|
+
from .ML_vision_transformers import TRANSFORM_REGISTRY
|
|
22
|
+
from .custom_logger import custom_logger
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = [
|
|
26
|
+
"VisionDatasetMaker",
|
|
27
|
+
"SegmentationDatasetMaker",
|
|
28
|
+
"ObjectDetectionDatasetMaker"
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# --- VisionDatasetMaker ---
|
|
33
|
+
class VisionDatasetMaker(_BaseMaker):
|
|
34
|
+
"""
|
|
35
|
+
Creates processed PyTorch datasets for computer vision tasks from an
|
|
36
|
+
image folder directory.
|
|
37
|
+
|
|
38
|
+
Supports two modes:
|
|
39
|
+
1. `from_folder()`: Loads from one directory and splits into train/val/test.
|
|
40
|
+
2. `from_folders()`: Loads from pre-split train/val/test directories.
|
|
41
|
+
|
|
42
|
+
Uses online augmentations per epoch (image augmentation without creating new files).
|
|
43
|
+
"""
|
|
44
|
+
def __init__(self):
|
|
45
|
+
"""
|
|
46
|
+
Typically not called directly. Use the class methods `from_folder()` or `from_folders()` to create an instance.
|
|
47
|
+
"""
|
|
48
|
+
super().__init__()
|
|
49
|
+
self._full_dataset: Optional[ImageFolder] = None
|
|
50
|
+
self.labels: Optional[List[int]] = None
|
|
51
|
+
self.class_map: Optional[dict[str,int]] = None
|
|
52
|
+
|
|
53
|
+
self._is_split = False
|
|
54
|
+
self._are_transforms_configured = False
|
|
55
|
+
self._val_recipe_components = None
|
|
56
|
+
self._has_mean_std: bool = False
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_folder(cls, root_dir: Union[str,Path]) -> 'VisionDatasetMaker':
|
|
60
|
+
"""
|
|
61
|
+
Creates a maker instance from a single root directory of images.
|
|
62
|
+
|
|
63
|
+
This method assumes a single directory (e.g., 'data/') that
|
|
64
|
+
contains class subfolders (e.g., 'data/cat/', 'data/dog/').
|
|
65
|
+
|
|
66
|
+
The dataset will be loaded in its entirety, and you MUST call
|
|
67
|
+
`.split_data()` afterward to create train/validation/test sets.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
root_dir (str | Path): The path to the root directory containing class subfolders.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
VisionDatasetMaker: A new instance with the full dataset loaded.
|
|
74
|
+
"""
|
|
75
|
+
root_path = make_fullpath(root_dir, enforce="directory")
|
|
76
|
+
# Load with NO transform. We get PIL Images.
|
|
77
|
+
full_dataset = ImageFolder(root=root_path, transform=None)
|
|
78
|
+
_LOGGER.info(f"Found {len(full_dataset)} images in {len(full_dataset.classes)} classes.")
|
|
79
|
+
|
|
80
|
+
maker = cls()
|
|
81
|
+
maker._full_dataset = full_dataset
|
|
82
|
+
maker.labels = [s[1] for s in full_dataset.samples]
|
|
83
|
+
maker.class_map = full_dataset.class_to_idx
|
|
84
|
+
return maker
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def from_folders(cls,
|
|
88
|
+
train_dir: Union[str,Path],
|
|
89
|
+
val_dir: Union[str,Path],
|
|
90
|
+
test_dir: Optional[Union[str,Path]] = None) -> 'VisionDatasetMaker':
|
|
91
|
+
"""
|
|
92
|
+
Creates a maker instance from separate, pre-split directories.
|
|
93
|
+
|
|
94
|
+
This method is used when you already have 'train', 'val', and
|
|
95
|
+
optionally 'test' folders, each containing class subfolders.
|
|
96
|
+
It bypasses the need for `.split_data()`.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
train_dir (str | Path): Path to the training data directory.
|
|
100
|
+
val_dir (str | Path): Path to the validation data directory.
|
|
101
|
+
test_dir (str | Path | None): Path to the test data directory.
|
|
102
|
+
|
|
103
|
+
Returns:
|
|
104
|
+
VisionDatasetMaker: A new, pre-split instance.
|
|
105
|
+
|
|
106
|
+
Raises:
|
|
107
|
+
ValueError: If the classes found in train, val, or test directories are inconsistent.
|
|
108
|
+
"""
|
|
109
|
+
train_path = make_fullpath(train_dir, enforce="directory")
|
|
110
|
+
val_path = make_fullpath(val_dir, enforce="directory")
|
|
111
|
+
|
|
112
|
+
_LOGGER.info("Loading data from separate directories.")
|
|
113
|
+
# Load with NO transform. We get PIL Images.
|
|
114
|
+
train_ds = ImageFolder(root=train_path, transform=None)
|
|
115
|
+
val_ds = ImageFolder(root=val_path, transform=None)
|
|
116
|
+
|
|
117
|
+
# Check for class consistency
|
|
118
|
+
if train_ds.class_to_idx != val_ds.class_to_idx:
|
|
119
|
+
_LOGGER.error("Train and validation directories have different or inconsistent classes.")
|
|
120
|
+
raise ValueError()
|
|
121
|
+
|
|
122
|
+
maker = cls()
|
|
123
|
+
maker._train_dataset = train_ds
|
|
124
|
+
maker._val_dataset = val_ds
|
|
125
|
+
maker.class_map = train_ds.class_to_idx
|
|
126
|
+
|
|
127
|
+
if test_dir:
|
|
128
|
+
test_path = make_fullpath(test_dir, enforce="directory")
|
|
129
|
+
test_ds = ImageFolder(root=test_path, transform=None)
|
|
130
|
+
if train_ds.class_to_idx != test_ds.class_to_idx:
|
|
131
|
+
_LOGGER.error("Train and test directories have different or inconsistent classes.")
|
|
132
|
+
raise ValueError()
|
|
133
|
+
maker._test_dataset = test_ds
|
|
134
|
+
_LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test images.")
|
|
135
|
+
else:
|
|
136
|
+
_LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val images.")
|
|
137
|
+
|
|
138
|
+
maker._is_split = True # Mark as "split" since data is pre-split
|
|
139
|
+
return maker
|
|
140
|
+
|
|
141
|
+
@staticmethod
|
|
142
|
+
def inspect_folder(path: Union[str, Path]):
|
|
143
|
+
"""
|
|
144
|
+
Logs a report of the types, sizes, and channels of image files
|
|
145
|
+
found in the directory and its subdirectories.
|
|
146
|
+
|
|
147
|
+
This is a utility method to help diagnose potential dataset
|
|
148
|
+
issues (e.g., mixed image modes, corrupted files) before loading.
|
|
149
|
+
|
|
150
|
+
Args:
|
|
151
|
+
path (str, Path): The directory path to inspect.
|
|
152
|
+
"""
|
|
153
|
+
path_obj = make_fullpath(path)
|
|
154
|
+
|
|
155
|
+
non_image_files = set()
|
|
156
|
+
img_types = set()
|
|
157
|
+
img_sizes = set()
|
|
158
|
+
img_channels = set()
|
|
159
|
+
img_counter = 0
|
|
160
|
+
|
|
161
|
+
_LOGGER.info(f"Inspecting folder: {path_obj}...")
|
|
162
|
+
# Use rglob to recursively find all files
|
|
163
|
+
for filepath in path_obj.rglob('*'):
|
|
164
|
+
if filepath.is_file():
|
|
165
|
+
try:
|
|
166
|
+
# Using PIL to open is a more reliable check
|
|
167
|
+
with Image.open(filepath) as img:
|
|
168
|
+
img_types.add(img.format)
|
|
169
|
+
img_sizes.add(img.size)
|
|
170
|
+
img_channels.update(img.getbands())
|
|
171
|
+
img_counter += 1
|
|
172
|
+
except (IOError, SyntaxError):
|
|
173
|
+
non_image_files.add(filepath.name)
|
|
174
|
+
|
|
175
|
+
if non_image_files:
|
|
176
|
+
_LOGGER.warning(f"Non-image or corrupted files found and ignored: {non_image_files}")
|
|
177
|
+
|
|
178
|
+
report = (
|
|
179
|
+
f"\n--- Inspection Report for '{path_obj.name}' ---\n"
|
|
180
|
+
f"Total images found: {img_counter}\n"
|
|
181
|
+
f"Image formats: {img_types or 'None'}\n"
|
|
182
|
+
f"Image sizes (WxH): {img_sizes or 'None'}\n"
|
|
183
|
+
f"Image channels (bands): {img_channels or 'None'}\n"
|
|
184
|
+
f"--------------------------------------"
|
|
185
|
+
)
|
|
186
|
+
print(report)
|
|
187
|
+
|
|
188
|
+
def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
|
|
189
|
+
stratify: bool = True, random_state: Optional[int] = None) -> 'VisionDatasetMaker':
|
|
190
|
+
"""
|
|
191
|
+
Splits the dataset into train, validation, and optional test sets.
|
|
192
|
+
|
|
193
|
+
This method MUST be called if you used `from_folder()`. It has no effect if you used `from_folders()`.
|
|
194
|
+
|
|
195
|
+
Args:
|
|
196
|
+
val_size (float): Proportion of the dataset to reserve for
|
|
197
|
+
validation (e.g., 0.2 for 20%).
|
|
198
|
+
test_size (float): Proportion of the dataset to reserve for
|
|
199
|
+
testing.
|
|
200
|
+
stratify (bool): If True, splits are performed in a stratified
|
|
201
|
+
fashion, preserving the class distribution.
|
|
202
|
+
random_state (int | None): Seed for the random number generator for reproducible splits.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
VisionDatasetMaker: The same instance, now with datasets split.
|
|
206
|
+
|
|
207
|
+
Raises:
|
|
208
|
+
ValueError: If `val_size` and `test_size` sum to 1.0 or more.
|
|
209
|
+
"""
|
|
210
|
+
if self._is_split:
|
|
211
|
+
_LOGGER.warning("Data has already been split.")
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
if val_size + test_size >= 1.0:
|
|
215
|
+
_LOGGER.error("The sum of val_size and test_size must be less than 1.")
|
|
216
|
+
raise ValueError()
|
|
217
|
+
|
|
218
|
+
if not self._full_dataset:
|
|
219
|
+
_LOGGER.error("There is no dataset to split.")
|
|
220
|
+
raise ValueError()
|
|
221
|
+
|
|
222
|
+
indices = list(range(len(self._full_dataset)))
|
|
223
|
+
labels_for_split = self.labels if stratify else None
|
|
224
|
+
|
|
225
|
+
train_indices, val_test_indices = train_test_split(
|
|
226
|
+
indices, test_size=(val_size + test_size), random_state=random_state, stratify=labels_for_split
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if not self.labels:
|
|
230
|
+
_LOGGER.error("Error when getting full dataset labels.")
|
|
231
|
+
raise ValueError()
|
|
232
|
+
|
|
233
|
+
if test_size > 0:
|
|
234
|
+
val_test_labels = [self.labels[i] for i in val_test_indices]
|
|
235
|
+
stratify_val_test = val_test_labels if stratify else None
|
|
236
|
+
val_indices, test_indices = train_test_split(
|
|
237
|
+
val_test_indices, test_size=(test_size / (val_size + test_size)),
|
|
238
|
+
random_state=random_state, stratify=stratify_val_test
|
|
239
|
+
)
|
|
240
|
+
self._test_dataset = Subset(self._full_dataset, test_indices)
|
|
241
|
+
_LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
|
|
242
|
+
else:
|
|
243
|
+
val_indices = val_test_indices
|
|
244
|
+
|
|
245
|
+
self._train_dataset = Subset(self._full_dataset, train_indices)
|
|
246
|
+
self._val_dataset = Subset(self._full_dataset, val_indices)
|
|
247
|
+
self._is_split = True
|
|
248
|
+
|
|
249
|
+
_LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
|
|
250
|
+
return self
|
|
251
|
+
|
|
252
|
+
def configure_transforms(self, resize_size: int = 256, crop_size: int = 224,
|
|
253
|
+
mean: Optional[List[float]] = [0.485, 0.456, 0.406],
|
|
254
|
+
std: Optional[List[float]] = [0.229, 0.224, 0.225],
|
|
255
|
+
pre_transforms: Optional[List[Callable]] = None,
|
|
256
|
+
extra_train_transforms: Optional[List[Callable]] = None) -> 'VisionDatasetMaker':
|
|
257
|
+
"""
|
|
258
|
+
Configures and applies the image transformations and augmentations.
|
|
259
|
+
|
|
260
|
+
This method must be called AFTER data is loaded and split.
|
|
261
|
+
|
|
262
|
+
It sets up two pipelines:
|
|
263
|
+
1. **Training Pipeline:** Includes random augmentations like
|
|
264
|
+
`RandomResizedCrop` and `RandomHorizontalFlip` (plus any
|
|
265
|
+
`extra_train_transforms`) for online augmentation.
|
|
266
|
+
2. **Validation/Test Pipeline:** A deterministic pipeline using
|
|
267
|
+
`Resize` and `CenterCrop` for consistent evaluation.
|
|
268
|
+
|
|
269
|
+
Both pipelines finish with `ToTensor` and `Normalize`.
|
|
270
|
+
|
|
271
|
+
Args:
|
|
272
|
+
resize_size (int): The size to resize the smallest edge to
|
|
273
|
+
for validation/testing.
|
|
274
|
+
crop_size (int): The target size (square) for the final
|
|
275
|
+
cropped image.
|
|
276
|
+
mean (List[float]): The mean values for normalization (e.g., ImageNet mean).
|
|
277
|
+
std (List[float]): The standard deviation values for normalization (e.g., ImageNet std).
|
|
278
|
+
extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
|
|
279
|
+
pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
|
|
280
|
+
|
|
281
|
+
Returns:
|
|
282
|
+
VisionDatasetMaker: The same instance, with transforms applied.
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
RuntimeError: If called before data is split.
|
|
286
|
+
"""
|
|
287
|
+
if not self._is_split:
|
|
288
|
+
_LOGGER.error("Transforms must be configured AFTER splitting data (or using `from_folders`). Call .split_data() first if using `from_folder`.")
|
|
289
|
+
raise RuntimeError()
|
|
290
|
+
|
|
291
|
+
if (mean is None and std is not None) or (mean is not None and std is None):
|
|
292
|
+
_LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
|
|
293
|
+
raise ValueError()
|
|
294
|
+
|
|
295
|
+
# --- Define Transform Pipelines ---
|
|
296
|
+
# These now MUST include ToTensor and Normalize, as the ImageFolder was loaded with transform=None.
|
|
297
|
+
|
|
298
|
+
# --- Store components for validation recipe ---
|
|
299
|
+
self._val_recipe_components = {
|
|
300
|
+
VisionTransformRecipeKeys.PRE_TRANSFORMS: pre_transforms or [],
|
|
301
|
+
VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
|
|
302
|
+
VisionTransformRecipeKeys.CROP_SIZE: crop_size,
|
|
303
|
+
}
|
|
304
|
+
|
|
305
|
+
if mean is not None and std is not None:
|
|
306
|
+
self._val_recipe_components.update({
|
|
307
|
+
VisionTransformRecipeKeys.MEAN: mean,
|
|
308
|
+
VisionTransformRecipeKeys.STD: std
|
|
309
|
+
})
|
|
310
|
+
self._has_mean_std = True
|
|
311
|
+
|
|
312
|
+
base_pipeline = []
|
|
313
|
+
if pre_transforms:
|
|
314
|
+
base_pipeline.extend(pre_transforms)
|
|
315
|
+
|
|
316
|
+
# Base augmentations for training
|
|
317
|
+
base_train_transforms = [
|
|
318
|
+
transforms.RandomResizedCrop(crop_size),
|
|
319
|
+
transforms.RandomHorizontalFlip()
|
|
320
|
+
]
|
|
321
|
+
if extra_train_transforms:
|
|
322
|
+
base_train_transforms.extend(extra_train_transforms)
|
|
323
|
+
|
|
324
|
+
# Final conversion and normalization
|
|
325
|
+
final_transforms: list[Callable] = [
|
|
326
|
+
transforms.ToTensor()
|
|
327
|
+
]
|
|
328
|
+
|
|
329
|
+
if self._has_mean_std:
|
|
330
|
+
final_transforms.append(transforms.Normalize(mean=mean, std=std))
|
|
331
|
+
|
|
332
|
+
# Build the val/test pipeline
|
|
333
|
+
val_transform_list = [
|
|
334
|
+
*base_pipeline, # Apply pre_transforms first
|
|
335
|
+
transforms.Resize(resize_size),
|
|
336
|
+
transforms.CenterCrop(crop_size),
|
|
337
|
+
*final_transforms
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
# Build the train pipeline
|
|
341
|
+
train_transform_list = [
|
|
342
|
+
*base_pipeline, # Apply pre_transforms first
|
|
343
|
+
*base_train_transforms,
|
|
344
|
+
*final_transforms
|
|
345
|
+
]
|
|
346
|
+
|
|
347
|
+
val_transform = transforms.Compose(val_transform_list)
|
|
348
|
+
train_transform = transforms.Compose(train_transform_list)
|
|
349
|
+
|
|
350
|
+
# --- Apply Transforms using the Wrapper ---
|
|
351
|
+
# This correctly assigns the transform regardless of whether the dataset is a Subset (from_folder) or an ImageFolder (from_folders).
|
|
352
|
+
|
|
353
|
+
self._train_dataset = _DatasetTransformer(self._train_dataset, train_transform) # type: ignore
|
|
354
|
+
self._val_dataset = _DatasetTransformer(self._val_dataset, val_transform) # type: ignore
|
|
355
|
+
if self._test_dataset:
|
|
356
|
+
self._test_dataset = _DatasetTransformer(self._test_dataset, val_transform) # type: ignore
|
|
357
|
+
|
|
358
|
+
self._are_transforms_configured = True
|
|
359
|
+
_LOGGER.info("Image transforms configured and applied.")
|
|
360
|
+
return self
|
|
361
|
+
|
|
362
|
+
def get_datasets(self) -> Tuple[Dataset, ...]:
|
|
363
|
+
"""
|
|
364
|
+
Returns the final train, validation, and optional test datasets.
|
|
365
|
+
|
|
366
|
+
This is the final step, used to retrieve the datasets for use in
|
|
367
|
+
a `MLTrainer` or `DataLoader`.
|
|
368
|
+
|
|
369
|
+
Returns:
|
|
370
|
+
(Tuple[Dataset, ...]): A tuple containing the (train, val)
|
|
371
|
+
or (train, val, test) datasets.
|
|
372
|
+
|
|
373
|
+
Raises:
|
|
374
|
+
RuntimeError: If called before data is split.
|
|
375
|
+
UserWarning: If called before transforms are configured.
|
|
376
|
+
"""
|
|
377
|
+
if not self._is_split:
|
|
378
|
+
_LOGGER.error("Data has not been split. Call .split_data() first.")
|
|
379
|
+
raise RuntimeError()
|
|
380
|
+
if not self._are_transforms_configured:
|
|
381
|
+
_LOGGER.warning("Transforms have not been configured.")
|
|
382
|
+
|
|
383
|
+
if self._test_dataset:
|
|
384
|
+
return self._train_dataset, self._val_dataset, self._test_dataset
|
|
385
|
+
return self._train_dataset, self._val_dataset
|
|
386
|
+
|
|
387
|
+
def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
|
|
388
|
+
"""
|
|
389
|
+
Saves the validation transform pipeline as a JSON recipe file.
|
|
390
|
+
|
|
391
|
+
This recipe can be loaded by the PyTorchVisionInferenceHandler
|
|
392
|
+
to ensure identical preprocessing.
|
|
393
|
+
|
|
394
|
+
Args:
|
|
395
|
+
filepath (str | Path): The path to save the .json recipe file.
|
|
396
|
+
"""
|
|
397
|
+
if not self._are_transforms_configured:
|
|
398
|
+
_LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
|
|
399
|
+
raise RuntimeError()
|
|
400
|
+
|
|
401
|
+
recipe: Dict[str, Any] = {
|
|
402
|
+
VisionTransformRecipeKeys.TASK: "classification",
|
|
403
|
+
VisionTransformRecipeKeys.PIPELINE: []
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
components = self._val_recipe_components
|
|
407
|
+
|
|
408
|
+
if not components:
|
|
409
|
+
_LOGGER.error(f"Error getting the transformers recipe for validation set.")
|
|
410
|
+
raise ValueError()
|
|
411
|
+
|
|
412
|
+
# validate path
|
|
413
|
+
file_path = make_fullpath(filepath, make=True, enforce="file")
|
|
414
|
+
|
|
415
|
+
# Handle pre_transforms
|
|
416
|
+
for t in components[VisionTransformRecipeKeys.PRE_TRANSFORMS]:
|
|
417
|
+
t_name = t.__class__.__name__
|
|
418
|
+
t_class = t.__class__
|
|
419
|
+
kwargs = {}
|
|
420
|
+
|
|
421
|
+
# 1. Check custom registry first
|
|
422
|
+
if t_name in TRANSFORM_REGISTRY:
|
|
423
|
+
_LOGGER.debug(f"Found '{t_name}' in TRANSFORM_REGISTRY.")
|
|
424
|
+
kwargs = getattr(t, VisionTransformRecipeKeys.KWARGS, {})
|
|
425
|
+
|
|
426
|
+
# 2. Else, try to introspect for standard torchvision transforms
|
|
427
|
+
else:
|
|
428
|
+
_LOGGER.debug(f"'{t_name}' not in registry. Attempting introspection...")
|
|
429
|
+
try:
|
|
430
|
+
# Get the __init__ signature of the transform's class
|
|
431
|
+
sig = inspect.signature(t_class.__init__)
|
|
432
|
+
|
|
433
|
+
# Iterate over its __init__ parameters (e.g., 'num_output_channels')
|
|
434
|
+
for param in sig.parameters.values():
|
|
435
|
+
if param.name == 'self':
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
# Check if the *instance* 't' has that parameter as an attribute
|
|
439
|
+
attr_name_public = param.name
|
|
440
|
+
attr_name_private = '_' + param.name
|
|
441
|
+
|
|
442
|
+
attr_to_get = ""
|
|
443
|
+
|
|
444
|
+
if hasattr(t, attr_name_public):
|
|
445
|
+
attr_to_get = attr_name_public
|
|
446
|
+
elif hasattr(t, attr_name_private):
|
|
447
|
+
attr_to_get = attr_name_private
|
|
448
|
+
else:
|
|
449
|
+
# Parameter in __init__ has no matching attribute
|
|
450
|
+
continue
|
|
451
|
+
|
|
452
|
+
# Store the value under the __init__ parameter's name
|
|
453
|
+
kwargs[param.name] = getattr(t, attr_to_get)
|
|
454
|
+
|
|
455
|
+
_LOGGER.debug(f"Introspection for '{t_name}' found kwargs: {kwargs}")
|
|
456
|
+
|
|
457
|
+
except (ValueError, TypeError):
|
|
458
|
+
# Fails on some built-ins or C-implemented __init__
|
|
459
|
+
_LOGGER.warning(f"Could not introspect parameters for '{t_name}'. If this transform has parameters, they will not be saved.")
|
|
460
|
+
kwargs = {}
|
|
461
|
+
|
|
462
|
+
# 3. Add to pipeline
|
|
463
|
+
recipe[VisionTransformRecipeKeys.PIPELINE].append({
|
|
464
|
+
VisionTransformRecipeKeys.NAME: t_name,
|
|
465
|
+
VisionTransformRecipeKeys.KWARGS: kwargs
|
|
466
|
+
})
|
|
467
|
+
|
|
468
|
+
# 2. Add standard transforms
|
|
469
|
+
recipe[VisionTransformRecipeKeys.PIPELINE].extend([
|
|
470
|
+
{VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
|
|
471
|
+
{VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
|
|
472
|
+
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
|
|
473
|
+
])
|
|
474
|
+
|
|
475
|
+
if self._has_mean_std:
|
|
476
|
+
recipe[VisionTransformRecipeKeys.PIPELINE].append(
|
|
477
|
+
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
478
|
+
"mean": components[VisionTransformRecipeKeys.MEAN],
|
|
479
|
+
"std": components[VisionTransformRecipeKeys.STD]
|
|
480
|
+
}}
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
# 3. Save the file
|
|
484
|
+
save_recipe(recipe, file_path)
|
|
485
|
+
|
|
486
|
+
def save_class_map(self, save_dir: Union[str,Path]) -> dict[str,int]:
|
|
487
|
+
"""
|
|
488
|
+
Saves the class to index mapping {str: int} to a directory.
|
|
489
|
+
"""
|
|
490
|
+
if not self.class_map:
|
|
491
|
+
_LOGGER.error(f"Class to index mapping is empty.")
|
|
492
|
+
raise ValueError()
|
|
493
|
+
|
|
494
|
+
custom_logger(data=self.class_map,
|
|
495
|
+
save_directory=save_dir,
|
|
496
|
+
log_name="Class_to_Index",
|
|
497
|
+
add_timestamp=False,
|
|
498
|
+
dict_as="json")
|
|
499
|
+
|
|
500
|
+
return self.class_map
|
|
501
|
+
|
|
502
|
+
|
|
503
|
+
class _DatasetTransformer(Dataset):
|
|
504
|
+
"""
|
|
505
|
+
Internal wrapper class to apply a specific transform pipeline to any
|
|
506
|
+
dataset (e.g., a full ImageFolder or a Subset).
|
|
507
|
+
"""
|
|
508
|
+
def __init__(self, dataset: Dataset, transform: Optional[transforms.Compose] = None):
|
|
509
|
+
self.dataset = dataset
|
|
510
|
+
self.transform = transform
|
|
511
|
+
|
|
512
|
+
# --- Propagate attributes for inspection ---
|
|
513
|
+
# For ImageFolder
|
|
514
|
+
if hasattr(dataset, 'class_to_idx'):
|
|
515
|
+
self.class_to_idx = getattr(dataset, 'class_to_idx')
|
|
516
|
+
if hasattr(dataset, 'classes'):
|
|
517
|
+
self.classes = getattr(dataset, 'classes')
|
|
518
|
+
# For Subset
|
|
519
|
+
if hasattr(dataset, 'indices'):
|
|
520
|
+
self.indices = getattr(dataset, 'indices')
|
|
521
|
+
if hasattr(dataset, 'dataset'):
|
|
522
|
+
# This allows access to the *original* full dataset
|
|
523
|
+
self.original_dataset = getattr(dataset, 'dataset')
|
|
524
|
+
|
|
525
|
+
def __getitem__(self, index):
|
|
526
|
+
# Get the original data (e.g., PIL Image, label)
|
|
527
|
+
x, y = self.dataset[index]
|
|
528
|
+
|
|
529
|
+
# Apply the specific transform for this dataset
|
|
530
|
+
if self.transform:
|
|
531
|
+
x = self.transform(x)
|
|
532
|
+
return x, y
|
|
533
|
+
|
|
534
|
+
def __len__(self):
|
|
535
|
+
return len(self.dataset) # type: ignore
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
# --- Segmentation dataset ----
|
|
539
|
+
class _SegmentationDataset(Dataset):
|
|
540
|
+
"""
|
|
541
|
+
Internal helper class to load image-mask pairs.
|
|
542
|
+
|
|
543
|
+
Loads images as RGB and masks as 'L' (grayscale, 8-bit integer pixels).
|
|
544
|
+
"""
|
|
545
|
+
def __init__(self, image_paths: List[Path], mask_paths: List[Path], transform: Optional[Callable] = None):
|
|
546
|
+
self.image_paths = image_paths
|
|
547
|
+
self.mask_paths = mask_paths
|
|
548
|
+
self.transform = transform
|
|
549
|
+
|
|
550
|
+
# --- Propagate 'classes' if they exist (for MLTrainer) ---
|
|
551
|
+
self.classes: List[str] = []
|
|
552
|
+
|
|
553
|
+
def __len__(self):
|
|
554
|
+
return len(self.image_paths)
|
|
555
|
+
|
|
556
|
+
def __getitem__(self, idx):
|
|
557
|
+
img_path = self.image_paths[idx]
|
|
558
|
+
mask_path = self.mask_paths[idx]
|
|
559
|
+
|
|
560
|
+
try:
|
|
561
|
+
# Open as PIL Images. Masks should be 'L'
|
|
562
|
+
image = Image.open(img_path).convert("RGB")
|
|
563
|
+
mask = Image.open(mask_path).convert("L")
|
|
564
|
+
except Exception as e:
|
|
565
|
+
_LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {mask_path.name}. Error: {e}")
|
|
566
|
+
# Return empty tensors
|
|
567
|
+
return torch.empty(3, 224, 224), torch.empty(224, 224, dtype=torch.long)
|
|
568
|
+
|
|
569
|
+
if self.transform:
|
|
570
|
+
image, mask = self.transform(image, mask)
|
|
571
|
+
|
|
572
|
+
return image, mask
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
# Internal Paired Transform Helpers
|
|
576
|
+
class _PairedCompose:
|
|
577
|
+
"""A 'Compose' for paired image/mask transforms."""
|
|
578
|
+
def __init__(self, transforms: List[Callable]):
|
|
579
|
+
self.transforms = transforms
|
|
580
|
+
|
|
581
|
+
def __call__(self, image: Any, mask: Any) -> Tuple[Any, Any]:
|
|
582
|
+
for t in self.transforms:
|
|
583
|
+
image, mask = t(image, mask)
|
|
584
|
+
return image, mask
|
|
585
|
+
|
|
586
|
+
class _PairedToTensor:
|
|
587
|
+
"""Converts a PIL Image pair (image, mask) to Tensors."""
|
|
588
|
+
def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
589
|
+
# Use new variable names to satisfy the linter
|
|
590
|
+
image_tensor = TF.to_tensor(image)
|
|
591
|
+
# Convert mask to LongTensor, not float.
|
|
592
|
+
# This creates a [H, W] tensor of integer class IDs.
|
|
593
|
+
mask_tensor = torch.from_numpy(numpy.array(mask, dtype=numpy.int64))
|
|
594
|
+
return image_tensor, mask_tensor
|
|
595
|
+
|
|
596
|
+
class _PairedNormalize:
|
|
597
|
+
"""Normalizes the image tensor and leaves the mask untouched."""
|
|
598
|
+
def __init__(self, mean: List[float], std: List[float]):
|
|
599
|
+
self.normalize = transforms.Normalize(mean, std)
|
|
600
|
+
|
|
601
|
+
def __call__(self, image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
602
|
+
image = self.normalize(image)
|
|
603
|
+
return image, mask
|
|
604
|
+
|
|
605
|
+
class _PairedResize:
|
|
606
|
+
"""Resizes an image and mask to the same size."""
|
|
607
|
+
def __init__(self, size: int):
|
|
608
|
+
self.size = [size, size]
|
|
609
|
+
|
|
610
|
+
def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
|
|
611
|
+
# Use new variable names to avoid linter confusion
|
|
612
|
+
resized_image = TF.resize(image, self.size, interpolation=TF.InterpolationMode.BILINEAR) # type: ignore
|
|
613
|
+
# Use NEAREST for mask to avoid interpolating class IDs (e.g., 1.5)
|
|
614
|
+
resized_mask = TF.resize(mask, self.size, interpolation=TF.InterpolationMode.NEAREST) # type: ignore
|
|
615
|
+
return resized_image, resized_mask # type: ignore
|
|
616
|
+
|
|
617
|
+
class _PairedCenterCrop:
|
|
618
|
+
"""Center-crops an image and mask to the same size."""
|
|
619
|
+
def __init__(self, size: int):
|
|
620
|
+
self.size = [size, size]
|
|
621
|
+
|
|
622
|
+
def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
|
|
623
|
+
cropped_image = TF.center_crop(image, self.size) # type: ignore
|
|
624
|
+
cropped_mask = TF.center_crop(mask, self.size) # type: ignore
|
|
625
|
+
return cropped_image, cropped_mask # type: ignore
|
|
626
|
+
|
|
627
|
+
class _PairedRandomHorizontalFlip:
|
|
628
|
+
"""Applies the same random horizontal flip to both image and mask."""
|
|
629
|
+
def __init__(self, p: float = 0.5):
|
|
630
|
+
self.p = p
|
|
631
|
+
|
|
632
|
+
def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
|
|
633
|
+
if random.random() < self.p:
|
|
634
|
+
flipped_image = TF.hflip(image) # type: ignore
|
|
635
|
+
flipped_mask = TF.hflip(mask) # type: ignore
|
|
636
|
+
return flipped_image, flipped_mask # type: ignore
|
|
637
|
+
|
|
638
|
+
class _PairedRandomResizedCrop:
|
|
639
|
+
"""Applies the same random resized crop to both image and mask."""
|
|
640
|
+
def __init__(self, size: int, scale: Tuple[float, float]=(0.08, 1.0), ratio: Tuple[float, float]=(3./4., 4./3.)):
|
|
641
|
+
self.size = [size, size]
|
|
642
|
+
self.scale = scale
|
|
643
|
+
self.ratio = ratio
|
|
644
|
+
self.interpolation = TF.InterpolationMode.BILINEAR
|
|
645
|
+
self.mask_interpolation = TF.InterpolationMode.NEAREST
|
|
646
|
+
|
|
647
|
+
def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
|
|
648
|
+
# Get parameters for the random crop
|
|
649
|
+
# Convert scale/ratio tuples to lists to satisfy the linter's type hint
|
|
650
|
+
i, j, h, w = transforms.RandomResizedCrop.get_params(image, list(self.scale), list(self.ratio)) # type: ignore
|
|
651
|
+
|
|
652
|
+
# Apply the crop with the SAME parameters and use new variable names
|
|
653
|
+
cropped_image = TF.resized_crop(image, i, j, h, w, self.size, self.interpolation) # type: ignore
|
|
654
|
+
cropped_mask = TF.resized_crop(mask, i, j, h, w, self.size, self.mask_interpolation) # type: ignore
|
|
655
|
+
|
|
656
|
+
return cropped_image, cropped_mask # type: ignore
|
|
657
|
+
|
|
658
|
+
# --- SegmentationDatasetMaker ---
|
|
659
|
+
class SegmentationDatasetMaker(_BaseMaker):
|
|
660
|
+
"""
|
|
661
|
+
Creates processed PyTorch datasets for segmentation from image and mask folders.
|
|
662
|
+
|
|
663
|
+
This maker finds all matching image-mask pairs from two directories,
|
|
664
|
+
splits them, and applies identical transformations (including augmentations)
|
|
665
|
+
to both the image and its corresponding mask.
|
|
666
|
+
|
|
667
|
+
Workflow:
|
|
668
|
+
1. `maker = SegmentationDatasetMaker.from_folders(img_dir, mask_dir)`
|
|
669
|
+
2. `maker.set_class_map({'background': 0, 'road': 1})`
|
|
670
|
+
3. `maker.split_data(val_size=0.2)`
|
|
671
|
+
4. `maker.configure_transforms(crop_size=256)`
|
|
672
|
+
5. `train_ds, val_ds = maker.get_datasets()`
|
|
673
|
+
"""
|
|
674
|
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
|
|
675
|
+
|
|
676
|
+
def __init__(self):
|
|
677
|
+
"""
|
|
678
|
+
Typically not called directly. Use the class method `from_folders()` to create an instance.
|
|
679
|
+
"""
|
|
680
|
+
super().__init__()
|
|
681
|
+
self.image_paths: List[Path] = []
|
|
682
|
+
self.mask_paths: List[Path] = []
|
|
683
|
+
self.class_map: Dict[str, int] = {}
|
|
684
|
+
|
|
685
|
+
self._is_split = False
|
|
686
|
+
self._are_transforms_configured = False
|
|
687
|
+
self.train_transform: Optional[Callable] = None
|
|
688
|
+
self.val_transform: Optional[Callable] = None
|
|
689
|
+
|
|
690
|
+
@classmethod
|
|
691
|
+
def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'SegmentationDatasetMaker':
|
|
692
|
+
"""
|
|
693
|
+
Creates a maker instance by loading all matching image-mask pairs
|
|
694
|
+
from two corresponding directories.
|
|
695
|
+
|
|
696
|
+
This method assumes that for an image `images/img_001.png`, there
|
|
697
|
+
is a corresponding mask `masks/img_001.png`.
|
|
698
|
+
|
|
699
|
+
Args:
|
|
700
|
+
image_dir (str | Path): Path to the directory containing input images.
|
|
701
|
+
mask_dir (str | Path): Path to the directory containing segmentation masks.
|
|
702
|
+
|
|
703
|
+
Returns:
|
|
704
|
+
SegmentationDatasetMaker: A new instance with all pairs loaded.
|
|
705
|
+
"""
|
|
706
|
+
maker = cls()
|
|
707
|
+
img_path_obj = make_fullpath(image_dir, enforce="directory")
|
|
708
|
+
msk_path_obj = make_fullpath(mask_dir, enforce="directory")
|
|
709
|
+
|
|
710
|
+
# Find all images
|
|
711
|
+
image_files = sorted([
|
|
712
|
+
p for p in img_path_obj.glob('*.*')
|
|
713
|
+
if p.suffix.lower() in cls.IMG_EXTENSIONS
|
|
714
|
+
])
|
|
715
|
+
|
|
716
|
+
if not image_files:
|
|
717
|
+
_LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
|
|
718
|
+
raise FileNotFoundError()
|
|
719
|
+
|
|
720
|
+
_LOGGER.info(f"Found {len(image_files)} images. Searching for matching masks in {mask_dir}...")
|
|
721
|
+
|
|
722
|
+
good_img_paths = []
|
|
723
|
+
good_mask_paths = []
|
|
724
|
+
|
|
725
|
+
for img_file in image_files:
|
|
726
|
+
mask_file = None
|
|
727
|
+
|
|
728
|
+
# 1. Try to find mask with the exact same name
|
|
729
|
+
mask_file_primary = msk_path_obj / img_file.name
|
|
730
|
+
if mask_file_primary.exists():
|
|
731
|
+
mask_file = mask_file_primary
|
|
732
|
+
|
|
733
|
+
# 2. If not, try to find mask with same stem + common mask extension
|
|
734
|
+
if mask_file is None:
|
|
735
|
+
for ext in cls.IMG_EXTENSIONS: # Masks are often .png
|
|
736
|
+
mask_file_secondary = msk_path_obj / (img_file.stem + ext)
|
|
737
|
+
if mask_file_secondary.exists():
|
|
738
|
+
mask_file = mask_file_secondary
|
|
739
|
+
break
|
|
740
|
+
|
|
741
|
+
# 3. If a match is found, add the pair
|
|
742
|
+
if mask_file:
|
|
743
|
+
good_img_paths.append(img_file)
|
|
744
|
+
good_mask_paths.append(mask_file)
|
|
745
|
+
else:
|
|
746
|
+
_LOGGER.warning(f"No corresponding mask found for image: {img_file.name}")
|
|
747
|
+
|
|
748
|
+
if not good_img_paths:
|
|
749
|
+
_LOGGER.error("No matching image-mask pairs were found.")
|
|
750
|
+
raise FileNotFoundError()
|
|
751
|
+
|
|
752
|
+
_LOGGER.info(f"Successfully found {len(good_img_paths)} image-mask pairs.")
|
|
753
|
+
maker.image_paths = good_img_paths
|
|
754
|
+
maker.mask_paths = good_mask_paths
|
|
755
|
+
|
|
756
|
+
return maker
|
|
757
|
+
|
|
758
|
+
@staticmethod
|
|
759
|
+
def inspect_folder(path: Union[str, Path]):
|
|
760
|
+
"""
|
|
761
|
+
Logs a report of the types, sizes, and channels of image files
|
|
762
|
+
found in the directory. Useful for checking masks.
|
|
763
|
+
"""
|
|
764
|
+
VisionDatasetMaker.inspect_folder(path)
|
|
765
|
+
|
|
766
|
+
def set_class_map(self, class_map: Dict[str, int]) -> 'SegmentationDatasetMaker':
|
|
767
|
+
"""
|
|
768
|
+
Sets a map of pixel_value -> class_name. This is used by the MLTrainer for clear evaluation reports.
|
|
769
|
+
|
|
770
|
+
Args:
|
|
771
|
+
class_map (Dict[int, str]): A dictionary mapping the integer pixel
|
|
772
|
+
value in a mask to its string name.
|
|
773
|
+
Example: {'background': 0, 'road': 1, 'car': 2}
|
|
774
|
+
"""
|
|
775
|
+
self.class_map = class_map
|
|
776
|
+
_LOGGER.info(f"Class map set: {class_map}")
|
|
777
|
+
return self
|
|
778
|
+
|
|
779
|
+
@property
|
|
780
|
+
def classes(self) -> List[str]:
|
|
781
|
+
"""Returns the list of class names, if set."""
|
|
782
|
+
if self.class_map:
|
|
783
|
+
return list(self.class_map.keys())
|
|
784
|
+
return []
|
|
785
|
+
|
|
786
|
+
def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
|
|
787
|
+
random_state: Optional[int] = 42) -> 'SegmentationDatasetMaker':
|
|
788
|
+
"""
|
|
789
|
+
Splits the loaded image-mask pairs into train, validation, and test sets.
|
|
790
|
+
|
|
791
|
+
Args:
|
|
792
|
+
val_size (float): Proportion of the dataset to reserve for validation.
|
|
793
|
+
test_size (float): Proportion of the dataset to reserve for testing.
|
|
794
|
+
random_state (int | None): Seed for reproducible splits.
|
|
795
|
+
|
|
796
|
+
Returns:
|
|
797
|
+
SegmentationDatasetMaker: The same instance, now with datasets created.
|
|
798
|
+
"""
|
|
799
|
+
if self._is_split:
|
|
800
|
+
_LOGGER.warning("Data has already been split.")
|
|
801
|
+
return self
|
|
802
|
+
|
|
803
|
+
if val_size + test_size >= 1.0:
|
|
804
|
+
_LOGGER.error("The sum of val_size and test_size must be less than 1.")
|
|
805
|
+
raise ValueError()
|
|
806
|
+
|
|
807
|
+
if not self.image_paths:
|
|
808
|
+
_LOGGER.error("There is no data to split. Use .from_folders() first.")
|
|
809
|
+
raise RuntimeError()
|
|
810
|
+
|
|
811
|
+
indices = list(range(len(self.image_paths)))
|
|
812
|
+
|
|
813
|
+
# Split indices
|
|
814
|
+
train_indices, val_test_indices = train_test_split(
|
|
815
|
+
indices, test_size=(val_size + test_size), random_state=random_state
|
|
816
|
+
)
|
|
817
|
+
|
|
818
|
+
# Helper to get paths from indices
|
|
819
|
+
def get_paths(idx_list):
|
|
820
|
+
return [self.image_paths[i] for i in idx_list], [self.mask_paths[i] for i in idx_list]
|
|
821
|
+
|
|
822
|
+
train_imgs, train_masks = get_paths(train_indices)
|
|
823
|
+
|
|
824
|
+
if test_size > 0:
|
|
825
|
+
val_indices, test_indices = train_test_split(
|
|
826
|
+
val_test_indices, test_size=(test_size / (val_size + test_size)),
|
|
827
|
+
random_state=random_state
|
|
828
|
+
)
|
|
829
|
+
val_imgs, val_masks = get_paths(val_indices)
|
|
830
|
+
test_imgs, test_masks = get_paths(test_indices)
|
|
831
|
+
|
|
832
|
+
self._test_dataset = _SegmentationDataset(test_imgs, test_masks, transform=None)
|
|
833
|
+
self._test_dataset.classes = self.classes # type: ignore
|
|
834
|
+
_LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
|
|
835
|
+
else:
|
|
836
|
+
val_imgs, val_masks = get_paths(val_test_indices)
|
|
837
|
+
|
|
838
|
+
self._train_dataset = _SegmentationDataset(train_imgs, train_masks, transform=None)
|
|
839
|
+
self._val_dataset = _SegmentationDataset(val_imgs, val_masks, transform=None)
|
|
840
|
+
|
|
841
|
+
# Propagate class names to datasets for MLTrainer
|
|
842
|
+
self._train_dataset.classes = self.classes # type: ignore
|
|
843
|
+
self._val_dataset.classes = self.classes # type: ignore
|
|
844
|
+
|
|
845
|
+
self._is_split = True
|
|
846
|
+
_LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
|
|
847
|
+
return self
|
|
848
|
+
|
|
849
|
+
def configure_transforms(self,
|
|
850
|
+
resize_size: int = 256,
|
|
851
|
+
crop_size: int = 224,
|
|
852
|
+
mean: List[float] = [0.485, 0.456, 0.406],
|
|
853
|
+
std: List[float] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
|
|
854
|
+
"""
|
|
855
|
+
Configures and applies the image and mask transformations.
|
|
856
|
+
|
|
857
|
+
This method must be called AFTER data is split.
|
|
858
|
+
|
|
859
|
+
Args:
|
|
860
|
+
resize_size (int): The size to resize the smallest edge to
|
|
861
|
+
for validation/testing.
|
|
862
|
+
crop_size (int): The target size (square) for the final
|
|
863
|
+
cropped image.
|
|
864
|
+
mean (List[float]): The mean values for image normalization.
|
|
865
|
+
std (List[float]): The std dev values for image normalization.
|
|
866
|
+
|
|
867
|
+
Returns:
|
|
868
|
+
SegmentationDatasetMaker: The same instance, with transforms applied.
|
|
869
|
+
"""
|
|
870
|
+
if not self._is_split:
|
|
871
|
+
_LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
|
|
872
|
+
raise RuntimeError()
|
|
873
|
+
|
|
874
|
+
# --- Store components for validation recipe ---
|
|
875
|
+
self.val_recipe_components = {
|
|
876
|
+
VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
|
|
877
|
+
VisionTransformRecipeKeys.CROP_SIZE: crop_size,
|
|
878
|
+
VisionTransformRecipeKeys.MEAN: mean,
|
|
879
|
+
VisionTransformRecipeKeys.STD: std
|
|
880
|
+
}
|
|
881
|
+
|
|
882
|
+
# --- Validation/Test Pipeline (Deterministic) ---
|
|
883
|
+
self.val_transform = _PairedCompose([
|
|
884
|
+
_PairedResize(resize_size),
|
|
885
|
+
_PairedCenterCrop(crop_size),
|
|
886
|
+
_PairedToTensor(),
|
|
887
|
+
_PairedNormalize(mean, std)
|
|
888
|
+
])
|
|
889
|
+
|
|
890
|
+
# --- Training Pipeline (Augmentation) ---
|
|
891
|
+
self.train_transform = _PairedCompose([
|
|
892
|
+
_PairedRandomResizedCrop(crop_size),
|
|
893
|
+
_PairedRandomHorizontalFlip(p=0.5),
|
|
894
|
+
_PairedToTensor(),
|
|
895
|
+
_PairedNormalize(mean, std)
|
|
896
|
+
])
|
|
897
|
+
|
|
898
|
+
# --- Apply Transforms to the Datasets ---
|
|
899
|
+
self._train_dataset.transform = self.train_transform # type: ignore
|
|
900
|
+
self._val_dataset.transform = self.val_transform # type: ignore
|
|
901
|
+
if self._test_dataset:
|
|
902
|
+
self._test_dataset.transform = self.val_transform # type: ignore
|
|
903
|
+
|
|
904
|
+
self._are_transforms_configured = True
|
|
905
|
+
_LOGGER.info("Paired segmentation transforms configured and applied.")
|
|
906
|
+
return self
|
|
907
|
+
|
|
908
|
+
def get_datasets(self) -> Tuple[Dataset, ...]:
|
|
909
|
+
"""
|
|
910
|
+
Returns the final train, validation, and optional test datasets.
|
|
911
|
+
|
|
912
|
+
Raises:
|
|
913
|
+
RuntimeError: If called before data is split.
|
|
914
|
+
RuntimeError: If called before transforms are configured.
|
|
915
|
+
"""
|
|
916
|
+
if not self._is_split:
|
|
917
|
+
_LOGGER.error("Data has not been split. Call .split_data() first.")
|
|
918
|
+
raise RuntimeError()
|
|
919
|
+
if not self._are_transforms_configured:
|
|
920
|
+
_LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
|
|
921
|
+
raise RuntimeError()
|
|
922
|
+
|
|
923
|
+
if self._test_dataset:
|
|
924
|
+
return self._train_dataset, self._val_dataset, self._test_dataset
|
|
925
|
+
return self._train_dataset, self._val_dataset
|
|
926
|
+
|
|
927
|
+
def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
|
|
928
|
+
"""
|
|
929
|
+
Saves the validation transform pipeline as a JSON recipe file.
|
|
930
|
+
|
|
931
|
+
This recipe can be loaded by the PyTorchVisionInferenceHandler
|
|
932
|
+
to ensure identical preprocessing.
|
|
933
|
+
|
|
934
|
+
Args:
|
|
935
|
+
filepath (str | Path): The path to save the .json recipe file.
|
|
936
|
+
"""
|
|
937
|
+
if not self._are_transforms_configured:
|
|
938
|
+
_LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
|
|
939
|
+
raise RuntimeError()
|
|
940
|
+
|
|
941
|
+
components = self.val_recipe_components
|
|
942
|
+
|
|
943
|
+
if not components:
|
|
944
|
+
_LOGGER.error(f"Error getting the transformers recipe for validation set.")
|
|
945
|
+
raise ValueError()
|
|
946
|
+
|
|
947
|
+
# validate path
|
|
948
|
+
file_path = make_fullpath(filepath, make=True, enforce="file")
|
|
949
|
+
|
|
950
|
+
# Add standard transforms
|
|
951
|
+
recipe: Dict[str, Any] = {
|
|
952
|
+
VisionTransformRecipeKeys.TASK: "segmentation",
|
|
953
|
+
VisionTransformRecipeKeys.PIPELINE: [
|
|
954
|
+
{VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components["resize_size"]}},
|
|
955
|
+
{VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components["crop_size"]}},
|
|
956
|
+
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
|
|
957
|
+
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
958
|
+
"mean": components["mean"],
|
|
959
|
+
"std": components["std"]
|
|
960
|
+
}}
|
|
961
|
+
]
|
|
962
|
+
}
|
|
963
|
+
|
|
964
|
+
# Save the file
|
|
965
|
+
save_recipe(recipe, file_path)
|
|
966
|
+
|
|
967
|
+
|
|
968
|
+
# Object detection
|
|
969
|
+
def _od_collate_fn(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]) -> Tuple[List[torch.Tensor], List[Dict[str, torch.Tensor]]]:
|
|
970
|
+
"""
|
|
971
|
+
Custom collate function for object detection.
|
|
972
|
+
|
|
973
|
+
Takes a list of (image, target) tuples and zips them into two lists:
|
|
974
|
+
(list_of_images, list_of_targets).
|
|
975
|
+
This is required for models like Faster R-CNN, which accept a list
|
|
976
|
+
of images of varying sizes.
|
|
977
|
+
"""
|
|
978
|
+
return tuple(zip(*batch)) # type: ignore
|
|
979
|
+
|
|
980
|
+
|
|
981
|
+
class _ObjectDetectionDataset(Dataset):
|
|
982
|
+
"""
|
|
983
|
+
Internal helper class to load image-annotation pairs.
|
|
984
|
+
|
|
985
|
+
Loads an image as 'RGB' and parses its corresponding JSON annotation file
|
|
986
|
+
to create the required target dictionary (boxes, labels).
|
|
987
|
+
"""
|
|
988
|
+
def __init__(self, image_paths: List[Path], annotation_paths: List[Path], transform: Optional[Callable] = None):
|
|
989
|
+
self.image_paths = image_paths
|
|
990
|
+
self.annotation_paths = annotation_paths
|
|
991
|
+
self.transform = transform
|
|
992
|
+
|
|
993
|
+
# --- Propagate 'classes' if they exist (for MLTrainer) ---
|
|
994
|
+
self.classes: List[str] = []
|
|
995
|
+
|
|
996
|
+
def __len__(self):
|
|
997
|
+
return len(self.image_paths)
|
|
998
|
+
|
|
999
|
+
def __getitem__(self, idx):
|
|
1000
|
+
img_path = self.image_paths[idx]
|
|
1001
|
+
ann_path = self.annotation_paths[idx]
|
|
1002
|
+
|
|
1003
|
+
try:
|
|
1004
|
+
# Open image
|
|
1005
|
+
image = Image.open(img_path).convert("RGB")
|
|
1006
|
+
|
|
1007
|
+
# Load and parse annotation
|
|
1008
|
+
with open(ann_path, 'r') as f:
|
|
1009
|
+
ann_data = json.load(f)
|
|
1010
|
+
|
|
1011
|
+
# Get boxes and labels from JSON
|
|
1012
|
+
boxes = ann_data[ObjectDetectionKeys.BOXES] # Expected: [[x1, y1, x2, y2], ...]
|
|
1013
|
+
labels = ann_data[ObjectDetectionKeys.LABELS] # Expected: [1, 2, 1, ...]
|
|
1014
|
+
|
|
1015
|
+
# Convert to tensors
|
|
1016
|
+
target: Dict[str, Any] = {}
|
|
1017
|
+
target[ObjectDetectionKeys.BOXES] = torch.as_tensor(boxes, dtype=torch.float32)
|
|
1018
|
+
target[ObjectDetectionKeys.LABELS] = torch.as_tensor(labels, dtype=torch.int64)
|
|
1019
|
+
|
|
1020
|
+
except Exception as e:
|
|
1021
|
+
_LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {ann_path.name}. Error: {e}")
|
|
1022
|
+
# Return empty/dummy data
|
|
1023
|
+
return torch.empty(3, 224, 224), {ObjectDetectionKeys.BOXES: torch.empty((0, 4)), ObjectDetectionKeys.LABELS: torch.empty(0, dtype=torch.long)}
|
|
1024
|
+
|
|
1025
|
+
if self.transform:
|
|
1026
|
+
image, target = self.transform(image, target)
|
|
1027
|
+
|
|
1028
|
+
return image, target
|
|
1029
|
+
|
|
1030
|
+
# Internal Paired Transform Helpers for Object Detection
|
|
1031
|
+
class _OD_PairedCompose:
|
|
1032
|
+
"""A 'Compose' for paired image/target_dict transforms."""
|
|
1033
|
+
def __init__(self, transforms: List[Callable]):
|
|
1034
|
+
self.transforms = transforms
|
|
1035
|
+
|
|
1036
|
+
def __call__(self, image: Any, target: Any) -> Tuple[Any, Any]:
|
|
1037
|
+
for t in self.transforms:
|
|
1038
|
+
image, target = t(image, target)
|
|
1039
|
+
return image, target
|
|
1040
|
+
|
|
1041
|
+
class _OD_PairedToTensor:
|
|
1042
|
+
"""Converts a PIL Image to Tensor, passes targets dict through."""
|
|
1043
|
+
def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
|
1044
|
+
return TF.to_tensor(image), target
|
|
1045
|
+
|
|
1046
|
+
class _OD_PairedNormalize:
|
|
1047
|
+
"""Normalizes the image tensor and leaves the target dict untouched."""
|
|
1048
|
+
def __init__(self, mean: List[float], std: List[float]):
|
|
1049
|
+
self.normalize = transforms.Normalize(mean, std)
|
|
1050
|
+
|
|
1051
|
+
def __call__(self, image: torch.Tensor, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
|
|
1052
|
+
image_normalized = self.normalize(image)
|
|
1053
|
+
return image_normalized, target
|
|
1054
|
+
|
|
1055
|
+
class _OD_PairedRandomHorizontalFlip:
|
|
1056
|
+
"""Applies the same random horizontal flip to both image and targets['boxes']."""
|
|
1057
|
+
def __init__(self, p: float = 0.5):
|
|
1058
|
+
self.p = p
|
|
1059
|
+
|
|
1060
|
+
def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[Image.Image, Dict[str, Any]]:
|
|
1061
|
+
if random.random() < self.p:
|
|
1062
|
+
w, h = image.size
|
|
1063
|
+
# Use new variable names to avoid linter confusion
|
|
1064
|
+
flipped_image = TF.hflip(image) # type: ignore
|
|
1065
|
+
|
|
1066
|
+
# Flip boxes
|
|
1067
|
+
boxes = target[ObjectDetectionKeys.BOXES].clone() # [N, 4]
|
|
1068
|
+
|
|
1069
|
+
# xmin' = w - xmax
|
|
1070
|
+
# xmax' = w - xmin
|
|
1071
|
+
boxes[:, 0] = w - target[ObjectDetectionKeys.BOXES][:, 2] # xmin'
|
|
1072
|
+
boxes[:, 2] = w - target[ObjectDetectionKeys.BOXES][:, 0] # xmax'
|
|
1073
|
+
target[ObjectDetectionKeys.BOXES] = boxes
|
|
1074
|
+
|
|
1075
|
+
return flipped_image, target # type: ignore
|
|
1076
|
+
|
|
1077
|
+
return image, target
|
|
1078
|
+
|
|
1079
|
+
|
|
1080
|
+
class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
1081
|
+
"""
|
|
1082
|
+
Creates processed PyTorch datasets for object detection from image
|
|
1083
|
+
and JSON annotation folders.
|
|
1084
|
+
|
|
1085
|
+
This maker finds all matching image-annotation pairs from two directories,
|
|
1086
|
+
splits them, and applies identical transformations (including augmentations)
|
|
1087
|
+
to both the image and its corresponding target dictionary.
|
|
1088
|
+
|
|
1089
|
+
The `DragonFastRCNN` model expects a list of images and a list of targets,
|
|
1090
|
+
so this class provides a `collate_fn` to be used with a DataLoader.
|
|
1091
|
+
|
|
1092
|
+
Workflow:
|
|
1093
|
+
1. `maker = ObjectDetectionDatasetMaker.from_folders(img_dir, ann_dir)`
|
|
1094
|
+
2. `maker.set_class_map({'background': 0, 'person': 1, 'car': 2})`
|
|
1095
|
+
3. `maker.split_data(val_size=0.2)`
|
|
1096
|
+
4. `maker.configure_transforms()`
|
|
1097
|
+
5. `train_ds, val_ds = maker.get_datasets()`
|
|
1098
|
+
6. `collate_fn = maker.collate_fn`
|
|
1099
|
+
7. `train_loader = DataLoader(train_ds, ..., collate_fn=collate_fn)`
|
|
1100
|
+
"""
|
|
1101
|
+
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
|
|
1102
|
+
|
|
1103
|
+
def __init__(self):
|
|
1104
|
+
"""
|
|
1105
|
+
Typically not called directly. Use the class method `from_folders()` to create an instance.
|
|
1106
|
+
"""
|
|
1107
|
+
super().__init__()
|
|
1108
|
+
self.image_paths: List[Path] = []
|
|
1109
|
+
self.annotation_paths: List[Path] = []
|
|
1110
|
+
self.class_map: Dict[str, int] = {}
|
|
1111
|
+
|
|
1112
|
+
self._is_split = False
|
|
1113
|
+
self._are_transforms_configured = False
|
|
1114
|
+
self.train_transform: Optional[Callable] = None
|
|
1115
|
+
self.val_transform: Optional[Callable] = None
|
|
1116
|
+
self._val_recipe_components: Optional[Dict[str, Any]] = None
|
|
1117
|
+
|
|
1118
|
+
@classmethod
|
|
1119
|
+
def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'ObjectDetectionDatasetMaker':
|
|
1120
|
+
"""
|
|
1121
|
+
Creates a maker instance by loading all matching image-annotation pairs
|
|
1122
|
+
from two corresponding directories.
|
|
1123
|
+
|
|
1124
|
+
This method assumes that for an image `images/img_001.png`, there
|
|
1125
|
+
is a corresponding annotation `annotations/img_001.json`.
|
|
1126
|
+
|
|
1127
|
+
The JSON file must contain "boxes" and "labels" keys:
|
|
1128
|
+
`{"boxes": [[x1,y1,x2,y2], ...], "labels": [1, 2, ...]}`
|
|
1129
|
+
|
|
1130
|
+
Args:
|
|
1131
|
+
image_dir (str | Path): Path to the directory containing input images.
|
|
1132
|
+
annotation_dir (str | Path): Path to the directory containing .json
|
|
1133
|
+
annotation files.
|
|
1134
|
+
|
|
1135
|
+
Returns:
|
|
1136
|
+
ObjectDetectionDatasetMaker: A new instance with all pairs loaded.
|
|
1137
|
+
"""
|
|
1138
|
+
maker = cls()
|
|
1139
|
+
img_path_obj = make_fullpath(image_dir, enforce="directory")
|
|
1140
|
+
ann_path_obj = make_fullpath(annotation_dir, enforce="directory")
|
|
1141
|
+
|
|
1142
|
+
# Find all images
|
|
1143
|
+
image_files = sorted([
|
|
1144
|
+
p for p in img_path_obj.glob('*.*')
|
|
1145
|
+
if p.suffix.lower() in cls.IMG_EXTENSIONS
|
|
1146
|
+
])
|
|
1147
|
+
|
|
1148
|
+
if not image_files:
|
|
1149
|
+
_LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
|
|
1150
|
+
raise FileNotFoundError()
|
|
1151
|
+
|
|
1152
|
+
_LOGGER.info(f"Found {len(image_files)} images. Searching for matching .json annotations in {annotation_dir}...")
|
|
1153
|
+
|
|
1154
|
+
good_img_paths = []
|
|
1155
|
+
good_ann_paths = []
|
|
1156
|
+
|
|
1157
|
+
for img_file in image_files:
|
|
1158
|
+
# Find annotation with same stem + .json
|
|
1159
|
+
ann_file = ann_path_obj / (img_file.stem + ".json")
|
|
1160
|
+
|
|
1161
|
+
if ann_file.exists():
|
|
1162
|
+
good_img_paths.append(img_file)
|
|
1163
|
+
good_ann_paths.append(ann_file)
|
|
1164
|
+
else:
|
|
1165
|
+
_LOGGER.warning(f"No corresponding .json annotation found for image: {img_file.name}")
|
|
1166
|
+
|
|
1167
|
+
if not good_img_paths:
|
|
1168
|
+
_LOGGER.error("No matching image-annotation pairs were found.")
|
|
1169
|
+
raise FileNotFoundError()
|
|
1170
|
+
|
|
1171
|
+
_LOGGER.info(f"Successfully found {len(good_img_paths)} image-annotation pairs.")
|
|
1172
|
+
maker.image_paths = good_img_paths
|
|
1173
|
+
maker.annotation_paths = good_ann_paths
|
|
1174
|
+
|
|
1175
|
+
return maker
|
|
1176
|
+
|
|
1177
|
+
@staticmethod
|
|
1178
|
+
def inspect_folder(path: Union[str, Path]):
|
|
1179
|
+
"""
|
|
1180
|
+
Logs a report of the types, sizes, and channels of image files
|
|
1181
|
+
found in the directory.
|
|
1182
|
+
"""
|
|
1183
|
+
VisionDatasetMaker.inspect_folder(path)
|
|
1184
|
+
|
|
1185
|
+
def set_class_map(self, class_map: Dict[str, int]) -> 'ObjectDetectionDatasetMaker':
|
|
1186
|
+
"""
|
|
1187
|
+
Sets a map of class_name -> pixel_value. This is used by the
|
|
1188
|
+
MLTrainer for clear evaluation reports.
|
|
1189
|
+
|
|
1190
|
+
**Important:** For object detection models, 'background' MUST
|
|
1191
|
+
be included as class 0.
|
|
1192
|
+
Example: `{'background': 0, 'person': 1, 'car': 2}`
|
|
1193
|
+
|
|
1194
|
+
Args:
|
|
1195
|
+
class_map (Dict[str, int]): A dictionary mapping the string name
|
|
1196
|
+
to its integer label.
|
|
1197
|
+
"""
|
|
1198
|
+
if 'background' not in class_map or class_map['background'] != 0:
|
|
1199
|
+
_LOGGER.warning("Object detection class map should include 'background' mapped to 0.")
|
|
1200
|
+
|
|
1201
|
+
self.class_map = class_map
|
|
1202
|
+
_LOGGER.info(f"Class map set: {class_map}")
|
|
1203
|
+
return self
|
|
1204
|
+
|
|
1205
|
+
@property
|
|
1206
|
+
def classes(self) -> List[str]:
|
|
1207
|
+
"""Returns the list of class names, if set."""
|
|
1208
|
+
if self.class_map:
|
|
1209
|
+
return list(self.class_map.keys())
|
|
1210
|
+
return []
|
|
1211
|
+
|
|
1212
|
+
def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
|
|
1213
|
+
random_state: Optional[int] = 42) -> 'ObjectDetectionDatasetMaker':
|
|
1214
|
+
"""
|
|
1215
|
+
Splits the loaded image-annotation pairs into train, validation, and test sets.
|
|
1216
|
+
|
|
1217
|
+
Args:
|
|
1218
|
+
val_size (float): Proportion of the dataset to reserve for validation.
|
|
1219
|
+
test_size (float): Proportion of the dataset to reserve for testing.
|
|
1220
|
+
random_state (int | None): Seed for reproducible splits.
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
ObjectDetectionDatasetMaker: The same instance, now with datasets created.
|
|
1224
|
+
"""
|
|
1225
|
+
if self._is_split:
|
|
1226
|
+
_LOGGER.warning("Data has already been split.")
|
|
1227
|
+
return self
|
|
1228
|
+
|
|
1229
|
+
if val_size + test_size >= 1.0:
|
|
1230
|
+
_LOGGER.error("The sum of val_size and test_size must be less than 1.")
|
|
1231
|
+
raise ValueError()
|
|
1232
|
+
|
|
1233
|
+
if not self.image_paths:
|
|
1234
|
+
_LOGGER.error("There is no data to split. Use .from_folders() first.")
|
|
1235
|
+
raise RuntimeError()
|
|
1236
|
+
|
|
1237
|
+
indices = list(range(len(self.image_paths)))
|
|
1238
|
+
|
|
1239
|
+
# Split indices
|
|
1240
|
+
train_indices, val_test_indices = train_test_split(
|
|
1241
|
+
indices, test_size=(val_size + test_size), random_state=random_state
|
|
1242
|
+
)
|
|
1243
|
+
|
|
1244
|
+
# Helper to get paths from indices
|
|
1245
|
+
def get_paths(idx_list):
|
|
1246
|
+
return [self.image_paths[i] for i in idx_list], [self.annotation_paths[i] for i in idx_list]
|
|
1247
|
+
|
|
1248
|
+
train_imgs, train_anns = get_paths(train_indices)
|
|
1249
|
+
|
|
1250
|
+
if test_size > 0:
|
|
1251
|
+
val_indices, test_indices = train_test_split(
|
|
1252
|
+
val_test_indices, test_size=(test_size / (val_size + test_size)),
|
|
1253
|
+
random_state=random_state
|
|
1254
|
+
)
|
|
1255
|
+
val_imgs, val_anns = get_paths(val_indices)
|
|
1256
|
+
test_imgs, test_anns = get_paths(test_indices)
|
|
1257
|
+
|
|
1258
|
+
self._test_dataset = _ObjectDetectionDataset(test_imgs, test_anns, transform=None)
|
|
1259
|
+
self._test_dataset.classes = self.classes # type: ignore
|
|
1260
|
+
_LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
|
|
1261
|
+
else:
|
|
1262
|
+
val_imgs, val_anns = get_paths(val_test_indices)
|
|
1263
|
+
|
|
1264
|
+
self._train_dataset = _ObjectDetectionDataset(train_imgs, train_anns, transform=None)
|
|
1265
|
+
self._val_dataset = _ObjectDetectionDataset(val_imgs, val_anns, transform=None)
|
|
1266
|
+
|
|
1267
|
+
# Propagate class names to datasets for MLTrainer
|
|
1268
|
+
self._train_dataset.classes = self.classes # type: ignore
|
|
1269
|
+
self._val_dataset.classes = self.classes # type: ignore
|
|
1270
|
+
|
|
1271
|
+
self._is_split = True
|
|
1272
|
+
_LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
|
|
1273
|
+
return self
|
|
1274
|
+
|
|
1275
|
+
def configure_transforms(self,
|
|
1276
|
+
mean: List[float] = [0.485, 0.456, 0.406],
|
|
1277
|
+
std: List[float] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
|
|
1278
|
+
"""
|
|
1279
|
+
Configures and applies the image and target transformations.
|
|
1280
|
+
|
|
1281
|
+
This method must be called AFTER data is split.
|
|
1282
|
+
|
|
1283
|
+
For object detection models like Faster R-CNN, images are NOT
|
|
1284
|
+
resized or cropped, as the model handles variable input sizes.
|
|
1285
|
+
Transforms are limited to augmentation (flip), ToTensor, and Normalize.
|
|
1286
|
+
|
|
1287
|
+
Args:
|
|
1288
|
+
mean (List[float]): The mean values for image normalization.
|
|
1289
|
+
std (List[float]): The std dev values for image normalization.
|
|
1290
|
+
|
|
1291
|
+
Returns:
|
|
1292
|
+
ObjectDetectionDatasetMaker: The same instance, with transforms applied.
|
|
1293
|
+
"""
|
|
1294
|
+
if not self._is_split:
|
|
1295
|
+
_LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
|
|
1296
|
+
raise RuntimeError()
|
|
1297
|
+
|
|
1298
|
+
# --- Store components for validation recipe ---
|
|
1299
|
+
self._val_recipe_components = {
|
|
1300
|
+
VisionTransformRecipeKeys.MEAN: mean,
|
|
1301
|
+
VisionTransformRecipeKeys.STD: std
|
|
1302
|
+
}
|
|
1303
|
+
|
|
1304
|
+
# --- Validation/Test Pipeline (Deterministic) ---
|
|
1305
|
+
self.val_transform = _OD_PairedCompose([
|
|
1306
|
+
_OD_PairedToTensor(),
|
|
1307
|
+
_OD_PairedNormalize(mean, std)
|
|
1308
|
+
])
|
|
1309
|
+
|
|
1310
|
+
# --- Training Pipeline (Augmentation) ---
|
|
1311
|
+
self.train_transform = _OD_PairedCompose([
|
|
1312
|
+
_OD_PairedRandomHorizontalFlip(p=0.5),
|
|
1313
|
+
_OD_PairedToTensor(),
|
|
1314
|
+
_OD_PairedNormalize(mean, std)
|
|
1315
|
+
])
|
|
1316
|
+
|
|
1317
|
+
# --- Apply Transforms to the Datasets ---
|
|
1318
|
+
self._train_dataset.transform = self.train_transform # type: ignore
|
|
1319
|
+
self._val_dataset.transform = self.val_transform # type: ignore
|
|
1320
|
+
if self._test_dataset:
|
|
1321
|
+
self._test_dataset.transform = self.val_transform # type: ignore
|
|
1322
|
+
|
|
1323
|
+
self._are_transforms_configured = True
|
|
1324
|
+
_LOGGER.info("Paired object detection transforms configured and applied.")
|
|
1325
|
+
return self
|
|
1326
|
+
|
|
1327
|
+
def get_datasets(self) -> Tuple[Dataset, ...]:
|
|
1328
|
+
"""
|
|
1329
|
+
Returns the final train, validation, and optional test datasets.
|
|
1330
|
+
|
|
1331
|
+
Raises:
|
|
1332
|
+
RuntimeError: If called before data is split.
|
|
1333
|
+
RuntimeError: If called before transforms are configured.
|
|
1334
|
+
"""
|
|
1335
|
+
if not self._is_split:
|
|
1336
|
+
_LOGGER.error("Data has not been split. Call .split_data() first.")
|
|
1337
|
+
raise RuntimeError()
|
|
1338
|
+
if not self._are_transforms_configured:
|
|
1339
|
+
_LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
|
|
1340
|
+
raise RuntimeError()
|
|
1341
|
+
|
|
1342
|
+
if self._test_dataset:
|
|
1343
|
+
return self._train_dataset, self._val_dataset, self._test_dataset
|
|
1344
|
+
return self._train_dataset, self._val_dataset
|
|
1345
|
+
|
|
1346
|
+
@property
|
|
1347
|
+
def collate_fn(self) -> Callable:
|
|
1348
|
+
"""
|
|
1349
|
+
Returns the collate function required by a DataLoader for this
|
|
1350
|
+
dataset. This function ensures that images and targets are
|
|
1351
|
+
batched as separate lists.
|
|
1352
|
+
"""
|
|
1353
|
+
return _od_collate_fn
|
|
1354
|
+
|
|
1355
|
+
def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
|
|
1356
|
+
"""
|
|
1357
|
+
Saves the validation transform pipeline as a JSON recipe file.
|
|
1358
|
+
|
|
1359
|
+
For object detection, this recipe only includes ToTensor and
|
|
1360
|
+
Normalize, as resizing is handled by the model.
|
|
1361
|
+
|
|
1362
|
+
Args:
|
|
1363
|
+
filepath (str | Path): The path to save the .json recipe file.
|
|
1364
|
+
"""
|
|
1365
|
+
if not self._are_transforms_configured:
|
|
1366
|
+
_LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
|
|
1367
|
+
raise RuntimeError()
|
|
1368
|
+
|
|
1369
|
+
components = self._val_recipe_components
|
|
1370
|
+
|
|
1371
|
+
if not components:
|
|
1372
|
+
_LOGGER.error(f"Error getting the transformers recipe for validation set.")
|
|
1373
|
+
raise ValueError()
|
|
1374
|
+
|
|
1375
|
+
# validate path
|
|
1376
|
+
file_path = make_fullpath(filepath, make=True, enforce="file")
|
|
1377
|
+
|
|
1378
|
+
# Add standard transforms
|
|
1379
|
+
recipe: Dict[str, Any] = {
|
|
1380
|
+
VisionTransformRecipeKeys.TASK: "object_detection",
|
|
1381
|
+
VisionTransformRecipeKeys.PIPELINE: [
|
|
1382
|
+
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
|
|
1383
|
+
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
1384
|
+
"mean": components["mean"],
|
|
1385
|
+
"std": components["std"]
|
|
1386
|
+
}}
|
|
1387
|
+
]
|
|
1388
|
+
}
|
|
1389
|
+
|
|
1390
|
+
# Save the file
|
|
1391
|
+
save_recipe(recipe, file_path)
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
def info():
|
|
1395
|
+
_script_info(__all__)
|