languagebind 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1031 @@
1
+ import math
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from peft import LoraConfig, get_peft_model
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+ from transformers import PreTrainedModel, add_start_docstrings
10
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
11
+ from transformers.models.clip.modeling_clip import CLIPMLP, CLIPAttention, CLIPTextEmbeddings, CLIPVisionEmbeddings, \
12
+ CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPOutput
13
+ from languagebind._compat import _expand_mask, clip_loss
14
+ from transformers.utils import add_start_docstrings_to_model_forward, replace_return_docstrings
15
+
16
+ from .configuration_image import LanguageBindImageConfig, CLIPVisionConfig, CLIPTextConfig
17
+
18
+
19
+
20
+ class PatchDropout(nn.Module):
21
+ """
22
+ https://arxiv.org/abs/2212.00794
23
+ """
24
+
25
+ def __init__(self, prob, exclude_first_token=True):
26
+ super().__init__()
27
+ assert 0 <= prob < 1.
28
+ self.prob = prob
29
+ self.exclude_first_token = exclude_first_token # exclude CLS token
30
+
31
+ def forward(self, x, B, T):
32
+ if not self.training or self.prob == 0.:
33
+ return x
34
+
35
+ if self.exclude_first_token:
36
+ cls_tokens, x = x[:, :1], x[:, 1:]
37
+ else:
38
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
39
+
40
+ batch = x.size()[0]
41
+ num_tokens = x.size()[1]
42
+
43
+ batch_indices = torch.arange(batch)
44
+ batch_indices = batch_indices[..., None]
45
+
46
+ keep_prob = 1 - self.prob
47
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
48
+
49
+ if T == 1:
50
+ rand = torch.randn(batch, num_tokens)
51
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
52
+ else:
53
+ rand = torch.randn(B, num_tokens)
54
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
55
+ patch_indices_keep = patch_indices_keep.unsqueeze(1).repeat(1, T, 1)
56
+ patch_indices_keep = rearrange(patch_indices_keep, 'b t n -> (b t) n')
57
+
58
+
59
+ x = x[batch_indices, patch_indices_keep]
60
+
61
+ if self.exclude_first_token:
62
+ x = torch.cat((cls_tokens, x), dim=1)
63
+
64
+ return x
65
+
66
+ class CLIPEncoderLayer(nn.Module):
67
+ def __init__(self, config: LanguageBindImageConfig):
68
+ super().__init__()
69
+ self.embed_dim = config.hidden_size
70
+ self.self_attn = CLIPAttention(config)
71
+ self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
72
+ self.mlp = CLIPMLP(config)
73
+ self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
74
+
75
+ self.add_time_attn = config.add_time_attn
76
+ if self.add_time_attn:
77
+ self.t = config.num_frames
78
+ self.temporal_embedding = nn.Parameter(torch.zeros(1, config.num_frames, config.hidden_size))
79
+ nn.init.normal_(self.temporal_embedding, std=config.hidden_size ** -0.5)
80
+
81
+ self.embed_dim = config.hidden_size
82
+ self.temporal_attn = CLIPAttention(config)
83
+ self.temporal_layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
84
+ self.temporal_mlp = CLIPMLP(config)
85
+ self.temporal_layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
86
+
87
+ def forward(
88
+ self,
89
+ hidden_states: torch.Tensor,
90
+ attention_mask: torch.Tensor,
91
+ causal_attention_mask: torch.Tensor,
92
+ output_attentions: Optional[bool] = False,
93
+ ) -> Tuple[torch.FloatTensor]:
94
+ """
95
+ Args:
96
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
97
+ attention_mask (`torch.FloatTensor`): attention mask of size
98
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
99
+ `(config.encoder_attention_heads,)`.
100
+ output_attentions (`bool`, *optional*):
101
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
102
+ returned tensors for more detail.
103
+ """
104
+
105
+
106
+ if self.add_time_attn:
107
+ bt, n, d = hidden_states.shape
108
+ t = self.t
109
+
110
+ # time embed
111
+ if t != 1:
112
+ n = hidden_states.shape[1]
113
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
114
+ hidden_states = hidden_states + self.temporal_embedding[:, :t, :]
115
+ hidden_states = rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
116
+
117
+ # time attn
118
+ residual = hidden_states
119
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
120
+ # hidden_states = self.layer_norm1(hidden_states) # share layernorm
121
+ hidden_states = self.temporal_layer_norm1(hidden_states)
122
+ hidden_states, attn_weights = self.temporal_attn(
123
+ hidden_states=hidden_states,
124
+ attention_mask=attention_mask,
125
+ causal_attention_mask=causal_attention_mask,
126
+ output_attentions=output_attentions,
127
+ )
128
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
129
+
130
+ residual = hidden_states
131
+ hidden_states = rearrange(hidden_states, '(b t) n d -> (b n) t d', t=t)
132
+ # hidden_states = self.layer_norm2(hidden_states) # share layernorm
133
+ hidden_states = self.temporal_layer_norm2(hidden_states)
134
+ hidden_states = self.temporal_mlp(hidden_states)
135
+ hidden_states = residual + rearrange(hidden_states, '(b n) t d -> (b t) n d', n=n)
136
+
137
+ # spatial attn
138
+ residual = hidden_states
139
+
140
+ hidden_states = self.layer_norm1(hidden_states)
141
+ hidden_states, attn_weights = self.self_attn(
142
+ hidden_states=hidden_states,
143
+ attention_mask=attention_mask,
144
+ causal_attention_mask=causal_attention_mask,
145
+ output_attentions=output_attentions,
146
+ )
147
+ hidden_states = residual + hidden_states
148
+
149
+ residual = hidden_states
150
+ hidden_states = self.layer_norm2(hidden_states)
151
+ hidden_states = self.mlp(hidden_states)
152
+ hidden_states = residual + hidden_states
153
+
154
+ outputs = (hidden_states,)
155
+
156
+ if output_attentions:
157
+ outputs += (attn_weights,)
158
+
159
+ return outputs
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+ class CLIPPreTrainedModel(PreTrainedModel):
170
+ """
171
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
172
+ models.
173
+ """
174
+
175
+ config_class = LanguageBindImageConfig
176
+ base_model_prefix = "clip"
177
+ supports_gradient_checkpointing = True
178
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
179
+
180
+ def _init_weights(self, module):
181
+ """Initialize the weights"""
182
+ factor = self.config.initializer_factor
183
+ if isinstance(module, CLIPTextEmbeddings):
184
+ module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
185
+ module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02)
186
+ elif isinstance(module, CLIPVisionEmbeddings):
187
+ factor = self.config.initializer_factor
188
+ nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
189
+ nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
190
+ nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
191
+ elif isinstance(module, CLIPAttention):
192
+ factor = self.config.initializer_factor
193
+ in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
194
+ out_proj_std = (module.embed_dim**-0.5) * factor
195
+ nn.init.normal_(module.q_proj.weight, std=in_proj_std)
196
+ nn.init.normal_(module.k_proj.weight, std=in_proj_std)
197
+ nn.init.normal_(module.v_proj.weight, std=in_proj_std)
198
+ nn.init.normal_(module.out_proj.weight, std=out_proj_std)
199
+ elif isinstance(module, CLIPMLP):
200
+ factor = self.config.initializer_factor
201
+ in_proj_std = (
202
+ (module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
203
+ )
204
+ fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
205
+ nn.init.normal_(module.fc1.weight, std=fc_std)
206
+ nn.init.normal_(module.fc2.weight, std=in_proj_std)
207
+ elif isinstance(module, LanguageBindImage):
208
+ nn.init.normal_(
209
+ module.text_projection.weight,
210
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
211
+ )
212
+ nn.init.normal_(
213
+ module.visual_projection.weight,
214
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
215
+ )
216
+ elif isinstance(module, CLIPVisionModelWithProjection):
217
+ nn.init.normal_(
218
+ module.visual_projection.weight,
219
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
220
+ )
221
+ elif isinstance(module, CLIPTextModelWithProjection):
222
+ nn.init.normal_(
223
+ module.text_projection.weight,
224
+ std=self.config.hidden_size**-0.5 * self.config.initializer_factor,
225
+ )
226
+
227
+ if isinstance(module, nn.LayerNorm):
228
+ module.bias.data.zero_()
229
+ module.weight.data.fill_(1.0)
230
+ if isinstance(module, nn.Linear) and module.bias is not None:
231
+ module.bias.data.zero_()
232
+
233
+ def _set_gradient_checkpointing(self, module, value=False):
234
+ if isinstance(module, CLIPEncoder):
235
+ module.gradient_checkpointing = value
236
+
237
+
238
+ CLIP_START_DOCSTRING = r"""
239
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
240
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
241
+ etc.)
242
+
243
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
244
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
245
+ and behavior.
246
+
247
+ Parameters:
248
+ config ([`CLIPConfig`]): Model configuration class with all the parameters of the model.
249
+ Initializing with a config file does not load the weights associated with the model, only the
250
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
251
+ """
252
+
253
+ CLIP_TEXT_INPUTS_DOCSTRING = r"""
254
+ Args:
255
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
256
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
257
+ it.
258
+
259
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
260
+ [`PreTrainedTokenizer.__call__`] for details.
261
+
262
+ [What are input IDs?](../glossary#input-ids)
263
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
264
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
265
+
266
+ - 1 for tokens that are **not masked**,
267
+ - 0 for tokens that are **masked**.
268
+
269
+ [What are attention masks?](../glossary#attention-mask)
270
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
271
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
272
+ config.max_position_embeddings - 1]`.
273
+
274
+ [What are position IDs?](../glossary#position-ids)
275
+ output_attentions (`bool`, *optional*):
276
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
277
+ tensors for more detail.
278
+ output_hidden_states (`bool`, *optional*):
279
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
280
+ more detail.
281
+ return_dict (`bool`, *optional*):
282
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
283
+ """
284
+
285
+ CLIP_VISION_INPUTS_DOCSTRING = r"""
286
+ Args:
287
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
288
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
289
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
290
+ output_attentions (`bool`, *optional*):
291
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
292
+ tensors for more detail.
293
+ output_hidden_states (`bool`, *optional*):
294
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
295
+ more detail.
296
+ return_dict (`bool`, *optional*):
297
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
298
+ """
299
+
300
+ CLIP_INPUTS_DOCSTRING = r"""
301
+ Args:
302
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
303
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
304
+ it.
305
+
306
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
307
+ [`PreTrainedTokenizer.__call__`] for details.
308
+
309
+ [What are input IDs?](../glossary#input-ids)
310
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
311
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
312
+
313
+ - 1 for tokens that are **not masked**,
314
+ - 0 for tokens that are **masked**.
315
+
316
+ [What are attention masks?](../glossary#attention-mask)
317
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
318
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
319
+ config.max_position_embeddings - 1]`.
320
+
321
+ [What are position IDs?](../glossary#position-ids)
322
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
323
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
324
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
325
+ return_loss (`bool`, *optional*):
326
+ Whether or not to return the contrastive loss.
327
+ output_attentions (`bool`, *optional*):
328
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
329
+ tensors for more detail.
330
+ output_hidden_states (`bool`, *optional*):
331
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
332
+ more detail.
333
+ return_dict (`bool`, *optional*):
334
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
335
+ """
336
+
337
+
338
+ class CLIPEncoder(nn.Module):
339
+ """
340
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
341
+ [`CLIPEncoderLayer`].
342
+
343
+ Args:
344
+ config: CLIPConfig
345
+ """
346
+
347
+ def __init__(self, config: LanguageBindImageConfig):
348
+ super().__init__()
349
+ self.config = config
350
+ self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
351
+ self.gradient_checkpointing = False
352
+
353
+ def forward(
354
+ self,
355
+ inputs_embeds,
356
+ attention_mask: Optional[torch.Tensor] = None,
357
+ causal_attention_mask: Optional[torch.Tensor] = None,
358
+ output_attentions: Optional[bool] = None,
359
+ output_hidden_states: Optional[bool] = None,
360
+ return_dict: Optional[bool] = None,
361
+ ) -> Union[Tuple, BaseModelOutput]:
362
+ r"""
363
+ Args:
364
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
365
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
366
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
367
+ than the model's internal embedding lookup matrix.
368
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
369
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
370
+
371
+ - 1 for tokens that are **not masked**,
372
+ - 0 for tokens that are **masked**.
373
+
374
+ [What are attention masks?](../glossary#attention-mask)
375
+ causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
376
+ Causal mask for the text model. Mask values selected in `[0, 1]`:
377
+
378
+ - 1 for tokens that are **not masked**,
379
+ - 0 for tokens that are **masked**.
380
+
381
+ [What are attention masks?](../glossary#attention-mask)
382
+ output_attentions (`bool`, *optional*):
383
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
384
+ returned tensors for more detail.
385
+ output_hidden_states (`bool`, *optional*):
386
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
387
+ for more detail.
388
+ return_dict (`bool`, *optional*):
389
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
390
+ """
391
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
392
+ output_hidden_states = (
393
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
394
+ )
395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
+
397
+ encoder_states = () if output_hidden_states else None
398
+ all_attentions = () if output_attentions else None
399
+
400
+ hidden_states = inputs_embeds
401
+ for idx, encoder_layer in enumerate(self.layers):
402
+ if output_hidden_states:
403
+ encoder_states = encoder_states + (hidden_states,)
404
+ if self.gradient_checkpointing and self.training:
405
+
406
+ def create_custom_forward(module):
407
+ def custom_forward(*inputs):
408
+ return module(*inputs, output_attentions)
409
+
410
+ return custom_forward
411
+
412
+ layer_outputs = torch.utils.checkpoint.checkpoint(
413
+ create_custom_forward(encoder_layer),
414
+ hidden_states,
415
+ attention_mask,
416
+ causal_attention_mask,
417
+ )
418
+ else:
419
+ layer_outputs = encoder_layer(
420
+ hidden_states,
421
+ attention_mask,
422
+ causal_attention_mask,
423
+ output_attentions=output_attentions,
424
+ )
425
+
426
+ hidden_states = layer_outputs[0]
427
+
428
+ if output_attentions:
429
+ all_attentions = all_attentions + (layer_outputs[1],)
430
+
431
+ if output_hidden_states:
432
+ encoder_states = encoder_states + (hidden_states,)
433
+
434
+ if not return_dict:
435
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
436
+ return BaseModelOutput(
437
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
438
+ )
439
+
440
+
441
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
442
+ def _make_causal_mask(
443
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
444
+ ):
445
+ """
446
+ Make causal mask used for bi-directional self-attention.
447
+ """
448
+ bsz, tgt_len = input_ids_shape
449
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
450
+ mask_cond = torch.arange(mask.size(-1), device=device)
451
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
452
+ mask = mask.to(dtype)
453
+
454
+ if past_key_values_length > 0:
455
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
456
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
457
+
458
+
459
+ class CLIPTextTransformer(nn.Module):
460
+ def __init__(self, config: CLIPTextConfig):
461
+ super().__init__()
462
+ self.config = config
463
+ embed_dim = config.hidden_size
464
+ self.embeddings = CLIPTextEmbeddings(config)
465
+ self.encoder = CLIPEncoder(config)
466
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
467
+
468
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
469
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
470
+ def forward(
471
+ self,
472
+ input_ids: Optional[torch.Tensor] = None,
473
+ attention_mask: Optional[torch.Tensor] = None,
474
+ position_ids: Optional[torch.Tensor] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
479
+ r"""
480
+ Returns:
481
+
482
+ """
483
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
484
+ output_hidden_states = (
485
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
486
+ )
487
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
488
+
489
+ if input_ids is None:
490
+ raise ValueError("You have to specify input_ids")
491
+
492
+ input_shape = input_ids.size()
493
+ input_ids = input_ids.view(-1, input_shape[-1])
494
+
495
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
496
+
497
+ # CLIP's text model uses causal mask, prepare it here.
498
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
499
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
500
+ # expand attention_mask
501
+ if attention_mask is not None:
502
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
503
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
504
+
505
+ encoder_outputs = self.encoder(
506
+ inputs_embeds=hidden_states,
507
+ attention_mask=attention_mask,
508
+ causal_attention_mask=causal_attention_mask,
509
+ output_attentions=output_attentions,
510
+ output_hidden_states=output_hidden_states,
511
+ return_dict=return_dict,
512
+ )
513
+
514
+ last_hidden_state = encoder_outputs[0]
515
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
516
+
517
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
518
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
519
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
520
+ pooled_output = last_hidden_state[
521
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
522
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
523
+ ]
524
+
525
+ if not return_dict:
526
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
527
+
528
+ return BaseModelOutputWithPooling(
529
+ last_hidden_state=last_hidden_state,
530
+ pooler_output=pooled_output,
531
+ hidden_states=encoder_outputs.hidden_states,
532
+ attentions=encoder_outputs.attentions,
533
+ )
534
+
535
+
536
+ @add_start_docstrings(
537
+ """The text model from CLIP without any head or projection on top.""",
538
+ CLIP_START_DOCSTRING,
539
+ )
540
+ class CLIPTextModel(CLIPPreTrainedModel):
541
+ config_class = CLIPTextConfig
542
+
543
+ _no_split_modules = ["CLIPEncoderLayer"]
544
+
545
+ def __init__(self, config: CLIPTextConfig):
546
+ super().__init__(config)
547
+ self.text_model = CLIPTextTransformer(config)
548
+ # Initialize weights and apply final processing
549
+ self.post_init()
550
+
551
+ def get_input_embeddings(self) -> nn.Module:
552
+ return self.text_model.embeddings.token_embedding
553
+
554
+ def set_input_embeddings(self, value):
555
+ self.text_model.embeddings.token_embedding = value
556
+
557
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
558
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
559
+ def forward(
560
+ self,
561
+ input_ids: Optional[torch.Tensor] = None,
562
+ attention_mask: Optional[torch.Tensor] = None,
563
+ position_ids: Optional[torch.Tensor] = None,
564
+ output_attentions: Optional[bool] = None,
565
+ output_hidden_states: Optional[bool] = None,
566
+ return_dict: Optional[bool] = None,
567
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
568
+ r"""
569
+ Returns:
570
+
571
+ Examples:
572
+
573
+ ```python
574
+ >>> from transformers import AutoTokenizer, CLIPTextModel
575
+
576
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
577
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
578
+
579
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
580
+
581
+ >>> outputs = model(**inputs)
582
+ >>> last_hidden_state = outputs.last_hidden_state
583
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
584
+ ```"""
585
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
586
+
587
+ return self.text_model(
588
+ input_ids=input_ids,
589
+ attention_mask=attention_mask,
590
+ position_ids=position_ids,
591
+ output_attentions=output_attentions,
592
+ output_hidden_states=output_hidden_states,
593
+ return_dict=return_dict,
594
+ )
595
+
596
+
597
+ class CLIPVisionTransformer(nn.Module):
598
+ def __init__(self, config: CLIPVisionConfig):
599
+ super().__init__()
600
+ self.config = config
601
+ embed_dim = config.hidden_size
602
+
603
+ self.embeddings = CLIPVisionEmbeddings(config)
604
+ self.patch_dropout = PatchDropout(config.force_patch_dropout)
605
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
606
+ self.encoder = CLIPEncoder(config)
607
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
608
+
609
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
610
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
611
+ def forward(
612
+ self,
613
+ pixel_values: Optional[torch.FloatTensor] = None,
614
+ output_attentions: Optional[bool] = None,
615
+ output_hidden_states: Optional[bool] = None,
616
+ return_dict: Optional[bool] = None,
617
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
618
+ r"""
619
+ Returns:
620
+
621
+ """
622
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
623
+ output_hidden_states = (
624
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
625
+ )
626
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
627
+
628
+ if pixel_values is None:
629
+ raise ValueError("You have to specify pixel_values")
630
+ ######################################
631
+ if len(pixel_values.shape) == 7:
632
+ b_new, pair_new, T, bs_new, channel_new, h_new, w_new = pixel_values.shape
633
+ # print(pixel_values.shape)
634
+ B = b_new * pair_new * bs_new
635
+ pixel_values = pixel_values.reshape(B*T, channel_new, h_new, w_new)
636
+
637
+ elif len(pixel_values.shape) == 5:
638
+ B, _, T, _, _ = pixel_values.shape
639
+ # print(pixel_values.shape)
640
+ pixel_values = rearrange(pixel_values, 'b c t h w -> (b t) c h w')
641
+ else:
642
+ # print(pixel_values.shape)
643
+ B, _, _, _ = pixel_values.shape
644
+ T = 1
645
+ ###########################
646
+ hidden_states = self.embeddings(pixel_values)
647
+
648
+ hidden_states = self.patch_dropout(hidden_states, B, T) ##############################################
649
+
650
+ hidden_states = self.pre_layrnorm(hidden_states)
651
+
652
+ encoder_outputs = self.encoder(
653
+ inputs_embeds=hidden_states,
654
+ output_attentions=output_attentions,
655
+ output_hidden_states=output_hidden_states,
656
+ return_dict=return_dict,
657
+ )
658
+
659
+ last_hidden_state = encoder_outputs[0]
660
+ pooled_output = last_hidden_state[:, 0, :]
661
+ pooled_output = self.post_layernorm(pooled_output)
662
+
663
+ pooled_output = pooled_output.reshape(B, T, -1).mean(1) ################################
664
+
665
+ if not return_dict:
666
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
667
+
668
+ return BaseModelOutputWithPooling(
669
+ last_hidden_state=last_hidden_state,
670
+ pooler_output=pooled_output,
671
+ hidden_states=encoder_outputs.hidden_states,
672
+ attentions=encoder_outputs.attentions,
673
+ )
674
+
675
+
676
+ @add_start_docstrings(
677
+ """The vision model from CLIP without any head or projection on top.""",
678
+ CLIP_START_DOCSTRING,
679
+ )
680
+ class CLIPVisionModel(CLIPPreTrainedModel):
681
+ config_class = CLIPVisionConfig
682
+ main_input_name = "pixel_values"
683
+
684
+ def __init__(self, config: CLIPVisionConfig):
685
+ super().__init__(config)
686
+ self.vision_model = CLIPVisionTransformer(config)
687
+ # Initialize weights and apply final processing
688
+ self.post_init()
689
+
690
+ def get_input_embeddings(self) -> nn.Module:
691
+ return self.vision_model.embeddings.patch_embedding
692
+
693
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
694
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
695
+ def forward(
696
+ self,
697
+ pixel_values: Optional[torch.FloatTensor] = None,
698
+ output_attentions: Optional[bool] = None,
699
+ output_hidden_states: Optional[bool] = None,
700
+ return_dict: Optional[bool] = None,
701
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
702
+ r"""
703
+ Returns:
704
+
705
+ Examples:
706
+
707
+ ```python
708
+ >>> from PIL import Image
709
+ >>> import requests
710
+ >>> from transformers import AutoProcessor, CLIPVisionModel
711
+
712
+ >>> model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")
713
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
714
+
715
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
716
+ >>> image = Image.open(requests.get(url, stream=True).raw)
717
+
718
+ >>> inputs = processor(images=image, return_tensors="pt")
719
+
720
+ >>> outputs = model(**inputs)
721
+ >>> last_hidden_state = outputs.last_hidden_state
722
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
723
+ ```"""
724
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
725
+
726
+ return self.vision_model(
727
+ pixel_values=pixel_values,
728
+ output_attentions=output_attentions,
729
+ output_hidden_states=output_hidden_states,
730
+ return_dict=return_dict,
731
+ )
732
+
733
+
734
+ @add_start_docstrings(CLIP_START_DOCSTRING)
735
+ class LanguageBindImage(CLIPPreTrainedModel):
736
+ config_class = LanguageBindImageConfig
737
+
738
+ def __init__(self, config: LanguageBindImageConfig):
739
+ super().__init__(config)
740
+
741
+ if not isinstance(config.text_config, CLIPTextConfig):
742
+ raise ValueError(
743
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
744
+ f" {type(config.text_config)}."
745
+ )
746
+
747
+ if not isinstance(config.vision_config, CLIPVisionConfig):
748
+ raise ValueError(
749
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
750
+ f" {type(config.vision_config)}."
751
+ )
752
+
753
+ text_config = config.text_config
754
+ vision_config = config.vision_config
755
+ self.add_time_attn = vision_config.add_time_attn
756
+ self.lora_r = vision_config.lora_r
757
+ self.lora_alpha = vision_config.lora_alpha
758
+ self.lora_dropout = vision_config.lora_dropout
759
+
760
+ self.projection_dim = config.projection_dim
761
+ self.text_embed_dim = text_config.hidden_size
762
+ self.vision_embed_dim = vision_config.hidden_size
763
+
764
+ self.text_model = CLIPTextTransformer(text_config)
765
+ self.vision_model = CLIPVisionTransformer(vision_config)
766
+
767
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
768
+ self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
769
+ self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
770
+
771
+ # Initialize weights and apply final processing
772
+ self.post_init()
773
+ self.convert_to_lora()
774
+ self.resize_pos(self.vision_model.embeddings, vision_config)
775
+
776
+ def convert_to_lora(self):
777
+ if self.lora_r == 0:
778
+ return
779
+ if self.add_time_attn:
780
+ target_modules = ["temporal_attn.k_proj", "temporal_attn.v_proj",
781
+ "temporal_attn.q_proj", "temporal_attn.out_proj",
782
+ "temporal_mlp.fc1", "temporal_mlp.fc2"]
783
+ else:
784
+ target_modules = ["k_proj", "v_proj", "q_proj", "out_proj"]
785
+ config = LoraConfig(
786
+ r=self.lora_r, # 16
787
+ lora_alpha=self.lora_alpha, # 16
788
+ target_modules=target_modules, # self_attn.out_proj
789
+ lora_dropout=self.lora_dropout, # 0.1
790
+ bias="none",
791
+ modules_to_save=[],
792
+ )
793
+ self.vision_model.encoder.is_gradient_checkpointing = False
794
+ self.vision_model.encoder = get_peft_model(self.vision_model.encoder, config)
795
+
796
+ def resize_pos(self, m, vision_config):
797
+ # convert embedding
798
+ if vision_config.num_mel_bins!=0 and vision_config.target_length!=0:
799
+ m.image_size = [vision_config.num_mel_bins, vision_config.target_length]
800
+ m.config.image_size = [m.image_size, m.image_size] if isinstance(m.image_size, int) else m.image_size
801
+ # pos resize
802
+ old_pos_embed_state_dict = m.position_embedding.state_dict()
803
+ old_pos_embed = old_pos_embed_state_dict['weight']
804
+ dtype = old_pos_embed.dtype
805
+ grid_size = [m.config.image_size[0] // m.patch_size, m.config.image_size[1] // m.patch_size]
806
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
807
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
808
+ if new_seq_len == old_pos_embed.shape[0]:
809
+ # m.to(args.device)
810
+ return
811
+
812
+ m.num_patches = grid_size[0] * grid_size[1]
813
+ m.num_positions = m.num_patches + 1
814
+ m.register_buffer("position_ids", torch.arange(m.num_positions).expand((1, -1)))
815
+ new_position_embedding = nn.Embedding(m.num_positions, m.embed_dim)
816
+
817
+ if extra_tokens:
818
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
819
+ else:
820
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
821
+ old_grid_size = [int(math.sqrt(len(pos_emb_img)))] * 2
822
+
823
+ # if is_master(args):
824
+ # logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
825
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
826
+ pos_emb_img = F.interpolate(
827
+ pos_emb_img,
828
+ size=grid_size,
829
+ mode='bicubic',
830
+ antialias=True,
831
+ align_corners=False,
832
+ )
833
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
834
+ if pos_emb_tok is not None:
835
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
836
+ else:
837
+ new_pos_embed = pos_emb_img
838
+ old_pos_embed_state_dict['weight'] = new_pos_embed.to(dtype)
839
+ m.position_embedding = new_position_embedding
840
+ m.position_embedding.load_state_dict(old_pos_embed_state_dict)
841
+
842
+ # m.to(args.device)
843
+
844
+ @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
845
+ def get_text_features(
846
+ self,
847
+ input_ids: Optional[torch.Tensor] = None,
848
+ attention_mask: Optional[torch.Tensor] = None,
849
+ position_ids: Optional[torch.Tensor] = None,
850
+ output_attentions: Optional[bool] = None,
851
+ output_hidden_states: Optional[bool] = None,
852
+ return_dict: Optional[bool] = None,
853
+ ) -> torch.FloatTensor:
854
+ r"""
855
+ Returns:
856
+ text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
857
+ applying the projection layer to the pooled output of [`CLIPTextModel`].
858
+
859
+ Examples:
860
+
861
+ ```python
862
+ >>> from transformers import AutoTokenizer, CLIPModel
863
+
864
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
865
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
866
+
867
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
868
+ >>> text_features = model.get_text_features(**inputs)
869
+ ```"""
870
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
871
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
872
+ output_hidden_states = (
873
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
874
+ )
875
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
876
+
877
+ text_outputs = self.text_model(
878
+ input_ids=input_ids,
879
+ attention_mask=attention_mask,
880
+ position_ids=position_ids,
881
+ output_attentions=output_attentions,
882
+ output_hidden_states=output_hidden_states,
883
+ return_dict=return_dict,
884
+ )
885
+
886
+ pooled_output = text_outputs[1]
887
+ text_features = self.text_projection(pooled_output)
888
+
889
+ return text_features
890
+
891
+ @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
892
+ def get_image_features(
893
+ self,
894
+ pixel_values: Optional[torch.FloatTensor] = None,
895
+ output_attentions: Optional[bool] = None,
896
+ output_hidden_states: Optional[bool] = None,
897
+ return_dict: Optional[bool] = None,
898
+ ) -> torch.FloatTensor:
899
+ r"""
900
+ Returns:
901
+ image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
902
+ applying the projection layer to the pooled output of [`CLIPVisionModel`].
903
+
904
+ Examples:
905
+
906
+ ```python
907
+ >>> from PIL import Image
908
+ >>> import requests
909
+ >>> from transformers import AutoProcessor, CLIPModel
910
+
911
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
912
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
913
+
914
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
915
+ >>> image = Image.open(requests.get(url, stream=True).raw)
916
+
917
+ >>> inputs = processor(images=image, return_tensors="pt")
918
+
919
+ >>> image_features = model.get_image_features(**inputs)
920
+ ```"""
921
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
922
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
923
+ output_hidden_states = (
924
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
925
+ )
926
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
927
+
928
+ vision_outputs = self.vision_model(
929
+ pixel_values=pixel_values,
930
+ output_attentions=output_attentions,
931
+ output_hidden_states=output_hidden_states,
932
+ return_dict=return_dict,
933
+ )
934
+
935
+ pooled_output = vision_outputs[1] # pooled_output
936
+ image_features = self.visual_projection(pooled_output)
937
+
938
+ return image_features
939
+
940
+ @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING)
941
+ @replace_return_docstrings(output_type=CLIPOutput, config_class=LanguageBindImageConfig)
942
+ def forward(
943
+ self,
944
+ input_ids: Optional[torch.LongTensor] = None,
945
+ pixel_values: Optional[torch.FloatTensor] = None,
946
+ attention_mask: Optional[torch.Tensor] = None,
947
+ position_ids: Optional[torch.LongTensor] = None,
948
+ return_loss: Optional[bool] = None,
949
+ output_attentions: Optional[bool] = None,
950
+ output_hidden_states: Optional[bool] = None,
951
+ return_dict: Optional[bool] = None,
952
+ ) -> Union[Tuple, CLIPOutput]:
953
+ r"""
954
+ Returns:
955
+
956
+ Examples:
957
+
958
+ ```python
959
+ >>> from PIL import Image
960
+ >>> import requests
961
+ >>> from transformers import AutoProcessor, CLIPModel
962
+
963
+ >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
964
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
965
+
966
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
967
+ >>> image = Image.open(requests.get(url, stream=True).raw)
968
+
969
+ >>> inputs = processor(
970
+ ... text=["a photo of a cat", "a photo of a dog"], images=image, return_tensors="pt", padding=True
971
+ ... )
972
+
973
+ >>> outputs = model(**inputs)
974
+ >>> logits_per_image = outputs.logits_per_image # this is the image-text similarity score
975
+ >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
976
+ ```"""
977
+ # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
978
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
979
+ output_hidden_states = (
980
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
981
+ )
982
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
983
+
984
+ vision_outputs = self.vision_model(
985
+ pixel_values=pixel_values,
986
+ output_attentions=output_attentions,
987
+ output_hidden_states=output_hidden_states,
988
+ return_dict=return_dict,
989
+ )
990
+
991
+ text_outputs = self.text_model(
992
+ input_ids=input_ids,
993
+ attention_mask=attention_mask,
994
+ position_ids=position_ids,
995
+ output_attentions=output_attentions,
996
+ output_hidden_states=output_hidden_states,
997
+ return_dict=return_dict,
998
+ )
999
+
1000
+ image_embeds = vision_outputs[1]
1001
+ image_embeds = self.visual_projection(image_embeds)
1002
+
1003
+ text_embeds = text_outputs[1]
1004
+ text_embeds = self.text_projection(text_embeds)
1005
+
1006
+ # normalized features
1007
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
1008
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
1009
+
1010
+ # cosine similarity as logits
1011
+ logit_scale = self.logit_scale.exp()
1012
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
1013
+ logits_per_image = logits_per_text.t()
1014
+
1015
+ loss = None
1016
+ if return_loss:
1017
+ loss = clip_loss(logits_per_text)
1018
+
1019
+ if not return_dict:
1020
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
1021
+ return ((loss,) + output) if loss is not None else output
1022
+
1023
+ return CLIPOutput(
1024
+ loss=loss,
1025
+ logits_per_image=logits_per_image,
1026
+ logits_per_text=logits_per_text,
1027
+ text_embeds=text_embeds,
1028
+ image_embeds=image_embeds,
1029
+ text_model_output=text_outputs,
1030
+ vision_model_output=vision_outputs,
1031
+ )