autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,79 +1,53 @@
1
- import ast
2
- import codecs
3
- import copy
4
- import re
1
+ import logging
5
2
  import warnings
6
- from io import BytesIO
7
3
  from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
8
4
 
9
- import numpy as np
10
5
  import pandas as pd
11
- import PIL
12
- from omegaconf import ListConfig
13
- from text_unidecode import unidecode
14
- from timm.data.constants import (
15
- IMAGENET_DEFAULT_MEAN,
16
- IMAGENET_DEFAULT_STD,
17
- IMAGENET_INCEPTION_MEAN,
18
- IMAGENET_INCEPTION_STD,
19
- )
20
- from tokenizers import pre_tokenizers
21
- from torchvision import transforms
6
+ from omegaconf import DictConfig, OmegaConf
7
+ from torch import nn
8
+
9
+ from autogluon.core.utils import default_holdout_frac, generate_train_test_split_combined
10
+ from autogluon.core.utils.loaders import load_pd
22
11
 
23
12
  from ..constants import (
24
- CLIP_IMAGE_MEAN,
25
- CLIP_IMAGE_STD,
13
+ BINARY,
14
+ CATEGORICAL,
15
+ DEFAULT_SHOT,
16
+ DOCUMENT,
17
+ FEW_SHOT,
26
18
  IDENTIFIER,
27
19
  IMAGE,
28
- IMAGE_BYTEARRAY,
29
20
  IMAGE_PATH,
21
+ LABEL,
30
22
  MMDET_IMAGE,
31
23
  MMLAB_MODELS,
24
+ MULTICLASS,
25
+ NER_ANNOTATION,
26
+ NER_TEXT,
27
+ NUMERICAL,
28
+ REGRESSION,
29
+ ROIS,
30
+ SAM,
31
+ SEMANTIC_SEGMENTATION_IMG,
32
+ TEXT,
33
+ TEXT_NER,
32
34
  )
33
35
  from .collator import DictCollator
36
+ from .infer_types import is_image_column
37
+ from .label_encoder import NerLabelEncoder
38
+ from .mixup import MixupModule
34
39
  from .preprocess_dataframe import MultiModalFeaturePreprocessor
40
+ from .process_categorical import CategoricalProcessor
41
+ from .process_document import DocumentProcessor
42
+ from .process_image import ImageProcessor
43
+ from .process_label import LabelProcessor
44
+ from .process_mmlab import MMDetProcessor
45
+ from .process_ner import NerProcessor
46
+ from .process_numerical import NumericalProcessor
47
+ from .process_semantic_seg_img import SemanticSegImageProcessor
48
+ from .process_text import TextProcessor
35
49
 
36
- try:
37
- from torchvision.transforms import InterpolationMode
38
-
39
- BICUBIC = InterpolationMode.BICUBIC
40
- NEAREST = InterpolationMode.NEAREST
41
- except ImportError:
42
- BICUBIC = PIL.Image.BICUBIC
43
- NEAREST = PIL.Image.NEAREST
44
-
45
- from .randaug import RandAugment
46
- from .trivial_augmenter import TrivialAugment
47
-
48
-
49
- def extract_value_from_config(
50
- config: Dict,
51
- keys: Tuple[str, ...],
52
- ):
53
- """
54
- Traverse a config dictionary to get some hyper-parameter's value.
55
-
56
- Parameters
57
- ----------
58
- config
59
- A config dictionary.
60
- keys
61
- The possible names of a hyper-parameter.
62
-
63
- Returns
64
- -------
65
- The hyper-parameter value.
66
- """
67
- result = []
68
- for k, v in config.items():
69
- if k in keys:
70
- result.append(v)
71
- elif isinstance(v, dict):
72
- result += extract_value_from_config(v, keys)
73
- else:
74
- pass
75
-
76
- return result
50
+ logger = logging.getLogger(__name__)
77
51
 
78
52
 
79
53
  def get_collate_fn(
@@ -165,7 +139,7 @@ def apply_df_preprocessor(
165
139
  def apply_data_processor(
166
140
  per_sample_features: Dict,
167
141
  data_processors: Dict,
168
- feature_modalities: Dict,
142
+ data_types: Dict,
169
143
  is_training: bool,
170
144
  load_only=False,
171
145
  ):
@@ -175,9 +149,11 @@ def apply_data_processor(
175
149
  Parameters
176
150
  ----------
177
151
  per_sample_features
178
- Modality features of one sample.
152
+ Features of one sample.
179
153
  data_processors
180
154
  A dict of data processors.
155
+ data_types
156
+ Data types of all columns.
181
157
  is_training
182
158
  Whether is training.
183
159
  load_only
@@ -194,14 +170,14 @@ def apply_data_processor(
194
170
  sample_features.update(
195
171
  per_model_processor(
196
172
  per_sample_features[per_modality],
197
- feature_modalities[per_modality],
173
+ data_types[per_modality],
198
174
  is_training=is_training,
199
175
  load_only=load_only,
200
176
  )
201
177
  if per_model_processor.prefix.lower().startswith(MMDET_IMAGE)
202
178
  else per_model_processor(
203
179
  per_sample_features[per_modality],
204
- feature_modalities[per_modality],
180
+ data_types[per_modality],
205
181
  is_training=is_training,
206
182
  )
207
183
  )
@@ -250,366 +226,616 @@ def get_per_sample_features(
250
226
  return ret
251
227
 
252
228
 
253
- def register_encoding_decoding_error_handlers() -> None:
254
- """Register the encoding and decoding error handlers for `utf-8` and `cp1252`."""
229
+ def default_holdout_frac(num_train_rows, hyperparameter_tune=False):
230
+ """Returns default holdout_frac used in fit().
231
+ Between row count 5,000 and 25,000 keep 0.1 holdout_frac, as we want to grow validation set to a stable 2500 examples.
232
+ """
233
+ if num_train_rows < 5000:
234
+ holdout_frac = max(0.1, min(0.2, 500.0 / num_train_rows))
235
+ else:
236
+ holdout_frac = max(0.01, min(0.1, 2500.0 / num_train_rows))
255
237
 
256
- def replace_encoding_with_utf8(error: UnicodeError) -> Tuple[bytes, int]:
257
- return error.object[error.start : error.end].encode("utf-8"), error.end
238
+ if hyperparameter_tune:
239
+ holdout_frac = min(
240
+ 0.2, holdout_frac * 2
241
+ ) # We want to allocate more validation data for HPO to avoid overfitting
258
242
 
259
- def replace_decoding_with_cp1252(error: UnicodeError) -> Tuple[str, int]:
260
- return error.object[error.start : error.end].decode("cp1252"), error.end
243
+ return holdout_frac
261
244
 
262
- codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
263
- codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)
264
245
 
246
+ def init_df_preprocessor(
247
+ config: DictConfig,
248
+ column_types: Dict,
249
+ label_column: Optional[str] = None,
250
+ train_df_x: Optional[pd.DataFrame] = None,
251
+ train_df_y: Optional[pd.Series] = None,
252
+ ):
253
+ """
254
+ Initialize the dataframe preprocessor by calling .fit().
265
255
 
266
- def normalize_txt(text: str) -> str:
267
- """Resolve the encoding problems and normalize the abnormal characters."""
256
+ Parameters
257
+ ----------
258
+ config
259
+ A DictConfig containing only the data config.
260
+ column_types
261
+ A dictionary that maps column names to their data types.
262
+ For example: `column_types = {"item_name": "text", "image": "image_path",
263
+ "product_description": "text", "height": "numerical"}`
264
+ may be used for a table with columns: "item_name", "brand", "product_description", and "height".
265
+ label_column
266
+ Name of the column that contains the target variable to predict.
267
+ train_df_x
268
+ A pd.DataFrame containing only the feature columns.
269
+ train_df_y
270
+ A pd.Series object containing only the label column.
268
271
 
269
- text = (
270
- text.encode("raw_unicode_escape")
271
- .decode("utf-8", errors="replace_decoding_with_cp1252")
272
- .encode("cp1252", errors="replace_encoding_with_utf8")
273
- .decode("utf-8", errors="replace_decoding_with_cp1252")
272
+ Returns
273
+ -------
274
+ Initialized dataframe preprocessor.
275
+ """
276
+ if label_column in column_types and column_types[label_column] == NER_ANNOTATION:
277
+ label_generator = NerLabelEncoder(config)
278
+ else:
279
+ label_generator = None
280
+
281
+ df_preprocessor = MultiModalFeaturePreprocessor(
282
+ config=config.data,
283
+ column_types=column_types,
284
+ label_column=label_column,
285
+ label_generator=label_generator,
286
+ )
287
+ df_preprocessor.fit(
288
+ X=train_df_x,
289
+ y=train_df_y,
274
290
  )
275
- text = unidecode(text)
276
- return text
291
+
292
+ return df_preprocessor
277
293
 
278
294
 
279
- def process_ner_annotations(ner_annotations, ner_text, entity_map, tokenizer, is_eval=False):
295
+ def get_image_transforms(model_config: DictConfig, model_name: str, advanced_hyperparameters: Dict):
280
296
  """
281
- Generate token-level/word-level labels with given text and NER annotations.
297
+ Get the image transforms of one image-related model.
298
+ Use the transforms in advanced_hyperparameters with higher priority.
282
299
 
283
300
  Parameters
284
301
  ----------
285
- ner_annotations
286
- The NER annotations.
287
- ner_text
288
- The corresponding raw text.
289
- entity_map
290
- The map between tags and tag indexes. e.g., {"PER":2, "LOC":3}.
291
- tokenizer
292
- The tokenizer to be used.
293
- is_eval
294
- Whether it is for evaluation or not, default: False
302
+ model_config
303
+ Config of one model.
304
+ model_name
305
+ Name of one model.
306
+ advanced_hyperparameters
307
+ The advanced hyperparameters whose values are complex objects.
295
308
 
296
309
  Returns
297
310
  -------
298
- Token-level/word-level labels and text features.
311
+ The image transforms used in training and validation.
299
312
  """
300
- col_tokens, token_to_word_mappings, word_offsets = tokenize_ner_text(ner_text, tokenizer)
301
- num_words = len(set(token_to_word_mappings)) - 1
302
- word_label = [1] * num_words
303
- # TODO: Potentially optimize word label generation via binary search
304
- b_prefix = "B-"
305
- i_prefix = "I-"
306
- for annot in ner_annotations:
307
- custom_offset = annot[0]
308
- custom_label = annot[1]
309
- is_start_word = True
310
- for idx, word_offset in enumerate(word_offsets[:num_words, :]):
311
- # support multiple words in an annotated offset range.
312
- # Allow partial overlapping between custom annotations and pretokenized words.
313
- if (word_offset[0] < custom_offset[1]) and (custom_offset[0] < word_offset[1]):
314
- if not (
315
- re.match(b_prefix, custom_label, re.IGNORECASE) or re.match(i_prefix, custom_label, re.IGNORECASE)
316
- ):
317
- if is_start_word and b_prefix + custom_label in entity_map:
318
- word_label[idx] = entity_map[b_prefix + custom_label]
319
- is_start_word = False
320
- elif i_prefix + custom_label in entity_map:
321
- word_label[idx] = entity_map[i_prefix + custom_label]
322
- else:
323
- if custom_label in entity_map:
324
- word_label[idx] = entity_map[custom_label]
325
-
326
- token_label = [0] * len(col_tokens.input_ids)
327
- temp = set()
328
- counter = 0
329
- for idx, token_to_word in enumerate(token_to_word_mappings):
330
- if token_to_word != -1 and token_to_word not in temp:
331
- temp.add(token_to_word)
332
- token_label[idx] = word_label[counter]
333
- counter += 1
334
- if not is_eval:
335
- label = token_label # return token-level labels for training
313
+ train_transform_key = f"model.{model_name}.train_transforms"
314
+ val_transform_key = f"model.{model_name}.val_transforms"
315
+ if advanced_hyperparameters and train_transform_key in advanced_hyperparameters:
316
+ train_transforms = advanced_hyperparameters[train_transform_key]
317
+ else:
318
+ train_transforms = model_config.train_transforms
319
+ train_transforms = list(train_transforms)
320
+
321
+ if advanced_hyperparameters and val_transform_key in advanced_hyperparameters:
322
+ val_transforms = advanced_hyperparameters[val_transform_key]
336
323
  else:
337
- label = word_label # return word-level labels for evaluation
324
+ val_transforms = model_config.val_transforms
325
+ val_transforms = list(val_transforms)
338
326
 
339
- return label, col_tokens, token_to_word_mappings, word_offsets
327
+ return train_transforms, val_transforms
340
328
 
341
329
 
342
- def tokenize_ner_text(text, tokenizer):
330
+ def create_data_processor(
331
+ data_type: str,
332
+ config: DictConfig,
333
+ model: nn.Module,
334
+ advanced_hyperparameters: Optional[Dict] = None,
335
+ ):
343
336
  """
344
- Tokenization process for the NER task. It will be used for the token-level label generation
345
- and the input text tokenization.
337
+ Create one data processor based on the data type and model.
346
338
 
347
339
  Parameters
348
340
  ----------
349
- text
350
- The raw text data.
351
- tokenizer
352
- The tokenizer to be used.
341
+ data_type
342
+ Data type.
343
+ config
344
+ The config may contain information required by creating a data processor.
345
+ In future, we may move the required config information into the model.config
346
+ to make the data processor conditioned only on the model itself.
347
+ model
348
+ The model.
353
349
 
354
350
  Returns
355
351
  -------
356
- The output of tokenizer and word offsets.
352
+ One data processor.
357
353
  """
358
- # pre-tokenization is required for NER token-level label generation.
359
- words_with_offsets = pre_tokenizers.BertPreTokenizer().pre_tokenize_str(text)
360
- words_with_offsets = is_space_counted(words_with_offsets) if len(words_with_offsets) > 1 else words_with_offsets
361
- words = [word for word, offset in words_with_offsets]
362
- word_offsets = np.array([[offset[0], offset[1]] for word, offset in words_with_offsets], dtype=np.int32)
363
- col_tokens = tokenizer(
364
- words,
365
- is_split_into_words=True,
366
- return_offsets_mapping=True,
367
- padding="max_length",
368
- truncation=True,
369
- max_length=tokenizer.model_max_length,
370
- return_token_type_ids=True,
371
- )
372
- offset_mapping = np.array(col_tokens.offset_mapping, dtype=np.int32)
373
- if len(words_with_offsets) > 1:
374
- if offset_mapping.shape[0] > len(words):
375
- word_offsets = np.pad(word_offsets, ((0, offset_mapping.shape[0] - len(words)), (0, 0)), "constant")
376
- # token to word mappings: it will tell us which token belongs to which word.
377
- token_to_word_mappings = [i if i != None else -1 for i in col_tokens.word_ids()]
378
- if len(set(token_to_word_mappings)) != len(words) + 1:
379
- warnings.warn(f"The token to word mappings are incorrect!")
354
+ model_config = getattr(config.model, model.prefix)
355
+ if data_type == IMAGE:
356
+ train_transforms, val_transforms = get_image_transforms(
357
+ model_config=model_config,
358
+ model_name=model.prefix,
359
+ advanced_hyperparameters=advanced_hyperparameters,
360
+ )
361
+ data_processor = ImageProcessor(
362
+ model=model,
363
+ train_transforms=train_transforms,
364
+ val_transforms=val_transforms,
365
+ max_image_num_per_column=model_config.max_image_num_per_column,
366
+ missing_value_strategy=config.data.image.missing_value_strategy,
367
+ dropout=config.data.modality_dropout,
368
+ )
369
+ elif data_type == TEXT:
370
+ data_processor = TextProcessor(
371
+ model=model,
372
+ insert_sep=model_config.insert_sep,
373
+ stochastic_chunk=model_config.stochastic_chunk,
374
+ text_detection_length=model_config.text_aug_detect_length,
375
+ text_trivial_aug_maxscale=model_config.text_trivial_aug_maxscale,
376
+ train_augment_types=model_config.text_train_augment_types,
377
+ normalize_text=config.data.text.normalize_text,
378
+ template_config=config.data.templates,
379
+ dropout=config.data.modality_dropout,
380
+ )
381
+ elif data_type == CATEGORICAL:
382
+ data_processor = CategoricalProcessor(
383
+ model=model,
384
+ dropout=config.data.modality_dropout,
385
+ )
386
+ elif data_type == NUMERICAL:
387
+ data_processor = NumericalProcessor(
388
+ model=model,
389
+ merge=model_config.merge,
390
+ dropout=config.data.modality_dropout,
391
+ )
392
+ elif data_type == LABEL:
393
+ data_processor = LabelProcessor(model=model)
394
+ elif data_type == TEXT_NER:
395
+ data_processor = NerProcessor(
396
+ model=model,
397
+ max_len=model_config.max_text_len,
398
+ entity_map=config.entity_map,
399
+ )
400
+ elif data_type == ROIS:
401
+ data_processor = MMDetProcessor(
402
+ model=model,
403
+ max_img_num_per_col=model_config.max_img_num_per_col,
404
+ missing_value_strategy=config.data.image.missing_value_strategy,
405
+ )
406
+ elif data_type == DOCUMENT:
407
+ train_transforms, val_transforms = get_image_transforms(
408
+ model_config=model_config,
409
+ model_name=model.prefix,
410
+ advanced_hyperparameters=advanced_hyperparameters,
411
+ )
412
+ data_processor = DocumentProcessor(
413
+ model=model,
414
+ train_transforms=train_transforms,
415
+ val_transforms=val_transforms,
416
+ size=model_config.image_size,
417
+ text_max_len=model_config.max_text_len,
418
+ missing_value_strategy=config.data.document.missing_value_strategy,
419
+ )
420
+ elif data_type == SEMANTIC_SEGMENTATION_IMG:
421
+ data_processor = SemanticSegImageProcessor(
422
+ model=model,
423
+ img_transforms=model_config.img_transforms,
424
+ gt_transforms=model_config.gt_transforms,
425
+ train_transforms=model_config.train_transforms,
426
+ val_transforms=model_config.val_transforms,
427
+ ignore_label=model_config.ignore_label,
428
+ )
380
429
  else:
381
- # If pre_tokenizer does not give word offsets, use word_ids and offset_mappings instead.
382
- word_offsets = np.append(offset_mapping[1:], [[0, 0]], axis=0)
383
- word_idx = np.arange(len(col_tokens.word_ids()) - col_tokens.word_ids().count(None))
384
- token_to_word_mappings = [
385
- val + word_idx[idx - 1] if val != None else -1 for idx, val in enumerate(col_tokens.word_ids())
386
- ]
430
+ raise ValueError(f"unknown data type: {data_type}")
387
431
 
388
- return col_tokens, token_to_word_mappings, word_offsets
432
+ return data_processor
389
433
 
390
434
 
391
- def is_space_counted(words_with_offsets):
435
+ def create_fusion_data_processors(
436
+ config: DictConfig,
437
+ model: nn.Module,
438
+ requires_label: Optional[bool] = True,
439
+ requires_data: Optional[bool] = True,
440
+ advanced_hyperparameters: Optional[Dict] = None,
441
+ ):
392
442
  """
393
- Some tokenizers will count space into words for example.
394
- Given text: 'hello world', normal bert will output: [('hello', (0, 5)), ('world', (6, 11))]
395
- while some checkpoint will output: [('▁hello', (0, 5)), ('▁world', (5, 11))]
396
- This will lead to inconsistency issue during labelling, details can be found here:
397
- https://github.com/huggingface/transformers/issues/18111
443
+ Create the data processors for late-fusion models. This function creates one processor for
444
+ each modality of each model. For example, if one model config contains BERT, ViT, and CLIP, then
445
+ BERT would have its own text processor, ViT would have its own image processor, and CLIP would have
446
+ its own text and image processors. This is to support training arbitrary combinations of single-modal
447
+ and multimodal models since two models may share the same modality but have different processing. Text
448
+ sequence length is a good example. BERT's sequence length is generally 512, while CLIP uses sequences of
449
+ length 77.
450
+
451
+ Parameters
452
+ ----------
453
+ config
454
+ A DictConfig object. The model config should be accessible by "config.model".
455
+ model
456
+ The model object.
398
457
 
399
- This function will check whether space is counted or not and realign the offset.
458
+ Returns
459
+ -------
460
+ A dictionary with modalities as the keys. Each modality has a list of processors.
461
+ Note that "label" is also treated as a modality for convenience.
400
462
  """
401
- offset0, offset1 = [], []
402
- for word, offset in words_with_offsets:
403
- offset0.append(offset[0])
404
- offset1.append(offset[1])
405
-
406
- realign = []
407
- if offset0[1:] == offset1[:-1]: # space are counted
408
- realign = [words_with_offsets[0]]
409
- for word, offset in words_with_offsets[1:]:
410
- if word.startswith("▁"): # it is "Lower One Eighth Block" (U+2581) rather than lower line (U+005F).
411
- realign.append((word, (offset[0] + 1, offset[1])))
412
- else:
413
- realign.append((word, offset))
463
+ data_processors = {
464
+ IMAGE: [],
465
+ TEXT: [],
466
+ CATEGORICAL: [],
467
+ NUMERICAL: [],
468
+ LABEL: [],
469
+ ROIS: [],
470
+ TEXT_NER: [],
471
+ DOCUMENT: [],
472
+ SEMANTIC_SEGMENTATION_IMG: [],
473
+ }
474
+
475
+ model_dict = {model.prefix: model}
476
+
477
+ if model.prefix.lower().startswith("fusion"):
478
+ for per_model in model.model:
479
+ model_dict[per_model.prefix] = per_model
480
+
481
+ assert sorted(list(model_dict.keys())) == sorted(config.model.names)
482
+
483
+ for per_name, per_model in model_dict.items():
484
+ model_config = getattr(config.model, per_model.prefix)
485
+ if model_config.data_types is not None:
486
+ data_types = model_config.data_types.copy()
487
+ else:
488
+ data_types = None
489
+
490
+ if per_name == NER_TEXT:
491
+ # create a multimodal processor for NER.
492
+ data_processors[TEXT_NER].append(
493
+ create_data_processor(
494
+ data_type=TEXT_NER,
495
+ config=config,
496
+ model=per_model,
497
+ )
498
+ )
499
+ requires_label = False
500
+ if data_types is not None and TEXT_NER in data_types:
501
+ data_types.remove(TEXT_NER)
502
+ elif per_name.lower().startswith(MMLAB_MODELS):
503
+ # create a multimodal processor for NER.
504
+ data_processors[ROIS].append(
505
+ create_data_processor(
506
+ data_type=ROIS,
507
+ config=config,
508
+ model=per_model,
509
+ )
510
+ )
511
+ if data_types is not None and IMAGE in data_types:
512
+ data_types.remove(IMAGE)
513
+ elif per_name == SAM:
514
+ data_processors[SEMANTIC_SEGMENTATION_IMG].append(
515
+ create_data_processor(
516
+ data_type=SEMANTIC_SEGMENTATION_IMG,
517
+ config=config,
518
+ model=per_model,
519
+ )
520
+ )
521
+ if data_types is not None and SEMANTIC_SEGMENTATION_IMG in data_types:
522
+ data_types.remove(SEMANTIC_SEGMENTATION_IMG)
523
+ requires_label = False
524
+
525
+ if requires_label:
526
+ # each model has its own label processor
527
+ label_processor = create_data_processor(
528
+ data_type=LABEL,
529
+ config=config,
530
+ model=per_model,
531
+ )
532
+ data_processors[LABEL].append(label_processor)
533
+
534
+ if requires_data and data_types:
535
+ for data_type in data_types:
536
+ per_data_processor = create_data_processor(
537
+ data_type=data_type,
538
+ model=per_model,
539
+ config=config,
540
+ advanced_hyperparameters=advanced_hyperparameters,
541
+ )
542
+ data_processors[data_type].append(per_data_processor)
414
543
 
415
- if realign:
416
- return realign
417
- else:
418
- return words_with_offsets
544
+ # Only keep the modalities with non-empty processors.
545
+ data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
419
546
 
547
+ if TEXT_NER in data_processors and LABEL in data_processors:
548
+ # LabelProcessor is not needed for NER tasks as annotations are handled in NerProcessor.
549
+ data_processors.pop(LABEL)
550
+ return data_processors
420
551
 
421
- def is_rois_input(sample):
552
+
553
+ def turn_on_off_feature_column_info(
554
+ data_processors: Dict,
555
+ flag: bool,
556
+ ):
422
557
  """
423
- check if a sample is rois for object detection
558
+ Turn on or off returning feature column information in data processors.
559
+ Since feature column information is not always required in training models,
560
+ we optionally turn this flag on or off.
424
561
 
425
562
  Parameters
426
563
  ----------
427
- sample
428
- The sampled data.
564
+ data_processors
565
+ The data processors.
566
+ flag
567
+ True/False
568
+ """
569
+ for per_modality_processors in data_processors.values():
570
+ for per_model_processor in per_modality_processors:
571
+ # label processor doesn't have requires_column_info.
572
+ if hasattr(per_model_processor, "requires_column_info"):
573
+ per_model_processor.requires_column_info = flag
574
+
575
+
576
+ def get_mixup(
577
+ model_config: DictConfig,
578
+ mixup_config: DictConfig,
579
+ num_classes: int,
580
+ ):
581
+ """
582
+ Get the mixup state for loss function choice.
583
+ Now the mixup can only support image data.
584
+ And the problem type can not support Regression.
585
+ Parameters
586
+ ----------
587
+ model_config
588
+ The model configs to find image model for the necessity of mixup.
589
+ mixup_config
590
+ The mixup configs for mixup and cutmix.
591
+ num_classes
592
+ The number of classes in the task. Class <= 1 will cause faults.
429
593
 
430
594
  Returns
431
595
  -------
432
- bool, whether a sample is rois for object detection
596
+ The mixup is on or off.
433
597
  """
434
- return isinstance(sample, list) and len(sample) and isinstance(sample[0], list) and len(sample[0]) == 5
435
-
436
-
437
- def get_text_token_max_len(provided_max_len, config, tokenizer, checkpoint_name):
598
+ model_active = False
599
+ names = model_config.names
600
+ if isinstance(names, str):
601
+ names = [names]
602
+ for model_name in names:
603
+ permodel_config = getattr(model_config, model_name)
604
+ if hasattr(permodel_config.data_types, IMAGE):
605
+ model_active = True
606
+ break
607
+
608
+ mixup_active = False
609
+ if mixup_config is not None and mixup_config.turn_on:
610
+ mixup_active = (
611
+ mixup_config.mixup_alpha > 0 or mixup_config.cutmix_alpha > 0.0 or mixup_config.cutmix_minmax is not None
612
+ )
613
+
614
+ mixup_state = model_active & mixup_active & ((num_classes is not None) and (num_classes > 1))
615
+ mixup_fn = None
616
+ if mixup_state:
617
+ mixup_args = dict(
618
+ mixup_alpha=mixup_config.mixup_alpha,
619
+ cutmix_alpha=mixup_config.cutmix_alpha,
620
+ cutmix_minmax=mixup_config.cutmix_minmax,
621
+ prob=mixup_config.prob,
622
+ switch_prob=mixup_config.switch_prob,
623
+ mode=mixup_config.mode,
624
+ label_smoothing=mixup_config.label_smoothing,
625
+ num_classes=num_classes,
626
+ )
627
+ mixup_fn = MixupModule(**mixup_args)
628
+ return mixup_state, mixup_fn
629
+
630
+
631
+ def data_to_df(
632
+ data: Union[pd.DataFrame, Dict, List],
633
+ required_columns: Optional[List] = None,
634
+ all_columns: Optional[List] = None,
635
+ header: Optional[str] = None,
636
+ ):
438
637
  """
439
- Compute the allowable max length of token sequences.
638
+ Convert the input data to a dataframe.
440
639
 
441
640
  Parameters
442
641
  ----------
443
- provided_max_len
444
- The provided max length.
445
- config
446
- Model config.
447
- tokenizer
448
- Text tokenizer.
449
- checkpoint_name
450
- Name of checkpoint.
642
+ data
643
+ Input data provided by users during prediction/evaluation.
644
+ required_columns
645
+ Required columns.
646
+ all_columns
647
+ All the possible columns got from training data. The column order is preserved.
648
+ header
649
+ Provided header to create a dataframe.
451
650
 
452
651
  Returns
453
652
  -------
454
- Token sequence max length.
653
+ A dataframe with required columns.
455
654
  """
456
- if hasattr(config, "relative_attention") and config.relative_attention:
457
- default_max_len = tokenizer.model_max_length
458
- elif hasattr(config, "position_embedding_type") and "relative" in config.position_embedding_type:
459
- default_max_len = tokenizer.model_max_length
460
- elif hasattr(config, "max_position_embeddings"):
461
- default_max_len = config.max_position_embeddings
462
- else:
463
- default_max_len = tokenizer.model_max_length
464
-
465
- if provided_max_len is None or provided_max_len <= 0:
466
- max_len = default_max_len
655
+ has_header = True
656
+ if isinstance(data, pd.DataFrame):
657
+ pass
658
+ elif isinstance(data, dict):
659
+ data = pd.DataFrame(data)
660
+ elif isinstance(data, list):
661
+ assert len(data) > 0, f"Expected data to have length > 0, but got {data} of len {len(data)}"
662
+ if header is None:
663
+ has_header = False
664
+ data = pd.DataFrame(data)
665
+ else:
666
+ data = pd.DataFrame({header: data})
667
+ elif isinstance(data, str):
668
+ df = pd.DataFrame([data])
669
+ col_name = list(df.columns)[0]
670
+ if is_image_column(df[col_name], col_name=col_name, image_type=IMAGE_PATH):
671
+ has_header = False
672
+ data = df
673
+ else:
674
+ data = load_pd.load(data)
467
675
  else:
468
- if provided_max_len < default_max_len:
469
- if default_max_len < 10**6: # Larger than this value usually means infinite.
470
- warnings.warn(
471
- f"provided max length: {provided_max_len} "
472
- f"is smaller than {checkpoint_name}'s default: {default_max_len}"
676
+ raise NotImplementedError(
677
+ f"The format of data is not understood. "
678
+ f'We have type(data)="{type(data)}", but a pd.DataFrame was required.'
679
+ )
680
+
681
+ if required_columns and all_columns:
682
+ detected_columns = data.columns.values.tolist()
683
+ missing_columns = []
684
+ for per_col in required_columns:
685
+ if per_col not in detected_columns:
686
+ missing_columns.append(per_col)
687
+
688
+ if len(missing_columns) > 0:
689
+ # assume no column names are provided and users organize data in the same column order of training data.
690
+ if len(detected_columns) == len(all_columns):
691
+ if has_header:
692
+ warnings.warn(
693
+ f"Replacing detected dataframe columns `{detected_columns}` with columns "
694
+ f"`{all_columns}` from training data."
695
+ "Double check the correspondences between them to avoid unexpected behaviors.",
696
+ UserWarning,
697
+ )
698
+ data.rename(dict(zip(detected_columns, required_columns)), axis=1, inplace=True)
699
+ else:
700
+ raise ValueError(
701
+ f"Dataframe columns `{detected_columns}` are detected, but columns `{missing_columns}` are missing. "
702
+ f"Please double check your input data to provide all the "
703
+ f"required columns `{required_columns}`."
473
704
  )
474
- max_len = min(provided_max_len, default_max_len)
475
705
 
476
- return max_len
706
+ return data
477
707
 
478
708
 
479
- def get_image_transform_funcs(transform_types: Union[List[str], ListConfig, List[Callable]], size: int):
709
+ def infer_scarcity_mode_by_data_size(df_train: pd.DataFrame, scarcity_threshold: int = 50):
480
710
  """
481
- Parse a list of transform strings into callable objects.
711
+ Infer based on the number of training sample the data scarsity. Select mode accordingly from [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT].
482
712
 
483
713
  Parameters
484
- ----------
485
- transform_types
486
- A list of transforms, which can be strings or callable objects.
487
- size
488
- Image size.
714
+ ---------------
715
+ df_train
716
+ Training dataframe
717
+ scarcity_threshold
718
+ Threshold number of samples when to select FEW_SHOT mode
489
719
 
490
720
  Returns
491
- -------
492
- A list of transform objects.
721
+ --------
722
+ Mode in [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT]
493
723
  """
494
- image_transforms = []
724
+ row_num = len(df_train)
725
+ if row_num < scarcity_threshold:
726
+ return FEW_SHOT
727
+ else:
728
+ return DEFAULT_SHOT
495
729
 
496
- if not transform_types:
497
- return image_transforms
498
730
 
499
- if isinstance(transform_types, ListConfig):
500
- transform_types = list(transform_types)
501
- elif not isinstance(transform_types, list):
502
- transform_types = [transform_types]
731
+ def infer_dtypes_by_model_names(model_config: DictConfig):
732
+ """
733
+ Get data types according to model types.
503
734
 
504
- if all([isinstance(trans_type, str) for trans_type in transform_types]):
505
- pass
506
- elif all([isinstance(trans_type, Callable) for trans_type in transform_types]):
507
- return copy.copy(transform_types)
508
- else:
509
- raise ValueError(f"transform_types {transform_types} contain neither all strings nor all callable objects.")
510
-
511
- for trans_type in transform_types:
512
- args = None
513
- kargs = None
514
- if "(" in trans_type:
515
- trans_mode = trans_type[0 : trans_type.find("(")]
516
- if "{" in trans_type:
517
- kargs = ast.literal_eval(trans_type[trans_type.find("{") : trans_type.rfind(")")])
518
- else:
519
- args = ast.literal_eval(trans_type[trans_type.find("(") :])
520
- else:
521
- trans_mode = trans_type
522
-
523
- if trans_mode == "resize_to_square":
524
- image_transforms.append(transforms.Resize((size, size), interpolation=BICUBIC))
525
- elif trans_mode == "resize_gt_to_square":
526
- image_transforms.append(transforms.Resize((size, size), interpolation=NEAREST))
527
- elif trans_mode == "resize_shorter_side":
528
- image_transforms.append(transforms.Resize(size, interpolation=BICUBIC))
529
- elif trans_mode == "center_crop":
530
- image_transforms.append(transforms.CenterCrop(size))
531
- elif trans_mode == "random_resize_crop":
532
- image_transforms.append(transforms.RandomResizedCrop(size))
533
- elif trans_mode == "random_horizontal_flip":
534
- image_transforms.append(transforms.RandomHorizontalFlip())
535
- elif trans_mode == "random_vertical_flip":
536
- image_transforms.append(transforms.RandomVerticalFlip())
537
- elif trans_mode == "color_jitter":
538
- if kargs is not None:
539
- image_transforms.append(transforms.ColorJitter(**kargs))
540
- elif args is not None:
541
- image_transforms.append(transforms.ColorJitter(*args))
542
- else:
543
- image_transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1))
544
- elif trans_mode == "affine":
545
- if kargs is not None:
546
- image_transforms.append(transforms.RandomAffine(**kargs))
547
- elif args is not None:
548
- image_transforms.append(transforms.RandomAffine(*args))
549
- else:
550
- image_transforms.append(transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)))
551
- elif trans_mode == "randaug":
552
- if kargs is not None:
553
- image_transforms.append(RandAugment(**kargs))
554
- elif args is not None:
555
- image_transforms.append(RandAugment(*args))
556
- else:
557
- image_transforms.append(RandAugment(2, 9))
558
- elif trans_mode == "trivial_augment":
559
- image_transforms.append(TrivialAugment(IMAGE, 30))
560
- else:
561
- raise ValueError(f"unknown transform type: {trans_mode}")
735
+ Parameters
736
+ ----------
737
+ model_config
738
+ Model config from `config.model`.
739
+
740
+ Returns
741
+ -------
742
+ The data types allowed by models and the default fallback data type.
743
+ """
744
+ allowable_dtypes = []
745
+ fallback_dtype = None
746
+ for per_model in model_config.names:
747
+ per_model_dtypes = OmegaConf.select(model_config, f"{per_model}.data_types")
748
+ if per_model_dtypes:
749
+ allowable_dtypes.extend(per_model_dtypes)
562
750
 
563
- return image_transforms
751
+ allowable_dtypes = set(allowable_dtypes)
752
+ if allowable_dtypes == {IMAGE, TEXT}:
753
+ fallback_dtype = TEXT
754
+ elif len(allowable_dtypes) == 1:
755
+ fallback_dtype = list(allowable_dtypes)[0]
564
756
 
757
+ return allowable_dtypes, fallback_dtype
565
758
 
566
- def construct_image_processor(
567
- image_transforms: Union[List[Callable], List[str]],
568
- size: int,
569
- normalization,
570
- ) -> transforms.Compose:
759
+
760
+ def split_train_tuning_data(
761
+ data: pd.DataFrame,
762
+ holdout_frac: float = None,
763
+ problem_type: str = None,
764
+ label_column: str = None,
765
+ random_state: int = 0,
766
+ ) -> (pd.DataFrame, pd.DataFrame):
571
767
  """
572
- Build up an image processor from the provided list of transform types.
768
+ Splits `data` into `train_data` and `tuning_data`.
769
+ If the problem_type is one of ['binary', 'multiclass']:
770
+ The split will be done with stratification on the label column.
771
+ Will guarantee at least 1 sample of every class in `data` will be present in `train_data`.
772
+ If only 1 sample of a class exists, it will always be put in `train_data` and not `tuning_data`.
573
773
 
574
774
  Parameters
575
775
  ----------
576
- image_transforms
577
- A list of image transform types.
578
- size
579
- Image size.
580
- normalization
581
- A transforms.Normalize object. When the image is ground truth image, 'normalization=None' should be specified.
776
+ data : pd.DataFrame
777
+ The data to be split
778
+ holdout_frac : float, default = None
779
+ The ratio of data to use as validation.
780
+ If 0.2, 20% of the data will be used for validation, and 80% for training.
781
+ If None, the ratio is automatically determined,
782
+ ranging from 0.2 for small row count to 0.01 for large row count.
783
+ random_state : int, default = 0
784
+ The random state to use when splitting the data, to make the splitting process deterministic.
785
+ If None, a random value is used.
582
786
 
583
787
  Returns
584
788
  -------
585
- A transforms.Compose object.
789
+ Tuple of (train_data, tuning_data) of the split `data`
586
790
  """
587
- image_transforms = get_image_transform_funcs(transform_types=image_transforms, size=size)
588
- if not any([isinstance(trans, transforms.ToTensor) for trans in image_transforms]):
589
- image_transforms.append(transforms.ToTensor())
590
- if not any([isinstance(trans, transforms.Normalize) for trans in image_transforms]) and normalization != None:
591
- image_transforms.append(normalization)
592
- return transforms.Compose(image_transforms)
791
+ if holdout_frac is None:
792
+ holdout_frac = default_holdout_frac(num_train_rows=len(data), hyperparameter_tune=False)
793
+
794
+ # TODO: Hack since the recognized problem types are only binary, multiclass, and regression
795
+ # Problem types used for purpose of stratification, so regression = no stratification
796
+ if problem_type in [BINARY, MULTICLASS]:
797
+ problem_type_for_split = problem_type
798
+ else:
799
+ problem_type_for_split = REGRESSION
800
+
801
+ train_data, tuning_data = generate_train_test_split_combined(
802
+ data=data,
803
+ label=label_column,
804
+ test_size=holdout_frac,
805
+ problem_type=problem_type_for_split,
806
+ random_state=random_state,
807
+ )
808
+ return train_data, tuning_data
593
809
 
594
810
 
595
- def image_mean_std(norm_type: str):
811
+ def get_detected_data_types(column_types: Dict):
596
812
  """
597
- Get image normalization mean and std by its name.
813
+ Extract data types from column types.
598
814
 
599
815
  Parameters
600
816
  ----------
601
- norm_type
602
- Name of image normalization.
817
+ column_types
818
+ A dataframe's column types.
603
819
 
604
820
  Returns
605
821
  -------
606
- Normalization mean and std.
822
+ A list of detected data types.
607
823
  """
608
- if norm_type == "inception":
609
- return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
610
- elif norm_type == "imagenet":
611
- return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
612
- elif norm_type == "clip":
613
- return CLIP_IMAGE_MEAN, CLIP_IMAGE_STD
614
- else:
615
- raise ValueError(f"unknown image normalization: {norm_type}")
824
+ data_types = []
825
+ for col_type in column_types.values():
826
+ if col_type.startswith(IMAGE) and IMAGE not in data_types:
827
+ data_types.append(IMAGE)
828
+ elif col_type.startswith(TEXT_NER) and TEXT_NER not in data_types:
829
+ data_types.append(TEXT_NER)
830
+ elif col_type.startswith(TEXT) and TEXT not in data_types:
831
+ data_types.append(TEXT)
832
+ elif col_type.startswith(DOCUMENT) and DOCUMENT not in data_types:
833
+ data_types.append(DOCUMENT)
834
+ elif col_type.startswith(NUMERICAL) and NUMERICAL not in data_types:
835
+ data_types.append(NUMERICAL)
836
+ elif col_type.startswith(CATEGORICAL) and CATEGORICAL not in data_types:
837
+ data_types.append(CATEGORICAL)
838
+ elif col_type.startswith(ROIS) and ROIS not in data_types:
839
+ data_types.append(ROIS)
840
+
841
+ return data_types