autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250304__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. autogluon/multimodal/__init__.py +4 -2
  2. autogluon/multimodal/configs/data/default.yaml +4 -2
  3. autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
  4. autogluon/multimodal/configs/model/default.yaml +58 -11
  5. autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
  6. autogluon/multimodal/constants.py +16 -5
  7. autogluon/multimodal/data/__init__.py +14 -2
  8. autogluon/multimodal/data/dataset.py +2 -2
  9. autogluon/multimodal/data/infer_types.py +16 -2
  10. autogluon/multimodal/data/label_encoder.py +3 -3
  11. autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
  12. autogluon/multimodal/data/preprocess_dataframe.py +55 -38
  13. autogluon/multimodal/data/process_categorical.py +35 -6
  14. autogluon/multimodal/data/process_document.py +59 -33
  15. autogluon/multimodal/data/process_image.py +198 -163
  16. autogluon/multimodal/data/process_label.py +7 -3
  17. autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
  18. autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
  19. autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
  20. autogluon/multimodal/data/process_ner.py +192 -4
  21. autogluon/multimodal/data/process_numerical.py +32 -5
  22. autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
  23. autogluon/multimodal/data/process_text.py +95 -58
  24. autogluon/multimodal/data/template_engine.py +7 -9
  25. autogluon/multimodal/data/templates.py +0 -2
  26. autogluon/multimodal/data/trivial_augmenter.py +2 -2
  27. autogluon/multimodal/data/utils.py +564 -338
  28. autogluon/multimodal/learners/__init__.py +2 -1
  29. autogluon/multimodal/learners/base.py +189 -189
  30. autogluon/multimodal/learners/ensemble.py +748 -0
  31. autogluon/multimodal/learners/few_shot_svm.py +6 -15
  32. autogluon/multimodal/learners/matching.py +59 -84
  33. autogluon/multimodal/learners/ner.py +23 -22
  34. autogluon/multimodal/learners/object_detection.py +26 -21
  35. autogluon/multimodal/learners/semantic_segmentation.py +16 -18
  36. autogluon/multimodal/models/__init__.py +12 -3
  37. autogluon/multimodal/models/augmenter.py +175 -0
  38. autogluon/multimodal/models/categorical_mlp.py +13 -8
  39. autogluon/multimodal/models/clip.py +92 -18
  40. autogluon/multimodal/models/custom_transformer.py +75 -75
  41. autogluon/multimodal/models/document_transformer.py +23 -9
  42. autogluon/multimodal/models/ft_transformer.py +40 -35
  43. autogluon/multimodal/models/fusion/base.py +2 -4
  44. autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
  45. autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
  46. autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
  47. autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
  48. autogluon/multimodal/models/meta_transformer.py +336 -0
  49. autogluon/multimodal/models/mlp.py +6 -6
  50. autogluon/multimodal/models/mmocr_text_detection.py +1 -1
  51. autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
  52. autogluon/multimodal/models/ner_text.py +1 -8
  53. autogluon/multimodal/models/numerical_mlp.py +14 -8
  54. autogluon/multimodal/models/sam.py +12 -2
  55. autogluon/multimodal/models/t_few.py +21 -5
  56. autogluon/multimodal/models/timm_image.py +74 -32
  57. autogluon/multimodal/models/utils.py +877 -16
  58. autogluon/multimodal/optim/__init__.py +17 -0
  59. autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
  60. autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
  61. autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
  62. autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
  63. autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
  64. autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
  65. autogluon/multimodal/optim/losses/__init__.py +14 -0
  66. autogluon/multimodal/optim/losses/bce_loss.py +25 -0
  67. autogluon/multimodal/optim/losses/focal_loss.py +81 -0
  68. autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
  69. autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
  70. autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
  71. autogluon/multimodal/optim/losses/structure_loss.py +26 -0
  72. autogluon/multimodal/optim/losses/utils.py +313 -0
  73. autogluon/multimodal/optim/lr/__init__.py +1 -0
  74. autogluon/multimodal/optim/lr/utils.py +332 -0
  75. autogluon/multimodal/optim/metrics/__init__.py +4 -0
  76. autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
  77. autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
  78. autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
  79. autogluon/multimodal/optim/metrics/utils.py +359 -0
  80. autogluon/multimodal/optim/utils.py +284 -0
  81. autogluon/multimodal/predictor.py +51 -12
  82. autogluon/multimodal/utils/__init__.py +19 -45
  83. autogluon/multimodal/utils/cache.py +23 -2
  84. autogluon/multimodal/utils/checkpoint.py +58 -5
  85. autogluon/multimodal/utils/config.py +127 -55
  86. autogluon/multimodal/utils/device.py +120 -0
  87. autogluon/multimodal/utils/distillation.py +8 -8
  88. autogluon/multimodal/utils/download.py +1 -1
  89. autogluon/multimodal/utils/env.py +22 -0
  90. autogluon/multimodal/utils/export.py +3 -3
  91. autogluon/multimodal/utils/hpo.py +5 -5
  92. autogluon/multimodal/utils/inference.py +37 -4
  93. autogluon/multimodal/utils/install.py +91 -0
  94. autogluon/multimodal/utils/load.py +52 -47
  95. autogluon/multimodal/utils/log.py +6 -41
  96. autogluon/multimodal/utils/matcher.py +3 -2
  97. autogluon/multimodal/utils/onnx.py +0 -4
  98. autogluon/multimodal/utils/path.py +10 -0
  99. autogluon/multimodal/utils/precision.py +130 -0
  100. autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
  101. autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
  102. autogluon/multimodal/utils/save.py +47 -29
  103. autogluon/multimodal/utils/strategy.py +24 -0
  104. autogluon/multimodal/version.py +1 -1
  105. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/METADATA +5 -5
  106. autogluon.multimodal-1.2.1b20250304.dist-info/RECORD +163 -0
  107. autogluon/multimodal/optimization/__init__.py +0 -16
  108. autogluon/multimodal/optimization/losses.py +0 -394
  109. autogluon/multimodal/optimization/utils.py +0 -1054
  110. autogluon/multimodal/utils/cloud_io.py +0 -80
  111. autogluon/multimodal/utils/data.py +0 -701
  112. autogluon/multimodal/utils/environment.py +0 -395
  113. autogluon/multimodal/utils/metric.py +0 -500
  114. autogluon/multimodal/utils/model.py +0 -558
  115. autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
  116. /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
  117. /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
  118. /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
  119. /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
  120. /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250304-py3.9-nspkg.pth +0 -0
  121. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/LICENSE +0 -0
  122. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/NOTICE +0 -0
  123. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/WHEEL +0 -0
  124. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/namespace_packages.txt +0 -0
  125. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/top_level.txt +0 -0
  126. {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250304.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
1
- from . import utils
1
+ from .augmenter import Augmenter
2
2
  from .categorical_mlp import CategoricalMLP
3
3
  from .clip import CLIPForImageText
4
4
  from .document_transformer import DocumentTransformer
@@ -9,7 +9,8 @@ from .fusion import (
9
9
  MultimodalFusionNER,
10
10
  MultimodalFusionTransformer,
11
11
  )
12
- from .huggingface_text import HFAutoModelForTextPrediction
12
+ from .hf_text import HFAutoModelForTextPrediction
13
+ from .meta_transformer import MetaTransformer
13
14
  from .mmdet_image import MMDetAutoModelForObjectDetection
14
15
  from .mmocr_text_detection import MMOCRAutoModelForTextDetection
15
16
  from .mmocr_text_recognition import MMOCRAutoModelForTextRecognition
@@ -18,4 +19,12 @@ from .numerical_mlp import NumericalMLP
18
19
  from .sam import SAMForSemanticSegmentation
19
20
  from .t_few import TFewModel
20
21
  from .timm_image import TimmAutoModelForImagePrediction
21
- from .utils import get_model_postprocess_fn
22
+ from .utils import (
23
+ create_fusion_model,
24
+ create_model,
25
+ get_model_postprocess_fn,
26
+ is_lazy_weight_tensor,
27
+ list_timm_models,
28
+ modify_duplicate_model_names,
29
+ select_model,
30
+ )
@@ -0,0 +1,175 @@
1
+ import logging
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import DictConfig
6
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
7
+
8
+ from .mlp import Unit
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ class VAETransformer(nn.Module):
14
+ def __init__(self, config: DictConfig, in_feautres: int, n_modality: int) -> None:
15
+ super().__init__()
16
+ self.config = config
17
+ self.emb_d = in_feautres
18
+ self.n_modality = n_modality
19
+ logger.debug(f" VAE Transformer # features {n_modality}, dim {self.emb_d}")
20
+
21
+ # encoder
22
+ encoder_layers = TransformerEncoderLayer(self.emb_d, config.n_head, config.tran_hidden, norm_first=True)
23
+ self.transformer_encoder = TransformerEncoder(encoder_layers, config.n_layer)
24
+
25
+ # encoder linear z
26
+ self.encoder_fc_z_mu = nn.Linear(self.emb_d, self.config.z_dim)
27
+ self.encoder_fc_z_logvar = nn.Linear(self.emb_d, self.config.z_dim)
28
+
29
+ # decoder linezr z
30
+ self.decoder_fc = nn.Linear(self.config.z_dim, self.emb_d)
31
+
32
+ # decoder
33
+ decoder_layers = TransformerEncoderLayer(self.emb_d, config.n_head, config.tran_hidden, norm_first=True)
34
+ self.transformer_decoder = TransformerEncoder(decoder_layers, config.n_layer)
35
+
36
+ self.last_layer = nn.Linear(self.emb_d, self.emb_d)
37
+
38
+ self.gating = nn.Identity()
39
+ self.init_parameters()
40
+
41
+ def init_parameters(self):
42
+ self.last_layer.weight.data.zero_()
43
+ self.last_layer.bias.data.zero_()
44
+
45
+ def reparameterize(self, mu, logvar):
46
+ std = torch.exp(0.5 * logvar)
47
+ eps = torch.randn_like(std)
48
+ return mu + eps * std
49
+
50
+ def forward(self, X):
51
+ input = X.reshape(-1, self.n_modality, self.emb_d) # [B, # modality, emb dim] torch.Size([8, 3, 1024])
52
+
53
+ hidden = self.transformer_encoder(input)
54
+
55
+ z_mu, z_logvar = self.encoder_fc_z_mu(hidden), self.encoder_fc_z_logvar(hidden)
56
+
57
+ z = self.reparameterize(z_mu, z_logvar)
58
+
59
+ hidden = self.decoder_fc(z)
60
+
61
+ noise = self.gating(self.last_layer(self.transformer_decoder(hidden)[:, : self.n_modality, :]))
62
+ recon_x = X.reshape(-1, self.n_modality, self.emb_d) + noise
63
+
64
+ return recon_x.reshape(len(X), -1), z_mu, z_logvar
65
+
66
+
67
+ class MlpVAE(nn.Module):
68
+ def __init__(self, input_dim, hidden_dim, z_dim=16) -> None:
69
+ super().__init__()
70
+ self.input_dim = input_dim
71
+ self.z_dim = z_dim
72
+ self.hidden_dim = hidden_dim
73
+
74
+ # Encoder P(Z|X)
75
+ encoder_layers = []
76
+ dims = [input_dim] + hidden_dim
77
+ for i in range(len(dims) - 1):
78
+ encoder_layers.append(
79
+ Unit(
80
+ normalization="layer_norm",
81
+ in_features=dims[i],
82
+ out_features=dims[i + 1],
83
+ activation="relu",
84
+ dropout=0.5,
85
+ )
86
+ )
87
+ self.encoder = nn.Sequential(*encoder_layers)
88
+
89
+ self.encoder_fc_z_mu = nn.Linear(self.hidden_dim[-1], self.z_dim)
90
+ self.encoder_fc_z_logvar = nn.Linear(self.hidden_dim[-1], self.z_dim)
91
+
92
+ # Decoder P(X|Z)
93
+ decoder_layers = []
94
+ dims = [input_dim] + hidden_dim + [z_dim]
95
+
96
+ for i in range(len(dims) - 1, 0, -1):
97
+ decoder_layers.append(
98
+ Unit(
99
+ normalization="layer_norm",
100
+ in_features=dims[i],
101
+ out_features=dims[i - 1],
102
+ activation="relu",
103
+ dropout=0.5,
104
+ )
105
+ )
106
+ self.decoder = nn.Sequential(*decoder_layers)
107
+
108
+ self.init_parameters()
109
+
110
+ def init_parameters(self):
111
+ self.decoder[-1].fc.weight.data.zero_()
112
+ self.decoder[-1].fc.bias.data.zero_()
113
+
114
+ def reparameterize(self, mu, logvar):
115
+ std = torch.exp(0.5 * logvar)
116
+ eps = torch.randn_like(std)
117
+ return mu + eps * std
118
+
119
+ def forward(self, x):
120
+ hidden = self.encoder(x)
121
+ z_mu, z_logvar = self.encoder_fc_z_mu(hidden), self.encoder_fc_z_logvar(hidden)
122
+ z = self.reparameterize(z_mu, z_logvar)
123
+
124
+ noise_x = self.decoder(z)
125
+ recon_x = x + noise_x
126
+ return recon_x, z_mu, z_logvar
127
+
128
+
129
+ class Augmenter(nn.Module):
130
+ def __init__(
131
+ self,
132
+ arch_type: str,
133
+ input_dim: int,
134
+ z_dim: int,
135
+ num_layers: int,
136
+ adv_weight: float,
137
+ ) -> None:
138
+ super().__init__()
139
+ logger.debug("Initializing Augmenter")
140
+ self.arch_type = arch_type
141
+ self.input_dim = input_dim
142
+ self.z_dim = z_dim
143
+ self.num_layers = num_layers
144
+ self.adv_weight = adv_weight
145
+ logger.debug(f"augmenter arch_type: {self.arch_type}")
146
+ logger.debug(f"augmenter input_dim: {self.input_dim}")
147
+ logger.debug(f"augmenter z_dim: {self.z_dim}")
148
+ logger.debug(f"augmenter num_layers: {self.num_layers}")
149
+ logger.debug(f"augmenter adv_weight: {self.adv_weight}")
150
+ if self.arch_type == "mlp_vae":
151
+ step = int((self.input_dim - self.z_dim) / (self.num_layers + 1))
152
+ hidden = [*range(self.input_dim - step, self.z_dim + step, -step)]
153
+ self.vae = MlpVAE(input_dim=self.input_dim, hidden_dim=hidden, z_dim=self.z_dim)
154
+ else:
155
+ raise ValueError(f"Unknown arch_type: {self.arch_type}")
156
+
157
+ self.name_to_id = self.get_layer_ids()
158
+
159
+ def forward(self, x):
160
+ return self.vae(x)
161
+
162
+ def get_layer_ids(
163
+ self,
164
+ ):
165
+ """
166
+ All layers have the same id 0 since there is no pre-trained models used here.
167
+
168
+ Returns
169
+ -------
170
+ A dictionary mapping the layer names (keys) to their ids (values).
171
+ """
172
+ name_to_id = {}
173
+ for n, _ in self.named_parameters():
174
+ name_to_id[n] = 0
175
+ return name_to_id
@@ -1,4 +1,5 @@
1
- from typing import List, Optional
1
+ import logging
2
+ from typing import Dict, Optional
2
3
 
3
4
  import torch
4
5
  from torch import nn
@@ -7,6 +8,8 @@ from ..constants import CATEGORICAL, FEATURES, LABEL, LOGITS
7
8
  from .mlp import MLP
8
9
  from .utils import init_weights
9
10
 
11
+ logger = logging.getLogger(__name__)
12
+
10
13
 
11
14
  class CategoricalMLP(nn.Module):
12
15
  """
@@ -17,11 +20,11 @@ class CategoricalMLP(nn.Module):
17
20
  def __init__(
18
21
  self,
19
22
  prefix: str,
20
- num_categories: List[int],
23
+ num_categories: Dict,
21
24
  out_features: Optional[int] = None,
22
25
  num_layers: Optional[int] = 1,
23
26
  activation: Optional[str] = "gelu",
24
- dropout_prob: Optional[float] = 0.5,
27
+ dropout: Optional[float] = 0.5,
25
28
  normalization: Optional[str] = "layer_norm",
26
29
  num_classes: Optional[int] = 0,
27
30
  ):
@@ -38,7 +41,7 @@ class CategoricalMLP(nn.Module):
38
41
  Number of MLP layers.
39
42
  activation
40
43
  Name of activation function.
41
- dropout_prob
44
+ dropout
42
45
  Dropout probability.
43
46
  normalization
44
47
  Name of normalization function.
@@ -46,15 +49,17 @@ class CategoricalMLP(nn.Module):
46
49
  Number of classes. 1 for a regression task.
47
50
  """
48
51
  super().__init__()
52
+ logger.debug(f"initializing {prefix} (CategoricalMLP)")
49
53
  self.out_features = out_features
50
54
  max_embedding_dim = 100
51
55
  embed_exponent = 0.56
52
56
  size_factor = 1.0
53
57
  self.column_embeddings = nn.ModuleList()
54
58
  self.column_mlps = nn.ModuleList()
55
- assert isinstance(num_categories, list)
59
+ assert isinstance(num_categories, dict)
60
+ self.num_categories = num_categories
56
61
 
57
- for num_categories_per_col in num_categories:
62
+ for num_categories_per_col in num_categories.values():
58
63
  embedding_dim_per_col = int(
59
64
  size_factor * max(2, min(max_embedding_dim, 1.6 * num_categories_per_col**embed_exponent))
60
65
  )
@@ -72,7 +77,7 @@ class CategoricalMLP(nn.Module):
72
77
  out_features=out_features,
73
78
  num_layers=num_layers,
74
79
  activation=activation,
75
- dropout_prob=dropout_prob,
80
+ dropout=dropout,
76
81
  normalization=normalization,
77
82
  )
78
83
  )
@@ -83,7 +88,7 @@ class CategoricalMLP(nn.Module):
83
88
  out_features=out_features,
84
89
  num_layers=num_layers,
85
90
  activation=activation,
86
- dropout_prob=dropout_prob,
91
+ dropout=dropout,
87
92
  normalization=normalization,
88
93
  )
89
94
 
@@ -5,7 +5,6 @@ import torch
5
5
  from torch import nn
6
6
 
7
7
  from ..constants import (
8
- AUTOMM,
9
8
  COLUMN,
10
9
  COLUMN_FEATURES,
11
10
  FEATURES,
@@ -22,8 +21,12 @@ from .utils import (
22
21
  assign_layer_ids,
23
22
  get_column_features,
24
23
  get_hf_config_and_model,
24
+ get_image_size_mean_std,
25
25
  get_pretrained_tokenizer,
26
+ get_text_segment_num,
27
+ get_text_token_max_len,
26
28
  init_weights,
29
+ replace_missing_images_with_learnable,
27
30
  )
28
31
 
29
32
  logger = logging.getLogger(__name__)
@@ -42,6 +45,15 @@ class CLIPForImageText(nn.Module):
42
45
  num_classes: Optional[int] = None,
43
46
  pretrained: Optional[bool] = True,
44
47
  tokenizer_name: Optional[str] = "clip",
48
+ has_image: Optional[bool] = True,
49
+ has_text: Optional[bool] = True,
50
+ image_size: Optional[int] = None,
51
+ image_norm: Optional[str] = None,
52
+ image_chan_num: Optional[int] = 3,
53
+ use_learnable_image: Optional[bool] = False,
54
+ max_text_len: Optional[int] = None,
55
+ text_segment_num: Optional[int] = 1,
56
+ is_matching: Optional[bool] = False,
45
57
  ):
46
58
  """
47
59
  Load the pretrained CLIP from huggingface transformers.
@@ -60,16 +72,26 @@ class CLIPForImageText(nn.Module):
60
72
  Name of the huggingface tokenizer type.
61
73
  """
62
74
  super().__init__()
63
- logger.debug(f"initializing {checkpoint_name}")
75
+ logger.debug(f"initializing {prefix} (CLIPForImageText)")
76
+ logger.debug(f"model checkpoint: {checkpoint_name}")
64
77
  self.checkpoint_name = checkpoint_name
65
78
  self.num_classes = num_classes
79
+ if is_matching: # init both image and text attributes for matching
80
+ has_image, has_text = True, True
81
+ self.has_image = has_image
82
+ self.has_text = has_text
66
83
 
67
84
  self.config, self.model = get_hf_config_and_model(checkpoint_name=checkpoint_name, pretrained=pretrained)
68
- self.tokenizer_name = tokenizer_name
69
- self.tokenizer = get_pretrained_tokenizer(
70
- tokenizer_name=self.tokenizer_name,
71
- checkpoint_name=self.checkpoint_name,
72
- )
85
+
86
+ if not self.has_image:
87
+ self.config.vision_config = None
88
+ self.model.vision_model = None
89
+ self.model.visual_projection = None
90
+
91
+ if not self.has_text:
92
+ self.config.text_config = None
93
+ self.model.text_model = None
94
+ self.model.text_projection = None
73
95
 
74
96
  self.out_features = self.model.config.projection_dim
75
97
 
@@ -77,6 +99,35 @@ class CLIPForImageText(nn.Module):
77
99
  self.head.apply(init_weights)
78
100
 
79
101
  self.prefix = prefix
102
+ if has_image:
103
+ self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
104
+ model_name=self.prefix,
105
+ config=self.model.vision_model.config,
106
+ provided_size=image_size,
107
+ provided_norm_type=image_norm,
108
+ support_variable_input_size=False,
109
+ )
110
+ self.use_learnable_image = use_learnable_image
111
+ if self.use_learnable_image:
112
+ self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
113
+ logger.debug("will use a learnable image to replace missing ones")
114
+ if has_text:
115
+ self.tokenizer_name = tokenizer_name
116
+ self.tokenizer = get_pretrained_tokenizer(
117
+ tokenizer_name=self.tokenizer_name,
118
+ checkpoint_name=self.checkpoint_name,
119
+ )
120
+ self.max_text_len = get_text_token_max_len(
121
+ provided_max_len=max_text_len,
122
+ config=self.model.text_model.config,
123
+ tokenizer=self.tokenizer,
124
+ checkpoint_name=self.checkpoint_name,
125
+ )
126
+ self.text_segment_num = get_text_segment_num(
127
+ config=self.model.text_model.config,
128
+ provided_segment_num=text_segment_num,
129
+ checkpoint_name=self.checkpoint_name,
130
+ )
80
131
 
81
132
  self.name_to_id = self.get_layer_ids()
82
133
  self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
@@ -117,6 +168,15 @@ class CLIPForImageText(nn.Module):
117
168
  def image_feature_dim(self):
118
169
  return self.model.config.vision_config.hidden_size
119
170
 
171
+ @property
172
+ def input_keys(self):
173
+ ret = []
174
+ if self.has_image:
175
+ ret.extend([self.image_key, self.image_valid_num_key])
176
+ if self.has_text:
177
+ ret.extend([self.text_token_ids_key, self.text_valid_length_key])
178
+ return ret
179
+
120
180
  def forward(
121
181
  self,
122
182
  batch: dict,
@@ -132,8 +192,8 @@ class CLIPForImageText(nn.Module):
132
192
  -------
133
193
  A dictionary with logits and features.
134
194
  """
135
- has_image = self.image_key in batch
136
- has_text = self.text_token_ids_key in batch
195
+ has_image = self.has_image and self.image_key in batch
196
+ has_text = self.has_text and self.text_token_ids_key in batch
137
197
  ret = {COLUMN_FEATURES: {FEATURES: {}, MASKS: {}}}
138
198
 
139
199
  if has_image:
@@ -141,6 +201,14 @@ class CLIPForImageText(nn.Module):
141
201
  image_valid_num = batch[self.image_valid_num_key]
142
202
  assert images.dim() == 5
143
203
  b, n, c, h, w = images.shape
204
+ steps = torch.arange(0, n).type_as(image_valid_num)
205
+ image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
206
+ if self.use_learnable_image:
207
+ images = replace_missing_images_with_learnable(
208
+ images=images,
209
+ image_masks=image_masks,
210
+ learnable_image=self.learnable_image,
211
+ )
144
212
  vision_outputs = self.model.vision_model(
145
213
  pixel_values=images.reshape((b * n, c, h, w)),
146
214
  output_attentions=True,
@@ -148,9 +216,9 @@ class CLIPForImageText(nn.Module):
148
216
  return_dict=True,
149
217
  )
150
218
  image_features = self.model.visual_projection(vision_outputs.pooler_output)
151
- steps = torch.arange(0, n).type_as(image_valid_num)
152
- image_masks = (steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1))).type_as(image_features) # (b, n)
153
- image_features = image_features.reshape((b, n, -1)) * image_masks[:, :, None] # (b, n, num_features)
219
+ image_features = image_features.reshape((b, n, -1)) # (b, n, num_features)
220
+ if not self.use_learnable_image:
221
+ image_features = image_features * image_masks[:, :, None].type_as(image_features)
154
222
 
155
223
  # normalized features
156
224
  image_features = image_features / torch.clamp(image_features.norm(dim=-1, keepdim=True), min=1e-6)
@@ -199,18 +267,24 @@ class CLIPForImageText(nn.Module):
199
267
  ret[COLUMN_FEATURES][MASKS].update(text_column_feature_masks)
200
268
  ret[FEATURES] = text_features
201
269
 
202
- if has_image and has_text:
203
- if self.num_classes:
270
+ if self.num_classes:
271
+ if has_image and has_text:
204
272
  features = image_features + text_features
205
273
  logits = self.head(features)
206
274
  ret[FEATURES] = features
275
+ elif has_image:
276
+ logits = self.head(image_features)
277
+ elif has_text:
278
+ logits = self.head(text_features)
207
279
  else:
280
+ raise RuntimeError("Neither image or text are used. Must have at least one.")
281
+ ret[LOGITS] = logits
282
+ else:
283
+ ret[LOGIT_SCALE] = self.model.logit_scale.exp()
284
+ if has_image and has_text:
208
285
  # cosine similarity as logits
209
286
  logits = torch.sum(image_features * text_features, dim=-1)
210
-
211
- ret[LOGITS] = logits
212
-
213
- ret[LOGIT_SCALE] = self.model.logit_scale.exp()
287
+ ret[LOGITS] = logits
214
288
 
215
289
  return {self.prefix: ret}
216
290