autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,4 +1,7 @@
1
+ import ast
2
+ import copy
1
3
  import logging
4
+ import random
2
5
  import warnings
3
6
  from io import BytesIO
4
7
  from typing import Callable, Dict, List, Optional, Union
@@ -6,24 +9,25 @@ from typing import Callable, Dict, List, Optional, Union
6
9
  import numpy as np
7
10
  import PIL
8
11
  import torch
12
+ from omegaconf import ListConfig
9
13
  from PIL import ImageFile
10
14
  from torch import nn
11
15
  from torchvision import transforms
12
16
 
13
- from .utils import construct_image_processor, image_mean_std
17
+ from .randaug import RandAugment
18
+ from .trivial_augmenter import TrivialAugment
14
19
 
15
20
  try:
16
21
  from torchvision.transforms import InterpolationMode
17
22
 
18
23
  BICUBIC = InterpolationMode.BICUBIC
24
+ NEAREST = InterpolationMode.NEAREST
19
25
  except ImportError:
20
26
  BICUBIC = PIL.Image.BICUBIC
27
+ NEAREST = PIL.Image.NEAREST
21
28
 
22
- from ..constants import CLIP, COLUMN, IMAGE, IMAGE_BASE64_STR, IMAGE_BYTEARRAY, IMAGE_VALID_NUM, TIMM_IMAGE
23
- from ..models.clip import CLIPForImageText
24
- from ..models.timm_image import TimmAutoModelForImagePrediction
29
+ from ..constants import COLUMN, IMAGE, IMAGE_BASE64_STR, IMAGE_BYTEARRAY, IMAGE_VALID_NUM
25
30
  from .collator import PadCollator, StackCollator
26
- from .utils import extract_value_from_config
27
31
 
28
32
  logger = logging.getLogger(__name__)
29
33
  ImageFile.LOAD_TRUNCATED_IMAGES = True
@@ -40,11 +44,10 @@ class ImageProcessor:
40
44
  model: nn.Module,
41
45
  train_transforms: Union[List[str], Callable, List[Callable]],
42
46
  val_transforms: Union[List[str], Callable, List[Callable]],
43
- norm_type: Optional[str] = None,
44
- size: Optional[int] = None,
45
- max_img_num_per_col: Optional[int] = 1,
47
+ max_image_num_per_column: Optional[int] = 1,
46
48
  missing_value_strategy: Optional[str] = "zero",
47
- requires_column_info: bool = False,
49
+ requires_column_info: Optional[bool] = False,
50
+ dropout: Optional[float] = 0,
48
51
  ):
49
52
  """
50
53
  Parameters
@@ -55,18 +58,7 @@ class ImageProcessor:
55
58
  A list of image transforms used in training. Note that the transform order matters.
56
59
  val_transforms
57
60
  A list of image transforms used in validation/test/prediction. Note that the transform order matters.
58
- norm_type
59
- How to normalize an image. We now support:
60
- - inception
61
- Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
62
- - imagenet
63
- Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
64
- - clip
65
- Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
66
- std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
67
- size
68
- The provided width / height of a square image.
69
- max_img_num_per_col
61
+ max_image_num_per_column
70
62
  The maximum number of images one sample can have.
71
63
  missing_value_strategy
72
64
  How to deal with a missing image. We now support:
@@ -77,6 +69,7 @@ class ImageProcessor:
77
69
  requires_column_info
78
70
  Whether to require feature column information in dataloader.
79
71
  """
72
+ logger.debug(f"initializing image processor for model {model.prefix}")
80
73
  self.train_transforms = train_transforms
81
74
  self.val_transforms = val_transforms
82
75
  logger.debug(f"image training transforms: {self.train_transforms}")
@@ -85,62 +78,25 @@ class ImageProcessor:
85
78
  self.prefix = model.prefix
86
79
  self.missing_value_strategy = missing_value_strategy
87
80
  self.requires_column_info = requires_column_info
88
- self.size = None
89
- self.mean = None
90
- self.std = None
91
- if isinstance(model, CLIPForImageText):
92
- config = model.model.vision_model.config
93
- else:
94
- config = model.config
95
-
96
- if model is not None:
97
- self.size, self.mean, self.std = self.extract_default(config)
98
- if isinstance(model, TimmAutoModelForImagePrediction):
99
- if model.support_variable_input_size() and size is not None:
100
- # We have detected that the model supports using an image size that is
101
- # different from the pretrained model, e.g., ConvNets with global pooling
102
- if size < self.size:
103
- logger.warning(
104
- f"The provided image size={size} is smaller than the default size "
105
- f"of the pretrained backbone, which is {self.size}. "
106
- f"Detailed configuration of the backbone is in {config}. "
107
- f"You may like to double check your configuration."
108
- )
109
- self.size = size
110
- elif size is not None and size != self.size:
111
- logger.warning(
112
- f"The model does not support using an image size that is different from the default size. "
113
- f"Provided image size={size}. Default size={self.size}. "
114
- f"Detailed model configuration={config}. We have ignored the provided image size."
115
- )
116
- if self.size is None:
117
- if size is not None:
118
- self.size = size
119
- logger.debug(f"using provided image size: {self.size}")
120
- else:
121
- raise ValueError("image size is missing")
122
- else:
123
- logger.debug(f"using detected image size: {self.size}")
124
- if self.mean is None or self.std is None:
125
- if norm_type is not None:
126
- self.mean, self.std = image_mean_std(norm_type)
127
- logger.debug(f"using provided normalization: {norm_type}")
128
- else:
129
- raise ValueError("image normalization mean and std are missing")
130
- else:
131
- logger.debug(f"using detected image normalization: {self.mean} and {self.std}")
81
+ assert 0 <= dropout <= 1
82
+ if dropout > 0:
83
+ logger.debug(f"image dropout probability: {dropout}")
84
+ self.dropout = dropout
85
+ self.size = model.image_size
86
+ self.mean = model.image_mean
87
+ self.std = model.image_std
88
+
132
89
  self.normalization = transforms.Normalize(self.mean, self.std)
133
- self.max_img_num_per_col = max_img_num_per_col
134
- if max_img_num_per_col <= 0:
135
- logger.debug(f"max_img_num_per_col {max_img_num_per_col} is reset to 1")
136
- max_img_num_per_col = 1
137
- self.max_img_num_per_col = max_img_num_per_col
138
- logger.debug(f"max_img_num_per_col: {max_img_num_per_col}")
139
-
140
- self.train_processor = construct_image_processor(
90
+ if max_image_num_per_column <= 0:
91
+ logger.debug(f"max_image_num_per_column {max_image_num_per_column} is reset to 1")
92
+ max_image_num_per_column = 1
93
+ self.max_image_num_per_column = max_image_num_per_column
94
+ logger.debug(f"max_image_num_per_column: {max_image_num_per_column}")
95
+
96
+ self.train_processor = self.construct_image_processor(
141
97
  image_transforms=self.train_transforms, size=self.size, normalization=self.normalization
142
98
  )
143
- self.val_processor = construct_image_processor(
99
+ self.val_processor = self.construct_image_processor(
144
100
  image_transforms=self.val_transforms, size=self.size, normalization=self.normalization
145
101
  )
146
102
 
@@ -159,7 +115,7 @@ class ImageProcessor:
159
115
  def collate_fn(self, image_column_names: Optional[List] = None, per_gpu_batch_size: Optional[int] = None) -> Dict:
160
116
  """
161
117
  Collate images into a batch. Here it pads images since the image number may
162
- vary from sample to sample. Samples with less images will be padded zeros.
118
+ vary from sample to sample. Samples with fewer images will be padded zeros.
163
119
  The valid image numbers of samples will be stacked into a vector.
164
120
  This function will be used when creating Pytorch DataLoader.
165
121
 
@@ -182,65 +138,23 @@ class ImageProcessor:
182
138
 
183
139
  return fn
184
140
 
185
- def extract_default(self, config=None):
186
- """
187
- Extract some default hyper-parameters, e.g., image size, mean, and std,
188
- from a pre-trained (timm or huggingface) checkpoint.
189
-
190
- Parameters
191
- ----------
192
- config
193
- Config of a pre-trained checkpoint.
194
-
195
- Returns
196
- -------
197
- image_size
198
- Image width/height.
199
- mean
200
- Image normalization mean.
201
- std
202
- Image normalizaiton std.
203
- """
204
- if self.prefix.lower().startswith(TIMM_IMAGE):
205
- image_size = config["input_size"][-1]
206
- mean = config["mean"]
207
- std = config["std"]
208
- elif self.prefix.lower().startswith(CLIP):
209
- extracted = extract_value_from_config(
210
- config=config.to_diff_dict(),
211
- keys=("image_size",),
212
- )
213
- if len(extracted) == 0:
214
- image_size = None
215
- elif len(extracted) >= 1:
216
- image_size = extracted[0]
217
- if isinstance(image_size, tuple):
218
- image_size = image_size[-1]
219
- else:
220
- raise ValueError(f" more than one image_size values are detected: {extracted}")
221
- mean = None
222
- std = None
223
- else:
224
- raise ValueError(f"Unknown image processor prefix: {self.prefix}")
225
- return image_size, mean, std
226
-
227
141
  def process_one_sample(
228
142
  self,
229
- image_features: Dict[str, Union[List[str], List[bytearray]]],
230
- feature_modalities: Dict[str, List[str]],
143
+ images: Dict[str, Union[List[str], List[bytearray]]],
144
+ sub_dtypes: Dict[str, str],
231
145
  is_training: bool,
232
146
  image_mode: Optional[str] = "RGB",
233
147
  ) -> Dict:
234
148
  """
235
149
  Read images, process them, and stack them. One sample can have multiple images,
236
- resulting in a tensor of (n, 3, size, size), where n <= max_img_num_per_col is the available image number.
150
+ resulting in a tensor of (n, 3, size, size), where n <= max_image_num_per_column is the available image number.
237
151
 
238
152
  Parameters
239
153
  ----------
240
- image_features
154
+ images
241
155
  One sample may have multiple image columns in a pd.DataFrame and multiple images
242
156
  inside each image column.
243
- feature_modalities
157
+ sub_dtypes
244
158
  What modality each column belongs to.
245
159
  is_training
246
160
  Whether to process images in the training mode.
@@ -252,35 +166,38 @@ class ImageProcessor:
252
166
  -------
253
167
  A dictionary containing one sample's images and their number.
254
168
  """
255
- images = []
169
+ valid_images = []
256
170
  zero_images = []
257
171
  ret = {}
258
172
  column_start = 0
259
173
 
260
- for per_col_name, per_col_image_features in image_features.items():
261
- for img_feature in per_col_image_features[: self.max_img_num_per_col]:
262
- with warnings.catch_warnings():
263
- warnings.filterwarnings(
264
- "ignore",
265
- message=(
266
- "Palette images with Transparency expressed in bytes should be converted to RGBA images"
267
- ),
268
- )
269
- is_zero_img = False
270
- try:
271
- if feature_modalities.get(per_col_name) in [IMAGE_BYTEARRAY, IMAGE_BASE64_STR]:
272
- image_feature = BytesIO(img_feature)
273
- else:
274
- image_feature = img_feature
275
- with PIL.Image.open(image_feature) as img:
276
- img = img.convert(image_mode)
277
- except Exception as e:
278
- if self.missing_value_strategy.lower() == "zero":
279
- logger.debug(f"Using a zero image due to '{e}'")
280
- img = PIL.Image.new(image_mode, (self.size, self.size), color=0)
281
- is_zero_img = True
282
- else:
283
- raise e
174
+ for per_col_name, per_col_image_raw in images.items():
175
+ for img_raw in per_col_image_raw[: self.max_image_num_per_column]:
176
+ if is_training and self.dropout > 0 and random.uniform(0, 1) <= self.dropout:
177
+ img = PIL.Image.new(image_mode, (self.size, self.size), color=0)
178
+ is_zero_img = True
179
+ else:
180
+ with warnings.catch_warnings():
181
+ warnings.filterwarnings(
182
+ "ignore",
183
+ message=(
184
+ "Palette images with Transparency expressed in bytes should be converted to RGBA images"
185
+ ),
186
+ )
187
+ is_zero_img = False
188
+ try:
189
+ if sub_dtypes.get(per_col_name) in [IMAGE_BYTEARRAY, IMAGE_BASE64_STR]:
190
+ img_raw = BytesIO(img_raw)
191
+
192
+ with PIL.Image.open(img_raw) as img:
193
+ img = img.convert(image_mode)
194
+ except Exception as e:
195
+ if self.missing_value_strategy.lower() == "zero":
196
+ logger.debug(f"Using a zero image due to '{e}'")
197
+ img = PIL.Image.new(image_mode, (self.size, self.size), color=0)
198
+ is_zero_img = True
199
+ else:
200
+ raise e
284
201
  if is_training:
285
202
  img = self.train_processor(img)
286
203
  else:
@@ -289,29 +206,152 @@ class ImageProcessor:
289
206
  if is_zero_img:
290
207
  zero_images.append(img)
291
208
  else:
292
- images.append(img)
209
+ valid_images.append(img)
293
210
 
294
211
  if self.requires_column_info:
295
212
  # only count the valid images since they are put ahead of the zero images in the below returning
296
213
  ret[f"{self.image_column_prefix}_{per_col_name}"] = np.array(
297
- [column_start, len(images)], dtype=np.int64
214
+ [column_start, len(valid_images)], dtype=np.int64
298
215
  )
299
- column_start = len(images)
216
+ column_start = len(valid_images)
300
217
 
301
218
  ret.update(
302
219
  {
303
220
  self.image_key: torch.tensor([])
304
- if len(images + zero_images) == 0
305
- else torch.stack(images + zero_images, dim=0),
306
- self.image_valid_num_key: len(images),
221
+ if len(valid_images + zero_images) == 0
222
+ else torch.stack(valid_images + zero_images, dim=0),
223
+ self.image_valid_num_key: len(valid_images),
307
224
  }
308
225
  )
309
226
  return ret
310
227
 
228
+ @staticmethod
229
+ def get_image_transform_funcs(transform_types: Union[List[str], ListConfig, List[Callable]], size: int):
230
+ """
231
+ Parse a list of transform strings into callable objects.
232
+
233
+ Parameters
234
+ ----------
235
+ transform_types
236
+ A list of transforms, which can be strings or callable objects.
237
+ size
238
+ Image size.
239
+
240
+ Returns
241
+ -------
242
+ A list of transform objects.
243
+ """
244
+ image_transforms = []
245
+
246
+ if not transform_types:
247
+ return image_transforms
248
+
249
+ if isinstance(transform_types, ListConfig):
250
+ transform_types = list(transform_types)
251
+ elif not isinstance(transform_types, list):
252
+ transform_types = [transform_types]
253
+
254
+ if all([isinstance(trans_type, str) for trans_type in transform_types]):
255
+ pass
256
+ elif all([isinstance(trans_type, Callable) for trans_type in transform_types]):
257
+ return copy.copy(transform_types)
258
+ else:
259
+ raise ValueError(
260
+ f"transform_types {transform_types} contain neither all strings nor all callable objects."
261
+ )
262
+
263
+ for trans_type in transform_types:
264
+ args = None
265
+ kargs = None
266
+ if "(" in trans_type:
267
+ trans_mode = trans_type[0 : trans_type.find("(")]
268
+ if "{" in trans_type:
269
+ kargs = ast.literal_eval(trans_type[trans_type.find("{") : trans_type.rfind(")")])
270
+ else:
271
+ args = ast.literal_eval(trans_type[trans_type.find("(") :])
272
+ else:
273
+ trans_mode = trans_type
274
+
275
+ if trans_mode == "resize_to_square":
276
+ image_transforms.append(transforms.Resize((size, size), interpolation=BICUBIC))
277
+ elif trans_mode == "resize_gt_to_square":
278
+ image_transforms.append(transforms.Resize((size, size), interpolation=NEAREST))
279
+ elif trans_mode == "resize_shorter_side":
280
+ image_transforms.append(transforms.Resize(size, interpolation=BICUBIC))
281
+ elif trans_mode == "center_crop":
282
+ image_transforms.append(transforms.CenterCrop(size))
283
+ elif trans_mode == "random_resize_crop":
284
+ image_transforms.append(transforms.RandomResizedCrop(size))
285
+ elif trans_mode == "random_horizontal_flip":
286
+ image_transforms.append(transforms.RandomHorizontalFlip())
287
+ elif trans_mode == "random_vertical_flip":
288
+ image_transforms.append(transforms.RandomVerticalFlip())
289
+ elif trans_mode == "color_jitter":
290
+ if kargs is not None:
291
+ image_transforms.append(transforms.ColorJitter(**kargs))
292
+ elif args is not None:
293
+ image_transforms.append(transforms.ColorJitter(*args))
294
+ else:
295
+ image_transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1))
296
+ elif trans_mode == "affine":
297
+ if kargs is not None:
298
+ image_transforms.append(transforms.RandomAffine(**kargs))
299
+ elif args is not None:
300
+ image_transforms.append(transforms.RandomAffine(*args))
301
+ else:
302
+ image_transforms.append(
303
+ transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1))
304
+ )
305
+ elif trans_mode == "randaug":
306
+ if kargs is not None:
307
+ image_transforms.append(RandAugment(**kargs))
308
+ elif args is not None:
309
+ image_transforms.append(RandAugment(*args))
310
+ else:
311
+ image_transforms.append(RandAugment(2, 9))
312
+ elif trans_mode == "trivial_augment":
313
+ image_transforms.append(TrivialAugment(IMAGE, 30))
314
+ else:
315
+ raise ValueError(f"unknown transform type: {trans_mode}")
316
+
317
+ return image_transforms
318
+
319
+ def construct_image_processor(
320
+ self,
321
+ image_transforms: Union[List[Callable], List[str]],
322
+ size: int,
323
+ normalization,
324
+ ) -> transforms.Compose:
325
+ """
326
+ Build up an image processor from the provided list of transform types.
327
+
328
+ Parameters
329
+ ----------
330
+ image_transforms
331
+ A list of image transform types.
332
+ size
333
+ Image size.
334
+ normalization
335
+ A transforms.Normalize object. When the image is ground truth image, 'normalization=None' should be specified.
336
+
337
+ Returns
338
+ -------
339
+ A transforms.Compose object.
340
+ """
341
+ image_transforms = self.get_image_transform_funcs(transform_types=image_transforms, size=size)
342
+ if not any([isinstance(trans, transforms.ToTensor) for trans in image_transforms]):
343
+ image_transforms.append(transforms.ToTensor())
344
+ if (
345
+ not any([isinstance(trans, transforms.Normalize) for trans in image_transforms])
346
+ and normalization is not None
347
+ ):
348
+ image_transforms.append(normalization)
349
+ return transforms.Compose(image_transforms)
350
+
311
351
  def __call__(
312
352
  self,
313
353
  images: Dict[str, List[str]],
314
- feature_modalities: Dict[str, Union[int, float, list]],
354
+ sub_dtypes: Dict[str, str],
315
355
  is_training: bool,
316
356
  ) -> Dict:
317
357
  """
@@ -321,8 +361,8 @@ class ImageProcessor:
321
361
  ----------
322
362
  images
323
363
  Images of one sample.
324
- feature_modalities
325
- The modality of the feature columns.
364
+ sub_dtypes
365
+ The sub data types of all image columns.
326
366
  is_training
327
367
  Whether to process images in the training mode.
328
368
 
@@ -332,7 +372,7 @@ class ImageProcessor:
332
372
  """
333
373
  images = {k: [v] if isinstance(v, str) else v for k, v in images.items()}
334
374
 
335
- return self.process_one_sample(images, feature_modalities, is_training)
375
+ return self.process_one_sample(images=images, sub_dtypes=sub_dtypes, is_training=is_training)
336
376
 
337
377
  def __getstate__(self):
338
378
  odict = self.__dict__.copy() # get attribute dictionary
@@ -341,12 +381,7 @@ class ImageProcessor:
341
381
 
342
382
  def __setstate__(self, state):
343
383
  self.__dict__ = state
344
- if "train_transform_types" in state: # backward compatible
345
- self.train_transforms = list(self.train_transform_types)
346
- if "val_transform_types" in state:
347
- self.val_transforms = list(self.val_transform_types)
348
-
349
- self.train_processor = construct_image_processor(
384
+ self.train_processor = self.construct_image_processor(
350
385
  image_transforms=self.train_transforms,
351
386
  size=self.size,
352
387
  normalization=self.normalization,
@@ -1,3 +1,4 @@
1
+ import logging
1
2
  from typing import Any, Dict, List, Optional, Union
2
3
 
3
4
  from torch import nn
@@ -5,6 +6,8 @@ from torch import nn
5
6
  from ..constants import LABEL, MMDET_IMAGE
6
7
  from .collator import ListCollator, StackCollator
7
8
 
9
+ logger = logging.getLogger(__name__)
10
+
8
11
 
9
12
  class LabelProcessor:
10
13
  """
@@ -23,6 +26,7 @@ class LabelProcessor:
23
26
  model
24
27
  The model for which this processor would be created.
25
28
  """
29
+ logger.debug(f"initializing label processor for model {model.prefix}")
26
30
  self.prefix = model.prefix
27
31
 
28
32
  @property
@@ -68,7 +72,7 @@ class LabelProcessor:
68
72
  def __call__(
69
73
  self,
70
74
  labels: Dict[str, Union[int, float]],
71
- feature_modalities: Dict[str, Union[int, float, list]],
75
+ sub_dtypes: Dict[str, str],
72
76
  is_training: bool,
73
77
  load_only: bool = False, # TODO: refactor mmdet_image and remove this
74
78
  ) -> Dict:
@@ -79,8 +83,8 @@ class LabelProcessor:
79
83
  ----------
80
84
  labels
81
85
  Labels of one sample.
82
- feature_modalities
83
- The modality of the feature columns.
86
+ sub_dtypes
87
+ The sub data types of all label columns.
84
88
  is_training
85
89
  Whether to do processing in the training mode. This unused flag is for the API compatibility.
86
90
  load_only
@@ -6,14 +6,7 @@ import PIL
6
6
  from PIL import ImageFile
7
7
  from torch import nn
8
8
 
9
- try:
10
- from torchvision.transforms import InterpolationMode
11
-
12
- BICUBIC = InterpolationMode.BICUBIC
13
- except ImportError:
14
- BICUBIC = PIL.Image.BICUBIC
15
-
16
- from ..utils import is_rois_input
9
+ from ..infer_types import is_rois_input
17
10
  from .process_mmlab_base import MMLabProcessor
18
11
 
19
12
  try:
@@ -7,16 +7,9 @@ import PIL
7
7
  from PIL import ImageFile
8
8
  from torch import nn
9
9
 
10
- try:
11
- from torchvision.transforms import InterpolationMode
12
-
13
- BICUBIC = InterpolationMode.BICUBIC
14
- except ImportError:
15
- BICUBIC = PIL.Image.BICUBIC
16
-
17
- from ...constants import AUTOMM, COLUMN, IMAGE, IMAGE_VALID_NUM, MMDET_IMAGE
10
+ from ...constants import COLUMN, IMAGE, IMAGE_VALID_NUM, MMDET_IMAGE
18
11
  from ..collator import StackCollator
19
- from ..utils import is_rois_input
12
+ from ..infer_types import is_rois_input
20
13
 
21
14
  try:
22
15
  with warnings.catch_warnings():
@@ -6,15 +6,7 @@ import PIL
6
6
  from PIL import ImageFile
7
7
  from torch import nn
8
8
 
9
- try:
10
- from torchvision.transforms import InterpolationMode
11
-
12
- BICUBIC = InterpolationMode.BICUBIC
13
- except ImportError:
14
- BICUBIC = PIL.Image.BICUBIC
15
-
16
- from ...constants import AUTOMM
17
- from ..utils import is_rois_input
9
+ from ..infer_types import is_rois_input
18
10
  from .process_mmlab_base import MMLabProcessor
19
11
 
20
12
  logger = logging.getLogger(__name__)