autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,4 +1,7 @@
|
|
1
|
+
import ast
|
2
|
+
import copy
|
1
3
|
import logging
|
4
|
+
import random
|
2
5
|
import warnings
|
3
6
|
from io import BytesIO
|
4
7
|
from typing import Callable, Dict, List, Optional, Union
|
@@ -6,24 +9,25 @@ from typing import Callable, Dict, List, Optional, Union
|
|
6
9
|
import numpy as np
|
7
10
|
import PIL
|
8
11
|
import torch
|
12
|
+
from omegaconf import ListConfig
|
9
13
|
from PIL import ImageFile
|
10
14
|
from torch import nn
|
11
15
|
from torchvision import transforms
|
12
16
|
|
13
|
-
from .
|
17
|
+
from .randaug import RandAugment
|
18
|
+
from .trivial_augmenter import TrivialAugment
|
14
19
|
|
15
20
|
try:
|
16
21
|
from torchvision.transforms import InterpolationMode
|
17
22
|
|
18
23
|
BICUBIC = InterpolationMode.BICUBIC
|
24
|
+
NEAREST = InterpolationMode.NEAREST
|
19
25
|
except ImportError:
|
20
26
|
BICUBIC = PIL.Image.BICUBIC
|
27
|
+
NEAREST = PIL.Image.NEAREST
|
21
28
|
|
22
|
-
from ..constants import
|
23
|
-
from ..models.clip import CLIPForImageText
|
24
|
-
from ..models.timm_image import TimmAutoModelForImagePrediction
|
29
|
+
from ..constants import COLUMN, IMAGE, IMAGE_BASE64_STR, IMAGE_BYTEARRAY, IMAGE_VALID_NUM
|
25
30
|
from .collator import PadCollator, StackCollator
|
26
|
-
from .utils import extract_value_from_config
|
27
31
|
|
28
32
|
logger = logging.getLogger(__name__)
|
29
33
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
@@ -40,11 +44,10 @@ class ImageProcessor:
|
|
40
44
|
model: nn.Module,
|
41
45
|
train_transforms: Union[List[str], Callable, List[Callable]],
|
42
46
|
val_transforms: Union[List[str], Callable, List[Callable]],
|
43
|
-
|
44
|
-
size: Optional[int] = None,
|
45
|
-
max_img_num_per_col: Optional[int] = 1,
|
47
|
+
max_image_num_per_column: Optional[int] = 1,
|
46
48
|
missing_value_strategy: Optional[str] = "zero",
|
47
|
-
requires_column_info: bool = False,
|
49
|
+
requires_column_info: Optional[bool] = False,
|
50
|
+
dropout: Optional[float] = 0,
|
48
51
|
):
|
49
52
|
"""
|
50
53
|
Parameters
|
@@ -55,18 +58,7 @@ class ImageProcessor:
|
|
55
58
|
A list of image transforms used in training. Note that the transform order matters.
|
56
59
|
val_transforms
|
57
60
|
A list of image transforms used in validation/test/prediction. Note that the transform order matters.
|
58
|
-
|
59
|
-
How to normalize an image. We now support:
|
60
|
-
- inception
|
61
|
-
Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
|
62
|
-
- imagenet
|
63
|
-
Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
|
64
|
-
- clip
|
65
|
-
Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
|
66
|
-
std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
|
67
|
-
size
|
68
|
-
The provided width / height of a square image.
|
69
|
-
max_img_num_per_col
|
61
|
+
max_image_num_per_column
|
70
62
|
The maximum number of images one sample can have.
|
71
63
|
missing_value_strategy
|
72
64
|
How to deal with a missing image. We now support:
|
@@ -77,6 +69,7 @@ class ImageProcessor:
|
|
77
69
|
requires_column_info
|
78
70
|
Whether to require feature column information in dataloader.
|
79
71
|
"""
|
72
|
+
logger.debug(f"initializing image processor for model {model.prefix}")
|
80
73
|
self.train_transforms = train_transforms
|
81
74
|
self.val_transforms = val_transforms
|
82
75
|
logger.debug(f"image training transforms: {self.train_transforms}")
|
@@ -85,62 +78,25 @@ class ImageProcessor:
|
|
85
78
|
self.prefix = model.prefix
|
86
79
|
self.missing_value_strategy = missing_value_strategy
|
87
80
|
self.requires_column_info = requires_column_info
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
if model is not None:
|
97
|
-
self.size, self.mean, self.std = self.extract_default(config)
|
98
|
-
if isinstance(model, TimmAutoModelForImagePrediction):
|
99
|
-
if model.support_variable_input_size() and size is not None:
|
100
|
-
# We have detected that the model supports using an image size that is
|
101
|
-
# different from the pretrained model, e.g., ConvNets with global pooling
|
102
|
-
if size < self.size:
|
103
|
-
logger.warning(
|
104
|
-
f"The provided image size={size} is smaller than the default size "
|
105
|
-
f"of the pretrained backbone, which is {self.size}. "
|
106
|
-
f"Detailed configuration of the backbone is in {config}. "
|
107
|
-
f"You may like to double check your configuration."
|
108
|
-
)
|
109
|
-
self.size = size
|
110
|
-
elif size is not None and size != self.size:
|
111
|
-
logger.warning(
|
112
|
-
f"The model does not support using an image size that is different from the default size. "
|
113
|
-
f"Provided image size={size}. Default size={self.size}. "
|
114
|
-
f"Detailed model configuration={config}. We have ignored the provided image size."
|
115
|
-
)
|
116
|
-
if self.size is None:
|
117
|
-
if size is not None:
|
118
|
-
self.size = size
|
119
|
-
logger.debug(f"using provided image size: {self.size}")
|
120
|
-
else:
|
121
|
-
raise ValueError("image size is missing")
|
122
|
-
else:
|
123
|
-
logger.debug(f"using detected image size: {self.size}")
|
124
|
-
if self.mean is None or self.std is None:
|
125
|
-
if norm_type is not None:
|
126
|
-
self.mean, self.std = image_mean_std(norm_type)
|
127
|
-
logger.debug(f"using provided normalization: {norm_type}")
|
128
|
-
else:
|
129
|
-
raise ValueError("image normalization mean and std are missing")
|
130
|
-
else:
|
131
|
-
logger.debug(f"using detected image normalization: {self.mean} and {self.std}")
|
81
|
+
assert 0 <= dropout <= 1
|
82
|
+
if dropout > 0:
|
83
|
+
logger.debug(f"image dropout probability: {dropout}")
|
84
|
+
self.dropout = dropout
|
85
|
+
self.size = model.image_size
|
86
|
+
self.mean = model.image_mean
|
87
|
+
self.std = model.image_std
|
88
|
+
|
132
89
|
self.normalization = transforms.Normalize(self.mean, self.std)
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
self.train_processor = construct_image_processor(
|
90
|
+
if max_image_num_per_column <= 0:
|
91
|
+
logger.debug(f"max_image_num_per_column {max_image_num_per_column} is reset to 1")
|
92
|
+
max_image_num_per_column = 1
|
93
|
+
self.max_image_num_per_column = max_image_num_per_column
|
94
|
+
logger.debug(f"max_image_num_per_column: {max_image_num_per_column}")
|
95
|
+
|
96
|
+
self.train_processor = self.construct_image_processor(
|
141
97
|
image_transforms=self.train_transforms, size=self.size, normalization=self.normalization
|
142
98
|
)
|
143
|
-
self.val_processor = construct_image_processor(
|
99
|
+
self.val_processor = self.construct_image_processor(
|
144
100
|
image_transforms=self.val_transforms, size=self.size, normalization=self.normalization
|
145
101
|
)
|
146
102
|
|
@@ -159,7 +115,7 @@ class ImageProcessor:
|
|
159
115
|
def collate_fn(self, image_column_names: Optional[List] = None, per_gpu_batch_size: Optional[int] = None) -> Dict:
|
160
116
|
"""
|
161
117
|
Collate images into a batch. Here it pads images since the image number may
|
162
|
-
vary from sample to sample. Samples with
|
118
|
+
vary from sample to sample. Samples with fewer images will be padded zeros.
|
163
119
|
The valid image numbers of samples will be stacked into a vector.
|
164
120
|
This function will be used when creating Pytorch DataLoader.
|
165
121
|
|
@@ -182,65 +138,23 @@ class ImageProcessor:
|
|
182
138
|
|
183
139
|
return fn
|
184
140
|
|
185
|
-
def extract_default(self, config=None):
|
186
|
-
"""
|
187
|
-
Extract some default hyper-parameters, e.g., image size, mean, and std,
|
188
|
-
from a pre-trained (timm or huggingface) checkpoint.
|
189
|
-
|
190
|
-
Parameters
|
191
|
-
----------
|
192
|
-
config
|
193
|
-
Config of a pre-trained checkpoint.
|
194
|
-
|
195
|
-
Returns
|
196
|
-
-------
|
197
|
-
image_size
|
198
|
-
Image width/height.
|
199
|
-
mean
|
200
|
-
Image normalization mean.
|
201
|
-
std
|
202
|
-
Image normalizaiton std.
|
203
|
-
"""
|
204
|
-
if self.prefix.lower().startswith(TIMM_IMAGE):
|
205
|
-
image_size = config["input_size"][-1]
|
206
|
-
mean = config["mean"]
|
207
|
-
std = config["std"]
|
208
|
-
elif self.prefix.lower().startswith(CLIP):
|
209
|
-
extracted = extract_value_from_config(
|
210
|
-
config=config.to_diff_dict(),
|
211
|
-
keys=("image_size",),
|
212
|
-
)
|
213
|
-
if len(extracted) == 0:
|
214
|
-
image_size = None
|
215
|
-
elif len(extracted) >= 1:
|
216
|
-
image_size = extracted[0]
|
217
|
-
if isinstance(image_size, tuple):
|
218
|
-
image_size = image_size[-1]
|
219
|
-
else:
|
220
|
-
raise ValueError(f" more than one image_size values are detected: {extracted}")
|
221
|
-
mean = None
|
222
|
-
std = None
|
223
|
-
else:
|
224
|
-
raise ValueError(f"Unknown image processor prefix: {self.prefix}")
|
225
|
-
return image_size, mean, std
|
226
|
-
|
227
141
|
def process_one_sample(
|
228
142
|
self,
|
229
|
-
|
230
|
-
|
143
|
+
images: Dict[str, Union[List[str], List[bytearray]]],
|
144
|
+
sub_dtypes: Dict[str, str],
|
231
145
|
is_training: bool,
|
232
146
|
image_mode: Optional[str] = "RGB",
|
233
147
|
) -> Dict:
|
234
148
|
"""
|
235
149
|
Read images, process them, and stack them. One sample can have multiple images,
|
236
|
-
resulting in a tensor of (n, 3, size, size), where n <=
|
150
|
+
resulting in a tensor of (n, 3, size, size), where n <= max_image_num_per_column is the available image number.
|
237
151
|
|
238
152
|
Parameters
|
239
153
|
----------
|
240
|
-
|
154
|
+
images
|
241
155
|
One sample may have multiple image columns in a pd.DataFrame and multiple images
|
242
156
|
inside each image column.
|
243
|
-
|
157
|
+
sub_dtypes
|
244
158
|
What modality each column belongs to.
|
245
159
|
is_training
|
246
160
|
Whether to process images in the training mode.
|
@@ -252,35 +166,38 @@ class ImageProcessor:
|
|
252
166
|
-------
|
253
167
|
A dictionary containing one sample's images and their number.
|
254
168
|
"""
|
255
|
-
|
169
|
+
valid_images = []
|
256
170
|
zero_images = []
|
257
171
|
ret = {}
|
258
172
|
column_start = 0
|
259
173
|
|
260
|
-
for per_col_name,
|
261
|
-
for
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
174
|
+
for per_col_name, per_col_image_raw in images.items():
|
175
|
+
for img_raw in per_col_image_raw[: self.max_image_num_per_column]:
|
176
|
+
if is_training and self.dropout > 0 and random.uniform(0, 1) <= self.dropout:
|
177
|
+
img = PIL.Image.new(image_mode, (self.size, self.size), color=0)
|
178
|
+
is_zero_img = True
|
179
|
+
else:
|
180
|
+
with warnings.catch_warnings():
|
181
|
+
warnings.filterwarnings(
|
182
|
+
"ignore",
|
183
|
+
message=(
|
184
|
+
"Palette images with Transparency expressed in bytes should be converted to RGBA images"
|
185
|
+
),
|
186
|
+
)
|
187
|
+
is_zero_img = False
|
188
|
+
try:
|
189
|
+
if sub_dtypes.get(per_col_name) in [IMAGE_BYTEARRAY, IMAGE_BASE64_STR]:
|
190
|
+
img_raw = BytesIO(img_raw)
|
191
|
+
|
192
|
+
with PIL.Image.open(img_raw) as img:
|
193
|
+
img = img.convert(image_mode)
|
194
|
+
except Exception as e:
|
195
|
+
if self.missing_value_strategy.lower() == "zero":
|
196
|
+
logger.debug(f"Using a zero image due to '{e}'")
|
197
|
+
img = PIL.Image.new(image_mode, (self.size, self.size), color=0)
|
198
|
+
is_zero_img = True
|
199
|
+
else:
|
200
|
+
raise e
|
284
201
|
if is_training:
|
285
202
|
img = self.train_processor(img)
|
286
203
|
else:
|
@@ -289,29 +206,152 @@ class ImageProcessor:
|
|
289
206
|
if is_zero_img:
|
290
207
|
zero_images.append(img)
|
291
208
|
else:
|
292
|
-
|
209
|
+
valid_images.append(img)
|
293
210
|
|
294
211
|
if self.requires_column_info:
|
295
212
|
# only count the valid images since they are put ahead of the zero images in the below returning
|
296
213
|
ret[f"{self.image_column_prefix}_{per_col_name}"] = np.array(
|
297
|
-
[column_start, len(
|
214
|
+
[column_start, len(valid_images)], dtype=np.int64
|
298
215
|
)
|
299
|
-
column_start = len(
|
216
|
+
column_start = len(valid_images)
|
300
217
|
|
301
218
|
ret.update(
|
302
219
|
{
|
303
220
|
self.image_key: torch.tensor([])
|
304
|
-
if len(
|
305
|
-
else torch.stack(
|
306
|
-
self.image_valid_num_key: len(
|
221
|
+
if len(valid_images + zero_images) == 0
|
222
|
+
else torch.stack(valid_images + zero_images, dim=0),
|
223
|
+
self.image_valid_num_key: len(valid_images),
|
307
224
|
}
|
308
225
|
)
|
309
226
|
return ret
|
310
227
|
|
228
|
+
@staticmethod
|
229
|
+
def get_image_transform_funcs(transform_types: Union[List[str], ListConfig, List[Callable]], size: int):
|
230
|
+
"""
|
231
|
+
Parse a list of transform strings into callable objects.
|
232
|
+
|
233
|
+
Parameters
|
234
|
+
----------
|
235
|
+
transform_types
|
236
|
+
A list of transforms, which can be strings or callable objects.
|
237
|
+
size
|
238
|
+
Image size.
|
239
|
+
|
240
|
+
Returns
|
241
|
+
-------
|
242
|
+
A list of transform objects.
|
243
|
+
"""
|
244
|
+
image_transforms = []
|
245
|
+
|
246
|
+
if not transform_types:
|
247
|
+
return image_transforms
|
248
|
+
|
249
|
+
if isinstance(transform_types, ListConfig):
|
250
|
+
transform_types = list(transform_types)
|
251
|
+
elif not isinstance(transform_types, list):
|
252
|
+
transform_types = [transform_types]
|
253
|
+
|
254
|
+
if all([isinstance(trans_type, str) for trans_type in transform_types]):
|
255
|
+
pass
|
256
|
+
elif all([isinstance(trans_type, Callable) for trans_type in transform_types]):
|
257
|
+
return copy.copy(transform_types)
|
258
|
+
else:
|
259
|
+
raise ValueError(
|
260
|
+
f"transform_types {transform_types} contain neither all strings nor all callable objects."
|
261
|
+
)
|
262
|
+
|
263
|
+
for trans_type in transform_types:
|
264
|
+
args = None
|
265
|
+
kargs = None
|
266
|
+
if "(" in trans_type:
|
267
|
+
trans_mode = trans_type[0 : trans_type.find("(")]
|
268
|
+
if "{" in trans_type:
|
269
|
+
kargs = ast.literal_eval(trans_type[trans_type.find("{") : trans_type.rfind(")")])
|
270
|
+
else:
|
271
|
+
args = ast.literal_eval(trans_type[trans_type.find("(") :])
|
272
|
+
else:
|
273
|
+
trans_mode = trans_type
|
274
|
+
|
275
|
+
if trans_mode == "resize_to_square":
|
276
|
+
image_transforms.append(transforms.Resize((size, size), interpolation=BICUBIC))
|
277
|
+
elif trans_mode == "resize_gt_to_square":
|
278
|
+
image_transforms.append(transforms.Resize((size, size), interpolation=NEAREST))
|
279
|
+
elif trans_mode == "resize_shorter_side":
|
280
|
+
image_transforms.append(transforms.Resize(size, interpolation=BICUBIC))
|
281
|
+
elif trans_mode == "center_crop":
|
282
|
+
image_transforms.append(transforms.CenterCrop(size))
|
283
|
+
elif trans_mode == "random_resize_crop":
|
284
|
+
image_transforms.append(transforms.RandomResizedCrop(size))
|
285
|
+
elif trans_mode == "random_horizontal_flip":
|
286
|
+
image_transforms.append(transforms.RandomHorizontalFlip())
|
287
|
+
elif trans_mode == "random_vertical_flip":
|
288
|
+
image_transforms.append(transforms.RandomVerticalFlip())
|
289
|
+
elif trans_mode == "color_jitter":
|
290
|
+
if kargs is not None:
|
291
|
+
image_transforms.append(transforms.ColorJitter(**kargs))
|
292
|
+
elif args is not None:
|
293
|
+
image_transforms.append(transforms.ColorJitter(*args))
|
294
|
+
else:
|
295
|
+
image_transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1))
|
296
|
+
elif trans_mode == "affine":
|
297
|
+
if kargs is not None:
|
298
|
+
image_transforms.append(transforms.RandomAffine(**kargs))
|
299
|
+
elif args is not None:
|
300
|
+
image_transforms.append(transforms.RandomAffine(*args))
|
301
|
+
else:
|
302
|
+
image_transforms.append(
|
303
|
+
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1))
|
304
|
+
)
|
305
|
+
elif trans_mode == "randaug":
|
306
|
+
if kargs is not None:
|
307
|
+
image_transforms.append(RandAugment(**kargs))
|
308
|
+
elif args is not None:
|
309
|
+
image_transforms.append(RandAugment(*args))
|
310
|
+
else:
|
311
|
+
image_transforms.append(RandAugment(2, 9))
|
312
|
+
elif trans_mode == "trivial_augment":
|
313
|
+
image_transforms.append(TrivialAugment(IMAGE, 30))
|
314
|
+
else:
|
315
|
+
raise ValueError(f"unknown transform type: {trans_mode}")
|
316
|
+
|
317
|
+
return image_transforms
|
318
|
+
|
319
|
+
def construct_image_processor(
|
320
|
+
self,
|
321
|
+
image_transforms: Union[List[Callable], List[str]],
|
322
|
+
size: int,
|
323
|
+
normalization,
|
324
|
+
) -> transforms.Compose:
|
325
|
+
"""
|
326
|
+
Build up an image processor from the provided list of transform types.
|
327
|
+
|
328
|
+
Parameters
|
329
|
+
----------
|
330
|
+
image_transforms
|
331
|
+
A list of image transform types.
|
332
|
+
size
|
333
|
+
Image size.
|
334
|
+
normalization
|
335
|
+
A transforms.Normalize object. When the image is ground truth image, 'normalization=None' should be specified.
|
336
|
+
|
337
|
+
Returns
|
338
|
+
-------
|
339
|
+
A transforms.Compose object.
|
340
|
+
"""
|
341
|
+
image_transforms = self.get_image_transform_funcs(transform_types=image_transforms, size=size)
|
342
|
+
if not any([isinstance(trans, transforms.ToTensor) for trans in image_transforms]):
|
343
|
+
image_transforms.append(transforms.ToTensor())
|
344
|
+
if (
|
345
|
+
not any([isinstance(trans, transforms.Normalize) for trans in image_transforms])
|
346
|
+
and normalization is not None
|
347
|
+
):
|
348
|
+
image_transforms.append(normalization)
|
349
|
+
return transforms.Compose(image_transforms)
|
350
|
+
|
311
351
|
def __call__(
|
312
352
|
self,
|
313
353
|
images: Dict[str, List[str]],
|
314
|
-
|
354
|
+
sub_dtypes: Dict[str, str],
|
315
355
|
is_training: bool,
|
316
356
|
) -> Dict:
|
317
357
|
"""
|
@@ -321,8 +361,8 @@ class ImageProcessor:
|
|
321
361
|
----------
|
322
362
|
images
|
323
363
|
Images of one sample.
|
324
|
-
|
325
|
-
The
|
364
|
+
sub_dtypes
|
365
|
+
The sub data types of all image columns.
|
326
366
|
is_training
|
327
367
|
Whether to process images in the training mode.
|
328
368
|
|
@@ -332,7 +372,7 @@ class ImageProcessor:
|
|
332
372
|
"""
|
333
373
|
images = {k: [v] if isinstance(v, str) else v for k, v in images.items()}
|
334
374
|
|
335
|
-
return self.process_one_sample(images,
|
375
|
+
return self.process_one_sample(images=images, sub_dtypes=sub_dtypes, is_training=is_training)
|
336
376
|
|
337
377
|
def __getstate__(self):
|
338
378
|
odict = self.__dict__.copy() # get attribute dictionary
|
@@ -341,12 +381,7 @@ class ImageProcessor:
|
|
341
381
|
|
342
382
|
def __setstate__(self, state):
|
343
383
|
self.__dict__ = state
|
344
|
-
|
345
|
-
self.train_transforms = list(self.train_transform_types)
|
346
|
-
if "val_transform_types" in state:
|
347
|
-
self.val_transforms = list(self.val_transform_types)
|
348
|
-
|
349
|
-
self.train_processor = construct_image_processor(
|
384
|
+
self.train_processor = self.construct_image_processor(
|
350
385
|
image_transforms=self.train_transforms,
|
351
386
|
size=self.size,
|
352
387
|
normalization=self.normalization,
|
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from typing import Any, Dict, List, Optional, Union
|
2
3
|
|
3
4
|
from torch import nn
|
@@ -5,6 +6,8 @@ from torch import nn
|
|
5
6
|
from ..constants import LABEL, MMDET_IMAGE
|
6
7
|
from .collator import ListCollator, StackCollator
|
7
8
|
|
9
|
+
logger = logging.getLogger(__name__)
|
10
|
+
|
8
11
|
|
9
12
|
class LabelProcessor:
|
10
13
|
"""
|
@@ -23,6 +26,7 @@ class LabelProcessor:
|
|
23
26
|
model
|
24
27
|
The model for which this processor would be created.
|
25
28
|
"""
|
29
|
+
logger.debug(f"initializing label processor for model {model.prefix}")
|
26
30
|
self.prefix = model.prefix
|
27
31
|
|
28
32
|
@property
|
@@ -68,7 +72,7 @@ class LabelProcessor:
|
|
68
72
|
def __call__(
|
69
73
|
self,
|
70
74
|
labels: Dict[str, Union[int, float]],
|
71
|
-
|
75
|
+
sub_dtypes: Dict[str, str],
|
72
76
|
is_training: bool,
|
73
77
|
load_only: bool = False, # TODO: refactor mmdet_image and remove this
|
74
78
|
) -> Dict:
|
@@ -79,8 +83,8 @@ class LabelProcessor:
|
|
79
83
|
----------
|
80
84
|
labels
|
81
85
|
Labels of one sample.
|
82
|
-
|
83
|
-
The
|
86
|
+
sub_dtypes
|
87
|
+
The sub data types of all label columns.
|
84
88
|
is_training
|
85
89
|
Whether to do processing in the training mode. This unused flag is for the API compatibility.
|
86
90
|
load_only
|
@@ -6,14 +6,7 @@ import PIL
|
|
6
6
|
from PIL import ImageFile
|
7
7
|
from torch import nn
|
8
8
|
|
9
|
-
|
10
|
-
from torchvision.transforms import InterpolationMode
|
11
|
-
|
12
|
-
BICUBIC = InterpolationMode.BICUBIC
|
13
|
-
except ImportError:
|
14
|
-
BICUBIC = PIL.Image.BICUBIC
|
15
|
-
|
16
|
-
from ..utils import is_rois_input
|
9
|
+
from ..infer_types import is_rois_input
|
17
10
|
from .process_mmlab_base import MMLabProcessor
|
18
11
|
|
19
12
|
try:
|
@@ -7,16 +7,9 @@ import PIL
|
|
7
7
|
from PIL import ImageFile
|
8
8
|
from torch import nn
|
9
9
|
|
10
|
-
|
11
|
-
from torchvision.transforms import InterpolationMode
|
12
|
-
|
13
|
-
BICUBIC = InterpolationMode.BICUBIC
|
14
|
-
except ImportError:
|
15
|
-
BICUBIC = PIL.Image.BICUBIC
|
16
|
-
|
17
|
-
from ...constants import AUTOMM, COLUMN, IMAGE, IMAGE_VALID_NUM, MMDET_IMAGE
|
10
|
+
from ...constants import COLUMN, IMAGE, IMAGE_VALID_NUM, MMDET_IMAGE
|
18
11
|
from ..collator import StackCollator
|
19
|
-
from ..
|
12
|
+
from ..infer_types import is_rois_input
|
20
13
|
|
21
14
|
try:
|
22
15
|
with warnings.catch_warnings():
|
@@ -6,15 +6,7 @@ import PIL
|
|
6
6
|
from PIL import ImageFile
|
7
7
|
from torch import nn
|
8
8
|
|
9
|
-
|
10
|
-
from torchvision.transforms import InterpolationMode
|
11
|
-
|
12
|
-
BICUBIC = InterpolationMode.BICUBIC
|
13
|
-
except ImportError:
|
14
|
-
BICUBIC = PIL.Image.BICUBIC
|
15
|
-
|
16
|
-
from ...constants import AUTOMM
|
17
|
-
from ..utils import is_rois_input
|
9
|
+
from ..infer_types import is_rois_input
|
18
10
|
from .process_mmlab_base import MMLabProcessor
|
19
11
|
|
20
12
|
logger = logging.getLogger(__name__)
|