dragon-ml-toolbox 13.3.0__py3-none-any.whl → 16.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/METADATA +20 -6
  2. dragon_ml_toolbox-16.2.0.dist-info/RECORD +51 -0
  3. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +10 -0
  4. ml_tools/ETL_cleaning.py +20 -20
  5. ml_tools/ETL_engineering.py +23 -25
  6. ml_tools/GUI_tools.py +20 -20
  7. ml_tools/MICE_imputation.py +207 -5
  8. ml_tools/ML_callbacks.py +43 -26
  9. ml_tools/ML_configuration.py +788 -0
  10. ml_tools/ML_datasetmaster.py +303 -448
  11. ml_tools/ML_evaluation.py +351 -93
  12. ml_tools/ML_evaluation_multi.py +139 -42
  13. ml_tools/ML_inference.py +290 -209
  14. ml_tools/ML_models.py +33 -106
  15. ml_tools/ML_models_advanced.py +323 -0
  16. ml_tools/ML_optimization.py +12 -12
  17. ml_tools/ML_scaler.py +11 -11
  18. ml_tools/ML_sequence_datasetmaster.py +341 -0
  19. ml_tools/ML_sequence_evaluation.py +219 -0
  20. ml_tools/ML_sequence_inference.py +391 -0
  21. ml_tools/ML_sequence_models.py +139 -0
  22. ml_tools/ML_trainer.py +1604 -179
  23. ml_tools/ML_utilities.py +351 -4
  24. ml_tools/ML_vision_datasetmaster.py +1540 -0
  25. ml_tools/ML_vision_evaluation.py +284 -0
  26. ml_tools/ML_vision_inference.py +405 -0
  27. ml_tools/ML_vision_models.py +641 -0
  28. ml_tools/ML_vision_transformers.py +284 -0
  29. ml_tools/PSO_optimization.py +6 -6
  30. ml_tools/SQL.py +4 -4
  31. ml_tools/_keys.py +171 -0
  32. ml_tools/_schema.py +1 -1
  33. ml_tools/custom_logger.py +37 -14
  34. ml_tools/data_exploration.py +502 -93
  35. ml_tools/ensemble_evaluation.py +54 -11
  36. ml_tools/ensemble_inference.py +7 -33
  37. ml_tools/ensemble_learning.py +1 -1
  38. ml_tools/math_utilities.py +1 -1
  39. ml_tools/optimization_tools.py +2 -2
  40. ml_tools/path_manager.py +5 -5
  41. ml_tools/serde.py +2 -2
  42. ml_tools/utilities.py +192 -4
  43. dragon_ml_toolbox-13.3.0.dist-info/RECORD +0 -41
  44. ml_tools/RNN_forecast.py +0 -56
  45. ml_tools/keys.py +0 -87
  46. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/WHEEL +0 -0
  47. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/licenses/LICENSE +0 -0
  48. {dragon_ml_toolbox-13.3.0.dist-info → dragon_ml_toolbox-16.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1540 @@
1
+ import torch
2
+ from torch.utils.data import Dataset, Subset
3
+ import numpy
4
+ from sklearn.model_selection import train_test_split
5
+ from typing import Union, Tuple, List, Optional, Callable, Dict, Any
6
+ from PIL import Image
7
+ from torchvision.datasets import ImageFolder
8
+ from torchvision import transforms
9
+ import torchvision.transforms.functional as TF
10
+ from pathlib import Path
11
+ import random
12
+ import json
13
+ import inspect
14
+
15
+ from .path_manager import make_fullpath
16
+ from ._logger import _LOGGER
17
+ from ._script_info import _script_info
18
+ from ._keys import VisionTransformRecipeKeys, ObjectDetectionKeys
19
+ from .ML_vision_transformers import TRANSFORM_REGISTRY, _save_recipe
20
+ from .custom_logger import custom_logger
21
+
22
+
23
+ __all__ = [
24
+ "DragonDatasetVision",
25
+ "DragonDatasetSegmentation",
26
+ "DragonDatasetObjectDetection"
27
+ ]
28
+
29
+
30
+ # --- Vision Maker ---
31
+ class DragonDatasetVision:
32
+ """
33
+ Creates processed PyTorch datasets for computer vision tasks from an
34
+ image folder directory.
35
+
36
+ Supports two modes:
37
+ 1. `from_folder()`: Loads from one directory and splits into train/val/test.
38
+ 2. `from_folders()`: Loads from pre-split train/val/test directories.
39
+
40
+ Uses online augmentations per epoch (image augmentation without creating new files).
41
+ """
42
+ def __init__(self):
43
+ """
44
+ Typically not called directly. Use the class methods `from_folder()` or `from_folders()` to create an instance.
45
+ """
46
+ self._train_dataset = None
47
+ self._test_dataset = None
48
+ self._val_dataset = None
49
+ self._full_dataset: Optional[ImageFolder] = None
50
+ self.labels: Optional[List[int]] = None
51
+ self.class_map: dict[str,int] = dict()
52
+
53
+ self._is_split = False
54
+ self._are_transforms_configured = False
55
+ self._val_recipe_components = None
56
+ self._has_mean_std: bool = False
57
+
58
+ @classmethod
59
+ def from_folder(cls, root_dir: Union[str,Path]) -> 'DragonDatasetVision':
60
+ """
61
+ Creates a maker instance from a single root directory of images.
62
+
63
+ This method assumes a single directory (e.g., 'data/') that
64
+ contains class subfolders (e.g., 'data/cat/', 'data/dog/').
65
+
66
+ The dataset will be loaded in its entirety, and you MUST call
67
+ `.split_data()` afterward to create train/validation/test sets.
68
+
69
+ Args:
70
+ root_dir (str | Path): The path to the root directory containing class subfolders.
71
+
72
+ Returns:
73
+ Instance: A new instance with the full dataset loaded.
74
+ """
75
+ root_path = make_fullpath(root_dir, enforce="directory")
76
+ # Load with NO transform. We get PIL Images.
77
+ full_dataset = ImageFolder(root=root_path, transform=None)
78
+ _LOGGER.info(f"Found {len(full_dataset)} images in {len(full_dataset.classes)} classes.")
79
+
80
+ maker = cls()
81
+ maker._full_dataset = full_dataset
82
+ maker.labels = [s[1] for s in full_dataset.samples]
83
+ maker.class_map = full_dataset.class_to_idx
84
+ return maker
85
+
86
+ @classmethod
87
+ def from_folders(cls,
88
+ train_dir: Union[str,Path],
89
+ val_dir: Union[str,Path],
90
+ test_dir: Optional[Union[str,Path]] = None) -> 'DragonDatasetVision':
91
+ """
92
+ Creates a maker instance from separate, pre-split directories.
93
+
94
+ This method is used when you already have 'train', 'val', and
95
+ optionally 'test' folders, each containing class subfolders.
96
+ It bypasses the need for `.split_data()`.
97
+
98
+ Args:
99
+ train_dir (str | Path): Path to the training data directory.
100
+ val_dir (str | Path): Path to the validation data directory.
101
+ test_dir (str | Path | None): Path to the test data directory.
102
+
103
+ Returns:
104
+ Instance: A new, pre-split instance.
105
+
106
+ Raises:
107
+ ValueError: If the classes found in train, val, or test directories are inconsistent.
108
+ """
109
+ train_path = make_fullpath(train_dir, enforce="directory")
110
+ val_path = make_fullpath(val_dir, enforce="directory")
111
+
112
+ _LOGGER.info("Loading data from separate directories.")
113
+ # Load with NO transform. We get PIL Images.
114
+ train_ds = ImageFolder(root=train_path, transform=None)
115
+ val_ds = ImageFolder(root=val_path, transform=None)
116
+
117
+ # Check for class consistency
118
+ if train_ds.class_to_idx != val_ds.class_to_idx:
119
+ _LOGGER.error("Train and validation directories have different or inconsistent classes.")
120
+ raise ValueError()
121
+
122
+ maker = cls()
123
+ maker._train_dataset = train_ds
124
+ maker._val_dataset = val_ds
125
+ maker.class_map = train_ds.class_to_idx
126
+
127
+ if test_dir:
128
+ test_path = make_fullpath(test_dir, enforce="directory")
129
+ test_ds = ImageFolder(root=test_path, transform=None)
130
+ if train_ds.class_to_idx != test_ds.class_to_idx:
131
+ _LOGGER.error("Train and test directories have different or inconsistent classes.")
132
+ raise ValueError()
133
+ maker._test_dataset = test_ds
134
+ _LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val, {len(test_ds)} test images.")
135
+ else:
136
+ _LOGGER.info(f"Loaded: {len(train_ds)} train, {len(val_ds)} val images.")
137
+
138
+ maker._is_split = True # Mark as "split" since data is pre-split
139
+ return maker
140
+
141
+ @staticmethod
142
+ def inspect_folder(path: Union[str, Path]):
143
+ """
144
+ Logs a report of the types, sizes, and channels of image files
145
+ found in the directory and its subdirectories.
146
+
147
+ This is a utility method to help diagnose potential dataset
148
+ issues (e.g., mixed image modes, corrupted files) before loading.
149
+
150
+ Args:
151
+ path (str, Path): The directory path to inspect.
152
+ """
153
+ path_obj = make_fullpath(path)
154
+
155
+ non_image_files = set()
156
+ img_types = set()
157
+ img_sizes = set()
158
+ img_channels = set()
159
+ img_counter = 0
160
+
161
+ _LOGGER.info(f"Inspecting folder: {path_obj}...")
162
+ # Use rglob to recursively find all files
163
+ for filepath in path_obj.rglob('*'):
164
+ if filepath.is_file():
165
+ try:
166
+ # Using PIL to open is a more reliable check
167
+ with Image.open(filepath) as img:
168
+ img_types.add(img.format)
169
+ img_sizes.add(img.size)
170
+ img_channels.update(img.getbands())
171
+ img_counter += 1
172
+ except (IOError, SyntaxError):
173
+ non_image_files.add(filepath.name)
174
+
175
+ if non_image_files:
176
+ _LOGGER.warning(f"Non-image or corrupted files found and ignored: {non_image_files}")
177
+
178
+ report = (
179
+ f"\n--- Inspection Report for '{path_obj.name}' ---\n"
180
+ f"Total images found: {img_counter}\n"
181
+ f"Image formats: {img_types or 'None'}\n"
182
+ f"Image sizes (WxH): {img_sizes or 'None'}\n"
183
+ f"Image channels (bands): {img_channels or 'None'}\n"
184
+ f"--------------------------------------"
185
+ )
186
+ print(report)
187
+
188
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
189
+ stratify: bool = True, random_state: Optional[int] = None) -> 'DragonDatasetVision':
190
+ """
191
+ Splits the dataset into train, validation, and optional test sets.
192
+
193
+ This method MUST be called if you used `from_folder()`. It has no effect if you used `from_folders()`.
194
+
195
+ Args:
196
+ val_size (float): Proportion of the dataset to reserve for
197
+ validation (e.g., 0.2 for 20%).
198
+ test_size (float): Proportion of the dataset to reserve for
199
+ testing.
200
+ stratify (bool): If True, splits are performed in a stratified
201
+ fashion, preserving the class distribution.
202
+ random_state (int | None): Seed for the random number generator for reproducible splits.
203
+
204
+ Returns:
205
+ Self: The same instance, now with datasets split.
206
+
207
+ Raises:
208
+ ValueError: If `val_size` and `test_size` sum to 1.0 or more.
209
+ """
210
+ if self._is_split:
211
+ _LOGGER.warning("Data has already been split.")
212
+ return self
213
+
214
+ if val_size + test_size >= 1.0:
215
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
216
+ raise ValueError()
217
+
218
+ if not self._full_dataset:
219
+ _LOGGER.error("There is no dataset to split.")
220
+ raise ValueError()
221
+
222
+ indices = list(range(len(self._full_dataset)))
223
+ labels_for_split = self.labels if stratify else None
224
+
225
+ train_indices, val_test_indices = train_test_split(
226
+ indices, test_size=(val_size + test_size), random_state=random_state, stratify=labels_for_split
227
+ )
228
+
229
+ if not self.labels:
230
+ _LOGGER.error("Error when getting full dataset labels.")
231
+ raise ValueError()
232
+
233
+ if test_size > 0:
234
+ val_test_labels = [self.labels[i] for i in val_test_indices]
235
+ stratify_val_test = val_test_labels if stratify else None
236
+ val_indices, test_indices = train_test_split(
237
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
238
+ random_state=random_state, stratify=stratify_val_test
239
+ )
240
+ self._test_dataset = Subset(self._full_dataset, test_indices)
241
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
242
+ else:
243
+ val_indices = val_test_indices
244
+
245
+ self._train_dataset = Subset(self._full_dataset, train_indices)
246
+ self._val_dataset = Subset(self._full_dataset, val_indices)
247
+ self._is_split = True
248
+
249
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
250
+ return self
251
+
252
+ def configure_transforms(self,
253
+ resize_size: int = 256,
254
+ crop_size: int = 224,
255
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
256
+ std: Optional[List[float]] = [0.229, 0.224, 0.225],
257
+ pre_transforms: Optional[List[Callable]] = None,
258
+ extra_train_transforms: Optional[List[Callable]] = None) -> 'DragonDatasetVision':
259
+ """
260
+ Configures and applies the image transformations and augmentations.
261
+
262
+ This method must be called AFTER data is loaded and split.
263
+
264
+ It sets up two pipelines:
265
+ 1. **Training Pipeline:** Includes random augmentations:
266
+ `RandomResizedCrop(crop_size)`, `RandomHorizontalFlip(0.5)`, and `RandomRotation(90)` (plus any
267
+ `extra_train_transforms`) for online augmentation.
268
+ 2. **Validation/Test Pipeline:** A deterministic pipeline using `Resize` and `CenterCrop` for consistent evaluation.
269
+
270
+ Both pipelines finish with `ToTensor` and `Normalize`.
271
+
272
+ Args:
273
+ resize_size (int): The size to resize the smallest edge to
274
+ for validation/testing.
275
+ crop_size (int): The target size (square) for the final
276
+ cropped image.
277
+ mean (List[float] | None): The mean values for normalization (e.g., ImageNet mean).
278
+ std (List[float] | None): The standard deviation values for normalization (e.g., ImageNet std).
279
+ extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
280
+ pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
281
+
282
+ Returns:
283
+ Self: The same instance, with transforms applied.
284
+
285
+ Raises:
286
+ RuntimeError: If called before data is split.
287
+ """
288
+ if not self._is_split:
289
+ _LOGGER.error("Transforms must be configured AFTER splitting data (or using `from_folders`). Call .split_data() first if using `from_folder`.")
290
+ raise RuntimeError()
291
+
292
+ if (mean is None and std is not None) or (mean is not None and std is None):
293
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
294
+ raise ValueError()
295
+
296
+ # --- Define Transform Pipelines ---
297
+ # These now MUST include ToTensor and Normalize, as the ImageFolder was loaded with transform=None.
298
+
299
+ # --- Store components for validation recipe ---
300
+ self._val_recipe_components = {
301
+ VisionTransformRecipeKeys.PRE_TRANSFORMS: pre_transforms or [],
302
+ VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
303
+ VisionTransformRecipeKeys.CROP_SIZE: crop_size,
304
+ }
305
+
306
+ if mean is not None and std is not None:
307
+ self._val_recipe_components.update({
308
+ VisionTransformRecipeKeys.MEAN: mean,
309
+ VisionTransformRecipeKeys.STD: std
310
+ })
311
+ self._has_mean_std = True
312
+
313
+ base_pipeline = []
314
+ if pre_transforms:
315
+ base_pipeline.extend(pre_transforms)
316
+
317
+ # Base augmentations for training
318
+ base_train_transforms = [
319
+ transforms.RandomResizedCrop(size=crop_size),
320
+ transforms.RandomHorizontalFlip(p=0.5),
321
+ transforms.RandomRotation(degrees=90)
322
+ ]
323
+ if extra_train_transforms:
324
+ base_train_transforms.extend(extra_train_transforms)
325
+
326
+ # Final conversion and normalization
327
+ final_transforms: list[Callable] = [
328
+ transforms.ToTensor()
329
+ ]
330
+
331
+ if self._has_mean_std:
332
+ final_transforms.append(transforms.Normalize(mean=mean, std=std))
333
+
334
+ # Build the val/test pipeline
335
+ val_transform_list = [
336
+ *base_pipeline, # Apply pre_transforms first
337
+ transforms.Resize(resize_size),
338
+ transforms.CenterCrop(crop_size),
339
+ *final_transforms
340
+ ]
341
+
342
+ # Build the train pipeline
343
+ train_transform_list = [
344
+ *base_pipeline, # Apply pre_transforms first
345
+ *base_train_transforms,
346
+ *final_transforms
347
+ ]
348
+
349
+ val_transform = transforms.Compose(val_transform_list)
350
+ train_transform = transforms.Compose(train_transform_list)
351
+
352
+ # --- Apply Transforms using the Wrapper ---
353
+ # This correctly assigns the transform regardless of whether the dataset is a Subset (from_folder) or an ImageFolder (from_folders).
354
+
355
+ self._train_dataset = _DatasetTransformer(self._train_dataset, train_transform, self.class_map) # type: ignore
356
+ self._val_dataset = _DatasetTransformer(self._val_dataset, val_transform, self.class_map) # type: ignore
357
+ if self._test_dataset:
358
+ self._test_dataset = _DatasetTransformer(self._test_dataset, val_transform, self.class_map) # type: ignore
359
+
360
+ self._are_transforms_configured = True
361
+ _LOGGER.info("Image transforms configured and applied.")
362
+ return self
363
+
364
+ def get_datasets(self) -> Tuple[Dataset, ...]:
365
+ """
366
+ Returns the final train, validation, and optional test datasets.
367
+
368
+ This is the final step, used to retrieve the datasets for use in
369
+ a `MLTrainer` or `DataLoader`.
370
+
371
+ Returns:
372
+ (Tuple[Dataset, ...]): A tuple containing the (train, val)
373
+ or (train, val, test) datasets.
374
+
375
+ Raises:
376
+ RuntimeError: If called before data is split.
377
+ UserWarning: If called before transforms are configured.
378
+ """
379
+ if not self._is_split:
380
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
381
+ raise RuntimeError()
382
+ if not self._are_transforms_configured:
383
+ _LOGGER.warning("Transforms have not been configured.")
384
+
385
+ if self._test_dataset:
386
+ return self._train_dataset, self._val_dataset, self._test_dataset # type: ignore
387
+ return self._train_dataset, self._val_dataset # type: ignore
388
+
389
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
390
+ """
391
+ Saves the validation transform pipeline as a JSON recipe file.
392
+
393
+ This recipe can be loaded by the PyTorchVisionInferenceHandler
394
+ to ensure identical preprocessing.
395
+
396
+ Args:
397
+ filepath (str | Path): The path to save the .json recipe file.
398
+ """
399
+ if not self._are_transforms_configured:
400
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
401
+ raise RuntimeError()
402
+
403
+ recipe: Dict[str, Any] = {
404
+ VisionTransformRecipeKeys.TASK: "classification",
405
+ VisionTransformRecipeKeys.PIPELINE: []
406
+ }
407
+
408
+ components = self._val_recipe_components
409
+
410
+ if not components:
411
+ _LOGGER.error(f"Error getting the transformers recipe for validation set.")
412
+ raise ValueError()
413
+
414
+ # validate path
415
+ file_path = make_fullpath(filepath, make=True, enforce="file")
416
+
417
+ # Handle pre_transforms
418
+ for t in components[VisionTransformRecipeKeys.PRE_TRANSFORMS]:
419
+ t_name = t.__class__.__name__
420
+ t_class = t.__class__
421
+ kwargs = {}
422
+
423
+ # 1. Check custom registry first
424
+ if t_name in TRANSFORM_REGISTRY:
425
+ _LOGGER.debug(f"Found '{t_name}' in TRANSFORM_REGISTRY.")
426
+ kwargs = getattr(t, VisionTransformRecipeKeys.KWARGS, {})
427
+
428
+ # 2. Else, try to introspect for standard torchvision transforms
429
+ else:
430
+ _LOGGER.debug(f"'{t_name}' not in registry. Attempting introspection...")
431
+ try:
432
+ # Get the __init__ signature of the transform's class
433
+ sig = inspect.signature(t_class.__init__)
434
+
435
+ # Iterate over its __init__ parameters (e.g., 'num_output_channels')
436
+ for param in sig.parameters.values():
437
+ if param.name == 'self':
438
+ continue
439
+
440
+ # Check if the *instance* 't' has that parameter as an attribute
441
+ attr_name_public = param.name
442
+ attr_name_private = '_' + param.name
443
+
444
+ attr_to_get = ""
445
+
446
+ if hasattr(t, attr_name_public):
447
+ attr_to_get = attr_name_public
448
+ elif hasattr(t, attr_name_private):
449
+ attr_to_get = attr_name_private
450
+ else:
451
+ # Parameter in __init__ has no matching attribute
452
+ continue
453
+
454
+ # Store the value under the __init__ parameter's name
455
+ kwargs[param.name] = getattr(t, attr_to_get)
456
+
457
+ _LOGGER.debug(f"Introspection for '{t_name}' found kwargs: {kwargs}")
458
+
459
+ except (ValueError, TypeError):
460
+ # Fails on some built-ins or C-implemented __init__
461
+ _LOGGER.warning(f"Could not introspect parameters for '{t_name}'. If this transform has parameters, they will not be saved.")
462
+ kwargs = {}
463
+
464
+ # 3. Add to pipeline
465
+ recipe[VisionTransformRecipeKeys.PIPELINE].append({
466
+ VisionTransformRecipeKeys.NAME: t_name,
467
+ VisionTransformRecipeKeys.KWARGS: kwargs
468
+ })
469
+
470
+ # 2. Add standard transforms
471
+ recipe[VisionTransformRecipeKeys.PIPELINE].extend([
472
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
473
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
474
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
475
+ ])
476
+
477
+ if self._has_mean_std:
478
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
479
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
480
+ "mean": components[VisionTransformRecipeKeys.MEAN],
481
+ "std": components[VisionTransformRecipeKeys.STD]
482
+ }}
483
+ )
484
+
485
+ # 3. Save the file
486
+ _save_recipe(recipe, file_path)
487
+
488
+ def save_class_map(self, save_dir: Union[str,Path]) -> dict[str,int]:
489
+ """
490
+ Saves the class to index mapping {str: int} to a directory.
491
+ """
492
+ if not self.class_map:
493
+ _LOGGER.error(f"Class to index mapping is empty.")
494
+ raise ValueError()
495
+
496
+ custom_logger(data=self.class_map,
497
+ save_directory=save_dir,
498
+ log_name="Class_to_Index",
499
+ add_timestamp=False,
500
+ dict_as="json")
501
+
502
+ return self.class_map
503
+
504
+ def images_per_dataset(self) -> str:
505
+ """
506
+ Get the number of images per dataset as a string.
507
+ """
508
+ if self._is_split:
509
+ train_len = len(self._train_dataset) if self._train_dataset else 0
510
+ val_len = len(self._val_dataset) if self._val_dataset else 0
511
+ test_len = len(self._test_dataset) if self._test_dataset else 0
512
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images"
513
+ elif self._full_dataset:
514
+ return f"Full Dataset: {len(self._full_dataset)} images"
515
+ else:
516
+ _LOGGER.warning("No datasets found.")
517
+ return "No datasets found"
518
+
519
+ def __repr__(self) -> str:
520
+ s = f"<{self.__class__.__name__}>:\n"
521
+ s += f" Split: {self._is_split}\n"
522
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
523
+
524
+ if self.class_map:
525
+ s += f" Classes: {len(self.class_map)}\n"
526
+
527
+ if self._is_split:
528
+ train_len = len(self._train_dataset) if self._train_dataset else 0
529
+ val_len = len(self._val_dataset) if self._val_dataset else 0
530
+ test_len = len(self._test_dataset) if self._test_dataset else 0
531
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
532
+ elif self._full_dataset:
533
+ s += f" Full Dataset Size: {len(self._full_dataset)} images\n"
534
+
535
+ return s
536
+
537
+
538
+ class _DatasetTransformer(Dataset):
539
+ """
540
+ Internal wrapper class to apply a specific transform pipeline to any
541
+ dataset (e.g., a full ImageFolder or a Subset).
542
+ """
543
+ def __init__(self, dataset: Dataset, transform: Optional[transforms.Compose] = None, class_map: dict[str,int]=dict()):
544
+ self.dataset = dataset
545
+ self.transform = transform
546
+ self.class_map = class_map
547
+
548
+ # --- Propagate attributes for inspection ---
549
+ # For ImageFolder
550
+ if hasattr(dataset, 'class_to_idx'):
551
+ self.class_to_idx = getattr(dataset, 'class_to_idx')
552
+ if hasattr(dataset, 'classes'):
553
+ self.classes = getattr(dataset, 'classes')
554
+ # For Subset
555
+ if hasattr(dataset, 'indices'):
556
+ self.indices = getattr(dataset, 'indices')
557
+ if hasattr(dataset, 'dataset'):
558
+ # This allows access to the *original* full dataset
559
+ self.original_dataset = getattr(dataset, 'dataset')
560
+
561
+ def __getitem__(self, index):
562
+ # Get the original data (e.g., PIL Image, label)
563
+ x, y = self.dataset[index]
564
+
565
+ # Apply the specific transform for this dataset
566
+ if self.transform:
567
+ x = self.transform(x)
568
+ return x, y
569
+
570
+ def __len__(self):
571
+ return len(self.dataset) # type: ignore
572
+
573
+
574
+ # --- Segmentation dataset ----
575
+ class _SegmentationDataset(Dataset):
576
+ """
577
+ Internal helper class to load image-mask pairs.
578
+
579
+ Loads images as RGB and masks as 'L' (grayscale, 8-bit integer pixels).
580
+ """
581
+ def __init__(self, image_paths: List[Path], mask_paths: List[Path], transform: Optional[Callable] = None):
582
+ self.image_paths = image_paths
583
+ self.mask_paths = mask_paths
584
+ self.transform = transform
585
+
586
+ # --- Propagate 'classes' if they exist for trainer ---
587
+ self.classes: List[str] = []
588
+
589
+ def __len__(self):
590
+ return len(self.image_paths)
591
+
592
+ def __getitem__(self, idx):
593
+ img_path = self.image_paths[idx]
594
+ mask_path = self.mask_paths[idx]
595
+
596
+ try:
597
+ # Open as PIL Images. Masks should be 'L'
598
+ image = Image.open(img_path).convert("RGB")
599
+ mask = Image.open(mask_path).convert("L")
600
+ except Exception as e:
601
+ _LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {mask_path.name}. Error: {e}")
602
+ # Return empty tensors
603
+ return torch.empty(3, 224, 224), torch.empty(224, 224, dtype=torch.long)
604
+
605
+ if self.transform:
606
+ image, mask = self.transform(image, mask)
607
+
608
+ return image, mask
609
+
610
+
611
+ # Internal Paired Transform Helpers
612
+ class _PairedCompose:
613
+ """A 'Compose' for paired image/mask transforms."""
614
+ def __init__(self, transforms: List[Callable]):
615
+ self.transforms = transforms
616
+
617
+ def __call__(self, image: Any, mask: Any) -> Tuple[Any, Any]:
618
+ for t in self.transforms:
619
+ image, mask = t(image, mask)
620
+ return image, mask
621
+
622
+ class _PairedToTensor:
623
+ """Converts a PIL Image pair (image, mask) to Tensors."""
624
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[torch.Tensor, torch.Tensor]:
625
+ # Use new variable names to satisfy the linter
626
+ image_tensor = TF.to_tensor(image)
627
+ # Convert mask to LongTensor, not float.
628
+ # This creates a [H, W] tensor of integer class IDs.
629
+ mask_tensor = torch.from_numpy(numpy.array(mask, dtype=numpy.int64))
630
+ return image_tensor, mask_tensor
631
+
632
+ class _PairedNormalize:
633
+ """Normalizes the image tensor and leaves the mask untouched."""
634
+ def __init__(self, mean: List[float], std: List[float]):
635
+ self.normalize = transforms.Normalize(mean, std)
636
+
637
+ def __call__(self, image: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
638
+ image = self.normalize(image)
639
+ return image, mask
640
+
641
+ class _PairedResize:
642
+ """Resizes an image and mask to the same size."""
643
+ def __init__(self, size: int):
644
+ self.size = [size, size]
645
+
646
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
647
+ # Use new variable names to avoid linter confusion
648
+ resized_image = TF.resize(image, self.size, interpolation=TF.InterpolationMode.BILINEAR) # type: ignore
649
+ # Use NEAREST for mask to avoid interpolating class IDs (e.g., 1.5)
650
+ resized_mask = TF.resize(mask, self.size, interpolation=TF.InterpolationMode.NEAREST) # type: ignore
651
+ return resized_image, resized_mask # type: ignore
652
+
653
+ class _PairedCenterCrop:
654
+ """Center-crops an image and mask to the same size."""
655
+ def __init__(self, size: int):
656
+ self.size = [size, size]
657
+
658
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
659
+ cropped_image = TF.center_crop(image, self.size) # type: ignore
660
+ cropped_mask = TF.center_crop(mask, self.size) # type: ignore
661
+ return cropped_image, cropped_mask # type: ignore
662
+
663
+ class _PairedRandomHorizontalFlip:
664
+ """Applies the same random horizontal flip to both image and mask."""
665
+ def __init__(self, p: float = 0.5):
666
+ self.p = p
667
+
668
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
669
+ if random.random() < self.p:
670
+ flipped_image = TF.hflip(image) # type: ignore
671
+ flipped_mask = TF.hflip(mask) # type: ignore
672
+ return flipped_image, flipped_mask # type: ignore
673
+
674
+ class _PairedRandomResizedCrop:
675
+ """Applies the same random resized crop to both image and mask."""
676
+ def __init__(self, size: int, scale: Tuple[float, float]=(0.08, 1.0), ratio: Tuple[float, float]=(3./4., 4./3.)):
677
+ self.size = [size, size]
678
+ self.scale = scale
679
+ self.ratio = ratio
680
+ self.interpolation = TF.InterpolationMode.BILINEAR
681
+ self.mask_interpolation = TF.InterpolationMode.NEAREST
682
+
683
+ def __call__(self, image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
684
+ # Get parameters for the random crop
685
+ # Convert scale/ratio tuples to lists to satisfy the linter's type hint
686
+ i, j, h, w = transforms.RandomResizedCrop.get_params(image, list(self.scale), list(self.ratio)) # type: ignore
687
+
688
+ # Apply the crop with the SAME parameters and use new variable names
689
+ cropped_image = TF.resized_crop(image, i, j, h, w, self.size, self.interpolation) # type: ignore
690
+ cropped_mask = TF.resized_crop(mask, i, j, h, w, self.size, self.mask_interpolation) # type: ignore
691
+
692
+ return cropped_image, cropped_mask # type: ignore
693
+
694
+ # --- Segmentation Dataset ---
695
+ class DragonDatasetSegmentation:
696
+ """
697
+ Creates processed PyTorch datasets for segmentation from image and mask folders.
698
+
699
+ This maker finds all matching image-mask pairs from two directories,
700
+ splits them, and applies identical transformations (including augmentations)
701
+ to both the image and its corresponding mask.
702
+
703
+ Workflow:
704
+ 1. `maker = DragonDatasetSegmentation.from_folders(img_dir, mask_dir)`
705
+ 2. `maker.set_class_map({'background': 0, 'road': 1})`
706
+ 3. `maker.split_data(val_size=0.2)`
707
+ 4. `maker.configure_transforms(crop_size=256)`
708
+ 5. `train_ds, val_ds = maker.get_datasets()`
709
+ """
710
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
711
+
712
+ def __init__(self):
713
+ """
714
+ Typically not called directly. Use the class method `from_folders()` to create an instance.
715
+ """
716
+ self._train_dataset = None
717
+ self._test_dataset = None
718
+ self._val_dataset = None
719
+ self.image_paths: List[Path] = []
720
+ self.mask_paths: List[Path] = []
721
+ self.class_map: Dict[str, int] = {}
722
+
723
+ self._is_split = False
724
+ self._are_transforms_configured = False
725
+ self.train_transform: Optional[Callable] = None
726
+ self.val_transform: Optional[Callable] = None
727
+ self._has_mean_std: bool = False
728
+
729
+ @classmethod
730
+ def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'DragonDatasetSegmentation':
731
+ """
732
+ Creates a maker instance by loading all matching image-mask pairs
733
+ from two corresponding directories.
734
+
735
+ This method assumes that for an image `images/img_001.png`, there
736
+ is a corresponding mask `masks/img_001.png`.
737
+
738
+ Args:
739
+ image_dir (str | Path): Path to the directory containing input images.
740
+ mask_dir (str | Path): Path to the directory containing segmentation masks.
741
+
742
+ Returns:
743
+ DragonDatasetSegmentation: A new instance with all pairs loaded.
744
+ """
745
+ maker = cls()
746
+ img_path_obj = make_fullpath(image_dir, enforce="directory")
747
+ msk_path_obj = make_fullpath(mask_dir, enforce="directory")
748
+
749
+ # Find all images
750
+ image_files = sorted([
751
+ p for p in img_path_obj.glob('*.*')
752
+ if p.suffix.lower() in cls.IMG_EXTENSIONS
753
+ ])
754
+
755
+ if not image_files:
756
+ _LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
757
+ raise FileNotFoundError()
758
+
759
+ _LOGGER.info(f"Found {len(image_files)} images. Searching for matching masks in {mask_dir}...")
760
+
761
+ good_img_paths = []
762
+ good_mask_paths = []
763
+
764
+ for img_file in image_files:
765
+ mask_file = None
766
+
767
+ # 1. Try to find mask with the exact same name
768
+ mask_file_primary = msk_path_obj / img_file.name
769
+ if mask_file_primary.exists():
770
+ mask_file = mask_file_primary
771
+
772
+ # 2. If not, try to find mask with same stem + common mask extension
773
+ if mask_file is None:
774
+ for ext in cls.IMG_EXTENSIONS: # Masks are often .png
775
+ mask_file_secondary = msk_path_obj / (img_file.stem + ext)
776
+ if mask_file_secondary.exists():
777
+ mask_file = mask_file_secondary
778
+ break
779
+
780
+ # 3. If a match is found, add the pair
781
+ if mask_file:
782
+ good_img_paths.append(img_file)
783
+ good_mask_paths.append(mask_file)
784
+ else:
785
+ _LOGGER.warning(f"No corresponding mask found for image: {img_file.name}")
786
+
787
+ if not good_img_paths:
788
+ _LOGGER.error("No matching image-mask pairs were found.")
789
+ raise FileNotFoundError()
790
+
791
+ _LOGGER.info(f"Successfully found {len(good_img_paths)} image-mask pairs.")
792
+ maker.image_paths = good_img_paths
793
+ maker.mask_paths = good_mask_paths
794
+
795
+ return maker
796
+
797
+ @staticmethod
798
+ def inspect_folder(path: Union[str, Path]):
799
+ """
800
+ Logs a report of the types, sizes, and channels of image files
801
+ found in the directory. Useful for checking masks.
802
+ """
803
+ DragonDatasetVision.inspect_folder(path)
804
+
805
+ def set_class_map(self, class_map: Dict[str, int]) -> 'DragonDatasetSegmentation':
806
+ """
807
+ Sets a map of class_name -> pixel value. This is used by the Trainer for clear evaluation reports.
808
+
809
+ Args:
810
+ class_map (Dict[str, int]): A dictionary mapping the integer pixel
811
+ value in a mask to its string name.
812
+ Example: {'background': 0, 'road': 1, 'car': 2}
813
+ """
814
+ self.class_map = class_map
815
+ _LOGGER.info(f"Class map set: {class_map}")
816
+ return self
817
+
818
+ @property
819
+ def classes(self) -> List[str]:
820
+ """Returns the list of class names, if set."""
821
+ if self.class_map:
822
+ return list(self.class_map.keys())
823
+ return []
824
+
825
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
826
+ random_state: Optional[int] = 42) -> 'DragonDatasetSegmentation':
827
+ """
828
+ Splits the loaded image-mask pairs into train, validation, and test sets.
829
+
830
+ Args:
831
+ val_size (float): Proportion of the dataset to reserve for validation.
832
+ test_size (float): Proportion of the dataset to reserve for testing.
833
+ random_state (int | None): Seed for reproducible splits.
834
+
835
+ Returns:
836
+ DragonDatasetSegmentation: The same instance, now with datasets created.
837
+ """
838
+ if self._is_split:
839
+ _LOGGER.warning("Data has already been split.")
840
+ return self
841
+
842
+ if val_size + test_size >= 1.0:
843
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
844
+ raise ValueError()
845
+
846
+ if not self.image_paths:
847
+ _LOGGER.error("There is no data to split. Use .from_folders() first.")
848
+ raise RuntimeError()
849
+
850
+ indices = list(range(len(self.image_paths)))
851
+
852
+ # Split indices
853
+ train_indices, val_test_indices = train_test_split(
854
+ indices, test_size=(val_size + test_size), random_state=random_state
855
+ )
856
+
857
+ # Helper to get paths from indices
858
+ def get_paths(idx_list):
859
+ return [self.image_paths[i] for i in idx_list], [self.mask_paths[i] for i in idx_list]
860
+
861
+ train_imgs, train_masks = get_paths(train_indices)
862
+
863
+ if test_size > 0:
864
+ val_indices, test_indices = train_test_split(
865
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
866
+ random_state=random_state
867
+ )
868
+ val_imgs, val_masks = get_paths(val_indices)
869
+ test_imgs, test_masks = get_paths(test_indices)
870
+
871
+ self._test_dataset = _SegmentationDataset(test_imgs, test_masks, transform=None)
872
+ self._test_dataset.classes = self.classes # type: ignore
873
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
874
+ else:
875
+ val_imgs, val_masks = get_paths(val_test_indices)
876
+
877
+ self._train_dataset = _SegmentationDataset(train_imgs, train_masks, transform=None)
878
+ self._val_dataset = _SegmentationDataset(val_imgs, val_masks, transform=None)
879
+
880
+ # Propagate class names to datasets for trainer
881
+ self._train_dataset.classes = self.classes # type: ignore
882
+ self._val_dataset.classes = self.classes # type: ignore
883
+
884
+ self._is_split = True
885
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
886
+ return self
887
+
888
+ def configure_transforms(self,
889
+ resize_size: int = 256,
890
+ crop_size: int = 224,
891
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
892
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'DragonDatasetSegmentation':
893
+ """
894
+ Configures and applies the image and mask transformations.
895
+
896
+ This method must be called AFTER data is split.
897
+
898
+ Args:
899
+ resize_size (int): The size to resize the smallest edge to
900
+ for validation/testing.
901
+ crop_size (int): The target size (square) for the final
902
+ cropped image.
903
+ mean (List[float] | None): The mean values for image normalization.
904
+ std (List[float] | None): The std dev values for image normalization.
905
+
906
+ Returns:
907
+ DragonDatasetSegmentation: The same instance, with transforms applied.
908
+ """
909
+ if not self._is_split:
910
+ _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
911
+ raise RuntimeError()
912
+
913
+ if (mean is None and std is not None) or (mean is not None and std is None):
914
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
915
+ raise ValueError()
916
+
917
+ # --- Store components for validation recipe ---
918
+ self.val_recipe_components: dict[str,Any] = {
919
+ VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
920
+ VisionTransformRecipeKeys.CROP_SIZE: crop_size,
921
+ }
922
+
923
+ if mean is not None and std is not None:
924
+ self.val_recipe_components.update({
925
+ VisionTransformRecipeKeys.MEAN: mean,
926
+ VisionTransformRecipeKeys.STD: std
927
+ })
928
+ self._has_mean_std = True
929
+
930
+ # --- Validation/Test Pipeline (Deterministic) ---
931
+ if self._has_mean_std:
932
+ self.val_transform = _PairedCompose([
933
+ _PairedResize(resize_size),
934
+ _PairedCenterCrop(crop_size),
935
+ _PairedToTensor(),
936
+ _PairedNormalize(mean, std) # type: ignore
937
+ ])
938
+ # --- Training Pipeline (Augmentation) ---
939
+ self.train_transform = _PairedCompose([
940
+ _PairedRandomResizedCrop(crop_size),
941
+ _PairedRandomHorizontalFlip(p=0.5),
942
+ _PairedToTensor(),
943
+ _PairedNormalize(mean, std) # type: ignore
944
+ ])
945
+ else:
946
+ self.val_transform = _PairedCompose([
947
+ _PairedResize(resize_size),
948
+ _PairedCenterCrop(crop_size),
949
+ _PairedToTensor()
950
+ ])
951
+ # --- Training Pipeline (Augmentation) ---
952
+ self.train_transform = _PairedCompose([
953
+ _PairedRandomResizedCrop(crop_size),
954
+ _PairedRandomHorizontalFlip(p=0.5),
955
+ _PairedToTensor()
956
+ ])
957
+
958
+ # --- Apply Transforms to the Datasets ---
959
+ self._train_dataset.transform = self.train_transform # type: ignore
960
+ self._val_dataset.transform = self.val_transform # type: ignore
961
+ if self._test_dataset:
962
+ self._test_dataset.transform = self.val_transform # type: ignore
963
+
964
+ self._are_transforms_configured = True
965
+ _LOGGER.info("Paired segmentation transforms configured and applied.")
966
+ return self
967
+
968
+ def get_datasets(self) -> Tuple[Dataset, ...]:
969
+ """
970
+ Returns the final train, validation, and optional test datasets.
971
+
972
+ Raises:
973
+ RuntimeError: If called before data is split.
974
+ RuntimeError: If called before transforms are configured.
975
+ """
976
+ if not self._is_split:
977
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
978
+ raise RuntimeError()
979
+ if not self._are_transforms_configured:
980
+ _LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
981
+ raise RuntimeError()
982
+
983
+ if self._test_dataset:
984
+ return self._train_dataset, self._val_dataset, self._test_dataset # type: ignore
985
+ return self._train_dataset, self._val_dataset # type: ignore
986
+
987
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
988
+ """
989
+ Saves the validation transform pipeline as a JSON recipe file.
990
+
991
+ This recipe can be loaded by the PyTorchVisionInferenceHandler
992
+ to ensure identical preprocessing.
993
+
994
+ Args:
995
+ filepath (str | Path): The path to save the .json recipe file.
996
+ """
997
+ if not self._are_transforms_configured:
998
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
999
+ raise RuntimeError()
1000
+
1001
+ components = self.val_recipe_components
1002
+
1003
+ if not components:
1004
+ _LOGGER.error(f"Error getting the transformers recipe for validation set.")
1005
+ raise ValueError()
1006
+
1007
+ # validate path
1008
+ file_path = make_fullpath(filepath, make=True, enforce="file")
1009
+
1010
+ # Add standard transforms
1011
+ recipe: Dict[str, Any] = {
1012
+ VisionTransformRecipeKeys.TASK: "segmentation",
1013
+ VisionTransformRecipeKeys.PIPELINE: [
1014
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
1015
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
1016
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
1017
+ ]
1018
+ }
1019
+
1020
+ if self._has_mean_std:
1021
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1022
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1023
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1024
+ "std": components[VisionTransformRecipeKeys.STD]
1025
+ }}
1026
+ )
1027
+
1028
+ # Save the file
1029
+ _save_recipe(recipe, file_path)
1030
+
1031
+ def images_per_dataset(self) -> str:
1032
+ """
1033
+ Get the number of images per dataset as a string.
1034
+ """
1035
+ if self._is_split:
1036
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1037
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1038
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1039
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images"
1040
+ else:
1041
+ _LOGGER.warning("No datasets found.")
1042
+ return "No datasets found"
1043
+
1044
+ def __repr__(self) -> str:
1045
+ s = f"<{self.__class__.__name__}>:\n"
1046
+ s += f" Total Image-Mask Pairs: {len(self.image_paths)}\n"
1047
+ s += f" Split: {self._is_split}\n"
1048
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1049
+
1050
+ if self.class_map:
1051
+ s += f" Classes: {list(self.class_map.keys())}\n"
1052
+
1053
+ if self._is_split:
1054
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1055
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1056
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1057
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1058
+
1059
+ return s
1060
+
1061
+
1062
+ # Object detection
1063
+ def _od_collate_fn(batch: List[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]) -> Tuple[List[torch.Tensor], List[Dict[str, torch.Tensor]]]:
1064
+ """
1065
+ Custom collate function for object detection.
1066
+
1067
+ Takes a list of (image, target) tuples and zips them into two lists:
1068
+ (list_of_images, list_of_targets).
1069
+ This is required for models like Faster R-CNN, which accept a list
1070
+ of images of varying sizes.
1071
+ """
1072
+ return tuple(zip(*batch)) # type: ignore
1073
+
1074
+
1075
+ class _ObjectDetectionDataset(Dataset):
1076
+ """
1077
+ Internal helper class to load image-annotation pairs.
1078
+
1079
+ Loads an image as 'RGB' and parses its corresponding JSON annotation file
1080
+ to create the required target dictionary (boxes, labels).
1081
+ """
1082
+ def __init__(self, image_paths: List[Path], annotation_paths: List[Path], transform: Optional[Callable] = None):
1083
+ self.image_paths = image_paths
1084
+ self.annotation_paths = annotation_paths
1085
+ self.transform = transform
1086
+
1087
+ # --- Propagate 'classes' if they exist ---
1088
+ self.classes: List[str] = []
1089
+
1090
+ def __len__(self):
1091
+ return len(self.image_paths)
1092
+
1093
+ def __getitem__(self, idx):
1094
+ img_path = self.image_paths[idx]
1095
+ ann_path = self.annotation_paths[idx]
1096
+
1097
+ try:
1098
+ # Open image
1099
+ image = Image.open(img_path).convert("RGB")
1100
+
1101
+ # Load and parse annotation
1102
+ with open(ann_path, 'r') as f:
1103
+ ann_data = json.load(f)
1104
+
1105
+ # Get boxes and labels from JSON
1106
+ boxes = ann_data[ObjectDetectionKeys.BOXES] # Expected: [[x1, y1, x2, y2], ...]
1107
+ labels = ann_data[ObjectDetectionKeys.LABELS] # Expected: [1, 2, 1, ...]
1108
+
1109
+ # Convert to tensors
1110
+ target: Dict[str, Any] = {}
1111
+ target[ObjectDetectionKeys.BOXES] = torch.as_tensor(boxes, dtype=torch.float32)
1112
+ target[ObjectDetectionKeys.LABELS] = torch.as_tensor(labels, dtype=torch.int64)
1113
+
1114
+ except Exception as e:
1115
+ _LOGGER.error(f"Error loading sample #{idx}: {img_path.name} / {ann_path.name}. Error: {e}")
1116
+ # Return empty/dummy data
1117
+ return torch.empty(3, 224, 224), {ObjectDetectionKeys.BOXES: torch.empty((0, 4)), ObjectDetectionKeys.LABELS: torch.empty(0, dtype=torch.long)}
1118
+
1119
+ if self.transform:
1120
+ image, target = self.transform(image, target)
1121
+
1122
+ return image, target
1123
+
1124
+ # Internal Paired Transform Helpers for Object Detection
1125
+ class _OD_PairedCompose:
1126
+ """A 'Compose' for paired image/target_dict transforms."""
1127
+ def __init__(self, transforms: List[Callable]):
1128
+ self.transforms = transforms
1129
+
1130
+ def __call__(self, image: Any, target: Any) -> Tuple[Any, Any]:
1131
+ for t in self.transforms:
1132
+ image, target = t(image, target)
1133
+ return image, target
1134
+
1135
+ class _OD_PairedToTensor:
1136
+ """Converts a PIL Image to Tensor, passes targets dict through."""
1137
+ def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
1138
+ return TF.to_tensor(image), target
1139
+
1140
+ class _OD_PairedNormalize:
1141
+ """Normalizes the image tensor and leaves the target dict untouched."""
1142
+ def __init__(self, mean: List[float], std: List[float]):
1143
+ self.normalize = transforms.Normalize(mean, std)
1144
+
1145
+ def __call__(self, image: torch.Tensor, target: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, Any]]:
1146
+ image_normalized = self.normalize(image)
1147
+ return image_normalized, target
1148
+
1149
+ class _OD_PairedRandomHorizontalFlip:
1150
+ """Applies the same random horizontal flip to both image and targets['boxes']."""
1151
+ def __init__(self, p: float = 0.5):
1152
+ self.p = p
1153
+
1154
+ def __call__(self, image: Image.Image, target: Dict[str, Any]) -> Tuple[Image.Image, Dict[str, Any]]:
1155
+ if random.random() < self.p:
1156
+ w, h = image.size
1157
+ # Use new variable names to avoid linter confusion
1158
+ flipped_image = TF.hflip(image) # type: ignore
1159
+
1160
+ # Flip boxes
1161
+ boxes = target[ObjectDetectionKeys.BOXES].clone() # [N, 4]
1162
+
1163
+ # xmin' = w - xmax
1164
+ # xmax' = w - xmin
1165
+ boxes[:, 0] = w - target[ObjectDetectionKeys.BOXES][:, 2] # xmin'
1166
+ boxes[:, 2] = w - target[ObjectDetectionKeys.BOXES][:, 0] # xmax'
1167
+ target[ObjectDetectionKeys.BOXES] = boxes
1168
+
1169
+ return flipped_image, target # type: ignore
1170
+
1171
+ return image, target
1172
+
1173
+
1174
+ class DragonDatasetObjectDetection:
1175
+ """
1176
+ Creates processed PyTorch datasets for object detection from image
1177
+ and JSON annotation folders.
1178
+
1179
+ This maker finds all matching image-annotation pairs from two directories,
1180
+ splits them, and applies identical transformations (including augmentations)
1181
+ to both the image and its corresponding target dictionary.
1182
+
1183
+ The `DragonFastRCNN` model expects a list of images and a list of targets,
1184
+ so this class provides a `collate_fn` to be used with a DataLoader.
1185
+
1186
+ Workflow:
1187
+ 1. `maker = DragonDatasetObjectDetection.from_folders(img_dir, ann_dir)`
1188
+ 2. `maker.set_class_map({'background': 0, 'person': 1, 'car': 2})`
1189
+ 3. `maker.split_data(val_size=0.2)`
1190
+ 4. `maker.configure_transforms()`
1191
+ 5. `train_ds, val_ds = maker.get_datasets()`
1192
+ 6. `collate_fn = maker.collate_fn`
1193
+ 7. `train_loader = DataLoader(train_ds, ..., collate_fn=collate_fn)`
1194
+ """
1195
+ IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff')
1196
+
1197
+ def __init__(self):
1198
+ """
1199
+ Typically not called directly. Use the class method `from_folders()` to create an instance.
1200
+ """
1201
+ self._train_dataset = None
1202
+ self._test_dataset = None
1203
+ self._val_dataset = None
1204
+ self.image_paths: List[Path] = []
1205
+ self.annotation_paths: List[Path] = []
1206
+ self.class_map: Dict[str, int] = {}
1207
+
1208
+ self._is_split = False
1209
+ self._are_transforms_configured = False
1210
+ self.train_transform: Optional[Callable] = None
1211
+ self.val_transform: Optional[Callable] = None
1212
+ self._val_recipe_components: Optional[Dict[str, Any]] = None
1213
+ self._has_mean_std: bool = False
1214
+
1215
+ @classmethod
1216
+ def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'DragonDatasetObjectDetection':
1217
+ """
1218
+ Creates a maker instance by loading all matching image-annotation pairs
1219
+ from two corresponding directories.
1220
+
1221
+ This method assumes that for an image `images/img_001.png`, there
1222
+ is a corresponding annotation `annotations/img_001.json`.
1223
+
1224
+ The JSON file must contain "boxes" and "labels" keys:
1225
+ `{"boxes": [[x1,y1,x2,y2], ...], "labels": [1, 2, ...]}`
1226
+
1227
+ Args:
1228
+ image_dir (str | Path): Path to the directory containing input images.
1229
+ annotation_dir (str | Path): Path to the directory containing .json
1230
+ annotation files.
1231
+
1232
+ Returns:
1233
+ DragonDatasetObjectDetection: A new instance with all pairs loaded.
1234
+ """
1235
+ maker = cls()
1236
+ img_path_obj = make_fullpath(image_dir, enforce="directory")
1237
+ ann_path_obj = make_fullpath(annotation_dir, enforce="directory")
1238
+
1239
+ # Find all images
1240
+ image_files = sorted([
1241
+ p for p in img_path_obj.glob('*.*')
1242
+ if p.suffix.lower() in cls.IMG_EXTENSIONS
1243
+ ])
1244
+
1245
+ if not image_files:
1246
+ _LOGGER.error(f"No images with extensions {cls.IMG_EXTENSIONS} found in {image_dir}")
1247
+ raise FileNotFoundError()
1248
+
1249
+ _LOGGER.info(f"Found {len(image_files)} images. Searching for matching .json annotations in {annotation_dir}...")
1250
+
1251
+ good_img_paths = []
1252
+ good_ann_paths = []
1253
+
1254
+ for img_file in image_files:
1255
+ # Find annotation with same stem + .json
1256
+ ann_file = ann_path_obj / (img_file.stem + ".json")
1257
+
1258
+ if ann_file.exists():
1259
+ good_img_paths.append(img_file)
1260
+ good_ann_paths.append(ann_file)
1261
+ else:
1262
+ _LOGGER.warning(f"No corresponding .json annotation found for image: {img_file.name}")
1263
+
1264
+ if not good_img_paths:
1265
+ _LOGGER.error("No matching image-annotation pairs were found.")
1266
+ raise FileNotFoundError()
1267
+
1268
+ _LOGGER.info(f"Successfully found {len(good_img_paths)} image-annotation pairs.")
1269
+ maker.image_paths = good_img_paths
1270
+ maker.annotation_paths = good_ann_paths
1271
+
1272
+ return maker
1273
+
1274
+ @staticmethod
1275
+ def inspect_folder(path: Union[str, Path]):
1276
+ """
1277
+ Logs a report of the types, sizes, and channels of image files
1278
+ found in the directory.
1279
+ """
1280
+ DragonDatasetVision.inspect_folder(path)
1281
+
1282
+ def set_class_map(self, class_map: Dict[str, int]) -> 'DragonDatasetObjectDetection':
1283
+ """
1284
+ Sets a map of class_name -> pixel_value. This is used by the
1285
+ trainer for clear evaluation reports.
1286
+
1287
+ **Important:** For object detection models, 'background' MUST
1288
+ be included as class 0.
1289
+ Example: `{'background': 0, 'person': 1, 'car': 2}`
1290
+
1291
+ Args:
1292
+ class_map (Dict[str, int]): A dictionary mapping the string name
1293
+ to its integer label.
1294
+ """
1295
+ if 'background' not in class_map or class_map['background'] != 0:
1296
+ _LOGGER.warning("Object detection class map should include 'background' mapped to 0.")
1297
+
1298
+ self.class_map = class_map
1299
+ _LOGGER.info(f"Class map set: {class_map}")
1300
+ return self
1301
+
1302
+ @property
1303
+ def classes(self) -> List[str]:
1304
+ """Returns the list of class names, if set."""
1305
+ if self.class_map:
1306
+ return list(self.class_map.keys())
1307
+ return []
1308
+
1309
+ def split_data(self, val_size: float = 0.2, test_size: float = 0.0,
1310
+ random_state: Optional[int] = 42) -> 'DragonDatasetObjectDetection':
1311
+ """
1312
+ Splits the loaded image-annotation pairs into train, validation, and test sets.
1313
+
1314
+ Args:
1315
+ val_size (float): Proportion of the dataset to reserve for validation.
1316
+ test_size (float): Proportion of the dataset to reserve for testing.
1317
+ random_state (int | None): Seed for reproducible splits.
1318
+
1319
+ Returns:
1320
+ DragonDatasetObjectDetection: The same instance, now with datasets created.
1321
+ """
1322
+ if self._is_split:
1323
+ _LOGGER.warning("Data has already been split.")
1324
+ return self
1325
+
1326
+ if val_size + test_size >= 1.0:
1327
+ _LOGGER.error("The sum of val_size and test_size must be less than 1.")
1328
+ raise ValueError()
1329
+
1330
+ if not self.image_paths:
1331
+ _LOGGER.error("There is no data to split. Use .from_folders() first.")
1332
+ raise RuntimeError()
1333
+
1334
+ indices = list(range(len(self.image_paths)))
1335
+
1336
+ # Split indices
1337
+ train_indices, val_test_indices = train_test_split(
1338
+ indices, test_size=(val_size + test_size), random_state=random_state
1339
+ )
1340
+
1341
+ # Helper to get paths from indices
1342
+ def get_paths(idx_list):
1343
+ return [self.image_paths[i] for i in idx_list], [self.annotation_paths[i] for i in idx_list]
1344
+
1345
+ train_imgs, train_anns = get_paths(train_indices)
1346
+
1347
+ if test_size > 0:
1348
+ val_indices, test_indices = train_test_split(
1349
+ val_test_indices, test_size=(test_size / (val_size + test_size)),
1350
+ random_state=random_state
1351
+ )
1352
+ val_imgs, val_anns = get_paths(val_indices)
1353
+ test_imgs, test_anns = get_paths(test_indices)
1354
+
1355
+ self._test_dataset = _ObjectDetectionDataset(test_imgs, test_anns, transform=None)
1356
+ self._test_dataset.classes = self.classes # type: ignore
1357
+ _LOGGER.info(f"Test set created with {len(self._test_dataset)} images.")
1358
+ else:
1359
+ val_imgs, val_anns = get_paths(val_test_indices)
1360
+
1361
+ self._train_dataset = _ObjectDetectionDataset(train_imgs, train_anns, transform=None)
1362
+ self._val_dataset = _ObjectDetectionDataset(val_imgs, val_anns, transform=None)
1363
+
1364
+ # Propagate class names to datasets for MLTrainer
1365
+ self._train_dataset.classes = self.classes # type: ignore
1366
+ self._val_dataset.classes = self.classes # type: ignore
1367
+
1368
+ self._is_split = True
1369
+ _LOGGER.info(f"Data split into: \n- Training: {len(self._train_dataset)} images \n- Validation: {len(self._val_dataset)} images")
1370
+ return self
1371
+
1372
+ def configure_transforms(self,
1373
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
1374
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'DragonDatasetObjectDetection':
1375
+ """
1376
+ Configures and applies the image and target transformations.
1377
+
1378
+ This method must be called AFTER data is split.
1379
+
1380
+ For object detection models like Faster R-CNN, images are NOT
1381
+ resized or cropped, as the model handles variable input sizes.
1382
+ Transforms are limited to augmentation (flip), ToTensor, and Normalize.
1383
+
1384
+ Args:
1385
+ mean (List[float] | None): The mean values for image normalization.
1386
+ std (List[float] | None): The std dev values for image normalization.
1387
+
1388
+ Returns:
1389
+ DragonDatasetObjectDetection: The same instance, with transforms applied.
1390
+ """
1391
+ if not self._is_split:
1392
+ _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
1393
+ raise RuntimeError()
1394
+
1395
+ if (mean is None and std is not None) or (mean is not None and std is None):
1396
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
1397
+ raise ValueError()
1398
+
1399
+ if mean is not None and std is not None:
1400
+ # --- Store components for validation recipe ---
1401
+ self._val_recipe_components = {
1402
+ VisionTransformRecipeKeys.MEAN: mean,
1403
+ VisionTransformRecipeKeys.STD: std
1404
+ }
1405
+ self._has_mean_std = True
1406
+
1407
+ if self._has_mean_std:
1408
+ # --- Validation/Test Pipeline (Deterministic) ---
1409
+ self.val_transform = _OD_PairedCompose([
1410
+ _OD_PairedToTensor(),
1411
+ _OD_PairedNormalize(mean, std) # type: ignore
1412
+ ])
1413
+
1414
+ # --- Training Pipeline (Augmentation) ---
1415
+ self.train_transform = _OD_PairedCompose([
1416
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1417
+ _OD_PairedToTensor(),
1418
+ _OD_PairedNormalize(mean, std) # type: ignore
1419
+ ])
1420
+ else:
1421
+ # --- Validation/Test Pipeline (Deterministic) ---
1422
+ self.val_transform = _OD_PairedCompose([
1423
+ _OD_PairedToTensor()
1424
+ ])
1425
+
1426
+ # --- Training Pipeline (Augmentation) ---
1427
+ self.train_transform = _OD_PairedCompose([
1428
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1429
+ _OD_PairedToTensor()
1430
+ ])
1431
+
1432
+ # --- Apply Transforms to the Datasets ---
1433
+ self._train_dataset.transform = self.train_transform # type: ignore
1434
+ self._val_dataset.transform = self.val_transform # type: ignore
1435
+ if self._test_dataset:
1436
+ self._test_dataset.transform = self.val_transform # type: ignore
1437
+
1438
+ self._are_transforms_configured = True
1439
+ _LOGGER.info("Paired object detection transforms configured and applied.")
1440
+ return self
1441
+
1442
+ def get_datasets(self) -> Tuple[Dataset, ...]:
1443
+ """
1444
+ Returns the final train, validation, and optional test datasets.
1445
+
1446
+ Raises:
1447
+ RuntimeError: If called before data is split.
1448
+ RuntimeError: If called before transforms are configured.
1449
+ """
1450
+ if not self._is_split:
1451
+ _LOGGER.error("Data has not been split. Call .split_data() first.")
1452
+ raise RuntimeError()
1453
+ if not self._are_transforms_configured:
1454
+ _LOGGER.error("Transforms have not been configured. Call .configure_transforms() first.")
1455
+ raise RuntimeError()
1456
+
1457
+ if self._test_dataset:
1458
+ return self._train_dataset, self._val_dataset, self._test_dataset # type: ignore
1459
+ return self._train_dataset, self._val_dataset # type: ignore
1460
+
1461
+ @property
1462
+ def collate_fn(self) -> Callable:
1463
+ """
1464
+ Returns the collate function required by a DataLoader for this
1465
+ dataset. This function ensures that images and targets are
1466
+ batched as separate lists.
1467
+ """
1468
+ return _od_collate_fn
1469
+
1470
+ def save_transform_recipe(self, filepath: Union[str, Path]) -> None:
1471
+ """
1472
+ Saves the validation transform pipeline as a JSON recipe file.
1473
+
1474
+ For object detection, this recipe only includes ToTensor and
1475
+ Normalize, as resizing is handled by the model.
1476
+
1477
+ Args:
1478
+ filepath (str | Path): The path to save the .json recipe file.
1479
+ """
1480
+ if not self._are_transforms_configured:
1481
+ _LOGGER.error("Transforms are not configured. Call .configure_transforms() first.")
1482
+ raise RuntimeError()
1483
+
1484
+ components = self._val_recipe_components
1485
+
1486
+ # validate path
1487
+ file_path = make_fullpath(filepath, make=True, enforce="file")
1488
+
1489
+ # Add standard transforms
1490
+ recipe: Dict[str, Any] = {
1491
+ VisionTransformRecipeKeys.TASK: "object_detection",
1492
+ VisionTransformRecipeKeys.PIPELINE: [
1493
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
1494
+ ]
1495
+ }
1496
+
1497
+ if self._has_mean_std and components:
1498
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1499
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1500
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1501
+ "std": components[VisionTransformRecipeKeys.STD]
1502
+ }}
1503
+ )
1504
+
1505
+ # Save the file
1506
+ _save_recipe(recipe, file_path)
1507
+
1508
+ def images_per_dataset(self) -> str:
1509
+ """
1510
+ Get the number of images per dataset as a string.
1511
+ """
1512
+ if self._is_split:
1513
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1514
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1515
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1516
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images"
1517
+ else:
1518
+ _LOGGER.warning("No datasets found.")
1519
+ return "No datasets found"
1520
+
1521
+ def __repr__(self) -> str:
1522
+ s = f"<{self.__class__.__name__}>:\n"
1523
+ s += f" Total Image-Annotation Pairs: {len(self.image_paths)}\n"
1524
+ s += f" Split: {self._is_split}\n"
1525
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1526
+
1527
+ if self.class_map:
1528
+ s += f" Classes ({len(self.class_map)}): {list(self.class_map.keys())}\n"
1529
+
1530
+ if self._is_split:
1531
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1532
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1533
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1534
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1535
+
1536
+ return s
1537
+
1538
+
1539
+ def info():
1540
+ _script_info(__all__)