autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,701 +0,0 @@
1
- import copy
2
- import logging
3
- import os
4
- import warnings
5
- from typing import Dict, List, Optional, Tuple, Union
6
-
7
- import pandas as pd
8
- from omegaconf import DictConfig, OmegaConf
9
- from torch import nn
10
-
11
- from autogluon.core.utils import default_holdout_frac, generate_train_test_split_combined
12
- from autogluon.core.utils.loaders import load_pd
13
-
14
- from ..constants import (
15
- BINARY,
16
- CATEGORICAL,
17
- DEFAULT_SHOT,
18
- DOCUMENT,
19
- FEW_SHOT,
20
- IMAGE,
21
- IMAGE_PATH,
22
- LABEL,
23
- MMLAB_MODELS,
24
- MULTICLASS,
25
- NER,
26
- NER_ANNOTATION,
27
- NER_TEXT,
28
- NUMERICAL,
29
- REGRESSION,
30
- ROIS,
31
- SAM,
32
- SEMANTIC_SEGMENTATION_IMG,
33
- TEXT,
34
- TEXT_NER,
35
- )
36
- from ..data import (
37
- CategoricalProcessor,
38
- DocumentProcessor,
39
- ImageProcessor,
40
- LabelProcessor,
41
- MixupModule,
42
- MMDetProcessor,
43
- MMOcrProcessor,
44
- MultiModalFeaturePreprocessor,
45
- NerLabelEncoder,
46
- NerProcessor,
47
- NumericalProcessor,
48
- SemanticSegImageProcessor,
49
- TextProcessor,
50
- )
51
- from ..data.infer_types import is_image_column
52
-
53
- logger = logging.getLogger(__name__)
54
-
55
-
56
- def init_df_preprocessor(
57
- config: DictConfig,
58
- column_types: Dict,
59
- label_column: Optional[str] = None,
60
- train_df_x: Optional[pd.DataFrame] = None,
61
- train_df_y: Optional[pd.Series] = None,
62
- ):
63
- """
64
- Initialize the dataframe preprocessor by calling .fit().
65
-
66
- Parameters
67
- ----------
68
- config
69
- A DictConfig containing only the data config.
70
- column_types
71
- A dictionary that maps column names to their data types.
72
- For example: `column_types = {"item_name": "text", "image": "image_path",
73
- "product_description": "text", "height": "numerical"}`
74
- may be used for a table with columns: "item_name", "brand", "product_description", and "height".
75
- label_column
76
- Name of the column that contains the target variable to predict.
77
- train_df_x
78
- A pd.DataFrame containing only the feature columns.
79
- train_df_y
80
- A pd.Series object containing only the label column.
81
-
82
- Returns
83
- -------
84
- Initialized dataframe preprocessor.
85
- """
86
- if label_column in column_types and column_types[label_column] == NER_ANNOTATION:
87
- label_generator = NerLabelEncoder(config)
88
- else:
89
- label_generator = None
90
-
91
- df_preprocessor = MultiModalFeaturePreprocessor(
92
- config=config.data,
93
- column_types=column_types,
94
- label_column=label_column,
95
- label_generator=label_generator,
96
- )
97
- df_preprocessor.fit(
98
- X=train_df_x,
99
- y=train_df_y,
100
- )
101
-
102
- return df_preprocessor
103
-
104
-
105
- def create_data_processor(
106
- data_type: str,
107
- config: DictConfig,
108
- model: nn.Module,
109
- advanced_hyperparameters: Optional[Dict] = None,
110
- ):
111
- """
112
- Create one data processor based on the data type and model.
113
-
114
- Parameters
115
- ----------
116
- data_type
117
- Data type.
118
- config
119
- The config may contain information required by creating a data processor.
120
- In future, we may move the required config information into the model.config
121
- to make the data processor conditioned only on the model itself.
122
- model
123
- The model.
124
-
125
- Returns
126
- -------
127
- One data processor.
128
- """
129
- model_config = getattr(config.model, model.prefix)
130
- if data_type == IMAGE:
131
- train_transforms, val_transforms = get_image_transforms(
132
- model_config=model_config,
133
- model_name=model.prefix,
134
- advanced_hyperparameters=advanced_hyperparameters,
135
- )
136
-
137
- data_processor = ImageProcessor(
138
- model=model,
139
- train_transforms=train_transforms,
140
- val_transforms=val_transforms,
141
- norm_type=model_config.image_norm,
142
- size=model_config.image_size,
143
- max_img_num_per_col=model_config.max_img_num_per_col,
144
- missing_value_strategy=config.data.image.missing_value_strategy,
145
- )
146
- elif data_type == TEXT:
147
- data_processor = TextProcessor(
148
- model=model,
149
- max_len=model_config.max_text_len,
150
- insert_sep=model_config.insert_sep,
151
- text_segment_num=model_config.text_segment_num,
152
- stochastic_chunk=model_config.stochastic_chunk,
153
- text_detection_length=OmegaConf.select(model_config, "text_aug_detect_length"),
154
- text_trivial_aug_maxscale=OmegaConf.select(model_config, "text_trivial_aug_maxscale"),
155
- train_augment_types=OmegaConf.select(model_config, "text_train_augment_types"),
156
- template_config=getattr(config.data, "templates", OmegaConf.create({"turn_on": False})),
157
- normalize_text=getattr(config.data.text, "normalize_text", False),
158
- )
159
- elif data_type == CATEGORICAL:
160
- data_processor = CategoricalProcessor(
161
- model=model,
162
- )
163
- elif data_type == NUMERICAL:
164
- data_processor = NumericalProcessor(
165
- model=model,
166
- merge=model_config.merge,
167
- )
168
- elif data_type == LABEL:
169
- data_processor = LabelProcessor(model=model)
170
- elif data_type == TEXT_NER:
171
- data_processor = NerProcessor(
172
- model=model,
173
- max_len=model_config.max_text_len,
174
- entity_map=config.entity_map,
175
- )
176
- elif data_type == ROIS:
177
- data_processor = MMDetProcessor(
178
- model=model,
179
- max_img_num_per_col=model_config.max_img_num_per_col,
180
- missing_value_strategy=config.data.image.missing_value_strategy,
181
- )
182
- elif data_type == DOCUMENT:
183
- train_transforms, val_transforms = get_image_transforms(
184
- model_config=model_config,
185
- model_name=model.prefix,
186
- advanced_hyperparameters=advanced_hyperparameters,
187
- )
188
- data_processor = DocumentProcessor(
189
- model=model,
190
- train_transform_types=train_transforms,
191
- val_transform_types=val_transforms,
192
- norm_type=model_config.image_norm,
193
- size=model_config.image_size,
194
- text_max_len=model_config.max_text_len,
195
- missing_value_strategy=config.data.document.missing_value_strategy,
196
- )
197
- elif data_type == SEMANTIC_SEGMENTATION_IMG:
198
- data_processor = SemanticSegImageProcessor(
199
- model=model,
200
- img_transforms=model_config.img_transforms,
201
- gt_transforms=model_config.gt_transforms,
202
- train_transforms=model_config.train_transforms,
203
- val_transforms=model_config.val_transforms,
204
- norm_type=model_config.image_norm,
205
- ignore_label=model_config.ignore_label,
206
- )
207
- else:
208
- raise ValueError(f"unknown data type: {data_type}")
209
-
210
- return data_processor
211
-
212
-
213
- def create_fusion_data_processors(
214
- config: DictConfig,
215
- model: nn.Module,
216
- requires_label: Optional[bool] = True,
217
- requires_data: Optional[bool] = True,
218
- advanced_hyperparameters: Optional[Dict] = None,
219
- ):
220
- """
221
- Create the data processors for late-fusion models. This function creates one processor for
222
- each modality of each model. For example, if one model config contains BERT, ViT, and CLIP, then
223
- BERT would have its own text processor, ViT would have its own image processor, and CLIP would have
224
- its own text and image processors. This is to support training arbitrary combinations of single-modal
225
- and multimodal models since two models may share the same modality but have different processing. Text
226
- sequence length is a good example. BERT's sequence length is generally 512, while CLIP uses sequences of
227
- length 77.
228
-
229
- Parameters
230
- ----------
231
- config
232
- A DictConfig object. The model config should be accessible by "config.model".
233
- model
234
- The model object.
235
-
236
- Returns
237
- -------
238
- A dictionary with modalities as the keys. Each modality has a list of processors.
239
- Note that "label" is also treated as a modality for convenience.
240
- """
241
- data_processors = {
242
- IMAGE: [],
243
- TEXT: [],
244
- CATEGORICAL: [],
245
- NUMERICAL: [],
246
- LABEL: [],
247
- ROIS: [],
248
- TEXT_NER: [],
249
- DOCUMENT: [],
250
- SEMANTIC_SEGMENTATION_IMG: [],
251
- }
252
-
253
- model_dict = {model.prefix: model}
254
-
255
- if model.prefix.lower().startswith("fusion"):
256
- for per_model in model.model:
257
- model_dict[per_model.prefix] = per_model
258
-
259
- assert sorted(list(model_dict.keys())) == sorted(config.model.names)
260
-
261
- for per_name, per_model in model_dict.items():
262
- model_config = getattr(config.model, per_model.prefix)
263
- if model_config.data_types is not None:
264
- data_types = model_config.data_types.copy()
265
- else:
266
- data_types = None
267
-
268
- if per_name == NER_TEXT:
269
- # create a multimodal processor for NER.
270
- data_processors[TEXT_NER].append(
271
- create_data_processor(
272
- data_type=TEXT_NER,
273
- config=config,
274
- model=per_model,
275
- )
276
- )
277
- requires_label = False
278
- if data_types is not None and TEXT_NER in data_types:
279
- data_types.remove(TEXT_NER)
280
- elif per_name.lower().startswith(MMLAB_MODELS):
281
- # create a multimodal processor for NER.
282
- data_processors[ROIS].append(
283
- create_data_processor(
284
- data_type=ROIS,
285
- config=config,
286
- model=per_model,
287
- )
288
- )
289
- if data_types is not None and IMAGE in data_types:
290
- data_types.remove(IMAGE)
291
- elif per_name == SAM:
292
- data_processors[SEMANTIC_SEGMENTATION_IMG].append(
293
- create_data_processor(
294
- data_type=SEMANTIC_SEGMENTATION_IMG,
295
- config=config,
296
- model=per_model,
297
- )
298
- )
299
- if data_types is not None and SEMANTIC_SEGMENTATION_IMG in data_types:
300
- data_types.remove(SEMANTIC_SEGMENTATION_IMG)
301
- requires_label = False
302
-
303
- if requires_label:
304
- # each model has its own label processor
305
- label_processor = create_data_processor(
306
- data_type=LABEL,
307
- config=config,
308
- model=per_model,
309
- )
310
- data_processors[LABEL].append(label_processor)
311
-
312
- if requires_data and data_types:
313
- for data_type in data_types:
314
- per_data_processor = create_data_processor(
315
- data_type=data_type,
316
- model=per_model,
317
- config=config,
318
- advanced_hyperparameters=advanced_hyperparameters,
319
- )
320
- data_processors[data_type].append(per_data_processor)
321
-
322
- # Only keep the modalities with non-empty processors.
323
- data_processors = {k: v for k, v in data_processors.items() if len(v) > 0}
324
-
325
- if TEXT_NER in data_processors and LABEL in data_processors:
326
- # LabelProcessor is not needed for NER tasks as annotations are handled in NerProcessor.
327
- data_processors.pop(LABEL)
328
- return data_processors
329
-
330
-
331
- def assign_feature_column_names(
332
- data_processors: Dict,
333
- df_preprocessor: MultiModalFeaturePreprocessor,
334
- ):
335
- """
336
- Assign feature column names to data processors.
337
- This is to patch the data processors saved by AutoGluon 0.4.0.
338
-
339
- Parameters
340
- ----------
341
- data_processors
342
- The data processors.
343
- df_preprocessor
344
- The dataframe preprocessor.
345
-
346
- Returns
347
- -------
348
- The data processors with feature column names added.
349
- """
350
- for per_modality in data_processors:
351
- if per_modality == LABEL or per_modality == TEXT_NER:
352
- continue
353
- for per_model_processor in data_processors[per_modality]:
354
- # requires_column_info=True is used for feature column distillation.
355
- per_model_processor.requires_column_info = False
356
- if per_modality == IMAGE:
357
- per_model_processor.image_column_names = df_preprocessor.image_path_names
358
- elif per_modality == TEXT:
359
- per_model_processor.text_column_names = df_preprocessor.text_feature_names
360
- elif per_modality == NUMERICAL:
361
- per_model_processor.numerical_column_names = df_preprocessor.numerical_feature_names
362
- elif per_modality == CATEGORICAL:
363
- per_model_processor.categorical_column_names = df_preprocessor.categorical_feature_names
364
- else:
365
- raise ValueError(f"Unknown modality: {per_modality}")
366
-
367
- return data_processors
368
-
369
-
370
- def turn_on_off_feature_column_info(
371
- data_processors: Dict,
372
- flag: bool,
373
- ):
374
- """
375
- Turn on or off returning feature column information in data processors.
376
- Since feature column information is not always required in training models,
377
- we optionally turn this flag on or off.
378
-
379
- Parameters
380
- ----------
381
- data_processors
382
- The data processors.
383
- flag
384
- True/False
385
- """
386
- for per_modality_processors in data_processors.values():
387
- for per_model_processor in per_modality_processors:
388
- # label processor doesn't have requires_column_info.
389
- if hasattr(per_model_processor, "requires_column_info"):
390
- per_model_processor.requires_column_info = flag
391
-
392
-
393
- def get_mixup(
394
- model_config: DictConfig,
395
- mixup_config: DictConfig,
396
- num_classes: int,
397
- ):
398
- """
399
- Get the mixup state for loss function choice.
400
- Now the mixup can only support image data.
401
- And the problem type can not support Regression.
402
- Parameters
403
- ----------
404
- model_config
405
- The model configs to find image model for the necessity of mixup.
406
- mixup_config
407
- The mixup configs for mixup and cutmix.
408
- num_classes
409
- The number of classes in the task. Class <= 1 will cause faults.
410
-
411
- Returns
412
- -------
413
- The mixup is on or off.
414
- """
415
- model_active = False
416
- names = model_config.names
417
- if isinstance(names, str):
418
- names = [names]
419
- for model_name in names:
420
- permodel_config = getattr(model_config, model_name)
421
- if hasattr(permodel_config.data_types, IMAGE):
422
- model_active = True
423
- break
424
-
425
- mixup_active = False
426
- if mixup_config is not None and mixup_config.turn_on:
427
- mixup_active = (
428
- mixup_config.mixup_alpha > 0 or mixup_config.cutmix_alpha > 0.0 or mixup_config.cutmix_minmax is not None
429
- )
430
-
431
- mixup_state = model_active & mixup_active & ((num_classes is not None) and (num_classes > 1))
432
- mixup_fn = None
433
- if mixup_state:
434
- mixup_args = dict(
435
- mixup_alpha=mixup_config.mixup_alpha,
436
- cutmix_alpha=mixup_config.cutmix_alpha,
437
- cutmix_minmax=mixup_config.cutmix_minmax,
438
- prob=mixup_config.prob,
439
- switch_prob=mixup_config.switch_prob,
440
- mode=mixup_config.mode,
441
- label_smoothing=mixup_config.label_smoothing,
442
- num_classes=num_classes,
443
- )
444
- mixup_fn = MixupModule(**mixup_args)
445
- return mixup_state, mixup_fn
446
-
447
-
448
- def data_to_df(
449
- data: Union[pd.DataFrame, Dict, List],
450
- required_columns: Optional[List] = None,
451
- all_columns: Optional[List] = None,
452
- header: Optional[str] = None,
453
- ):
454
- """
455
- Convert the input data to a dataframe.
456
-
457
- Parameters
458
- ----------
459
- data
460
- Input data provided by users during prediction/evaluation.
461
- required_columns
462
- Required columns.
463
- all_columns
464
- All the possible columns got from training data. The column order is preserved.
465
- header
466
- Provided header to create a dataframe.
467
-
468
- Returns
469
- -------
470
- A dataframe with required columns.
471
- """
472
- has_header = True
473
- if isinstance(data, pd.DataFrame):
474
- pass
475
- elif isinstance(data, dict):
476
- data = pd.DataFrame(data)
477
- elif isinstance(data, list):
478
- assert len(data) > 0, f"Expected data to have length > 0, but got {data} of len {len(data)}"
479
- if header is None:
480
- has_header = False
481
- data = pd.DataFrame(data)
482
- else:
483
- data = pd.DataFrame({header: data})
484
- elif isinstance(data, str):
485
- df = pd.DataFrame([data])
486
- col_name = list(df.columns)[0]
487
- if is_image_column(df[col_name], col_name=col_name, image_type=IMAGE_PATH):
488
- has_header = False
489
- data = df
490
- else:
491
- data = load_pd.load(data)
492
- else:
493
- raise NotImplementedError(
494
- f"The format of data is not understood. "
495
- f'We have type(data)="{type(data)}", but a pd.DataFrame was required.'
496
- )
497
-
498
- if required_columns and all_columns:
499
- detected_columns = data.columns.values.tolist()
500
- missing_columns = []
501
- for per_col in required_columns:
502
- if per_col not in detected_columns:
503
- missing_columns.append(per_col)
504
-
505
- if len(missing_columns) > 0:
506
- # assume no column names are provided and users organize data in the same column order of training data.
507
- if len(detected_columns) == len(all_columns):
508
- if has_header:
509
- warnings.warn(
510
- f"Replacing detected dataframe columns `{detected_columns}` with columns "
511
- f"`{all_columns}` from training data."
512
- "Double check the correspondences between them to avoid unexpected behaviors.",
513
- UserWarning,
514
- )
515
- data.rename(dict(zip(detected_columns, required_columns)), axis=1, inplace=True)
516
- else:
517
- raise ValueError(
518
- f"Dataframe columns `{detected_columns}` are detected, but columns `{missing_columns}` are missing. "
519
- f"Please double check your input data to provide all the "
520
- f"required columns `{required_columns}`."
521
- )
522
-
523
- return data
524
-
525
-
526
- def infer_scarcity_mode_by_data_size(df_train: pd.DataFrame, scarcity_threshold: int = 50):
527
- """
528
- Infer based on the number of training sample the data scarsity. Select mode accordingly from [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT].
529
-
530
- Parameters
531
- ---------------
532
- df_train
533
- Training dataframe
534
- scarcity_threshold
535
- Threshold number of samples when to select FEW_SHOT mode
536
-
537
- Returns
538
- --------
539
- Mode in [DEFAULT_SHOT, FEW_SHOT, ZERO_SHOT]
540
- """
541
- row_num = len(df_train)
542
- if row_num < scarcity_threshold:
543
- return FEW_SHOT
544
- else:
545
- return DEFAULT_SHOT
546
-
547
-
548
- def infer_dtypes_by_model_names(model_config: DictConfig):
549
- """
550
- Get data types according to model types.
551
-
552
- Parameters
553
- ----------
554
- model_config
555
- Model config from `config.model`.
556
-
557
- Returns
558
- -------
559
- The data types allowed by models and the default fallback data type.
560
- """
561
- allowable_dtypes = []
562
- fallback_dtype = None
563
- for per_model in model_config.names:
564
- per_model_dtypes = OmegaConf.select(model_config, f"{per_model}.data_types")
565
- if per_model_dtypes:
566
- allowable_dtypes.extend(per_model_dtypes)
567
-
568
- allowable_dtypes = set(allowable_dtypes)
569
- if allowable_dtypes == {IMAGE, TEXT}:
570
- fallback_dtype = TEXT
571
- elif len(allowable_dtypes) == 1:
572
- fallback_dtype = list(allowable_dtypes)[0]
573
-
574
- return allowable_dtypes, fallback_dtype
575
-
576
-
577
- def split_train_tuning_data(
578
- data: pd.DataFrame,
579
- holdout_frac: float = None,
580
- problem_type: str = None,
581
- label_column: str = None,
582
- random_state: int = 0,
583
- ) -> (pd.DataFrame, pd.DataFrame):
584
- """
585
- Splits `data` into `train_data` and `tuning_data`.
586
- If the problem_type is one of ['binary', 'multiclass']:
587
- The split will be done with stratification on the label column.
588
- Will guarantee at least 1 sample of every class in `data` will be present in `train_data`.
589
- If only 1 sample of a class exists, it will always be put in `train_data` and not `tuning_data`.
590
-
591
- Parameters
592
- ----------
593
- data : pd.DataFrame
594
- The data to be split
595
- holdout_frac : float, default = None
596
- The ratio of data to use as validation.
597
- If 0.2, 20% of the data will be used for validation, and 80% for training.
598
- If None, the ratio is automatically determined,
599
- ranging from 0.2 for small row count to 0.01 for large row count.
600
- random_state : int, default = 0
601
- The random state to use when splitting the data, to make the splitting process deterministic.
602
- If None, a random value is used.
603
-
604
- Returns
605
- -------
606
- Tuple of (train_data, tuning_data) of the split `data`
607
- """
608
- if holdout_frac is None:
609
- holdout_frac = default_holdout_frac(num_train_rows=len(data), hyperparameter_tune=False)
610
-
611
- # TODO: Hack since the recognized problem types are only binary, multiclass, and regression
612
- # Problem types used for purpose of stratification, so regression = no stratification
613
- if problem_type in [BINARY, MULTICLASS]:
614
- problem_type_for_split = problem_type
615
- else:
616
- problem_type_for_split = REGRESSION
617
-
618
- train_data, tuning_data = generate_train_test_split_combined(
619
- data=data,
620
- label=label_column,
621
- test_size=holdout_frac,
622
- problem_type=problem_type_for_split,
623
- random_state=random_state,
624
- )
625
- return train_data, tuning_data
626
-
627
-
628
- def get_detected_data_types(column_types: Dict):
629
- """
630
- Extract data types from column types.
631
-
632
- Parameters
633
- ----------
634
- column_types
635
- A dataframe's column types.
636
-
637
- Returns
638
- -------
639
- A list of detected data types.
640
- """
641
- data_types = []
642
- for col_type in column_types.values():
643
- if col_type.startswith(IMAGE) and IMAGE not in data_types:
644
- data_types.append(IMAGE)
645
- elif col_type.startswith(TEXT_NER) and TEXT_NER not in data_types:
646
- data_types.append(TEXT_NER)
647
- elif col_type.startswith(TEXT) and TEXT not in data_types:
648
- data_types.append(TEXT)
649
- elif col_type.startswith(DOCUMENT) and DOCUMENT not in data_types:
650
- data_types.append(DOCUMENT)
651
- elif col_type.startswith(NUMERICAL) and NUMERICAL not in data_types:
652
- data_types.append(NUMERICAL)
653
- elif col_type.startswith(CATEGORICAL) and CATEGORICAL not in data_types:
654
- data_types.append(CATEGORICAL)
655
- elif col_type.startswith(ROIS) and ROIS not in data_types:
656
- data_types.append(ROIS)
657
-
658
- return data_types
659
-
660
-
661
- def get_image_transforms(model_config: DictConfig, model_name: str, advanced_hyperparameters: Dict):
662
- """
663
- Get the image transforms of one image-related model.
664
- Use the transforms in advanced_hyperparameters with higher priority.
665
-
666
- Parameters
667
- ----------
668
- model_config
669
- Config of one model.
670
- model_name
671
- Name of one model.
672
- advanced_hyperparameters
673
- The advanced hyperparameters whose values are complex objects.
674
-
675
- Returns
676
- -------
677
- The image transforms used in training and validation.
678
- """
679
- train_transform_key = f"model.{model_name}.train_transforms"
680
- val_transform_key = f"model.{model_name}.val_transforms"
681
- if advanced_hyperparameters and train_transform_key in advanced_hyperparameters:
682
- train_transforms = advanced_hyperparameters[train_transform_key]
683
- else:
684
- train_transforms = (
685
- model_config.train_transform_types
686
- if hasattr(model_config, "train_transform_types")
687
- else model_config.train_transforms
688
- )
689
- train_transforms = list(train_transforms)
690
-
691
- if advanced_hyperparameters and val_transform_key in advanced_hyperparameters:
692
- val_transforms = advanced_hyperparameters[val_transform_key]
693
- else:
694
- val_transforms = (
695
- model_config.val_transform_types
696
- if hasattr(model_config, "val_transform_types")
697
- else model_config.val_transforms
698
- )
699
- val_transforms = list(val_transforms)
700
-
701
- return train_transforms, val_transforms