autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -0,0 +1,336 @@
1
+ import logging
2
+ from typing import Dict, Optional
3
+
4
+ import torch
5
+ from timm import create_model
6
+ from timm.models.vision_transformer import Block
7
+ from torch import nn
8
+
9
+ from ..constants import (
10
+ CATEGORICAL,
11
+ FEATURES,
12
+ IMAGE,
13
+ IMAGE_VALID_NUM,
14
+ LABEL,
15
+ LOGITS,
16
+ NUMERICAL,
17
+ TEXT_SEGMENT_IDS,
18
+ TEXT_TOKEN_IDS,
19
+ TEXT_VALID_LENGTH,
20
+ )
21
+ from .custom_transformer import CLSToken
22
+ from .ft_transformer import CategoricalFeatureTokenizer, NumEmbeddings
23
+ from .utils import (
24
+ assign_layer_ids,
25
+ get_hf_config_and_model,
26
+ get_image_size_mean_std,
27
+ get_pretrained_tokenizer,
28
+ get_text_segment_num,
29
+ get_text_token_max_len,
30
+ init_weights,
31
+ replace_missing_images_with_learnable,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class MetaTransformer(nn.Module):
38
+ def __init__(
39
+ self,
40
+ prefix: str,
41
+ num_classes: int,
42
+ checkpoint_path: str,
43
+ model_version: str,
44
+ has_image: bool,
45
+ has_text: bool,
46
+ num_numerical_columns: int,
47
+ num_categories: Dict,
48
+ numerical_fill_values: Dict,
49
+ image_size: Optional[int] = None,
50
+ image_norm: Optional[str] = None,
51
+ image_chan_num: Optional[int] = 3,
52
+ use_learnable_image: Optional[bool] = False,
53
+ max_text_len: Optional[int] = None,
54
+ text_segment_num: Optional[int] = 1,
55
+ ):
56
+ super().__init__()
57
+ logger.debug(f"initializing {prefix} (MetaTransformer)")
58
+ self.prefix = prefix
59
+ self.checkpoint_name = checkpoint_path
60
+
61
+ if model_version == "base":
62
+ dim = 768
63
+ num_heads = 12
64
+ layer_num = 12
65
+ elif model_version == "large":
66
+ dim = 1024
67
+ num_heads = 16
68
+ layer_num = 24
69
+ else:
70
+ raise ValueError(f"Unsupported model version: {model_version}. Options are 'base' and 'large'.")
71
+
72
+ self.model = nn.Sequential(
73
+ *[
74
+ Block(
75
+ dim=dim,
76
+ num_heads=num_heads,
77
+ mlp_ratio=4.0,
78
+ qkv_bias=True,
79
+ norm_layer=nn.LayerNorm,
80
+ act_layer=nn.GELU,
81
+ )
82
+ for i in range(layer_num)
83
+ ]
84
+ )
85
+
86
+ checkpoint = torch.load(checkpoint_path) # nosec B614
87
+ self.checkpoint_path = checkpoint_path
88
+ self.model.load_state_dict(checkpoint, strict=True)
89
+
90
+ self.head = nn.Linear(dim, num_classes) if num_classes else nn.Identity()
91
+
92
+ self.cls_token = CLSToken(token_dim=dim, initialization="uniform")
93
+ self.config = None
94
+
95
+ self.tokenizer = None
96
+ self.text_adaptor = None
97
+ self.image_tokenizer = None
98
+ self.image_adaptor = None
99
+ self.categorical_feature_tokenizer = None
100
+ self.categorical_adapter = None
101
+ self.numerical_feature_tokenizer = None
102
+ self.numerical_adapter = None
103
+
104
+ # if has_text or has_image:
105
+ # clip_ckpt = "openai/clip-vit-base-patch32"
106
+ # _, clip_model = get_hf_config_and_model(checkpoint_name=clip_ckpt, pretrained=True)
107
+
108
+ if has_text:
109
+ checkpoint_name = "microsoft/deberta-v3-base"
110
+ _, text_model = get_hf_config_and_model(checkpoint_name=checkpoint_name, pretrained=True)
111
+ self.text_config = text_model.config
112
+ # refer to https://github.com/invictus717/MetaTransformer/blob/master/Data2Seq/Data2Seq.py#L28
113
+ self.tokenizer = get_pretrained_tokenizer(
114
+ tokenizer_name="hf_auto",
115
+ checkpoint_name=checkpoint_name,
116
+ )
117
+ self.text_embed = text_model.embeddings
118
+ self.text_adaptor = nn.Linear(self.text_config.hidden_size, dim)
119
+ self.tokenizer_name = "hf_auto"
120
+ self.max_text_len = get_text_token_max_len(
121
+ provided_max_len=max_text_len,
122
+ config=self.text_config,
123
+ tokenizer=self.tokenizer,
124
+ checkpoint_name=checkpoint_name,
125
+ )
126
+ self.text_segment_num = get_text_segment_num(
127
+ config=self.text_config,
128
+ provided_segment_num=text_segment_num,
129
+ checkpoint_name=checkpoint_name,
130
+ )
131
+ if has_image:
132
+ image_model = create_model("timm/vit_base_patch16_224.mae", pretrained=True)
133
+ self.image_config = image_model.default_cfg
134
+ self.patch_embed = image_model.patch_embed
135
+ self.image_adaptor = nn.Linear(image_model.embed_dim, dim)
136
+ self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
137
+ model_name=self.prefix,
138
+ config=self.image_config,
139
+ provided_size=image_size,
140
+ provided_norm_type=image_norm,
141
+ support_variable_input_size=False,
142
+ )
143
+ self.use_learnable_image = use_learnable_image
144
+ if self.use_learnable_image:
145
+ self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
146
+ logger.debug("will use a learnable image to replace missing ones")
147
+
148
+ if num_categories:
149
+ self.num_categories = num_categories
150
+ self.categorical_feature_tokenizer = CategoricalFeatureTokenizer(
151
+ num_categories=list(num_categories.values()),
152
+ token_dim=dim,
153
+ bias=True,
154
+ initialization="normal",
155
+ )
156
+ self.categorical_adapter = nn.Linear(dim, dim)
157
+
158
+ if num_numerical_columns > 0:
159
+ self.num_numerical_columns = num_numerical_columns
160
+ self.numerical_fill_values = numerical_fill_values
161
+ self.numerical_feature_tokenizer = NumEmbeddings(
162
+ in_features=num_numerical_columns,
163
+ d_embedding=dim,
164
+ embedding_arch=["linear"],
165
+ )
166
+ self.numerical_adapter = nn.Linear(dim, dim)
167
+
168
+ self.out_features = dim
169
+
170
+ # init weights
171
+ self.head.apply(init_weights)
172
+ self.name_to_id = self.get_layer_ids()
173
+ self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
174
+
175
+ @property
176
+ def text_token_ids_key(self):
177
+ return f"{self.prefix}_{TEXT_TOKEN_IDS}"
178
+
179
+ @property
180
+ def text_valid_length_key(self):
181
+ return f"{self.prefix}_{TEXT_VALID_LENGTH}"
182
+
183
+ @property
184
+ def text_segment_ids_key(self):
185
+ return f"{self.prefix}_{TEXT_SEGMENT_IDS}"
186
+
187
+ @property
188
+ def image_key(self):
189
+ return f"{self.prefix}_{IMAGE}"
190
+
191
+ @property
192
+ def image_valid_num_key(self):
193
+ return f"{self.prefix}_{IMAGE_VALID_NUM}"
194
+
195
+ @property
196
+ def categorical_key(self):
197
+ return f"{self.prefix}_{CATEGORICAL}"
198
+
199
+ @property
200
+ def numerical_key(self):
201
+ return f"{self.prefix}_{NUMERICAL}"
202
+
203
+ @property
204
+ def label_key(self):
205
+ return f"{self.prefix}_{LABEL}"
206
+
207
+ def forward(
208
+ self,
209
+ batch: dict,
210
+ ):
211
+ multimodal_features = []
212
+ if self.image_tokenizer:
213
+ images = batch[self.image_key]
214
+ image_valid_num = batch[self.image_valid_num_key]
215
+ b, n, c, h, w = images.shape
216
+ steps = torch.arange(0, n).type_as(image_valid_num)
217
+ image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
218
+ if self.use_learnable_image:
219
+ images = replace_missing_images_with_learnable(
220
+ images=images,
221
+ image_masks=image_masks,
222
+ learnable_image=self.learnable_image,
223
+ )
224
+ image_embeddings = self.patch_embed(images.reshape((b * n, c, h, w))) # (b*n, l, d)
225
+ assert image_embeddings.ndim == 3
226
+ image_embeddings = self.image_adaptor(image_embeddings)
227
+ multimodal_features.append(image_embeddings)
228
+ if self.text_adaptor: # text tokenizer is used in text processor
229
+ text_token_ids = batch[self.text_token_ids_key]
230
+ text_valid_length = batch[self.text_valid_length_key]
231
+ steps = torch.arange(0, text_token_ids.shape[1]).type_as(text_valid_length)
232
+ text_masks = (steps.reshape((1, -1)) < text_valid_length.reshape((-1, 1))).type_as(text_token_ids)
233
+ # text_embeddings = self.text_embeddings(batch[self.text_token_ids_key]) # (b, l, d)
234
+ input_ids = text_token_ids
235
+ inputs_embeds = None
236
+ attention_mask = text_masks
237
+ position_ids = None
238
+ if "token_type_ids" in self.tokenizer.model_input_names:
239
+ token_type_ids = batch[self.text_segment_ids_key]
240
+ else:
241
+ token_type_ids = None
242
+
243
+ if input_ids is not None and inputs_embeds is not None:
244
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
245
+ elif input_ids is not None:
246
+ input_shape = input_ids.size()
247
+ elif inputs_embeds is not None:
248
+ input_shape = inputs_embeds.size()[:-1]
249
+ else:
250
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
251
+
252
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
253
+
254
+ if attention_mask is None:
255
+ attention_mask = torch.ones(input_shape, device=device)
256
+ if token_type_ids is None:
257
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
258
+
259
+ text_embeddings = self.text_embed(
260
+ input_ids=input_ids,
261
+ token_type_ids=token_type_ids,
262
+ position_ids=position_ids,
263
+ mask=attention_mask,
264
+ inputs_embeds=inputs_embeds,
265
+ )
266
+ text_embeddings = self.text_adaptor(text_embeddings)
267
+ assert text_embeddings.ndim == 3
268
+ multimodal_features.append(text_embeddings)
269
+ if self.categorical_feature_tokenizer:
270
+ categorical_inputs = []
271
+ for categorical_input in batch[self.categorical_key]:
272
+ categorical_inputs.append(categorical_input)
273
+ categorical_inputs = torch.stack(categorical_inputs, dim=1)
274
+
275
+ categorical_features = self.categorical_feature_tokenizer(categorical_inputs)
276
+ categorical_features = self.categorical_adapter(categorical_features) # (b, l, d)
277
+ assert categorical_features.ndim == 3
278
+ multimodal_features.append(categorical_features)
279
+ if self.numerical_feature_tokenizer:
280
+ numerical_features = self.numerical_feature_tokenizer(batch[self.numerical_key])
281
+ numerical_features = self.numerical_adapter(numerical_features)
282
+ assert numerical_features.ndim == 3
283
+ multimodal_features.append(numerical_features)
284
+
285
+ multimodal_features = torch.cat(multimodal_features, dim=1)
286
+ multimodal_features = self.cls_token(multimodal_features)
287
+ features = self.model(multimodal_features)
288
+ pooled_features = features[:, -1, :] # CLSToken append the cls token to the sequence tail
289
+ logits = self.head(pooled_features)
290
+ ret = {
291
+ LOGITS: logits,
292
+ FEATURES: pooled_features,
293
+ }
294
+ return {self.prefix: ret}
295
+
296
+ def get_layer_ids(self):
297
+ """
298
+ Assign an id to each layer. Layer ids will be used in layer-wise lr decay.
299
+ Basically, id gradually increases when going from the output end to
300
+ the input end. The layers defined in this class, e.g., head, have id 0.
301
+
302
+ In the AutoModel scenario, this function may not always return the correct result.
303
+ Thus, you can use "print(json.dumps(name_to_id, indent=2))" to manually check whether
304
+ the layer ids are reasonable.
305
+
306
+ Returns
307
+ -------
308
+ A dictionary mapping the layer names (keys) to their ids (values).
309
+ """
310
+ model_prefix = "model"
311
+ pre_encoder_patterns = (
312
+ "embeddings",
313
+ "LayerNorm",
314
+ "wte",
315
+ "wpe",
316
+ "shared.weight",
317
+ "encoder.conv.conv",
318
+ "relative_attention_bias",
319
+ "dummy_layer",
320
+ )
321
+ post_encoder_patterns = ("head", "pooler", "ln_f", "final_layer_norm")
322
+ names = [n for n, _ in self.named_parameters()]
323
+
324
+ name_to_id, names = assign_layer_ids(
325
+ names=names,
326
+ pre_encoder_patterns=pre_encoder_patterns,
327
+ post_encoder_patterns=post_encoder_patterns,
328
+ model_pre=model_prefix,
329
+ )
330
+ if len(names) > 0:
331
+ logger.debug(f"outer layers are treated as head: {names}")
332
+ for n in names:
333
+ assert n not in name_to_id
334
+ name_to_id[n] = 0
335
+
336
+ return name_to_id
@@ -51,7 +51,7 @@ class Unit(nn.Module):
51
51
  in_features: int,
52
52
  out_features: int,
53
53
  activation: str,
54
- dropout_prob: float,
54
+ dropout: float,
55
55
  ):
56
56
  """
57
57
  Parameters
@@ -64,7 +64,7 @@ class Unit(nn.Module):
64
64
  Dimension of output features.
65
65
  activation
66
66
  Name of activation function.
67
- dropout_prob
67
+ dropout
68
68
  Dropout probability.
69
69
  """
70
70
  super().__init__()
@@ -78,7 +78,7 @@ class Unit(nn.Module):
78
78
  raise ValueError(f"unknown normalization: {normalization}")
79
79
  self.fc = nn.Linear(in_features, out_features)
80
80
  self.act_fn = ALL_ACT_LAYERS[activation]()
81
- self.dropout = nn.Dropout(dropout_prob)
81
+ self.dropout = nn.Dropout(dropout)
82
82
 
83
83
  def forward(self, x):
84
84
  # pre normalization
@@ -102,7 +102,7 @@ class MLP(nn.Module):
102
102
  out_features: Optional[int] = None,
103
103
  num_layers: Optional[int] = 1,
104
104
  activation: Optional[str] = "gelu",
105
- dropout_prob: Optional[float] = 0.5,
105
+ dropout: Optional[float] = 0.5,
106
106
  normalization: Optional[str] = "layer_norm",
107
107
  ):
108
108
  """
@@ -118,7 +118,7 @@ class MLP(nn.Module):
118
118
  Number of layers.
119
119
  activation
120
120
  Name of activation function.
121
- dropout_prob
121
+ dropout
122
122
  Dropout probability.
123
123
  normalization
124
124
  Name of normalization function.
@@ -134,7 +134,7 @@ class MLP(nn.Module):
134
134
  in_features=in_features,
135
135
  out_features=hidden_features,
136
136
  activation=activation,
137
- dropout_prob=dropout_prob,
137
+ dropout=dropout,
138
138
  )
139
139
  in_features = hidden_features
140
140
  layers.append(per_unit)
@@ -11,7 +11,7 @@ except ImportError:
11
11
  mmocr = None
12
12
  from torch import nn
13
13
 
14
- from ..constants import AUTOMM, BBOX, COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
14
+ from ..constants import BBOX, COLUMN, COLUMN_FEATURES, FEATURES, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASKS
15
15
  from .utils import assign_layer_ids, get_column_features, get_mmocr_config_and_model, get_model_head
16
16
 
17
17
  logger = logging.getLogger(__name__)
@@ -12,7 +12,6 @@ except ImportError:
12
12
  from torch import nn
13
13
 
14
14
  from ..constants import (
15
- AUTOMM,
16
15
  COLUMN,
17
16
  COLUMN_FEATURES,
18
17
  FEATURES,
@@ -3,25 +3,18 @@ from typing import Dict, List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
- from torch import nn
7
6
  from transformers import logging as hf_logging
8
7
 
9
8
  from ..constants import (
10
- AUTOMM,
11
- COLUMN,
12
9
  COLUMN_FEATURES,
13
10
  FEATURES,
14
- LABEL,
15
11
  LOGITS,
16
12
  MASKS,
17
13
  NER_ANNOTATION,
18
- TEXT_SEGMENT_IDS,
19
- TEXT_TOKEN_IDS,
20
- TEXT_VALID_LENGTH,
21
14
  TOKEN_WORD_MAPPING,
22
15
  WORD_OFFSETS,
23
16
  )
24
- from .huggingface_text import HFAutoModelForTextPrediction
17
+ from .hf_text import HFAutoModelForTextPrediction
25
18
  from .utils import assign_layer_ids, get_column_features, get_pretrained_tokenizer
26
19
 
27
20
  hf_logging.set_verbosity_error()
@@ -1,4 +1,5 @@
1
- from typing import List, Optional
1
+ import logging
2
+ from typing import Dict, List, Optional
2
3
 
3
4
  from torch import nn
4
5
 
@@ -7,6 +8,8 @@ from .ft_transformer import NumEmbeddings
7
8
  from .mlp import MLP
8
9
  from .utils import init_weights
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class NumericalMLP(nn.Module):
12
15
  """
@@ -21,11 +24,12 @@ class NumericalMLP(nn.Module):
21
24
  out_features: Optional[int] = None,
22
25
  num_layers: Optional[int] = 1,
23
26
  activation: Optional[str] = "leaky_relu",
24
- dropout_prob: Optional[float] = 0.5,
27
+ dropout: Optional[float] = 0.5,
25
28
  normalization: Optional[str] = "layer_norm",
26
29
  num_classes: Optional[int] = 0,
27
- d_token: Optional[int] = 8,
30
+ token_dim: Optional[int] = 8,
28
31
  embedding_arch: Optional[List[str]] = None,
32
+ numerical_fill_values: Optional[Dict] = None,
29
33
  ):
30
34
  """
31
35
  Parameters
@@ -42,13 +46,13 @@ class NumericalMLP(nn.Module):
42
46
  Number of MLP layers.
43
47
  activation
44
48
  Name of activation function.
45
- dropout_prob
49
+ dropout
46
50
  Dropout probability.
47
51
  normalization
48
52
  Name of normalization function.
49
53
  num_classes
50
54
  Number of classes. 1 for a regression task.
51
- d_token
55
+ token_dim
52
56
  The size of one token for `NumericalEmbedding`.
53
57
  embedding_arch
54
58
  A list containing the names of embedding layers.
@@ -56,19 +60,21 @@ class NumericalMLP(nn.Module):
56
60
  {'linear', 'shared_linear', 'autodis', 'positional', 'relu', 'layernorm'}
57
61
  """
58
62
  super().__init__()
63
+ logger.debug(f"initializing {prefix} (NumericalMLP)")
59
64
  self.out_features = out_features
65
+ self.numerical_fill_values = numerical_fill_values
60
66
 
61
67
  self.numerical_feature_tokenizer = (
62
68
  NumEmbeddings(
63
69
  in_features=in_features,
64
- d_embedding=d_token,
70
+ d_embedding=token_dim,
65
71
  embedding_arch=embedding_arch,
66
72
  )
67
73
  if embedding_arch is not None
68
74
  else nn.Identity()
69
75
  )
70
76
 
71
- in_features = in_features * d_token if embedding_arch is not None else in_features
77
+ in_features = in_features * token_dim if embedding_arch is not None else in_features
72
78
 
73
79
  self.mlp = MLP(
74
80
  in_features=in_features,
@@ -76,7 +82,7 @@ class NumericalMLP(nn.Module):
76
82
  out_features=out_features,
77
83
  num_layers=num_layers,
78
84
  activation=activation,
79
- dropout_prob=dropout_prob,
85
+ dropout=dropout,
80
86
  normalization=normalization,
81
87
  )
82
88
  self.head = nn.Linear(out_features, num_classes) if num_classes > 0 else nn.Identity()
@@ -3,14 +3,13 @@ from typing import Dict, List, Optional, Tuple
3
3
 
4
4
  import torch
5
5
  import torch.nn.functional as F
6
- from omegaconf import DictConfig
7
6
  from torch import nn
8
7
  from transformers import SamConfig
9
8
 
10
9
  from ..constants import CLASS_LABEL, CLASS_LOGITS, COLUMN, IMAGE, IMAGE_VALID_NUM, LABEL, LOGITS, MASK_LABEL, MOE_LOSS
11
10
  from .adaptation_layers import ConvLoRALinear
12
11
  from .custom_hf_models.modeling_sam_for_conv_lora import SamImageSegmentationOutput, SamModel
13
- from .utils import assign_layer_ids, freeze_model_layers
12
+ from .utils import assign_layer_ids, freeze_model_layers, image_mean_std
14
13
 
15
14
  logger = logging.getLogger(__name__)
16
15
 
@@ -269,6 +268,7 @@ class SAMForSemanticSegmentation(nn.Module):
269
268
  pretrained: Optional[bool] = True,
270
269
  frozen_layers: Optional[list] = None,
271
270
  num_mask_tokens: int = 1,
271
+ image_norm: Optional[str] = None,
272
272
  ):
273
273
  """
274
274
  Load a pretrained Segment Anything Model (SAM).
@@ -287,6 +287,15 @@ class SAMForSemanticSegmentation(nn.Module):
287
287
  A list of substrings of frozen layers' names.
288
288
  num_mask_tokens
289
289
  The number of mask proposals.
290
+ image_norm
291
+ How to normalize an image. We now support:
292
+ - inception
293
+ Normalize image by IMAGENET_INCEPTION_MEAN and IMAGENET_INCEPTION_STD from timm
294
+ - imagenet
295
+ Normalize image by IMAGENET_DEFAULT_MEAN and IMAGENET_DEFAULT_STD from timm
296
+ - clip
297
+ Normalize image by mean (0.48145466, 0.4578275, 0.40821073) and
298
+ std (0.26862954, 0.26130258, 0.27577711), used for CLIP.
290
299
  """
291
300
 
292
301
  super().__init__()
@@ -305,6 +314,7 @@ class SAMForSemanticSegmentation(nn.Module):
305
314
 
306
315
  self.image_size = self.model.vision_encoder.image_size
307
316
  self.config = self.model.config
317
+ self.image_mean, self.image_std = image_mean_std(image_norm)
308
318
 
309
319
  self.model.mask_decoder.num_mask_tokens = num_mask_tokens
310
320
  mask_token_data = self.model.mask_decoder.mask_tokens.weight.data[0]
@@ -1,7 +1,4 @@
1
- import collections
2
1
  import logging
3
- import os
4
- import random
5
2
  from functools import lru_cache
6
3
  from typing import Dict, List, Optional, Tuple
7
4
 
@@ -12,7 +9,6 @@ from transformers import AutoConfig, AutoModelForSeq2SeqLM
12
9
  from transformers import logging as hf_logging
13
10
 
14
11
  from ..constants import (
15
- AUTOMM,
16
12
  CHOICES_IDS,
17
13
  COLUMN,
18
14
  COLUMN_FEATURES,
@@ -26,7 +22,14 @@ from ..constants import (
26
22
  TEXT_TOKEN_IDS,
27
23
  TEXT_VALID_LENGTH,
28
24
  )
29
- from .utils import DummyLayer, assign_layer_ids, get_column_features, get_pretrained_tokenizer
25
+ from .utils import (
26
+ DummyLayer,
27
+ assign_layer_ids,
28
+ get_column_features,
29
+ get_pretrained_tokenizer,
30
+ get_text_segment_num,
31
+ get_text_token_max_len,
32
+ )
30
33
 
31
34
  hf_logging.set_verbosity_error()
32
35
 
@@ -56,6 +59,8 @@ class TFewModel(nn.Module):
56
59
  low_cpu_mem_usage: Optional[bool] = False,
57
60
  pretrained: Optional[bool] = True,
58
61
  tokenizer_name: Optional[str] = "hf_auto",
62
+ max_text_len: Optional[int] = None,
63
+ text_segment_num: Optional[int] = 1,
59
64
  ):
60
65
  """
61
66
  Load a pretrained T5-based text transformer backbone.
@@ -106,6 +111,17 @@ class TFewModel(nn.Module):
106
111
  tokenizer_name=self.tokenizer_name,
107
112
  checkpoint_name=self.checkpoint_name,
108
113
  )
114
+ self.max_text_len = get_text_token_max_len(
115
+ provided_max_len=max_text_len,
116
+ config=self.config,
117
+ tokenizer=self.tokenizer,
118
+ checkpoint_name=self.checkpoint_name,
119
+ )
120
+ self.text_segment_num = get_text_segment_num(
121
+ config=self.config,
122
+ provided_segment_num=text_segment_num,
123
+ checkpoint_name=self.checkpoint_name,
124
+ )
109
125
  self.eos_token = self.tokenizer.eos_token
110
126
  self.out_features = (
111
127
  self.model.config.hidden_size