dragon-ml-toolbox 14.3.1__py3-none-any.whl → 14.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -273,8 +273,8 @@ class VisionDatasetMaker(_BaseMaker):
273
273
  for validation/testing.
274
274
  crop_size (int): The target size (square) for the final
275
275
  cropped image.
276
- mean (List[float]): The mean values for normalization (e.g., ImageNet mean).
277
- std (List[float]): The standard deviation values for normalization (e.g., ImageNet std).
276
+ mean (List[float] | None): The mean values for normalization (e.g., ImageNet mean).
277
+ std (List[float] | None): The standard deviation values for normalization (e.g., ImageNet std).
278
278
  extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
279
279
  pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
280
280
 
@@ -499,6 +499,39 @@ class VisionDatasetMaker(_BaseMaker):
499
499
 
500
500
  return self.class_map
501
501
 
502
+ def images_per_dataset(self) -> str:
503
+ """
504
+ Get the number of images per dataset as a string.
505
+ """
506
+ if self._is_split:
507
+ train_len = len(self._train_dataset) if self._train_dataset else 0
508
+ val_len = len(self._val_dataset) if self._val_dataset else 0
509
+ test_len = len(self._test_dataset) if self._test_dataset else 0
510
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
511
+ elif self._full_dataset:
512
+ return f"Full Dataset: {len(self._full_dataset)} images\n"
513
+ else:
514
+ _LOGGER.warning("No datasets found.")
515
+ return "No datasets found\n"
516
+
517
+ def __repr__(self) -> str:
518
+ s = f"<{self.__class__.__name__}>:\n"
519
+ s += f" Split: {self._is_split}\n"
520
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
521
+
522
+ if self.class_map:
523
+ s += f" Classes: {len(self.class_map)}\n"
524
+
525
+ if self._is_split:
526
+ train_len = len(self._train_dataset) if self._train_dataset else 0
527
+ val_len = len(self._val_dataset) if self._val_dataset else 0
528
+ test_len = len(self._test_dataset) if self._test_dataset else 0
529
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
530
+ elif self._full_dataset:
531
+ s += f" Full Dataset Size: {len(self._full_dataset)} images\n"
532
+
533
+ return s
534
+
502
535
 
503
536
  class _DatasetTransformer(Dataset):
504
537
  """
@@ -686,6 +719,7 @@ class SegmentationDatasetMaker(_BaseMaker):
686
719
  self._are_transforms_configured = False
687
720
  self.train_transform: Optional[Callable] = None
688
721
  self.val_transform: Optional[Callable] = None
722
+ self._has_mean_std: bool = False
689
723
 
690
724
  @classmethod
691
725
  def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'SegmentationDatasetMaker':
@@ -849,8 +883,8 @@ class SegmentationDatasetMaker(_BaseMaker):
849
883
  def configure_transforms(self,
850
884
  resize_size: int = 256,
851
885
  crop_size: int = 224,
852
- mean: List[float] = [0.485, 0.456, 0.406],
853
- std: List[float] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
886
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
887
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
854
888
  """
855
889
  Configures and applies the image and mask transformations.
856
890
 
@@ -861,8 +895,8 @@ class SegmentationDatasetMaker(_BaseMaker):
861
895
  for validation/testing.
862
896
  crop_size (int): The target size (square) for the final
863
897
  cropped image.
864
- mean (List[float]): The mean values for image normalization.
865
- std (List[float]): The std dev values for image normalization.
898
+ mean (List[float] | None): The mean values for image normalization.
899
+ std (List[float] | None): The std dev values for image normalization.
866
900
 
867
901
  Returns:
868
902
  SegmentationDatasetMaker: The same instance, with transforms applied.
@@ -871,29 +905,50 @@ class SegmentationDatasetMaker(_BaseMaker):
871
905
  _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
872
906
  raise RuntimeError()
873
907
 
908
+ if (mean is None and std is not None) or (mean is not None and std is None):
909
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
910
+ raise ValueError()
911
+
874
912
  # --- Store components for validation recipe ---
875
- self.val_recipe_components = {
913
+ self.val_recipe_components: dict[str,Any] = {
876
914
  VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
877
915
  VisionTransformRecipeKeys.CROP_SIZE: crop_size,
878
- VisionTransformRecipeKeys.MEAN: mean,
879
- VisionTransformRecipeKeys.STD: std
880
916
  }
917
+
918
+ if mean is not None and std is not None:
919
+ self.val_recipe_components.update({
920
+ VisionTransformRecipeKeys.MEAN: mean,
921
+ VisionTransformRecipeKeys.STD: std
922
+ })
923
+ self._has_mean_std = True
881
924
 
882
925
  # --- Validation/Test Pipeline (Deterministic) ---
883
- self.val_transform = _PairedCompose([
884
- _PairedResize(resize_size),
885
- _PairedCenterCrop(crop_size),
886
- _PairedToTensor(),
887
- _PairedNormalize(mean, std)
888
- ])
889
-
890
- # --- Training Pipeline (Augmentation) ---
891
- self.train_transform = _PairedCompose([
892
- _PairedRandomResizedCrop(crop_size),
893
- _PairedRandomHorizontalFlip(p=0.5),
894
- _PairedToTensor(),
895
- _PairedNormalize(mean, std)
896
- ])
926
+ if self._has_mean_std:
927
+ self.val_transform = _PairedCompose([
928
+ _PairedResize(resize_size),
929
+ _PairedCenterCrop(crop_size),
930
+ _PairedToTensor(),
931
+ _PairedNormalize(mean, std) # type: ignore
932
+ ])
933
+ # --- Training Pipeline (Augmentation) ---
934
+ self.train_transform = _PairedCompose([
935
+ _PairedRandomResizedCrop(crop_size),
936
+ _PairedRandomHorizontalFlip(p=0.5),
937
+ _PairedToTensor(),
938
+ _PairedNormalize(mean, std) # type: ignore
939
+ ])
940
+ else:
941
+ self.val_transform = _PairedCompose([
942
+ _PairedResize(resize_size),
943
+ _PairedCenterCrop(crop_size),
944
+ _PairedToTensor()
945
+ ])
946
+ # --- Training Pipeline (Augmentation) ---
947
+ self.train_transform = _PairedCompose([
948
+ _PairedRandomResizedCrop(crop_size),
949
+ _PairedRandomHorizontalFlip(p=0.5),
950
+ _PairedToTensor()
951
+ ])
897
952
 
898
953
  # --- Apply Transforms to the Datasets ---
899
954
  self._train_dataset.transform = self.train_transform # type: ignore
@@ -946,23 +1001,57 @@ class SegmentationDatasetMaker(_BaseMaker):
946
1001
 
947
1002
  # validate path
948
1003
  file_path = make_fullpath(filepath, make=True, enforce="file")
949
-
1004
+
950
1005
  # Add standard transforms
951
1006
  recipe: Dict[str, Any] = {
952
1007
  VisionTransformRecipeKeys.TASK: "segmentation",
953
1008
  VisionTransformRecipeKeys.PIPELINE: [
954
- {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components["resize_size"]}},
955
- {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components["crop_size"]}},
956
- {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
957
- {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
958
- "mean": components["mean"],
959
- "std": components["std"]
960
- }}
1009
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
1010
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
1011
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
961
1012
  ]
962
1013
  }
963
1014
 
1015
+ if self._has_mean_std:
1016
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1017
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1018
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1019
+ "std": components[VisionTransformRecipeKeys.STD]
1020
+ }}
1021
+ )
1022
+
964
1023
  # Save the file
965
1024
  save_recipe(recipe, file_path)
1025
+
1026
+ def images_per_dataset(self) -> str:
1027
+ """
1028
+ Get the number of images per dataset as a string.
1029
+ """
1030
+ if self._is_split:
1031
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1032
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1033
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1034
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
1035
+ else:
1036
+ _LOGGER.warning("No datasets found.")
1037
+ return "No datasets found\n"
1038
+
1039
+ def __repr__(self) -> str:
1040
+ s = f"<{self.__class__.__name__}>:\n"
1041
+ s += f" Total Image-Mask Pairs: {len(self.image_paths)}\n"
1042
+ s += f" Split: {self._is_split}\n"
1043
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1044
+
1045
+ if self.class_map:
1046
+ s += f" Classes: {list(self.class_map.keys())}\n"
1047
+
1048
+ if self._is_split:
1049
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1050
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1051
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1052
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1053
+
1054
+ return s
966
1055
 
967
1056
 
968
1057
  # Object detection
@@ -1114,6 +1203,7 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1114
1203
  self.train_transform: Optional[Callable] = None
1115
1204
  self.val_transform: Optional[Callable] = None
1116
1205
  self._val_recipe_components: Optional[Dict[str, Any]] = None
1206
+ self._has_mean_std: bool = False
1117
1207
 
1118
1208
  @classmethod
1119
1209
  def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'ObjectDetectionDatasetMaker':
@@ -1273,8 +1363,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1273
1363
  return self
1274
1364
 
1275
1365
  def configure_transforms(self,
1276
- mean: List[float] = [0.485, 0.456, 0.406],
1277
- std: List[float] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
1366
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
1367
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
1278
1368
  """
1279
1369
  Configures and applies the image and target transformations.
1280
1370
 
@@ -1285,8 +1375,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1285
1375
  Transforms are limited to augmentation (flip), ToTensor, and Normalize.
1286
1376
 
1287
1377
  Args:
1288
- mean (List[float]): The mean values for image normalization.
1289
- std (List[float]): The std dev values for image normalization.
1378
+ mean (List[float] | None): The mean values for image normalization.
1379
+ std (List[float] | None): The std dev values for image normalization.
1290
1380
 
1291
1381
  Returns:
1292
1382
  ObjectDetectionDatasetMaker: The same instance, with transforms applied.
@@ -1295,24 +1385,42 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1295
1385
  _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
1296
1386
  raise RuntimeError()
1297
1387
 
1298
- # --- Store components for validation recipe ---
1299
- self._val_recipe_components = {
1300
- VisionTransformRecipeKeys.MEAN: mean,
1301
- VisionTransformRecipeKeys.STD: std
1302
- }
1303
-
1304
- # --- Validation/Test Pipeline (Deterministic) ---
1305
- self.val_transform = _OD_PairedCompose([
1306
- _OD_PairedToTensor(),
1307
- _OD_PairedNormalize(mean, std)
1308
- ])
1388
+ if (mean is None and std is not None) or (mean is not None and std is None):
1389
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
1390
+ raise ValueError()
1309
1391
 
1310
- # --- Training Pipeline (Augmentation) ---
1311
- self.train_transform = _OD_PairedCompose([
1312
- _OD_PairedRandomHorizontalFlip(p=0.5),
1313
- _OD_PairedToTensor(),
1314
- _OD_PairedNormalize(mean, std)
1315
- ])
1392
+ if mean is not None and std is not None:
1393
+ # --- Store components for validation recipe ---
1394
+ self._val_recipe_components = {
1395
+ VisionTransformRecipeKeys.MEAN: mean,
1396
+ VisionTransformRecipeKeys.STD: std
1397
+ }
1398
+ self._has_mean_std = True
1399
+
1400
+ if self._has_mean_std:
1401
+ # --- Validation/Test Pipeline (Deterministic) ---
1402
+ self.val_transform = _OD_PairedCompose([
1403
+ _OD_PairedToTensor(),
1404
+ _OD_PairedNormalize(mean, std) # type: ignore
1405
+ ])
1406
+
1407
+ # --- Training Pipeline (Augmentation) ---
1408
+ self.train_transform = _OD_PairedCompose([
1409
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1410
+ _OD_PairedToTensor(),
1411
+ _OD_PairedNormalize(mean, std) # type: ignore
1412
+ ])
1413
+ else:
1414
+ # --- Validation/Test Pipeline (Deterministic) ---
1415
+ self.val_transform = _OD_PairedCompose([
1416
+ _OD_PairedToTensor()
1417
+ ])
1418
+
1419
+ # --- Training Pipeline (Augmentation) ---
1420
+ self.train_transform = _OD_PairedCompose([
1421
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1422
+ _OD_PairedToTensor()
1423
+ ])
1316
1424
 
1317
1425
  # --- Apply Transforms to the Datasets ---
1318
1426
  self._train_dataset.transform = self.train_transform # type: ignore
@@ -1368,10 +1476,6 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1368
1476
 
1369
1477
  components = self._val_recipe_components
1370
1478
 
1371
- if not components:
1372
- _LOGGER.error(f"Error getting the transformers recipe for validation set.")
1373
- raise ValueError()
1374
-
1375
1479
  # validate path
1376
1480
  file_path = make_fullpath(filepath, make=True, enforce="file")
1377
1481
 
@@ -1380,15 +1484,49 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1380
1484
  VisionTransformRecipeKeys.TASK: "object_detection",
1381
1485
  VisionTransformRecipeKeys.PIPELINE: [
1382
1486
  {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
1383
- {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1384
- "mean": components["mean"],
1385
- "std": components["std"]
1386
- }}
1387
1487
  ]
1388
1488
  }
1389
1489
 
1490
+ if self._has_mean_std and components:
1491
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1492
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1493
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1494
+ "std": components[VisionTransformRecipeKeys.STD]
1495
+ }}
1496
+ )
1497
+
1390
1498
  # Save the file
1391
1499
  save_recipe(recipe, file_path)
1500
+
1501
+ def images_per_dataset(self) -> str:
1502
+ """
1503
+ Get the number of images per dataset as a string.
1504
+ """
1505
+ if self._is_split:
1506
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1507
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1508
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1509
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
1510
+ else:
1511
+ _LOGGER.warning("No datasets found.")
1512
+ return "No datasets found\n"
1513
+
1514
+ def __repr__(self) -> str:
1515
+ s = f"<{self.__class__.__name__}>:\n"
1516
+ s += f" Total Image-Annotation Pairs: {len(self.image_paths)}\n"
1517
+ s += f" Split: {self._is_split}\n"
1518
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1519
+
1520
+ if self.class_map:
1521
+ s += f" Classes ({len(self.class_map)}): {list(self.class_map.keys())}\n"
1522
+
1523
+ if self._is_split:
1524
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1525
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1526
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1527
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1528
+
1529
+ return s
1392
1530
 
1393
1531
 
1394
1532
  def info():
@@ -47,12 +47,17 @@ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
47
47
  self.num_classes = num_classes
48
48
  self.in_channels = in_channels
49
49
  self.model_name = model_name
50
+ self._pretrained_default_transforms = None
50
51
 
51
52
  # --- 2. Instantiate the base model ---
52
53
  if init_with_pretrained:
53
54
  weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
54
55
  weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
55
56
 
57
+ # Save transformations for pretrained models
58
+ if weights:
59
+ self._pretrained_default_transforms = weights.transforms()
60
+
56
61
  if weights is None and init_with_pretrained:
57
62
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
58
63
  self.model = getattr(vision_models, model_name)(pretrained=True)
@@ -331,6 +336,7 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
331
336
  self.num_classes = num_classes
332
337
  self.in_channels = in_channels
333
338
  self.model_name = model_name
339
+ self._pretrained_default_transforms = None
334
340
 
335
341
  # --- 2. Instantiate the base model ---
336
342
  model_kwargs = {
@@ -343,6 +349,10 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
343
349
  weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
344
350
  weights = weights_enum.DEFAULT if weights_enum else None
345
351
 
352
+ # save pretrained model transformations
353
+ if weights:
354
+ self._pretrained_default_transforms = weights.transforms()
355
+
346
356
  if weights is None:
347
357
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
348
358
  # Legacy models used 'pretrained=True' and num_classes was separate
@@ -520,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
520
530
  This wrapper allows for customizing the model backbone, input channels,
521
531
  and the number of output classes for transfer learning.
522
532
 
523
- NOTE: This model is NOT compatible with the MLTrainer class.
533
+ NOTE: This model is NOT compatible with the MLTrainer class. Use the ObjectDetectionTrainer instead.
524
534
  """
525
535
  def __init__(self,
526
536
  num_classes: int,
@@ -550,6 +560,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
550
560
  self.num_classes = num_classes
551
561
  self.in_channels = in_channels
552
562
  self.model_name = model_name
563
+ self._pretrained_default_transforms = None
553
564
 
554
565
  # --- 2. Instantiate the base model ---
555
566
  model_constructor = getattr(detection_models, model_name)
@@ -560,6 +571,9 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
560
571
 
561
572
  weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
562
573
  weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
574
+
575
+ if weights:
576
+ self._pretrained_default_transforms = weights.transforms()
563
577
 
564
578
  self.model = model_constructor(weights=weights, weights_backbone=weights)
565
579
 
@@ -1,14 +1,18 @@
1
- from typing import Union, Dict, Type, Callable
1
+ from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
2
2
  from PIL import ImageOps, Image
3
+ from torchvision import transforms
4
+ from pathlib import Path
3
5
 
4
6
  from ._logger import _LOGGER
5
7
  from ._script_info import _script_info
6
8
  from .keys import VisionTransformRecipeKeys
9
+ from .path_manager import make_fullpath
7
10
 
8
11
 
9
12
  __all__ = [
10
13
  "TRANSFORM_REGISTRY",
11
- "ResizeAspectFill"
14
+ "ResizeAspectFill",
15
+ "create_offline_augmentations"
12
16
  ]
13
17
 
14
18
  # --- Custom Vision Transform Class ---
@@ -23,9 +27,8 @@ class ResizeAspectFill:
23
27
  """
24
28
  def __init__(self, pad_color: Union[str, int] = "black") -> None:
25
29
  self.pad_color = pad_color
26
- # Store kwargs to allow for recreation
30
+ # Store kwargs to allow for re-creation
27
31
  self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
28
- # self._kwargs = {"pad_color": pad_color}
29
32
 
30
33
  def __call__(self, image: Image.Image) -> Image.Image:
31
34
  if not isinstance(image, Image.Image):
@@ -47,12 +50,154 @@ class ResizeAspectFill:
47
50
  padding = (left_padding, 0, right_padding, 0)
48
51
 
49
52
  return ImageOps.expand(image, padding, fill=self.pad_color)
50
-
51
53
 
52
- #NOTE: Add custom transforms here.
54
+
55
+ #NOTE: Add custom transforms.
53
56
  TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
54
57
  "ResizeAspectFill": ResizeAspectFill,
55
58
  }
56
59
 
60
+
61
+ def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
62
+ """Internal helper to build a transform pipeline from a recipe dict."""
63
+ pipeline_steps: List[Callable] = []
64
+
65
+ if VisionTransformRecipeKeys.PIPELINE not in recipe:
66
+ _LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
67
+ raise ValueError("Invalid recipe format.")
68
+
69
+ for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
70
+ t_name = step.get(VisionTransformRecipeKeys.NAME)
71
+ t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
72
+
73
+ if not t_name:
74
+ _LOGGER.error(f"Invalid transform step, missing 'name': {step}")
75
+ continue
76
+
77
+ transform_class: Any = None
78
+
79
+ # 1. Check standard torchvision transforms
80
+ if hasattr(transforms, t_name):
81
+ transform_class = getattr(transforms, t_name)
82
+ # 2. Check custom transforms
83
+ elif t_name in TRANSFORM_REGISTRY:
84
+ transform_class = TRANSFORM_REGISTRY[t_name]
85
+ # 3. Not found
86
+ else:
87
+ _LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
88
+ raise ValueError(f"Unknown transform name: {t_name}")
89
+
90
+ # Instantiate the transform
91
+ try:
92
+ pipeline_steps.append(transform_class(**t_kwargs))
93
+ except Exception as e:
94
+ _LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
95
+ raise
96
+
97
+ return transforms.Compose(pipeline_steps)
98
+
99
+
100
+ def create_offline_augmentations(
101
+ input_directory: Union[str, Path],
102
+ output_directory: Union[str, Path],
103
+ results_per_image: int,
104
+ recipe: Optional[Dict[str, Any]] = None,
105
+ save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
106
+ save_quality: int = 80
107
+ ) -> None:
108
+ """
109
+ Reads all valid images from an input directory, applies augmentations,
110
+ and saves the new images to an output directory (offline augmentation).
111
+
112
+ Skips subdirectories in the input path.
113
+
114
+ Args:
115
+ input_directory (Union[str, Path]): Path to the directory of source images.
116
+ output_directory (Union[str, Path]): Path to save the augmented images.
117
+ results_per_image (int): The number of augmented versions to create
118
+ for each source image.
119
+ recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
120
+ a default set of strong, random
121
+ augmentations will be used.
122
+ save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
123
+ Defaults to "WEBP" for good compression.
124
+ save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
125
+ """
126
+ VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
127
+
128
+ # --- 1. Validate Paths ---
129
+ in_path = make_fullpath(input_directory, enforce="directory")
130
+ out_path = make_fullpath(output_directory, make=True, enforce="directory")
131
+
132
+ _LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
133
+
134
+ # --- 2. Find Images ---
135
+ image_files = [
136
+ f for f in in_path.iterdir()
137
+ if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
138
+ ]
139
+
140
+ if not image_files:
141
+ _LOGGER.warning(f"No valid image files found in {in_path}.")
142
+ return
143
+
144
+ _LOGGER.info(f"Found {len(image_files)} images to process.")
145
+
146
+ # --- 3. Define Transform Pipeline ---
147
+ transform_pipeline: transforms.Compose
148
+
149
+ if recipe:
150
+ _LOGGER.info("Building transformations from provided recipe.")
151
+ try:
152
+ transform_pipeline = _build_transform_from_recipe(recipe)
153
+ except Exception as e:
154
+ _LOGGER.error(f"Failed to build transform from recipe: {e}")
155
+ return
156
+ else:
157
+ _LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
158
+ # Default "random" pipeline
159
+ transform_pipeline = transforms.Compose([
160
+ transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
161
+ transforms.RandomHorizontalFlip(p=0.5),
162
+ transforms.RandomRotation(degrees=90),
163
+ transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
164
+ transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
165
+ transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
166
+ transforms.RandomApply([
167
+ transforms.GaussianBlur(kernel_size=3)
168
+ ], p=0.3)
169
+ ])
170
+
171
+ # --- 4. Process Images ---
172
+ total_saved = 0
173
+ format_upper = save_format.upper()
174
+
175
+ for img_path in image_files:
176
+ _LOGGER.debug(f"Processing {img_path.name}...")
177
+ try:
178
+ original_image = Image.open(img_path).convert("RGB")
179
+
180
+ for i in range(results_per_image):
181
+ new_stem = f"{img_path.stem}_aug_{i+1:03d}"
182
+ output_path = out_path / f"{new_stem}.{format_upper.lower()}"
183
+
184
+ # Apply transform
185
+ transformed_image = transform_pipeline(original_image)
186
+
187
+ # Save
188
+ transformed_image.save(
189
+ output_path,
190
+ format=format_upper,
191
+ quality=save_quality,
192
+ optimize=True # Add optimize flag
193
+ )
194
+ total_saved += 1
195
+
196
+ except Exception as e:
197
+ _LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
198
+
199
+ _LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
200
+
201
+
57
202
  def info():
58
203
  _script_info(__all__)