autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,558 +0,0 @@
|
|
1
|
-
import functools
|
2
|
-
import json
|
3
|
-
import logging
|
4
|
-
import warnings
|
5
|
-
from typing import Dict, List, Optional, Tuple, Union
|
6
|
-
|
7
|
-
import timm
|
8
|
-
from omegaconf import DictConfig, OmegaConf
|
9
|
-
from torch import Tensor, nn
|
10
|
-
|
11
|
-
from ..constants import (
|
12
|
-
ALL_MODALITIES,
|
13
|
-
AUTOMM,
|
14
|
-
CATEGORICAL,
|
15
|
-
CATEGORICAL_MLP,
|
16
|
-
CLIP,
|
17
|
-
DOCUMENT,
|
18
|
-
DOCUMENT_TRANSFORMER,
|
19
|
-
FT_TRANSFORMER,
|
20
|
-
FUSION_MLP,
|
21
|
-
FUSION_NER,
|
22
|
-
FUSION_TRANSFORMER,
|
23
|
-
HF_TEXT,
|
24
|
-
IMAGE,
|
25
|
-
MMDET_IMAGE,
|
26
|
-
MMOCR_TEXT_DET,
|
27
|
-
MMOCR_TEXT_RECOG,
|
28
|
-
NER,
|
29
|
-
NER_TEXT,
|
30
|
-
NUMERICAL,
|
31
|
-
NUMERICAL_MLP,
|
32
|
-
PEFT_ADDITIVE_STRATEGIES,
|
33
|
-
SAM,
|
34
|
-
SEMANTIC_SEGMENTATION_IMG,
|
35
|
-
T_FEW,
|
36
|
-
TEXT,
|
37
|
-
TEXT_NER,
|
38
|
-
TIMM_IMAGE,
|
39
|
-
XYXY,
|
40
|
-
)
|
41
|
-
from ..data import MultiModalFeaturePreprocessor
|
42
|
-
from ..models import (
|
43
|
-
CategoricalMLP,
|
44
|
-
CLIPForImageText,
|
45
|
-
DocumentTransformer,
|
46
|
-
FT_Transformer,
|
47
|
-
HFAutoModelForNER,
|
48
|
-
HFAutoModelForTextPrediction,
|
49
|
-
MMDetAutoModelForObjectDetection,
|
50
|
-
MMOCRAutoModelForTextDetection,
|
51
|
-
MMOCRAutoModelForTextRecognition,
|
52
|
-
MultimodalFusionMLP,
|
53
|
-
MultimodalFusionNER,
|
54
|
-
MultimodalFusionTransformer,
|
55
|
-
NumericalMLP,
|
56
|
-
SAMForSemanticSegmentation,
|
57
|
-
TFewModel,
|
58
|
-
TimmAutoModelForImagePrediction,
|
59
|
-
)
|
60
|
-
from ..models.utils import inject_adaptation_to_linear_layer
|
61
|
-
|
62
|
-
logger = logging.getLogger(__name__)
|
63
|
-
|
64
|
-
|
65
|
-
def select_model(
|
66
|
-
config: DictConfig,
|
67
|
-
df_preprocessor: MultiModalFeaturePreprocessor,
|
68
|
-
strict: Optional[bool] = True,
|
69
|
-
):
|
70
|
-
"""
|
71
|
-
Filter model config through the detected modalities in the training data.
|
72
|
-
If MultiModalFeaturePreprocessor can't detect some modality,
|
73
|
-
this function will remove the models that use this modality. This function is to
|
74
|
-
maximize the user flexibility in defining the config.
|
75
|
-
For example, if one uses the default, including hf_text and timm_image, as the model config template
|
76
|
-
but the training data don't have images, this function will filter out timm_image.
|
77
|
-
|
78
|
-
Parameters
|
79
|
-
----------
|
80
|
-
config
|
81
|
-
A DictConfig object. The model config should be accessible by "config.model"
|
82
|
-
df_preprocessor
|
83
|
-
A MultiModalFeaturePreprocessor object, which has called .fit() on the training data.
|
84
|
-
Column names of the same modality are grouped into one list. If a modality's list is empty,
|
85
|
-
it means the training data don't have this modality.
|
86
|
-
strict
|
87
|
-
If False, allow retaining one model when partial modalities are available for that model.
|
88
|
-
|
89
|
-
Returns
|
90
|
-
-------
|
91
|
-
Config with some unused models removed.
|
92
|
-
"""
|
93
|
-
data_status = {}
|
94
|
-
for per_modality in ALL_MODALITIES:
|
95
|
-
data_status[per_modality] = False
|
96
|
-
if len(df_preprocessor.image_feature_names) > 0:
|
97
|
-
data_status[IMAGE] = True
|
98
|
-
if len(df_preprocessor.text_feature_names) > 0:
|
99
|
-
data_status[TEXT] = True
|
100
|
-
if len(df_preprocessor.categorical_feature_names) > 0:
|
101
|
-
data_status[CATEGORICAL] = True
|
102
|
-
if len(df_preprocessor.numerical_feature_names) > 0:
|
103
|
-
data_status[NUMERICAL] = True
|
104
|
-
if len(df_preprocessor.ner_feature_names) > 0:
|
105
|
-
data_status[TEXT_NER] = True
|
106
|
-
if len(df_preprocessor.document_feature_names) > 0:
|
107
|
-
data_status[DOCUMENT] = True
|
108
|
-
if len(df_preprocessor.semantic_segmentation_feature_names) > 0:
|
109
|
-
data_status[SEMANTIC_SEGMENTATION_IMG] = True
|
110
|
-
|
111
|
-
names = config.model.names
|
112
|
-
if isinstance(names, str):
|
113
|
-
names = [names]
|
114
|
-
selected_model_names = []
|
115
|
-
fusion_model_name = []
|
116
|
-
for model_name in names:
|
117
|
-
model_config = getattr(config.model, model_name)
|
118
|
-
strict = getattr(model_config, "requires_all_dtypes", strict)
|
119
|
-
if not model_config.data_types:
|
120
|
-
fusion_model_name.append(model_name)
|
121
|
-
continue
|
122
|
-
model_data_status = [data_status[d_type] for d_type in model_config.data_types]
|
123
|
-
if all(model_data_status):
|
124
|
-
selected_model_names.append(model_name)
|
125
|
-
else:
|
126
|
-
if any(model_data_status) and not strict:
|
127
|
-
selected_model_names.append(model_name)
|
128
|
-
else:
|
129
|
-
delattr(config.model, model_name)
|
130
|
-
|
131
|
-
if len(selected_model_names) == 0:
|
132
|
-
raise ValueError("No model is available for this dataset.")
|
133
|
-
# only allow no more than 1 fusion model
|
134
|
-
if len(fusion_model_name) > 1:
|
135
|
-
raise ValueError(f"More than one fusion models `{fusion_model_name}` are detected, but only one is allowed.")
|
136
|
-
|
137
|
-
if len(selected_model_names) > 1:
|
138
|
-
assert len(fusion_model_name) == 1
|
139
|
-
selected_model_names.extend(fusion_model_name)
|
140
|
-
elif len(fusion_model_name) == 1 and hasattr(config.model, fusion_model_name[0]):
|
141
|
-
delattr(config.model, fusion_model_name[0])
|
142
|
-
|
143
|
-
config.model.names = selected_model_names
|
144
|
-
logger.debug(f"selected models: {selected_model_names}")
|
145
|
-
for model_name in selected_model_names:
|
146
|
-
logger.debug(f"model dtypes: {getattr(config.model, model_name).data_types}")
|
147
|
-
|
148
|
-
# clean up unused model configs
|
149
|
-
model_keys = list(config.model.keys())
|
150
|
-
for model_name in model_keys:
|
151
|
-
if model_name not in selected_model_names + ["names"]:
|
152
|
-
delattr(config.model, model_name)
|
153
|
-
|
154
|
-
return config
|
155
|
-
|
156
|
-
|
157
|
-
def create_model(
|
158
|
-
model_name: str,
|
159
|
-
model_config: DictConfig,
|
160
|
-
num_classes: Optional[int] = 0,
|
161
|
-
classes: Optional[list] = None,
|
162
|
-
num_numerical_columns: Optional[int] = None,
|
163
|
-
num_categories: Optional[List[int]] = None,
|
164
|
-
pretrained: Optional[bool] = True,
|
165
|
-
):
|
166
|
-
"""
|
167
|
-
Create a single model.
|
168
|
-
|
169
|
-
Parameters
|
170
|
-
----------
|
171
|
-
model_name
|
172
|
-
Name of the model.
|
173
|
-
model_config
|
174
|
-
Config of the model.
|
175
|
-
num_classes
|
176
|
-
The class number for a classification task. It should be 1 for a regression task.
|
177
|
-
classes
|
178
|
-
All classes in this dataset.
|
179
|
-
num_numerical_columns
|
180
|
-
The number of numerical columns in the training dataframe.
|
181
|
-
num_categories
|
182
|
-
The category number for each categorical column in the training dataframe.
|
183
|
-
pretrained
|
184
|
-
Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
|
185
|
-
|
186
|
-
Returns
|
187
|
-
-------
|
188
|
-
A model.
|
189
|
-
"""
|
190
|
-
if model_name.lower().startswith(CLIP):
|
191
|
-
model = CLIPForImageText(
|
192
|
-
prefix=model_name,
|
193
|
-
checkpoint_name=model_config.checkpoint_name,
|
194
|
-
num_classes=num_classes,
|
195
|
-
pretrained=pretrained,
|
196
|
-
tokenizer_name=model_config.tokenizer_name,
|
197
|
-
)
|
198
|
-
elif model_name.lower().startswith(TIMM_IMAGE):
|
199
|
-
model = TimmAutoModelForImagePrediction(
|
200
|
-
prefix=model_name,
|
201
|
-
checkpoint_name=model_config.checkpoint_name,
|
202
|
-
num_classes=num_classes,
|
203
|
-
mix_choice=model_config.mix_choice,
|
204
|
-
pretrained=pretrained,
|
205
|
-
)
|
206
|
-
elif model_name.lower().startswith(HF_TEXT):
|
207
|
-
model = HFAutoModelForTextPrediction(
|
208
|
-
prefix=model_name,
|
209
|
-
checkpoint_name=model_config.checkpoint_name,
|
210
|
-
num_classes=num_classes,
|
211
|
-
pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
|
212
|
-
gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
|
213
|
-
low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
|
214
|
-
pretrained=pretrained,
|
215
|
-
tokenizer_name=model_config.tokenizer_name,
|
216
|
-
use_fast=OmegaConf.select(model_config, "use_fast", default=True),
|
217
|
-
)
|
218
|
-
elif model_name.lower().startswith(T_FEW):
|
219
|
-
model = TFewModel(
|
220
|
-
prefix=model_name,
|
221
|
-
checkpoint_name=model_config.checkpoint_name,
|
222
|
-
length_norm=model_config.length_norm, # Normalizes length to adjust for length bias in target template
|
223
|
-
unlikely_loss=model_config.unlikely_loss, # Adds loss term that lowers probability of incorrect outputs
|
224
|
-
mc_loss=model_config.mc_loss, # Adds multiple choice cross entropy loss
|
225
|
-
num_classes=num_classes,
|
226
|
-
gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
|
227
|
-
low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
|
228
|
-
pretrained=pretrained,
|
229
|
-
tokenizer_name=model_config.tokenizer_name,
|
230
|
-
)
|
231
|
-
elif model_name.lower().startswith(NUMERICAL_MLP):
|
232
|
-
model = NumericalMLP(
|
233
|
-
prefix=model_name,
|
234
|
-
in_features=num_numerical_columns,
|
235
|
-
hidden_features=model_config.hidden_size,
|
236
|
-
out_features=model_config.hidden_size,
|
237
|
-
num_layers=model_config.num_layers,
|
238
|
-
activation=model_config.activation,
|
239
|
-
dropout_prob=model_config.drop_rate,
|
240
|
-
normalization=model_config.normalization,
|
241
|
-
d_token=OmegaConf.select(model_config, "d_token"),
|
242
|
-
embedding_arch=OmegaConf.select(model_config, "embedding_arch"),
|
243
|
-
num_classes=num_classes,
|
244
|
-
)
|
245
|
-
elif model_name.lower().startswith(CATEGORICAL_MLP):
|
246
|
-
model = CategoricalMLP(
|
247
|
-
prefix=model_name,
|
248
|
-
num_categories=num_categories,
|
249
|
-
out_features=model_config.hidden_size,
|
250
|
-
num_layers=model_config.num_layers,
|
251
|
-
activation=model_config.activation,
|
252
|
-
dropout_prob=model_config.drop_rate,
|
253
|
-
normalization=model_config.normalization,
|
254
|
-
num_classes=num_classes,
|
255
|
-
)
|
256
|
-
elif model_name.lower().startswith(DOCUMENT_TRANSFORMER):
|
257
|
-
model = DocumentTransformer(
|
258
|
-
prefix=model_name,
|
259
|
-
checkpoint_name=model_config.checkpoint_name,
|
260
|
-
num_classes=num_classes,
|
261
|
-
pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
|
262
|
-
gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
|
263
|
-
low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
|
264
|
-
pretrained=pretrained,
|
265
|
-
tokenizer_name=model_config.tokenizer_name,
|
266
|
-
)
|
267
|
-
elif model_name.lower().startswith(MMDET_IMAGE):
|
268
|
-
model = MMDetAutoModelForObjectDetection(
|
269
|
-
prefix=model_name,
|
270
|
-
checkpoint_name=model_config.checkpoint_name,
|
271
|
-
config_file=OmegaConf.select(model_config, "config_file", default=None),
|
272
|
-
classes=classes,
|
273
|
-
pretrained=pretrained,
|
274
|
-
output_bbox_format=OmegaConf.select(model_config, "output_bbox_format", default=XYXY),
|
275
|
-
frozen_layers=OmegaConf.select(model_config, "frozen_layers", default=None),
|
276
|
-
)
|
277
|
-
elif model_name.lower().startswith(MMOCR_TEXT_DET):
|
278
|
-
model = MMOCRAutoModelForTextDetection(
|
279
|
-
prefix=model_name,
|
280
|
-
checkpoint_name=model_config.checkpoint_name,
|
281
|
-
)
|
282
|
-
elif model_name.lower().startswith(MMOCR_TEXT_RECOG):
|
283
|
-
model = MMOCRAutoModelForTextRecognition(
|
284
|
-
prefix=model_name,
|
285
|
-
checkpoint_name=model_config.checkpoint_name,
|
286
|
-
)
|
287
|
-
elif model_name.lower().startswith(NER_TEXT):
|
288
|
-
model = HFAutoModelForNER(
|
289
|
-
prefix=model_name,
|
290
|
-
checkpoint_name=model_config.checkpoint_name,
|
291
|
-
num_classes=num_classes,
|
292
|
-
gradient_checkpointing=OmegaConf.select(model_config, "gradient_checkpointing"),
|
293
|
-
low_cpu_mem_usage=OmegaConf.select(model_config, "low_cpu_mem_usage", default=False),
|
294
|
-
pretrained=pretrained,
|
295
|
-
tokenizer_name=model_config.tokenizer_name,
|
296
|
-
)
|
297
|
-
elif model_name.lower().startswith(FUSION_MLP):
|
298
|
-
model = functools.partial(
|
299
|
-
MultimodalFusionMLP,
|
300
|
-
prefix=model_name,
|
301
|
-
hidden_features=model_config.hidden_sizes,
|
302
|
-
num_classes=num_classes,
|
303
|
-
adapt_in_features=model_config.adapt_in_features,
|
304
|
-
activation=model_config.activation,
|
305
|
-
dropout_prob=model_config.drop_rate,
|
306
|
-
normalization=model_config.normalization,
|
307
|
-
loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
|
308
|
-
)
|
309
|
-
elif model_name.lower().startswith(FUSION_NER):
|
310
|
-
model = functools.partial(
|
311
|
-
MultimodalFusionNER,
|
312
|
-
prefix=model_name,
|
313
|
-
hidden_features=model_config.hidden_sizes,
|
314
|
-
num_classes=num_classes,
|
315
|
-
adapt_in_features=model_config.adapt_in_features,
|
316
|
-
activation=model_config.activation,
|
317
|
-
dropout_prob=model_config.drop_rate,
|
318
|
-
normalization=model_config.normalization,
|
319
|
-
loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
|
320
|
-
)
|
321
|
-
elif model_name.lower().startswith(FUSION_TRANSFORMER):
|
322
|
-
model = functools.partial(
|
323
|
-
MultimodalFusionTransformer,
|
324
|
-
prefix=model_name,
|
325
|
-
hidden_features=model_config.hidden_size,
|
326
|
-
num_classes=num_classes,
|
327
|
-
n_blocks=model_config.n_blocks,
|
328
|
-
attention_n_heads=model_config.attention_n_heads,
|
329
|
-
ffn_d_hidden=model_config.ffn_d_hidden,
|
330
|
-
attention_dropout=model_config.attention_dropout,
|
331
|
-
residual_dropout=model_config.residual_dropout,
|
332
|
-
ffn_dropout=model_config.ffn_dropout,
|
333
|
-
attention_normalization=model_config.normalization,
|
334
|
-
ffn_normalization=model_config.normalization,
|
335
|
-
head_normalization=model_config.normalization,
|
336
|
-
ffn_activation=model_config.ffn_activation,
|
337
|
-
head_activation=model_config.head_activation,
|
338
|
-
adapt_in_features=model_config.adapt_in_features,
|
339
|
-
loss_weight=model_config.weight if hasattr(model_config, "weight") else None,
|
340
|
-
additive_attention=OmegaConf.select(model_config, "additive_attention", default=False),
|
341
|
-
share_qv_weights=OmegaConf.select(model_config, "share_qv_weights", default=False),
|
342
|
-
)
|
343
|
-
elif model_name.lower().startswith(FT_TRANSFORMER):
|
344
|
-
model = FT_Transformer(
|
345
|
-
prefix=model_name,
|
346
|
-
num_numerical_columns=num_numerical_columns,
|
347
|
-
num_categories=num_categories,
|
348
|
-
embedding_arch=model_config.embedding_arch,
|
349
|
-
token_dim=model_config.token_dim,
|
350
|
-
hidden_size=model_config.hidden_size,
|
351
|
-
hidden_features=model_config.hidden_size,
|
352
|
-
num_classes=num_classes,
|
353
|
-
num_blocks=model_config.num_blocks,
|
354
|
-
attention_n_heads=model_config.attention_n_heads,
|
355
|
-
attention_dropout=model_config.attention_dropout,
|
356
|
-
attention_normalization=model_config.normalization,
|
357
|
-
ffn_hidden_size=model_config.ffn_hidden_size,
|
358
|
-
ffn_dropout=model_config.ffn_dropout,
|
359
|
-
ffn_normalization=model_config.normalization,
|
360
|
-
ffn_activation=model_config.ffn_activation,
|
361
|
-
residual_dropout=model_config.residual_dropout,
|
362
|
-
head_normalization=model_config.normalization,
|
363
|
-
head_activation=model_config.head_activation,
|
364
|
-
additive_attention=OmegaConf.select(model_config, "additive_attention", default=False),
|
365
|
-
share_qv_weights=OmegaConf.select(model_config, "share_qv_weights", default=False),
|
366
|
-
pooling_mode=OmegaConf.select(model_config, "pooling_mode", default="cls"),
|
367
|
-
checkpoint_name=model_config.checkpoint_name,
|
368
|
-
pretrained=pretrained,
|
369
|
-
)
|
370
|
-
elif model_name.lower().startswith(SAM):
|
371
|
-
model = SAMForSemanticSegmentation(
|
372
|
-
prefix=model_name,
|
373
|
-
checkpoint_name=model_config.checkpoint_name,
|
374
|
-
num_classes=num_classes,
|
375
|
-
pretrained=pretrained,
|
376
|
-
frozen_layers=OmegaConf.select(model_config, "frozen_layers", default=None),
|
377
|
-
num_mask_tokens=OmegaConf.select(model_config, "num_mask_tokens", default=1),
|
378
|
-
)
|
379
|
-
else:
|
380
|
-
raise ValueError(f"unknown model name: {model_name}")
|
381
|
-
|
382
|
-
return model
|
383
|
-
|
384
|
-
|
385
|
-
def create_fusion_model(
|
386
|
-
config: DictConfig,
|
387
|
-
num_classes: Optional[int] = None,
|
388
|
-
classes: Optional[list] = None,
|
389
|
-
num_numerical_columns: Optional[int] = None,
|
390
|
-
num_categories: Optional[List[int]] = None,
|
391
|
-
pretrained: Optional[bool] = True,
|
392
|
-
):
|
393
|
-
"""
|
394
|
-
Create models. It supports the auto models of huggingface text and timm image.
|
395
|
-
Multimodal models, e.g., CLIP, should be added case-by-case since their configs and usages
|
396
|
-
may be different. It uses MLP for the numerical features, categorical features, and late-fusion.
|
397
|
-
|
398
|
-
Parameters
|
399
|
-
----------
|
400
|
-
config
|
401
|
-
A DictConfig object. The model config should be accessible by "config.model".
|
402
|
-
num_classes
|
403
|
-
The class number for a classification task. It should be 1 for a regression task.
|
404
|
-
classes
|
405
|
-
All classes in this dataset.
|
406
|
-
num_numerical_columns
|
407
|
-
The number of numerical columns in the training dataframe.
|
408
|
-
num_categories
|
409
|
-
The category number for each categorical column in the training dataframe.
|
410
|
-
pretrained
|
411
|
-
Whether using the pretrained timm models. If pretrained=True, download the pretrained model.
|
412
|
-
|
413
|
-
Returns
|
414
|
-
-------
|
415
|
-
A Pytorch model.
|
416
|
-
"""
|
417
|
-
names = config.model.names
|
418
|
-
if isinstance(names, str):
|
419
|
-
names = [names]
|
420
|
-
# make sure no duplicate model names
|
421
|
-
assert len(names) == len(set(names))
|
422
|
-
logger.debug(f"output_shape: {num_classes}")
|
423
|
-
names = sorted(names)
|
424
|
-
config.model.names = names
|
425
|
-
single_models = []
|
426
|
-
fusion_model = None
|
427
|
-
|
428
|
-
for model_name in names:
|
429
|
-
model_config = getattr(config.model, model_name)
|
430
|
-
model = create_model(
|
431
|
-
model_name=model_name,
|
432
|
-
model_config=model_config,
|
433
|
-
num_classes=num_classes,
|
434
|
-
classes=classes,
|
435
|
-
num_numerical_columns=num_numerical_columns,
|
436
|
-
num_categories=num_categories,
|
437
|
-
pretrained=pretrained,
|
438
|
-
)
|
439
|
-
|
440
|
-
if isinstance(model, functools.partial): # fusion model
|
441
|
-
if fusion_model is None:
|
442
|
-
fusion_model = model
|
443
|
-
else:
|
444
|
-
raise ValueError(
|
445
|
-
f"More than one fusion models are detected in {names}. Only one fusion model is allowed."
|
446
|
-
)
|
447
|
-
else: # single model
|
448
|
-
if (
|
449
|
-
OmegaConf.select(config, "optimization.efficient_finetune") is not None
|
450
|
-
and OmegaConf.select(config, "optimization.efficient_finetune") != "None"
|
451
|
-
):
|
452
|
-
model = apply_model_adaptation(model, config)
|
453
|
-
single_models.append(model)
|
454
|
-
|
455
|
-
if len(single_models) > 1:
|
456
|
-
# must have one fusion model if there are multiple independent models
|
457
|
-
return fusion_model(models=single_models)
|
458
|
-
elif len(single_models) == 1:
|
459
|
-
return single_models[0]
|
460
|
-
else:
|
461
|
-
raise ValueError(f"No available models for {names}")
|
462
|
-
|
463
|
-
|
464
|
-
def apply_model_adaptation(model: nn.Module, config: DictConfig) -> nn.Module:
|
465
|
-
"""
|
466
|
-
Apply an adaptation to the model for efficient fine-tuning.
|
467
|
-
|
468
|
-
Parameters
|
469
|
-
----------
|
470
|
-
model
|
471
|
-
A PyTorch model.
|
472
|
-
config:
|
473
|
-
A DictConfig object. The optimization config should be accessible by "config.optimization".
|
474
|
-
"""
|
475
|
-
if OmegaConf.select(config, "optimization.efficient_finetune") in PEFT_ADDITIVE_STRATEGIES:
|
476
|
-
model = inject_adaptation_to_linear_layer(
|
477
|
-
model=model,
|
478
|
-
efficient_finetune=OmegaConf.select(config, "optimization.efficient_finetune"),
|
479
|
-
lora_r=config.optimization.lora.r,
|
480
|
-
lora_alpha=config.optimization.lora.alpha,
|
481
|
-
module_filter=config.optimization.lora.module_filter,
|
482
|
-
filter=config.optimization.lora.filter,
|
483
|
-
extra_trainable_params=OmegaConf.select(config, "optimization.extra_trainable_params"),
|
484
|
-
conv_lora_expert_num=config.optimization.lora.conv_lora_expert_num,
|
485
|
-
)
|
486
|
-
model.name_to_id = model.get_layer_ids() # Need to update name to id dictionary.
|
487
|
-
|
488
|
-
return model
|
489
|
-
|
490
|
-
|
491
|
-
def modify_duplicate_model_names(
|
492
|
-
learner,
|
493
|
-
postfix: str,
|
494
|
-
blacklist: List[str],
|
495
|
-
):
|
496
|
-
"""
|
497
|
-
Modify a learner's model names if they exist in a blacklist.
|
498
|
-
|
499
|
-
Parameters
|
500
|
-
----------
|
501
|
-
learner
|
502
|
-
A BaseLearner object.
|
503
|
-
postfix
|
504
|
-
The postfix used to change the duplicate names.
|
505
|
-
blacklist
|
506
|
-
A list of names. The provided learner can't use model names in the list.
|
507
|
-
|
508
|
-
Returns
|
509
|
-
-------
|
510
|
-
The learner guaranteed has no duplicate model names with the blacklist names.
|
511
|
-
"""
|
512
|
-
model_names = []
|
513
|
-
for n in learner._config.model.names:
|
514
|
-
if n in blacklist:
|
515
|
-
new_name = f"{n}_{postfix}"
|
516
|
-
assert new_name not in blacklist
|
517
|
-
assert new_name not in learner._config.model.names
|
518
|
-
# modify model prefix
|
519
|
-
if n == learner._model.prefix:
|
520
|
-
learner._model.prefix = new_name
|
521
|
-
else:
|
522
|
-
assert isinstance(learner._model.model, nn.ModuleList)
|
523
|
-
for per_model in learner._model.model:
|
524
|
-
if n == per_model.prefix:
|
525
|
-
per_model.prefix = new_name
|
526
|
-
break
|
527
|
-
# modify data processor prefix
|
528
|
-
for per_modality_processors in learner._data_processors.values():
|
529
|
-
for per_processor in per_modality_processors:
|
530
|
-
if n == per_processor.prefix:
|
531
|
-
per_processor.prefix = new_name
|
532
|
-
# modify model config keys
|
533
|
-
setattr(learner._config.model, new_name, getattr(learner._config.model, n))
|
534
|
-
delattr(learner._config.model, n)
|
535
|
-
|
536
|
-
model_names.append(new_name)
|
537
|
-
else:
|
538
|
-
model_names.append(n)
|
539
|
-
|
540
|
-
learner._config.model.names = model_names
|
541
|
-
|
542
|
-
return learner
|
543
|
-
|
544
|
-
|
545
|
-
def list_timm_models(pretrained=True):
|
546
|
-
return timm.list_models(pretrained=pretrained)
|
547
|
-
|
548
|
-
|
549
|
-
def is_lazy_weight_tensor(p: Tensor) -> bool:
|
550
|
-
from torch.nn.parameter import UninitializedParameter
|
551
|
-
|
552
|
-
if isinstance(p, UninitializedParameter):
|
553
|
-
warnings.warn(
|
554
|
-
"A layer with UninitializedParameter was found. "
|
555
|
-
"Thus, the total number of parameters detected may be inaccurate."
|
556
|
-
)
|
557
|
-
return True
|
558
|
-
return False
|