autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -14,17 +14,14 @@ from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
|
14
14
|
from autogluon.features import CategoryFeatureGenerator
|
15
15
|
|
16
16
|
from ..constants import (
|
17
|
-
AUTOMM,
|
18
17
|
CATEGORICAL,
|
19
18
|
DOCUMENT,
|
20
|
-
DOCUMENT_IMAGE,
|
21
19
|
IDENTIFIER,
|
22
20
|
IMAGE,
|
23
21
|
IMAGE_BASE64_STR,
|
24
22
|
IMAGE_BYTEARRAY,
|
25
23
|
IMAGE_PATH,
|
26
24
|
LABEL,
|
27
|
-
NER,
|
28
25
|
NER_ANNOTATION,
|
29
26
|
NULL,
|
30
27
|
NUMERICAL,
|
@@ -73,19 +70,17 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
73
70
|
|
74
71
|
if label_column:
|
75
72
|
if label_generator is None:
|
76
|
-
self._label_generator = CustomLabelEncoder(
|
77
|
-
positive_class=OmegaConf.select(config, "pos_label", default=None)
|
78
|
-
)
|
73
|
+
self._label_generator = CustomLabelEncoder(positive_class=config.pos_label)
|
79
74
|
else:
|
80
75
|
self._label_generator = label_generator
|
81
76
|
|
82
77
|
# Scaler used for numerical labels
|
83
|
-
numerical_label_preprocessing =
|
78
|
+
numerical_label_preprocessing = config.label.numerical_preprocessing
|
84
79
|
if numerical_label_preprocessing == "minmaxscaler":
|
85
80
|
self._label_scaler = MinMaxScaler()
|
86
81
|
elif numerical_label_preprocessing == "standardscaler":
|
87
82
|
self._label_scaler = StandardScaler()
|
88
|
-
elif numerical_label_preprocessing is None
|
83
|
+
elif numerical_label_preprocessing is None:
|
89
84
|
self._label_scaler = StandardScaler(with_mean=False, with_std=False)
|
90
85
|
else:
|
91
86
|
raise ValueError(
|
@@ -135,8 +130,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
135
130
|
# Some columns will be ignored
|
136
131
|
self._ignore_columns_set = set()
|
137
132
|
self._text_feature_names = []
|
138
|
-
self.
|
139
|
-
self._categorical_num_categories = []
|
133
|
+
self._categorical_num_categories = dict()
|
140
134
|
self._numerical_feature_names = []
|
141
135
|
self._image_feature_names = []
|
142
136
|
self._rois_feature_names = []
|
@@ -154,10 +148,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
154
148
|
|
155
149
|
@property
|
156
150
|
def image_path_names(self):
|
157
|
-
if
|
158
|
-
return self._image_path_names
|
159
|
-
else:
|
160
|
-
return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
|
151
|
+
return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
|
161
152
|
|
162
153
|
@property
|
163
154
|
def rois_feature_names(self):
|
@@ -173,7 +164,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
173
164
|
|
174
165
|
@property
|
175
166
|
def image_feature_names(self):
|
176
|
-
return self.
|
167
|
+
return self._image_feature_names
|
177
168
|
|
178
169
|
@property
|
179
170
|
def text_feature_names(self):
|
@@ -181,12 +172,21 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
181
172
|
|
182
173
|
@property
|
183
174
|
def categorical_feature_names(self):
|
184
|
-
return self.
|
175
|
+
return list(self.categorical_num_categories.keys())
|
185
176
|
|
186
177
|
@property
|
187
178
|
def numerical_feature_names(self):
|
188
179
|
return self._numerical_feature_names
|
189
180
|
|
181
|
+
@property
|
182
|
+
def numerical_fill_values(self):
|
183
|
+
ret = dict()
|
184
|
+
for col_name in self._numerical_feature_names:
|
185
|
+
generator = self._feature_generators[col_name]
|
186
|
+
ret[col_name] = generator.transform(np.full([1, 1], np.nan))[:, 0][0]
|
187
|
+
|
188
|
+
return ret
|
189
|
+
|
190
190
|
@property
|
191
191
|
def document_feature_names(self):
|
192
192
|
# Added for backward compatibility.
|
@@ -216,17 +216,12 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
216
216
|
|
217
217
|
@property
|
218
218
|
def required_feature_names(self):
|
219
|
-
image_feature_names = (
|
220
|
-
self._image_path_names if hasattr(self, "_image_path_names") else self._image_feature_names
|
221
|
-
)
|
222
|
-
rois_feature_names = self._rois_feature_names if hasattr(self, "_rois_feature_names") else []
|
223
|
-
|
224
219
|
return (
|
225
|
-
|
220
|
+
self._image_feature_names
|
226
221
|
+ self._text_feature_names
|
227
222
|
+ self._numerical_feature_names
|
228
|
-
+ self.
|
229
|
-
+
|
223
|
+
+ self.categorical_feature_names
|
224
|
+
+ self._rois_feature_names
|
230
225
|
)
|
231
226
|
|
232
227
|
@property
|
@@ -268,16 +263,13 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
268
263
|
|
269
264
|
def get_column_names(self, modality: str):
|
270
265
|
if modality.startswith(IMAGE):
|
271
|
-
|
272
|
-
return self._image_path_names
|
273
|
-
else:
|
274
|
-
return self._image_feature_names
|
266
|
+
return self._image_feature_names
|
275
267
|
elif modality == ROIS:
|
276
268
|
return self._rois_feature_names
|
277
269
|
elif modality == TEXT:
|
278
270
|
return self._text_feature_names
|
279
271
|
elif modality == CATEGORICAL:
|
280
|
-
return self.
|
272
|
+
return self.categorical_feature_names
|
281
273
|
elif modality == NUMERICAL:
|
282
274
|
return self._numerical_feature_names
|
283
275
|
elif modality.startswith(DOCUMENT):
|
@@ -344,8 +336,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
344
336
|
continue
|
345
337
|
num_categories = len(generator.category_map[col_name])
|
346
338
|
# Add one unknown category
|
347
|
-
self._categorical_num_categories
|
348
|
-
self._categorical_feature_names.append(col_name)
|
339
|
+
self._categorical_num_categories[col_name] = num_categories + 1
|
349
340
|
elif col_type == NUMERICAL:
|
350
341
|
processed_data = pd.to_numeric(col_value)
|
351
342
|
if len(processed_data.unique()) == 1:
|
@@ -392,7 +383,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
392
383
|
elif self.label_type == NUMERICAL:
|
393
384
|
y = pd.to_numeric(y).to_numpy()
|
394
385
|
self._label_scaler.fit(np.expand_dims(y, axis=-1))
|
395
|
-
elif self.label_type
|
386
|
+
elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
|
396
387
|
pass # Do nothing. TODO: Shall we call fit here?
|
397
388
|
elif self.label_type == NER_ANNOTATION:
|
398
389
|
# If there are ner annotations and text columns but no NER feature columns,
|
@@ -426,6 +417,24 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
426
417
|
if y is not None:
|
427
418
|
self._fit_y(y=y, X=X)
|
428
419
|
|
420
|
+
@staticmethod
|
421
|
+
def convert_categorical_to_text(col_value: pd.Series, template: str, col_name: str):
|
422
|
+
# TODO: do we need to consider whether categorical values are valid text?
|
423
|
+
col_value = col_value.astype("object")
|
424
|
+
if template == "direct":
|
425
|
+
processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
|
426
|
+
elif template == "list":
|
427
|
+
processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + ": " + str(ele))
|
428
|
+
elif template == "text":
|
429
|
+
processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + " is " + str(ele))
|
430
|
+
elif template == "latex":
|
431
|
+
processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele) + " & ")
|
432
|
+
else:
|
433
|
+
raise ValueError(
|
434
|
+
f"Unsupported template {template} for converting categorical data into text. Select one from: ['direct', 'list', 'text', 'latex']."
|
435
|
+
)
|
436
|
+
return processed_data
|
437
|
+
|
429
438
|
def transform_text(
|
430
439
|
self,
|
431
440
|
df: pd.DataFrame,
|
@@ -455,10 +464,15 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
455
464
|
for col_name in self._text_feature_names:
|
456
465
|
col_value = df[col_name]
|
457
466
|
col_type = self._column_types[col_name]
|
458
|
-
if col_type == TEXT
|
459
|
-
# TODO: do we need to consider whether categorical values are valid text?
|
467
|
+
if col_type == TEXT:
|
460
468
|
col_value = col_value.astype("object")
|
461
469
|
processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
|
470
|
+
elif col_type == CATEGORICAL:
|
471
|
+
processed_data = self.convert_categorical_to_text(
|
472
|
+
col_value=col_value,
|
473
|
+
template=self._config.categorical.convert_to_text_template,
|
474
|
+
col_name=col_name,
|
475
|
+
)
|
462
476
|
elif col_type == NUMERICAL:
|
463
477
|
processed_data = pd.to_numeric(col_value).apply("{:.3f}".format)
|
464
478
|
elif col_type == f"{TEXT}_{IDENTIFIER}":
|
@@ -710,7 +724,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
710
724
|
self._fit_called or self._fit_x_called
|
711
725
|
), "You will need to first call preprocessor.fit before calling preprocessor.transform_categorical."
|
712
726
|
categorical_features = {}
|
713
|
-
for col_name, num_category in
|
727
|
+
for col_name, num_category in self._categorical_num_categories.items():
|
714
728
|
col_value = df[col_name]
|
715
729
|
processed_data = col_value.astype("category")
|
716
730
|
generator = self._feature_generators[col_name]
|
@@ -757,7 +771,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
757
771
|
elif self.label_type == NUMERICAL:
|
758
772
|
y = pd.to_numeric(y_df).to_numpy()
|
759
773
|
y = self._label_scaler.transform(np.expand_dims(y, axis=-1))[:, 0].astype(np.float32)
|
760
|
-
elif self.label_type
|
774
|
+
elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
|
761
775
|
y = y_df.to_list()
|
762
776
|
elif self.label_type == NER_ANNOTATION:
|
763
777
|
y = self._label_generator.transform(y_df)
|
@@ -866,8 +880,11 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
|
|
866
880
|
), "You will need to first call preprocessor.fit_y() before calling preprocessor.transform_prediction."
|
867
881
|
|
868
882
|
if self.label_type == CATEGORICAL:
|
869
|
-
assert y_pred.shape
|
870
|
-
y_pred
|
883
|
+
assert len(y_pred.shape) <= 2
|
884
|
+
if len(y_pred.shape) == 2 and y_pred.shape[1] >= 2:
|
885
|
+
y_pred = y_pred.argmax(axis=1)
|
886
|
+
else:
|
887
|
+
y_pred = (y_pred > 0.5).astype(int)
|
871
888
|
# Transform the predicted label back to the original space (e.g., string values)
|
872
889
|
if inverse_categorical:
|
873
890
|
y_pred = self._label_generator.inverse_transform(y_pred)
|
@@ -1,11 +1,14 @@
|
|
1
|
+
import logging
|
2
|
+
import random
|
1
3
|
from typing import Any, Dict, List, Optional, Union
|
2
4
|
|
3
|
-
import numpy as np
|
4
5
|
from torch import nn
|
5
6
|
|
6
7
|
from ..constants import CATEGORICAL, COLUMN
|
7
8
|
from .collator import StackCollator, TupleCollator
|
8
9
|
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
9
12
|
|
10
13
|
class CategoricalProcessor:
|
11
14
|
"""
|
@@ -18,6 +21,7 @@ class CategoricalProcessor:
|
|
18
21
|
self,
|
19
22
|
model: nn.Module,
|
20
23
|
requires_column_info: bool = False,
|
24
|
+
dropout: Optional[float] = 0,
|
21
25
|
):
|
22
26
|
"""
|
23
27
|
Parameters
|
@@ -27,8 +31,16 @@ class CategoricalProcessor:
|
|
27
31
|
requires_column_info
|
28
32
|
Whether to require feature column information in dataloader.
|
29
33
|
"""
|
34
|
+
logger.debug(f"initializing categorical processor for model {model.prefix}")
|
30
35
|
self.prefix = model.prefix
|
31
36
|
self.requires_column_info = requires_column_info
|
37
|
+
self.num_categories = model.num_categories
|
38
|
+
self.dropout = dropout
|
39
|
+
assert 0 <= self.dropout <= 1
|
40
|
+
if self.dropout > 0:
|
41
|
+
logger.debug(f"categorical value dropout probability: {self.dropout}")
|
42
|
+
fill_values = {k: v - 1 for k, v in self.num_categories.items()}
|
43
|
+
logger.debug(f"dropped values will be replaced by {fill_values}")
|
32
44
|
|
33
45
|
@property
|
34
46
|
def categorical_key(self):
|
@@ -60,6 +72,7 @@ class CategoricalProcessor:
|
|
60
72
|
def process_one_sample(
|
61
73
|
self,
|
62
74
|
categorical_features: Dict[str, int],
|
75
|
+
is_training: bool,
|
63
76
|
) -> Dict:
|
64
77
|
"""
|
65
78
|
Process one sample's categorical features. Assume the categorical features
|
@@ -69,6 +82,8 @@ class CategoricalProcessor:
|
|
69
82
|
----------
|
70
83
|
categorical_features
|
71
84
|
Categorical features of one sample.
|
85
|
+
is_training
|
86
|
+
Whether to do processing in the training mode.
|
72
87
|
|
73
88
|
Returns
|
74
89
|
-------
|
@@ -80,6 +95,17 @@ class CategoricalProcessor:
|
|
80
95
|
for i, col_name in enumerate(categorical_features.keys()):
|
81
96
|
ret[f"{self.categorical_column_prefix}_{col_name}"] = i
|
82
97
|
|
98
|
+
if is_training and self.dropout > 0:
|
99
|
+
categorical_features_copy = dict()
|
100
|
+
for k, v in categorical_features.items():
|
101
|
+
if random.uniform(0, 1) <= self.dropout:
|
102
|
+
categorical_features_copy[k] = self.num_categories[k] - 1
|
103
|
+
else:
|
104
|
+
categorical_features_copy[k] = v
|
105
|
+
categorical_features = categorical_features_copy
|
106
|
+
|
107
|
+
# make sure keys are in the same order
|
108
|
+
assert list(categorical_features.keys()) == list(self.num_categories.keys())
|
83
109
|
ret[self.categorical_key] = list(categorical_features.values())
|
84
110
|
|
85
111
|
return ret
|
@@ -87,7 +113,7 @@ class CategoricalProcessor:
|
|
87
113
|
def __call__(
|
88
114
|
self,
|
89
115
|
categorical_features: Dict[str, int],
|
90
|
-
|
116
|
+
sub_dtypes: Dict[str, str],
|
91
117
|
is_training: bool,
|
92
118
|
) -> Dict:
|
93
119
|
"""
|
@@ -97,13 +123,16 @@ class CategoricalProcessor:
|
|
97
123
|
----------
|
98
124
|
categorical_features
|
99
125
|
Categorical features of one sample.
|
100
|
-
|
101
|
-
The
|
126
|
+
sub_dtypes
|
127
|
+
The sub data types of all categorical columns.
|
102
128
|
is_training
|
103
|
-
Whether to do processing in the training mode.
|
129
|
+
Whether to do processing in the training mode.
|
104
130
|
|
105
131
|
Returns
|
106
132
|
-------
|
107
133
|
A dictionary containing one sample's processed categorical features.
|
108
134
|
"""
|
109
|
-
return self.process_one_sample(
|
135
|
+
return self.process_one_sample(
|
136
|
+
categorical_features=categorical_features,
|
137
|
+
is_training=is_training,
|
138
|
+
)
|
@@ -1,30 +1,24 @@
|
|
1
|
-
import importlib.util
|
2
1
|
import logging
|
3
2
|
import os
|
4
|
-
import re
|
5
|
-
import shutil
|
6
|
-
import subprocess
|
7
3
|
import warnings
|
8
|
-
from
|
9
|
-
from typing import Any, Dict, List, Optional, Union
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
10
5
|
|
11
6
|
import numpy as np
|
12
7
|
import PIL
|
13
8
|
import pytesseract
|
14
|
-
import torch
|
15
9
|
from numpy.typing import NDArray
|
16
|
-
from PIL import ImageFile
|
17
10
|
from torch import nn
|
18
11
|
from torchvision import transforms
|
19
12
|
|
20
|
-
from ..constants import
|
21
|
-
from .
|
22
|
-
from .
|
13
|
+
from ..constants import BBOX, DOCUMENT_PDF
|
14
|
+
from ..models.utils import get_pretrained_tokenizer
|
15
|
+
from .collator import PadCollator
|
16
|
+
from .process_image import ImageProcessor
|
23
17
|
|
24
18
|
logger = logging.getLogger(__name__)
|
25
19
|
|
26
20
|
|
27
|
-
class DocumentProcessor:
|
21
|
+
class DocumentProcessor(ImageProcessor):
|
28
22
|
"""
|
29
23
|
Prepare document data for Document Classification.
|
30
24
|
OCR (Optical character recognition) is applied to get the document texts and bounding boxes.
|
@@ -34,9 +28,8 @@ class DocumentProcessor:
|
|
34
28
|
def __init__(
|
35
29
|
self,
|
36
30
|
model: nn.Module,
|
37
|
-
|
38
|
-
|
39
|
-
norm_type: Optional[str] = None,
|
31
|
+
train_transforms: Union[List[str], Callable, List[Callable]],
|
32
|
+
val_transforms: Union[List[str], Callable, List[Callable]],
|
40
33
|
size: Optional[int] = None,
|
41
34
|
text_max_len: Optional[int] = 512,
|
42
35
|
missing_value_strategy: Optional[str] = "zero",
|
@@ -46,19 +39,10 @@ class DocumentProcessor:
|
|
46
39
|
----------
|
47
40
|
model
|
48
41
|
The model using this data processor.
|
49
|
-
|
42
|
+
train_transforms
|
50
43
|
A list of image transforms used in training. Note that the transform order matters.
|
51
|
-
|
44
|
+
val_transforms
|
52
45
|
A list of image transforms used in validation/test/prediction. Note that the transform order matters.
|
53
|
-
norm_type
|
54
|
-
How to normalize an image. We now support:
|
55
|
-
- inception
|
56
|
-
Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
|
57
|
-
- imagenet
|
58
|
-
Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
|
59
|
-
- clip
|
60
|
-
Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
|
61
|
-
std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
|
62
46
|
size
|
63
47
|
The width / height of a square image.
|
64
48
|
text_max_len
|
@@ -79,15 +63,16 @@ class DocumentProcessor:
|
|
79
63
|
|
80
64
|
# For document image processing.
|
81
65
|
self.size = size
|
82
|
-
self.
|
83
|
-
self.
|
84
|
-
self.mean
|
66
|
+
self.train_transforms = train_transforms
|
67
|
+
self.val_transforms = val_transforms
|
68
|
+
self.mean = model.image_mean
|
69
|
+
self.std = model.image_std
|
85
70
|
self.normalization = transforms.Normalize(self.mean, self.std)
|
86
|
-
self.train_processor = construct_image_processor(
|
87
|
-
size=self.size, normalization=self.normalization, image_transforms=self.
|
71
|
+
self.train_processor = self.construct_image_processor(
|
72
|
+
size=self.size, normalization=self.normalization, image_transforms=self.train_transforms
|
88
73
|
)
|
89
|
-
self.val_processor = construct_image_processor(
|
90
|
-
size=self.size, normalization=self.normalization, image_transforms=self.
|
74
|
+
self.val_processor = self.construct_image_processor(
|
75
|
+
size=self.size, normalization=self.normalization, image_transforms=self.val_transforms
|
91
76
|
)
|
92
77
|
|
93
78
|
self.missing_value_strategy = missing_value_strategy
|
@@ -359,6 +344,47 @@ class DocumentProcessor:
|
|
359
344
|
|
360
345
|
return ret
|
361
346
|
|
347
|
+
def save_tokenizer(
|
348
|
+
self,
|
349
|
+
path: str,
|
350
|
+
):
|
351
|
+
"""
|
352
|
+
Save the text tokenizer and record its relative paths, e.g, hf_text.
|
353
|
+
|
354
|
+
Parameters
|
355
|
+
----------
|
356
|
+
path
|
357
|
+
The root path of saving.
|
358
|
+
|
359
|
+
"""
|
360
|
+
save_path = os.path.join(path, self.prefix)
|
361
|
+
self.tokenizer.save_pretrained(save_path)
|
362
|
+
self.tokenizer = self.prefix
|
363
|
+
|
364
|
+
def load_tokenizer(
|
365
|
+
self,
|
366
|
+
path: str,
|
367
|
+
):
|
368
|
+
"""
|
369
|
+
Load saved text tokenizers. If text/ner processors already have tokenizers,
|
370
|
+
then do nothing.
|
371
|
+
|
372
|
+
Parameters
|
373
|
+
----------
|
374
|
+
path
|
375
|
+
The root path of loading.
|
376
|
+
|
377
|
+
Returns
|
378
|
+
-------
|
379
|
+
A list of text/ner processors with tokenizers loaded.
|
380
|
+
"""
|
381
|
+
if isinstance(self.tokenizer, str):
|
382
|
+
load_path = os.path.join(path, self.tokenizer)
|
383
|
+
self.tokenizer = get_pretrained_tokenizer(
|
384
|
+
tokenizer_name=self.tokenizer_name,
|
385
|
+
checkpoint_name=load_path,
|
386
|
+
)
|
387
|
+
|
362
388
|
def __call__(
|
363
389
|
self,
|
364
390
|
all_features: Dict[str, Union[NDArray, list]],
|