optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,9 +13,11 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
from typing import Any, Callable
|
|
16
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
|
17
17
|
|
|
18
|
+
import torch
|
|
18
19
|
from transformers import BartForConditionalGeneration, PreTrainedModel
|
|
20
|
+
from transformers.modeling_outputs import Seq2SeqModelOutput
|
|
19
21
|
|
|
20
22
|
from ....utils.logging import get_logger
|
|
21
23
|
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
|
@@ -35,6 +37,25 @@ class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
35
37
|
on RBLN devices, optimized for feature extraction use cases.
|
|
36
38
|
"""
|
|
37
39
|
|
|
40
|
+
def forward(
|
|
41
|
+
self,
|
|
42
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
43
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> Union[Tuple, Seq2SeqModelOutput]:
|
|
46
|
+
"""
|
|
47
|
+
Forward pass for the RBLN-optimized BART model for feature extraction tasks.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
51
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a Seq2SeqModelOutput object.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
58
|
+
|
|
38
59
|
|
|
39
60
|
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
40
61
|
"""
|
|
@@ -48,7 +69,7 @@ class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
|
48
69
|
support_causal_attn = True
|
|
49
70
|
|
|
50
71
|
@classmethod
|
|
51
|
-
def
|
|
72
|
+
def _wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
|
|
52
73
|
return BartWrapper(
|
|
53
74
|
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
|
54
75
|
)
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BertModelWrapper(torch.nn.Module):
|
|
5
|
+
def __init__(self, model, rbln_config):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.model = model
|
|
8
|
+
self.rbln_config = rbln_config
|
|
9
|
+
|
|
10
|
+
def forward(self, *args, **kwargs):
|
|
11
|
+
output = self.model(*args, **kwargs)
|
|
12
|
+
if isinstance(output, torch.Tensor):
|
|
13
|
+
return output
|
|
14
|
+
elif isinstance(output, tuple):
|
|
15
|
+
return tuple(x for x in output if x is not None)
|
|
16
|
+
return output
|
|
@@ -12,15 +12,22 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import (
|
|
19
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
20
|
+
MaskedLMOutput,
|
|
21
|
+
QuestionAnsweringModelOutput,
|
|
22
|
+
)
|
|
23
|
+
|
|
16
24
|
from ...modeling_generic import (
|
|
17
25
|
RBLNModelForMaskedLM,
|
|
18
26
|
RBLNModelForQuestionAnswering,
|
|
19
27
|
RBLNTransformerEncoderForFeatureExtraction,
|
|
20
28
|
)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
logger = get_logger(__name__)
|
|
29
|
+
from .bert_architecture import BertModelWrapper
|
|
30
|
+
from .configuration_bert import RBLNBertModelConfig
|
|
24
31
|
|
|
25
32
|
|
|
26
33
|
class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
@@ -34,6 +41,46 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
34
41
|
|
|
35
42
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
36
43
|
|
|
44
|
+
@classmethod
|
|
45
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
|
|
46
|
+
return BertModelWrapper(model, rbln_config)
|
|
47
|
+
|
|
48
|
+
def forward(
|
|
49
|
+
self,
|
|
50
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
51
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
52
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple]:
|
|
56
|
+
"""
|
|
57
|
+
Forward pass for the RBLN-optimized BERT model for feature extraction tasks.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
61
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
62
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
63
|
+
position_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of positions of each input sequence tokens in the position embeddings.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
input_map = {
|
|
70
|
+
"input_ids": input_ids,
|
|
71
|
+
"attention_mask": attention_mask,
|
|
72
|
+
"token_type_ids": token_type_ids,
|
|
73
|
+
"position_ids": position_ids,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
model_input_names = getattr(self.rbln_config, "model_input_names", None)
|
|
77
|
+
if model_input_names is None:
|
|
78
|
+
model_input_names = self.rbln_model_input_names
|
|
79
|
+
|
|
80
|
+
ordered_inputs = [input_map[name] for name in model_input_names if name in input_map]
|
|
81
|
+
|
|
82
|
+
return super().forward(*ordered_inputs, **kwargs)
|
|
83
|
+
|
|
37
84
|
|
|
38
85
|
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
39
86
|
"""
|
|
@@ -46,6 +93,27 @@ class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
|
46
93
|
|
|
47
94
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
48
95
|
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
99
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
100
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> Union[MaskedLMOutput, Tuple]:
|
|
103
|
+
"""
|
|
104
|
+
Forward pass for the RBLN-optimized BERT model for masked language modeling tasks.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
108
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
109
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
116
|
+
|
|
49
117
|
|
|
50
118
|
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
51
119
|
"""
|
|
@@ -57,3 +125,24 @@ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
|
57
125
|
"""
|
|
58
126
|
|
|
59
127
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
132
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
133
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
) -> Union[QuestionAnsweringModelOutput, Tuple]:
|
|
136
|
+
"""
|
|
137
|
+
Forward pass for the RBLN-optimized BERT model for question answering tasks.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
141
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
142
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
@@ -12,9 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
@@ -25,6 +29,16 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
|
25
29
|
RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
|
|
26
30
|
"""
|
|
27
31
|
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
batch_size: Optional[int] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.batch_size = batch_size or 1
|
|
39
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
40
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
41
|
+
|
|
28
42
|
|
|
29
43
|
class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
30
44
|
"""
|
|
@@ -36,24 +50,34 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
|
36
50
|
|
|
37
51
|
def __init__(
|
|
38
52
|
self,
|
|
53
|
+
batch_size: Optional[int] = None,
|
|
39
54
|
num_query_tokens: Optional[int] = None,
|
|
40
55
|
image_text_hidden_size: Optional[int] = None,
|
|
41
56
|
**kwargs,
|
|
42
57
|
):
|
|
43
58
|
"""
|
|
44
59
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
Raises:
|
|
49
|
-
ValueError: If batch_size is not a positive integer.
|
|
60
|
+
num_query_tokens (Optional[int]): The number of query tokens passed through the Transformer.
|
|
61
|
+
image_text_hidden_size (Optional[int]): Dimensionality of the hidden state of the image-text fusion layer.
|
|
62
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
63
|
"""
|
|
51
64
|
super().__init__(**kwargs)
|
|
65
|
+
self.batch_size = batch_size or 1
|
|
66
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
67
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
68
|
+
|
|
52
69
|
self.num_query_tokens = num_query_tokens
|
|
53
70
|
self.image_text_hidden_size = image_text_hidden_size
|
|
54
71
|
|
|
55
72
|
|
|
56
73
|
class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
74
|
+
"""
|
|
75
|
+
Configuration class for RBLNBlip2ForConditionalGeneration.
|
|
76
|
+
|
|
77
|
+
This configuration class stores the configuration parameters specific to
|
|
78
|
+
RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
|
|
79
|
+
"""
|
|
80
|
+
|
|
57
81
|
submodules = ["vision_model", "qformer", "language_model"]
|
|
58
82
|
|
|
59
83
|
def __init__(
|
|
@@ -62,14 +86,15 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
62
86
|
vision_model: Optional[RBLNModelConfig] = None,
|
|
63
87
|
qformer: Optional[RBLNModelConfig] = None,
|
|
64
88
|
language_model: Optional[RBLNModelConfig] = None,
|
|
65
|
-
**kwargs:
|
|
89
|
+
**kwargs: Any,
|
|
66
90
|
):
|
|
67
91
|
"""
|
|
68
92
|
Args:
|
|
69
93
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
70
94
|
vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
95
|
+
qformer (Optional[RBLNModelConfig]): Configuration for the RBLN-optimized BLIP-2 Q-Former model.
|
|
71
96
|
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
|
72
|
-
|
|
97
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
73
98
|
|
|
74
99
|
Raises:
|
|
75
100
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -79,6 +104,12 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
79
104
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
80
105
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
81
106
|
|
|
82
|
-
self.
|
|
83
|
-
|
|
84
|
-
|
|
107
|
+
if self.batch_size != 1:
|
|
108
|
+
logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
|
|
109
|
+
logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
|
|
110
|
+
|
|
111
|
+
self.vision_model = self.initialize_submodule_config(
|
|
112
|
+
submodule_config=vision_model, batch_size=1, force_kwargs=True
|
|
113
|
+
)
|
|
114
|
+
self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
|
|
115
|
+
self.language_model = self.initialize_submodule_config(submodule_config=language_model)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
from transformers import (
|
|
@@ -30,38 +30,31 @@ from transformers.utils import logging
|
|
|
30
30
|
|
|
31
31
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
32
32
|
from ....modeling import RBLNModel
|
|
33
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
34
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
logger = logging.get_logger(__name__)
|
|
36
38
|
|
|
37
39
|
if TYPE_CHECKING:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
AutoProcessor,
|
|
41
|
-
AutoTokenizer,
|
|
42
|
-
)
|
|
40
|
+
import rebel
|
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
43
42
|
|
|
44
43
|
|
|
45
|
-
class LoopProjector:
|
|
46
|
-
def __init__(self, language_projection
|
|
47
|
-
|
|
44
|
+
class LoopProjector(LoopProcessor):
|
|
45
|
+
def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
|
|
46
|
+
super().__init__(model=language_projection)
|
|
48
47
|
|
|
49
|
-
def
|
|
50
|
-
query_output
|
|
48
|
+
def _get_batch_size(self, query_output, **kwargs):
|
|
49
|
+
return query_output.shape[0]
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
outputs.append(self.language_projection(query_output[i : i + 1]))
|
|
56
|
-
|
|
57
|
-
outputs = torch.cat(outputs, dim=0)
|
|
58
|
-
return outputs
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
|
|
52
|
+
query_output_item = query_output[index : index + 1]
|
|
53
|
+
return ([query_output_item], {})
|
|
59
54
|
|
|
60
|
-
def
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
def __repr__(self) -> str:
|
|
64
|
-
return repr(self.language_projection)
|
|
55
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
56
|
+
output = torch.cat(outputs, dim=0)
|
|
57
|
+
return output
|
|
65
58
|
|
|
66
59
|
|
|
67
60
|
class RBLNBlip2VisionModel(RBLNModel):
|
|
@@ -72,11 +65,13 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
72
65
|
on RBLN devices, supporting image encoding for multimodal vision-language tasks.
|
|
73
66
|
"""
|
|
74
67
|
|
|
68
|
+
_tp_support = False
|
|
69
|
+
|
|
75
70
|
def get_input_embeddings(self):
|
|
76
71
|
return self.embeddings
|
|
77
72
|
|
|
78
73
|
@classmethod
|
|
79
|
-
def
|
|
74
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
80
75
|
class Blip2VisionModelWrapper(torch.nn.Module):
|
|
81
76
|
def __init__(self, model: "Blip2VisionModel") -> None:
|
|
82
77
|
super().__init__()
|
|
@@ -100,8 +95,7 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
100
95
|
(
|
|
101
96
|
"pixel_values",
|
|
102
97
|
[
|
|
103
|
-
|
|
104
|
-
1,
|
|
98
|
+
rbln_config.batch_size,
|
|
105
99
|
model_config.num_channels,
|
|
106
100
|
model_config.image_size,
|
|
107
101
|
model_config.image_size,
|
|
@@ -116,12 +110,21 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
116
110
|
|
|
117
111
|
def forward(
|
|
118
112
|
self,
|
|
119
|
-
pixel_values,
|
|
120
|
-
output_attentions: Optional[bool] = None,
|
|
121
|
-
output_hidden_states: Optional[bool] = None,
|
|
122
|
-
return_dict: Optional[bool] = None,
|
|
113
|
+
pixel_values: torch.FloatTensor,
|
|
123
114
|
interpolate_pos_encoding: bool = False,
|
|
115
|
+
return_dict: Optional[bool] = None,
|
|
124
116
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
117
|
+
"""
|
|
118
|
+
Forward pass for the RBLN-optimized Blip2VisionModel model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
|
|
122
|
+
interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
|
|
123
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
127
|
+
"""
|
|
125
128
|
batch_size = pixel_values.shape[0]
|
|
126
129
|
outputs = []
|
|
127
130
|
for i in range(batch_size):
|
|
@@ -151,11 +154,13 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
151
154
|
mechanisms for multimodal understanding tasks.
|
|
152
155
|
"""
|
|
153
156
|
|
|
157
|
+
_tp_support = False
|
|
158
|
+
|
|
154
159
|
def get_input_embeddings(self):
|
|
155
160
|
return self.embeddings.word_embeddings
|
|
156
161
|
|
|
157
162
|
@classmethod
|
|
158
|
-
def
|
|
163
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
159
164
|
class Blip2QFormerModelWrapper(torch.nn.Module):
|
|
160
165
|
def __init__(self, model: "Blip2QFormerModel"):
|
|
161
166
|
super().__init__()
|
|
@@ -178,7 +183,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
178
183
|
return Blip2QFormerModelWrapper(model).eval()
|
|
179
184
|
|
|
180
185
|
@classmethod
|
|
181
|
-
def _update_submodule_config(
|
|
186
|
+
def _update_submodule_config(
|
|
187
|
+
cls,
|
|
188
|
+
model: "PreTrainedModel",
|
|
189
|
+
rbln_config: RBLNModelConfig,
|
|
190
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
191
|
+
):
|
|
182
192
|
if rbln_config.num_query_tokens is None:
|
|
183
193
|
rbln_config.num_query_tokens = model.config.num_query_tokens
|
|
184
194
|
|
|
@@ -199,7 +209,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
199
209
|
(
|
|
200
210
|
"query_embeds",
|
|
201
211
|
[
|
|
202
|
-
|
|
212
|
+
rbln_config.batch_size,
|
|
203
213
|
rbln_config.num_query_tokens,
|
|
204
214
|
model_config.hidden_size,
|
|
205
215
|
],
|
|
@@ -208,7 +218,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
208
218
|
(
|
|
209
219
|
"encoder_hidden_states",
|
|
210
220
|
[
|
|
211
|
-
|
|
221
|
+
rbln_config.batch_size,
|
|
212
222
|
# image_text_hidden_size + cls token
|
|
213
223
|
rbln_config.image_text_hidden_size + 1,
|
|
214
224
|
model_config.encoder_hidden_size,
|
|
@@ -218,7 +228,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
218
228
|
(
|
|
219
229
|
"encoder_attention_mask",
|
|
220
230
|
# image_text_hidden_size + cls token
|
|
221
|
-
[
|
|
231
|
+
[rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
|
|
222
232
|
"int64",
|
|
223
233
|
),
|
|
224
234
|
]
|
|
@@ -230,17 +240,22 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
230
240
|
def forward(
|
|
231
241
|
self,
|
|
232
242
|
query_embeds: torch.FloatTensor,
|
|
233
|
-
query_length: Optional[int] = None,
|
|
234
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
|
235
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
|
236
243
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
237
244
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
238
|
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
239
|
-
use_cache: Optional[bool] = None,
|
|
240
|
-
output_attentions: Optional[bool] = None,
|
|
241
|
-
output_hidden_states: Optional[bool] = None,
|
|
242
245
|
return_dict: Optional[bool] = None,
|
|
243
246
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
247
|
+
"""
|
|
248
|
+
The forward pass for the RBLN-optimized Blip2QFormerModel model.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
|
|
252
|
+
encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
|
|
253
|
+
encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
|
|
254
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
|
|
258
|
+
"""
|
|
244
259
|
batch_size = query_embeds.shape[0]
|
|
245
260
|
outputs = []
|
|
246
261
|
for i in range(batch_size):
|
|
@@ -265,7 +280,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
265
280
|
)
|
|
266
281
|
|
|
267
282
|
|
|
268
|
-
class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
283
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
269
284
|
"""
|
|
270
285
|
RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
271
286
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
@@ -348,7 +363,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
|
348
363
|
return self.language_model.get_input_embeddings()
|
|
349
364
|
|
|
350
365
|
@classmethod
|
|
351
|
-
def
|
|
366
|
+
def _wrap_model_if_needed(cls, model, rbln_config):
|
|
352
367
|
return model.language_projection
|
|
353
368
|
|
|
354
369
|
@classmethod
|
|
@@ -433,3 +448,79 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
|
433
448
|
)
|
|
434
449
|
|
|
435
450
|
return inputs_embeds
|
|
451
|
+
|
|
452
|
+
@torch.no_grad()
|
|
453
|
+
def generate(
|
|
454
|
+
self,
|
|
455
|
+
pixel_values: torch.FloatTensor,
|
|
456
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
457
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
458
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
459
|
+
interpolate_pos_encoding: bool = False,
|
|
460
|
+
**generate_kwargs,
|
|
461
|
+
) -> List[torch.LongTensor]:
|
|
462
|
+
"""
|
|
463
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
464
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
pixel_values (torch.FloatTensor): Input images to be processed.
|
|
468
|
+
input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
|
|
469
|
+
attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
|
|
470
|
+
inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
|
|
471
|
+
interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
|
|
472
|
+
Returns:
|
|
473
|
+
A list of strings of length batch_size * num_captions.
|
|
474
|
+
"""
|
|
475
|
+
batch_size = pixel_values.shape[0]
|
|
476
|
+
image_embeds = self.vision_model(
|
|
477
|
+
pixel_values,
|
|
478
|
+
return_dict=True,
|
|
479
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
480
|
+
).last_hidden_state
|
|
481
|
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
482
|
+
|
|
483
|
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
484
|
+
query_outputs = self.qformer(
|
|
485
|
+
query_embeds=query_tokens,
|
|
486
|
+
encoder_hidden_states=image_embeds,
|
|
487
|
+
encoder_attention_mask=image_attention_mask,
|
|
488
|
+
return_dict=True,
|
|
489
|
+
)
|
|
490
|
+
query_output = query_outputs.last_hidden_state
|
|
491
|
+
|
|
492
|
+
if query_output.dtype != image_embeds.dtype:
|
|
493
|
+
query_output = query_output.to(image_embeds.dtype)
|
|
494
|
+
|
|
495
|
+
language_model_inputs = self.language_projection(query_output)
|
|
496
|
+
|
|
497
|
+
if inputs_embeds is None:
|
|
498
|
+
if input_ids is None:
|
|
499
|
+
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
|
500
|
+
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
|
501
|
+
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
|
502
|
+
input_ids = input_ids.repeat(batch_size, 1)
|
|
503
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
504
|
+
|
|
505
|
+
if attention_mask is None:
|
|
506
|
+
attention_mask = torch.ones_like(input_ids)
|
|
507
|
+
|
|
508
|
+
if input_ids is None:
|
|
509
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
510
|
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
511
|
+
)
|
|
512
|
+
special_image_mask = special_image_mask.all(-1)
|
|
513
|
+
else:
|
|
514
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
515
|
+
|
|
516
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
517
|
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
518
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
519
|
+
|
|
520
|
+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
|
521
|
+
if not self.language_model.config.is_encoder_decoder:
|
|
522
|
+
inputs["input_ids"] = input_ids
|
|
523
|
+
|
|
524
|
+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
|
525
|
+
|
|
526
|
+
return outputs
|
|
@@ -12,20 +12,20 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
class RBLNCLIPTextModelConfig(RBLNModelConfig):
|
|
21
|
-
def __init__(self, batch_size: Optional[int] = None, **kwargs:
|
|
21
|
+
def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
|
|
22
22
|
"""
|
|
23
23
|
Args:
|
|
24
24
|
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
25
|
-
|
|
25
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
26
26
|
|
|
27
27
|
Raises:
|
|
28
|
-
ValueError: If batch_size is not a positive integer.
|
|
28
|
+
ValueError: If `batch_size` is not a positive integer.
|
|
29
29
|
"""
|
|
30
30
|
super().__init__(**kwargs)
|
|
31
31
|
self.batch_size = batch_size or 1
|
|
@@ -43,16 +43,27 @@ class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
|
|
|
43
43
|
|
|
44
44
|
|
|
45
45
|
class RBLNCLIPVisionModelConfig(RBLNModelConfig):
|
|
46
|
-
def __init__(
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
batch_size: Optional[int] = None,
|
|
49
|
+
image_size: Optional[int] = None,
|
|
50
|
+
interpolate_pos_encoding: Optional[bool] = None,
|
|
51
|
+
output_hidden_states: Optional[bool] = None,
|
|
52
|
+
output_attentions: Optional[bool] = None,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
):
|
|
47
55
|
"""
|
|
48
56
|
Args:
|
|
49
57
|
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
|
50
58
|
image_size (Optional[int]): The size of input images. Can be an integer for square images,
|
|
51
59
|
a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
|
|
52
|
-
|
|
60
|
+
interpolate_pos_encoding (Optional[bool]): Whether or not to interpolate pre-trained position encodings. Defaults to `False`.
|
|
61
|
+
output_hidden_states (Optional[bool]): Whether or not to return the hidden states of all layers.
|
|
62
|
+
output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers
|
|
63
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
53
64
|
|
|
54
65
|
Raises:
|
|
55
|
-
ValueError: If batch_size is not a positive integer.
|
|
66
|
+
ValueError: If `batch_size` is not a positive integer.
|
|
56
67
|
"""
|
|
57
68
|
super().__init__(**kwargs)
|
|
58
69
|
self.batch_size = batch_size or 1
|
|
@@ -60,6 +71,9 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
|
|
|
60
71
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
61
72
|
|
|
62
73
|
self.image_size = image_size
|
|
74
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding or False
|
|
75
|
+
self.output_hidden_states = output_hidden_states
|
|
76
|
+
self.output_attentions = output_attentions
|
|
63
77
|
|
|
64
78
|
@property
|
|
65
79
|
def image_width(self):
|