dragon-ml-toolbox 14.3.0__py3-none-any.whl → 14.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dragon-ml-toolbox might be problematic. Click here for more details.

@@ -10,6 +10,7 @@ import torchvision.transforms.functional as TF
10
10
  from pathlib import Path
11
11
  import random
12
12
  import json
13
+ import inspect
13
14
 
14
15
  from .ML_datasetmaster import _BaseMaker
15
16
  from .path_manager import make_fullpath
@@ -272,8 +273,8 @@ class VisionDatasetMaker(_BaseMaker):
272
273
  for validation/testing.
273
274
  crop_size (int): The target size (square) for the final
274
275
  cropped image.
275
- mean (List[float]): The mean values for normalization (e.g., ImageNet mean).
276
- std (List[float]): The standard deviation values for normalization (e.g., ImageNet std).
276
+ mean (List[float] | None): The mean values for normalization (e.g., ImageNet mean).
277
+ std (List[float] | None): The standard deviation values for normalization (e.g., ImageNet std).
277
278
  extra_train_transforms (List[Callable] | None): A list of additional torchvision transforms to add to the end of the training transformations.
278
279
  pre_transforms (List[Callable] | None): An list of transforms to be applied at the very beginning of the transformations for all sets.
279
280
 
@@ -411,16 +412,58 @@ class VisionDatasetMaker(_BaseMaker):
411
412
  # validate path
412
413
  file_path = make_fullpath(filepath, make=True, enforce="file")
413
414
 
414
- # 1. Handle pre_transforms
415
+ # Handle pre_transforms
415
416
  for t in components[VisionTransformRecipeKeys.PRE_TRANSFORMS]:
416
417
  t_name = t.__class__.__name__
418
+ t_class = t.__class__
419
+ kwargs = {}
420
+
421
+ # 1. Check custom registry first
417
422
  if t_name in TRANSFORM_REGISTRY:
418
- recipe[VisionTransformRecipeKeys.PIPELINE].append({
419
- VisionTransformRecipeKeys.NAME: t_name,
420
- VisionTransformRecipeKeys.KWARGS: getattr(t, VisionTransformRecipeKeys.KWARGS, {})
421
- })
423
+ _LOGGER.debug(f"Found '{t_name}' in TRANSFORM_REGISTRY.")
424
+ kwargs = getattr(t, VisionTransformRecipeKeys.KWARGS, {})
425
+
426
+ # 2. Else, try to introspect for standard torchvision transforms
422
427
  else:
423
- _LOGGER.warning(f"Skipping unknown pre_transform '{t_name}' in recipe. Not in TRANSFORM_REGISTRY.")
428
+ _LOGGER.debug(f"'{t_name}' not in registry. Attempting introspection...")
429
+ try:
430
+ # Get the __init__ signature of the transform's class
431
+ sig = inspect.signature(t_class.__init__)
432
+
433
+ # Iterate over its __init__ parameters (e.g., 'num_output_channels')
434
+ for param in sig.parameters.values():
435
+ if param.name == 'self':
436
+ continue
437
+
438
+ # Check if the *instance* 't' has that parameter as an attribute
439
+ attr_name_public = param.name
440
+ attr_name_private = '_' + param.name
441
+
442
+ attr_to_get = ""
443
+
444
+ if hasattr(t, attr_name_public):
445
+ attr_to_get = attr_name_public
446
+ elif hasattr(t, attr_name_private):
447
+ attr_to_get = attr_name_private
448
+ else:
449
+ # Parameter in __init__ has no matching attribute
450
+ continue
451
+
452
+ # Store the value under the __init__ parameter's name
453
+ kwargs[param.name] = getattr(t, attr_to_get)
454
+
455
+ _LOGGER.debug(f"Introspection for '{t_name}' found kwargs: {kwargs}")
456
+
457
+ except (ValueError, TypeError):
458
+ # Fails on some built-ins or C-implemented __init__
459
+ _LOGGER.warning(f"Could not introspect parameters for '{t_name}'. If this transform has parameters, they will not be saved.")
460
+ kwargs = {}
461
+
462
+ # 3. Add to pipeline
463
+ recipe[VisionTransformRecipeKeys.PIPELINE].append({
464
+ VisionTransformRecipeKeys.NAME: t_name,
465
+ VisionTransformRecipeKeys.KWARGS: kwargs
466
+ })
424
467
 
425
468
  # 2. Add standard transforms
426
469
  recipe[VisionTransformRecipeKeys.PIPELINE].extend([
@@ -456,6 +499,39 @@ class VisionDatasetMaker(_BaseMaker):
456
499
 
457
500
  return self.class_map
458
501
 
502
+ def images_per_dataset(self) -> str:
503
+ """
504
+ Get the number of images per dataset as a string.
505
+ """
506
+ if self._is_split:
507
+ train_len = len(self._train_dataset) if self._train_dataset else 0
508
+ val_len = len(self._val_dataset) if self._val_dataset else 0
509
+ test_len = len(self._test_dataset) if self._test_dataset else 0
510
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
511
+ elif self._full_dataset:
512
+ return f"Full Dataset: {len(self._full_dataset)} images\n"
513
+ else:
514
+ _LOGGER.warning("No datasets found.")
515
+ return "No datasets found\n"
516
+
517
+ def __repr__(self) -> str:
518
+ s = f"<{self.__class__.__name__}>:\n"
519
+ s += f" Split: {self._is_split}\n"
520
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
521
+
522
+ if self.class_map:
523
+ s += f" Classes: {len(self.class_map)}\n"
524
+
525
+ if self._is_split:
526
+ train_len = len(self._train_dataset) if self._train_dataset else 0
527
+ val_len = len(self._val_dataset) if self._val_dataset else 0
528
+ test_len = len(self._test_dataset) if self._test_dataset else 0
529
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
530
+ elif self._full_dataset:
531
+ s += f" Full Dataset Size: {len(self._full_dataset)} images\n"
532
+
533
+ return s
534
+
459
535
 
460
536
  class _DatasetTransformer(Dataset):
461
537
  """
@@ -643,6 +719,7 @@ class SegmentationDatasetMaker(_BaseMaker):
643
719
  self._are_transforms_configured = False
644
720
  self.train_transform: Optional[Callable] = None
645
721
  self.val_transform: Optional[Callable] = None
722
+ self._has_mean_std: bool = False
646
723
 
647
724
  @classmethod
648
725
  def from_folders(cls, image_dir: Union[str, Path], mask_dir: Union[str, Path]) -> 'SegmentationDatasetMaker':
@@ -806,8 +883,8 @@ class SegmentationDatasetMaker(_BaseMaker):
806
883
  def configure_transforms(self,
807
884
  resize_size: int = 256,
808
885
  crop_size: int = 224,
809
- mean: List[float] = [0.485, 0.456, 0.406],
810
- std: List[float] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
886
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
887
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'SegmentationDatasetMaker':
811
888
  """
812
889
  Configures and applies the image and mask transformations.
813
890
 
@@ -818,8 +895,8 @@ class SegmentationDatasetMaker(_BaseMaker):
818
895
  for validation/testing.
819
896
  crop_size (int): The target size (square) for the final
820
897
  cropped image.
821
- mean (List[float]): The mean values for image normalization.
822
- std (List[float]): The std dev values for image normalization.
898
+ mean (List[float] | None): The mean values for image normalization.
899
+ std (List[float] | None): The std dev values for image normalization.
823
900
 
824
901
  Returns:
825
902
  SegmentationDatasetMaker: The same instance, with transforms applied.
@@ -828,29 +905,50 @@ class SegmentationDatasetMaker(_BaseMaker):
828
905
  _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
829
906
  raise RuntimeError()
830
907
 
908
+ if (mean is None and std is not None) or (mean is not None and std is None):
909
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
910
+ raise ValueError()
911
+
831
912
  # --- Store components for validation recipe ---
832
- self.val_recipe_components = {
913
+ self.val_recipe_components: dict[str,Any] = {
833
914
  VisionTransformRecipeKeys.RESIZE_SIZE: resize_size,
834
915
  VisionTransformRecipeKeys.CROP_SIZE: crop_size,
835
- VisionTransformRecipeKeys.MEAN: mean,
836
- VisionTransformRecipeKeys.STD: std
837
916
  }
917
+
918
+ if mean is not None and std is not None:
919
+ self.val_recipe_components.update({
920
+ VisionTransformRecipeKeys.MEAN: mean,
921
+ VisionTransformRecipeKeys.STD: std
922
+ })
923
+ self._has_mean_std = True
838
924
 
839
925
  # --- Validation/Test Pipeline (Deterministic) ---
840
- self.val_transform = _PairedCompose([
841
- _PairedResize(resize_size),
842
- _PairedCenterCrop(crop_size),
843
- _PairedToTensor(),
844
- _PairedNormalize(mean, std)
845
- ])
846
-
847
- # --- Training Pipeline (Augmentation) ---
848
- self.train_transform = _PairedCompose([
849
- _PairedRandomResizedCrop(crop_size),
850
- _PairedRandomHorizontalFlip(p=0.5),
851
- _PairedToTensor(),
852
- _PairedNormalize(mean, std)
853
- ])
926
+ if self._has_mean_std:
927
+ self.val_transform = _PairedCompose([
928
+ _PairedResize(resize_size),
929
+ _PairedCenterCrop(crop_size),
930
+ _PairedToTensor(),
931
+ _PairedNormalize(mean, std) # type: ignore
932
+ ])
933
+ # --- Training Pipeline (Augmentation) ---
934
+ self.train_transform = _PairedCompose([
935
+ _PairedRandomResizedCrop(crop_size),
936
+ _PairedRandomHorizontalFlip(p=0.5),
937
+ _PairedToTensor(),
938
+ _PairedNormalize(mean, std) # type: ignore
939
+ ])
940
+ else:
941
+ self.val_transform = _PairedCompose([
942
+ _PairedResize(resize_size),
943
+ _PairedCenterCrop(crop_size),
944
+ _PairedToTensor()
945
+ ])
946
+ # --- Training Pipeline (Augmentation) ---
947
+ self.train_transform = _PairedCompose([
948
+ _PairedRandomResizedCrop(crop_size),
949
+ _PairedRandomHorizontalFlip(p=0.5),
950
+ _PairedToTensor()
951
+ ])
854
952
 
855
953
  # --- Apply Transforms to the Datasets ---
856
954
  self._train_dataset.transform = self.train_transform # type: ignore
@@ -903,23 +1001,57 @@ class SegmentationDatasetMaker(_BaseMaker):
903
1001
 
904
1002
  # validate path
905
1003
  file_path = make_fullpath(filepath, make=True, enforce="file")
906
-
1004
+
907
1005
  # Add standard transforms
908
1006
  recipe: Dict[str, Any] = {
909
1007
  VisionTransformRecipeKeys.TASK: "segmentation",
910
1008
  VisionTransformRecipeKeys.PIPELINE: [
911
- {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components["resize_size"]}},
912
- {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components["crop_size"]}},
913
- {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
914
- {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
915
- "mean": components["mean"],
916
- "std": components["std"]
917
- }}
1009
+ {VisionTransformRecipeKeys.NAME: "Resize", "kwargs": {"size": components[VisionTransformRecipeKeys.RESIZE_SIZE]}},
1010
+ {VisionTransformRecipeKeys.NAME: "CenterCrop", "kwargs": {"size": components[VisionTransformRecipeKeys.CROP_SIZE]}},
1011
+ {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}}
918
1012
  ]
919
1013
  }
920
1014
 
1015
+ if self._has_mean_std:
1016
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1017
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1018
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1019
+ "std": components[VisionTransformRecipeKeys.STD]
1020
+ }}
1021
+ )
1022
+
921
1023
  # Save the file
922
1024
  save_recipe(recipe, file_path)
1025
+
1026
+ def images_per_dataset(self) -> str:
1027
+ """
1028
+ Get the number of images per dataset as a string.
1029
+ """
1030
+ if self._is_split:
1031
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1032
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1033
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1034
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
1035
+ else:
1036
+ _LOGGER.warning("No datasets found.")
1037
+ return "No datasets found\n"
1038
+
1039
+ def __repr__(self) -> str:
1040
+ s = f"<{self.__class__.__name__}>:\n"
1041
+ s += f" Total Image-Mask Pairs: {len(self.image_paths)}\n"
1042
+ s += f" Split: {self._is_split}\n"
1043
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1044
+
1045
+ if self.class_map:
1046
+ s += f" Classes: {list(self.class_map.keys())}\n"
1047
+
1048
+ if self._is_split:
1049
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1050
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1051
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1052
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1053
+
1054
+ return s
923
1055
 
924
1056
 
925
1057
  # Object detection
@@ -1071,6 +1203,7 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1071
1203
  self.train_transform: Optional[Callable] = None
1072
1204
  self.val_transform: Optional[Callable] = None
1073
1205
  self._val_recipe_components: Optional[Dict[str, Any]] = None
1206
+ self._has_mean_std: bool = False
1074
1207
 
1075
1208
  @classmethod
1076
1209
  def from_folders(cls, image_dir: Union[str, Path], annotation_dir: Union[str, Path]) -> 'ObjectDetectionDatasetMaker':
@@ -1230,8 +1363,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1230
1363
  return self
1231
1364
 
1232
1365
  def configure_transforms(self,
1233
- mean: List[float] = [0.485, 0.456, 0.406],
1234
- std: List[float] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
1366
+ mean: Optional[List[float]] = [0.485, 0.456, 0.406],
1367
+ std: Optional[List[float]] = [0.229, 0.224, 0.225]) -> 'ObjectDetectionDatasetMaker':
1235
1368
  """
1236
1369
  Configures and applies the image and target transformations.
1237
1370
 
@@ -1242,8 +1375,8 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1242
1375
  Transforms are limited to augmentation (flip), ToTensor, and Normalize.
1243
1376
 
1244
1377
  Args:
1245
- mean (List[float]): The mean values for image normalization.
1246
- std (List[float]): The std dev values for image normalization.
1378
+ mean (List[float] | None): The mean values for image normalization.
1379
+ std (List[float] | None): The std dev values for image normalization.
1247
1380
 
1248
1381
  Returns:
1249
1382
  ObjectDetectionDatasetMaker: The same instance, with transforms applied.
@@ -1252,24 +1385,42 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1252
1385
  _LOGGER.error("Transforms must be configured AFTER splitting data. Call .split_data() first.")
1253
1386
  raise RuntimeError()
1254
1387
 
1255
- # --- Store components for validation recipe ---
1256
- self._val_recipe_components = {
1257
- VisionTransformRecipeKeys.MEAN: mean,
1258
- VisionTransformRecipeKeys.STD: std
1259
- }
1260
-
1261
- # --- Validation/Test Pipeline (Deterministic) ---
1262
- self.val_transform = _OD_PairedCompose([
1263
- _OD_PairedToTensor(),
1264
- _OD_PairedNormalize(mean, std)
1265
- ])
1388
+ if (mean is None and std is not None) or (mean is not None and std is None):
1389
+ _LOGGER.error(f"'mean' and 'std' must be both None or both defined, but only one was provided.")
1390
+ raise ValueError()
1266
1391
 
1267
- # --- Training Pipeline (Augmentation) ---
1268
- self.train_transform = _OD_PairedCompose([
1269
- _OD_PairedRandomHorizontalFlip(p=0.5),
1270
- _OD_PairedToTensor(),
1271
- _OD_PairedNormalize(mean, std)
1272
- ])
1392
+ if mean is not None and std is not None:
1393
+ # --- Store components for validation recipe ---
1394
+ self._val_recipe_components = {
1395
+ VisionTransformRecipeKeys.MEAN: mean,
1396
+ VisionTransformRecipeKeys.STD: std
1397
+ }
1398
+ self._has_mean_std = True
1399
+
1400
+ if self._has_mean_std:
1401
+ # --- Validation/Test Pipeline (Deterministic) ---
1402
+ self.val_transform = _OD_PairedCompose([
1403
+ _OD_PairedToTensor(),
1404
+ _OD_PairedNormalize(mean, std) # type: ignore
1405
+ ])
1406
+
1407
+ # --- Training Pipeline (Augmentation) ---
1408
+ self.train_transform = _OD_PairedCompose([
1409
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1410
+ _OD_PairedToTensor(),
1411
+ _OD_PairedNormalize(mean, std) # type: ignore
1412
+ ])
1413
+ else:
1414
+ # --- Validation/Test Pipeline (Deterministic) ---
1415
+ self.val_transform = _OD_PairedCompose([
1416
+ _OD_PairedToTensor()
1417
+ ])
1418
+
1419
+ # --- Training Pipeline (Augmentation) ---
1420
+ self.train_transform = _OD_PairedCompose([
1421
+ _OD_PairedRandomHorizontalFlip(p=0.5),
1422
+ _OD_PairedToTensor()
1423
+ ])
1273
1424
 
1274
1425
  # --- Apply Transforms to the Datasets ---
1275
1426
  self._train_dataset.transform = self.train_transform # type: ignore
@@ -1325,10 +1476,6 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1325
1476
 
1326
1477
  components = self._val_recipe_components
1327
1478
 
1328
- if not components:
1329
- _LOGGER.error(f"Error getting the transformers recipe for validation set.")
1330
- raise ValueError()
1331
-
1332
1479
  # validate path
1333
1480
  file_path = make_fullpath(filepath, make=True, enforce="file")
1334
1481
 
@@ -1337,15 +1484,49 @@ class ObjectDetectionDatasetMaker(_BaseMaker):
1337
1484
  VisionTransformRecipeKeys.TASK: "object_detection",
1338
1485
  VisionTransformRecipeKeys.PIPELINE: [
1339
1486
  {VisionTransformRecipeKeys.NAME: "ToTensor", "kwargs": {}},
1340
- {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1341
- "mean": components["mean"],
1342
- "std": components["std"]
1343
- }}
1344
1487
  ]
1345
1488
  }
1346
1489
 
1490
+ if self._has_mean_std and components:
1491
+ recipe[VisionTransformRecipeKeys.PIPELINE].append(
1492
+ {VisionTransformRecipeKeys.NAME: "Normalize", "kwargs": {
1493
+ "mean": components[VisionTransformRecipeKeys.MEAN],
1494
+ "std": components[VisionTransformRecipeKeys.STD]
1495
+ }}
1496
+ )
1497
+
1347
1498
  # Save the file
1348
1499
  save_recipe(recipe, file_path)
1500
+
1501
+ def images_per_dataset(self) -> str:
1502
+ """
1503
+ Get the number of images per dataset as a string.
1504
+ """
1505
+ if self._is_split:
1506
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1507
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1508
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1509
+ return f"Train | Validation | Test: {train_len} | {val_len} | {test_len} images\n"
1510
+ else:
1511
+ _LOGGER.warning("No datasets found.")
1512
+ return "No datasets found\n"
1513
+
1514
+ def __repr__(self) -> str:
1515
+ s = f"<{self.__class__.__name__}>:\n"
1516
+ s += f" Total Image-Annotation Pairs: {len(self.image_paths)}\n"
1517
+ s += f" Split: {self._is_split}\n"
1518
+ s += f" Transforms Configured: {self._are_transforms_configured}\n"
1519
+
1520
+ if self.class_map:
1521
+ s += f" Classes ({len(self.class_map)}): {list(self.class_map.keys())}\n"
1522
+
1523
+ if self._is_split:
1524
+ train_len = len(self._train_dataset) if self._train_dataset else 0
1525
+ val_len = len(self._val_dataset) if self._val_dataset else 0
1526
+ test_len = len(self._test_dataset) if self._test_dataset else 0
1527
+ s += f" Datasets (Train|Val|Test): {train_len} | {val_len} | {test_len}\n"
1528
+
1529
+ return s
1349
1530
 
1350
1531
 
1351
1532
  def info():
@@ -47,12 +47,17 @@ class _BaseVisionWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
47
47
  self.num_classes = num_classes
48
48
  self.in_channels = in_channels
49
49
  self.model_name = model_name
50
+ self._pretrained_default_transforms = None
50
51
 
51
52
  # --- 2. Instantiate the base model ---
52
53
  if init_with_pretrained:
53
54
  weights_enum = getattr(vision_models, weights_enum_name, None) if weights_enum_name else None
54
55
  weights = weights_enum.IMAGENET1K_V1 if weights_enum else None
55
56
 
57
+ # Save transformations for pretrained models
58
+ if weights:
59
+ self._pretrained_default_transforms = weights.transforms()
60
+
56
61
  if weights is None and init_with_pretrained:
57
62
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
58
63
  self.model = getattr(vision_models, model_name)(pretrained=True)
@@ -331,6 +336,7 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
331
336
  self.num_classes = num_classes
332
337
  self.in_channels = in_channels
333
338
  self.model_name = model_name
339
+ self._pretrained_default_transforms = None
334
340
 
335
341
  # --- 2. Instantiate the base model ---
336
342
  model_kwargs = {
@@ -343,6 +349,10 @@ class _BaseSegmentationWrapper(nn.Module, _ArchitectureHandlerMixin, ABC):
343
349
  weights_enum = getattr(vision_models.segmentation, weights_enum_name, None) if weights_enum_name else None
344
350
  weights = weights_enum.DEFAULT if weights_enum else None
345
351
 
352
+ # save pretrained model transformations
353
+ if weights:
354
+ self._pretrained_default_transforms = weights.transforms()
355
+
346
356
  if weights is None:
347
357
  _LOGGER.warning(f"Could not find modern weights for {model_name}. Using 'pretrained=True' legacy fallback.")
348
358
  # Legacy models used 'pretrained=True' and num_classes was separate
@@ -520,7 +530,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
520
530
  This wrapper allows for customizing the model backbone, input channels,
521
531
  and the number of output classes for transfer learning.
522
532
 
523
- NOTE: This model is NOT compatible with the MLTrainer class.
533
+ NOTE: This model is NOT compatible with the MLTrainer class. Use the ObjectDetectionTrainer instead.
524
534
  """
525
535
  def __init__(self,
526
536
  num_classes: int,
@@ -550,6 +560,7 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
550
560
  self.num_classes = num_classes
551
561
  self.in_channels = in_channels
552
562
  self.model_name = model_name
563
+ self._pretrained_default_transforms = None
553
564
 
554
565
  # --- 2. Instantiate the base model ---
555
566
  model_constructor = getattr(detection_models, model_name)
@@ -560,6 +571,9 @@ class DragonFastRCNN(nn.Module, _ArchitectureHandlerMixin):
560
571
 
561
572
  weights_enum = getattr(detection_models, weights_enum_name, None) if weights_enum_name else None
562
573
  weights = weights_enum.DEFAULT if weights_enum and init_with_pretrained else None
574
+
575
+ if weights:
576
+ self._pretrained_default_transforms = weights.transforms()
563
577
 
564
578
  self.model = model_constructor(weights=weights, weights_backbone=weights)
565
579