dragon-ml-toolbox 14.3.1__py3-none-any.whl → 14.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dragon-ml-toolbox might be problematic. Click here for more details.
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/METADATA +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/RECORD +17 -16
- ml_tools/ML_configuration.py +116 -0
- ml_tools/ML_datasetmaster.py +42 -0
- ml_tools/ML_evaluation.py +208 -63
- ml_tools/ML_evaluation_multi.py +40 -10
- ml_tools/ML_trainer.py +38 -12
- ml_tools/ML_utilities.py +50 -1
- ml_tools/ML_vision_datasetmaster.py +198 -60
- ml_tools/ML_vision_models.py +15 -1
- ml_tools/ML_vision_transformers.py +151 -6
- ml_tools/ensemble_evaluation.py +53 -10
- ml_tools/keys.py +2 -1
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/WHEEL +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/licenses/LICENSE-THIRD-PARTY.md +0 -0
- {dragon_ml_toolbox-14.3.1.dist-info → dragon_ml_toolbox-14.8.0.dist-info}/top_level.txt +0 -0
|
@@ -273,8 +273,8 @@ class VisionDatasetMaker(_BaseMaker):
|
|
|
273
273
|
for validation/testing.
|
|
274
274
|
crop_size (int): The target size (square) for the final
|
|
275
275
|
cropped image.
|
|
276
|
-
mean (List[float]): The mean values for normalization (e.g., ImageNet mean).
|
|
277
|
-
std (List[float]): The standard deviation values for normalization (e.g., ImageNet std).
|
|
276
|
+
mean (List[float] | None): The mean values for normalization (e.g., ImageNet mean).
|
|
277
|
+
std (List[float] | None): The standard deviation values for normalization (e.g., ImageNet std).
|
|
278
278
|
extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
|
|
279
279
|
pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
|
|
280
280
|
|
|
@@ -499,6 +499,39 @@ class VisionDatasetMaker(_BaseMaker):
|
|
|
499
499
|
|
|
500
500
|
return self.class_map
|
|
501
501
|
|
|
502
|
+
def images_per_dataset(self) -> str:
|
|
503
|
+
"""
|
|
504
|
+
Get the number of images per dataset as a string.
|
|
505
|
+
"""
|
|
506
|
+
if self._is_split:
|
|
507
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
508
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
509
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
510
|
+
return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
|
|
511
|
+
elif self._full_dataset:
|
|
512
|
+
return f"Full Dataset: {len(self._full_dataset)} images\n"
|
|
513
|
+
else:
|
|
514
|
+
_LOGGER.warning("No datasets found.")
|
|
515
|
+
return "No datasets found\n"
|
|
516
|
+
|
|
517
|
+
def __repr__(self) -> str:
|
|
518
|
+
s = f"<{self.__class__.__name__}>:\n"
|
|
519
|
+
s += f" Split: {self._is_split}\n"
|
|
520
|
+
s += f" Transforms Configured: {self._are_transforms_configured}\n"
|
|
521
|
+
|
|
522
|
+
if self.class_map:
|
|
523
|
+
s += f" Classes: {len(self.class_map)}\n"
|
|
524
|
+
|
|
525
|
+
if self._is_split:
|
|
526
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
527
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
528
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
529
|
+
s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
|
|
530
|
+
elif self._full_dataset:
|
|
531
|
+
s += f" Full Dataset Size: {len(self._full_dataset)} images\n"
|
|
532
|
+
|
|
533
|
+
return s
|
|
534
|
+
|
|
502
535
|
|
|
503
536
|
class _DatasetTransformer(Dataset):
|
|
504
537
|
"""
|
|
@@ -686,6 +719,7 @@ class SegmentationDatasetMaker(_BaseMaker):
|
|
|
686
719
|
self._are_transforms_configured = False
|
|
687
720
|
self.train_transform: Optional[Callable] = None
|
|
688
721
|
self.val_transform: Optional[Callable] = None
|
|
722
|
+
self._has_mean_std: bool = False
|
|
689
723
|
|
|
690
724
|
@classmethod
|
|
691
725
|
def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'SegmentationDatasetMaker':
|
|
@@ -849,8 +883,8 @@ class SegmentationDatasetMaker(_BaseMaker):
|
|
|
849
883
|
def configure_transforms(self,
|
|
850
884
|
resize_size: int = 256,
|
|
851
885
|
crop_size: int = 224,
|
|
852
|
-
mean: List[float] = [0.485, 0.456, 0.406],
|
|
853
|
-
std: List[float] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
|
|
886
|
+
mean: Optional[List[float]] = [0.485, 0.456, 0.406],
|
|
887
|
+
std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
|
|
854
888
|
"""
|
|
855
889
|
Configures and applies the image and mask transformations.
|
|
856
890
|
|
|
@@ -861,8 +895,8 @@ class SegmentationDatasetMaker(_BaseMaker):
|
|
|
861
895
|
for validation/testing.
|
|
862
896
|
crop_size (int): The target size (square) for the final
|
|
863
897
|
cropped image.
|
|
864
|
-
mean (List[float]): The mean values for image normalization.
|
|
865
|
-
std (List[float]): The std dev values for image normalization.
|
|
898
|
+
mean (List[float] | None): The mean values for image normalization.
|
|
899
|
+
std (List[float] | None): The std dev values for image normalization.
|
|
866
900
|
|
|
867
901
|
Returns:
|
|
868
902
|
SegmentationDatasetMaker: The same instance, with transforms applied.
|
|
@@ -871,29 +905,50 @@ class SegmentationDatasetMaker(_BaseMaker):
|
|
|
871
905
|
_LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
|
|
872
906
|
raise RuntimeError()
|
|
873
907
|
|
|
908
|
+
if (mean is None and std is not None) or (mean is not None and std is None):
|
|
909
|
+
_LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
|
|
910
|
+
raise ValueError()
|
|
911
|
+
|
|
874
912
|
# --- Store components for validation recipe ---
|
|
875
|
-
self.val_recipe_components = {
|
|
913
|
+
self.val_recipe_components: dict[str,Any] = {
|
|
876
914
|
VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
|
|
877
915
|
VisionTransformRecipeKeys.CROP_SIZE: crop_size,
|
|
878
|
-
VisionTransformRecipeKeys.MEAN: mean,
|
|
879
|
-
VisionTransformRecipeKeys.STD: std
|
|
880
916
|
}
|
|
917
|
+
|
|
918
|
+
if mean is not None and std is not None:
|
|
919
|
+
self.val_recipe_components.update({
|
|
920
|
+
VisionTransformRecipeKeys.MEAN: mean,
|
|
921
|
+
VisionTransformRecipeKeys.STD: std
|
|
922
|
+
})
|
|
923
|
+
self._has_mean_std = True
|
|
881
924
|
|
|
882
925
|
# --- Validation/Test Pipeline (Deterministic) ---
|
|
883
|
-
self.
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
926
|
+
if self._has_mean_std:
|
|
927
|
+
self.val_transform = _PairedCompose([
|
|
928
|
+
_PairedResize(resize_size),
|
|
929
|
+
_PairedCenterCrop(crop_size),
|
|
930
|
+
_PairedToTensor(),
|
|
931
|
+
_PairedNormalize(mean, std) # type: ignore
|
|
932
|
+
])
|
|
933
|
+
# --- Training Pipeline (Augmentation) ---
|
|
934
|
+
self.train_transform = _PairedCompose([
|
|
935
|
+
_PairedRandomResizedCrop(crop_size),
|
|
936
|
+
_PairedRandomHorizontalFlip(p=0.5),
|
|
937
|
+
_PairedToTensor(),
|
|
938
|
+
_PairedNormalize(mean, std) # type: ignore
|
|
939
|
+
])
|
|
940
|
+
else:
|
|
941
|
+
self.val_transform = _PairedCompose([
|
|
942
|
+
_PairedResize(resize_size),
|
|
943
|
+
_PairedCenterCrop(crop_size),
|
|
944
|
+
_PairedToTensor()
|
|
945
|
+
])
|
|
946
|
+
# --- Training Pipeline (Augmentation) ---
|
|
947
|
+
self.train_transform = _PairedCompose([
|
|
948
|
+
_PairedRandomResizedCrop(crop_size),
|
|
949
|
+
_PairedRandomHorizontalFlip(p=0.5),
|
|
950
|
+
_PairedToTensor()
|
|
951
|
+
])
|
|
897
952
|
|
|
898
953
|
# --- Apply Transforms to the Datasets ---
|
|
899
954
|
self._train_dataset.transform = self.train_transform # type: ignore
|
|
@@ -946,23 +1001,57 @@ class SegmentationDatasetMaker(_BaseMaker):
|
|
|
946
1001
|
|
|
947
1002
|
# validate path
|
|
948
1003
|
file_path = make_fullpath(filepath, make=True, enforce="file")
|
|
949
|
-
|
|
1004
|
+
|
|
950
1005
|
# Add standard transforms
|
|
951
1006
|
recipe: Dict[str, Any] = {
|
|
952
1007
|
VisionTransformRecipeKeys.TASK: "segmentation",
|
|
953
1008
|
VisionTransformRecipeKeys.PIPELINE: [
|
|
954
|
-
{VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[
|
|
955
|
-
{VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[
|
|
956
|
-
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
|
|
957
|
-
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
958
|
-
"mean": components["mean"],
|
|
959
|
-
"std": components["std"]
|
|
960
|
-
}}
|
|
1009
|
+
{VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
|
|
1010
|
+
{VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
|
|
1011
|
+
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
|
|
961
1012
|
]
|
|
962
1013
|
}
|
|
963
1014
|
|
|
1015
|
+
if self._has_mean_std:
|
|
1016
|
+
recipe[VisionTransformRecipeKeys.PIPELINE].append(
|
|
1017
|
+
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
1018
|
+
"mean": components[VisionTransformRecipeKeys.MEAN],
|
|
1019
|
+
"std": components[VisionTransformRecipeKeys.STD]
|
|
1020
|
+
}}
|
|
1021
|
+
)
|
|
1022
|
+
|
|
964
1023
|
# Save the file
|
|
965
1024
|
save_recipe(recipe, file_path)
|
|
1025
|
+
|
|
1026
|
+
def images_per_dataset(self) -> str:
|
|
1027
|
+
"""
|
|
1028
|
+
Get the number of images per dataset as a string.
|
|
1029
|
+
"""
|
|
1030
|
+
if self._is_split:
|
|
1031
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
1032
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
1033
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
1034
|
+
return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
|
|
1035
|
+
else:
|
|
1036
|
+
_LOGGER.warning("No datasets found.")
|
|
1037
|
+
return "No datasets found\n"
|
|
1038
|
+
|
|
1039
|
+
def __repr__(self) -> str:
|
|
1040
|
+
s = f"<{self.__class__.__name__}>:\n"
|
|
1041
|
+
s += f" Total Image-Mask Pairs: {len(self.image_paths)}\n"
|
|
1042
|
+
s += f" Split: {self._is_split}\n"
|
|
1043
|
+
s += f" Transforms Configured: {self._are_transforms_configured}\n"
|
|
1044
|
+
|
|
1045
|
+
if self.class_map:
|
|
1046
|
+
s += f" Classes: {list(self.class_map.keys())}\n"
|
|
1047
|
+
|
|
1048
|
+
if self._is_split:
|
|
1049
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
1050
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
1051
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
1052
|
+
s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
|
|
1053
|
+
|
|
1054
|
+
return s
|
|
966
1055
|
|
|
967
1056
|
|
|
968
1057
|
# Object detection
|
|
@@ -1114,6 +1203,7 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1114
1203
|
self.train_transform: Optional[Callable] = None
|
|
1115
1204
|
self.val_transform: Optional[Callable] = None
|
|
1116
1205
|
self._val_recipe_components: Optional[Dict[str, Any]] = None
|
|
1206
|
+
self._has_mean_std: bool = False
|
|
1117
1207
|
|
|
1118
1208
|
@classmethod
|
|
1119
1209
|
def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'ObjectDetectionDatasetMaker':
|
|
@@ -1273,8 +1363,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1273
1363
|
return self
|
|
1274
1364
|
|
|
1275
1365
|
def configure_transforms(self,
|
|
1276
|
-
mean: List[float] = [0.485, 0.456, 0.406],
|
|
1277
|
-
std: List[float] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
|
|
1366
|
+
mean: Optional[List[float]] = [0.485, 0.456, 0.406],
|
|
1367
|
+
std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
|
|
1278
1368
|
"""
|
|
1279
1369
|
Configures and applies the image and target transformations.
|
|
1280
1370
|
|
|
@@ -1285,8 +1375,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1285
1375
|
Transforms are limited to augmentation (flip), ToTensor, and Normalize.
|
|
1286
1376
|
|
|
1287
1377
|
Args:
|
|
1288
|
-
mean (List[float]): The mean values for image normalization.
|
|
1289
|
-
std (List[float]): The std dev values for image normalization.
|
|
1378
|
+
mean (List[float] | None): The mean values for image normalization.
|
|
1379
|
+
std (List[float] | None): The std dev values for image normalization.
|
|
1290
1380
|
|
|
1291
1381
|
Returns:
|
|
1292
1382
|
ObjectDetectionDatasetMaker: The same instance, with transforms applied.
|
|
@@ -1295,24 +1385,42 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1295
1385
|
_LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
|
|
1296
1386
|
raise RuntimeError()
|
|
1297
1387
|
|
|
1298
|
-
|
|
1299
|
-
|
|
1300
|
-
|
|
1301
|
-
VisionTransformRecipeKeys.STD: std
|
|
1302
|
-
}
|
|
1303
|
-
|
|
1304
|
-
# --- Validation/Test Pipeline (Deterministic) ---
|
|
1305
|
-
self.val_transform = _OD_PairedCompose([
|
|
1306
|
-
_OD_PairedToTensor(),
|
|
1307
|
-
_OD_PairedNormalize(mean, std)
|
|
1308
|
-
])
|
|
1388
|
+
if (mean is None and std is not None) or (mean is not None and std is None):
|
|
1389
|
+
_LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
|
|
1390
|
+
raise ValueError()
|
|
1309
1391
|
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
1392
|
+
if mean is not None and std is not None:
|
|
1393
|
+
# --- Store components for validation recipe ---
|
|
1394
|
+
self._val_recipe_components = {
|
|
1395
|
+
VisionTransformRecipeKeys.MEAN: mean,
|
|
1396
|
+
VisionTransformRecipeKeys.STD: std
|
|
1397
|
+
}
|
|
1398
|
+
self._has_mean_std = True
|
|
1399
|
+
|
|
1400
|
+
if self._has_mean_std:
|
|
1401
|
+
# --- Validation/Test Pipeline (Deterministic) ---
|
|
1402
|
+
self.val_transform = _OD_PairedCompose([
|
|
1403
|
+
_OD_PairedToTensor(),
|
|
1404
|
+
_OD_PairedNormalize(mean, std) # type: ignore
|
|
1405
|
+
])
|
|
1406
|
+
|
|
1407
|
+
# --- Training Pipeline (Augmentation) ---
|
|
1408
|
+
self.train_transform = _OD_PairedCompose([
|
|
1409
|
+
_OD_PairedRandomHorizontalFlip(p=0.5),
|
|
1410
|
+
_OD_PairedToTensor(),
|
|
1411
|
+
_OD_PairedNormalize(mean, std) # type: ignore
|
|
1412
|
+
])
|
|
1413
|
+
else:
|
|
1414
|
+
# --- Validation/Test Pipeline (Deterministic) ---
|
|
1415
|
+
self.val_transform = _OD_PairedCompose([
|
|
1416
|
+
_OD_PairedToTensor()
|
|
1417
|
+
])
|
|
1418
|
+
|
|
1419
|
+
# --- Training Pipeline (Augmentation) ---
|
|
1420
|
+
self.train_transform = _OD_PairedCompose([
|
|
1421
|
+
_OD_PairedRandomHorizontalFlip(p=0.5),
|
|
1422
|
+
_OD_PairedToTensor()
|
|
1423
|
+
])
|
|
1316
1424
|
|
|
1317
1425
|
# --- Apply Transforms to the Datasets ---
|
|
1318
1426
|
self._train_dataset.transform = self.train_transform # type: ignore
|
|
@@ -1368,10 +1476,6 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1368
1476
|
|
|
1369
1477
|
components = self._val_recipe_components
|
|
1370
1478
|
|
|
1371
|
-
if not components:
|
|
1372
|
-
_LOGGER.error(f"Error getting the transformers recipe for validation set.")
|
|
1373
|
-
raise ValueError()
|
|
1374
|
-
|
|
1375
1479
|
# validate path
|
|
1376
1480
|
file_path = make_fullpath(filepath, make=True, enforce="file")
|
|
1377
1481
|
|
|
@@ -1380,15 +1484,49 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
|
|
|
1380
1484
|
VisionTransformRecipeKeys.TASK: "object_detection",
|
|
1381
1485
|
VisionTransformRecipeKeys.PIPELINE: [
|
|
1382
1486
|
{VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
|
|
1383
|
-
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
1384
|
-
"mean": components["mean"],
|
|
1385
|
-
"std": components["std"]
|
|
1386
|
-
}}
|
|
1387
1487
|
]
|
|
1388
1488
|
}
|
|
1389
1489
|
|
|
1490
|
+
if self._has_mean_std and components:
|
|
1491
|
+
recipe[VisionTransformRecipeKeys.PIPELINE].append(
|
|
1492
|
+
{VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
|
|
1493
|
+
"mean": components[VisionTransformRecipeKeys.MEAN],
|
|
1494
|
+
"std": components[VisionTransformRecipeKeys.STD]
|
|
1495
|
+
}}
|
|
1496
|
+
)
|
|
1497
|
+
|
|
1390
1498
|
# Save the file
|
|
1391
1499
|
save_recipe(recipe, file_path)
|
|
1500
|
+
|
|
1501
|
+
def images_per_dataset(self) -> str:
|
|
1502
|
+
"""
|
|
1503
|
+
Get the number of images per dataset as a string.
|
|
1504
|
+
"""
|
|
1505
|
+
if self._is_split:
|
|
1506
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
1507
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
1508
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
1509
|
+
return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
|
|
1510
|
+
else:
|
|
1511
|
+
_LOGGER.warning("No datasets found.")
|
|
1512
|
+
return "No datasets found\n"
|
|
1513
|
+
|
|
1514
|
+
def __repr__(self) -> str:
|
|
1515
|
+
s = f"<{self.__class__.__name__}>:\n"
|
|
1516
|
+
s += f" Total Image-Annotation Pairs: {len(self.image_paths)}\n"
|
|
1517
|
+
s += f" Split: {self._is_split}\n"
|
|
1518
|
+
s += f" Transforms Configured: {self._are_transforms_configured}\n"
|
|
1519
|
+
|
|
1520
|
+
if self.class_map:
|
|
1521
|
+
s += f" Classes ({len(self.class_map)}): {list(self.class_map.keys())}\n"
|
|
1522
|
+
|
|
1523
|
+
if self._is_split:
|
|
1524
|
+
train_len = len(self._train_dataset) if self._train_dataset else 0
|
|
1525
|
+
val_len = len(self._val_dataset) if self._val_dataset else 0
|
|
1526
|
+
test_len = len(self._test_dataset) if self._test_dataset else 0
|
|
1527
|
+
s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
|
|
1528
|
+
|
|
1529
|
+
return s
|
|
1392
1530
|
|
|
1393
1531
|
|
|
1394
1532
|
def info():
|
ml_tools/ML_vision_models.py
CHANGED
|
@@ -47,12 +47,17 @@ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
47
47
|
self.num_classes = num_classes
|
|
48
48
|
self.in_channels = in_channels
|
|
49
49
|
self.model_name = model_name
|
|
50
|
+
self._pretrained_default_transforms = None
|
|
50
51
|
|
|
51
52
|
# --- 2. Instantiate the base model ---
|
|
52
53
|
if init_with_pretrained:
|
|
53
54
|
weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
|
|
54
55
|
weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
|
|
55
56
|
|
|
57
|
+
# Save transformations for pretrained models
|
|
58
|
+
if weights:
|
|
59
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
60
|
+
|
|
56
61
|
if weights is None and init_with_pretrained:
|
|
57
62
|
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
58
63
|
self.model = getattr(vision_models, model_name)(pretrained=True)
|
|
@@ -331,6 +336,7 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
331
336
|
self.num_classes = num_classes
|
|
332
337
|
self.in_channels = in_channels
|
|
333
338
|
self.model_name = model_name
|
|
339
|
+
self._pretrained_default_transforms = None
|
|
334
340
|
|
|
335
341
|
# --- 2. Instantiate the base model ---
|
|
336
342
|
model_kwargs = {
|
|
@@ -343,6 +349,10 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
|
|
|
343
349
|
weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
|
|
344
350
|
weights = weights_enum.DEFAULT if weights_enum else None
|
|
345
351
|
|
|
352
|
+
# save pretrained model transformations
|
|
353
|
+
if weights:
|
|
354
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
355
|
+
|
|
346
356
|
if weights is None:
|
|
347
357
|
_LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
|
|
348
358
|
# Legacy models used 'pretrained=True' and num_classes was separate
|
|
@@ -520,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
520
530
|
This wrapper allows for customizing the model backbone, input channels,
|
|
521
531
|
and the number of output classes for transfer learning.
|
|
522
532
|
|
|
523
|
-
NOTE: This model is NOT compatible with the MLTrainer class.
|
|
533
|
+
NOTE: This model is NOT compatible with the MLTrainer class. Use the ObjectDetectionTrainer instead.
|
|
524
534
|
"""
|
|
525
535
|
def __init__(self,
|
|
526
536
|
num_classes: int,
|
|
@@ -550,6 +560,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
550
560
|
self.num_classes = num_classes
|
|
551
561
|
self.in_channels = in_channels
|
|
552
562
|
self.model_name = model_name
|
|
563
|
+
self._pretrained_default_transforms = None
|
|
553
564
|
|
|
554
565
|
# --- 2. Instantiate the base model ---
|
|
555
566
|
model_constructor = getattr(detection_models, model_name)
|
|
@@ -560,6 +571,9 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
|
|
|
560
571
|
|
|
561
572
|
weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
|
|
562
573
|
weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
|
|
574
|
+
|
|
575
|
+
if weights:
|
|
576
|
+
self._pretrained_default_transforms = weights.transforms()
|
|
563
577
|
|
|
564
578
|
self.model = model_constructor(weights=weights, weights_backbone=weights)
|
|
565
579
|
|
|
@@ -1,14 +1,18 @@
|
|
|
1
|
-
from typing import Union, Dict, Type, Callable
|
|
1
|
+
from typing import Union, Dict, Type, Callable, Optional, Any, List, Literal
|
|
2
2
|
from PIL import ImageOps, Image
|
|
3
|
+
from torchvision import transforms
|
|
4
|
+
from pathlib import Path
|
|
3
5
|
|
|
4
6
|
from ._logger import _LOGGER
|
|
5
7
|
from ._script_info import _script_info
|
|
6
8
|
from .keys import VisionTransformRecipeKeys
|
|
9
|
+
from .path_manager import make_fullpath
|
|
7
10
|
|
|
8
11
|
|
|
9
12
|
__all__ = [
|
|
10
13
|
"TRANSFORM_REGISTRY",
|
|
11
|
-
"ResizeAspectFill"
|
|
14
|
+
"ResizeAspectFill",
|
|
15
|
+
"create_offline_augmentations"
|
|
12
16
|
]
|
|
13
17
|
|
|
14
18
|
# --- Custom Vision Transform Class ---
|
|
@@ -23,9 +27,8 @@ class ResizeAspectFill:
|
|
|
23
27
|
"""
|
|
24
28
|
def __init__(self, pad_color: Union[str, int] = "black") -> None:
|
|
25
29
|
self.pad_color = pad_color
|
|
26
|
-
# Store kwargs to allow for
|
|
30
|
+
# Store kwargs to allow for re-creation
|
|
27
31
|
self.__setattr__(VisionTransformRecipeKeys.KWARGS, {"pad_color": pad_color})
|
|
28
|
-
# self._kwargs = {"pad_color": pad_color}
|
|
29
32
|
|
|
30
33
|
def __call__(self, image: Image.Image) -> Image.Image:
|
|
31
34
|
if not isinstance(image, Image.Image):
|
|
@@ -47,12 +50,154 @@ class ResizeAspectFill:
|
|
|
47
50
|
padding = (left_padding, 0, right_padding, 0)
|
|
48
51
|
|
|
49
52
|
return ImageOps.expand(image, padding, fill=self.pad_color)
|
|
50
|
-
|
|
51
53
|
|
|
52
|
-
|
|
54
|
+
|
|
55
|
+
#NOTE: Add custom transforms.
|
|
53
56
|
TRANSFORM_REGISTRY: Dict[str, Type[Callable]] = {
|
|
54
57
|
"ResizeAspectFill": ResizeAspectFill,
|
|
55
58
|
}
|
|
56
59
|
|
|
60
|
+
|
|
61
|
+
def _build_transform_from_recipe(recipe: Dict[str, Any]) -> transforms.Compose:
|
|
62
|
+
"""Internal helper to build a transform pipeline from a recipe dict."""
|
|
63
|
+
pipeline_steps: List[Callable] = []
|
|
64
|
+
|
|
65
|
+
if VisionTransformRecipeKeys.PIPELINE not in recipe:
|
|
66
|
+
_LOGGER.error("Recipe dict is invalid: missing 'pipeline' key.")
|
|
67
|
+
raise ValueError("Invalid recipe format.")
|
|
68
|
+
|
|
69
|
+
for step in recipe[VisionTransformRecipeKeys.PIPELINE]:
|
|
70
|
+
t_name = step.get(VisionTransformRecipeKeys.NAME)
|
|
71
|
+
t_kwargs = step.get(VisionTransformRecipeKeys.KWARGS, {})
|
|
72
|
+
|
|
73
|
+
if not t_name:
|
|
74
|
+
_LOGGER.error(f"Invalid transform step, missing 'name': {step}")
|
|
75
|
+
continue
|
|
76
|
+
|
|
77
|
+
transform_class: Any = None
|
|
78
|
+
|
|
79
|
+
# 1. Check standard torchvision transforms
|
|
80
|
+
if hasattr(transforms, t_name):
|
|
81
|
+
transform_class = getattr(transforms, t_name)
|
|
82
|
+
# 2. Check custom transforms
|
|
83
|
+
elif t_name in TRANSFORM_REGISTRY:
|
|
84
|
+
transform_class = TRANSFORM_REGISTRY[t_name]
|
|
85
|
+
# 3. Not found
|
|
86
|
+
else:
|
|
87
|
+
_LOGGER.error(f"Unknown transform '{t_name}' in recipe. Not found in torchvision.transforms or TRANSFORM_REGISTRY.")
|
|
88
|
+
raise ValueError(f"Unknown transform name: {t_name}")
|
|
89
|
+
|
|
90
|
+
# Instantiate the transform
|
|
91
|
+
try:
|
|
92
|
+
pipeline_steps.append(transform_class(**t_kwargs))
|
|
93
|
+
except Exception as e:
|
|
94
|
+
_LOGGER.error(f"Failed to instantiate transform '{t_name}' with kwargs {t_kwargs}: {e}")
|
|
95
|
+
raise
|
|
96
|
+
|
|
97
|
+
return transforms.Compose(pipeline_steps)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def create_offline_augmentations(
|
|
101
|
+
input_directory: Union[str, Path],
|
|
102
|
+
output_directory: Union[str, Path],
|
|
103
|
+
results_per_image: int,
|
|
104
|
+
recipe: Optional[Dict[str, Any]] = None,
|
|
105
|
+
save_format: Literal["WEBP", "JPEG", "PNG", "BMP", "TIF"] = "WEBP",
|
|
106
|
+
save_quality: int = 80
|
|
107
|
+
) -> None:
|
|
108
|
+
"""
|
|
109
|
+
Reads all valid images from an input directory, applies augmentations,
|
|
110
|
+
and saves the new images to an output directory (offline augmentation).
|
|
111
|
+
|
|
112
|
+
Skips subdirectories in the input path.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
input_directory (Union[str, Path]): Path to the directory of source images.
|
|
116
|
+
output_directory (Union[str, Path]): Path to save the augmented images.
|
|
117
|
+
results_per_image (int): The number of augmented versions to create
|
|
118
|
+
for each source image.
|
|
119
|
+
recipe (Optional[Dict[str, Any]]): A transform recipe dictionary. If None,
|
|
120
|
+
a default set of strong, random
|
|
121
|
+
augmentations will be used.
|
|
122
|
+
save_format (str): The format to save images (e.g., "WEBP", "JPEG", "PNG").
|
|
123
|
+
Defaults to "WEBP" for good compression.
|
|
124
|
+
save_quality (int): The quality for lossy formats (1-100). Defaults to 80.
|
|
125
|
+
"""
|
|
126
|
+
VALID_IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.bmp', '.webp', '.tif', '.tiff')
|
|
127
|
+
|
|
128
|
+
# --- 1. Validate Paths ---
|
|
129
|
+
in_path = make_fullpath(input_directory, enforce="directory")
|
|
130
|
+
out_path = make_fullpath(output_directory, make=True, enforce="directory")
|
|
131
|
+
|
|
132
|
+
_LOGGER.info(f"Starting offline augmentation:\n\tInput: {in_path}\n\tOutput: {out_path}")
|
|
133
|
+
|
|
134
|
+
# --- 2. Find Images ---
|
|
135
|
+
image_files = [
|
|
136
|
+
f for f in in_path.iterdir()
|
|
137
|
+
if f.is_file() and f.suffix.lower() in VALID_IMG_EXTENSIONS
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if not image_files:
|
|
141
|
+
_LOGGER.warning(f"No valid image files found in {in_path}.")
|
|
142
|
+
return
|
|
143
|
+
|
|
144
|
+
_LOGGER.info(f"Found {len(image_files)} images to process.")
|
|
145
|
+
|
|
146
|
+
# --- 3. Define Transform Pipeline ---
|
|
147
|
+
transform_pipeline: transforms.Compose
|
|
148
|
+
|
|
149
|
+
if recipe:
|
|
150
|
+
_LOGGER.info("Building transformations from provided recipe.")
|
|
151
|
+
try:
|
|
152
|
+
transform_pipeline = _build_transform_from_recipe(recipe)
|
|
153
|
+
except Exception as e:
|
|
154
|
+
_LOGGER.error(f"Failed to build transform from recipe: {e}")
|
|
155
|
+
return
|
|
156
|
+
else:
|
|
157
|
+
_LOGGER.info("No recipe provided. Using default random augmentation pipeline.")
|
|
158
|
+
# Default "random" pipeline
|
|
159
|
+
transform_pipeline = transforms.Compose([
|
|
160
|
+
transforms.RandomResizedCrop(256, scale=(0.4, 1.0)),
|
|
161
|
+
transforms.RandomHorizontalFlip(p=0.5),
|
|
162
|
+
transforms.RandomRotation(degrees=90),
|
|
163
|
+
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.15),
|
|
164
|
+
transforms.RandomPerspective(distortion_scale=0.2, p=0.4),
|
|
165
|
+
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
|
|
166
|
+
transforms.RandomApply([
|
|
167
|
+
transforms.GaussianBlur(kernel_size=3)
|
|
168
|
+
], p=0.3)
|
|
169
|
+
])
|
|
170
|
+
|
|
171
|
+
# --- 4. Process Images ---
|
|
172
|
+
total_saved = 0
|
|
173
|
+
format_upper = save_format.upper()
|
|
174
|
+
|
|
175
|
+
for img_path in image_files:
|
|
176
|
+
_LOGGER.debug(f"Processing {img_path.name}...")
|
|
177
|
+
try:
|
|
178
|
+
original_image = Image.open(img_path).convert("RGB")
|
|
179
|
+
|
|
180
|
+
for i in range(results_per_image):
|
|
181
|
+
new_stem = f"{img_path.stem}_aug_{i+1:03d}"
|
|
182
|
+
output_path = out_path / f"{new_stem}.{format_upper.lower()}"
|
|
183
|
+
|
|
184
|
+
# Apply transform
|
|
185
|
+
transformed_image = transform_pipeline(original_image)
|
|
186
|
+
|
|
187
|
+
# Save
|
|
188
|
+
transformed_image.save(
|
|
189
|
+
output_path,
|
|
190
|
+
format=format_upper,
|
|
191
|
+
quality=save_quality,
|
|
192
|
+
optimize=True # Add optimize flag
|
|
193
|
+
)
|
|
194
|
+
total_saved += 1
|
|
195
|
+
|
|
196
|
+
except Exception as e:
|
|
197
|
+
_LOGGER.warning(f"Failed to process or save augmentations for {img_path.name}: {e}")
|
|
198
|
+
|
|
199
|
+
_LOGGER.info(f"Offline augmentation complete. Saved {total_saved} new images.")
|
|
200
|
+
|
|
201
|
+
|
|
57
202
|
def info():
|
|
58
203
|
_script_info(__all__)
|