autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -14,17 +14,14 @@ from sklearn.preprocessing import MinMaxScaler, StandardScaler
14
14
  from autogluon.features import CategoryFeatureGenerator
15
15
 
16
16
  from ..constants import (
17
- AUTOMM,
18
17
  CATEGORICAL,
19
18
  DOCUMENT,
20
- DOCUMENT_IMAGE,
21
19
  IDENTIFIER,
22
20
  IMAGE,
23
21
  IMAGE_BASE64_STR,
24
22
  IMAGE_BYTEARRAY,
25
23
  IMAGE_PATH,
26
24
  LABEL,
27
- NER,
28
25
  NER_ANNOTATION,
29
26
  NULL,
30
27
  NUMERICAL,
@@ -73,19 +70,17 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
73
70
 
74
71
  if label_column:
75
72
  if label_generator is None:
76
- self._label_generator = CustomLabelEncoder(
77
- positive_class=OmegaConf.select(config, "pos_label", default=None)
78
- )
73
+ self._label_generator = CustomLabelEncoder(positive_class=config.pos_label)
79
74
  else:
80
75
  self._label_generator = label_generator
81
76
 
82
77
  # Scaler used for numerical labels
83
- numerical_label_preprocessing = OmegaConf.select(config, "label.numerical_label_preprocessing")
78
+ numerical_label_preprocessing = config.label.numerical_preprocessing
84
79
  if numerical_label_preprocessing == "minmaxscaler":
85
80
  self._label_scaler = MinMaxScaler()
86
81
  elif numerical_label_preprocessing == "standardscaler":
87
82
  self._label_scaler = StandardScaler()
88
- elif numerical_label_preprocessing is None or numerical_label_preprocessing.lower() == "none":
83
+ elif numerical_label_preprocessing is None:
89
84
  self._label_scaler = StandardScaler(with_mean=False, with_std=False)
90
85
  else:
91
86
  raise ValueError(
@@ -135,8 +130,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
135
130
  # Some columns will be ignored
136
131
  self._ignore_columns_set = set()
137
132
  self._text_feature_names = []
138
- self._categorical_feature_names = []
139
- self._categorical_num_categories = []
133
+ self._categorical_num_categories = dict()
140
134
  self._numerical_feature_names = []
141
135
  self._image_feature_names = []
142
136
  self._rois_feature_names = []
@@ -154,10 +148,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
154
148
 
155
149
  @property
156
150
  def image_path_names(self):
157
- if hasattr(self, "_image_path_names"):
158
- return self._image_path_names
159
- else:
160
- return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
151
+ return [col_name for col_name in self._image_feature_names if self._column_types[col_name] == IMAGE_PATH]
161
152
 
162
153
  @property
163
154
  def rois_feature_names(self):
@@ -173,7 +164,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
173
164
 
174
165
  @property
175
166
  def image_feature_names(self):
176
- return self._image_path_names if hasattr(self, "_image_path_names") else self._image_feature_names
167
+ return self._image_feature_names
177
168
 
178
169
  @property
179
170
  def text_feature_names(self):
@@ -181,12 +172,21 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
181
172
 
182
173
  @property
183
174
  def categorical_feature_names(self):
184
- return self._categorical_feature_names
175
+ return list(self.categorical_num_categories.keys())
185
176
 
186
177
  @property
187
178
  def numerical_feature_names(self):
188
179
  return self._numerical_feature_names
189
180
 
181
+ @property
182
+ def numerical_fill_values(self):
183
+ ret = dict()
184
+ for col_name in self._numerical_feature_names:
185
+ generator = self._feature_generators[col_name]
186
+ ret[col_name] = generator.transform(np.full([1, 1], np.nan))[:, 0][0]
187
+
188
+ return ret
189
+
190
190
  @property
191
191
  def document_feature_names(self):
192
192
  # Added for backward compatibility.
@@ -216,17 +216,12 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
216
216
 
217
217
  @property
218
218
  def required_feature_names(self):
219
- image_feature_names = (
220
- self._image_path_names if hasattr(self, "_image_path_names") else self._image_feature_names
221
- )
222
- rois_feature_names = self._rois_feature_names if hasattr(self, "_rois_feature_names") else []
223
-
224
219
  return (
225
- image_feature_names
220
+ self._image_feature_names
226
221
  + self._text_feature_names
227
222
  + self._numerical_feature_names
228
- + self._categorical_feature_names
229
- + rois_feature_names
223
+ + self.categorical_feature_names
224
+ + self._rois_feature_names
230
225
  )
231
226
 
232
227
  @property
@@ -268,16 +263,13 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
268
263
 
269
264
  def get_column_names(self, modality: str):
270
265
  if modality.startswith(IMAGE):
271
- if hasattr(self, "_image_path_names"):
272
- return self._image_path_names
273
- else:
274
- return self._image_feature_names
266
+ return self._image_feature_names
275
267
  elif modality == ROIS:
276
268
  return self._rois_feature_names
277
269
  elif modality == TEXT:
278
270
  return self._text_feature_names
279
271
  elif modality == CATEGORICAL:
280
- return self._categorical_feature_names
272
+ return self.categorical_feature_names
281
273
  elif modality == NUMERICAL:
282
274
  return self._numerical_feature_names
283
275
  elif modality.startswith(DOCUMENT):
@@ -344,8 +336,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
344
336
  continue
345
337
  num_categories = len(generator.category_map[col_name])
346
338
  # Add one unknown category
347
- self._categorical_num_categories.append(num_categories + 1)
348
- self._categorical_feature_names.append(col_name)
339
+ self._categorical_num_categories[col_name] = num_categories + 1
349
340
  elif col_type == NUMERICAL:
350
341
  processed_data = pd.to_numeric(col_value)
351
342
  if len(processed_data.unique()) == 1:
@@ -392,7 +383,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
392
383
  elif self.label_type == NUMERICAL:
393
384
  y = pd.to_numeric(y).to_numpy()
394
385
  self._label_scaler.fit(np.expand_dims(y, axis=-1))
395
- elif self.label_type == ROIS or self.label_type == SEMANTIC_SEGMENTATION_GT:
386
+ elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
396
387
  pass # Do nothing. TODO: Shall we call fit here?
397
388
  elif self.label_type == NER_ANNOTATION:
398
389
  # If there are ner annotations and text columns but no NER feature columns,
@@ -426,6 +417,24 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
426
417
  if y is not None:
427
418
  self._fit_y(y=y, X=X)
428
419
 
420
+ @staticmethod
421
+ def convert_categorical_to_text(col_value: pd.Series, template: str, col_name: str):
422
+ # TODO: do we need to consider whether categorical values are valid text?
423
+ col_value = col_value.astype("object")
424
+ if template == "direct":
425
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
426
+ elif template == "list":
427
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + ": " + str(ele))
428
+ elif template == "text":
429
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else col_name + " is " + str(ele))
430
+ elif template == "latex":
431
+ processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele) + " & ")
432
+ else:
433
+ raise ValueError(
434
+ f"Unsupported template {template} for converting categorical data into text. Select one from: ['direct', 'list', 'text', 'latex']."
435
+ )
436
+ return processed_data
437
+
429
438
  def transform_text(
430
439
  self,
431
440
  df: pd.DataFrame,
@@ -455,10 +464,15 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
455
464
  for col_name in self._text_feature_names:
456
465
  col_value = df[col_name]
457
466
  col_type = self._column_types[col_name]
458
- if col_type == TEXT or col_type == CATEGORICAL:
459
- # TODO: do we need to consider whether categorical values are valid text?
467
+ if col_type == TEXT:
460
468
  col_value = col_value.astype("object")
461
469
  processed_data = col_value.apply(lambda ele: "" if pd.isnull(ele) else str(ele))
470
+ elif col_type == CATEGORICAL:
471
+ processed_data = self.convert_categorical_to_text(
472
+ col_value=col_value,
473
+ template=self._config.categorical.convert_to_text_template,
474
+ col_name=col_name,
475
+ )
462
476
  elif col_type == NUMERICAL:
463
477
  processed_data = pd.to_numeric(col_value).apply("{:.3f}".format)
464
478
  elif col_type == f"{TEXT}_{IDENTIFIER}":
@@ -710,7 +724,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
710
724
  self._fit_called or self._fit_x_called
711
725
  ), "You will need to first call preprocessor.fit before calling preprocessor.transform_categorical."
712
726
  categorical_features = {}
713
- for col_name, num_category in zip(self._categorical_feature_names, self._categorical_num_categories):
727
+ for col_name, num_category in self._categorical_num_categories.items():
714
728
  col_value = df[col_name]
715
729
  processed_data = col_value.astype("category")
716
730
  generator = self._feature_generators[col_name]
@@ -757,7 +771,7 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
757
771
  elif self.label_type == NUMERICAL:
758
772
  y = pd.to_numeric(y_df).to_numpy()
759
773
  y = self._label_scaler.transform(np.expand_dims(y, axis=-1))[:, 0].astype(np.float32)
760
- elif self.label_type == ROIS or self.label_type == SEMANTIC_SEGMENTATION_GT:
774
+ elif self.label_type in [ROIS, SEMANTIC_SEGMENTATION_GT]:
761
775
  y = y_df.to_list()
762
776
  elif self.label_type == NER_ANNOTATION:
763
777
  y = self._label_generator.transform(y_df)
@@ -866,8 +880,11 @@ class MultiModalFeaturePreprocessor(TransformerMixin, BaseEstimator):
866
880
  ), "You will need to first call preprocessor.fit_y() before calling preprocessor.transform_prediction."
867
881
 
868
882
  if self.label_type == CATEGORICAL:
869
- assert y_pred.shape[1] >= 2
870
- y_pred = y_pred.argmax(axis=1)
883
+ assert len(y_pred.shape) <= 2
884
+ if len(y_pred.shape) == 2 and y_pred.shape[1] >= 2:
885
+ y_pred = y_pred.argmax(axis=1)
886
+ else:
887
+ y_pred = (y_pred > 0.5).astype(int)
871
888
  # Transform the predicted label back to the original space (e.g., string values)
872
889
  if inverse_categorical:
873
890
  y_pred = self._label_generator.inverse_transform(y_pred)
@@ -1,11 +1,14 @@
1
+ import logging
2
+ import random
1
3
  from typing import Any, Dict, List, Optional, Union
2
4
 
3
- import numpy as np
4
5
  from torch import nn
5
6
 
6
7
  from ..constants import CATEGORICAL, COLUMN
7
8
  from .collator import StackCollator, TupleCollator
8
9
 
10
+ logger = logging.getLogger(__name__)
11
+
9
12
 
10
13
  class CategoricalProcessor:
11
14
  """
@@ -18,6 +21,7 @@ class CategoricalProcessor:
18
21
  self,
19
22
  model: nn.Module,
20
23
  requires_column_info: bool = False,
24
+ dropout: Optional[float] = 0,
21
25
  ):
22
26
  """
23
27
  Parameters
@@ -27,8 +31,16 @@ class CategoricalProcessor:
27
31
  requires_column_info
28
32
  Whether to require feature column information in dataloader.
29
33
  """
34
+ logger.debug(f"initializing categorical processor for model {model.prefix}")
30
35
  self.prefix = model.prefix
31
36
  self.requires_column_info = requires_column_info
37
+ self.num_categories = model.num_categories
38
+ self.dropout = dropout
39
+ assert 0 <= self.dropout <= 1
40
+ if self.dropout > 0:
41
+ logger.debug(f"categorical value dropout probability: {self.dropout}")
42
+ fill_values = {k: v - 1 for k, v in self.num_categories.items()}
43
+ logger.debug(f"dropped values will be replaced by {fill_values}")
32
44
 
33
45
  @property
34
46
  def categorical_key(self):
@@ -60,6 +72,7 @@ class CategoricalProcessor:
60
72
  def process_one_sample(
61
73
  self,
62
74
  categorical_features: Dict[str, int],
75
+ is_training: bool,
63
76
  ) -> Dict:
64
77
  """
65
78
  Process one sample's categorical features. Assume the categorical features
@@ -69,6 +82,8 @@ class CategoricalProcessor:
69
82
  ----------
70
83
  categorical_features
71
84
  Categorical features of one sample.
85
+ is_training
86
+ Whether to do processing in the training mode.
72
87
 
73
88
  Returns
74
89
  -------
@@ -80,6 +95,17 @@ class CategoricalProcessor:
80
95
  for i, col_name in enumerate(categorical_features.keys()):
81
96
  ret[f"{self.categorical_column_prefix}_{col_name}"] = i
82
97
 
98
+ if is_training and self.dropout > 0:
99
+ categorical_features_copy = dict()
100
+ for k, v in categorical_features.items():
101
+ if random.uniform(0, 1) <= self.dropout:
102
+ categorical_features_copy[k] = self.num_categories[k] - 1
103
+ else:
104
+ categorical_features_copy[k] = v
105
+ categorical_features = categorical_features_copy
106
+
107
+ # make sure keys are in the same order
108
+ assert list(categorical_features.keys()) == list(self.num_categories.keys())
83
109
  ret[self.categorical_key] = list(categorical_features.values())
84
110
 
85
111
  return ret
@@ -87,7 +113,7 @@ class CategoricalProcessor:
87
113
  def __call__(
88
114
  self,
89
115
  categorical_features: Dict[str, int],
90
- feature_modalities: Dict[str, Union[int, float, list]],
116
+ sub_dtypes: Dict[str, str],
91
117
  is_training: bool,
92
118
  ) -> Dict:
93
119
  """
@@ -97,13 +123,16 @@ class CategoricalProcessor:
97
123
  ----------
98
124
  categorical_features
99
125
  Categorical features of one sample.
100
- feature_modalities
101
- The modality of the feature columns.
126
+ sub_dtypes
127
+ The sub data types of all categorical columns.
102
128
  is_training
103
- Whether to do processing in the training mode. This unused flag is for the API compatibility.
129
+ Whether to do processing in the training mode.
104
130
 
105
131
  Returns
106
132
  -------
107
133
  A dictionary containing one sample's processed categorical features.
108
134
  """
109
- return self.process_one_sample(categorical_features)
135
+ return self.process_one_sample(
136
+ categorical_features=categorical_features,
137
+ is_training=is_training,
138
+ )
@@ -1,30 +1,24 @@
1
- import importlib.util
2
1
  import logging
3
2
  import os
4
- import re
5
- import shutil
6
- import subprocess
7
3
  import warnings
8
- from io import BytesIO
9
- from typing import Any, Dict, List, Optional, Union
4
+ from typing import Any, Callable, Dict, List, Optional, Union
10
5
 
11
6
  import numpy as np
12
7
  import PIL
13
8
  import pytesseract
14
- import torch
15
9
  from numpy.typing import NDArray
16
- from PIL import ImageFile
17
10
  from torch import nn
18
11
  from torchvision import transforms
19
12
 
20
- from ..constants import AUTOMM, BBOX, DOCUMENT_PDF
21
- from .collator import PadCollator, StackCollator
22
- from .utils import construct_image_processor, image_mean_std
13
+ from ..constants import BBOX, DOCUMENT_PDF
14
+ from ..models.utils import get_pretrained_tokenizer
15
+ from .collator import PadCollator
16
+ from .process_image import ImageProcessor
23
17
 
24
18
  logger = logging.getLogger(__name__)
25
19
 
26
20
 
27
- class DocumentProcessor:
21
+ class DocumentProcessor(ImageProcessor):
28
22
  """
29
23
  Prepare document data for Document Classification.
30
24
  OCR (Optical character recognition) is applied to get the document texts and bounding boxes.
@@ -34,9 +28,8 @@ class DocumentProcessor:
34
28
  def __init__(
35
29
  self,
36
30
  model: nn.Module,
37
- train_transform_types: List[str],
38
- val_transform_types: List[str],
39
- norm_type: Optional[str] = None,
31
+ train_transforms: Union[List[str], Callable, List[Callable]],
32
+ val_transforms: Union[List[str], Callable, List[Callable]],
40
33
  size: Optional[int] = None,
41
34
  text_max_len: Optional[int] = 512,
42
35
  missing_value_strategy: Optional[str] = "zero",
@@ -46,19 +39,10 @@ class DocumentProcessor:
46
39
  ----------
47
40
  model
48
41
  The model using this data processor.
49
- train_transform_types
42
+ train_transforms
50
43
  A list of image transforms used in training. Note that the transform order matters.
51
- val_transform_types
44
+ val_transforms
52
45
  A list of image transforms used in validation/test/prediction. Note that the transform order matters.
53
- norm_type
54
- How to normalize an image. We now support:
55
- - inception
56
- Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
57
- - imagenet
58
- Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
59
- - clip
60
- Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
61
- std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
62
46
  size
63
47
  The width / height of a square image.
64
48
  text_max_len
@@ -79,15 +63,16 @@ class DocumentProcessor:
79
63
 
80
64
  # For document image processing.
81
65
  self.size = size
82
- self.train_transform_types = train_transform_types
83
- self.val_transform_types = val_transform_types
84
- self.mean, self.std = image_mean_std(norm_type)
66
+ self.train_transforms = train_transforms
67
+ self.val_transforms = val_transforms
68
+ self.mean = model.image_mean
69
+ self.std = model.image_std
85
70
  self.normalization = transforms.Normalize(self.mean, self.std)
86
- self.train_processor = construct_image_processor(
87
- size=self.size, normalization=self.normalization, image_transforms=self.train_transform_types
71
+ self.train_processor = self.construct_image_processor(
72
+ size=self.size, normalization=self.normalization, image_transforms=self.train_transforms
88
73
  )
89
- self.val_processor = construct_image_processor(
90
- size=self.size, normalization=self.normalization, image_transforms=self.val_transform_types
74
+ self.val_processor = self.construct_image_processor(
75
+ size=self.size, normalization=self.normalization, image_transforms=self.val_transforms
91
76
  )
92
77
 
93
78
  self.missing_value_strategy = missing_value_strategy
@@ -359,6 +344,47 @@ class DocumentProcessor:
359
344
 
360
345
  return ret
361
346
 
347
+ def save_tokenizer(
348
+ self,
349
+ path: str,
350
+ ):
351
+ """
352
+ Save the text tokenizer and record its relative paths, e.g, hf_text.
353
+
354
+ Parameters
355
+ ----------
356
+ path
357
+ The root path of saving.
358
+
359
+ """
360
+ save_path = os.path.join(path, self.prefix)
361
+ self.tokenizer.save_pretrained(save_path)
362
+ self.tokenizer = self.prefix
363
+
364
+ def load_tokenizer(
365
+ self,
366
+ path: str,
367
+ ):
368
+ """
369
+ Load saved text tokenizers. If text/ner processors already have tokenizers,
370
+ then do nothing.
371
+
372
+ Parameters
373
+ ----------
374
+ path
375
+ The root path of loading.
376
+
377
+ Returns
378
+ -------
379
+ A list of text/ner processors with tokenizers loaded.
380
+ """
381
+ if isinstance(self.tokenizer, str):
382
+ load_path = os.path.join(path, self.tokenizer)
383
+ self.tokenizer = get_pretrained_tokenizer(
384
+ tokenizer_name=self.tokenizer_name,
385
+ checkpoint_name=load_path,
386
+ )
387
+
362
388
  def __call__(
363
389
  self,
364
390
  all_features: Dict[str, Union[NDArray, list]],