optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -12,7 +12,14 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
15
17
|
import torch
|
|
18
|
+
from transformers.modeling_outputs import (
|
|
19
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
20
|
+
MaskedLMOutput,
|
|
21
|
+
QuestionAnsweringModelOutput,
|
|
22
|
+
)
|
|
16
23
|
|
|
17
24
|
from ...modeling_generic import (
|
|
18
25
|
RBLNModelForMaskedLM,
|
|
@@ -35,9 +42,45 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
35
42
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
36
43
|
|
|
37
44
|
@classmethod
|
|
38
|
-
def
|
|
45
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
|
|
39
46
|
return BertModelWrapper(model, rbln_config)
|
|
40
47
|
|
|
48
|
+
def forward(
|
|
49
|
+
self,
|
|
50
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
51
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
52
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple]:
|
|
56
|
+
"""
|
|
57
|
+
Forward pass for the RBLN-optimized BERT model for feature extraction tasks.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
61
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
62
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
63
|
+
position_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of positions of each input sequence tokens in the position embeddings.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
input_map = {
|
|
70
|
+
"input_ids": input_ids,
|
|
71
|
+
"attention_mask": attention_mask,
|
|
72
|
+
"token_type_ids": token_type_ids,
|
|
73
|
+
"position_ids": position_ids,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
model_input_names = getattr(self.rbln_config, "model_input_names", None)
|
|
77
|
+
if model_input_names is None:
|
|
78
|
+
model_input_names = self.rbln_model_input_names
|
|
79
|
+
|
|
80
|
+
ordered_inputs = [input_map[name] for name in model_input_names if name in input_map]
|
|
81
|
+
|
|
82
|
+
return super().forward(*ordered_inputs, **kwargs)
|
|
83
|
+
|
|
41
84
|
|
|
42
85
|
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
43
86
|
"""
|
|
@@ -50,6 +93,27 @@ class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
|
50
93
|
|
|
51
94
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
52
95
|
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
99
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
100
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> Union[MaskedLMOutput, Tuple]:
|
|
103
|
+
"""
|
|
104
|
+
Forward pass for the RBLN-optimized BERT model for masked language modeling tasks.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
108
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
109
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
116
|
+
|
|
53
117
|
|
|
54
118
|
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
55
119
|
"""
|
|
@@ -61,3 +125,24 @@ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
|
61
125
|
"""
|
|
62
126
|
|
|
63
127
|
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
132
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
133
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
) -> Union[QuestionAnsweringModelOutput, Tuple]:
|
|
136
|
+
"""
|
|
137
|
+
Forward pass for the RBLN-optimized BERT model for question answering tasks.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
141
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
142
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
@@ -14,7 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
20
|
from transformers import (
|
|
@@ -71,7 +71,7 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
71
71
|
return self.embeddings
|
|
72
72
|
|
|
73
73
|
@classmethod
|
|
74
|
-
def
|
|
74
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
75
75
|
class Blip2VisionModelWrapper(torch.nn.Module):
|
|
76
76
|
def __init__(self, model: "Blip2VisionModel") -> None:
|
|
77
77
|
super().__init__()
|
|
@@ -111,11 +111,20 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
111
111
|
def forward(
|
|
112
112
|
self,
|
|
113
113
|
pixel_values: torch.FloatTensor,
|
|
114
|
-
output_attentions: Optional[bool] = None,
|
|
115
|
-
output_hidden_states: Optional[bool] = None,
|
|
116
|
-
return_dict: Optional[bool] = None,
|
|
117
114
|
interpolate_pos_encoding: bool = False,
|
|
115
|
+
return_dict: Optional[bool] = None,
|
|
118
116
|
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
117
|
+
"""
|
|
118
|
+
Forward pass for the RBLN-optimized Blip2VisionModel model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
|
|
122
|
+
interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
|
|
123
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
127
|
+
"""
|
|
119
128
|
batch_size = pixel_values.shape[0]
|
|
120
129
|
outputs = []
|
|
121
130
|
for i in range(batch_size):
|
|
@@ -151,7 +160,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
151
160
|
return self.embeddings.word_embeddings
|
|
152
161
|
|
|
153
162
|
@classmethod
|
|
154
|
-
def
|
|
163
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
155
164
|
class Blip2QFormerModelWrapper(torch.nn.Module):
|
|
156
165
|
def __init__(self, model: "Blip2QFormerModel"):
|
|
157
166
|
super().__init__()
|
|
@@ -231,17 +240,22 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
231
240
|
def forward(
|
|
232
241
|
self,
|
|
233
242
|
query_embeds: torch.FloatTensor,
|
|
234
|
-
query_length: Optional[int] = None,
|
|
235
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
|
236
|
-
head_mask: Optional[torch.FloatTensor] = None,
|
|
237
243
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
238
244
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
239
|
-
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
|
240
|
-
use_cache: Optional[bool] = None,
|
|
241
|
-
output_attentions: Optional[bool] = None,
|
|
242
|
-
output_hidden_states: Optional[bool] = None,
|
|
243
245
|
return_dict: Optional[bool] = None,
|
|
244
246
|
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
247
|
+
"""
|
|
248
|
+
The forward pass for the RBLN-optimized Blip2QFormerModel model.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
|
|
252
|
+
encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
|
|
253
|
+
encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
|
|
254
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
|
|
258
|
+
"""
|
|
245
259
|
batch_size = query_embeds.shape[0]
|
|
246
260
|
outputs = []
|
|
247
261
|
for i in range(batch_size):
|
|
@@ -349,7 +363,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
349
363
|
return self.language_model.get_input_embeddings()
|
|
350
364
|
|
|
351
365
|
@classmethod
|
|
352
|
-
def
|
|
366
|
+
def _wrap_model_if_needed(cls, model, rbln_config):
|
|
353
367
|
return model.language_projection
|
|
354
368
|
|
|
355
369
|
@classmethod
|
|
@@ -444,7 +458,20 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
|
|
|
444
458
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
445
459
|
interpolate_pos_encoding: bool = False,
|
|
446
460
|
**generate_kwargs,
|
|
447
|
-
) -> torch.LongTensor:
|
|
461
|
+
) -> List[torch.LongTensor]:
|
|
462
|
+
"""
|
|
463
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
464
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
pixel_values (torch.FloatTensor): Input images to be processed.
|
|
468
|
+
input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
|
|
469
|
+
attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
|
|
470
|
+
inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
|
|
471
|
+
interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
|
|
472
|
+
Returns:
|
|
473
|
+
A list of strings of length batch_size * num_captions.
|
|
474
|
+
"""
|
|
448
475
|
batch_size = pixel_values.shape[0]
|
|
449
476
|
image_embeds = self.vision_model(
|
|
450
477
|
pixel_values,
|
|
@@ -54,7 +54,7 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
|
54
54
|
_tp_support = False
|
|
55
55
|
|
|
56
56
|
@classmethod
|
|
57
|
-
def
|
|
57
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
|
|
58
58
|
return _TextEncoder(model).eval()
|
|
59
59
|
|
|
60
60
|
@classmethod
|
|
@@ -92,6 +92,9 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
|
92
92
|
Args:
|
|
93
93
|
input_ids (torch.LongTensor): The input ids to the model.
|
|
94
94
|
return_dict (Optional[bool]): Whether to return a dictionary of outputs.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPTextModelOutput object.
|
|
95
98
|
"""
|
|
96
99
|
|
|
97
100
|
# To ignore using attention_mask, we override forward method.
|
|
@@ -157,7 +160,7 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
|
157
160
|
_tp_support = False
|
|
158
161
|
|
|
159
162
|
@classmethod
|
|
160
|
-
def
|
|
163
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
|
|
161
164
|
wrapper_cfg = {
|
|
162
165
|
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
|
163
166
|
"output_hidden_states": rbln_config.output_hidden_states,
|
|
@@ -230,6 +233,9 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
|
230
233
|
output_attentions (Optional[bool]): Whether to return attentions.
|
|
231
234
|
output_hidden_states (Optional[bool]): Whether to return hidden states.
|
|
232
235
|
interpolate_pos_encoding (bool): Whether to interpolate position encoding.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
233
239
|
"""
|
|
234
240
|
|
|
235
241
|
if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
|
|
@@ -307,6 +313,38 @@ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
|
|
|
307
313
|
multimodal embedding alignment tasks.
|
|
308
314
|
"""
|
|
309
315
|
|
|
316
|
+
def forward(
|
|
317
|
+
self,
|
|
318
|
+
pixel_values: torch.FloatTensor,
|
|
319
|
+
return_dict: bool = True,
|
|
320
|
+
output_attentions: Optional[bool] = None,
|
|
321
|
+
output_hidden_states: Optional[bool] = None,
|
|
322
|
+
interpolate_pos_encoding: bool = False,
|
|
323
|
+
**kwargs,
|
|
324
|
+
) -> Union[Tuple, CLIPVisionModelOutput]:
|
|
325
|
+
"""
|
|
326
|
+
Forward pass for the RBLN-optimized CLIP vision encoder model with projection.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
pixel_values (torch.Tensor): The pixel values to the model.
|
|
330
|
+
return_dict (bool): Whether to return a dictionary of outputs.
|
|
331
|
+
output_attentions (Optional[bool]): Whether to return attentions.
|
|
332
|
+
output_hidden_states (Optional[bool]): Whether to return hidden states.
|
|
333
|
+
interpolate_pos_encoding (bool): Whether to interpolate position encoding.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPVisionModelOutput object.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
return super().forward(
|
|
340
|
+
pixel_values=pixel_values,
|
|
341
|
+
return_dict=return_dict,
|
|
342
|
+
output_attentions=output_attentions,
|
|
343
|
+
output_hidden_states=output_hidden_states,
|
|
344
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
345
|
+
**kwargs,
|
|
346
|
+
)
|
|
347
|
+
|
|
310
348
|
def _prepare_output(self, output, return_dict):
|
|
311
349
|
# Prepare model output based on return_dict flag.
|
|
312
350
|
# This method can be overridden by subclasses to provide task-specific output handling.
|
|
@@ -77,11 +77,11 @@ class ColPaliModel(nn.Module):
|
|
|
77
77
|
self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
|
|
78
78
|
):
|
|
79
79
|
super().__init__()
|
|
80
|
-
self._original_mod = model
|
|
81
80
|
self.layers = nn.ModuleList(layers)
|
|
82
81
|
self.output_hidden_states = output_hidden_states
|
|
83
|
-
self.
|
|
84
|
-
self.
|
|
82
|
+
self.config = model.config
|
|
83
|
+
self.norm = model.norm
|
|
84
|
+
self.hidden_size = self.config.hidden_size
|
|
85
85
|
self.max_seq_len = max_seq_len
|
|
86
86
|
|
|
87
87
|
def forward(
|
|
@@ -118,7 +118,6 @@ class ColPaliModel(nn.Module):
|
|
|
118
118
|
class ColPaliLayer(nn.Module):
|
|
119
119
|
def __init__(self, layer, self_attn: "ColPaliAttention"):
|
|
120
120
|
super().__init__()
|
|
121
|
-
self._original_mod = layer
|
|
122
121
|
self.self_attn = self_attn
|
|
123
122
|
self.mlp = layer.mlp
|
|
124
123
|
self.input_layernorm = layer.input_layernorm
|
|
@@ -155,27 +154,22 @@ class ColPaliLayer(nn.Module):
|
|
|
155
154
|
class ColPaliAttention(nn.Module):
|
|
156
155
|
def __init__(self, self_attn):
|
|
157
156
|
super().__init__()
|
|
158
|
-
self.
|
|
159
|
-
self.num_heads = getattr(
|
|
160
|
-
|
|
161
|
-
)
|
|
162
|
-
self.head_dim = self._original_mod.head_dim
|
|
157
|
+
self.config = self_attn.config
|
|
158
|
+
self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
|
|
159
|
+
self.head_dim = self_attn.head_dim
|
|
163
160
|
self.scaling = self.head_dim**-0.5
|
|
164
161
|
|
|
165
|
-
if hasattr(
|
|
166
|
-
self.num_key_value_heads =
|
|
167
|
-
elif hasattr(
|
|
168
|
-
self.num_key_value_heads =
|
|
162
|
+
if hasattr(self_attn, "num_key_value_heads"):
|
|
163
|
+
self.num_key_value_heads = self_attn.num_key_value_heads
|
|
164
|
+
elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
|
|
165
|
+
self.num_key_value_heads = self_attn.config.num_key_value_heads
|
|
169
166
|
else:
|
|
170
167
|
self.num_key_value_heads = self.num_heads
|
|
171
168
|
|
|
172
|
-
self.
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
self.
|
|
176
|
-
self.k_proj = self._original_mod.k_proj
|
|
177
|
-
self.v_proj = self._original_mod.v_proj
|
|
178
|
-
self.o_proj = self._original_mod.o_proj
|
|
169
|
+
self.q_proj = self_attn.q_proj
|
|
170
|
+
self.k_proj = self_attn.k_proj
|
|
171
|
+
self.v_proj = self_attn.v_proj
|
|
172
|
+
self.o_proj = self_attn.o_proj
|
|
179
173
|
|
|
180
174
|
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
181
175
|
query_states = self.q_proj(hidden_states)
|
|
@@ -11,7 +11,7 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Any,
|
|
14
|
+
from typing import Any, Optional
|
|
15
15
|
|
|
16
16
|
from ....configuration_utils import RBLNModelConfig
|
|
17
17
|
from ....utils.logging import get_logger
|
|
@@ -33,7 +33,9 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
|
33
33
|
|
|
34
34
|
# Create a configuration object
|
|
35
35
|
config = RBLNColPaliForRetrievalConfig(
|
|
36
|
-
|
|
36
|
+
vlm={
|
|
37
|
+
"language_model": {"prefill_chunk_size": 8192},
|
|
38
|
+
}
|
|
37
39
|
output_hidden_states=False,
|
|
38
40
|
tensor_parallel_size=4
|
|
39
41
|
)
|
|
@@ -47,24 +49,21 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
|
47
49
|
```
|
|
48
50
|
"""
|
|
49
51
|
|
|
50
|
-
|
|
52
|
+
_allow_no_compile_cfgs = True
|
|
53
|
+
submodules = ["vlm"]
|
|
51
54
|
|
|
52
55
|
def __init__(
|
|
53
56
|
self,
|
|
54
57
|
batch_size: Optional[int] = None,
|
|
55
|
-
|
|
58
|
+
vlm: Optional[RBLNModelConfig] = None,
|
|
56
59
|
output_hidden_states: Optional[bool] = None,
|
|
57
|
-
vision_tower: Optional[RBLNModelConfig] = None,
|
|
58
60
|
**kwargs: Any,
|
|
59
61
|
):
|
|
60
62
|
"""
|
|
61
63
|
Args:
|
|
62
64
|
batch_size (Optional[int]): The batch size for the model.
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
|
|
66
|
-
output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
|
|
67
|
-
vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
65
|
+
vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
|
|
66
|
+
output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
|
|
68
67
|
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
69
68
|
Raises:
|
|
70
69
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -74,11 +73,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
|
74
73
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
75
74
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
76
75
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
self.vision_tower = self.initialize_submodule_config(
|
|
81
|
-
submodule_config=vision_tower, batch_size=1, force_kwargs=True
|
|
76
|
+
self.output_hidden_states = output_hidden_states or False
|
|
77
|
+
self.vlm = self.initialize_submodule_config(
|
|
78
|
+
submodule_config=vlm, batch_size=batch_size, output_hidden_states=output_hidden_states
|
|
82
79
|
)
|
|
83
|
-
self.max_seq_lens = max_seq_lens
|
|
84
|
-
self.output_hidden_states = output_hidden_states
|