languagebind 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- languagebind/__init__.py +91 -0
- languagebind/_compat.py +24 -0
- languagebind/audio/__init__.py +0 -0
- languagebind/audio/configuration_audio.py +420 -0
- languagebind/audio/modeling_audio.py +1031 -0
- languagebind/audio/processing_audio.py +174 -0
- languagebind/audio/tokenization_audio.py +78 -0
- languagebind/depth/__init__.py +0 -0
- languagebind/depth/configuration_depth.py +415 -0
- languagebind/depth/modeling_depth.py +1031 -0
- languagebind/depth/processing_depth.py +108 -0
- languagebind/depth/tokenization_depth.py +78 -0
- languagebind/image/__init__.py +0 -0
- languagebind/image/configuration_image.py +413 -0
- languagebind/image/modeling_image.py +1031 -0
- languagebind/image/processing_image.py +77 -0
- languagebind/image/tokenization_image.py +78 -0
- languagebind/thermal/__init__.py +0 -0
- languagebind/thermal/configuration_thermal.py +413 -0
- languagebind/thermal/modeling_thermal.py +1031 -0
- languagebind/thermal/processing_thermal.py +77 -0
- languagebind/thermal/tokenization_thermal.py +78 -0
- languagebind/video/__init__.py +0 -0
- languagebind/video/configuration_video.py +413 -0
- languagebind/video/modeling_video.py +1143 -0
- languagebind/video/processing_video.py +174 -0
- languagebind/video/tokenization_video.py +78 -0
- languagebind-0.1.0.dist-info/METADATA +71 -0
- languagebind-0.1.0.dist-info/RECORD +30 -0
- languagebind-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1143 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from typing import Optional, Tuple, Union
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from einops import rearrange
|
|
6
|
+
from peft import LoraConfig, get_peft_model
|
|
7
|
+
from torch import nn
|
|
8
|
+
from torch.nn import functional as F
|
|
9
|
+
from transformers import PreTrainedModel, add_start_docstrings
|
|
10
|
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
|
11
|
+
from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
|
|
12
|
+
CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput
|
|
13
|
+
from languagebind._compat import _expand_mask, clip_loss
|
|
14
|
+
from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
|
|
15
|
+
|
|
16
|
+
from .configuration_video import LanguageBindVideoConfig, CLIPVisionConfig, CLIPTextConfig
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CLIPVisionEmbeddings(nn.Module):
|
|
21
|
+
def __init__(self, config: CLIPVisionConfig):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.config = config
|
|
24
|
+
self.embed_dim = config.hidden_size
|
|
25
|
+
self.image_size = config.image_size
|
|
26
|
+
self.patch_size = config.patch_size
|
|
27
|
+
|
|
28
|
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
|
29
|
+
|
|
30
|
+
self.patch_embedding = nn.Conv2d(
|
|
31
|
+
in_channels=config.num_channels,
|
|
32
|
+
out_channels=self.embed_dim,
|
|
33
|
+
kernel_size=self.patch_size,
|
|
34
|
+
stride=self.patch_size,
|
|
35
|
+
bias=False,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
39
|
+
self.num_positions = self.num_patches + 1
|
|
40
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
41
|
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
|
42
|
+
|
|
43
|
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
44
|
+
# (b t) c h w
|
|
45
|
+
batch_size = pixel_values.shape[0]
|
|
46
|
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
|
47
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
|
48
|
+
|
|
49
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
|
50
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
|
51
|
+
embeddings = embeddings + self.position_embedding(self.position_ids) # b hw c
|
|
52
|
+
return embeddings
|
|
53
|
+
|
|
54
|
+
class CLIPVisionEmbeddings3D(nn.Module):
|
|
55
|
+
def __init__(self, config: CLIPVisionConfig):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.config = config
|
|
58
|
+
self.embed_dim = config.hidden_size
|
|
59
|
+
self.image_size = config.image_size
|
|
60
|
+
self.patch_size = config.patch_size
|
|
61
|
+
self.num_frames = config.num_frames
|
|
62
|
+
self.tube_size = getattr(config, 'tube_size', 1)
|
|
63
|
+
|
|
64
|
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
|
65
|
+
|
|
66
|
+
self.patch_embedding = nn.Conv2d(
|
|
67
|
+
in_channels=config.num_channels,
|
|
68
|
+
out_channels=self.embed_dim,
|
|
69
|
+
kernel_size=self.patch_size,
|
|
70
|
+
stride=self.patch_size,
|
|
71
|
+
bias=False,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
|
75
|
+
self.num_positions = self.num_patches + 1
|
|
76
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
|
77
|
+
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
|
78
|
+
|
|
79
|
+
self.expand3d()
|
|
80
|
+
|
|
81
|
+
def expand3d(self):
|
|
82
|
+
|
|
83
|
+
state_dict = self.patch_embedding.state_dict()
|
|
84
|
+
state_dict_expand = state_dict['weight'].unsqueeze(2)
|
|
85
|
+
device, dtype = state_dict_expand.device, state_dict_expand.dtype
|
|
86
|
+
# print(device, dtype)
|
|
87
|
+
|
|
88
|
+
zero = torch.zeros_like(state_dict_expand).to(device=device, dtype=dtype)
|
|
89
|
+
state_dict_expand3d = torch.cat([state_dict_expand] + (self.tube_size-1)*[zero], dim=2)
|
|
90
|
+
|
|
91
|
+
# state_dict_expand3d = torch.cat([state_dict_expand / self.tube_size] * self.tube_size, dim=2)
|
|
92
|
+
|
|
93
|
+
patch_embedding = nn.Conv3d(
|
|
94
|
+
in_channels=self.patch_embedding.in_channels,
|
|
95
|
+
out_channels=self.embed_dim,
|
|
96
|
+
kernel_size=(self.tube_size, self.patch_size, self.patch_size),
|
|
97
|
+
stride=(self.tube_size, self.patch_size, self.patch_size),
|
|
98
|
+
bias=False,
|
|
99
|
+
).to(device=device, dtype=dtype)
|
|
100
|
+
patch_embedding.load_state_dict({'weight': state_dict_expand3d})
|
|
101
|
+
self.patch_embedding = patch_embedding
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class_embedding = nn.Parameter(self.class_embedding.data.repeat(self.num_frames // self.tube_size, 1)).to(device=device, dtype=dtype)
|
|
105
|
+
self.class_embedding = class_embedding
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
|
109
|
+
# (b t) c h w
|
|
110
|
+
batch_size = pixel_values.shape[0] // self.num_frames
|
|
111
|
+
pixel_values = rearrange(pixel_values, '(b t) c h w -> b c t h w', b=batch_size, t=self.num_frames)
|
|
112
|
+
# print('pixel_values', pixel_values.shape)
|
|
113
|
+
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, t, grid, grid]
|
|
114
|
+
# print('patch_embeds', patch_embeds.shape)
|
|
115
|
+
# SET_GLOBAL_VALUE('NUM_FRAMES', patch_embeds.shape[2])
|
|
116
|
+
patch_embeds = rearrange(patch_embeds, 'b c t h w -> b t (h w) c')
|
|
117
|
+
|
|
118
|
+
class_embeds = self.class_embedding.unsqueeze(1).unsqueeze(0).repeat(batch_size, 1, 1, 1) # b t 1 c
|
|
119
|
+
# print('class_embeds', class_embeds.device, class_embeds.dtype)
|
|
120
|
+
# print('patch_embeds', patch_embeds.device, patch_embeds.dtype)
|
|
121
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=2) # b t hw+1 c
|
|
122
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
|
123
|
+
embeddings = rearrange(embeddings, 'b t hw_1 c -> (b t) hw_1 c')
|
|
124
|
+
return embeddings
|
|
125
|
+
|
|
126
|
+
class PatchDropout(nn.Module):
|
|
127
|
+
"""
|
|
128
|
+
https://arxiv.org/abs/2212.00794
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, prob, exclude_first_token=True):
|
|
132
|
+
super().__init__()
|
|
133
|
+
assert 0 <= prob < 1.
|
|
134
|
+
self.prob = prob
|
|
135
|
+
self.exclude_first_token = exclude_first_token # exclude CLS token
|
|
136
|
+
|
|
137
|
+
def forward(self, x, B, T):
|
|
138
|
+
if not self.training or self.prob == 0.:
|
|
139
|
+
return x
|
|
140
|
+
|
|
141
|
+
if self.exclude_first_token:
|
|
142
|
+
cls_tokens, x = x[:, :1], x[:, 1:]
|
|
143
|
+
else:
|
|
144
|
+
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
|
|
145
|
+
|
|
146
|
+
batch = x.size()[0]
|
|
147
|
+
num_tokens = x.size()[1]
|
|
148
|
+
|
|
149
|
+
batch_indices = torch.arange(batch)
|
|
150
|
+
batch_indices = batch_indices[..., None]
|
|
151
|
+
|
|
152
|
+
keep_prob = 1 - self.prob
|
|
153
|
+
num_patches_keep = max(1, int(num_tokens * keep_prob))
|
|
154
|
+
|
|
155
|
+
if T == 1:
|
|
156
|
+
rand = torch.randn(batch, num_tokens)
|
|
157
|
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
|
158
|
+
else:
|
|
159
|
+
rand = torch.randn(B, num_tokens)
|
|
160
|
+
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
|
|
161
|
+
patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
|
|
162
|
+
patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
x = x[batch_indices, patch_indices_keep]
|
|
166
|
+
|
|
167
|
+
if self.exclude_first_token:
|
|
168
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
169
|
+
|
|
170
|
+
return x
|
|
171
|
+
|
|
172
|
+
class CLIPEncoderLayer(nn.Module):
|
|
173
|
+
def __init__(self, config: LanguageBindVideoConfig):
|
|
174
|
+
super().__init__()
|
|
175
|
+
self.embed_dim = config.hidden_size
|
|
176
|
+
self.self_attn = CLIPAttention(config)
|
|
177
|
+
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
178
|
+
self.mlp = CLIPMLP(config)
|
|
179
|
+
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
180
|
+
|
|
181
|
+
self.add_time_attn = config.add_time_attn
|
|
182
|
+
if self.add_time_attn:
|
|
183
|
+
self.t = config.num_frames
|
|
184
|
+
self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
|
|
185
|
+
nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
|
|
186
|
+
|
|
187
|
+
self.embed_dim = config.hidden_size
|
|
188
|
+
self.temporal_attn = CLIPAttention(config)
|
|
189
|
+
self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
190
|
+
# self.temporal_mlp = CLIPMLP(config)
|
|
191
|
+
# self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
|
192
|
+
|
|
193
|
+
def forward(
|
|
194
|
+
self,
|
|
195
|
+
hidden_states: torch.Tensor,
|
|
196
|
+
attention_mask: torch.Tensor,
|
|
197
|
+
causal_attention_mask: torch.Tensor,
|
|
198
|
+
output_attentions: Optional[bool] = False,
|
|
199
|
+
) -> Tuple[torch.FloatTensor]:
|
|
200
|
+
"""
|
|
201
|
+
Args:
|
|
202
|
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
|
203
|
+
attention_mask (`torch.FloatTensor`): attention mask of size
|
|
204
|
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
|
205
|
+
`(config.encoder_attention_heads,)`.
|
|
206
|
+
output_attentions (`bool`, *optional*):
|
|
207
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
208
|
+
returned tensors for more detail.
|
|
209
|
+
"""
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
if self.add_time_attn:
|
|
213
|
+
bt, n, d = hidden_states.shape
|
|
214
|
+
t = self.t
|
|
215
|
+
|
|
216
|
+
# time embed
|
|
217
|
+
if t != 1:
|
|
218
|
+
n = hidden_states.shape[1]
|
|
219
|
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
|
220
|
+
hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
|
|
221
|
+
hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
|
222
|
+
|
|
223
|
+
# time attn
|
|
224
|
+
residual = hidden_states
|
|
225
|
+
hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
|
226
|
+
# hidden_states = self.layer_norm1(hidden_states) # share layernorm
|
|
227
|
+
hidden_states = self.temporal_layer_norm1(hidden_states)
|
|
228
|
+
hidden_states, attn_weights = self.temporal_attn(
|
|
229
|
+
hidden_states=hidden_states,
|
|
230
|
+
attention_mask=attention_mask,
|
|
231
|
+
causal_attention_mask=causal_attention_mask,
|
|
232
|
+
output_attentions=output_attentions,
|
|
233
|
+
)
|
|
234
|
+
hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
|
235
|
+
|
|
236
|
+
# residual = hidden_states
|
|
237
|
+
# hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
|
|
238
|
+
# # hidden_states = self.layer_norm2(hidden_states) # share layernorm
|
|
239
|
+
# hidden_states = self.temporal_layer_norm2(hidden_states)
|
|
240
|
+
# hidden_states = self.temporal_mlp(hidden_states)
|
|
241
|
+
# hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
|
|
242
|
+
|
|
243
|
+
# spatial attn
|
|
244
|
+
residual = hidden_states
|
|
245
|
+
|
|
246
|
+
hidden_states = self.layer_norm1(hidden_states)
|
|
247
|
+
hidden_states, attn_weights = self.self_attn(
|
|
248
|
+
hidden_states=hidden_states,
|
|
249
|
+
attention_mask=attention_mask,
|
|
250
|
+
causal_attention_mask=causal_attention_mask,
|
|
251
|
+
output_attentions=output_attentions,
|
|
252
|
+
)
|
|
253
|
+
hidden_states = residual + hidden_states
|
|
254
|
+
|
|
255
|
+
residual = hidden_states
|
|
256
|
+
hidden_states = self.layer_norm2(hidden_states)
|
|
257
|
+
hidden_states = self.mlp(hidden_states)
|
|
258
|
+
hidden_states = residual + hidden_states
|
|
259
|
+
|
|
260
|
+
outputs = (hidden_states,)
|
|
261
|
+
|
|
262
|
+
if output_attentions:
|
|
263
|
+
outputs += (attn_weights,)
|
|
264
|
+
|
|
265
|
+
return outputs
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class CLIPPreTrainedModel(PreTrainedModel):
|
|
276
|
+
"""
|
|
277
|
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
|
278
|
+
models.
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
config_class = LanguageBindVideoConfig
|
|
282
|
+
base_model_prefix = "clip"
|
|
283
|
+
supports_gradient_checkpointing = True
|
|
284
|
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
|
285
|
+
|
|
286
|
+
def _init_weights(self, module):
|
|
287
|
+
"""Initialize the weights"""
|
|
288
|
+
factor = self.config.initializer_factor
|
|
289
|
+
if isinstance(module, CLIPTextEmbeddings):
|
|
290
|
+
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
291
|
+
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
|
|
292
|
+
elif isinstance(module, CLIPVisionEmbeddings):
|
|
293
|
+
factor = self.config.initializer_factor
|
|
294
|
+
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
|
|
295
|
+
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
|
|
296
|
+
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
|
|
297
|
+
elif isinstance(module, CLIPAttention):
|
|
298
|
+
factor = self.config.initializer_factor
|
|
299
|
+
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
300
|
+
out_proj_std = (module.embed_dim**-0.5) * factor
|
|
301
|
+
nn.init.normal_(module.q_proj.weight, std=in_proj_std)
|
|
302
|
+
nn.init.normal_(module.k_proj.weight, std=in_proj_std)
|
|
303
|
+
nn.init.normal_(module.v_proj.weight, std=in_proj_std)
|
|
304
|
+
nn.init.normal_(module.out_proj.weight, std=out_proj_std)
|
|
305
|
+
elif isinstance(module, CLIPMLP):
|
|
306
|
+
factor = self.config.initializer_factor
|
|
307
|
+
in_proj_std = (
|
|
308
|
+
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
|
|
309
|
+
)
|
|
310
|
+
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
|
|
311
|
+
nn.init.normal_(module.fc1.weight, std=fc_std)
|
|
312
|
+
nn.init.normal_(module.fc2.weight, std=in_proj_std)
|
|
313
|
+
elif isinstance(module, LanguageBindVideo):
|
|
314
|
+
nn.init.normal_(
|
|
315
|
+
module.text_projection.weight,
|
|
316
|
+
std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
|
|
317
|
+
)
|
|
318
|
+
nn.init.normal_(
|
|
319
|
+
module.visual_projection.weight,
|
|
320
|
+
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
|
|
321
|
+
)
|
|
322
|
+
elif isinstance(module, CLIPVisionModelWithProjection):
|
|
323
|
+
nn.init.normal_(
|
|
324
|
+
module.visual_projection.weight,
|
|
325
|
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
|
326
|
+
)
|
|
327
|
+
elif isinstance(module, CLIPTextModelWithProjection):
|
|
328
|
+
nn.init.normal_(
|
|
329
|
+
module.text_projection.weight,
|
|
330
|
+
std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
if isinstance(module, nn.LayerNorm):
|
|
334
|
+
module.bias.data.zero_()
|
|
335
|
+
module.weight.data.fill_(1.0)
|
|
336
|
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
337
|
+
module.bias.data.zero_()
|
|
338
|
+
|
|
339
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
|
340
|
+
if isinstance(module, CLIPEncoder):
|
|
341
|
+
module.gradient_checkpointing = value
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
CLIP_START_DOCSTRING = r"""
|
|
345
|
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
|
346
|
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
|
347
|
+
etc.)
|
|
348
|
+
|
|
349
|
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
|
350
|
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
|
351
|
+
and behavior.
|
|
352
|
+
|
|
353
|
+
Parameters:
|
|
354
|
+
config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
|
|
355
|
+
Initializing with a config file does not load the weights associated with the model, only the
|
|
356
|
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
|
357
|
+
"""
|
|
358
|
+
|
|
359
|
+
CLIP_TEXT_INPUTS_DOCSTRING = r"""
|
|
360
|
+
Args:
|
|
361
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
362
|
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
363
|
+
it.
|
|
364
|
+
|
|
365
|
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
366
|
+
[`PreTrainedTokenizer.__call__`] for details.
|
|
367
|
+
|
|
368
|
+
[What are input IDs?](../glossary#input-ids)
|
|
369
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
370
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
371
|
+
|
|
372
|
+
- 1 for tokens that are **not masked**,
|
|
373
|
+
- 0 for tokens that are **masked**.
|
|
374
|
+
|
|
375
|
+
[What are attention masks?](../glossary#attention-mask)
|
|
376
|
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
377
|
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
378
|
+
config.max_position_embeddings - 1]`.
|
|
379
|
+
|
|
380
|
+
[What are position IDs?](../glossary#position-ids)
|
|
381
|
+
output_attentions (`bool`, *optional*):
|
|
382
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
383
|
+
tensors for more detail.
|
|
384
|
+
output_hidden_states (`bool`, *optional*):
|
|
385
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
386
|
+
more detail.
|
|
387
|
+
return_dict (`bool`, *optional*):
|
|
388
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
CLIP_VISION_INPUTS_DOCSTRING = r"""
|
|
392
|
+
Args:
|
|
393
|
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
394
|
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
395
|
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
|
396
|
+
output_attentions (`bool`, *optional*):
|
|
397
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
398
|
+
tensors for more detail.
|
|
399
|
+
output_hidden_states (`bool`, *optional*):
|
|
400
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
401
|
+
more detail.
|
|
402
|
+
return_dict (`bool`, *optional*):
|
|
403
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
404
|
+
"""
|
|
405
|
+
|
|
406
|
+
CLIP_INPUTS_DOCSTRING = r"""
|
|
407
|
+
Args:
|
|
408
|
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
409
|
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
|
410
|
+
it.
|
|
411
|
+
|
|
412
|
+
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
|
413
|
+
[`PreTrainedTokenizer.__call__`] for details.
|
|
414
|
+
|
|
415
|
+
[What are input IDs?](../glossary#input-ids)
|
|
416
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
417
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
418
|
+
|
|
419
|
+
- 1 for tokens that are **not masked**,
|
|
420
|
+
- 0 for tokens that are **masked**.
|
|
421
|
+
|
|
422
|
+
[What are attention masks?](../glossary#attention-mask)
|
|
423
|
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
424
|
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
|
425
|
+
config.max_position_embeddings - 1]`.
|
|
426
|
+
|
|
427
|
+
[What are position IDs?](../glossary#position-ids)
|
|
428
|
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
|
429
|
+
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
|
|
430
|
+
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
|
|
431
|
+
return_loss (`bool`, *optional*):
|
|
432
|
+
Whether or not to return the contrastive loss.
|
|
433
|
+
output_attentions (`bool`, *optional*):
|
|
434
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
|
435
|
+
tensors for more detail.
|
|
436
|
+
output_hidden_states (`bool`, *optional*):
|
|
437
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
|
438
|
+
more detail.
|
|
439
|
+
return_dict (`bool`, *optional*):
|
|
440
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
441
|
+
"""
|
|
442
|
+
|
|
443
|
+
|
|
444
|
+
class CLIPEncoder(nn.Module):
|
|
445
|
+
"""
|
|
446
|
+
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
|
447
|
+
[`CLIPEncoderLayer`].
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
config: CLIPConfig
|
|
451
|
+
"""
|
|
452
|
+
|
|
453
|
+
def __init__(self, config: LanguageBindVideoConfig):
|
|
454
|
+
super().__init__()
|
|
455
|
+
self.config = config
|
|
456
|
+
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
|
457
|
+
self.gradient_checkpointing = False
|
|
458
|
+
|
|
459
|
+
def forward(
|
|
460
|
+
self,
|
|
461
|
+
inputs_embeds,
|
|
462
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
463
|
+
causal_attention_mask: Optional[torch.Tensor] = None,
|
|
464
|
+
output_attentions: Optional[bool] = None,
|
|
465
|
+
output_hidden_states: Optional[bool] = None,
|
|
466
|
+
return_dict: Optional[bool] = None,
|
|
467
|
+
) -> Union[Tuple, BaseModelOutput]:
|
|
468
|
+
r"""
|
|
469
|
+
Args:
|
|
470
|
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
|
471
|
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
|
472
|
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
|
473
|
+
than the model's internal embedding lookup matrix.
|
|
474
|
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
475
|
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
476
|
+
|
|
477
|
+
- 1 for tokens that are **not masked**,
|
|
478
|
+
- 0 for tokens that are **masked**.
|
|
479
|
+
|
|
480
|
+
[What are attention masks?](../glossary#attention-mask)
|
|
481
|
+
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
482
|
+
Causal mask for the text model. Mask values selected in `[0, 1]`:
|
|
483
|
+
|
|
484
|
+
- 1 for tokens that are **not masked**,
|
|
485
|
+
- 0 for tokens that are **masked**.
|
|
486
|
+
|
|
487
|
+
[What are attention masks?](../glossary#attention-mask)
|
|
488
|
+
output_attentions (`bool`, *optional*):
|
|
489
|
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
|
490
|
+
returned tensors for more detail.
|
|
491
|
+
output_hidden_states (`bool`, *optional*):
|
|
492
|
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
|
493
|
+
for more detail.
|
|
494
|
+
return_dict (`bool`, *optional*):
|
|
495
|
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
|
496
|
+
"""
|
|
497
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
498
|
+
output_hidden_states = (
|
|
499
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
500
|
+
)
|
|
501
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
502
|
+
|
|
503
|
+
encoder_states = () if output_hidden_states else None
|
|
504
|
+
all_attentions = () if output_attentions else None
|
|
505
|
+
|
|
506
|
+
hidden_states = inputs_embeds
|
|
507
|
+
for idx, encoder_layer in enumerate(self.layers):
|
|
508
|
+
if output_hidden_states:
|
|
509
|
+
encoder_states = encoder_states + (hidden_states,)
|
|
510
|
+
if self.gradient_checkpointing and self.training:
|
|
511
|
+
|
|
512
|
+
def create_custom_forward(module):
|
|
513
|
+
def custom_forward(*inputs):
|
|
514
|
+
return module(*inputs, output_attentions)
|
|
515
|
+
|
|
516
|
+
return custom_forward
|
|
517
|
+
|
|
518
|
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
|
519
|
+
create_custom_forward(encoder_layer),
|
|
520
|
+
hidden_states,
|
|
521
|
+
attention_mask,
|
|
522
|
+
causal_attention_mask,
|
|
523
|
+
)
|
|
524
|
+
else:
|
|
525
|
+
layer_outputs = encoder_layer(
|
|
526
|
+
hidden_states,
|
|
527
|
+
attention_mask,
|
|
528
|
+
causal_attention_mask,
|
|
529
|
+
output_attentions=output_attentions,
|
|
530
|
+
)
|
|
531
|
+
|
|
532
|
+
hidden_states = layer_outputs[0]
|
|
533
|
+
|
|
534
|
+
if output_attentions:
|
|
535
|
+
all_attentions = all_attentions + (layer_outputs[1],)
|
|
536
|
+
|
|
537
|
+
if output_hidden_states:
|
|
538
|
+
encoder_states = encoder_states + (hidden_states,)
|
|
539
|
+
|
|
540
|
+
if not return_dict:
|
|
541
|
+
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
|
542
|
+
return BaseModelOutput(
|
|
543
|
+
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
|
|
547
|
+
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
|
548
|
+
def _make_causal_mask(
|
|
549
|
+
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
550
|
+
):
|
|
551
|
+
"""
|
|
552
|
+
Make causal mask used for bi-directional self-attention.
|
|
553
|
+
"""
|
|
554
|
+
bsz, tgt_len = input_ids_shape
|
|
555
|
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
|
556
|
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
|
557
|
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
|
558
|
+
mask = mask.to(dtype)
|
|
559
|
+
|
|
560
|
+
if past_key_values_length > 0:
|
|
561
|
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
|
562
|
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
|
563
|
+
|
|
564
|
+
|
|
565
|
+
class CLIPTextTransformer(nn.Module):
|
|
566
|
+
def __init__(self, config: CLIPTextConfig):
|
|
567
|
+
super().__init__()
|
|
568
|
+
self.config = config
|
|
569
|
+
embed_dim = config.hidden_size
|
|
570
|
+
self.embeddings = CLIPTextEmbeddings(config)
|
|
571
|
+
self.encoder = CLIPEncoder(config)
|
|
572
|
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
573
|
+
|
|
574
|
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
|
575
|
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
|
576
|
+
def forward(
|
|
577
|
+
self,
|
|
578
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
579
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
580
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
581
|
+
output_attentions: Optional[bool] = None,
|
|
582
|
+
output_hidden_states: Optional[bool] = None,
|
|
583
|
+
return_dict: Optional[bool] = None,
|
|
584
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
585
|
+
r"""
|
|
586
|
+
Returns:
|
|
587
|
+
|
|
588
|
+
"""
|
|
589
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
590
|
+
output_hidden_states = (
|
|
591
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
592
|
+
)
|
|
593
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
594
|
+
|
|
595
|
+
if input_ids is None:
|
|
596
|
+
raise ValueError("You have to specify input_ids")
|
|
597
|
+
|
|
598
|
+
input_shape = input_ids.size()
|
|
599
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
|
600
|
+
|
|
601
|
+
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
|
|
602
|
+
|
|
603
|
+
# CLIP's text model uses causal mask, prepare it here.
|
|
604
|
+
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
|
|
605
|
+
causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
|
|
606
|
+
# expand attention_mask
|
|
607
|
+
if attention_mask is not None:
|
|
608
|
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
|
609
|
+
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
|
|
610
|
+
|
|
611
|
+
encoder_outputs = self.encoder(
|
|
612
|
+
inputs_embeds=hidden_states,
|
|
613
|
+
attention_mask=attention_mask,
|
|
614
|
+
causal_attention_mask=causal_attention_mask,
|
|
615
|
+
output_attentions=output_attentions,
|
|
616
|
+
output_hidden_states=output_hidden_states,
|
|
617
|
+
return_dict=return_dict,
|
|
618
|
+
)
|
|
619
|
+
|
|
620
|
+
last_hidden_state = encoder_outputs[0]
|
|
621
|
+
last_hidden_state = self.final_layer_norm(last_hidden_state)
|
|
622
|
+
|
|
623
|
+
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
|
|
624
|
+
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
|
625
|
+
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
|
|
626
|
+
pooled_output = last_hidden_state[
|
|
627
|
+
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
|
|
628
|
+
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
|
|
629
|
+
]
|
|
630
|
+
|
|
631
|
+
if not return_dict:
|
|
632
|
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
633
|
+
|
|
634
|
+
return BaseModelOutputWithPooling(
|
|
635
|
+
last_hidden_state=last_hidden_state,
|
|
636
|
+
pooler_output=pooled_output,
|
|
637
|
+
hidden_states=encoder_outputs.hidden_states,
|
|
638
|
+
attentions=encoder_outputs.attentions,
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
@add_start_docstrings(
|
|
643
|
+
"""The text model from CLIP without any head or projection on top.""",
|
|
644
|
+
CLIP_START_DOCSTRING,
|
|
645
|
+
)
|
|
646
|
+
class CLIPTextModel(CLIPPreTrainedModel):
|
|
647
|
+
config_class = CLIPTextConfig
|
|
648
|
+
|
|
649
|
+
_no_split_modules = ["CLIPEncoderLayer"]
|
|
650
|
+
|
|
651
|
+
def __init__(self, config: CLIPTextConfig):
|
|
652
|
+
super().__init__(config)
|
|
653
|
+
self.text_model = CLIPTextTransformer(config)
|
|
654
|
+
# Initialize weights and apply final processing
|
|
655
|
+
self.post_init()
|
|
656
|
+
|
|
657
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
658
|
+
return self.text_model.embeddings.token_embedding
|
|
659
|
+
|
|
660
|
+
def set_input_embeddings(self, value):
|
|
661
|
+
self.text_model.embeddings.token_embedding = value
|
|
662
|
+
|
|
663
|
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
|
664
|
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
|
|
665
|
+
def forward(
|
|
666
|
+
self,
|
|
667
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
668
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
669
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
670
|
+
output_attentions: Optional[bool] = None,
|
|
671
|
+
output_hidden_states: Optional[bool] = None,
|
|
672
|
+
return_dict: Optional[bool] = None,
|
|
673
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
674
|
+
r"""
|
|
675
|
+
Returns:
|
|
676
|
+
|
|
677
|
+
Examples:
|
|
678
|
+
|
|
679
|
+
```python
|
|
680
|
+
>>> from transformers import AutoTokenizer, CLIPTextModel
|
|
681
|
+
|
|
682
|
+
>>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
683
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
|
684
|
+
|
|
685
|
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
|
686
|
+
|
|
687
|
+
>>> outputs = model(**inputs)
|
|
688
|
+
>>> last_hidden_state = outputs.last_hidden_state
|
|
689
|
+
>>> pooled_output = outputs.pooler_output # pooled (EOS token) states
|
|
690
|
+
```"""
|
|
691
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
692
|
+
|
|
693
|
+
return self.text_model(
|
|
694
|
+
input_ids=input_ids,
|
|
695
|
+
attention_mask=attention_mask,
|
|
696
|
+
position_ids=position_ids,
|
|
697
|
+
output_attentions=output_attentions,
|
|
698
|
+
output_hidden_states=output_hidden_states,
|
|
699
|
+
return_dict=return_dict,
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
class CLIPVisionTransformer(nn.Module):
|
|
704
|
+
def __init__(self, config: CLIPVisionConfig):
|
|
705
|
+
super().__init__()
|
|
706
|
+
self.config = config
|
|
707
|
+
embed_dim = config.hidden_size
|
|
708
|
+
vl_new = getattr(config, 'clip_type', 'vl') == 'vl_new'
|
|
709
|
+
add_time_attn = config.add_time_attn
|
|
710
|
+
# self.embeddings = CLIPVisionEmbeddings(config)
|
|
711
|
+
if add_time_attn:
|
|
712
|
+
if vl_new:
|
|
713
|
+
self.embeddings = CLIPVisionEmbeddings3D(config)
|
|
714
|
+
else:
|
|
715
|
+
self.embeddings = CLIPVisionEmbeddings(config)
|
|
716
|
+
|
|
717
|
+
self.patch_dropout = PatchDropout(config.force_patch_dropout)
|
|
718
|
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
719
|
+
self.encoder = CLIPEncoder(config)
|
|
720
|
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
|
721
|
+
|
|
722
|
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
|
723
|
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
|
724
|
+
def forward(
|
|
725
|
+
self,
|
|
726
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
727
|
+
output_attentions: Optional[bool] = None,
|
|
728
|
+
output_hidden_states: Optional[bool] = None,
|
|
729
|
+
return_dict: Optional[bool] = None,
|
|
730
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
731
|
+
r"""
|
|
732
|
+
Returns:
|
|
733
|
+
|
|
734
|
+
"""
|
|
735
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
736
|
+
output_hidden_states = (
|
|
737
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
738
|
+
)
|
|
739
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
740
|
+
|
|
741
|
+
if pixel_values is None:
|
|
742
|
+
raise ValueError("You have to specify pixel_values")
|
|
743
|
+
######################################
|
|
744
|
+
if len(pixel_values.shape) == 7:
|
|
745
|
+
b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
|
|
746
|
+
# print(pixel_values.shape)
|
|
747
|
+
B = b_new * pair_new * bs_new
|
|
748
|
+
pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
|
|
749
|
+
|
|
750
|
+
elif len(pixel_values.shape) == 5:
|
|
751
|
+
B, _, T, _, _ = pixel_values.shape
|
|
752
|
+
# print(pixel_values.shape)
|
|
753
|
+
pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
|
|
754
|
+
else:
|
|
755
|
+
# print(pixel_values.shape)
|
|
756
|
+
B, _, _, _ = pixel_values.shape
|
|
757
|
+
T = 1
|
|
758
|
+
###########################
|
|
759
|
+
hidden_states = self.embeddings(pixel_values)
|
|
760
|
+
|
|
761
|
+
hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
|
|
762
|
+
|
|
763
|
+
hidden_states = self.pre_layrnorm(hidden_states)
|
|
764
|
+
|
|
765
|
+
encoder_outputs = self.encoder(
|
|
766
|
+
inputs_embeds=hidden_states,
|
|
767
|
+
output_attentions=output_attentions,
|
|
768
|
+
output_hidden_states=output_hidden_states,
|
|
769
|
+
return_dict=return_dict,
|
|
770
|
+
)
|
|
771
|
+
|
|
772
|
+
last_hidden_state = encoder_outputs[0]
|
|
773
|
+
pooled_output = last_hidden_state[:, 0, :]
|
|
774
|
+
pooled_output = self.post_layernorm(pooled_output)
|
|
775
|
+
|
|
776
|
+
pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
|
|
777
|
+
if not return_dict:
|
|
778
|
+
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
|
779
|
+
|
|
780
|
+
return BaseModelOutputWithPooling(
|
|
781
|
+
last_hidden_state=last_hidden_state,
|
|
782
|
+
pooler_output=pooled_output,
|
|
783
|
+
hidden_states=encoder_outputs.hidden_states,
|
|
784
|
+
attentions=encoder_outputs.attentions,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
@add_start_docstrings(
|
|
789
|
+
"""The vision model from CLIP without any head or projection on top.""",
|
|
790
|
+
CLIP_START_DOCSTRING,
|
|
791
|
+
)
|
|
792
|
+
class CLIPVisionModel(CLIPPreTrainedModel):
|
|
793
|
+
config_class = CLIPVisionConfig
|
|
794
|
+
main_input_name = "pixel_values"
|
|
795
|
+
|
|
796
|
+
def __init__(self, config: CLIPVisionConfig):
|
|
797
|
+
super().__init__(config)
|
|
798
|
+
self.vision_model = CLIPVisionTransformer(config)
|
|
799
|
+
# Initialize weights and apply final processing
|
|
800
|
+
self.post_init()
|
|
801
|
+
|
|
802
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
803
|
+
return self.vision_model.embeddings.patch_embedding
|
|
804
|
+
|
|
805
|
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
|
806
|
+
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
|
|
807
|
+
def forward(
|
|
808
|
+
self,
|
|
809
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
810
|
+
output_attentions: Optional[bool] = None,
|
|
811
|
+
output_hidden_states: Optional[bool] = None,
|
|
812
|
+
return_dict: Optional[bool] = None,
|
|
813
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
814
|
+
r"""
|
|
815
|
+
Returns:
|
|
816
|
+
|
|
817
|
+
Examples:
|
|
818
|
+
|
|
819
|
+
```python
|
|
820
|
+
>>> from PIL import Image
|
|
821
|
+
>>> import requests
|
|
822
|
+
>>> from transformers import AutoProcessor, CLIPVisionModel
|
|
823
|
+
|
|
824
|
+
>>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
825
|
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
826
|
+
|
|
827
|
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
828
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
829
|
+
|
|
830
|
+
>>> inputs = processor(images=image, return_tensors="pt")
|
|
831
|
+
|
|
832
|
+
>>> outputs = model(**inputs)
|
|
833
|
+
>>> last_hidden_state = outputs.last_hidden_state
|
|
834
|
+
>>> pooled_output = outputs.pooler_output # pooled CLS states
|
|
835
|
+
```"""
|
|
836
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
837
|
+
|
|
838
|
+
return self.vision_model(
|
|
839
|
+
pixel_values=pixel_values,
|
|
840
|
+
output_attentions=output_attentions,
|
|
841
|
+
output_hidden_states=output_hidden_states,
|
|
842
|
+
return_dict=return_dict,
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
|
|
846
|
+
@add_start_docstrings(CLIP_START_DOCSTRING)
|
|
847
|
+
class LanguageBindVideo(CLIPPreTrainedModel):
|
|
848
|
+
config_class = LanguageBindVideoConfig
|
|
849
|
+
|
|
850
|
+
def __init__(self, config: LanguageBindVideoConfig):
|
|
851
|
+
super().__init__(config)
|
|
852
|
+
|
|
853
|
+
if not isinstance(config.text_config, CLIPTextConfig):
|
|
854
|
+
raise ValueError(
|
|
855
|
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
|
856
|
+
f" {type(config.text_config)}."
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
|
860
|
+
raise ValueError(
|
|
861
|
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
|
862
|
+
f" {type(config.vision_config)}."
|
|
863
|
+
)
|
|
864
|
+
|
|
865
|
+
text_config = config.text_config
|
|
866
|
+
vision_config = config.vision_config
|
|
867
|
+
self.add_time_attn = vision_config.add_time_attn
|
|
868
|
+
self.lora_r = vision_config.lora_r
|
|
869
|
+
self.lora_alpha = vision_config.lora_alpha
|
|
870
|
+
self.lora_dropout = vision_config.lora_dropout
|
|
871
|
+
|
|
872
|
+
self.projection_dim = config.projection_dim
|
|
873
|
+
self.text_embed_dim = text_config.hidden_size
|
|
874
|
+
self.vision_embed_dim = vision_config.hidden_size
|
|
875
|
+
|
|
876
|
+
self.text_model = CLIPTextTransformer(text_config)
|
|
877
|
+
self.vision_model = CLIPVisionTransformer(vision_config)
|
|
878
|
+
|
|
879
|
+
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
|
|
880
|
+
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
|
|
881
|
+
self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
|
|
882
|
+
|
|
883
|
+
# Initialize weights and apply final processing
|
|
884
|
+
self.post_init()
|
|
885
|
+
self.convert_to_lora()
|
|
886
|
+
# self.resize_pos(self.vision_model.embeddings, vision_config)
|
|
887
|
+
|
|
888
|
+
def convert_to_lora(self):
|
|
889
|
+
if self.lora_r == 0:
|
|
890
|
+
return
|
|
891
|
+
if self.add_time_attn:
|
|
892
|
+
target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
|
|
893
|
+
"temporal_attn.q_proj", "temporal_attn.out_proj",
|
|
894
|
+
"temporal_mlp.fc1", "temporal_mlp.fc2"]
|
|
895
|
+
else:
|
|
896
|
+
target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
|
|
897
|
+
config = LoraConfig(
|
|
898
|
+
r=self.lora_r, # 16
|
|
899
|
+
lora_alpha=self.lora_alpha, # 16
|
|
900
|
+
target_modules=target_modules, # self_attn.out_proj
|
|
901
|
+
lora_dropout=self.lora_dropout, # 0.1
|
|
902
|
+
bias="none",
|
|
903
|
+
modules_to_save=[],
|
|
904
|
+
)
|
|
905
|
+
self.vision_model.encoder.is_gradient_checkpointing = False
|
|
906
|
+
self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
|
|
907
|
+
|
|
908
|
+
def resize_pos(self, m, vision_config):
|
|
909
|
+
# convert embedding
|
|
910
|
+
if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
|
|
911
|
+
m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
|
|
912
|
+
m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
|
|
913
|
+
# pos resize
|
|
914
|
+
old_pos_embed_state_dict = m.position_embedding.state_dict()
|
|
915
|
+
old_pos_embed = old_pos_embed_state_dict['weight']
|
|
916
|
+
dtype = old_pos_embed.dtype
|
|
917
|
+
grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
|
|
918
|
+
extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
|
|
919
|
+
new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
|
|
920
|
+
if new_seq_len == old_pos_embed.shape[0]:
|
|
921
|
+
# m.to(args.device)
|
|
922
|
+
return
|
|
923
|
+
|
|
924
|
+
m.num_patches = grid_size[0] * grid_size[1]
|
|
925
|
+
m.num_positions = m.num_patches + 1
|
|
926
|
+
m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
|
|
927
|
+
new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
|
|
928
|
+
|
|
929
|
+
if extra_tokens:
|
|
930
|
+
pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
|
|
931
|
+
else:
|
|
932
|
+
pos_emb_tok, pos_emb_img = None, old_pos_embed
|
|
933
|
+
old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
|
|
934
|
+
|
|
935
|
+
# if is_master(args):
|
|
936
|
+
# logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
|
|
937
|
+
pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
|
|
938
|
+
pos_emb_img = F.interpolate(
|
|
939
|
+
pos_emb_img,
|
|
940
|
+
size=grid_size,
|
|
941
|
+
mode='bicubic',
|
|
942
|
+
antialias=True,
|
|
943
|
+
align_corners=False,
|
|
944
|
+
)
|
|
945
|
+
pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
|
|
946
|
+
if pos_emb_tok is not None:
|
|
947
|
+
new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
|
|
948
|
+
else:
|
|
949
|
+
new_pos_embed = pos_emb_img
|
|
950
|
+
old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
|
|
951
|
+
m.position_embedding = new_position_embedding
|
|
952
|
+
m.position_embedding.load_state_dict(old_pos_embed_state_dict)
|
|
953
|
+
|
|
954
|
+
# m.to(args.device)
|
|
955
|
+
|
|
956
|
+
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
|
|
957
|
+
def get_text_features(
|
|
958
|
+
self,
|
|
959
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
960
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
961
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
962
|
+
output_attentions: Optional[bool] = None,
|
|
963
|
+
output_hidden_states: Optional[bool] = None,
|
|
964
|
+
return_dict: Optional[bool] = None,
|
|
965
|
+
) -> torch.FloatTensor:
|
|
966
|
+
r"""
|
|
967
|
+
Returns:
|
|
968
|
+
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
|
|
969
|
+
applying the projection layer to the pooled output of [`CLIPTextModel`].
|
|
970
|
+
|
|
971
|
+
Examples:
|
|
972
|
+
|
|
973
|
+
```python
|
|
974
|
+
>>> from transformers import AutoTokenizer, CLIPModel
|
|
975
|
+
|
|
976
|
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
977
|
+
>>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
|
|
978
|
+
|
|
979
|
+
>>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
|
|
980
|
+
>>> text_features = model.get_text_features(**inputs)
|
|
981
|
+
```"""
|
|
982
|
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
983
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
984
|
+
output_hidden_states = (
|
|
985
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
986
|
+
)
|
|
987
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
988
|
+
|
|
989
|
+
text_outputs = self.text_model(
|
|
990
|
+
input_ids=input_ids,
|
|
991
|
+
attention_mask=attention_mask,
|
|
992
|
+
position_ids=position_ids,
|
|
993
|
+
output_attentions=output_attentions,
|
|
994
|
+
output_hidden_states=output_hidden_states,
|
|
995
|
+
return_dict=return_dict,
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
pooled_output = text_outputs[1]
|
|
999
|
+
text_features = self.text_projection(pooled_output)
|
|
1000
|
+
|
|
1001
|
+
return text_features
|
|
1002
|
+
|
|
1003
|
+
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
|
|
1004
|
+
def get_image_features(
|
|
1005
|
+
self,
|
|
1006
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
1007
|
+
output_attentions: Optional[bool] = None,
|
|
1008
|
+
output_hidden_states: Optional[bool] = None,
|
|
1009
|
+
return_dict: Optional[bool] = None,
|
|
1010
|
+
) -> torch.FloatTensor:
|
|
1011
|
+
r"""
|
|
1012
|
+
Returns:
|
|
1013
|
+
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
|
|
1014
|
+
applying the projection layer to the pooled output of [`CLIPVisionModel`].
|
|
1015
|
+
|
|
1016
|
+
Examples:
|
|
1017
|
+
|
|
1018
|
+
```python
|
|
1019
|
+
>>> from PIL import Image
|
|
1020
|
+
>>> import requests
|
|
1021
|
+
>>> from transformers import AutoProcessor, CLIPModel
|
|
1022
|
+
|
|
1023
|
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
1024
|
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
1025
|
+
|
|
1026
|
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1027
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
1028
|
+
|
|
1029
|
+
>>> inputs = processor(images=image, return_tensors="pt")
|
|
1030
|
+
|
|
1031
|
+
>>> image_features = model.get_image_features(**inputs)
|
|
1032
|
+
```"""
|
|
1033
|
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
1034
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1035
|
+
output_hidden_states = (
|
|
1036
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1037
|
+
)
|
|
1038
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1039
|
+
|
|
1040
|
+
vision_outputs = self.vision_model(
|
|
1041
|
+
pixel_values=pixel_values,
|
|
1042
|
+
output_attentions=output_attentions,
|
|
1043
|
+
output_hidden_states=output_hidden_states,
|
|
1044
|
+
return_dict=return_dict,
|
|
1045
|
+
)
|
|
1046
|
+
|
|
1047
|
+
pooled_output = vision_outputs[1] # pooled_output
|
|
1048
|
+
image_features = self.visual_projection(pooled_output)
|
|
1049
|
+
|
|
1050
|
+
return image_features
|
|
1051
|
+
|
|
1052
|
+
@add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
|
|
1053
|
+
@replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindVideoConfig)
|
|
1054
|
+
def forward(
|
|
1055
|
+
self,
|
|
1056
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
1057
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
1058
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
1059
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
1060
|
+
return_loss: Optional[bool] = None,
|
|
1061
|
+
output_attentions: Optional[bool] = None,
|
|
1062
|
+
output_hidden_states: Optional[bool] = None,
|
|
1063
|
+
return_dict: Optional[bool] = None,
|
|
1064
|
+
) -> Union[Tuple, CLIPOutput]:
|
|
1065
|
+
r"""
|
|
1066
|
+
Returns:
|
|
1067
|
+
|
|
1068
|
+
Examples:
|
|
1069
|
+
|
|
1070
|
+
```python
|
|
1071
|
+
>>> from PIL import Image
|
|
1072
|
+
>>> import requests
|
|
1073
|
+
>>> from transformers import AutoProcessor, CLIPModel
|
|
1074
|
+
|
|
1075
|
+
>>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
|
1076
|
+
>>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
|
1077
|
+
|
|
1078
|
+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
|
|
1079
|
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
1080
|
+
|
|
1081
|
+
>>> inputs = processor(
|
|
1082
|
+
... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
|
|
1083
|
+
... )
|
|
1084
|
+
|
|
1085
|
+
>>> outputs = model(**inputs)
|
|
1086
|
+
>>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
|
|
1087
|
+
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
|
|
1088
|
+
```"""
|
|
1089
|
+
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
|
|
1090
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1091
|
+
output_hidden_states = (
|
|
1092
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1093
|
+
)
|
|
1094
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
1095
|
+
|
|
1096
|
+
vision_outputs = self.vision_model(
|
|
1097
|
+
pixel_values=pixel_values,
|
|
1098
|
+
output_attentions=output_attentions,
|
|
1099
|
+
output_hidden_states=output_hidden_states,
|
|
1100
|
+
return_dict=return_dict,
|
|
1101
|
+
)
|
|
1102
|
+
|
|
1103
|
+
text_outputs = self.text_model(
|
|
1104
|
+
input_ids=input_ids,
|
|
1105
|
+
attention_mask=attention_mask,
|
|
1106
|
+
position_ids=position_ids,
|
|
1107
|
+
output_attentions=output_attentions,
|
|
1108
|
+
output_hidden_states=output_hidden_states,
|
|
1109
|
+
return_dict=return_dict,
|
|
1110
|
+
)
|
|
1111
|
+
|
|
1112
|
+
image_embeds = vision_outputs[1]
|
|
1113
|
+
image_embeds = self.visual_projection(image_embeds)
|
|
1114
|
+
|
|
1115
|
+
text_embeds = text_outputs[1]
|
|
1116
|
+
text_embeds = self.text_projection(text_embeds)
|
|
1117
|
+
|
|
1118
|
+
# normalized features
|
|
1119
|
+
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
1120
|
+
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
|
|
1121
|
+
|
|
1122
|
+
# cosine similarity as logits
|
|
1123
|
+
logit_scale = self.logit_scale.exp()
|
|
1124
|
+
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
|
|
1125
|
+
logits_per_image = logits_per_text.t()
|
|
1126
|
+
|
|
1127
|
+
loss = None
|
|
1128
|
+
if return_loss:
|
|
1129
|
+
loss = clip_loss(logits_per_text)
|
|
1130
|
+
|
|
1131
|
+
if not return_dict:
|
|
1132
|
+
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
|
|
1133
|
+
return ((loss,) + output) if loss is not None else output
|
|
1134
|
+
|
|
1135
|
+
return CLIPOutput(
|
|
1136
|
+
loss=loss,
|
|
1137
|
+
logits_per_image=logits_per_image,
|
|
1138
|
+
logits_per_text=logits_per_text,
|
|
1139
|
+
text_embeds=text_embeds,
|
|
1140
|
+
image_embeds=image_embeds,
|
|
1141
|
+
text_model_output=text_outputs,
|
|
1142
|
+
vision_model_output=vision_outputs,
|
|
1143
|
+
)
|