autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,16 +1,18 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
|
+
import re
|
3
4
|
import warnings
|
4
5
|
from typing import Any, Dict, List, Optional, Union
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
from numpy.typing import NDArray
|
8
9
|
from omegaconf import DictConfig
|
10
|
+
from tokenizers import pre_tokenizers
|
9
11
|
from torch import nn
|
10
12
|
|
11
|
-
from ..constants import
|
13
|
+
from ..constants import NER_ANNOTATION, NER_TEXT, TEXT, TEXT_NER
|
14
|
+
from ..models.utils import get_pretrained_tokenizer
|
12
15
|
from .collator import PadCollator, StackCollator
|
13
|
-
from .utils import process_ner_annotations, tokenize_ner_text
|
14
16
|
|
15
17
|
logger = logging.getLogger(__name__)
|
16
18
|
|
@@ -124,12 +126,12 @@ class NerProcessor:
|
|
124
126
|
ner_text = all_features[text_column]
|
125
127
|
if is_training or annotation_column is not None:
|
126
128
|
ner_annotation = all_features[annotation_column]
|
127
|
-
label, col_tokens, token_to_word_mappings, word_offsets = process_ner_annotations(
|
129
|
+
label, col_tokens, token_to_word_mappings, word_offsets = self.process_ner_annotations(
|
128
130
|
ner_annotation, ner_text, self.entity_map, self.tokenizer
|
129
131
|
)
|
130
132
|
ret.update({self.label_key: label})
|
131
133
|
else:
|
132
|
-
col_tokens, token_to_word_mappings, word_offsets = tokenize_ner_text(ner_text, self.tokenizer)
|
134
|
+
col_tokens, token_to_word_mappings, word_offsets = self.tokenize_ner_text(ner_text, self.tokenizer)
|
133
135
|
ret.update({self.label_key: np.array([], dtype=np.int32)})
|
134
136
|
|
135
137
|
ret.update(
|
@@ -144,6 +146,192 @@ class NerProcessor:
|
|
144
146
|
|
145
147
|
return ret
|
146
148
|
|
149
|
+
@classmethod
|
150
|
+
def process_ner_annotations(cls, ner_annotations, ner_text, entity_map, tokenizer, is_eval=False):
|
151
|
+
"""
|
152
|
+
Generate token-level/word-level labels with given text and NER annotations.
|
153
|
+
|
154
|
+
Parameters
|
155
|
+
----------
|
156
|
+
ner_annotations
|
157
|
+
The NER annotations.
|
158
|
+
ner_text
|
159
|
+
The corresponding raw text.
|
160
|
+
entity_map
|
161
|
+
The map between tags and tag indexes. e.g., {"PER":2, "LOC":3}.
|
162
|
+
tokenizer
|
163
|
+
The tokenizer to be used.
|
164
|
+
is_eval
|
165
|
+
Whether it is for evaluation or not, default: False
|
166
|
+
|
167
|
+
Returns
|
168
|
+
-------
|
169
|
+
Token-level/word-level labels and text features.
|
170
|
+
"""
|
171
|
+
col_tokens, token_to_word_mappings, word_offsets = cls.tokenize_ner_text(ner_text, tokenizer)
|
172
|
+
num_words = len(set(token_to_word_mappings)) - 1
|
173
|
+
word_label = [1] * num_words
|
174
|
+
# TODO: Potentially optimize word label generation via binary search
|
175
|
+
b_prefix = "B-"
|
176
|
+
i_prefix = "I-"
|
177
|
+
for annot in ner_annotations:
|
178
|
+
custom_offset = annot[0]
|
179
|
+
custom_label = annot[1]
|
180
|
+
is_start_word = True
|
181
|
+
for idx, word_offset in enumerate(word_offsets[:num_words, :]):
|
182
|
+
# support multiple words in an annotated offset range.
|
183
|
+
# Allow partial overlapping between custom annotations and pretokenized words.
|
184
|
+
if (word_offset[0] < custom_offset[1]) and (custom_offset[0] < word_offset[1]):
|
185
|
+
if not (
|
186
|
+
re.match(b_prefix, custom_label, re.IGNORECASE)
|
187
|
+
or re.match(i_prefix, custom_label, re.IGNORECASE)
|
188
|
+
):
|
189
|
+
if is_start_word and b_prefix + custom_label in entity_map:
|
190
|
+
word_label[idx] = entity_map[b_prefix + custom_label]
|
191
|
+
is_start_word = False
|
192
|
+
elif i_prefix + custom_label in entity_map:
|
193
|
+
word_label[idx] = entity_map[i_prefix + custom_label]
|
194
|
+
else:
|
195
|
+
if custom_label in entity_map:
|
196
|
+
word_label[idx] = entity_map[custom_label]
|
197
|
+
|
198
|
+
token_label = [0] * len(col_tokens.input_ids)
|
199
|
+
temp = set()
|
200
|
+
counter = 0
|
201
|
+
for idx, token_to_word in enumerate(token_to_word_mappings):
|
202
|
+
if token_to_word != -1 and token_to_word not in temp:
|
203
|
+
temp.add(token_to_word)
|
204
|
+
token_label[idx] = word_label[counter]
|
205
|
+
counter += 1
|
206
|
+
if not is_eval:
|
207
|
+
label = token_label # return token-level labels for training
|
208
|
+
else:
|
209
|
+
label = word_label # return word-level labels for evaluation
|
210
|
+
|
211
|
+
return label, col_tokens, token_to_word_mappings, word_offsets
|
212
|
+
|
213
|
+
@classmethod
|
214
|
+
def tokenize_ner_text(cls, text, tokenizer):
|
215
|
+
"""
|
216
|
+
Tokenization process for the NER task. It will be used for the token-level label generation
|
217
|
+
and the input text tokenization.
|
218
|
+
|
219
|
+
Parameters
|
220
|
+
----------
|
221
|
+
text
|
222
|
+
The raw text data.
|
223
|
+
tokenizer
|
224
|
+
The tokenizer to be used.
|
225
|
+
|
226
|
+
Returns
|
227
|
+
-------
|
228
|
+
The output of tokenizer and word offsets.
|
229
|
+
"""
|
230
|
+
# pre-tokenization is required for NER token-level label generation.
|
231
|
+
words_with_offsets = pre_tokenizers.BertPreTokenizer().pre_tokenize_str(text)
|
232
|
+
words_with_offsets = (
|
233
|
+
cls.is_space_counted(words_with_offsets) if len(words_with_offsets) > 1 else words_with_offsets
|
234
|
+
)
|
235
|
+
words = [word for word, offset in words_with_offsets]
|
236
|
+
word_offsets = np.array([[offset[0], offset[1]] for word, offset in words_with_offsets], dtype=np.int32)
|
237
|
+
col_tokens = tokenizer(
|
238
|
+
words,
|
239
|
+
is_split_into_words=True,
|
240
|
+
return_offsets_mapping=True,
|
241
|
+
padding="max_length",
|
242
|
+
truncation=True,
|
243
|
+
max_length=tokenizer.model_max_length,
|
244
|
+
return_token_type_ids=True,
|
245
|
+
)
|
246
|
+
offset_mapping = np.array(col_tokens.offset_mapping, dtype=np.int32)
|
247
|
+
if len(words_with_offsets) > 1:
|
248
|
+
if offset_mapping.shape[0] > len(words):
|
249
|
+
word_offsets = np.pad(word_offsets, ((0, offset_mapping.shape[0] - len(words)), (0, 0)), "constant")
|
250
|
+
# token to word mappings: it will tell us which token belongs to which word.
|
251
|
+
token_to_word_mappings = [i if i != None else -1 for i in col_tokens.word_ids()]
|
252
|
+
if len(set(token_to_word_mappings)) != len(words) + 1:
|
253
|
+
warnings.warn(f"The token to word mappings are incorrect!")
|
254
|
+
else:
|
255
|
+
# If pre_tokenizer does not give word offsets, use word_ids and offset_mappings instead.
|
256
|
+
word_offsets = np.append(offset_mapping[1:], [[0, 0]], axis=0)
|
257
|
+
word_idx = np.arange(len(col_tokens.word_ids()) - col_tokens.word_ids().count(None))
|
258
|
+
token_to_word_mappings = [
|
259
|
+
val + word_idx[idx - 1] if val != None else -1 for idx, val in enumerate(col_tokens.word_ids())
|
260
|
+
]
|
261
|
+
|
262
|
+
return col_tokens, token_to_word_mappings, word_offsets
|
263
|
+
|
264
|
+
@staticmethod
|
265
|
+
def is_space_counted(words_with_offsets):
|
266
|
+
"""
|
267
|
+
Some tokenizers will count space into words for example.
|
268
|
+
Given text: 'hello world', normal bert will output: [('hello', (0, 5)), ('world', (6, 11))]
|
269
|
+
while some checkpoint will output: [('▁hello', (0, 5)), ('▁world', (5, 11))]
|
270
|
+
This will lead to inconsistency issue during labelling, details can be found here:
|
271
|
+
https://github.com/huggingface/transformers/issues/18111
|
272
|
+
|
273
|
+
This function will check whether space is counted or not and realign the offset.
|
274
|
+
"""
|
275
|
+
offset0, offset1 = [], []
|
276
|
+
for word, offset in words_with_offsets:
|
277
|
+
offset0.append(offset[0])
|
278
|
+
offset1.append(offset[1])
|
279
|
+
|
280
|
+
realign = []
|
281
|
+
if offset0[1:] == offset1[:-1]: # space are counted
|
282
|
+
realign = [words_with_offsets[0]]
|
283
|
+
for word, offset in words_with_offsets[1:]:
|
284
|
+
if word.startswith("▁"): # it is "Lower One Eighth Block" (U+2581) rather than lower line (U+005F).
|
285
|
+
realign.append((word, (offset[0] + 1, offset[1])))
|
286
|
+
else:
|
287
|
+
realign.append((word, offset))
|
288
|
+
|
289
|
+
if realign:
|
290
|
+
return realign
|
291
|
+
else:
|
292
|
+
return words_with_offsets
|
293
|
+
|
294
|
+
def save_tokenizer(
|
295
|
+
self,
|
296
|
+
path: str,
|
297
|
+
):
|
298
|
+
"""
|
299
|
+
Save the text tokenizer and record its relative paths, e.g, hf_text.
|
300
|
+
|
301
|
+
Parameters
|
302
|
+
----------
|
303
|
+
path
|
304
|
+
The root path of saving.
|
305
|
+
|
306
|
+
"""
|
307
|
+
save_path = os.path.join(path, self.prefix)
|
308
|
+
self.tokenizer.save_pretrained(save_path)
|
309
|
+
self.tokenizer = self.prefix
|
310
|
+
|
311
|
+
def load_tokenizer(
|
312
|
+
self,
|
313
|
+
path: str,
|
314
|
+
):
|
315
|
+
"""
|
316
|
+
Load saved text tokenizers. If text/ner processors already have tokenizers,
|
317
|
+
then do nothing.
|
318
|
+
|
319
|
+
Parameters
|
320
|
+
----------
|
321
|
+
path
|
322
|
+
The root path of loading.
|
323
|
+
|
324
|
+
Returns
|
325
|
+
-------
|
326
|
+
A list of text/ner processors with tokenizers loaded.
|
327
|
+
"""
|
328
|
+
if isinstance(self.tokenizer, str):
|
329
|
+
load_path = os.path.join(path, self.tokenizer)
|
330
|
+
self.tokenizer = get_pretrained_tokenizer(
|
331
|
+
tokenizer_name=self.tokenizer_name,
|
332
|
+
checkpoint_name=load_path,
|
333
|
+
)
|
334
|
+
|
147
335
|
def __call__(
|
148
336
|
self,
|
149
337
|
all_features: Dict[str, Union[NDArray, list]],
|
@@ -1,3 +1,5 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
1
3
|
from typing import Any, Dict, List, Optional, Union
|
2
4
|
|
3
5
|
import numpy as np
|
@@ -6,6 +8,8 @@ from torch import nn
|
|
6
8
|
from ..constants import COLUMN, NUMERICAL
|
7
9
|
from .collator import StackCollator
|
8
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
9
13
|
|
10
14
|
class NumericalProcessor:
|
11
15
|
"""
|
@@ -19,6 +23,7 @@ class NumericalProcessor:
|
|
19
23
|
model: nn.Module,
|
20
24
|
merge: Optional[str] = "concat",
|
21
25
|
requires_column_info: bool = False,
|
26
|
+
dropout: Optional[float] = 0,
|
22
27
|
):
|
23
28
|
"""
|
24
29
|
Parameters
|
@@ -33,9 +38,16 @@ class NumericalProcessor:
|
|
33
38
|
requires_column_info
|
34
39
|
Whether to require feature column information in dataloader.
|
35
40
|
"""
|
41
|
+
logger.debug(f"initializing numerical processor for model {model.prefix}")
|
36
42
|
self.prefix = model.prefix
|
37
43
|
self.merge = merge
|
38
44
|
self.requires_column_info = requires_column_info
|
45
|
+
self.numerical_fill_values = model.numerical_fill_values
|
46
|
+
self.dropout = dropout
|
47
|
+
assert 0 <= self.dropout <= 1
|
48
|
+
if self.dropout > 0:
|
49
|
+
logger.debug(f"numerical value dropout probability: {self.dropout}")
|
50
|
+
logger.debug(f"dropped values will be replaced by {self.numerical_fill_values}")
|
39
51
|
|
40
52
|
@property
|
41
53
|
def numerical_key(self):
|
@@ -67,6 +79,7 @@ class NumericalProcessor:
|
|
67
79
|
def process_one_sample(
|
68
80
|
self,
|
69
81
|
numerical_features: Dict[str, float],
|
82
|
+
is_training: bool,
|
70
83
|
) -> Dict:
|
71
84
|
"""
|
72
85
|
Process one sample's numerical features.
|
@@ -76,6 +89,8 @@ class NumericalProcessor:
|
|
76
89
|
----------
|
77
90
|
numerical_features
|
78
91
|
Numerical features of one sample.
|
92
|
+
is_training
|
93
|
+
Whether to do processing in the training mode.
|
79
94
|
|
80
95
|
Returns
|
81
96
|
-------
|
@@ -87,6 +102,15 @@ class NumericalProcessor:
|
|
87
102
|
for i, col_name in enumerate(numerical_features.keys()):
|
88
103
|
ret[f"{self.numerical_column_prefix}_{col_name}"] = i
|
89
104
|
|
105
|
+
if is_training and self.dropout > 0:
|
106
|
+
numerical_features_copy = dict()
|
107
|
+
for k, v in numerical_features.items():
|
108
|
+
if random.uniform(0, 1) <= self.dropout:
|
109
|
+
numerical_features_copy[k] = self.numerical_fill_values[k]
|
110
|
+
else:
|
111
|
+
numerical_features_copy[k] = v
|
112
|
+
numerical_features = numerical_features_copy
|
113
|
+
|
90
114
|
if self.merge == "concat":
|
91
115
|
ret[self.numerical_key] = np.array(list(numerical_features.values()), dtype=np.float32)
|
92
116
|
else:
|
@@ -97,7 +121,7 @@ class NumericalProcessor:
|
|
97
121
|
def __call__(
|
98
122
|
self,
|
99
123
|
numerical_features: Dict[str, float],
|
100
|
-
|
124
|
+
sub_dtypes: Dict[str, str],
|
101
125
|
is_training: bool,
|
102
126
|
) -> Dict:
|
103
127
|
"""
|
@@ -107,13 +131,16 @@ class NumericalProcessor:
|
|
107
131
|
----------
|
108
132
|
numerical_features
|
109
133
|
Numerical features of one sample.
|
110
|
-
|
111
|
-
The
|
134
|
+
sub_dtypes
|
135
|
+
The sub data types of all numerical columns.
|
112
136
|
is_training
|
113
|
-
Whether to do processing in the training mode.
|
137
|
+
Whether to do processing in the training mode.
|
114
138
|
|
115
139
|
Returns
|
116
140
|
-------
|
117
141
|
A dictionary containing one sample's processed numerical features.
|
118
142
|
"""
|
119
|
-
return self.process_one_sample(
|
143
|
+
return self.process_one_sample(
|
144
|
+
numerical_features=numerical_features,
|
145
|
+
is_training=is_training,
|
146
|
+
)
|
@@ -1,44 +1,32 @@
|
|
1
1
|
import logging
|
2
2
|
import random
|
3
|
-
import warnings
|
4
|
-
from io import BytesIO
|
5
3
|
from typing import Dict, List, Optional, Union
|
6
4
|
|
7
5
|
import numpy as np
|
8
6
|
import PIL
|
9
7
|
import torch
|
10
|
-
from omegaconf import DictConfig
|
11
8
|
from PIL import Image, ImageFile
|
12
9
|
from torch import nn
|
13
10
|
from torchvision import transforms
|
14
11
|
|
15
|
-
from .utils import construct_image_processor, image_mean_std
|
16
|
-
|
17
|
-
try:
|
18
|
-
from torchvision.transforms import InterpolationMode
|
19
|
-
|
20
|
-
BICUBIC = InterpolationMode.BICUBIC
|
21
|
-
except ImportError:
|
22
|
-
BICUBIC = PIL.Image.BICUBIC
|
23
|
-
|
24
12
|
from ..constants import (
|
25
13
|
CLASS_LABEL,
|
26
14
|
COLUMN,
|
27
15
|
IMAGE,
|
28
|
-
IMAGE_BYTEARRAY,
|
29
16
|
IMAGE_VALID_NUM,
|
30
17
|
LABEL,
|
31
18
|
MASK_LABEL,
|
32
19
|
SEMANTIC_SEGMENTATION_GT,
|
33
20
|
SEMANTIC_SEGMENTATION_IMG,
|
34
21
|
)
|
35
|
-
from .collator import ListCollator, PadCollator
|
22
|
+
from .collator import ListCollator, PadCollator
|
23
|
+
from .process_image import ImageProcessor
|
36
24
|
|
37
25
|
logger = logging.getLogger(__name__)
|
38
26
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
39
27
|
|
40
28
|
|
41
|
-
class SemanticSegImageProcessor:
|
29
|
+
class SemanticSegImageProcessor(ImageProcessor):
|
42
30
|
"""
|
43
31
|
Prepare image data for the model specified by "prefix". For multiple models requiring image data,
|
44
32
|
we need to create a ImageProcessor for each related model so that they will have independent input.
|
@@ -51,7 +39,6 @@ class SemanticSegImageProcessor:
|
|
51
39
|
gt_transforms: List[str],
|
52
40
|
train_transforms: Optional[List[str]] = None,
|
53
41
|
val_transforms: Optional[List[str]] = None,
|
54
|
-
norm_type: Optional[str] = None,
|
55
42
|
max_img_num_per_col: Optional[int] = 1,
|
56
43
|
missing_value_strategy: Optional[str] = "skip",
|
57
44
|
requires_column_info: bool = False,
|
@@ -70,15 +57,6 @@ class SemanticSegImageProcessor:
|
|
70
57
|
A list of image transforms used in training for data augmentation. Note that the transform order matters.
|
71
58
|
val_transforms
|
72
59
|
A list of image transforms used in validation/test/prediction. Note that the transform order matters.
|
73
|
-
norm_type
|
74
|
-
How to normalize an image. We now support:
|
75
|
-
- inception
|
76
|
-
Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
|
77
|
-
- imagenet
|
78
|
-
Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
|
79
|
-
- clip
|
80
|
-
Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
|
81
|
-
std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
|
82
60
|
max_img_num_per_col
|
83
61
|
The maximum number of images one sample can have.
|
84
62
|
missing_value_strategy
|
@@ -98,7 +76,8 @@ class SemanticSegImageProcessor:
|
|
98
76
|
self.requires_column_info = requires_column_info
|
99
77
|
|
100
78
|
self.size = model.image_size
|
101
|
-
self.mean
|
79
|
+
self.mean = model.image_mean
|
80
|
+
self.std = model.image_std
|
102
81
|
self.normalization = transforms.Normalize(self.mean, self.std)
|
103
82
|
self.num_classes = model.num_classes
|
104
83
|
self.ignore_label = ignore_label
|
@@ -110,10 +89,10 @@ class SemanticSegImageProcessor:
|
|
110
89
|
self.max_img_num_per_col = max_img_num_per_col
|
111
90
|
logger.debug(f"max_img_num_per_col: {max_img_num_per_col}")
|
112
91
|
|
113
|
-
self.img_processor = construct_image_processor(
|
92
|
+
self.img_processor = self.construct_image_processor(
|
114
93
|
image_transforms=self.img_transforms, size=self.size, normalization=self.normalization
|
115
94
|
)
|
116
|
-
self.gt_processor = construct_image_processor(
|
95
|
+
self.gt_processor = self.construct_image_processor(
|
117
96
|
image_transforms=self.gt_transforms, size=self.size, normalization=None
|
118
97
|
)
|
119
98
|
self.train_transforms = self.get_train_transforms(train_transforms)
|
@@ -325,3 +304,19 @@ class SemanticSegImageProcessor:
|
|
325
304
|
if trans_mode == "random_horizontal_flip":
|
326
305
|
train_trans.append(transforms.RandomHorizontalFlip(1.0))
|
327
306
|
return transforms.Compose(train_trans)
|
307
|
+
|
308
|
+
def __getstate__(self):
|
309
|
+
odict = self.__dict__.copy() # get attribute dictionary
|
310
|
+
del odict["img_processor"]
|
311
|
+
del odict["gt_processor"]
|
312
|
+
|
313
|
+
return odict
|
314
|
+
|
315
|
+
def __setstate__(self, state):
|
316
|
+
self.__dict__ = state
|
317
|
+
self.img_processor = self.construct_image_processor(
|
318
|
+
image_transforms=self.img_transforms, size=self.size, normalization=self.normalization
|
319
|
+
)
|
320
|
+
self.gt_processor = self.construct_image_processor(
|
321
|
+
image_transforms=self.gt_transforms, size=self.size, normalization=None
|
322
|
+
)
|