autogluon.multimodal 1.2.1b20250303__py3-none-any.whl → 1.2.1b20250305__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/multimodal/__init__.py +4 -2
- autogluon/multimodal/configs/data/default.yaml +4 -2
- autogluon/multimodal/configs/{environment → env}/default.yaml +2 -3
- autogluon/multimodal/configs/model/default.yaml +58 -11
- autogluon/multimodal/configs/{optimization → optim}/default.yaml +21 -4
- autogluon/multimodal/constants.py +16 -5
- autogluon/multimodal/data/__init__.py +14 -2
- autogluon/multimodal/data/dataset.py +2 -2
- autogluon/multimodal/data/infer_types.py +16 -2
- autogluon/multimodal/data/label_encoder.py +3 -3
- autogluon/multimodal/{utils → data}/nlpaug.py +4 -4
- autogluon/multimodal/data/preprocess_dataframe.py +55 -38
- autogluon/multimodal/data/process_categorical.py +35 -6
- autogluon/multimodal/data/process_document.py +59 -33
- autogluon/multimodal/data/process_image.py +198 -163
- autogluon/multimodal/data/process_label.py +7 -3
- autogluon/multimodal/data/process_mmlab/process_mmdet.py +1 -8
- autogluon/multimodal/data/process_mmlab/process_mmlab_base.py +2 -9
- autogluon/multimodal/data/process_mmlab/process_mmocr.py +1 -9
- autogluon/multimodal/data/process_ner.py +192 -4
- autogluon/multimodal/data/process_numerical.py +32 -5
- autogluon/multimodal/data/process_semantic_seg_img.py +23 -28
- autogluon/multimodal/data/process_text.py +95 -58
- autogluon/multimodal/data/template_engine.py +7 -9
- autogluon/multimodal/data/templates.py +0 -2
- autogluon/multimodal/data/trivial_augmenter.py +2 -2
- autogluon/multimodal/data/utils.py +564 -338
- autogluon/multimodal/learners/__init__.py +2 -1
- autogluon/multimodal/learners/base.py +189 -189
- autogluon/multimodal/learners/ensemble.py +748 -0
- autogluon/multimodal/learners/few_shot_svm.py +6 -15
- autogluon/multimodal/learners/matching.py +59 -84
- autogluon/multimodal/learners/ner.py +23 -22
- autogluon/multimodal/learners/object_detection.py +26 -21
- autogluon/multimodal/learners/semantic_segmentation.py +16 -18
- autogluon/multimodal/models/__init__.py +12 -3
- autogluon/multimodal/models/augmenter.py +175 -0
- autogluon/multimodal/models/categorical_mlp.py +13 -8
- autogluon/multimodal/models/clip.py +92 -18
- autogluon/multimodal/models/custom_transformer.py +75 -75
- autogluon/multimodal/models/document_transformer.py +23 -9
- autogluon/multimodal/models/ft_transformer.py +40 -35
- autogluon/multimodal/models/fusion/base.py +2 -4
- autogluon/multimodal/models/fusion/fusion_mlp.py +82 -18
- autogluon/multimodal/models/fusion/fusion_ner.py +1 -1
- autogluon/multimodal/models/fusion/fusion_transformer.py +23 -23
- autogluon/multimodal/models/{huggingface_text.py → hf_text.py} +21 -2
- autogluon/multimodal/models/meta_transformer.py +336 -0
- autogluon/multimodal/models/mlp.py +6 -6
- autogluon/multimodal/models/mmocr_text_detection.py +1 -1
- autogluon/multimodal/models/mmocr_text_recognition.py +0 -1
- autogluon/multimodal/models/ner_text.py +1 -8
- autogluon/multimodal/models/numerical_mlp.py +14 -8
- autogluon/multimodal/models/sam.py +12 -2
- autogluon/multimodal/models/t_few.py +21 -5
- autogluon/multimodal/models/timm_image.py +74 -32
- autogluon/multimodal/models/utils.py +877 -16
- autogluon/multimodal/optim/__init__.py +17 -0
- autogluon/multimodal/{optimization → optim}/lit_distiller.py +2 -1
- autogluon/multimodal/{optimization → optim}/lit_matcher.py +4 -10
- autogluon/multimodal/{optimization → optim}/lit_mmdet.py +2 -10
- autogluon/multimodal/{optimization → optim}/lit_module.py +139 -14
- autogluon/multimodal/{optimization → optim}/lit_ner.py +3 -3
- autogluon/multimodal/{optimization → optim}/lit_semantic_seg.py +1 -1
- autogluon/multimodal/optim/losses/__init__.py +14 -0
- autogluon/multimodal/optim/losses/bce_loss.py +25 -0
- autogluon/multimodal/optim/losses/focal_loss.py +81 -0
- autogluon/multimodal/optim/losses/lemda_loss.py +39 -0
- autogluon/multimodal/optim/losses/rkd_loss.py +103 -0
- autogluon/multimodal/optim/losses/softmax_losses.py +177 -0
- autogluon/multimodal/optim/losses/structure_loss.py +26 -0
- autogluon/multimodal/optim/losses/utils.py +313 -0
- autogluon/multimodal/optim/lr/__init__.py +1 -0
- autogluon/multimodal/optim/lr/utils.py +332 -0
- autogluon/multimodal/optim/metrics/__init__.py +4 -0
- autogluon/multimodal/optim/metrics/coverage_metrics.py +42 -0
- autogluon/multimodal/optim/metrics/hit_rate_metrics.py +78 -0
- autogluon/multimodal/optim/metrics/ranking_metrics.py +231 -0
- autogluon/multimodal/optim/metrics/utils.py +359 -0
- autogluon/multimodal/optim/utils.py +284 -0
- autogluon/multimodal/predictor.py +51 -12
- autogluon/multimodal/utils/__init__.py +19 -45
- autogluon/multimodal/utils/cache.py +23 -2
- autogluon/multimodal/utils/checkpoint.py +58 -5
- autogluon/multimodal/utils/config.py +127 -55
- autogluon/multimodal/utils/device.py +120 -0
- autogluon/multimodal/utils/distillation.py +8 -8
- autogluon/multimodal/utils/download.py +1 -1
- autogluon/multimodal/utils/env.py +22 -0
- autogluon/multimodal/utils/export.py +3 -3
- autogluon/multimodal/utils/hpo.py +5 -5
- autogluon/multimodal/utils/inference.py +37 -4
- autogluon/multimodal/utils/install.py +91 -0
- autogluon/multimodal/utils/load.py +52 -47
- autogluon/multimodal/utils/log.py +6 -41
- autogluon/multimodal/utils/matcher.py +3 -2
- autogluon/multimodal/utils/onnx.py +0 -4
- autogluon/multimodal/utils/path.py +10 -0
- autogluon/multimodal/utils/precision.py +130 -0
- autogluon/multimodal/{presets.py → utils/presets.py} +259 -66
- autogluon/multimodal/{problem_types.py → utils/problem_types.py} +30 -1
- autogluon/multimodal/utils/save.py +47 -29
- autogluon/multimodal/utils/strategy.py +24 -0
- autogluon/multimodal/version.py +1 -1
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/METADATA +5 -5
- autogluon.multimodal-1.2.1b20250305.dist-info/RECORD +163 -0
- autogluon/multimodal/optimization/__init__.py +0 -16
- autogluon/multimodal/optimization/losses.py +0 -394
- autogluon/multimodal/optimization/utils.py +0 -1054
- autogluon/multimodal/utils/cloud_io.py +0 -80
- autogluon/multimodal/utils/data.py +0 -701
- autogluon/multimodal/utils/environment.py +0 -395
- autogluon/multimodal/utils/metric.py +0 -500
- autogluon/multimodal/utils/model.py +0 -558
- autogluon.multimodal-1.2.1b20250303.dist-info/RECORD +0 -145
- /autogluon/multimodal/{optimization → optim}/deepspeed.py +0 -0
- /autogluon/multimodal/{optimization/lr_scheduler.py → optim/lr/lr_schedulers.py} +0 -0
- /autogluon/multimodal/{optimization → optim/metrics}/semantic_seg_metrics.py +0 -0
- /autogluon/multimodal/{registry.py → utils/registry.py} +0 -0
- /autogluon.multimodal-1.2.1b20250303-py3.9-nspkg.pth → /autogluon.multimodal-1.2.1b20250305-py3.9-nspkg.pth +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/LICENSE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/NOTICE +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/WHEEL +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/namespace_packages.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/top_level.txt +0 -0
- {autogluon.multimodal-1.2.1b20250303.dist-info → autogluon.multimodal-1.2.1b20250305.dist-info}/zip-safe +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
from . import
|
1
|
+
from .augmenter import Augmenter
|
2
2
|
from .categorical_mlp import CategoricalMLP
|
3
3
|
from .clip import CLIPForImageText
|
4
4
|
from .document_transformer import DocumentTransformer
|
@@ -9,7 +9,8 @@ from .fusion import (
|
|
9
9
|
MultimodalFusionNER,
|
10
10
|
MultimodalFusionTransformer,
|
11
11
|
)
|
12
|
-
from .
|
12
|
+
from .hf_text import HFAutoModelForTextPrediction
|
13
|
+
from .meta_transformer import MetaTransformer
|
13
14
|
from .mmdet_image import MMDetAutoModelForObjectDetection
|
14
15
|
from .mmocr_text_detection import MMOCRAutoModelForTextDetection
|
15
16
|
from .mmocr_text_recognition import MMOCRAutoModelForTextRecognition
|
@@ -18,4 +19,12 @@ from .numerical_mlp import NumericalMLP
|
|
18
19
|
from .sam import SAMForSemanticSegmentation
|
19
20
|
from .t_few import TFewModel
|
20
21
|
from .timm_image import TimmAutoModelForImagePrediction
|
21
|
-
from .utils import
|
22
|
+
from .utils import (
|
23
|
+
create_fusion_model,
|
24
|
+
create_model,
|
25
|
+
get_model_postprocess_fn,
|
26
|
+
is_lazy_weight_tensor,
|
27
|
+
list_timm_models,
|
28
|
+
modify_duplicate_model_names,
|
29
|
+
select_model,
|
30
|
+
)
|
@@ -0,0 +1,175 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import torch.nn as nn
|
5
|
+
from omegaconf import DictConfig
|
6
|
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
7
|
+
|
8
|
+
from .mlp import Unit
|
9
|
+
|
10
|
+
logger = logging.getLogger(__name__)
|
11
|
+
|
12
|
+
|
13
|
+
class VAETransformer(nn.Module):
|
14
|
+
def __init__(self, config: DictConfig, in_feautres: int, n_modality: int) -> None:
|
15
|
+
super().__init__()
|
16
|
+
self.config = config
|
17
|
+
self.emb_d = in_feautres
|
18
|
+
self.n_modality = n_modality
|
19
|
+
logger.debug(f" VAE Transformer # features {n_modality}, dim {self.emb_d}")
|
20
|
+
|
21
|
+
# encoder
|
22
|
+
encoder_layers = TransformerEncoderLayer(self.emb_d, config.n_head, config.tran_hidden, norm_first=True)
|
23
|
+
self.transformer_encoder = TransformerEncoder(encoder_layers, config.n_layer)
|
24
|
+
|
25
|
+
# encoder linear z
|
26
|
+
self.encoder_fc_z_mu = nn.Linear(self.emb_d, self.config.z_dim)
|
27
|
+
self.encoder_fc_z_logvar = nn.Linear(self.emb_d, self.config.z_dim)
|
28
|
+
|
29
|
+
# decoder linezr z
|
30
|
+
self.decoder_fc = nn.Linear(self.config.z_dim, self.emb_d)
|
31
|
+
|
32
|
+
# decoder
|
33
|
+
decoder_layers = TransformerEncoderLayer(self.emb_d, config.n_head, config.tran_hidden, norm_first=True)
|
34
|
+
self.transformer_decoder = TransformerEncoder(decoder_layers, config.n_layer)
|
35
|
+
|
36
|
+
self.last_layer = nn.Linear(self.emb_d, self.emb_d)
|
37
|
+
|
38
|
+
self.gating = nn.Identity()
|
39
|
+
self.init_parameters()
|
40
|
+
|
41
|
+
def init_parameters(self):
|
42
|
+
self.last_layer.weight.data.zero_()
|
43
|
+
self.last_layer.bias.data.zero_()
|
44
|
+
|
45
|
+
def reparameterize(self, mu, logvar):
|
46
|
+
std = torch.exp(0.5 * logvar)
|
47
|
+
eps = torch.randn_like(std)
|
48
|
+
return mu + eps * std
|
49
|
+
|
50
|
+
def forward(self, X):
|
51
|
+
input = X.reshape(-1, self.n_modality, self.emb_d) # [B, # modality, emb dim] torch.Size([8, 3, 1024])
|
52
|
+
|
53
|
+
hidden = self.transformer_encoder(input)
|
54
|
+
|
55
|
+
z_mu, z_logvar = self.encoder_fc_z_mu(hidden), self.encoder_fc_z_logvar(hidden)
|
56
|
+
|
57
|
+
z = self.reparameterize(z_mu, z_logvar)
|
58
|
+
|
59
|
+
hidden = self.decoder_fc(z)
|
60
|
+
|
61
|
+
noise = self.gating(self.last_layer(self.transformer_decoder(hidden)[:, : self.n_modality, :]))
|
62
|
+
recon_x = X.reshape(-1, self.n_modality, self.emb_d) + noise
|
63
|
+
|
64
|
+
return recon_x.reshape(len(X), -1), z_mu, z_logvar
|
65
|
+
|
66
|
+
|
67
|
+
class MlpVAE(nn.Module):
|
68
|
+
def __init__(self, input_dim, hidden_dim, z_dim=16) -> None:
|
69
|
+
super().__init__()
|
70
|
+
self.input_dim = input_dim
|
71
|
+
self.z_dim = z_dim
|
72
|
+
self.hidden_dim = hidden_dim
|
73
|
+
|
74
|
+
# Encoder P(Z|X)
|
75
|
+
encoder_layers = []
|
76
|
+
dims = [input_dim] + hidden_dim
|
77
|
+
for i in range(len(dims) - 1):
|
78
|
+
encoder_layers.append(
|
79
|
+
Unit(
|
80
|
+
normalization="layer_norm",
|
81
|
+
in_features=dims[i],
|
82
|
+
out_features=dims[i + 1],
|
83
|
+
activation="relu",
|
84
|
+
dropout=0.5,
|
85
|
+
)
|
86
|
+
)
|
87
|
+
self.encoder = nn.Sequential(*encoder_layers)
|
88
|
+
|
89
|
+
self.encoder_fc_z_mu = nn.Linear(self.hidden_dim[-1], self.z_dim)
|
90
|
+
self.encoder_fc_z_logvar = nn.Linear(self.hidden_dim[-1], self.z_dim)
|
91
|
+
|
92
|
+
# Decoder P(X|Z)
|
93
|
+
decoder_layers = []
|
94
|
+
dims = [input_dim] + hidden_dim + [z_dim]
|
95
|
+
|
96
|
+
for i in range(len(dims) - 1, 0, -1):
|
97
|
+
decoder_layers.append(
|
98
|
+
Unit(
|
99
|
+
normalization="layer_norm",
|
100
|
+
in_features=dims[i],
|
101
|
+
out_features=dims[i - 1],
|
102
|
+
activation="relu",
|
103
|
+
dropout=0.5,
|
104
|
+
)
|
105
|
+
)
|
106
|
+
self.decoder = nn.Sequential(*decoder_layers)
|
107
|
+
|
108
|
+
self.init_parameters()
|
109
|
+
|
110
|
+
def init_parameters(self):
|
111
|
+
self.decoder[-1].fc.weight.data.zero_()
|
112
|
+
self.decoder[-1].fc.bias.data.zero_()
|
113
|
+
|
114
|
+
def reparameterize(self, mu, logvar):
|
115
|
+
std = torch.exp(0.5 * logvar)
|
116
|
+
eps = torch.randn_like(std)
|
117
|
+
return mu + eps * std
|
118
|
+
|
119
|
+
def forward(self, x):
|
120
|
+
hidden = self.encoder(x)
|
121
|
+
z_mu, z_logvar = self.encoder_fc_z_mu(hidden), self.encoder_fc_z_logvar(hidden)
|
122
|
+
z = self.reparameterize(z_mu, z_logvar)
|
123
|
+
|
124
|
+
noise_x = self.decoder(z)
|
125
|
+
recon_x = x + noise_x
|
126
|
+
return recon_x, z_mu, z_logvar
|
127
|
+
|
128
|
+
|
129
|
+
class Augmenter(nn.Module):
|
130
|
+
def __init__(
|
131
|
+
self,
|
132
|
+
arch_type: str,
|
133
|
+
input_dim: int,
|
134
|
+
z_dim: int,
|
135
|
+
num_layers: int,
|
136
|
+
adv_weight: float,
|
137
|
+
) -> None:
|
138
|
+
super().__init__()
|
139
|
+
logger.debug("Initializing Augmenter")
|
140
|
+
self.arch_type = arch_type
|
141
|
+
self.input_dim = input_dim
|
142
|
+
self.z_dim = z_dim
|
143
|
+
self.num_layers = num_layers
|
144
|
+
self.adv_weight = adv_weight
|
145
|
+
logger.debug(f"augmenter arch_type: {self.arch_type}")
|
146
|
+
logger.debug(f"augmenter input_dim: {self.input_dim}")
|
147
|
+
logger.debug(f"augmenter z_dim: {self.z_dim}")
|
148
|
+
logger.debug(f"augmenter num_layers: {self.num_layers}")
|
149
|
+
logger.debug(f"augmenter adv_weight: {self.adv_weight}")
|
150
|
+
if self.arch_type == "mlp_vae":
|
151
|
+
step = int((self.input_dim - self.z_dim) / (self.num_layers + 1))
|
152
|
+
hidden = [*range(self.input_dim - step, self.z_dim + step, -step)]
|
153
|
+
self.vae = MlpVAE(input_dim=self.input_dim, hidden_dim=hidden, z_dim=self.z_dim)
|
154
|
+
else:
|
155
|
+
raise ValueError(f"Unknown arch_type: {self.arch_type}")
|
156
|
+
|
157
|
+
self.name_to_id = self.get_layer_ids()
|
158
|
+
|
159
|
+
def forward(self, x):
|
160
|
+
return self.vae(x)
|
161
|
+
|
162
|
+
def get_layer_ids(
|
163
|
+
self,
|
164
|
+
):
|
165
|
+
"""
|
166
|
+
All layers have the same id 0 since there is no pre-trained models used here.
|
167
|
+
|
168
|
+
Returns
|
169
|
+
-------
|
170
|
+
A dictionary mapping the layer names (keys) to their ids (values).
|
171
|
+
"""
|
172
|
+
name_to_id = {}
|
173
|
+
for n, _ in self.named_parameters():
|
174
|
+
name_to_id[n] = 0
|
175
|
+
return name_to_id
|
@@ -1,4 +1,5 @@
|
|
1
|
-
|
1
|
+
import logging
|
2
|
+
from typing import Dict, Optional
|
2
3
|
|
3
4
|
import torch
|
4
5
|
from torch import nn
|
@@ -7,6 +8,8 @@ from ..constants import CATEGORICAL, FEATURES, LABEL, LOGITS
|
|
7
8
|
from .mlp import MLP
|
8
9
|
from .utils import init_weights
|
9
10
|
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
10
13
|
|
11
14
|
class CategoricalMLP(nn.Module):
|
12
15
|
"""
|
@@ -17,11 +20,11 @@ class CategoricalMLP(nn.Module):
|
|
17
20
|
def __init__(
|
18
21
|
self,
|
19
22
|
prefix: str,
|
20
|
-
num_categories:
|
23
|
+
num_categories: Dict,
|
21
24
|
out_features: Optional[int] = None,
|
22
25
|
num_layers: Optional[int] = 1,
|
23
26
|
activation: Optional[str] = "gelu",
|
24
|
-
|
27
|
+
dropout: Optional[float] = 0.5,
|
25
28
|
normalization: Optional[str] = "layer_norm",
|
26
29
|
num_classes: Optional[int] = 0,
|
27
30
|
):
|
@@ -38,7 +41,7 @@ class CategoricalMLP(nn.Module):
|
|
38
41
|
Number of MLP layers.
|
39
42
|
activation
|
40
43
|
Name of activation function.
|
41
|
-
|
44
|
+
dropout
|
42
45
|
Dropout probability.
|
43
46
|
normalization
|
44
47
|
Name of normalization function.
|
@@ -46,15 +49,17 @@ class CategoricalMLP(nn.Module):
|
|
46
49
|
Number of classes. 1 for a regression task.
|
47
50
|
"""
|
48
51
|
super().__init__()
|
52
|
+
logger.debug(f"initializing {prefix} (CategoricalMLP)")
|
49
53
|
self.out_features = out_features
|
50
54
|
max_embedding_dim = 100
|
51
55
|
embed_exponent = 0.56
|
52
56
|
size_factor = 1.0
|
53
57
|
self.column_embeddings = nn.ModuleList()
|
54
58
|
self.column_mlps = nn.ModuleList()
|
55
|
-
assert isinstance(num_categories,
|
59
|
+
assert isinstance(num_categories, dict)
|
60
|
+
self.num_categories = num_categories
|
56
61
|
|
57
|
-
for num_categories_per_col in num_categories:
|
62
|
+
for num_categories_per_col in num_categories.values():
|
58
63
|
embedding_dim_per_col = int(
|
59
64
|
size_factor * max(2, min(max_embedding_dim, 1.6 * num_categories_per_col**embed_exponent))
|
60
65
|
)
|
@@ -72,7 +77,7 @@ class CategoricalMLP(nn.Module):
|
|
72
77
|
out_features=out_features,
|
73
78
|
num_layers=num_layers,
|
74
79
|
activation=activation,
|
75
|
-
|
80
|
+
dropout=dropout,
|
76
81
|
normalization=normalization,
|
77
82
|
)
|
78
83
|
)
|
@@ -83,7 +88,7 @@ class CategoricalMLP(nn.Module):
|
|
83
88
|
out_features=out_features,
|
84
89
|
num_layers=num_layers,
|
85
90
|
activation=activation,
|
86
|
-
|
91
|
+
dropout=dropout,
|
87
92
|
normalization=normalization,
|
88
93
|
)
|
89
94
|
|
@@ -5,7 +5,6 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
|
7
7
|
from ..constants import (
|
8
|
-
AUTOMM,
|
9
8
|
COLUMN,
|
10
9
|
COLUMN_FEATURES,
|
11
10
|
FEATURES,
|
@@ -22,8 +21,12 @@ from .utils import (
|
|
22
21
|
assign_layer_ids,
|
23
22
|
get_column_features,
|
24
23
|
get_hf_config_and_model,
|
24
|
+
get_image_size_mean_std,
|
25
25
|
get_pretrained_tokenizer,
|
26
|
+
get_text_segment_num,
|
27
|
+
get_text_token_max_len,
|
26
28
|
init_weights,
|
29
|
+
replace_missing_images_with_learnable,
|
27
30
|
)
|
28
31
|
|
29
32
|
logger = logging.getLogger(__name__)
|
@@ -42,6 +45,15 @@ class CLIPForImageText(nn.Module):
|
|
42
45
|
num_classes: Optional[int] = None,
|
43
46
|
pretrained: Optional[bool] = True,
|
44
47
|
tokenizer_name: Optional[str] = "clip",
|
48
|
+
has_image: Optional[bool] = True,
|
49
|
+
has_text: Optional[bool] = True,
|
50
|
+
image_size: Optional[int] = None,
|
51
|
+
image_norm: Optional[str] = None,
|
52
|
+
image_chan_num: Optional[int] = 3,
|
53
|
+
use_learnable_image: Optional[bool] = False,
|
54
|
+
max_text_len: Optional[int] = None,
|
55
|
+
text_segment_num: Optional[int] = 1,
|
56
|
+
is_matching: Optional[bool] = False,
|
45
57
|
):
|
46
58
|
"""
|
47
59
|
Load the pretrained CLIP from huggingface transformers.
|
@@ -60,16 +72,26 @@ class CLIPForImageText(nn.Module):
|
|
60
72
|
Name of the huggingface tokenizer type.
|
61
73
|
"""
|
62
74
|
super().__init__()
|
63
|
-
logger.debug(f"initializing {
|
75
|
+
logger.debug(f"initializing {prefix} (CLIPForImageText)")
|
76
|
+
logger.debug(f"model checkpoint: {checkpoint_name}")
|
64
77
|
self.checkpoint_name = checkpoint_name
|
65
78
|
self.num_classes = num_classes
|
79
|
+
if is_matching: # init both image and text attributes for matching
|
80
|
+
has_image, has_text = True, True
|
81
|
+
self.has_image = has_image
|
82
|
+
self.has_text = has_text
|
66
83
|
|
67
84
|
self.config, self.model = get_hf_config_and_model(checkpoint_name=checkpoint_name, pretrained=pretrained)
|
68
|
-
|
69
|
-
self.
|
70
|
-
|
71
|
-
|
72
|
-
|
85
|
+
|
86
|
+
if not self.has_image:
|
87
|
+
self.config.vision_config = None
|
88
|
+
self.model.vision_model = None
|
89
|
+
self.model.visual_projection = None
|
90
|
+
|
91
|
+
if not self.has_text:
|
92
|
+
self.config.text_config = None
|
93
|
+
self.model.text_model = None
|
94
|
+
self.model.text_projection = None
|
73
95
|
|
74
96
|
self.out_features = self.model.config.projection_dim
|
75
97
|
|
@@ -77,6 +99,35 @@ class CLIPForImageText(nn.Module):
|
|
77
99
|
self.head.apply(init_weights)
|
78
100
|
|
79
101
|
self.prefix = prefix
|
102
|
+
if has_image:
|
103
|
+
self.image_size, self.image_mean, self.image_std = get_image_size_mean_std(
|
104
|
+
model_name=self.prefix,
|
105
|
+
config=self.model.vision_model.config,
|
106
|
+
provided_size=image_size,
|
107
|
+
provided_norm_type=image_norm,
|
108
|
+
support_variable_input_size=False,
|
109
|
+
)
|
110
|
+
self.use_learnable_image = use_learnable_image
|
111
|
+
if self.use_learnable_image:
|
112
|
+
self.learnable_image = nn.Parameter(torch.zeros(image_chan_num, self.image_size, self.image_size))
|
113
|
+
logger.debug("will use a learnable image to replace missing ones")
|
114
|
+
if has_text:
|
115
|
+
self.tokenizer_name = tokenizer_name
|
116
|
+
self.tokenizer = get_pretrained_tokenizer(
|
117
|
+
tokenizer_name=self.tokenizer_name,
|
118
|
+
checkpoint_name=self.checkpoint_name,
|
119
|
+
)
|
120
|
+
self.max_text_len = get_text_token_max_len(
|
121
|
+
provided_max_len=max_text_len,
|
122
|
+
config=self.model.text_model.config,
|
123
|
+
tokenizer=self.tokenizer,
|
124
|
+
checkpoint_name=self.checkpoint_name,
|
125
|
+
)
|
126
|
+
self.text_segment_num = get_text_segment_num(
|
127
|
+
config=self.model.text_model.config,
|
128
|
+
provided_segment_num=text_segment_num,
|
129
|
+
checkpoint_name=self.checkpoint_name,
|
130
|
+
)
|
80
131
|
|
81
132
|
self.name_to_id = self.get_layer_ids()
|
82
133
|
self.head_layer_names = [n for n, layer_id in self.name_to_id.items() if layer_id == 0]
|
@@ -117,6 +168,15 @@ class CLIPForImageText(nn.Module):
|
|
117
168
|
def image_feature_dim(self):
|
118
169
|
return self.model.config.vision_config.hidden_size
|
119
170
|
|
171
|
+
@property
|
172
|
+
def input_keys(self):
|
173
|
+
ret = []
|
174
|
+
if self.has_image:
|
175
|
+
ret.extend([self.image_key, self.image_valid_num_key])
|
176
|
+
if self.has_text:
|
177
|
+
ret.extend([self.text_token_ids_key, self.text_valid_length_key])
|
178
|
+
return ret
|
179
|
+
|
120
180
|
def forward(
|
121
181
|
self,
|
122
182
|
batch: dict,
|
@@ -132,8 +192,8 @@ class CLIPForImageText(nn.Module):
|
|
132
192
|
-------
|
133
193
|
A dictionary with logits and features.
|
134
194
|
"""
|
135
|
-
has_image = self.image_key in batch
|
136
|
-
has_text = self.text_token_ids_key in batch
|
195
|
+
has_image = self.has_image and self.image_key in batch
|
196
|
+
has_text = self.has_text and self.text_token_ids_key in batch
|
137
197
|
ret = {COLUMN_FEATURES: {FEATURES: {}, MASKS: {}}}
|
138
198
|
|
139
199
|
if has_image:
|
@@ -141,6 +201,14 @@ class CLIPForImageText(nn.Module):
|
|
141
201
|
image_valid_num = batch[self.image_valid_num_key]
|
142
202
|
assert images.dim() == 5
|
143
203
|
b, n, c, h, w = images.shape
|
204
|
+
steps = torch.arange(0, n).type_as(image_valid_num)
|
205
|
+
image_masks = steps.reshape((1, -1)) < image_valid_num.reshape((-1, 1)) # (b, n)
|
206
|
+
if self.use_learnable_image:
|
207
|
+
images = replace_missing_images_with_learnable(
|
208
|
+
images=images,
|
209
|
+
image_masks=image_masks,
|
210
|
+
learnable_image=self.learnable_image,
|
211
|
+
)
|
144
212
|
vision_outputs = self.model.vision_model(
|
145
213
|
pixel_values=images.reshape((b * n, c, h, w)),
|
146
214
|
output_attentions=True,
|
@@ -148,9 +216,9 @@ class CLIPForImageText(nn.Module):
|
|
148
216
|
return_dict=True,
|
149
217
|
)
|
150
218
|
image_features = self.model.visual_projection(vision_outputs.pooler_output)
|
151
|
-
|
152
|
-
|
153
|
-
|
219
|
+
image_features = image_features.reshape((b, n, -1)) # (b, n, num_features)
|
220
|
+
if not self.use_learnable_image:
|
221
|
+
image_features = image_features * image_masks[:, :, None].type_as(image_features)
|
154
222
|
|
155
223
|
# normalized features
|
156
224
|
image_features = image_features / torch.clamp(image_features.norm(dim=-1, keepdim=True), min=1e-6)
|
@@ -199,18 +267,24 @@ class CLIPForImageText(nn.Module):
|
|
199
267
|
ret[COLUMN_FEATURES][MASKS].update(text_column_feature_masks)
|
200
268
|
ret[FEATURES] = text_features
|
201
269
|
|
202
|
-
if
|
203
|
-
if
|
270
|
+
if self.num_classes:
|
271
|
+
if has_image and has_text:
|
204
272
|
features = image_features + text_features
|
205
273
|
logits = self.head(features)
|
206
274
|
ret[FEATURES] = features
|
275
|
+
elif has_image:
|
276
|
+
logits = self.head(image_features)
|
277
|
+
elif has_text:
|
278
|
+
logits = self.head(text_features)
|
207
279
|
else:
|
280
|
+
raise RuntimeError("Neither image or text are used. Must have at least one.")
|
281
|
+
ret[LOGITS] = logits
|
282
|
+
else:
|
283
|
+
ret[LOGIT_SCALE] = self.model.logit_scale.exp()
|
284
|
+
if has_image and has_text:
|
208
285
|
# cosine similarity as logits
|
209
286
|
logits = torch.sum(image_features * text_features, dim=-1)
|
210
|
-
|
211
|
-
ret[LOGITS] = logits
|
212
|
-
|
213
|
-
ret[LOGIT_SCALE] = self.model.logit_scale.exp()
|
287
|
+
ret[LOGITS] = logits
|
214
288
|
|
215
289
|
return {self.prefix: ret}
|
216
290
|
|