autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -0,0 +1,130 @@
|
|
1
|
+
import contextlib
|
2
|
+
import logging
|
3
|
+
import warnings
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
def convert_to_torch_precision(precision: Union[int, str]):
|
12
|
+
"""
|
13
|
+
Convert a precision integer or string to the corresponding torch precision.
|
14
|
+
|
15
|
+
Parameters
|
16
|
+
----------
|
17
|
+
precision
|
18
|
+
a precision integer or string from the config.
|
19
|
+
|
20
|
+
Returns
|
21
|
+
-------
|
22
|
+
A torch precision object.
|
23
|
+
"""
|
24
|
+
precision_mapping = {
|
25
|
+
16: torch.half,
|
26
|
+
"16": torch.half,
|
27
|
+
"16-mixed": torch.half,
|
28
|
+
"16-true": torch.half,
|
29
|
+
"bf16": torch.bfloat16,
|
30
|
+
"bf16-mixed": torch.bfloat16,
|
31
|
+
"bf16-true": torch.bfloat16,
|
32
|
+
32: torch.float32,
|
33
|
+
"32": torch.float32,
|
34
|
+
"32-true": torch.float32,
|
35
|
+
64: torch.float64,
|
36
|
+
"64": torch.float64,
|
37
|
+
"64-true": torch.float64,
|
38
|
+
}
|
39
|
+
|
40
|
+
if precision in precision_mapping:
|
41
|
+
precision = precision_mapping[precision]
|
42
|
+
else:
|
43
|
+
raise ValueError(f"Unknown precision: {precision}")
|
44
|
+
|
45
|
+
return precision
|
46
|
+
|
47
|
+
|
48
|
+
def infer_precision(
|
49
|
+
num_gpus: int, precision: Union[int, str], as_torch: Optional[bool] = False, cpu_only_warning: bool = True
|
50
|
+
):
|
51
|
+
"""
|
52
|
+
Infer the proper precision based on the environment setup and the provided precision.
|
53
|
+
|
54
|
+
Parameters
|
55
|
+
----------
|
56
|
+
num_gpus
|
57
|
+
GPU number.
|
58
|
+
precision
|
59
|
+
The precision provided in config.
|
60
|
+
as_torch
|
61
|
+
Whether to convert the precision to the Pytorch format.
|
62
|
+
cpu_only_warning
|
63
|
+
Whether to turn on warning if the instance has only CPU.
|
64
|
+
|
65
|
+
Returns
|
66
|
+
-------
|
67
|
+
The inferred precision.
|
68
|
+
"""
|
69
|
+
if num_gpus == 0: # CPU only prediction
|
70
|
+
if cpu_only_warning:
|
71
|
+
warnings.warn(
|
72
|
+
"Only CPU is detected in the instance. "
|
73
|
+
"This may result in slow speed for MultiModalPredictor. "
|
74
|
+
"Consider using an instance with GPU support.",
|
75
|
+
UserWarning,
|
76
|
+
)
|
77
|
+
precision = 32 # Force to use fp32 for training since 16-mixed is not available in CPU
|
78
|
+
else:
|
79
|
+
if isinstance(precision, str) and "bf16" in precision and not torch.cuda.is_bf16_supported():
|
80
|
+
warnings.warn(
|
81
|
+
f"{precision} is not supported by the GPU device / cuda version. "
|
82
|
+
"Consider using GPU devices with versions after Amphere or upgrading cuda to be >=11.0. "
|
83
|
+
f"MultiModalPredictor is switching precision from {precision} to 32.",
|
84
|
+
UserWarning,
|
85
|
+
)
|
86
|
+
precision = 32
|
87
|
+
|
88
|
+
if as_torch:
|
89
|
+
precision = convert_to_torch_precision(precision=precision)
|
90
|
+
|
91
|
+
return precision
|
92
|
+
|
93
|
+
|
94
|
+
@contextlib.contextmanager
|
95
|
+
def double_precision_context():
|
96
|
+
"""
|
97
|
+
Double precision context manager.
|
98
|
+
"""
|
99
|
+
default_dtype = torch.get_default_dtype()
|
100
|
+
torch.set_default_dtype(torch.float64)
|
101
|
+
yield
|
102
|
+
torch.set_default_dtype(default_dtype)
|
103
|
+
|
104
|
+
|
105
|
+
def get_precision_context(precision: Union[int, str], device_type: Optional[str] = None):
|
106
|
+
"""
|
107
|
+
Choose the proper context manager based on the precision.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
precision
|
112
|
+
The precision.
|
113
|
+
device_type
|
114
|
+
gpu or cpu.
|
115
|
+
|
116
|
+
Returns
|
117
|
+
-------
|
118
|
+
A precision context manager.
|
119
|
+
"""
|
120
|
+
precision = convert_to_torch_precision(precision=precision)
|
121
|
+
|
122
|
+
if precision in [torch.half, torch.float16, torch.bfloat16]:
|
123
|
+
return torch.autocast(device_type=device_type, dtype=precision)
|
124
|
+
if precision == torch.float32:
|
125
|
+
assert torch.get_default_dtype() == torch.float32
|
126
|
+
return contextlib.nullcontext()
|
127
|
+
elif precision == torch.float64:
|
128
|
+
return double_precision_context()
|
129
|
+
else:
|
130
|
+
raise ValueError(f"Unknown precision: {precision}")
|
@@ -2,17 +2,17 @@ from typing import List, Optional
|
|
2
2
|
|
3
3
|
from autogluon.common.utils.try_import import try_import_ray
|
4
4
|
|
5
|
-
from
|
5
|
+
from ..constants import (
|
6
6
|
BEST_QUALITY,
|
7
7
|
BINARY,
|
8
8
|
DATA,
|
9
9
|
DEFAULT,
|
10
|
-
|
10
|
+
ENV,
|
11
11
|
HIGH_QUALITY,
|
12
12
|
MEDIUM_QUALITY,
|
13
13
|
MODEL,
|
14
14
|
MULTICLASS,
|
15
|
-
|
15
|
+
OPTIM,
|
16
16
|
REGRESSION,
|
17
17
|
)
|
18
18
|
from .registry import Registry
|
@@ -32,9 +32,9 @@ def get_default_hpo_setup():
|
|
32
32
|
}
|
33
33
|
|
34
34
|
default_tunable_hyperparameters = {
|
35
|
-
"
|
36
|
-
"
|
37
|
-
"
|
35
|
+
"optim.lr": tune.loguniform(1e-5, 1e-2),
|
36
|
+
"optim.optim_type": tune.choice(["adamw", "sgd"]),
|
37
|
+
"optim.max_epochs": tune.choice(list(range(5, 31))),
|
38
38
|
"env.batch_size": tune.choice([16, 32, 64, 128, 256]),
|
39
39
|
}
|
40
40
|
|
@@ -146,7 +146,7 @@ def default(presets: str = DEFAULT):
|
|
146
146
|
"model.hf_text.checkpoint_name": "google/electra-small-discriminator",
|
147
147
|
"model.timm_image.checkpoint_name": "mobilenetv3_large_100",
|
148
148
|
"model.document_transformer.checkpoint_name": "microsoft/layoutlmv2-base-uncased",
|
149
|
-
"
|
149
|
+
"optim.lr": 4e-4,
|
150
150
|
}
|
151
151
|
)
|
152
152
|
elif presets == BEST_QUALITY:
|
@@ -186,7 +186,7 @@ def default(presets: str = DEFAULT):
|
|
186
186
|
hyperparameters.update(
|
187
187
|
{
|
188
188
|
"model.hf_text.checkpoint_name": "microsoft/mdeberta-v3-base",
|
189
|
-
"
|
189
|
+
"optim.top_k": 1,
|
190
190
|
"env.precision": "bf16-mixed",
|
191
191
|
"env.per_gpu_batch_size": 4,
|
192
192
|
}
|
@@ -222,7 +222,7 @@ def few_shot_classification(presets: str = DEFAULT):
|
|
222
222
|
"model.names": ["hf_text", "clip"],
|
223
223
|
"model.clip.checkpoint_name": "openai/clip-vit-large-patch14-336",
|
224
224
|
"model.clip.image_size": 336,
|
225
|
-
"env.
|
225
|
+
"env.inference_batch_size_ratio": 1,
|
226
226
|
}
|
227
227
|
)
|
228
228
|
hyperparameter_tune_kwargs = {}
|
@@ -258,14 +258,14 @@ def zero_shot_image_classification(presets: str = DEFAULT):
|
|
258
258
|
{
|
259
259
|
"model.clip.checkpoint_name": "openai/clip-vit-large-patch14-336",
|
260
260
|
"model.clip.image_size": 336,
|
261
|
-
"env.
|
261
|
+
"env.inference_batch_size_ratio": 1,
|
262
262
|
}
|
263
263
|
)
|
264
264
|
elif presets == HIGH_QUALITY:
|
265
265
|
hyperparameters.update(
|
266
266
|
{
|
267
267
|
"model.clip.checkpoint_name": "openai/clip-vit-large-patch14",
|
268
|
-
"env.
|
268
|
+
"env.inference_batch_size_ratio": 1,
|
269
269
|
}
|
270
270
|
)
|
271
271
|
elif presets == MEDIUM_QUALITY:
|
@@ -300,27 +300,27 @@ def object_detection(presets: str = DEFAULT):
|
|
300
300
|
hyperparameters = {
|
301
301
|
"model.names": ["mmdet_image"],
|
302
302
|
"model.mmdet_image.frozen_layers": [],
|
303
|
-
"
|
304
|
-
"
|
305
|
-
"
|
303
|
+
"optim.patience": 20,
|
304
|
+
"optim.val_check_interval": 1.0,
|
305
|
+
"optim.check_val_every_n_epoch": 1,
|
306
306
|
"env.batch_size": 32,
|
307
307
|
"env.per_gpu_batch_size": 1,
|
308
308
|
"env.num_workers": 2,
|
309
|
-
"
|
310
|
-
"
|
311
|
-
"
|
312
|
-
"
|
313
|
-
"
|
314
|
-
"
|
315
|
-
"
|
316
|
-
"
|
317
|
-
"
|
318
|
-
"
|
319
|
-
"env.
|
309
|
+
"optim.lr": 1e-5,
|
310
|
+
"optim.weight_decay": 1e-4,
|
311
|
+
"optim.lr_mult": 10,
|
312
|
+
"optim.lr_choice": "two_stages",
|
313
|
+
"optim.lr_schedule": "multi_step",
|
314
|
+
"optim.gradient_clip_val": 0.1,
|
315
|
+
"optim.max_epochs": 60,
|
316
|
+
"optim.warmup_steps": 0.0,
|
317
|
+
"optim.top_k": 1,
|
318
|
+
"optim.top_k_average_method": "best",
|
319
|
+
"env.inference_batch_size_ratio": 1,
|
320
320
|
"env.strategy": "ddp",
|
321
321
|
"env.auto_select_gpus": True, # Turn on for detection to return devices in a list, TODO: fix the extra GPU usage bug
|
322
322
|
"env.num_gpus": -1,
|
323
|
-
"
|
323
|
+
"optim.lr_decay": 0.9,
|
324
324
|
}
|
325
325
|
hyperparameter_tune_kwargs = {}
|
326
326
|
|
@@ -335,15 +335,15 @@ def object_detection(presets: str = DEFAULT):
|
|
335
335
|
{
|
336
336
|
"model.mmdet_image.checkpoint_name": "yolox_l",
|
337
337
|
"env.per_gpu_batch_size": 2, # Works on 8G GPU
|
338
|
-
"
|
339
|
-
"
|
340
|
-
"
|
341
|
-
"
|
342
|
-
"
|
343
|
-
"
|
344
|
-
"
|
345
|
-
"
|
346
|
-
"
|
338
|
+
"optim.lr": 5e-5,
|
339
|
+
"optim.patience": 5,
|
340
|
+
"optim.max_epochs": 50,
|
341
|
+
"optim.val_check_interval": 1.0,
|
342
|
+
"optim.check_val_every_n_epoch": 3,
|
343
|
+
"optim.lr_mult": 100,
|
344
|
+
"optim.weight_decay": 1e-3,
|
345
|
+
"optim.lr_schedule": "cosine_decay",
|
346
|
+
"optim.gradient_clip_val": 1,
|
347
347
|
}
|
348
348
|
)
|
349
349
|
elif presets in [DEFAULT, HIGH_QUALITY]:
|
@@ -386,32 +386,32 @@ def semantic_segmentation(presets: str = DEFAULT):
|
|
386
386
|
"model.sam.checkpoint_name": "facebook/sam-vit-huge",
|
387
387
|
"env.batch_size": 4,
|
388
388
|
"env.per_gpu_batch_size": 1,
|
389
|
-
"env.
|
389
|
+
"env.inference_batch_size_ratio": 1,
|
390
390
|
"env.strategy": "ddp_find_unused_parameters_true",
|
391
391
|
"env.auto_select_gpus": False,
|
392
392
|
"env.num_gpus": -1,
|
393
393
|
"env.num_workers": 4,
|
394
394
|
"env.precision": "16-mixed",
|
395
|
-
"
|
396
|
-
"
|
397
|
-
"
|
398
|
-
"
|
399
|
-
"
|
400
|
-
"
|
401
|
-
"
|
402
|
-
"
|
403
|
-
"
|
404
|
-
"
|
405
|
-
"
|
406
|
-
"
|
407
|
-
"
|
408
|
-
"
|
409
|
-
"
|
410
|
-
"
|
411
|
-
"
|
412
|
-
"
|
413
|
-
"
|
414
|
-
"
|
395
|
+
"optim.lr": 1e-4,
|
396
|
+
"optim.loss_func": "structure_loss",
|
397
|
+
"optim.lr_decay": 0,
|
398
|
+
"optim.lr_mult": 1,
|
399
|
+
"optim.lr_choice": "single_stage",
|
400
|
+
"optim.lr_schedule": "polynomial_decay",
|
401
|
+
"optim.max_epochs": 30,
|
402
|
+
"optim.top_k": 3,
|
403
|
+
"optim.top_k_average_method": "best",
|
404
|
+
"optim.warmup_steps": 0.0,
|
405
|
+
"optim.weight_decay": 0.0001,
|
406
|
+
"optim.patience": 10,
|
407
|
+
"optim.val_check_interval": 1.0,
|
408
|
+
"optim.check_val_every_n_epoch": 1,
|
409
|
+
"optim.peft": "lora",
|
410
|
+
"optim.lora.module_filter": [".*vision_encoder.*attn"],
|
411
|
+
"optim.lora.filter": ["q", "v"],
|
412
|
+
"optim.extra_trainable_params": [".*mask_decoder"],
|
413
|
+
"optim.lora.r": 3,
|
414
|
+
"optim.lora.alpha": 32,
|
415
415
|
}
|
416
416
|
hyperparameter_tune_kwargs = {}
|
417
417
|
|
@@ -444,7 +444,7 @@ def ocr_text_detection(presets: str = DEFAULT):
|
|
444
444
|
hyperparameters = {
|
445
445
|
"model.names": ["mmocr_text_detection"],
|
446
446
|
"model.mmocr_text_detection.checkpoint_name": "TextSnake",
|
447
|
-
"env.
|
447
|
+
"env.inference_batch_size_ratio": 1,
|
448
448
|
"env.num_gpus": 1,
|
449
449
|
"env.precision": 32,
|
450
450
|
}
|
@@ -479,7 +479,7 @@ def ocr_text_recognition(presets: str = DEFAULT):
|
|
479
479
|
hyperparameters = {
|
480
480
|
"model.names": ["mmocr_text_recognition"],
|
481
481
|
"model.mmocr_text_recognition.checkpoint_name": "ABINet",
|
482
|
-
"env.
|
482
|
+
"env.inference_batch_size_ratio": 1,
|
483
483
|
"env.num_gpus": 1,
|
484
484
|
"env.precision": 32,
|
485
485
|
}
|
@@ -514,7 +514,7 @@ def feature_extraction(presets: str = DEFAULT): # TODO: rename the problem type
|
|
514
514
|
"model.names": ["hf_text"],
|
515
515
|
"model.hf_text.checkpoint_name": "sentence-transformers/msmarco-MiniLM-L-12-v3",
|
516
516
|
"model.hf_text.pooling_mode": "mean",
|
517
|
-
"env.
|
517
|
+
"env.inference_batch_size_ratio": 1,
|
518
518
|
}
|
519
519
|
hyperparameter_tune_kwargs = {}
|
520
520
|
|
@@ -651,7 +651,7 @@ def image_text_similarity(presets: str = DEFAULT):
|
|
651
651
|
hyperparameters = {
|
652
652
|
"model.names": ["clip"],
|
653
653
|
"matcher.loss.type": "multi_negatives_softmax_loss",
|
654
|
-
"
|
654
|
+
"optim.lr": 1e-5,
|
655
655
|
}
|
656
656
|
hyperparameter_tune_kwargs = {}
|
657
657
|
|
@@ -747,9 +747,196 @@ def ner(presets: str = DEFAULT):
|
|
747
747
|
return hyperparameters, hyperparameter_tune_kwargs
|
748
748
|
|
749
749
|
|
750
|
-
|
750
|
+
@automm_presets.register()
|
751
|
+
def ensemble(presets: str = DEFAULT):
|
752
|
+
hyperparameters = {
|
753
|
+
"lf_mlp": {
|
754
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
755
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
756
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
757
|
+
"data.categorical.convert_to_text": False,
|
758
|
+
"data.numerical.convert_to_text": False,
|
759
|
+
"optim.cross_modal_align": "null",
|
760
|
+
"data.modality_dropout": 0,
|
761
|
+
"model.timm_image.use_learnable_image": False,
|
762
|
+
"optim.lemda.turn_on": False,
|
763
|
+
},
|
764
|
+
"lf_transformer": {
|
765
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_transformer"],
|
766
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
767
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
768
|
+
"data.categorical.convert_to_text": False,
|
769
|
+
"data.numerical.convert_to_text": False,
|
770
|
+
"optim.cross_modal_align": "null",
|
771
|
+
"data.modality_dropout": 0,
|
772
|
+
"model.timm_image.use_learnable_image": False,
|
773
|
+
"optim.lemda.turn_on": False,
|
774
|
+
},
|
775
|
+
"lf_clip": {
|
776
|
+
"model.names": ["ft_transformer", "clip_image", "clip_text", "fusion_mlp"],
|
777
|
+
"model.clip_image.data_types": ["image"],
|
778
|
+
"model.clip_text.data_types": ["text"],
|
779
|
+
"model.clip_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
780
|
+
"model.clip_text.text_trivial_aug_maxscale": 0,
|
781
|
+
"data.categorical.convert_to_text": False,
|
782
|
+
"data.numerical.convert_to_text": False,
|
783
|
+
"optim.cross_modal_align": "null",
|
784
|
+
"data.modality_dropout": 0,
|
785
|
+
"model.clip_image.use_learnable_image": False,
|
786
|
+
"optim.lemda.turn_on": False,
|
787
|
+
},
|
788
|
+
"early_fusion": {
|
789
|
+
"model.names": ["meta_transformer"],
|
790
|
+
"model.meta_transformer.checkpoint_path": "null",
|
791
|
+
"model.meta_transformer.train_transforms": ["resize_shorter_side", "center_crop"],
|
792
|
+
"model.meta_transformer.text_trivial_aug_maxscale": 0,
|
793
|
+
"data.categorical.convert_to_text": False,
|
794
|
+
"data.numerical.convert_to_text": False,
|
795
|
+
"optim.cross_modal_align": "null",
|
796
|
+
"data.modality_dropout": 0,
|
797
|
+
"model.meta_transformer.use_learnable_image": False,
|
798
|
+
"optim.lemda.turn_on": False,
|
799
|
+
},
|
800
|
+
"convert_categorical_to_text": {
|
801
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
802
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
803
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
804
|
+
"data.categorical.convert_to_text": True,
|
805
|
+
"data.categorical.convert_to_text_template": "latex",
|
806
|
+
"data.numerical.convert_to_text": False,
|
807
|
+
"optim.cross_modal_align": "null",
|
808
|
+
"data.modality_dropout": 0,
|
809
|
+
"model.timm_image.use_learnable_image": False,
|
810
|
+
"optim.lemda.turn_on": False,
|
811
|
+
},
|
812
|
+
"convert_numeric_to_text": {
|
813
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
814
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
815
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
816
|
+
"data.categorical.convert_to_text": False,
|
817
|
+
"data.numerical.convert_to_text": True,
|
818
|
+
"optim.cross_modal_align": "null",
|
819
|
+
"data.modality_dropout": 0,
|
820
|
+
"model.timm_image.use_learnable_image": False,
|
821
|
+
"optim.lemda.turn_on": False,
|
822
|
+
},
|
823
|
+
"cross_modal_align_pos_only": {
|
824
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
825
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
826
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
827
|
+
"data.categorical.convert_to_text": False,
|
828
|
+
"data.numerical.convert_to_text": False,
|
829
|
+
"optim.cross_modal_align": "positive_only",
|
830
|
+
"optim.cross_modal_align_weight": 1,
|
831
|
+
"data.modality_dropout": 0,
|
832
|
+
"model.timm_image.use_learnable_image": False,
|
833
|
+
"optim.lemda.turn_on": False,
|
834
|
+
},
|
835
|
+
"input_aug": {
|
836
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
837
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop", "trivial_augment"],
|
838
|
+
"model.hf_text.text_trivial_aug_maxscale": 0.1,
|
839
|
+
"data.categorical.convert_to_text": False,
|
840
|
+
"data.numerical.convert_to_text": False,
|
841
|
+
"optim.cross_modal_align": "null",
|
842
|
+
"data.modality_dropout": 0,
|
843
|
+
"model.timm_image.use_learnable_image": False,
|
844
|
+
"optim.lemda.turn_on": False,
|
845
|
+
},
|
846
|
+
"feature_aug_lemda": {
|
847
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
848
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
849
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
850
|
+
"data.categorical.convert_to_text": False,
|
851
|
+
"data.numerical.convert_to_text": False,
|
852
|
+
"optim.cross_modal_align": "null",
|
853
|
+
"data.modality_dropout": 0,
|
854
|
+
"model.timm_image.use_learnable_image": False,
|
855
|
+
"optim.lemda.turn_on": True,
|
856
|
+
"optim.automatic_optimization": False,
|
857
|
+
},
|
858
|
+
"modality_dropout": {
|
859
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
860
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
861
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
862
|
+
"data.categorical.convert_to_text": False,
|
863
|
+
"data.numerical.convert_to_text": False,
|
864
|
+
"optim.cross_modal_align": "null",
|
865
|
+
"data.modality_dropout": 0.2,
|
866
|
+
"model.timm_image.use_learnable_image": False,
|
867
|
+
"optim.lemda.turn_on": False,
|
868
|
+
},
|
869
|
+
"learnable_image": {
|
870
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
871
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
872
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
873
|
+
"data.categorical.convert_to_text": False,
|
874
|
+
"data.numerical.convert_to_text": False,
|
875
|
+
"optim.cross_modal_align": "null",
|
876
|
+
"data.modality_dropout": 0,
|
877
|
+
"model.timm_image.use_learnable_image": True,
|
878
|
+
"optim.lemda.turn_on": False,
|
879
|
+
},
|
880
|
+
"modality_dropout_and_learnable_image": {
|
881
|
+
"model.names": ["ft_transformer", "timm_image", "hf_text", "fusion_mlp"],
|
882
|
+
"model.timm_image.train_transforms": ["resize_shorter_side", "center_crop"],
|
883
|
+
"model.hf_text.text_trivial_aug_maxscale": 0,
|
884
|
+
"data.categorical.convert_to_text": False,
|
885
|
+
"data.numerical.convert_to_text": False,
|
886
|
+
"optim.cross_modal_align": "null",
|
887
|
+
"data.modality_dropout": 0.2,
|
888
|
+
"model.timm_image.use_learnable_image": True,
|
889
|
+
"optim.lemda.turn_on": False,
|
890
|
+
},
|
891
|
+
}
|
892
|
+
|
893
|
+
if presets in [DEFAULT, HIGH_QUALITY]:
|
894
|
+
for v in hyperparameters.values():
|
895
|
+
if "timm_image" in v["model.names"]:
|
896
|
+
v["model.timm_image.checkpoint_name"] = "caformer_b36.sail_in22k_ft_in1k"
|
897
|
+
if "hf_text" in v["model.names"]:
|
898
|
+
v["model.hf_text.checkpoint_name"] = "google/electra-base-discriminator"
|
899
|
+
if "meta_transformer" in v["model.names"]:
|
900
|
+
v["model.meta_transformer.model_version"] = "base"
|
901
|
+
if "clip_image" in v["model.names"]:
|
902
|
+
v["model.clip_image.checkpoint_name"] = "openai/clip-vit-base-patch32"
|
903
|
+
if "clip_text" in v["model.names"]:
|
904
|
+
v["model.clip_text.checkpoint_name"] = "openai/clip-vit-base-patch32"
|
905
|
+
|
906
|
+
elif presets == MEDIUM_QUALITY:
|
907
|
+
for v in hyperparameters.values():
|
908
|
+
if "timm_image" in v["model.names"]:
|
909
|
+
v["model.timm_image.checkpoint_name"] = "mobilenetv3_large_100"
|
910
|
+
if "hf_text" in v["model.names"]:
|
911
|
+
v["model.hf_text.checkpoint_name"] = "google/electra-small-discriminator"
|
912
|
+
if "meta_transformer" in v["model.names"]:
|
913
|
+
v["model.meta_transformer.model_version"] = "base"
|
914
|
+
if "clip_image" in v["model.names"]:
|
915
|
+
v["model.clip_image.checkpoint_name"] = "openai/clip-vit-base-patch32"
|
916
|
+
if "clip_text" in v["model.names"]:
|
917
|
+
v["model.clip_text.checkpoint_name"] = "openai/clip-vit-base-patch32"
|
918
|
+
elif presets == BEST_QUALITY:
|
919
|
+
for v in hyperparameters.values():
|
920
|
+
if "timm_image" in v["model.names"]:
|
921
|
+
v["model.timm_image.checkpoint_name"] = "swin_large_patch4_window7_224"
|
922
|
+
if "hf_text" in v["model.names"]:
|
923
|
+
v["model.hf_text.checkpoint_name"] = "microsoft/deberta-v3-base"
|
924
|
+
if "meta_transformer" in v["model.names"]:
|
925
|
+
v["model.meta_transformer.model_version"] = "large"
|
926
|
+
if "clip_image" in v["model.names"]:
|
927
|
+
v["model.clip_image.checkpoint_name"] = "openai/clip-vit-large-patch14"
|
928
|
+
if "clip_text" in v["model.names"]:
|
929
|
+
v["model.clip_text.checkpoint_name"] = "openai/clip-vit-large-patch14"
|
930
|
+
else:
|
931
|
+
raise ValueError(f"Unknown preset type: {presets}")
|
932
|
+
|
933
|
+
return hyperparameters, None
|
934
|
+
|
935
|
+
|
936
|
+
def list_presets(verbose: bool = False):
|
751
937
|
"""
|
752
938
|
List all available presets.
|
939
|
+
|
753
940
|
Returns
|
754
941
|
-------
|
755
942
|
A list of presets.
|
@@ -765,7 +952,7 @@ def list_automm_presets(verbose: bool = False):
|
|
765
952
|
return preset_details
|
766
953
|
|
767
954
|
|
768
|
-
def
|
955
|
+
def get_basic_config(extra: Optional[List[str]] = None):
|
769
956
|
"""
|
770
957
|
Get the basic config of AutoMM.
|
771
958
|
|
@@ -776,13 +963,13 @@ def get_basic_automm_config(extra: Optional[List[str]] = None):
|
|
776
963
|
|
777
964
|
Returns
|
778
965
|
-------
|
779
|
-
A dict config with keys: MODEL, DATA,
|
966
|
+
A dict config with keys: MODEL, DATA, OPTIM, ENV, and their default values.
|
780
967
|
"""
|
781
968
|
config = {
|
782
969
|
MODEL: DEFAULT,
|
783
970
|
DATA: DEFAULT,
|
784
|
-
|
785
|
-
|
971
|
+
OPTIM: DEFAULT,
|
972
|
+
ENV: DEFAULT,
|
786
973
|
}
|
787
974
|
if extra:
|
788
975
|
for k in extra:
|
@@ -791,7 +978,7 @@ def get_basic_automm_config(extra: Optional[List[str]] = None):
|
|
791
978
|
return config
|
792
979
|
|
793
980
|
|
794
|
-
def
|
981
|
+
def get_presets(problem_type: str, presets: str):
|
795
982
|
"""
|
796
983
|
Get the default hyperparameters and hyperparameter_tune_kwargs given problem type and presets.
|
797
984
|
|
@@ -831,3 +1018,9 @@ def get_automm_presets(problem_type: str, presets: str):
|
|
831
1018
|
)
|
832
1019
|
|
833
1020
|
return hyperparameters, hyperparameter_tune_kwargs
|
1021
|
+
|
1022
|
+
|
1023
|
+
def get_ensemble_presets(presets):
|
1024
|
+
if not presets:
|
1025
|
+
presets = DEFAULT
|
1026
|
+
return automm_presets.create("ensemble", presets)
|
@@ -1,9 +1,10 @@
|
|
1
1
|
"""Problem types supported in MultiModalPredictor"""
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from dataclasses import dataclass, field
|
4
5
|
from typing import List, Optional, Set
|
5
6
|
|
6
|
-
from
|
7
|
+
from ..constants import (
|
7
8
|
ACCURACY,
|
8
9
|
BINARY,
|
9
10
|
CATEGORICAL,
|
@@ -38,6 +39,8 @@ from .constants import (
|
|
38
39
|
)
|
39
40
|
from .registry import Registry
|
40
41
|
|
42
|
+
logger = logging.getLogger(__name__)
|
43
|
+
|
41
44
|
PROBLEM_TYPES_REG = Registry("problem_type_properties")
|
42
45
|
|
43
46
|
|
@@ -277,3 +280,29 @@ PROBLEM_TYPES_REG.register(
|
|
277
280
|
_fallback_validation_metric=ACCURACY,
|
278
281
|
),
|
279
282
|
)
|
283
|
+
|
284
|
+
|
285
|
+
def infer_problem_type_by_eval_metric(eval_metric_name: str, problem_type: str):
|
286
|
+
if eval_metric_name is not None and eval_metric_name.lower() in [
|
287
|
+
"rmse",
|
288
|
+
"r2",
|
289
|
+
"pearsonr",
|
290
|
+
"spearmanr",
|
291
|
+
]:
|
292
|
+
if problem_type is None:
|
293
|
+
logger.debug(
|
294
|
+
f"Infer problem type to be a regression problem "
|
295
|
+
f"since the evaluation metric is set as {eval_metric_name}."
|
296
|
+
)
|
297
|
+
problem_type = REGRESSION
|
298
|
+
else:
|
299
|
+
problem_prop = PROBLEM_TYPES_REG.get(problem_type)
|
300
|
+
if NUMERICAL not in problem_prop.supported_label_type:
|
301
|
+
raise ValueError(
|
302
|
+
f"The provided evaluation metric will require the problem "
|
303
|
+
f"to support label type = {NUMERICAL}. However, "
|
304
|
+
f"the provided problem type = {problem_type} only "
|
305
|
+
f"supports label type = {problem_prop.supported_label_type}."
|
306
|
+
)
|
307
|
+
|
308
|
+
return problem_type
|