autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,79 +1,53 @@
|
|
1
|
-
import
|
2
|
-
import codecs
|
3
|
-
import copy
|
4
|
-
import re
|
1
|
+
import logging
|
5
2
|
import warnings
|
6
|
-
from io import BytesIO
|
7
3
|
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
|
8
4
|
|
9
|
-
import numpy as np
|
10
5
|
import pandas as pd
|
11
|
-
import
|
12
|
-
from
|
13
|
-
|
14
|
-
from
|
15
|
-
|
16
|
-
IMAGENET_DEFAULT_STD,
|
17
|
-
IMAGENET_INCEPTION_MEAN,
|
18
|
-
IMAGENET_INCEPTION_STD,
|
19
|
-
)
|
20
|
-
from tokenizers import pre_tokenizers
|
21
|
-
from torchvision import transforms
|
6
|
+
from omegaconf import DictConfig, OmegaConf
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
from autogluon.core.utils import default_holdout_frac, generate_train_test_split_combined
|
10
|
+
from autogluon.core.utils.loaders import load_pd
|
22
11
|
|
23
12
|
from ..constants import (
|
24
|
-
|
25
|
-
|
13
|
+
BINARY,
|
14
|
+
CATEGORICAL,
|
15
|
+
DEFAULT_SHOT,
|
16
|
+
DOCUMENT,
|
17
|
+
FEW_SHOT,
|
26
18
|
IDENTIFIER,
|
27
19
|
IMAGE,
|
28
|
-
IMAGE_BYTEARRAY,
|
29
20
|
IMAGE_PATH,
|
21
|
+
LABEL,
|
30
22
|
MMDET_IMAGE,
|
31
23
|
MMLAB_MODELS,
|
24
|
+
MULTICLASS,
|
25
|
+
NER_ANNOTATION,
|
26
|
+
NER_TEXT,
|
27
|
+
NUMERICAL,
|
28
|
+
REGRESSION,
|
29
|
+
ROIS,
|
30
|
+
SAM,
|
31
|
+
SEMANTIC_SEGMENTATION_IMG,
|
32
|
+
TEXT,
|
33
|
+
TEXT_NER,
|
32
34
|
)
|
33
35
|
from .collator import DictCollator
|
36
|
+
from .infer_types import is_image_column
|
37
|
+
from .label_encoder import NerLabelEncoder
|
38
|
+
from .mixup import MixupModule
|
34
39
|
from .preprocess_dataframe import MultiModalFeaturePreprocessor
|
40
|
+
from .process_categorical import CategoricalProcessor
|
41
|
+
from .process_document import DocumentProcessor
|
42
|
+
from .process_image import ImageProcessor
|
43
|
+
from .process_label import LabelProcessor
|
44
|
+
from .process_mmlab import MMDetProcessor
|
45
|
+
from .process_ner import NerProcessor
|
46
|
+
from .process_numerical import NumericalProcessor
|
47
|
+
from .process_semantic_seg_img import SemanticSegImageProcessor
|
48
|
+
from .process_text import TextProcessor
|
35
49
|
|
36
|
-
|
37
|
-
from torchvision.transforms import InterpolationMode
|
38
|
-
|
39
|
-
BICUBIC = InterpolationMode.BICUBIC
|
40
|
-
NEAREST = InterpolationMode.NEAREST
|
41
|
-
except ImportError:
|
42
|
-
BICUBIC = PIL.Image.BICUBIC
|
43
|
-
NEAREST = PIL.Image.NEAREST
|
44
|
-
|
45
|
-
from .randaug import RandAugment
|
46
|
-
from .trivial_augmenter import TrivialAugment
|
47
|
-
|
48
|
-
|
49
|
-
def extract_value_from_config(
|
50
|
-
config: Dict,
|
51
|
-
keys: Tuple[str, ...],
|
52
|
-
):
|
53
|
-
"""
|
54
|
-
Traverse a config dictionary to get some hyper-parameter's value.
|
55
|
-
|
56
|
-
Parameters
|
57
|
-
----------
|
58
|
-
config
|
59
|
-
A config dictionary.
|
60
|
-
keys
|
61
|
-
The possible names of a hyper-parameter.
|
62
|
-
|
63
|
-
Returns
|
64
|
-
-------
|
65
|
-
The hyper-parameter value.
|
66
|
-
"""
|
67
|
-
result = []
|
68
|
-
for k, v in config.items():
|
69
|
-
if k in keys:
|
70
|
-
result.append(v)
|
71
|
-
elif isinstance(v, dict):
|
72
|
-
result += extract_value_from_config(v, keys)
|
73
|
-
else:
|
74
|
-
pass
|
75
|
-
|
76
|
-
return result
|
50
|
+
logger = logging.getLogger(__name__)
|
77
51
|
|
78
52
|
|
79
53
|
def get_collate_fn(
|
@@ -165,7 +139,7 @@ def apply_df_preprocessor(
|
|
165
139
|
def apply_data_processor(
|
166
140
|
per_sample_features: Dict,
|
167
141
|
data_processors: Dict,
|
168
|
-
|
142
|
+
data_types: Dict,
|
169
143
|
is_training: bool,
|
170
144
|
load_only=False,
|
171
145
|
):
|
@@ -175,9 +149,11 @@ def apply_data_processor(
|
|
175
149
|
Parameters
|
176
150
|
----------
|
177
151
|
per_sample_features
|
178
|
-
|
152
|
+
Features of one sample.
|
179
153
|
data_processors
|
180
154
|
A dict of data processors.
|
155
|
+
data_types
|
156
|
+
Data types of all columns.
|
181
157
|
is_training
|
182
158
|
Whether is training.
|
183
159
|
load_only
|
@@ -194,14 +170,14 @@ def apply_data_processor(
|
|
194
170
|
sample_features.update(
|
195
171
|
per_model_processor(
|
196
172
|
per_sample_features[per_modality],
|
197
|
-
|
173
|
+
data_types[per_modality],
|
198
174
|
is_training=is_training,
|
199
175
|
load_only=load_only,
|
200
176
|
)
|
201
177
|
if per_model_processor.prefix.lower().startswith(MMDET_IMAGE)
|
202
178
|
else per_model_processor(
|
203
179
|
per_sample_features[per_modality],
|
204
|
-
|
180
|
+
data_types[per_modality],
|
205
181
|
is_training=is_training,
|
206
182
|
)
|
207
183
|
)
|
@@ -250,366 +226,616 @@ def get_per_sample_features(
|
|
250
226
|
return ret
|
251
227
|
|
252
228
|
|
253
|
-
def
|
254
|
-
"""
|
229
|
+
def default_holdout_frac(num_train_rows, hyperparameter_tune=False):
|
230
|
+
"""Returns default holdout_frac used in fit().
|
231
|
+
Between row count 5,000 and 25,000 keep 0.1 holdout_frac, as we want to grow validation set to a stable 2500 examples.
|
232
|
+
"""
|
233
|
+
if num_train_rows < 5000:
|
234
|
+
holdout_frac = max(0.1, min(0.2, 500.0 / num_train_rows))
|
235
|
+
else:
|
236
|
+
holdout_frac = max(0.01, min(0.1, 2500.0 / num_train_rows))
|
255
237
|
|
256
|
-
|
257
|
-
|
238
|
+
if hyperparameter_tune:
|
239
|
+
holdout_frac = min(
|
240
|
+
0.2, holdout_frac * 2
|
241
|
+
) # We want to allocate more validation data for HPO to avoid overfitting
|
258
242
|
|
259
|
-
|
260
|
-
return error.object[error.start : error.end].decode("cp1252"), error.end
|
243
|
+
return holdout_frac
|
261
244
|
|
262
|
-
codecs.register_error("replace_encoding_with_utf8", replace_encoding_with_utf8)
|
263
|
-
codecs.register_error("replace_decoding_with_cp1252", replace_decoding_with_cp1252)
|
264
245
|
|
246
|
+
def init_df_preprocessor(
|
247
|
+
config: DictConfig,
|
248
|
+
column_types: Dict,
|
249
|
+
label_column: Optional[str] = None,
|
250
|
+
train_df_x: Optional[pd.DataFrame] = None,
|
251
|
+
train_df_y: Optional[pd.Series] = None,
|
252
|
+
):
|
253
|
+
"""
|
254
|
+
Initialize the dataframe preprocessor by calling .fit().
|
265
255
|
|
266
|
-
|
267
|
-
|
256
|
+
Parameters
|
257
|
+
----------
|
258
|
+
config
|
259
|
+
A DictConfig containing only the data config.
|
260
|
+
column_types
|
261
|
+
A dictionary that maps column names to their data types.
|
262
|
+
For example: `column_types = {"item_name": "text", "image": "image_path",
|
263
|
+
"product_description": "text", "height": "numerical"}`
|
264
|
+
may be used for a table with columns: "item_name", "brand", "product_description", and "height".
|
265
|
+
label_column
|
266
|
+
Name of the column that contains the target variable to predict.
|
267
|
+
train_df_x
|
268
|
+
A pd.DataFrame containing only the feature columns.
|
269
|
+
train_df_y
|
270
|
+
A pd.Series object containing only the label column.
|
268
271
|
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
272
|
+
Returns
|
273
|
+
-------
|
274
|
+
Initialized dataframe preprocessor.
|
275
|
+
"""
|
276
|
+
if label_column in column_types and column_types[label_column] == NER_ANNOTATION:
|
277
|
+
label_generator = NerLabelEncoder(config)
|
278
|
+
else:
|
279
|
+
label_generator = None
|
280
|
+
|
281
|
+
df_preprocessor = MultiModalFeaturePreprocessor(
|
282
|
+
config=config.data,
|
283
|
+
column_types=column_types,
|
284
|
+
label_column=label_column,
|
285
|
+
label_generator=label_generator,
|
286
|
+
)
|
287
|
+
df_preprocessor.fit(
|
288
|
+
X=train_df_x,
|
289
|
+
y=train_df_y,
|
274
290
|
)
|
275
|
-
|
276
|
-
return
|
291
|
+
|
292
|
+
return df_preprocessor
|
277
293
|
|
278
294
|
|
279
|
-
def
|
295
|
+
def get_image_transforms(model_config: DictConfig, model_name: str, advanced_hyperparameters: Dict):
|
280
296
|
"""
|
281
|
-
|
297
|
+
Get the image transforms of one image-related model.
|
298
|
+
Use the transforms in advanced_hyperparameters with higher priority.
|
282
299
|
|
283
300
|
Parameters
|
284
301
|
----------
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
The
|
291
|
-
tokenizer
|
292
|
-
The tokenizer to be used.
|
293
|
-
is_eval
|
294
|
-
Whether it is for evaluation or not, default: False
|
302
|
+
model_config
|
303
|
+
Config of one model.
|
304
|
+
model_name
|
305
|
+
Name of one model.
|
306
|
+
advanced_hyperparameters
|
307
|
+
The advanced hyperparameters whose values are complex objects.
|
295
308
|
|
296
309
|
Returns
|
297
310
|
-------
|
298
|
-
|
311
|
+
The image transforms used in training and validation.
|
299
312
|
"""
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
for idx, word_offset in enumerate(word_offsets[:num_words, :]):
|
311
|
-
# support multiple words in an annotated offset range.
|
312
|
-
# Allow partial overlapping between custom annotations and pretokenized words.
|
313
|
-
if (word_offset[0] < custom_offset[1]) and (custom_offset[0] < word_offset[1]):
|
314
|
-
if not (
|
315
|
-
re.match(b_prefix, custom_label, re.IGNORECASE) or re.match(i_prefix, custom_label, re.IGNORECASE)
|
316
|
-
):
|
317
|
-
if is_start_word and b_prefix + custom_label in entity_map:
|
318
|
-
word_label[idx] = entity_map[b_prefix + custom_label]
|
319
|
-
is_start_word = False
|
320
|
-
elif i_prefix + custom_label in entity_map:
|
321
|
-
word_label[idx] = entity_map[i_prefix + custom_label]
|
322
|
-
else:
|
323
|
-
if custom_label in entity_map:
|
324
|
-
word_label[idx] = entity_map[custom_label]
|
325
|
-
|
326
|
-
token_label = [0] * len(col_tokens.input_ids)
|
327
|
-
temp = set()
|
328
|
-
counter = 0
|
329
|
-
for idx, token_to_word in enumerate(token_to_word_mappings):
|
330
|
-
if token_to_word != -1 and token_to_word not in temp:
|
331
|
-
temp.add(token_to_word)
|
332
|
-
token_label[idx] = word_label[counter]
|
333
|
-
counter += 1
|
334
|
-
if not is_eval:
|
335
|
-
label = token_label # return token-level labels for training
|
313
|
+
train_transform_key = f"model.{model_name}.train_transforms"
|
314
|
+
val_transform_key = f"model.{model_name}.val_transforms"
|
315
|
+
if advanced_hyperparameters and train_transform_key in advanced_hyperparameters:
|
316
|
+
train_transforms = advanced_hyperparameters[train_transform_key]
|
317
|
+
else:
|
318
|
+
train_transforms = model_config.train_transforms
|
319
|
+
train_transforms = list(train_transforms)
|
320
|
+
|
321
|
+
if advanced_hyperparameters and val_transform_key in advanced_hyperparameters:
|
322
|
+
val_transforms = advanced_hyperparameters[val_transform_key]
|
336
323
|
else:
|
337
|
-
|
324
|
+
val_transforms = model_config.val_transforms
|
325
|
+
val_transforms = list(val_transforms)
|
338
326
|
|
339
|
-
return
|
327
|
+
return train_transforms, val_transforms
|
340
328
|
|
341
329
|
|
342
|
-
def
|
330
|
+
def create_data_processor(
|
331
|
+
data_type: str,
|
332
|
+
config: DictConfig,
|
333
|
+
model: nn.Module,
|
334
|
+
advanced_hyperparameters: Optional[Dict] = None,
|
335
|
+
):
|
343
336
|
"""
|
344
|
-
|
345
|
-
and the input text tokenization.
|
337
|
+
Create one data processor based on the data type and model.
|
346
338
|
|
347
339
|
Parameters
|
348
340
|
----------
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
The
|
341
|
+
data_type
|
342
|
+
Data type.
|
343
|
+
config
|
344
|
+
The config may contain information required by creating a data processor.
|
345
|
+
In future, we may move the required config information into the model.config
|
346
|
+
to make the data processor conditioned only on the model itself.
|
347
|
+
model
|
348
|
+
The model.
|
353
349
|
|
354
350
|
Returns
|
355
351
|
-------
|
356
|
-
|
352
|
+
One data processor.
|
357
353
|
"""
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
354
|
+
model_config = getattr(config.model, model.prefix)
|
355
|
+
if data_type == IMAGE:
|
356
|
+
train_transforms, val_transforms = get_image_transforms(
|
357
|
+
model_config=model_config,
|
358
|
+
model_name=model.prefix,
|
359
|
+
advanced_hyperparameters=advanced_hyperparameters,
|
360
|
+
)
|
361
|
+
data_processor = ImageProcessor(
|
362
|
+
model=model,
|
363
|
+
train_transforms=train_transforms,
|
364
|
+
val_transforms=val_transforms,
|
365
|
+
max_image_num_per_column=model_config.max_image_num_per_column,
|
366
|
+
missing_value_strategy=config.data.image.missing_value_strategy,
|
367
|
+
dropout=config.data.modality_dropout,
|
368
|
+
)
|
369
|
+
elif data_type == TEXT:
|
370
|
+
data_processor = TextProcessor(
|
371
|
+
model=model,
|
372
|
+
insert_sep=model_config.insert_sep,
|
373
|
+
stochastic_chunk=model_config.stochastic_chunk,
|
374
|
+
text_detection_length=model_config.text_aug_detect_length,
|
375
|
+
text_trivial_aug_maxscale=model_config.text_trivial_aug_maxscale,
|
376
|
+
train_augment_types=model_config.text_train_augment_types,
|
377
|
+
normalize_text=config.data.text.normalize_text,
|
378
|
+
template_config=config.data.templates,
|
379
|
+
dropout=config.data.modality_dropout,
|
380
|
+
)
|
381
|
+
elif data_type == CATEGORICAL:
|
382
|
+
data_processor = CategoricalProcessor(
|
383
|
+
model=model,
|
384
|
+
dropout=config.data.modality_dropout,
|
385
|
+
)
|
386
|
+
elif data_type == NUMERICAL:
|
387
|
+
data_processor = NumericalProcessor(
|
388
|
+
model=model,
|
389
|
+
merge=model_config.merge,
|
390
|
+
dropout=config.data.modality_dropout,
|
391
|
+
)
|
392
|
+
elif data_type == LABEL:
|
393
|
+
data_processor = LabelProcessor(model=model)
|
394
|
+
elif data_type == TEXT_NER:
|
395
|
+
data_processor = NerProcessor(
|
396
|
+
model=model,
|
397
|
+
max_len=model_config.max_text_len,
|
398
|
+
entity_map=config.entity_map,
|
399
|
+
)
|
400
|
+
elif data_type == ROIS:
|
401
|
+
data_processor = MMDetProcessor(
|
402
|
+
model=model,
|
403
|
+
max_img_num_per_col=model_config.max_img_num_per_col,
|
404
|
+
missing_value_strategy=config.data.image.missing_value_strategy,
|
405
|
+
)
|
406
|
+
elif data_type == DOCUMENT:
|
407
|
+
train_transforms, val_transforms = get_image_transforms(
|
408
|
+
model_config=model_config,
|
409
|
+
model_name=model.prefix,
|
410
|
+
advanced_hyperparameters=advanced_hyperparameters,
|
411
|
+
)
|
412
|
+
data_processor = DocumentProcessor(
|
413
|
+
model=model,
|
414
|
+
train_transforms=train_transforms,
|
415
|
+
val_transforms=val_transforms,
|
416
|
+
size=model_config.image_size,
|
417
|
+
text_max_len=model_config.max_text_len,
|
418
|
+
missing_value_strategy=config.data.document.missing_value_strategy,
|
419
|
+
)
|
420
|
+
elif data_type == SEMANTIC_SEGMENTATION_IMG:
|
421
|
+
data_processor = SemanticSegImageProcessor(
|
422
|
+
model=model,
|
423
|
+
img_transforms=model_config.img_transforms,
|
424
|
+
gt_transforms=model_config.gt_transforms,
|
425
|
+
train_transforms=model_config.train_transforms,
|
426
|
+
val_transforms=model_config.val_transforms,
|
427
|
+
ignore_label=model_config.ignore_label,
|
428
|
+
)
|
380
429
|
else:
|
381
|
-
|
382
|
-
word_offsets = np.append(offset_mapping[1:], [[0, 0]], axis=0)
|
383
|
-
word_idx = np.arange(len(col_tokens.word_ids()) - col_tokens.word_ids().count(None))
|
384
|
-
token_to_word_mappings = [
|
385
|
-
val + word_idx[idx - 1] if val != None else -1 for idx, val in enumerate(col_tokens.word_ids())
|
386
|
-
]
|
430
|
+
raise ValueError(f"unknown data type: {data_type}")
|
387
431
|
|
388
|
-
return
|
432
|
+
return data_processor
|
389
433
|
|
390
434
|
|
391
|
-
def
|
435
|
+
def create_fusion_data_processors(
|
436
|
+
config: DictConfig,
|
437
|
+
model: nn.Module,
|
438
|
+
requires_label: Optional[bool] = True,
|
439
|
+
requires_data: Optional[bool] = True,
|
440
|
+
advanced_hyperparameters: Optional[Dict] = None,
|
441
|
+
):
|
392
442
|
"""
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
443
|
+
Create the data processors for late-fusion models. This function creates one processor for
|
444
|
+
each modality of each model. For example, if one model config contains BERT, ViT, and CLIP, then
|
445
|
+
BERT would have its own text processor, ViT would have its own image processor, and CLIP would have
|
446
|
+
its own text and image processors. This is to support training arbitrary combinations of single-modal
|
447
|
+
and multimodal models since two models may share the same modality but have different processing. Text
|
448
|
+
sequence length is a good example. BERT's sequence length is generally 512, while CLIP uses sequences of
|
449
|
+
length 77.
|
450
|
+
|
451
|
+
Parameters
|
452
|
+
----------
|
453
|
+
config
|
454
|
+
A DictConfig object. The model config should be accessible by "config.model".
|
455
|
+
model
|
456
|
+
The model object.
|
398
457
|
|
399
|
-
|
458
|
+
Returns
|
459
|
+
-------
|
460
|
+
A dictionary with modalities as the keys. Each modality has a list of processors.
|
461
|
+
Note that "label" is also treated as a modality for convenience.
|
400
462
|
"""
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
463
|
+
data_processors = {
|
464
|
+
IMAGE: [],
|
465
|
+
TEXT: [],
|
466
|
+
CATEGORICAL: [],
|
467
|
+
NUMERICAL: [],
|
468
|
+
LABEL: [],
|
469
|
+
ROIS: [],
|
470
|
+
TEXT_NER: [],
|
471
|
+
DOCUMENT: [],
|
472
|
+
SEMANTIC_SEGMENTATION_IMG: [],
|
473
|
+
}
|
474
|
+
|
475
|
+
model_dict = {model.prefix: model}
|
476
|
+
|
477
|
+
if model.prefix.lower().startswith("fusion"):
|
478
|
+
for per_model in model.model:
|
479
|
+
model_dict[per_model.prefix] = per_model
|
480
|
+
|
481
|
+
assert sorted(list(model_dict.keys())) == sorted(config.model.names)
|
482
|
+
|
483
|
+
for per_name, per_model in model_dict.items():
|
484
|
+
model_config = getattr(config.model, per_model.prefix)
|
485
|
+
if model_config.data_types is not None:
|
486
|
+
data_types = model_config.data_types.copy()
|
487
|
+
else:
|
488
|
+
data_types = None
|
489
|
+
|
490
|
+
if per_name == NER_TEXT:
|
491
|
+
# create a multimodal processor for NER.
|
492
|
+
data_processors[TEXT_NER].append(
|
493
|
+
create_data_processor(
|
494
|
+
data_type=TEXT_NER,
|
495
|
+
config=config,
|
496
|
+
model=per_model,
|
497
|
+
)
|
498
|
+
)
|
499
|
+
requires_label = False
|
500
|
+
if data_types is not None and TEXT_NER in data_types:
|
501
|
+
data_types.remove(TEXT_NER)
|
502
|
+
elif per_name.lower().startswith(MMLAB_MODELS):
|
503
|
+
# create a multimodal processor for NER.
|
504
|
+
data_processors[ROIS].append(
|
505
|
+
create_data_processor(
|
506
|
+
data_type=ROIS,
|
507
|
+
config=config,
|
508
|
+
model=per_model,
|
509
|
+
)
|
510
|
+
)
|
511
|
+
if data_types is not None and IMAGE in data_types:
|
512
|
+
data_types.remove(IMAGE)
|
513
|
+
elif per_name == SAM:
|
514
|
+
data_processors[SEMANTIC_SEGMENTATION_IMG].append(
|
515
|
+
create_data_processor(
|
516
|
+
data_type=SEMANTIC_SEGMENTATION_IMG,
|
517
|
+
config=config,
|
518
|
+
model=per_model,
|
519
|
+
)
|
520
|
+
)
|
521
|
+
if data_types is not None and SEMANTIC_SEGMENTATION_IMG in data_types:
|
522
|
+
data_types.remove(SEMANTIC_SEGMENTATION_IMG)
|
523
|
+
requires_label = False
|
524
|
+
|
525
|
+
if requires_label:
|
526
|
+
# each model has its own label processor
|
527
|
+
label_processor = create_data_processor(
|
528
|
+
data_type=LABEL,
|
529
|
+
config=config,
|
530
|
+
model=per_model,
|
531
|
+
)
|
532
|
+
data_processors[LABEL].append(label_processor)
|
533
|
+
|
534
|
+
if requires_data and data_types:
|
535
|
+
for data_type in data_types:
|
536
|
+
per_data_processor = create_data_processor(
|
537
|
+
data_type=data_type,
|
538
|
+
model=per_model,
|
539
|
+
config=config,
|
540
|
+
advanced_hyperparameters=advanced_hyperparameters,
|
541
|
+
)
|
542
|
+
data_processors[data_type].append(per_data_processor)
|
414
543
|
|
415
|
-
|
416
|
-
|
417
|
-
else:
|
418
|
-
return words_with_offsets
|
544
|
+
# Only keep the modalities with non-empty processors.
|
545
|
+
data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
|
419
546
|
|
547
|
+
if TEXT_NER in data_processors and LABEL in data_processors:
|
548
|
+
# LabelProcessor is not needed for NER tasks as annotations are handled in NerProcessor.
|
549
|
+
data_processors.pop(LABEL)
|
550
|
+
return data_processors
|
420
551
|
|
421
|
-
|
552
|
+
|
553
|
+
def turn_on_off_feature_column_info(
|
554
|
+
data_processors: Dict,
|
555
|
+
flag: bool,
|
556
|
+
):
|
422
557
|
"""
|
423
|
-
|
558
|
+
Turn on or off returning feature column information in data processors.
|
559
|
+
Since feature column information is not always required in training models,
|
560
|
+
we optionally turn this flag on or off.
|
424
561
|
|
425
562
|
Parameters
|
426
563
|
----------
|
427
|
-
|
428
|
-
The
|
564
|
+
data_processors
|
565
|
+
The data processors.
|
566
|
+
flag
|
567
|
+
True/False
|
568
|
+
"""
|
569
|
+
for per_modality_processors in data_processors.values():
|
570
|
+
for per_model_processor in per_modality_processors:
|
571
|
+
# label processor doesn't have requires_column_info.
|
572
|
+
if hasattr(per_model_processor, "requires_column_info"):
|
573
|
+
per_model_processor.requires_column_info = flag
|
574
|
+
|
575
|
+
|
576
|
+
def get_mixup(
|
577
|
+
model_config: DictConfig,
|
578
|
+
mixup_config: DictConfig,
|
579
|
+
num_classes: int,
|
580
|
+
):
|
581
|
+
"""
|
582
|
+
Get the mixup state for loss function choice.
|
583
|
+
Now the mixup can only support image data.
|
584
|
+
And the problem type can not support Regression.
|
585
|
+
Parameters
|
586
|
+
----------
|
587
|
+
model_config
|
588
|
+
The model configs to find image model for the necessity of mixup.
|
589
|
+
mixup_config
|
590
|
+
The mixup configs for mixup and cutmix.
|
591
|
+
num_classes
|
592
|
+
The number of classes in the task. Class <= 1 will cause faults.
|
429
593
|
|
430
594
|
Returns
|
431
595
|
-------
|
432
|
-
|
596
|
+
The mixup is on or off.
|
433
597
|
"""
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
598
|
+
model_active = False
|
599
|
+
names = model_config.names
|
600
|
+
if isinstance(names, str):
|
601
|
+
names = [names]
|
602
|
+
for model_name in names:
|
603
|
+
permodel_config = getattr(model_config, model_name)
|
604
|
+
if hasattr(permodel_config.data_types, IMAGE):
|
605
|
+
model_active = True
|
606
|
+
break
|
607
|
+
|
608
|
+
mixup_active = False
|
609
|
+
if mixup_config is not None and mixup_config.turn_on:
|
610
|
+
mixup_active = (
|
611
|
+
mixup_config.mixup_alpha > 0 or mixup_config.cutmix_alpha > 0.0 or mixup_config.cutmix_minmax is not None
|
612
|
+
)
|
613
|
+
|
614
|
+
mixup_state = model_active & mixup_active & ((num_classes is not None) and (num_classes > 1))
|
615
|
+
mixup_fn = None
|
616
|
+
if mixup_state:
|
617
|
+
mixup_args = dict(
|
618
|
+
mixup_alpha=mixup_config.mixup_alpha,
|
619
|
+
cutmix_alpha=mixup_config.cutmix_alpha,
|
620
|
+
cutmix_minmax=mixup_config.cutmix_minmax,
|
621
|
+
prob=mixup_config.prob,
|
622
|
+
switch_prob=mixup_config.switch_prob,
|
623
|
+
mode=mixup_config.mode,
|
624
|
+
label_smoothing=mixup_config.label_smoothing,
|
625
|
+
num_classes=num_classes,
|
626
|
+
)
|
627
|
+
mixup_fn = MixupModule(**mixup_args)
|
628
|
+
return mixup_state, mixup_fn
|
629
|
+
|
630
|
+
|
631
|
+
def data_to_df(
|
632
|
+
data: Union[pd.DataFrame, Dict, List],
|
633
|
+
required_columns: Optional[List] = None,
|
634
|
+
all_columns: Optional[List] = None,
|
635
|
+
header: Optional[str] = None,
|
636
|
+
):
|
438
637
|
"""
|
439
|
-
|
638
|
+
Convert the input data to a dataframe.
|
440
639
|
|
441
640
|
Parameters
|
442
641
|
----------
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
642
|
+
data
|
643
|
+
Input data provided by users during prediction/evaluation.
|
644
|
+
required_columns
|
645
|
+
Required columns.
|
646
|
+
all_columns
|
647
|
+
All the possible columns got from training data. The column order is preserved.
|
648
|
+
header
|
649
|
+
Provided header to create a dataframe.
|
451
650
|
|
452
651
|
Returns
|
453
652
|
-------
|
454
|
-
|
653
|
+
A dataframe with required columns.
|
455
654
|
"""
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
655
|
+
has_header = True
|
656
|
+
if isinstance(data, pd.DataFrame):
|
657
|
+
pass
|
658
|
+
elif isinstance(data, dict):
|
659
|
+
data = pd.DataFrame(data)
|
660
|
+
elif isinstance(data, list):
|
661
|
+
assert len(data) > 0, f"Expected data to have length > 0, but got {data} of len {len(data)}"
|
662
|
+
if header is None:
|
663
|
+
has_header = False
|
664
|
+
data = pd.DataFrame(data)
|
665
|
+
else:
|
666
|
+
data = pd.DataFrame({header: data})
|
667
|
+
elif isinstance(data, str):
|
668
|
+
df = pd.DataFrame([data])
|
669
|
+
col_name = list(df.columns)[0]
|
670
|
+
if is_image_column(df[col_name], col_name=col_name, image_type=IMAGE_PATH):
|
671
|
+
has_header = False
|
672
|
+
data = df
|
673
|
+
else:
|
674
|
+
data = load_pd.load(data)
|
467
675
|
else:
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
676
|
+
raise NotImplementedError(
|
677
|
+
f"The format of data is not understood. "
|
678
|
+
f'We have type(data)="{type(data)}", but a pd.DataFrame was required.'
|
679
|
+
)
|
680
|
+
|
681
|
+
if required_columns and all_columns:
|
682
|
+
detected_columns = data.columns.values.tolist()
|
683
|
+
missing_columns = []
|
684
|
+
for per_col in required_columns:
|
685
|
+
if per_col not in detected_columns:
|
686
|
+
missing_columns.append(per_col)
|
687
|
+
|
688
|
+
if len(missing_columns) > 0:
|
689
|
+
# assume no column names are provided and users organize data in the same column order of training data.
|
690
|
+
if len(detected_columns) == len(all_columns):
|
691
|
+
if has_header:
|
692
|
+
warnings.warn(
|
693
|
+
f"Replacing detected dataframe columns `{detected_columns}` with columns "
|
694
|
+
f"`{all_columns}` from training data."
|
695
|
+
"Double check the correspondences between them to avoid unexpected behaviors.",
|
696
|
+
UserWarning,
|
697
|
+
)
|
698
|
+
data.rename(dict(zip(detected_columns, required_columns)), axis=1, inplace=True)
|
699
|
+
else:
|
700
|
+
raise ValueError(
|
701
|
+
f"Dataframe columns `{detected_columns}` are detected, but columns `{missing_columns}` are missing. "
|
702
|
+
f"Please double check your input data to provide all the "
|
703
|
+
f"required columns `{required_columns}`."
|
473
704
|
)
|
474
|
-
max_len = min(provided_max_len, default_max_len)
|
475
705
|
|
476
|
-
return
|
706
|
+
return data
|
477
707
|
|
478
708
|
|
479
|
-
def
|
709
|
+
def infer_scarcity_mode_by_data_size(df_train: pd.DataFrame, scarcity_threshold: int = 50):
|
480
710
|
"""
|
481
|
-
|
711
|
+
Infer based on the number of training sample the data scarsity. Select mode accordingly from [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT].
|
482
712
|
|
483
713
|
Parameters
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
714
|
+
---------------
|
715
|
+
df_train
|
716
|
+
Training dataframe
|
717
|
+
scarcity_threshold
|
718
|
+
Threshold number of samples when to select FEW_SHOT mode
|
489
719
|
|
490
720
|
Returns
|
491
|
-
|
492
|
-
|
721
|
+
--------
|
722
|
+
Mode in [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT]
|
493
723
|
"""
|
494
|
-
|
724
|
+
row_num = len(df_train)
|
725
|
+
if row_num < scarcity_threshold:
|
726
|
+
return FEW_SHOT
|
727
|
+
else:
|
728
|
+
return DEFAULT_SHOT
|
495
729
|
|
496
|
-
if not transform_types:
|
497
|
-
return image_transforms
|
498
730
|
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
transform_types = [transform_types]
|
731
|
+
def infer_dtypes_by_model_names(model_config: DictConfig):
|
732
|
+
"""
|
733
|
+
Get data types according to model types.
|
503
734
|
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
args = ast.literal_eval(trans_type[trans_type.find("(") :])
|
520
|
-
else:
|
521
|
-
trans_mode = trans_type
|
522
|
-
|
523
|
-
if trans_mode == "resize_to_square":
|
524
|
-
image_transforms.append(transforms.Resize((size, size), interpolation=BICUBIC))
|
525
|
-
elif trans_mode == "resize_gt_to_square":
|
526
|
-
image_transforms.append(transforms.Resize((size, size), interpolation=NEAREST))
|
527
|
-
elif trans_mode == "resize_shorter_side":
|
528
|
-
image_transforms.append(transforms.Resize(size, interpolation=BICUBIC))
|
529
|
-
elif trans_mode == "center_crop":
|
530
|
-
image_transforms.append(transforms.CenterCrop(size))
|
531
|
-
elif trans_mode == "random_resize_crop":
|
532
|
-
image_transforms.append(transforms.RandomResizedCrop(size))
|
533
|
-
elif trans_mode == "random_horizontal_flip":
|
534
|
-
image_transforms.append(transforms.RandomHorizontalFlip())
|
535
|
-
elif trans_mode == "random_vertical_flip":
|
536
|
-
image_transforms.append(transforms.RandomVerticalFlip())
|
537
|
-
elif trans_mode == "color_jitter":
|
538
|
-
if kargs is not None:
|
539
|
-
image_transforms.append(transforms.ColorJitter(**kargs))
|
540
|
-
elif args is not None:
|
541
|
-
image_transforms.append(transforms.ColorJitter(*args))
|
542
|
-
else:
|
543
|
-
image_transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1))
|
544
|
-
elif trans_mode == "affine":
|
545
|
-
if kargs is not None:
|
546
|
-
image_transforms.append(transforms.RandomAffine(**kargs))
|
547
|
-
elif args is not None:
|
548
|
-
image_transforms.append(transforms.RandomAffine(*args))
|
549
|
-
else:
|
550
|
-
image_transforms.append(transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)))
|
551
|
-
elif trans_mode == "randaug":
|
552
|
-
if kargs is not None:
|
553
|
-
image_transforms.append(RandAugment(**kargs))
|
554
|
-
elif args is not None:
|
555
|
-
image_transforms.append(RandAugment(*args))
|
556
|
-
else:
|
557
|
-
image_transforms.append(RandAugment(2, 9))
|
558
|
-
elif trans_mode == "trivial_augment":
|
559
|
-
image_transforms.append(TrivialAugment(IMAGE, 30))
|
560
|
-
else:
|
561
|
-
raise ValueError(f"unknown transform type: {trans_mode}")
|
735
|
+
Parameters
|
736
|
+
----------
|
737
|
+
model_config
|
738
|
+
Model config from `config.model`.
|
739
|
+
|
740
|
+
Returns
|
741
|
+
-------
|
742
|
+
The data types allowed by models and the default fallback data type.
|
743
|
+
"""
|
744
|
+
allowable_dtypes = []
|
745
|
+
fallback_dtype = None
|
746
|
+
for per_model in model_config.names:
|
747
|
+
per_model_dtypes = OmegaConf.select(model_config, f"{per_model}.data_types")
|
748
|
+
if per_model_dtypes:
|
749
|
+
allowable_dtypes.extend(per_model_dtypes)
|
562
750
|
|
563
|
-
|
751
|
+
allowable_dtypes = set(allowable_dtypes)
|
752
|
+
if allowable_dtypes == {IMAGE, TEXT}:
|
753
|
+
fallback_dtype = TEXT
|
754
|
+
elif len(allowable_dtypes) == 1:
|
755
|
+
fallback_dtype = list(allowable_dtypes)[0]
|
564
756
|
|
757
|
+
return allowable_dtypes, fallback_dtype
|
565
758
|
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
759
|
+
|
760
|
+
def split_train_tuning_data(
|
761
|
+
data: pd.DataFrame,
|
762
|
+
holdout_frac: float = None,
|
763
|
+
problem_type: str = None,
|
764
|
+
label_column: str = None,
|
765
|
+
random_state: int = 0,
|
766
|
+
) -> (pd.DataFrame, pd.DataFrame):
|
571
767
|
"""
|
572
|
-
|
768
|
+
Splits `data` into `train_data` and `tuning_data`.
|
769
|
+
If the problem_type is one of ['binary', 'multiclass']:
|
770
|
+
The split will be done with stratification on the label column.
|
771
|
+
Will guarantee at least 1 sample of every class in `data` will be present in `train_data`.
|
772
|
+
If only 1 sample of a class exists, it will always be put in `train_data` and not `tuning_data`.
|
573
773
|
|
574
774
|
Parameters
|
575
775
|
----------
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
|
580
|
-
|
581
|
-
|
776
|
+
data : pd.DataFrame
|
777
|
+
The data to be split
|
778
|
+
holdout_frac : float, default = None
|
779
|
+
The ratio of data to use as validation.
|
780
|
+
If 0.2, 20% of the data will be used for validation, and 80% for training.
|
781
|
+
If None, the ratio is automatically determined,
|
782
|
+
ranging from 0.2 for small row count to 0.01 for large row count.
|
783
|
+
random_state : int, default = 0
|
784
|
+
The random state to use when splitting the data, to make the splitting process deterministic.
|
785
|
+
If None, a random value is used.
|
582
786
|
|
583
787
|
Returns
|
584
788
|
-------
|
585
|
-
|
789
|
+
Tuple of (train_data, tuning_data) of the split `data`
|
586
790
|
"""
|
587
|
-
|
588
|
-
|
589
|
-
|
590
|
-
|
591
|
-
|
592
|
-
|
791
|
+
if holdout_frac is None:
|
792
|
+
holdout_frac = default_holdout_frac(num_train_rows=len(data), hyperparameter_tune=False)
|
793
|
+
|
794
|
+
# TODO: Hack since the recognized problem types are only binary, multiclass, and regression
|
795
|
+
# Problem types used for purpose of stratification, so regression = no stratification
|
796
|
+
if problem_type in [BINARY, MULTICLASS]:
|
797
|
+
problem_type_for_split = problem_type
|
798
|
+
else:
|
799
|
+
problem_type_for_split = REGRESSION
|
800
|
+
|
801
|
+
train_data, tuning_data = generate_train_test_split_combined(
|
802
|
+
data=data,
|
803
|
+
label=label_column,
|
804
|
+
test_size=holdout_frac,
|
805
|
+
problem_type=problem_type_for_split,
|
806
|
+
random_state=random_state,
|
807
|
+
)
|
808
|
+
return train_data, tuning_data
|
593
809
|
|
594
810
|
|
595
|
-
def
|
811
|
+
def get_detected_data_types(column_types: Dict):
|
596
812
|
"""
|
597
|
-
|
813
|
+
Extract data types from column types.
|
598
814
|
|
599
815
|
Parameters
|
600
816
|
----------
|
601
|
-
|
602
|
-
|
817
|
+
column_types
|
818
|
+
A dataframe's column types.
|
603
819
|
|
604
820
|
Returns
|
605
821
|
-------
|
606
|
-
|
822
|
+
A list of detected data types.
|
607
823
|
"""
|
608
|
-
|
609
|
-
|
610
|
-
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
824
|
+
data_types = []
|
825
|
+
for col_type in column_types.values():
|
826
|
+
if col_type.startswith(IMAGE) and IMAGE not in data_types:
|
827
|
+
data_types.append(IMAGE)
|
828
|
+
elif col_type.startswith(TEXT_NER) and TEXT_NER not in data_types:
|
829
|
+
data_types.append(TEXT_NER)
|
830
|
+
elif col_type.startswith(TEXT) and TEXT not in data_types:
|
831
|
+
data_types.append(TEXT)
|
832
|
+
elif col_type.startswith(DOCUMENT) and DOCUMENT not in data_types:
|
833
|
+
data_types.append(DOCUMENT)
|
834
|
+
elif col_type.startswith(NUMERICAL) and NUMERICAL not in data_types:
|
835
|
+
data_types.append(NUMERICAL)
|
836
|
+
elif col_type.startswith(CATEGORICAL) and CATEGORICAL not in data_types:
|
837
|
+
data_types.append(CATEGORICAL)
|
838
|
+
elif col_type.startswith(ROIS) and ROIS not in data_types:
|
839
|
+
data_types.append(ROIS)
|
840
|
+
|
841
|
+
return data_types
|