dragon-ml-toolbox 13.7.0__py3-none-any.whl → 14.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -0,0 +1,1315 @@
1
+ import torch
2
+ from torch.utils.data import Dataset, Subset
3
+ import numpy
4
+ from sklearn.model_selection import train_test_split
5
+ from typing import Union, Tuple, List, Optional, Callable, Dict, Any
6
+ from PIL import Image
7
+ from torchvision.datasets import ImageFolder
8
+ from torchvision import transforms
9
+ import torchvision.transforms.functional as TF
10
+ from pathlib import Path
11
+ import random
12
+ import json
13
+
14
+ from .ML_datasetmaster import _BaseMaker
15
+ from .path_manager import make_fullpath
16
+ from ._logger import _LOGGER
17
+ from ._script_info import _script_info
18
+ from .keys import VisionTransformRecipeKeys, ObjectDetectionKeys
19
+ from ._ML_vision_recipe import save_recipe
20
+ from .ML_vision_transformers import TRANSFORM_REGISTRY
21
+
22
+
23
+ __all__ = [
24
+ "VisionDatasetMaker",
25
+ "SegmentationDatasetMaker",
26
+ "ObjectDetectionDatasetMaker"
27
+ ]
28
+
29
+
30
+ # --- VisionDatasetMaker ---
31
+ class VisionDatasetMaker(_BaseMaker):
32
+ """
33
+ Creates processed PyTorch datasets for computer vision tasks from an
34
+ image folder directory.
35
+
36
+ Supports two modes:
37
+ 1. `from_folder()`: Loads from one directory and splits into train/val/test.
38
+ 2. `from_folders()`: Loads from pre-split train/val/test directories.
39
+
40
+ Uses online augmentations per epoch (image augmentation without creating new files).
41
+ """
42
+ def __init__(self):
43
+ """
44
+ Typically not called directly. Use the class methods `from_folder()` or `from_folders()` to create an instance.
45
+ """
46
+ super().__init__()
47
+ self._full_dataset: Optional[ImageFolder] = None
48
+ self.labels: Optional[List[int]] = None
49
+ self.class_map: Optional[dict[str,int]] = None
50
+
51
+ self._is_split = False
52
+ self._are_transforms_configured = False
53
+ self._val_recipe_components = None
54
+
55
+ @classmethod
56
+ def from_folder(cls, root_dir: str) -> 'VisionDatasetMaker':
57
+ """
58
+ Creates a maker instance from a single root directory of images.
59
+
60
+ This method assumes a single directory (e.g., 'data/') that
61
+ contains class subfolders (e.g., 'data/cat/', 'data/dog/').
62
+
63
+ The dataset will be loaded in its entirety, and you MUST call
64
+ `.split_data()` afterward to create train/validation/test sets.
65
+
66
+ Args:
67
+ root_dir (str): The path to the root directory containing
68
+ class subfolders.
69
+
70
+ Returns:
71
+ VisionDatasetMaker: A new instance with the full dataset loaded.
72
+ """
73
+ # Load with NO transform. We get PIL Images.
74
+ full_dataset = ImageFolder(root=root_dir, transform=None)
75
+ _LOGGER.info(f"Found {len(full_dataset)} images in {len(full_dataset.classes)} classes.")
76
+
77
+ maker = cls()
78
+ maker._full_dataset = full_dataset
79
+ maker.labels = [s[1] for s in full_dataset.samples]
80
+ maker.class_map = full_dataset.class_to_idx
81
+ return maker
82
+
83
+ @classmethod
84
+ def from_folders(cls,
85
+ train_dir: str,
86
+ val_dir: str,
87
+ test_dir: Optional[str] = None) -> 'VisionDatasetMaker':
88
+ """
89
+ Creates a maker instance from separate, pre-split directories.
90
+
91
+ This method is used when you already have 'train', 'val', and
92
+ optionally 'test' folders, each containing class subfolders.
93
+ It bypasses the need for `.split_data()`.
94
+
95
+ Args:
96
+ train_dir (str): Path to the training data directory.
97
+ val_dir (str): Path to the validation data directory.
98
+ test_dir (str, None): Path to the test data directory.
99
+
100
+ Returns:
101
+ VisionDatasetMaker: A new, pre-split instance.
102
+
103
+ Raises:
104
+ ValueError: If the classes found in train, val, or test directories are inconsistent.
105
+ """
106
+ _LOGGER.info("Loading data from separate directories.")
107
+ # Load with NO transform. We get PIL Images.
108
+ train_ds = ImageFolder(root=train_dir, transform=None)
109
+ val_ds = ImageFolder(root=val_dir, transform=None)
110
+
111
+ # Check for class consistency
112
+ if train_ds.class_to_idx != val_ds.class_to_idx:
113
+ _LOGGER.error("Train and validation directories have different or inconsistent classes.")
114
+ raise ValueError()
115
+
116
+ maker = cls()
117
+ maker._train_dataset = train_ds
118
+ maker._val_dataset = val_ds
119
+ maker.class_map = train_ds.class_to_idx
120
+
121
+ if test_dir:
122
+ test_ds = ImageFolder(root=test_dir, transform=None)
123
+ if train_ds.class_to_idx != test_ds.class_to_idx:
124
+ _LOGGER.error("Train and test directories have different or inconsistent classes.")
125
+ raise ValueError()
126
+ maker._test_dataset = test_ds
127
+ _LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test images.")
128
+ else:
129
+ _LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val images.")
130
+
131
+ maker._is_split = True # Mark as "split" since data is pre-split
132
+ return maker
133
+
134
+ @staticmethod
135
+ def inspect_folder(path: Union[str, Path]):
136
+ """
137
+ Logs a report of the types, sizes, and channels of image files
138
+ found in the directory and its subdirectories.
139
+
140
+ This is a utility method to help diagnose potential dataset
141
+ issues (e.g., mixed image modes, corrupted files) before loading.
142
+
143
+ Args:
144
+ path (str, Path): The directory path to inspect.
145
+ """
146
+ path_obj = make_fullpath(path)
147
+
148
+ non_image_files = set()
149
+ img_types = set()
150
+ img_sizes = set()
151
+ img_channels = set()
152
+ img_counter = 0
153
+
154
+ _LOGGER.info(f"Inspecting folder: {path_obj}...")
155
+ # Use rglob to recursively find all files
156
+ for filepath in path_obj.rglob('*'):
157
+ if filepath.is_file():
158
+ try:
159
+ # Using PIL to open is a more reliable check
160
+ with Image.open(filepath) as img:
161
+ img_types.add(img.format)
162
+ img_sizes.add(img.size)
163
+ img_channels.update(img.getbands())
164
+ img_counter += 1
165
+ except (IOError, SyntaxError):
166
+ non_image_files.add(filepath.name)
167
+
168
+ if non_image_files:
169
+ _LOGGER.warning(f"Non-image or corrupted files found and ignored: {non_image_files}")
170
+
171
+ report = (
172
+ f"\n--- Inspection Report for '{path_obj.name}' ---\n"
173
+ f"Total images found: {img_counter}\n"
174
+ f"Image formats: {img_types or 'None'}\n"
175
+ f"Image sizes (WxH): {img_sizes or 'None'}\n"
176
+ f"Image channels (bands): {img_channels or 'None'}\n"
177
+ f"--------------------------------------"
178
+ )
179
+ print(report)
180
+
181
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
182
+ stratify: bool = True, random_state: Optional[int] = None) -> 'VisionDatasetMaker':
183
+ """
184
+ Splits the dataset into train, validation, and optional test sets.
185
+
186
+ This method MUST be called if you used `from_folder()`. It has no effect if you used `from_folders()`.
187
+
188
+ Args:
189
+ val_size (float): Proportion of the dataset to reserve for
190
+ validation (e.g., 0.2 for 20%).
191
+ test_size (float): Proportion of the dataset to reserve for
192
+ testing.
193
+ stratify (bool): If True, splits are performed in a stratified
194
+ fashion, preserving the class distribution.
195
+ random_state (int | None): Seed for the random number generator for reproducible splits.
196
+
197
+ Returns:
198
+ VisionDatasetMaker: The same instance, now with datasets split.
199
+
200
+ Raises:
201
+ ValueError: If `val_size` and `test_size` sum to 1.0 or more.
202
+ """
203
+ if self._is_split:
204
+ _LOGGER.warning("Data has already been split.")
205
+ return self
206
+
207
+ if val_size + test_size >= 1.0:
208
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
209
+ raise ValueError()
210
+
211
+ if not self._full_dataset:
212
+ _LOGGER.error("There is no dataset to split.")
213
+ raise ValueError()
214
+
215
+ indices = list(range(len(self._full_dataset)))
216
+ labels_for_split = self.labels if stratify else None
217
+
218
+ train_indices, val_test_indices = train_test_split(
219
+ indices, test_size=(val_size + test_size), random_state=random_state, stratify=labels_for_split
220
+ )
221
+
222
+ if not self.labels:
223
+ _LOGGER.error("Error when getting full dataset labels.")
224
+ raise ValueError()
225
+
226
+ if test_size > 0:
227
+ val_test_labels = [self.labels[i] for i in val_test_indices]
228
+ stratify_val_test = val_test_labels if stratify else None
229
+ val_indices, test_indices = train_test_split(
230
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
231
+ random_state=random_state, stratify=stratify_val_test
232
+ )
233
+ self._test_dataset = Subset(self._full_dataset, test_indices)
234
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
235
+ else:
236
+ val_indices = val_test_indices
237
+
238
+ self._train_dataset = Subset(self._full_dataset, train_indices)
239
+ self._val_dataset = Subset(self._full_dataset, val_indices)
240
+ self._is_split = True
241
+
242
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
243
+ return self
244
+
245
+ def configure_transforms(self, resize_size: int = 256, crop_size: int = 224,
246
+ mean: List[float] = [0.485, 0.456, 0.406],
247
+ std: List[float] = [0.229, 0.224, 0.225],
248
+ pre_transforms: Optional[List[Callable]] = None,
249
+ extra_train_transforms: Optional[List[Callable]] = None) -> 'VisionDatasetMaker':
250
+ """
251
+ Configures and applies the image transformations and augmentations.
252
+
253
+ This method must be called AFTER data is loaded and split.
254
+
255
+ It sets up two pipelines:
256
+ 1. **Training Pipeline:** Includes random augmentations like
257
+ `RandomResizedCrop` and `RandomHorizontalFlip` (plus any
258
+ `extra_train_transforms`) for online augmentation.
259
+ 2. **Validation/Test Pipeline:** A deterministic pipeline using
260
+ `Resize` and `CenterCrop` for consistent evaluation.
261
+
262
+ Both pipelines finish with `ToTensor` and `Normalize`.
263
+
264
+ Args:
265
+ resize_size (int): The size to resize the smallest edge to
266
+ for validation/testing.
267
+ crop_size (int): The target size (square) for the final
268
+ cropped image.
269
+ mean (List[float]): The mean values for normalization (e.g., ImageNet mean).
270
+ std (List[float]): The standard deviation values for normalization (e.g., ImageNet std).
271
+ extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
272
+ pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
273
+
274
+ Returns:
275
+ VisionDatasetMaker: The same instance, with transforms applied.
276
+
277
+ Raises:
278
+ RuntimeError: If called before data is split.
279
+ """
280
+ if not self._is_split:
281
+ _LOGGER.error("Transforms must be configured AFTER splitting data (or using `from_folders`). Call .split_data() first if using `from_folder`.")
282
+ raise RuntimeError()
283
+
284
+ # --- Define Transform Pipelines ---
285
+ # These now MUST include ToTensor and Normalize, as the ImageFolder was loaded with transform=None.
286
+
287
+ # --- Store components for validation recipe ---
288
+ self._val_recipe_components = {
289
+ VisionTransformRecipeKeys.PRE_TRANSFORMS: pre_transforms or [],
290
+ VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
291
+ VisionTransformRecipeKeys.CROP_SIZE: crop_size,
292
+ VisionTransformRecipeKeys.MEAN: mean,
293
+ VisionTransformRecipeKeys.STD: std
294
+ }
295
+
296
+ base_pipeline = []
297
+ if pre_transforms:
298
+ base_pipeline.extend(pre_transforms)
299
+
300
+ # Base augmentations for training
301
+ base_train_transforms = [
302
+ transforms.RandomResizedCrop(crop_size),
303
+ transforms.RandomHorizontalFlip()
304
+ ]
305
+ if extra_train_transforms:
306
+ base_train_transforms.extend(extra_train_transforms)
307
+
308
+ # Final conversion and normalization
309
+ final_transforms = [
310
+ transforms.ToTensor(),
311
+ transforms.Normalize(mean=mean, std=std)
312
+ ]
313
+
314
+ # Build the val/test pipeline
315
+ val_transform_list = [
316
+ *base_pipeline, # Apply pre_transforms first
317
+ transforms.Resize(resize_size),
318
+ transforms.CenterCrop(crop_size),
319
+ *final_transforms
320
+ ]
321
+
322
+ # Build the train pipeline
323
+ train_transform_list = [
324
+ *base_pipeline, # Apply pre_transforms first
325
+ *base_train_transforms,
326
+ *final_transforms
327
+ ]
328
+
329
+ val_transform = transforms.Compose(val_transform_list)
330
+ train_transform = transforms.Compose(train_transform_list)
331
+
332
+ # --- Apply Transforms using the Wrapper ---
333
+ # This correctly assigns the transform regardless of whether the dataset is a Subset (from_folder) or an ImageFolder (from_folders).
334
+
335
+ self._train_dataset = _DatasetTransformer(self._train_dataset, train_transform) # type: ignore
336
+ self._val_dataset = _DatasetTransformer(self._val_dataset, val_transform) # type: ignore
337
+ if self._test_dataset:
338
+ self._test_dataset = _DatasetTransformer(self._test_dataset, val_transform) # type: ignore
339
+
340
+ self._are_transforms_configured = True
341
+ _LOGGER.info("Image transforms configured and applied.")
342
+ return self
343
+
344
+ def get_datasets(self) -> Tuple[Dataset, ...]:
345
+ """
346
+ Returns the final train, validation, and optional test datasets.
347
+
348
+ This is the final step, used to retrieve the datasets for use in
349
+ a `MLTrainer` or `DataLoader`.
350
+
351
+ Returns:
352
+ (Tuple[Dataset, ...]): A tuple containing the (train, val)
353
+ or (train, val, test) datasets.
354
+
355
+ Raises:
356
+ RuntimeError: If called before data is split.
357
+ UserWarning: If called before transforms are configured.
358
+ """
359
+ if not self._is_split:
360
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
361
+ raise RuntimeError()
362
+ if not self._are_transforms_configured:
363
+ _LOGGER.warning("Transforms have not been configured.")
364
+
365
+ if self._test_dataset:
366
+ return self._train_dataset, self._val_dataset, self._test_dataset
367
+ return self._train_dataset, self._val_dataset
368
+
369
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
370
+ """
371
+ Saves the validation transform pipeline as a JSON recipe file.
372
+
373
+ This recipe can be loaded by the PyTorchVisionInferenceHandler
374
+ to ensure identical preprocessing.
375
+
376
+ Args:
377
+ filepath (str | Path): The path to save the .json recipe file.
378
+ """
379
+ if not self._are_transforms_configured:
380
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
381
+ raise RuntimeError()
382
+
383
+ recipe: Dict[str, Any] = {
384
+ VisionTransformRecipeKeys.TASK: "classification",
385
+ VisionTransformRecipeKeys.PIPELINE: []
386
+ }
387
+
388
+ components = self._val_recipe_components
389
+
390
+ if not components:
391
+ _LOGGER.error(f"Error getting the transformers recipe for validation set.")
392
+ raise ValueError()
393
+
394
+ # validate path
395
+ file_path = make_fullpath(filepath, make=True, enforce="file")
396
+
397
+ # 1. Handle pre_transforms
398
+ for t in components[VisionTransformRecipeKeys.PRE_TRANSFORMS]:
399
+ t_name = t.__class__.__name__
400
+ if t_name in TRANSFORM_REGISTRY:
401
+ recipe[VisionTransformRecipeKeys.PIPELINE].append({
402
+ VisionTransformRecipeKeys.NAME: t_name,
403
+ VisionTransformRecipeKeys.KWARGS: getattr(t, VisionTransformRecipeKeys.KWARGS, {})
404
+ })
405
+ else:
406
+ _LOGGER.warning(f"Skipping unknown pre_transform '{t_name}' in recipe. Not in TRANSFORM_REGISTRY.")
407
+
408
+ # 2. Add standard transforms
409
+ recipe[VisionTransformRecipeKeys.PIPELINE].extend([
410
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
411
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
412
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
413
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
414
+ "mean": components[VisionTransformRecipeKeys.MEAN],
415
+ "std": components[VisionTransformRecipeKeys.STD]
416
+ }}
417
+ ])
418
+
419
+ # 3. Save the file
420
+ save_recipe(recipe, file_path)
421
+
422
+
423
+ class _DatasetTransformer(Dataset):
424
+ """
425
+ Internal wrapper class to apply a specific transform pipeline to any
426
+ dataset (e.g., a full ImageFolder or a Subset).
427
+ """
428
+ def __init__(self, dataset: Dataset, transform: Optional[transforms.Compose] = None):
429
+ self.dataset = dataset
430
+ self.transform = transform
431
+
432
+ # --- Propagate attributes for inspection ---
433
+ # For ImageFolder
434
+ if hasattr(dataset, 'class_to_idx'):
435
+ self.class_to_idx = getattr(dataset, 'class_to_idx')
436
+ if hasattr(dataset, 'classes'):
437
+ self.classes = getattr(dataset, 'classes')
438
+ # For Subset
439
+ if hasattr(dataset, 'indices'):
440
+ self.indices = getattr(dataset, 'indices')
441
+ if hasattr(dataset, 'dataset'):
442
+ # This allows access to the *original* full dataset
443
+ self.original_dataset = getattr(dataset, 'dataset')
444
+
445
+ def __getitem__(self, index):
446
+ # Get the original data (e.g., PIL Image, label)
447
+ x, y = self.dataset[index]
448
+
449
+ # Apply the specific transform for this dataset
450
+ if self.transform:
451
+ x = self.transform(x)
452
+ return x, y
453
+
454
+ def __len__(self):
455
+ return len(self.dataset) # type: ignore
456
+
457
+
458
+ # --- Segmentation dataset ----
459
+ class _SegmentationDataset(Dataset):
460
+ """
461
+ Internal helper class to load image-mask pairs.
462
+
463
+ Loads images as RGB and masks as 'L' (grayscale, 8-bit integer pixels).
464
+ """
465
+ def __init__(self, image_paths: List[Path], mask_paths: List[Path], transform: Optional[Callable] = None):
466
+ self.image_paths = image_paths
467
+ self.mask_paths = mask_paths
468
+ self.transform = transform
469
+
470
+ # --- Propagate 'classes' if they exist (for MLTrainer) ---
471
+ self.classes: List[str] = []
472
+
473
+ def __len__(self):
474
+ return len(self.image_paths)
475
+
476
+ def __getitem__(self, idx):
477
+ img_path = self.image_paths[idx]
478
+ mask_path = self.mask_paths[idx]
479
+
480
+ try:
481
+ # Open as PIL Images. Masks should be 'L'
482
+ image = Image.open(img_path).convert("RGB")
483
+ mask = Image.open(mask_path).convert("L")
484
+ except Exception as e:
485
+ _LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {mask_path.name}. Error: {e}")
486
+ # Return empty tensors
487
+ return torch.empty(3, 224, 224), torch.empty(224, 224, dtype=torch.long)
488
+
489
+ if self.transform:
490
+ image, mask = self.transform(image, mask)
491
+
492
+ return image, mask
493
+
494
+
495
+ # Internal Paired Transform Helpers
496
+ class _PairedCompose:
497
+ """A 'Compose' for paired image/mask transforms."""
498
+ def __init__(self, transforms: List[Callable]):
499
+ self.transforms = transforms
500
+
501
+ def __call__(self, image: Any, mask: Any) -> Tuple[Any, Any]:
502
+ for t in self.transforms:
503
+ image, mask = t(image, mask)
504
+ return image, mask
505
+
506
+ class _PairedToTensor:
507
+ """Converts a PIL Image pair (image, mask) to Tensors."""
508
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
509
+ # Use new variable names to satisfy the linter
510
+ image_tensor = TF.to_tensor(image)
511
+ # Convert mask to LongTensor, not float.
512
+ # This creates a [H, W] tensor of integer class IDs.
513
+ mask_tensor = torch.from_numpy(numpy.array(mask, dtype=numpy.int64))
514
+ return image_tensor, mask_tensor
515
+
516
+ class _PairedNormalize:
517
+ """Normalizes the image tensor and leaves the mask untouched."""
518
+ def __init__(self, mean: List[float], std: List[float]):
519
+ self.normalize = transforms.Normalize(mean, std)
520
+
521
+ def __call__(self, image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
522
+ image = self.normalize(image)
523
+ return image, mask
524
+
525
+ class _PairedResize:
526
+ """Resizes an image and mask to the same size."""
527
+ def __init__(self, size: int):
528
+ self.size = [size, size]
529
+
530
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
531
+ # Use new variable names to avoid linter confusion
532
+ resized_image = TF.resize(image, self.size, interpolation=TF.InterpolationMode.BILINEAR) # type: ignore
533
+ # Use NEAREST for mask to avoid interpolating class IDs (e.g., 1.5)
534
+ resized_mask = TF.resize(mask, self.size, interpolation=TF.InterpolationMode.NEAREST) # type: ignore
535
+ return resized_image, resized_mask # type: ignore
536
+
537
+ class _PairedCenterCrop:
538
+ """Center-crops an image and mask to the same size."""
539
+ def __init__(self, size: int):
540
+ self.size = [size, size]
541
+
542
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
543
+ cropped_image = TF.center_crop(image, self.size) # type: ignore
544
+ cropped_mask = TF.center_crop(mask, self.size) # type: ignore
545
+ return cropped_image, cropped_mask # type: ignore
546
+
547
+ class _PairedRandomHorizontalFlip:
548
+ """Applies the same random horizontal flip to both image and mask."""
549
+ def __init__(self, p: float = 0.5):
550
+ self.p = p
551
+
552
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
553
+ if random.random() < self.p:
554
+ flipped_image = TF.hflip(image) # type: ignore
555
+ flipped_mask = TF.hflip(mask) # type: ignore
556
+ return flipped_image, flipped_mask # type: ignore
557
+
558
+ class _PairedRandomResizedCrop:
559
+ """Applies the same random resized crop to both image and mask."""
560
+ def __init__(self, size: int, scale: Tuple[float, float]=(0.08, 1.0), ratio: Tuple[float, float]=(3./4., 4./3.)):
561
+ self.size = [size, size]
562
+ self.scale = scale
563
+ self.ratio = ratio
564
+ self.interpolation = TF.InterpolationMode.BILINEAR
565
+ self.mask_interpolation = TF.InterpolationMode.NEAREST
566
+
567
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
568
+ # Get parameters for the random crop
569
+ # Convert scale/ratio tuples to lists to satisfy the linter's type hint
570
+ i, j, h, w = transforms.RandomResizedCrop.get_params(image, list(self.scale), list(self.ratio)) # type: ignore
571
+
572
+ # Apply the crop with the SAME parameters and use new variable names
573
+ cropped_image = TF.resized_crop(image, i, j, h, w, self.size, self.interpolation) # type: ignore
574
+ cropped_mask = TF.resized_crop(mask, i, j, h, w, self.size, self.mask_interpolation) # type: ignore
575
+
576
+ return cropped_image, cropped_mask # type: ignore
577
+
578
+ # --- SegmentationDatasetMaker ---
579
+ class SegmentationDatasetMaker(_BaseMaker):
580
+ """
581
+ Creates processed PyTorch datasets for segmentation from image and mask folders.
582
+
583
+ This maker finds all matching image-mask pairs from two directories,
584
+ splits them, and applies identical transformations (including augmentations)
585
+ to both the image and its corresponding mask.
586
+
587
+ Workflow:
588
+ 1. `maker = SegmentationDatasetMaker.from_folders(img_dir, mask_dir)`
589
+ 2. `maker.set_class_map({'background': 0, 'road': 1})`
590
+ 3. `maker.split_data(val_size=0.2)`
591
+ 4. `maker.configure_transforms(crop_size=256)`
592
+ 5. `train_ds, val_ds = maker.get_datasets()`
593
+ """
594
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
595
+
596
+ def __init__(self):
597
+ """
598
+ Typically not called directly. Use the class method `from_folders()` to create an instance.
599
+ """
600
+ super().__init__()
601
+ self.image_paths: List[Path] = []
602
+ self.mask_paths: List[Path] = []
603
+ self.class_map: Dict[str, int] = {}
604
+
605
+ self._is_split = False
606
+ self._are_transforms_configured = False
607
+ self.train_transform: Optional[Callable] = None
608
+ self.val_transform: Optional[Callable] = None
609
+
610
+ @classmethod
611
+ def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'SegmentationDatasetMaker':
612
+ """
613
+ Creates a maker instance by loading all matching image-mask pairs
614
+ from two corresponding directories.
615
+
616
+ This method assumes that for an image `images/img_001.png`, there
617
+ is a corresponding mask `masks/img_001.png`.
618
+
619
+ Args:
620
+ image_dir (str | Path): Path to the directory containing input images.
621
+ mask_dir (str | Path): Path to the directory containing segmentation masks.
622
+
623
+ Returns:
624
+ SegmentationDatasetMaker: A new instance with all pairs loaded.
625
+ """
626
+ maker = cls()
627
+ img_path_obj = make_fullpath(image_dir, enforce="directory")
628
+ msk_path_obj = make_fullpath(mask_dir, enforce="directory")
629
+
630
+ # Find all images
631
+ image_files = sorted([
632
+ p for p in img_path_obj.glob('*.*')
633
+ if p.suffix.lower() in cls.IMG_EXTENSIONS
634
+ ])
635
+
636
+ if not image_files:
637
+ _LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
638
+ raise FileNotFoundError()
639
+
640
+ _LOGGER.info(f"Found {len(image_files)} images. Searching for matching masks in {mask_dir}...")
641
+
642
+ good_img_paths = []
643
+ good_mask_paths = []
644
+
645
+ for img_file in image_files:
646
+ mask_file = None
647
+
648
+ # 1. Try to find mask with the exact same name
649
+ mask_file_primary = msk_path_obj / img_file.name
650
+ if mask_file_primary.exists():
651
+ mask_file = mask_file_primary
652
+
653
+ # 2. If not, try to find mask with same stem + common mask extension
654
+ if mask_file is None:
655
+ for ext in cls.IMG_EXTENSIONS: # Masks are often .png
656
+ mask_file_secondary = msk_path_obj / (img_file.stem + ext)
657
+ if mask_file_secondary.exists():
658
+ mask_file = mask_file_secondary
659
+ break
660
+
661
+ # 3. If a match is found, add the pair
662
+ if mask_file:
663
+ good_img_paths.append(img_file)
664
+ good_mask_paths.append(mask_file)
665
+ else:
666
+ _LOGGER.warning(f"No corresponding mask found for image: {img_file.name}")
667
+
668
+ if not good_img_paths:
669
+ _LOGGER.error("No matching image-mask pairs were found.")
670
+ raise FileNotFoundError()
671
+
672
+ _LOGGER.info(f"Successfully found {len(good_img_paths)} image-mask pairs.")
673
+ maker.image_paths = good_img_paths
674
+ maker.mask_paths = good_mask_paths
675
+
676
+ return maker
677
+
678
+ @staticmethod
679
+ def inspect_folder(path: Union[str, Path]):
680
+ """
681
+ Logs a report of the types, sizes, and channels of image files
682
+ found in the directory. Useful for checking masks.
683
+ """
684
+ VisionDatasetMaker.inspect_folder(path)
685
+
686
+ def set_class_map(self, class_map: Dict[str, int]) -> 'SegmentationDatasetMaker':
687
+ """
688
+ Sets a map of pixel_value -> class_name. This is used by the MLTrainer for clear evaluation reports.
689
+
690
+ Args:
691
+ class_map (Dict[int, str]): A dictionary mapping the integer pixel
692
+ value in a mask to its string name.
693
+ Example: {'background': 0, 'road': 1, 'car': 2}
694
+ """
695
+ self.class_map = class_map
696
+ _LOGGER.info(f"Class map set: {class_map}")
697
+ return self
698
+
699
+ @property
700
+ def classes(self) -> List[str]:
701
+ """Returns the list of class names, if set."""
702
+ if self.class_map:
703
+ return list(self.class_map.keys())
704
+ return []
705
+
706
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
707
+ random_state: Optional[int] = 42) -> 'SegmentationDatasetMaker':
708
+ """
709
+ Splits the loaded image-mask pairs into train, validation, and test sets.
710
+
711
+ Args:
712
+ val_size (float): Proportion of the dataset to reserve for validation.
713
+ test_size (float): Proportion of the dataset to reserve for testing.
714
+ random_state (int | None): Seed for reproducible splits.
715
+
716
+ Returns:
717
+ SegmentationDatasetMaker: The same instance, now with datasets created.
718
+ """
719
+ if self._is_split:
720
+ _LOGGER.warning("Data has already been split.")
721
+ return self
722
+
723
+ if val_size + test_size >= 1.0:
724
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
725
+ raise ValueError()
726
+
727
+ if not self.image_paths:
728
+ _LOGGER.error("There is no data to split. Use .from_folders() first.")
729
+ raise RuntimeError()
730
+
731
+ indices = list(range(len(self.image_paths)))
732
+
733
+ # Split indices
734
+ train_indices, val_test_indices = train_test_split(
735
+ indices, test_size=(val_size + test_size), random_state=random_state
736
+ )
737
+
738
+ # Helper to get paths from indices
739
+ def get_paths(idx_list):
740
+ return [self.image_paths[i] for i in idx_list], [self.mask_paths[i] for i in idx_list]
741
+
742
+ train_imgs, train_masks = get_paths(train_indices)
743
+
744
+ if test_size > 0:
745
+ val_indices, test_indices = train_test_split(
746
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
747
+ random_state=random_state
748
+ )
749
+ val_imgs, val_masks = get_paths(val_indices)
750
+ test_imgs, test_masks = get_paths(test_indices)
751
+
752
+ self._test_dataset = _SegmentationDataset(test_imgs, test_masks, transform=None)
753
+ self._test_dataset.classes = self.classes # type: ignore
754
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
755
+ else:
756
+ val_imgs, val_masks = get_paths(val_test_indices)
757
+
758
+ self._train_dataset = _SegmentationDataset(train_imgs, train_masks, transform=None)
759
+ self._val_dataset = _SegmentationDataset(val_imgs, val_masks, transform=None)
760
+
761
+ # Propagate class names to datasets for MLTrainer
762
+ self._train_dataset.classes = self.classes # type: ignore
763
+ self._val_dataset.classes = self.classes # type: ignore
764
+
765
+ self._is_split = True
766
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
767
+ return self
768
+
769
+ def configure_transforms(self,
770
+ resize_size: int = 256,
771
+ crop_size: int = 224,
772
+ mean: List[float] = [0.485, 0.456, 0.406],
773
+ std: List[float] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
774
+ """
775
+ Configures and applies the image and mask transformations.
776
+
777
+ This method must be called AFTER data is split.
778
+
779
+ Args:
780
+ resize_size (int): The size to resize the smallest edge to
781
+ for validation/testing.
782
+ crop_size (int): The target size (square) for the final
783
+ cropped image.
784
+ mean (List[float]): The mean values for image normalization.
785
+ std (List[float]): The std dev values for image normalization.
786
+
787
+ Returns:
788
+ SegmentationDatasetMaker: The same instance, with transforms applied.
789
+ """
790
+ if not self._is_split:
791
+ _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
792
+ raise RuntimeError()
793
+
794
+ # --- Store components for validation recipe ---
795
+ self.val_recipe_components = {
796
+ VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
797
+ VisionTransformRecipeKeys.CROP_SIZE: crop_size,
798
+ VisionTransformRecipeKeys.MEAN: mean,
799
+ VisionTransformRecipeKeys.STD: std
800
+ }
801
+
802
+ # --- Validation/Test Pipeline (Deterministic) ---
803
+ self.val_transform = _PairedCompose([
804
+ _PairedResize(resize_size),
805
+ _PairedCenterCrop(crop_size),
806
+ _PairedToTensor(),
807
+ _PairedNormalize(mean, std)
808
+ ])
809
+
810
+ # --- Training Pipeline (Augmentation) ---
811
+ self.train_transform = _PairedCompose([
812
+ _PairedRandomResizedCrop(crop_size),
813
+ _PairedRandomHorizontalFlip(p=0.5),
814
+ _PairedToTensor(),
815
+ _PairedNormalize(mean, std)
816
+ ])
817
+
818
+ # --- Apply Transforms to the Datasets ---
819
+ self._train_dataset.transform = self.train_transform # type: ignore
820
+ self._val_dataset.transform = self.val_transform # type: ignore
821
+ if self._test_dataset:
822
+ self._test_dataset.transform = self.val_transform # type: ignore
823
+
824
+ self._are_transforms_configured = True
825
+ _LOGGER.info("Paired segmentation transforms configured and applied.")
826
+ return self
827
+
828
+ def get_datasets(self) -> Tuple[Dataset, ...]:
829
+ """
830
+ Returns the final train, validation, and optional test datasets.
831
+
832
+ Raises:
833
+ RuntimeError: If called before data is split.
834
+ RuntimeError: If called before transforms are configured.
835
+ """
836
+ if not self._is_split:
837
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
838
+ raise RuntimeError()
839
+ if not self._are_transforms_configured:
840
+ _LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
841
+ raise RuntimeError()
842
+
843
+ if self._test_dataset:
844
+ return self._train_dataset, self._val_dataset, self._test_dataset
845
+ return self._train_dataset, self._val_dataset
846
+
847
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
848
+ """
849
+ Saves the validation transform pipeline as a JSON recipe file.
850
+
851
+ This recipe can be loaded by the PyTorchVisionInferenceHandler
852
+ to ensure identical preprocessing.
853
+
854
+ Args:
855
+ filepath (str | Path): The path to save the .json recipe file.
856
+ """
857
+ if not self._are_transforms_configured:
858
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
859
+ raise RuntimeError()
860
+
861
+ components = self.val_recipe_components
862
+
863
+ if not components:
864
+ _LOGGER.error(f"Error getting the transformers recipe for validation set.")
865
+ raise ValueError()
866
+
867
+ # validate path
868
+ file_path = make_fullpath(filepath, make=True, enforce="file")
869
+
870
+ # Add standard transforms
871
+ recipe: Dict[str, Any] = {
872
+ VisionTransformRecipeKeys.TASK: "segmentation",
873
+ VisionTransformRecipeKeys.PIPELINE: [
874
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components["resize_size"]}},
875
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components["crop_size"]}},
876
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
877
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
878
+ "mean": components["mean"],
879
+ "std": components["std"]
880
+ }}
881
+ ]
882
+ }
883
+
884
+ # Save the file
885
+ save_recipe(recipe, file_path)
886
+
887
+
888
+ # Object detection
889
+ def _od_collate_fn(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]) -> Tuple[List[torch.Tensor], List[Dict[str, torch.Tensor]]]:
890
+ """
891
+ Custom collate function for object detection.
892
+
893
+ Takes a list of (image, target) tuples and zips them into two lists:
894
+ (list_of_images, list_of_targets).
895
+ This is required for models like Faster R-CNN, which accept a list
896
+ of images of varying sizes.
897
+ """
898
+ return tuple(zip(*batch)) # type: ignore
899
+
900
+
901
+ class _ObjectDetectionDataset(Dataset):
902
+ """
903
+ Internal helper class to load image-annotation pairs.
904
+
905
+ Loads an image as 'RGB' and parses its corresponding JSON annotation file
906
+ to create the required target dictionary (boxes, labels).
907
+ """
908
+ def __init__(self, image_paths: List[Path], annotation_paths: List[Path], transform: Optional[Callable] = None):
909
+ self.image_paths = image_paths
910
+ self.annotation_paths = annotation_paths
911
+ self.transform = transform
912
+
913
+ # --- Propagate 'classes' if they exist (for MLTrainer) ---
914
+ self.classes: List[str] = []
915
+
916
+ def __len__(self):
917
+ return len(self.image_paths)
918
+
919
+ def __getitem__(self, idx):
920
+ img_path = self.image_paths[idx]
921
+ ann_path = self.annotation_paths[idx]
922
+
923
+ try:
924
+ # Open image
925
+ image = Image.open(img_path).convert("RGB")
926
+
927
+ # Load and parse annotation
928
+ with open(ann_path, 'r') as f:
929
+ ann_data = json.load(f)
930
+
931
+ # Get boxes and labels from JSON
932
+ boxes = ann_data[ObjectDetectionKeys.BOXES] # Expected: [[x1, y1, x2, y2], ...]
933
+ labels = ann_data[ObjectDetectionKeys.LABELS] # Expected: [1, 2, 1, ...]
934
+
935
+ # Convert to tensors
936
+ target: Dict[str, Any] = {}
937
+ target[ObjectDetectionKeys.BOXES] = torch.as_tensor(boxes, dtype=torch.float32)
938
+ target[ObjectDetectionKeys.LABELS] = torch.as_tensor(labels, dtype=torch.int64)
939
+
940
+ except Exception as e:
941
+ _LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {ann_path.name}. Error: {e}")
942
+ # Return empty/dummy data
943
+ return torch.empty(3, 224, 224), {ObjectDetectionKeys.BOXES: torch.empty((0, 4)), ObjectDetectionKeys.LABELS: torch.empty(0, dtype=torch.long)}
944
+
945
+ if self.transform:
946
+ image, target = self.transform(image, target)
947
+
948
+ return image, target
949
+
950
+ # Internal Paired Transform Helpers for Object Detection
951
+ class _OD_PairedCompose:
952
+ """A 'Compose' for paired image/target_dict transforms."""
953
+ def __init__(self, transforms: List[Callable]):
954
+ self.transforms = transforms
955
+
956
+ def __call__(self, image: Any, target: Any) -> Tuple[Any, Any]:
957
+ for t in self.transforms:
958
+ image, target = t(image, target)
959
+ return image, target
960
+
961
+ class _OD_PairedToTensor:
962
+ """Converts a PIL Image to Tensor, passes targets dict through."""
963
+ def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
964
+ return TF.to_tensor(image), target
965
+
966
+ class _OD_PairedNormalize:
967
+ """Normalizes the image tensor and leaves the target dict untouched."""
968
+ def __init__(self, mean: List[float], std: List[float]):
969
+ self.normalize = transforms.Normalize(mean, std)
970
+
971
+ def __call__(self, image: torch.Tensor, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
972
+ image_normalized = self.normalize(image)
973
+ return image_normalized, target
974
+
975
+ class _OD_PairedRandomHorizontalFlip:
976
+ """Applies the same random horizontal flip to both image and targets['boxes']."""
977
+ def __init__(self, p: float = 0.5):
978
+ self.p = p
979
+
980
+ def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[Image.Image, Dict[str, Any]]:
981
+ if random.random() < self.p:
982
+ w, h = image.size
983
+ # Use new variable names to avoid linter confusion
984
+ flipped_image = TF.hflip(image) # type: ignore
985
+
986
+ # Flip boxes
987
+ boxes = target[ObjectDetectionKeys.BOXES].clone() # [N, 4]
988
+
989
+ # xmin' = w - xmax
990
+ # xmax' = w - xmin
991
+ boxes[:, 0] = w - target[ObjectDetectionKeys.BOXES][:, 2] # xmin'
992
+ boxes[:, 2] = w - target[ObjectDetectionKeys.BOXES][:, 0] # xmax'
993
+ target[ObjectDetectionKeys.BOXES] = boxes
994
+
995
+ return flipped_image, target # type: ignore
996
+
997
+ return image, target
998
+
999
+
1000
+ class ObjectDetectionDatasetMaker(_BaseMaker):
1001
+ """
1002
+ Creates processed PyTorch datasets for object detection from image
1003
+ and JSON annotation folders.
1004
+
1005
+ This maker finds all matching image-annotation pairs from two directories,
1006
+ splits them, and applies identical transformations (including augmentations)
1007
+ to both the image and its corresponding target dictionary.
1008
+
1009
+ The `DragonFastRCNN` model expects a list of images and a list of targets,
1010
+ so this class provides a `collate_fn` to be used with a DataLoader.
1011
+
1012
+ Workflow:
1013
+ 1. `maker = ObjectDetectionDatasetMaker.from_folders(img_dir, ann_dir)`
1014
+ 2. `maker.set_class_map({'background': 0, 'person': 1, 'car': 2})`
1015
+ 3. `maker.split_data(val_size=0.2)`
1016
+ 4. `maker.configure_transforms()`
1017
+ 5. `train_ds, val_ds = maker.get_datasets()`
1018
+ 6. `collate_fn = maker.collate_fn`
1019
+ 7. `train_loader = DataLoader(train_ds, ..., collate_fn=collate_fn)`
1020
+ """
1021
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
1022
+
1023
+ def __init__(self):
1024
+ """
1025
+ Typically not called directly. Use the class method `from_folders()` to create an instance.
1026
+ """
1027
+ super().__init__()
1028
+ self.image_paths: List[Path] = []
1029
+ self.annotation_paths: List[Path] = []
1030
+ self.class_map: Dict[str, int] = {}
1031
+
1032
+ self._is_split = False
1033
+ self._are_transforms_configured = False
1034
+ self.train_transform: Optional[Callable] = None
1035
+ self.val_transform: Optional[Callable] = None
1036
+ self._val_recipe_components: Optional[Dict[str, Any]] = None
1037
+
1038
+ @classmethod
1039
+ def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'ObjectDetectionDatasetMaker':
1040
+ """
1041
+ Creates a maker instance by loading all matching image-annotation pairs
1042
+ from two corresponding directories.
1043
+
1044
+ This method assumes that for an image `images/img_001.png`, there
1045
+ is a corresponding annotation `annotations/img_001.json`.
1046
+
1047
+ The JSON file must contain "boxes" and "labels" keys:
1048
+ `{"boxes": [[x1,y1,x2,y2], ...], "labels": [1, 2, ...]}`
1049
+
1050
+ Args:
1051
+ image_dir (str | Path): Path to the directory containing input images.
1052
+ annotation_dir (str | Path): Path to the directory containing .json
1053
+ annotation files.
1054
+
1055
+ Returns:
1056
+ ObjectDetectionDatasetMaker: A new instance with all pairs loaded.
1057
+ """
1058
+ maker = cls()
1059
+ img_path_obj = make_fullpath(image_dir, enforce="directory")
1060
+ ann_path_obj = make_fullpath(annotation_dir, enforce="directory")
1061
+
1062
+ # Find all images
1063
+ image_files = sorted([
1064
+ p for p in img_path_obj.glob('*.*')
1065
+ if p.suffix.lower() in cls.IMG_EXTENSIONS
1066
+ ])
1067
+
1068
+ if not image_files:
1069
+ _LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
1070
+ raise FileNotFoundError()
1071
+
1072
+ _LOGGER.info(f"Found {len(image_files)} images. Searching for matching .json annotations in {annotation_dir}...")
1073
+
1074
+ good_img_paths = []
1075
+ good_ann_paths = []
1076
+
1077
+ for img_file in image_files:
1078
+ # Find annotation with same stem + .json
1079
+ ann_file = ann_path_obj / (img_file.stem + ".json")
1080
+
1081
+ if ann_file.exists():
1082
+ good_img_paths.append(img_file)
1083
+ good_ann_paths.append(ann_file)
1084
+ else:
1085
+ _LOGGER.warning(f"No corresponding .json annotation found for image: {img_file.name}")
1086
+
1087
+ if not good_img_paths:
1088
+ _LOGGER.error("No matching image-annotation pairs were found.")
1089
+ raise FileNotFoundError()
1090
+
1091
+ _LOGGER.info(f"Successfully found {len(good_img_paths)} image-annotation pairs.")
1092
+ maker.image_paths = good_img_paths
1093
+ maker.annotation_paths = good_ann_paths
1094
+
1095
+ return maker
1096
+
1097
+ @staticmethod
1098
+ def inspect_folder(path: Union[str, Path]):
1099
+ """
1100
+ Logs a report of the types, sizes, and channels of image files
1101
+ found in the directory.
1102
+ """
1103
+ VisionDatasetMaker.inspect_folder(path)
1104
+
1105
+ def set_class_map(self, class_map: Dict[str, int]) -> 'ObjectDetectionDatasetMaker':
1106
+ """
1107
+ Sets a map of class_name -> pixel_value. This is used by the
1108
+ MLTrainer for clear evaluation reports.
1109
+
1110
+ **Important:** For object detection models, 'background' MUST
1111
+ be included as class 0.
1112
+ Example: `{'background': 0, 'person': 1, 'car': 2}`
1113
+
1114
+ Args:
1115
+ class_map (Dict[str, int]): A dictionary mapping the string name
1116
+ to its integer label.
1117
+ """
1118
+ if 'background' not in class_map or class_map['background'] != 0:
1119
+ _LOGGER.warning("Object detection class map should include 'background' mapped to 0.")
1120
+
1121
+ self.class_map = class_map
1122
+ _LOGGER.info(f"Class map set: {class_map}")
1123
+ return self
1124
+
1125
+ @property
1126
+ def classes(self) -> List[str]:
1127
+ """Returns the list of class names, if set."""
1128
+ if self.class_map:
1129
+ return list(self.class_map.keys())
1130
+ return []
1131
+
1132
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
1133
+ random_state: Optional[int] = 42) -> 'ObjectDetectionDatasetMaker':
1134
+ """
1135
+ Splits the loaded image-annotation pairs into train, validation, and test sets.
1136
+
1137
+ Args:
1138
+ val_size (float): Proportion of the dataset to reserve for validation.
1139
+ test_size (float): Proportion of the dataset to reserve for testing.
1140
+ random_state (int | None): Seed for reproducible splits.
1141
+
1142
+ Returns:
1143
+ ObjectDetectionDatasetMaker: The same instance, now with datasets created.
1144
+ """
1145
+ if self._is_split:
1146
+ _LOGGER.warning("Data has already been split.")
1147
+ return self
1148
+
1149
+ if val_size + test_size >= 1.0:
1150
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
1151
+ raise ValueError()
1152
+
1153
+ if not self.image_paths:
1154
+ _LOGGER.error("There is no data to split. Use .from_folders() first.")
1155
+ raise RuntimeError()
1156
+
1157
+ indices = list(range(len(self.image_paths)))
1158
+
1159
+ # Split indices
1160
+ train_indices, val_test_indices = train_test_split(
1161
+ indices, test_size=(val_size + test_size), random_state=random_state
1162
+ )
1163
+
1164
+ # Helper to get paths from indices
1165
+ def get_paths(idx_list):
1166
+ return [self.image_paths[i] for i in idx_list], [self.annotation_paths[i] for i in idx_list]
1167
+
1168
+ train_imgs, train_anns = get_paths(train_indices)
1169
+
1170
+ if test_size > 0:
1171
+ val_indices, test_indices = train_test_split(
1172
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
1173
+ random_state=random_state
1174
+ )
1175
+ val_imgs, val_anns = get_paths(val_indices)
1176
+ test_imgs, test_anns = get_paths(test_indices)
1177
+
1178
+ self._test_dataset = _ObjectDetectionDataset(test_imgs, test_anns, transform=None)
1179
+ self._test_dataset.classes = self.classes # type: ignore
1180
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
1181
+ else:
1182
+ val_imgs, val_anns = get_paths(val_test_indices)
1183
+
1184
+ self._train_dataset = _ObjectDetectionDataset(train_imgs, train_anns, transform=None)
1185
+ self._val_dataset = _ObjectDetectionDataset(val_imgs, val_anns, transform=None)
1186
+
1187
+ # Propagate class names to datasets for MLTrainer
1188
+ self._train_dataset.classes = self.classes # type: ignore
1189
+ self._val_dataset.classes = self.classes # type: ignore
1190
+
1191
+ self._is_split = True
1192
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
1193
+ return self
1194
+
1195
+ def configure_transforms(self,
1196
+ mean: List[float] = [0.485, 0.456, 0.406],
1197
+ std: List[float] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
1198
+ """
1199
+ Configures and applies the image and target transformations.
1200
+
1201
+ This method must be called AFTER data is split.
1202
+
1203
+ For object detection models like Faster R-CNN, images are NOT
1204
+ resized or cropped, as the model handles variable input sizes.
1205
+ Transforms are limited to augmentation (flip), ToTensor, and Normalize.
1206
+
1207
+ Args:
1208
+ mean (List[float]): The mean values for image normalization.
1209
+ std (List[float]): The std dev values for image normalization.
1210
+
1211
+ Returns:
1212
+ ObjectDetectionDatasetMaker: The same instance, with transforms applied.
1213
+ """
1214
+ if not self._is_split:
1215
+ _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
1216
+ raise RuntimeError()
1217
+
1218
+ # --- Store components for validation recipe ---
1219
+ self._val_recipe_components = {
1220
+ VisionTransformRecipeKeys.MEAN: mean,
1221
+ VisionTransformRecipeKeys.STD: std
1222
+ }
1223
+
1224
+ # --- Validation/Test Pipeline (Deterministic) ---
1225
+ self.val_transform = _OD_PairedCompose([
1226
+ _OD_PairedToTensor(),
1227
+ _OD_PairedNormalize(mean, std)
1228
+ ])
1229
+
1230
+ # --- Training Pipeline (Augmentation) ---
1231
+ self.train_transform = _OD_PairedCompose([
1232
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1233
+ _OD_PairedToTensor(),
1234
+ _OD_PairedNormalize(mean, std)
1235
+ ])
1236
+
1237
+ # --- Apply Transforms to the Datasets ---
1238
+ self._train_dataset.transform = self.train_transform # type: ignore
1239
+ self._val_dataset.transform = self.val_transform # type: ignore
1240
+ if self._test_dataset:
1241
+ self._test_dataset.transform = self.val_transform # type: ignore
1242
+
1243
+ self._are_transforms_configured = True
1244
+ _LOGGER.info("Paired object detection transforms configured and applied.")
1245
+ return self
1246
+
1247
+ def get_datasets(self) -> Tuple[Dataset, ...]:
1248
+ """
1249
+ Returns the final train, validation, and optional test datasets.
1250
+
1251
+ Raises:
1252
+ RuntimeError: If called before data is split.
1253
+ RuntimeError: If called before transforms are configured.
1254
+ """
1255
+ if not self._is_split:
1256
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
1257
+ raise RuntimeError()
1258
+ if not self._are_transforms_configured:
1259
+ _LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
1260
+ raise RuntimeError()
1261
+
1262
+ if self._test_dataset:
1263
+ return self._train_dataset, self._val_dataset, self._test_dataset
1264
+ return self._train_dataset, self._val_dataset
1265
+
1266
+ @property
1267
+ def collate_fn(self) -> Callable:
1268
+ """
1269
+ Returns the collate function required by a DataLoader for this
1270
+ dataset. This function ensures that images and targets are
1271
+ batched as separate lists.
1272
+ """
1273
+ return _od_collate_fn
1274
+
1275
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
1276
+ """
1277
+ Saves the validation transform pipeline as a JSON recipe file.
1278
+
1279
+ For object detection, this recipe only includes ToTensor and
1280
+ Normalize, as resizing is handled by the model.
1281
+
1282
+ Args:
1283
+ filepath (str | Path): The path to save the .json recipe file.
1284
+ """
1285
+ if not self._are_transforms_configured:
1286
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
1287
+ raise RuntimeError()
1288
+
1289
+ components = self._val_recipe_components
1290
+
1291
+ if not components:
1292
+ _LOGGER.error(f"Error getting the transformers recipe for validation set.")
1293
+ raise ValueError()
1294
+
1295
+ # validate path
1296
+ file_path = make_fullpath(filepath, make=True, enforce="file")
1297
+
1298
+ # Add standard transforms
1299
+ recipe: Dict[str, Any] = {
1300
+ VisionTransformRecipeKeys.TASK: "object_detection",
1301
+ VisionTransformRecipeKeys.PIPELINE: [
1302
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
1303
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1304
+ "mean": components["mean"],
1305
+ "std": components["std"]
1306
+ }}
1307
+ ]
1308
+ }
1309
+
1310
+ # Save the file
1311
+ save_recipe(recipe, file_path)
1312
+
1313
+
1314
+ def info():
1315
+ _script_info(__all__)