autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -0,0 +1,336 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from timm import create_model
|
6
|
+
from timm.models.vision_transformer import Block
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
from ..constants import (
|
10
|
+
CATEGORICAL,
|
11
|
+
FEATURES,
|
12
|
+
IMAGE,
|
13
|
+
IMAGE_VALID_NUM,
|
14
|
+
LABEL,
|
15
|
+
LOGITS,
|
16
|
+
NUMERICAL,
|
17
|
+
TEXT_SEGMENT_IDS,
|
18
|
+
TEXT_TOKEN_IDS,
|
19
|
+
TEXT_VALID_LENGTH,
|
20
|
+
)
|
21
|
+
from .custom_transformer import CLSToken
|
22
|
+
from .ft_transformer import CategoricalFeatureTokenizer, NumEmbeddings
|
23
|
+
from .utils import (
|
24
|
+
assign_layer_ids,
|
25
|
+
get_hf_config_and_model,
|
26
|
+
get_image_size_mean_std,
|
27
|
+
get_pretrained_tokenizer,
|
28
|
+
get_text_segment_num,
|
29
|
+
get_text_token_max_len,
|
30
|
+
init_weights,
|
31
|
+
replace_missing_images_with_learnable,
|
32
|
+
)
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class MetaTransformer(nn.Module):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
prefix: str,
|
41
|
+
num_classes: int,
|
42
|
+
checkpoint_path: str,
|
43
|
+
model_version: str,
|
44
|
+
has_image: bool,
|
45
|
+
has_text: bool,
|
46
|
+
num_numerical_columns: int,
|
47
|
+
num_categories: Dict,
|
48
|
+
numerical_fill_values: Dict,
|
49
|
+
image_size: Optional[int] = None,
|
50
|
+
image_norm: Optional[str] = None,
|
51
|
+
image_chan_num: Optional[int] = 3,
|
52
|
+
use_learnable_image: Optional[bool] = False,
|
53
|
+
max_text_len: Optional[int] = None,
|
54
|
+
text_segment_num: Optional[int] = 1,
|
55
|
+
):
|
56
|
+
super().__init__()
|
57
|
+
logger.debug(f"initializing {prefix} (MetaTransformer)")
|
58
|
+
self.prefix = prefix
|
59
|
+
self.checkpoint_name = checkpoint_path
|
60
|
+
|
61
|
+
if model_version == "base":
|
62
|
+
dim = 768
|
63
|
+
num_heads = 12
|
64
|
+
layer_num = 12
|
65
|
+
elif model_version == "large":
|
66
|
+
dim = 1024
|
67
|
+
num_heads = 16
|
68
|
+
layer_num = 24
|
69
|
+
else:
|
70
|
+
raise ValueError(f"Unsupported model version: {model_version}. Options are 'base' and 'large'.")
|
71
|
+
|
72
|
+
self.model = nn.Sequential(
|
73
|
+
*[
|
74
|
+
Block(
|
75
|
+
dim=dim,
|
76
|
+
num_heads=num_heads,
|
77
|
+
mlp_ratio=4.0,
|
78
|
+
qkv_bias=True,
|
79
|
+
norm_layer=nn.LayerNorm,
|
80
|
+
act_layer=nn.GELU,
|
81
|
+
)
|
82
|
+
for i in range(layer_num)
|
83
|
+
]
|
84
|
+
)
|
85
|
+
|
86
|
+
checkpoint = torch.load(checkpoint_path) # nosec B614
|
87
|
+
self.checkpoint_path = checkpoint_path
|
88
|
+
self.model.load_state_dict(checkpoint, strict=True)
|
89
|
+
|
90
|
+
self.head = nn.Linear(dim, num_classes) if num_classes else nn.Identity()
|
91
|
+
|
92
|
+
self.cls_token = CLSToken(token_dim=dim, initialization="uniform")
|
93
|
+
self.config = None
|
94
|
+
|
95
|
+
self.tokenizer = None
|
96
|
+
self.text_adaptor = None
|
97
|
+
self.image_tokenizer = None
|
98
|
+
self.image_adaptor = None
|
99
|
+
self.categorical_feature_tokenizer = None
|
100
|
+
self.categorical_adapter = None
|
101
|
+
self.numerical_feature_tokenizer = None
|
102
|
+
self.numerical_adapter = None
|
103
|
+
|
104
|
+
# if has_text or has_image:
|
105
|
+
# clip_ckpt = "openai/clip-vit-base-patch32"
|
106
|
+
# _, clip_model = get_hf_config_and_model(checkpoint_name=clip_ckpt, pretrained=True)
|
107
|
+
|
108
|
+
if has_text:
|
109
|
+
checkpoint_name = "microsoft/deberta-v3-base"
|
110
|
+
_, text_model = get_hf_config_and_model(checkpoint_name=checkpoint_name, pretrained=True)
|
111
|
+
self.text_config = text_model.config
|
112
|
+
# refer to https://github.com/invictus717/MetaTransformer/blob/master/Data2Seq/Data2Seq.py#L28
|
113
|
+
self.tokenizer = get_pretrained_tokenizer(
|
114
|
+
tokenizer_name="hf_auto",
|
115
|
+
checkpoint_name=checkpoint_name,
|
116
|
+
)
|
117
|
+
self.text_embed = text_model.embeddings
|
118
|
+
self.text_adaptor = nn.Linear(self.text_config.hidden_size, dim)
|
119
|
+
self.tokenizer_name = "hf_auto"
|
120
|
+
self.max_text_len = get_text_token_max_len(
|
121
|
+
provided_max_len=max_text_len,
|
122
|
+
config=self.text_config,
|
123
|
+
tokenizer=self.tokenizer,
|
124
|
+
checkpoint_name=checkpoint_name,
|
125
|
+
)
|
126
|
+
self.text_segment_num = get_text_segment_num(
|
127
|
+
config=self.text_config,
|
128
|
+
provided_segment_num=text_segment_num,
|
129
|
+
checkpoint_name=checkpoint_name,
|
130
|
+
)
|
131
|
+
if has_image:
|
132
|
+
image_model = create_model("timm/vit_base_patch16_224.mae", pretrained=True)
|
133
|
+
self.image_config = image_model.default_cfg
|
134
|
+
self.patch_embed = image_model.patch_embed
|
135
|
+
self.image_adaptor = nn.Linear(image_model.embed_dim, dim)
|
136
|
+
self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
|
137
|
+
model_name=self.prefix,
|
138
|
+
config=self.image_config,
|
139
|
+
provided_size=image_size,
|
140
|
+
provided_norm_type=image_norm,
|
141
|
+
support_variable_input_size=False,
|
142
|
+
)
|
143
|
+
self.use_learnable_image = use_learnable_image
|
144
|
+
if self.use_learnable_image:
|
145
|
+
self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
|
146
|
+
logger.debug("will use a learnable image to replace missing ones")
|
147
|
+
|
148
|
+
if num_categories:
|
149
|
+
self.num_categories = num_categories
|
150
|
+
self.categorical_feature_tokenizer = CategoricalFeatureTokenizer(
|
151
|
+
num_categories=list(num_categories.values()),
|
152
|
+
token_dim=dim,
|
153
|
+
bias=True,
|
154
|
+
initialization="normal",
|
155
|
+
)
|
156
|
+
self.categorical_adapter = nn.Linear(dim, dim)
|
157
|
+
|
158
|
+
if num_numerical_columns > 0:
|
159
|
+
self.num_numerical_columns = num_numerical_columns
|
160
|
+
self.numerical_fill_values = numerical_fill_values
|
161
|
+
self.numerical_feature_tokenizer = NumEmbeddings(
|
162
|
+
in_features=num_numerical_columns,
|
163
|
+
d_embedding=dim,
|
164
|
+
embedding_arch=["linear"],
|
165
|
+
)
|
166
|
+
self.numerical_adapter = nn.Linear(dim, dim)
|
167
|
+
|
168
|
+
self.out_features = dim
|
169
|
+
|
170
|
+
# init weights
|
171
|
+
self.head.apply(init_weights)
|
172
|
+
self.name_to_id = self.get_layer_ids()
|
173
|
+
self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
|
174
|
+
|
175
|
+
@property
|
176
|
+
def text_token_ids_key(self):
|
177
|
+
return f"{self.prefix}_{TEXT_TOKEN_IDS}"
|
178
|
+
|
179
|
+
@property
|
180
|
+
def text_valid_length_key(self):
|
181
|
+
return f"{self.prefix}_{TEXT_VALID_LENGTH}"
|
182
|
+
|
183
|
+
@property
|
184
|
+
def text_segment_ids_key(self):
|
185
|
+
return f"{self.prefix}_{TEXT_SEGMENT_IDS}"
|
186
|
+
|
187
|
+
@property
|
188
|
+
def image_key(self):
|
189
|
+
return f"{self.prefix}_{IMAGE}"
|
190
|
+
|
191
|
+
@property
|
192
|
+
def image_valid_num_key(self):
|
193
|
+
return f"{self.prefix}_{IMAGE_VALID_NUM}"
|
194
|
+
|
195
|
+
@property
|
196
|
+
def categorical_key(self):
|
197
|
+
return f"{self.prefix}_{CATEGORICAL}"
|
198
|
+
|
199
|
+
@property
|
200
|
+
def numerical_key(self):
|
201
|
+
return f"{self.prefix}_{NUMERICAL}"
|
202
|
+
|
203
|
+
@property
|
204
|
+
def label_key(self):
|
205
|
+
return f"{self.prefix}_{LABEL}"
|
206
|
+
|
207
|
+
def forward(
|
208
|
+
self,
|
209
|
+
batch: dict,
|
210
|
+
):
|
211
|
+
multimodal_features = []
|
212
|
+
if self.image_tokenizer:
|
213
|
+
images = batch[self.image_key]
|
214
|
+
image_valid_num = batch[self.image_valid_num_key]
|
215
|
+
b, n, c, h, w = images.shape
|
216
|
+
steps = torch.arange(0, n).type_as(image_valid_num)
|
217
|
+
image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
|
218
|
+
if self.use_learnable_image:
|
219
|
+
images = replace_missing_images_with_learnable(
|
220
|
+
images=images,
|
221
|
+
image_masks=image_masks,
|
222
|
+
learnable_image=self.learnable_image,
|
223
|
+
)
|
224
|
+
image_embeddings = self.patch_embed(images.reshape((b * n, c, h, w))) # (b*n, l, d)
|
225
|
+
assert image_embeddings.ndim == 3
|
226
|
+
image_embeddings = self.image_adaptor(image_embeddings)
|
227
|
+
multimodal_features.append(image_embeddings)
|
228
|
+
if self.text_adaptor: # text tokenizer is used in text processor
|
229
|
+
text_token_ids = batch[self.text_token_ids_key]
|
230
|
+
text_valid_length = batch[self.text_valid_length_key]
|
231
|
+
steps = torch.arange(0, text_token_ids.shape[1]).type_as(text_valid_length)
|
232
|
+
text_masks = (steps.reshape((1, -1)) < text_valid_length.reshape((-1, 1))).type_as(text_token_ids)
|
233
|
+
# text_embeddings = self.text_embeddings(batch[self.text_token_ids_key]) # (b, l, d)
|
234
|
+
input_ids = text_token_ids
|
235
|
+
inputs_embeds = None
|
236
|
+
attention_mask = text_masks
|
237
|
+
position_ids = None
|
238
|
+
if "token_type_ids" in self.tokenizer.model_input_names:
|
239
|
+
token_type_ids = batch[self.text_segment_ids_key]
|
240
|
+
else:
|
241
|
+
token_type_ids = None
|
242
|
+
|
243
|
+
if input_ids is not None and inputs_embeds is not None:
|
244
|
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
245
|
+
elif input_ids is not None:
|
246
|
+
input_shape = input_ids.size()
|
247
|
+
elif inputs_embeds is not None:
|
248
|
+
input_shape = inputs_embeds.size()[:-1]
|
249
|
+
else:
|
250
|
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
251
|
+
|
252
|
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
253
|
+
|
254
|
+
if attention_mask is None:
|
255
|
+
attention_mask = torch.ones(input_shape, device=device)
|
256
|
+
if token_type_ids is None:
|
257
|
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
258
|
+
|
259
|
+
text_embeddings = self.text_embed(
|
260
|
+
input_ids=input_ids,
|
261
|
+
token_type_ids=token_type_ids,
|
262
|
+
position_ids=position_ids,
|
263
|
+
mask=attention_mask,
|
264
|
+
inputs_embeds=inputs_embeds,
|
265
|
+
)
|
266
|
+
text_embeddings = self.text_adaptor(text_embeddings)
|
267
|
+
assert text_embeddings.ndim == 3
|
268
|
+
multimodal_features.append(text_embeddings)
|
269
|
+
if self.categorical_feature_tokenizer:
|
270
|
+
categorical_inputs = []
|
271
|
+
for categorical_input in batch[self.categorical_key]:
|
272
|
+
categorical_inputs.append(categorical_input)
|
273
|
+
categorical_inputs = torch.stack(categorical_inputs, dim=1)
|
274
|
+
|
275
|
+
categorical_features = self.categorical_feature_tokenizer(categorical_inputs)
|
276
|
+
categorical_features = self.categorical_adapter(categorical_features) # (b, l, d)
|
277
|
+
assert categorical_features.ndim == 3
|
278
|
+
multimodal_features.append(categorical_features)
|
279
|
+
if self.numerical_feature_tokenizer:
|
280
|
+
numerical_features = self.numerical_feature_tokenizer(batch[self.numerical_key])
|
281
|
+
numerical_features = self.numerical_adapter(numerical_features)
|
282
|
+
assert numerical_features.ndim == 3
|
283
|
+
multimodal_features.append(numerical_features)
|
284
|
+
|
285
|
+
multimodal_features = torch.cat(multimodal_features, dim=1)
|
286
|
+
multimodal_features = self.cls_token(multimodal_features)
|
287
|
+
features = self.model(multimodal_features)
|
288
|
+
pooled_features = features[:, -1, :] # CLSToken append the cls token to the sequence tail
|
289
|
+
logits = self.head(pooled_features)
|
290
|
+
ret = {
|
291
|
+
LOGITS: logits,
|
292
|
+
FEATURES: pooled_features,
|
293
|
+
}
|
294
|
+
return {self.prefix: ret}
|
295
|
+
|
296
|
+
def get_layer_ids(self):
|
297
|
+
"""
|
298
|
+
Assign an id to each layer. Layer ids will be used in layer-wise lr decay.
|
299
|
+
Basically, id gradually increases when going from the output end to
|
300
|
+
the input end. The layers defined in this class, e.g., head, have id 0.
|
301
|
+
|
302
|
+
In the AutoModel scenario, this function may not always return the correct result.
|
303
|
+
Thus, you can use "print(json.dumps(name_to_id, indent=2))" to manually check whether
|
304
|
+
the layer ids are reasonable.
|
305
|
+
|
306
|
+
Returns
|
307
|
+
-------
|
308
|
+
A dictionary mapping the layer names (keys) to their ids (values).
|
309
|
+
"""
|
310
|
+
model_prefix = "model"
|
311
|
+
pre_encoder_patterns = (
|
312
|
+
"embeddings",
|
313
|
+
"LayerNorm",
|
314
|
+
"wte",
|
315
|
+
"wpe",
|
316
|
+
"shared.weight",
|
317
|
+
"encoder.conv.conv",
|
318
|
+
"relative_attention_bias",
|
319
|
+
"dummy_layer",
|
320
|
+
)
|
321
|
+
post_encoder_patterns = ("head", "pooler", "ln_f", "final_layer_norm")
|
322
|
+
names = [n for n, _ in self.named_parameters()]
|
323
|
+
|
324
|
+
name_to_id, names = assign_layer_ids(
|
325
|
+
names=names,
|
326
|
+
pre_encoder_patterns=pre_encoder_patterns,
|
327
|
+
post_encoder_patterns=post_encoder_patterns,
|
328
|
+
model_pre=model_prefix,
|
329
|
+
)
|
330
|
+
if len(names) > 0:
|
331
|
+
logger.debug(f"outer layers are treated as head: {names}")
|
332
|
+
for n in names:
|
333
|
+
assert n not in name_to_id
|
334
|
+
name_to_id[n] = 0
|
335
|
+
|
336
|
+
return name_to_id
|
@@ -51,7 +51,7 @@ class Unit(nn.Module):
|
|
51
51
|
in_features: int,
|
52
52
|
out_features: int,
|
53
53
|
activation: str,
|
54
|
-
|
54
|
+
dropout: float,
|
55
55
|
):
|
56
56
|
"""
|
57
57
|
Parameters
|
@@ -64,7 +64,7 @@ class Unit(nn.Module):
|
|
64
64
|
Dimension of output features.
|
65
65
|
activation
|
66
66
|
Name of activation function.
|
67
|
-
|
67
|
+
dropout
|
68
68
|
Dropout probability.
|
69
69
|
"""
|
70
70
|
super().__init__()
|
@@ -78,7 +78,7 @@ class Unit(nn.Module):
|
|
78
78
|
raise ValueError(f"unknown normalization: {normalization}")
|
79
79
|
self.fc = nn.Linear(in_features, out_features)
|
80
80
|
self.act_fn = ALL_ACT_LAYERS[activation]()
|
81
|
-
self.dropout = nn.Dropout(
|
81
|
+
self.dropout = nn.Dropout(dropout)
|
82
82
|
|
83
83
|
def forward(self, x):
|
84
84
|
# pre normalization
|
@@ -102,7 +102,7 @@ class MLP(nn.Module):
|
|
102
102
|
out_features: Optional[int] = None,
|
103
103
|
num_layers: Optional[int] = 1,
|
104
104
|
activation: Optional[str] = "gelu",
|
105
|
-
|
105
|
+
dropout: Optional[float] = 0.5,
|
106
106
|
normalization: Optional[str] = "layer_norm",
|
107
107
|
):
|
108
108
|
"""
|
@@ -118,7 +118,7 @@ class MLP(nn.Module):
|
|
118
118
|
Number of layers.
|
119
119
|
activation
|
120
120
|
Name of activation function.
|
121
|
-
|
121
|
+
dropout
|
122
122
|
Dropout probability.
|
123
123
|
normalization
|
124
124
|
Name of normalization function.
|
@@ -134,7 +134,7 @@ class MLP(nn.Module):
|
|
134
134
|
in_features=in_features,
|
135
135
|
out_features=hidden_features,
|
136
136
|
activation=activation,
|
137
|
-
|
137
|
+
dropout=dropout,
|
138
138
|
)
|
139
139
|
in_features = hidden_features
|
140
140
|
layers.append(per_unit)
|
@@ -11,7 +11,7 @@ except ImportError:
|
|
11
11
|
mmocr = None
|
12
12
|
from torch import nn
|
13
13
|
|
14
|
-
from ..constants import
|
14
|
+
from ..constants import BBOX, COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
|
15
15
|
from .utils import assign_layer_ids, get_column_features, get_mmocr_config_and_model, get_model_head
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
@@ -3,25 +3,18 @@ from typing import Dict, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.nn.functional as F
|
6
|
-
from torch import nn
|
7
6
|
from transformers import logging as hf_logging
|
8
7
|
|
9
8
|
from ..constants import (
|
10
|
-
AUTOMM,
|
11
|
-
COLUMN,
|
12
9
|
COLUMN_FEATURES,
|
13
10
|
FEATURES,
|
14
|
-
LABEL,
|
15
11
|
LOGITS,
|
16
12
|
MASKS,
|
17
13
|
NER_ANNOTATION,
|
18
|
-
TEXT_SEGMENT_IDS,
|
19
|
-
TEXT_TOKEN_IDS,
|
20
|
-
TEXT_VALID_LENGTH,
|
21
14
|
TOKEN_WORD_MAPPING,
|
22
15
|
WORD_OFFSETS,
|
23
16
|
)
|
24
|
-
from .
|
17
|
+
from .hf_text import HFAutoModelForTextPrediction
|
25
18
|
from .utils import assign_layer_ids, get_column_features, get_pretrained_tokenizer
|
26
19
|
|
27
20
|
hf_logging.set_verbosity_error()
|
@@ -1,4 +1,5 @@
|
|
1
|
-
|
1
|
+
import logging
|
2
|
+
from typing import Dict, List, Optional
|
2
3
|
|
3
4
|
from torch import nn
|
4
5
|
|
@@ -7,6 +8,8 @@ from .ft_transformer import NumEmbeddings
|
|
7
8
|
from .mlp import MLP
|
8
9
|
from .utils import init_weights
|
9
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
10
13
|
|
11
14
|
class NumericalMLP(nn.Module):
|
12
15
|
"""
|
@@ -21,11 +24,12 @@ class NumericalMLP(nn.Module):
|
|
21
24
|
out_features: Optional[int] = None,
|
22
25
|
num_layers: Optional[int] = 1,
|
23
26
|
activation: Optional[str] = "leaky_relu",
|
24
|
-
|
27
|
+
dropout: Optional[float] = 0.5,
|
25
28
|
normalization: Optional[str] = "layer_norm",
|
26
29
|
num_classes: Optional[int] = 0,
|
27
|
-
|
30
|
+
token_dim: Optional[int] = 8,
|
28
31
|
embedding_arch: Optional[List[str]] = None,
|
32
|
+
numerical_fill_values: Optional[Dict] = None,
|
29
33
|
):
|
30
34
|
"""
|
31
35
|
Parameters
|
@@ -42,13 +46,13 @@ class NumericalMLP(nn.Module):
|
|
42
46
|
Number of MLP layers.
|
43
47
|
activation
|
44
48
|
Name of activation function.
|
45
|
-
|
49
|
+
dropout
|
46
50
|
Dropout probability.
|
47
51
|
normalization
|
48
52
|
Name of normalization function.
|
49
53
|
num_classes
|
50
54
|
Number of classes. 1 for a regression task.
|
51
|
-
|
55
|
+
token_dim
|
52
56
|
The size of one token for `NumericalEmbedding`.
|
53
57
|
embedding_arch
|
54
58
|
A list containing the names of embedding layers.
|
@@ -56,19 +60,21 @@ class NumericalMLP(nn.Module):
|
|
56
60
|
{'linear', 'shared_linear', 'autodis', 'positional', 'relu', 'layernorm'}
|
57
61
|
"""
|
58
62
|
super().__init__()
|
63
|
+
logger.debug(f"initializing {prefix} (NumericalMLP)")
|
59
64
|
self.out_features = out_features
|
65
|
+
self.numerical_fill_values = numerical_fill_values
|
60
66
|
|
61
67
|
self.numerical_feature_tokenizer = (
|
62
68
|
NumEmbeddings(
|
63
69
|
in_features=in_features,
|
64
|
-
d_embedding=
|
70
|
+
d_embedding=token_dim,
|
65
71
|
embedding_arch=embedding_arch,
|
66
72
|
)
|
67
73
|
if embedding_arch is not None
|
68
74
|
else nn.Identity()
|
69
75
|
)
|
70
76
|
|
71
|
-
in_features = in_features *
|
77
|
+
in_features = in_features * token_dim if embedding_arch is not None else in_features
|
72
78
|
|
73
79
|
self.mlp = MLP(
|
74
80
|
in_features=in_features,
|
@@ -76,7 +82,7 @@ class NumericalMLP(nn.Module):
|
|
76
82
|
out_features=out_features,
|
77
83
|
num_layers=num_layers,
|
78
84
|
activation=activation,
|
79
|
-
|
85
|
+
dropout=dropout,
|
80
86
|
normalization=normalization,
|
81
87
|
)
|
82
88
|
self.head = nn.Linear(out_features, num_classes) if num_classes > 0 else nn.Identity()
|
@@ -3,14 +3,13 @@ from typing import Dict, List, Optional, Tuple
|
|
3
3
|
|
4
4
|
import torch
|
5
5
|
import torch.nn.functional as F
|
6
|
-
from omegaconf import DictConfig
|
7
6
|
from torch import nn
|
8
7
|
from transformers import SamConfig
|
9
8
|
|
10
9
|
from ..constants import CLASS_LABEL, CLASS_LOGITS, COLUMN, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASK_LABEL, MOE_LOSS
|
11
10
|
from .adaptation_layers import ConvLoRALinear
|
12
11
|
from .custom_hf_models.modeling_sam_for_conv_lora import SamImageSegmentationOutput, SamModel
|
13
|
-
from .utils import assign_layer_ids, freeze_model_layers
|
12
|
+
from .utils import assign_layer_ids, freeze_model_layers, image_mean_std
|
14
13
|
|
15
14
|
logger = logging.getLogger(__name__)
|
16
15
|
|
@@ -269,6 +268,7 @@ class SAMForSemanticSegmentation(nn.Module):
|
|
269
268
|
pretrained: Optional[bool] = True,
|
270
269
|
frozen_layers: Optional[list] = None,
|
271
270
|
num_mask_tokens: int = 1,
|
271
|
+
image_norm: Optional[str] = None,
|
272
272
|
):
|
273
273
|
"""
|
274
274
|
Load a pretrained Segment Anything Model (SAM).
|
@@ -287,6 +287,15 @@ class SAMForSemanticSegmentation(nn.Module):
|
|
287
287
|
A list of substrings of frozen layers' names.
|
288
288
|
num_mask_tokens
|
289
289
|
The number of mask proposals.
|
290
|
+
image_norm
|
291
|
+
How to normalize an image. We now support:
|
292
|
+
- inception
|
293
|
+
Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
|
294
|
+
- imagenet
|
295
|
+
Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
|
296
|
+
- clip
|
297
|
+
Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
|
298
|
+
std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
|
290
299
|
"""
|
291
300
|
|
292
301
|
super().__init__()
|
@@ -305,6 +314,7 @@ class SAMForSemanticSegmentation(nn.Module):
|
|
305
314
|
|
306
315
|
self.image_size = self.model.vision_encoder.image_size
|
307
316
|
self.config = self.model.config
|
317
|
+
self.image_mean, self.image_std = image_mean_std(image_norm)
|
308
318
|
|
309
319
|
self.model.mask_decoder.num_mask_tokens = num_mask_tokens
|
310
320
|
mask_token_data = self.model.mask_decoder.mask_tokens.weight.data[0]
|
@@ -1,7 +1,4 @@
|
|
1
|
-
import collections
|
2
1
|
import logging
|
3
|
-
import os
|
4
|
-
import random
|
5
2
|
from functools import lru_cache
|
6
3
|
from typing import Dict, List, Optional, Tuple
|
7
4
|
|
@@ -12,7 +9,6 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
|
12
9
|
from transformers import logging as hf_logging
|
13
10
|
|
14
11
|
from ..constants import (
|
15
|
-
AUTOMM,
|
16
12
|
CHOICES_IDS,
|
17
13
|
COLUMN,
|
18
14
|
COLUMN_FEATURES,
|
@@ -26,7 +22,14 @@ from ..constants import (
|
|
26
22
|
TEXT_TOKEN_IDS,
|
27
23
|
TEXT_VALID_LENGTH,
|
28
24
|
)
|
29
|
-
from .utils import
|
25
|
+
from .utils import (
|
26
|
+
DummyLayer,
|
27
|
+
assign_layer_ids,
|
28
|
+
get_column_features,
|
29
|
+
get_pretrained_tokenizer,
|
30
|
+
get_text_segment_num,
|
31
|
+
get_text_token_max_len,
|
32
|
+
)
|
30
33
|
|
31
34
|
hf_logging.set_verbosity_error()
|
32
35
|
|
@@ -56,6 +59,8 @@ class TFewModel(nn.Module):
|
|
56
59
|
low_cpu_mem_usage: Optional[bool] = False,
|
57
60
|
pretrained: Optional[bool] = True,
|
58
61
|
tokenizer_name: Optional[str] = "hf_auto",
|
62
|
+
max_text_len: Optional[int] = None,
|
63
|
+
text_segment_num: Optional[int] = 1,
|
59
64
|
):
|
60
65
|
"""
|
61
66
|
Load a pretrained T5-based text transformer backbone.
|
@@ -106,6 +111,17 @@ class TFewModel(nn.Module):
|
|
106
111
|
tokenizer_name=self.tokenizer_name,
|
107
112
|
checkpoint_name=self.checkpoint_name,
|
108
113
|
)
|
114
|
+
self.max_text_len = get_text_token_max_len(
|
115
|
+
provided_max_len=max_text_len,
|
116
|
+
config=self.config,
|
117
|
+
tokenizer=self.tokenizer,
|
118
|
+
checkpoint_name=self.checkpoint_name,
|
119
|
+
)
|
120
|
+
self.text_segment_num = get_text_segment_num(
|
121
|
+
config=self.config,
|
122
|
+
provided_segment_num=text_segment_num,
|
123
|
+
checkpoint_name=self.checkpoint_name,
|
124
|
+
)
|
109
125
|
self.eos_token = self.tokenizer.eos_token
|
110
126
|
self.out_features = (
|
111
127
|
self.model.config.hidden_size
|