optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
|
|
19
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
19
20
|
from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
|
|
20
21
|
|
|
21
22
|
from ....configuration_utils import RBLNCompileConfig
|
|
@@ -50,8 +51,10 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
|
50
51
|
on RBLN devices, supporting text encoding for multimodal tasks.
|
|
51
52
|
"""
|
|
52
53
|
|
|
54
|
+
_tp_support = False
|
|
55
|
+
|
|
53
56
|
@classmethod
|
|
54
|
-
def
|
|
57
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
|
|
55
58
|
return _TextEncoder(model).eval()
|
|
56
59
|
|
|
57
60
|
@classmethod
|
|
@@ -82,7 +85,18 @@ class RBLNCLIPTextModel(RBLNModel):
|
|
|
82
85
|
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
83
86
|
return rbln_config
|
|
84
87
|
|
|
85
|
-
def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
|
|
88
|
+
def forward(self, input_ids: torch.LongTensor, return_dict: Optional[bool] = None, **kwargs) -> torch.FloatTensor:
|
|
89
|
+
"""
|
|
90
|
+
Forward pass for the RBLN-optimized CLIP text encoder model.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
input_ids (torch.LongTensor): The input ids to the model.
|
|
94
|
+
return_dict (Optional[bool]): Whether to return a dictionary of outputs.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPTextModelOutput object.
|
|
98
|
+
"""
|
|
99
|
+
|
|
86
100
|
# To ignore using attention_mask, we override forward method.
|
|
87
101
|
output = super().forward(input_ids, return_dict=return_dict)
|
|
88
102
|
return output
|
|
@@ -111,12 +125,27 @@ class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
|
|
|
111
125
|
|
|
112
126
|
|
|
113
127
|
class _VisionEncoder(torch.nn.Module):
|
|
114
|
-
def __init__(
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
enc: CLIPVisionModel,
|
|
131
|
+
interpolate_pos_encoding: bool,
|
|
132
|
+
output_hidden_states: bool,
|
|
133
|
+
output_attentions: bool,
|
|
134
|
+
):
|
|
115
135
|
super().__init__()
|
|
116
136
|
self.enc = enc
|
|
137
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding
|
|
138
|
+
self.output_hidden_states = output_hidden_states
|
|
139
|
+
self.output_attentions = output_attentions
|
|
117
140
|
|
|
118
141
|
def forward(self, inp):
|
|
119
|
-
enc_out = self.enc(
|
|
142
|
+
enc_out = self.enc(
|
|
143
|
+
inp,
|
|
144
|
+
output_hidden_states=self.output_hidden_states,
|
|
145
|
+
interpolate_pos_encoding=self.interpolate_pos_encoding,
|
|
146
|
+
output_attentions=self.output_attentions,
|
|
147
|
+
return_dict=False,
|
|
148
|
+
)
|
|
120
149
|
return enc_out
|
|
121
150
|
|
|
122
151
|
|
|
@@ -128,9 +157,16 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
|
128
157
|
on RBLN devices, supporting image encoding for multimodal tasks.
|
|
129
158
|
"""
|
|
130
159
|
|
|
160
|
+
_tp_support = False
|
|
161
|
+
|
|
131
162
|
@classmethod
|
|
132
|
-
def
|
|
133
|
-
|
|
163
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
|
|
164
|
+
wrapper_cfg = {
|
|
165
|
+
"interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
|
|
166
|
+
"output_hidden_states": rbln_config.output_hidden_states,
|
|
167
|
+
"output_attentions": rbln_config.output_attentions,
|
|
168
|
+
}
|
|
169
|
+
return _VisionEncoder(model, **wrapper_cfg).eval()
|
|
134
170
|
|
|
135
171
|
@classmethod
|
|
136
172
|
def update_rbln_config_using_pipe(
|
|
@@ -155,6 +191,12 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
|
155
191
|
if rbln_config.image_size is None:
|
|
156
192
|
raise ValueError("`rbln_image_size` should be specified!")
|
|
157
193
|
|
|
194
|
+
if rbln_config.output_attentions is None:
|
|
195
|
+
rbln_config.output_attentions = getattr(model_config, "output_attentions", False)
|
|
196
|
+
|
|
197
|
+
if rbln_config.output_hidden_states is None:
|
|
198
|
+
rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
|
|
199
|
+
|
|
158
200
|
rbln_compile_config = RBLNCompileConfig(
|
|
159
201
|
input_info=[
|
|
160
202
|
(
|
|
@@ -175,28 +217,91 @@ class RBLNCLIPVisionModel(RBLNModel):
|
|
|
175
217
|
|
|
176
218
|
def forward(
|
|
177
219
|
self,
|
|
178
|
-
pixel_values:
|
|
179
|
-
return_dict: bool =
|
|
220
|
+
pixel_values: torch.FloatTensor,
|
|
221
|
+
return_dict: bool = True,
|
|
222
|
+
output_attentions: Optional[bool] = None,
|
|
223
|
+
output_hidden_states: Optional[bool] = None,
|
|
224
|
+
interpolate_pos_encoding: bool = False,
|
|
180
225
|
**kwargs,
|
|
181
|
-
) -> Union[Tuple,
|
|
226
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
227
|
+
"""
|
|
228
|
+
Forward pass for the RBLN-optimized CLIP vision encoder model.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
pixel_values (torch.Tensor): The pixel values to the model.
|
|
232
|
+
return_dict (bool): Whether to return a dictionary of outputs.
|
|
233
|
+
output_attentions (Optional[bool]): Whether to return attentions.
|
|
234
|
+
output_hidden_states (Optional[bool]): Whether to return hidden states.
|
|
235
|
+
interpolate_pos_encoding (bool): Whether to interpolate position encoding.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
239
|
+
"""
|
|
240
|
+
|
|
182
241
|
if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
|
|
183
242
|
logger.warning(
|
|
184
243
|
f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
|
|
185
244
|
)
|
|
245
|
+
|
|
246
|
+
output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
|
|
247
|
+
output_hidden_states = (
|
|
248
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if output_attentions != self.rbln_config.output_attentions:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
|
|
254
|
+
f"Please compile again with the correct argument."
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
258
|
+
raise ValueError(
|
|
259
|
+
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
260
|
+
f"Please compile again with the correct argument."
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
|
|
266
|
+
f"Please compile again with the correct argument."
|
|
267
|
+
)
|
|
268
|
+
|
|
186
269
|
output = super().forward(pixel_values, return_dict=return_dict)
|
|
187
270
|
return output
|
|
188
271
|
|
|
189
272
|
def _prepare_output(self, output, return_dict):
|
|
190
273
|
# Prepare model output based on return_dict flag.
|
|
191
274
|
# This method can be overridden by subclasses to provide task-specific output handling.
|
|
275
|
+
last_hidden_state = output.pop(0)
|
|
276
|
+
pooler_output = output.pop(0)
|
|
277
|
+
vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
|
|
278
|
+
|
|
279
|
+
if self.rbln_config.output_hidden_states:
|
|
280
|
+
hidden_states = ()
|
|
281
|
+
num_hidden_layers = vision_config.num_hidden_layers
|
|
282
|
+
for _ in range(num_hidden_layers + 1):
|
|
283
|
+
hidden_states += (output.pop(0),)
|
|
284
|
+
else:
|
|
285
|
+
hidden_states = None
|
|
286
|
+
|
|
287
|
+
if self.rbln_config.output_attentions:
|
|
288
|
+
attentions = ()
|
|
289
|
+
num_hidden_layers = vision_config.num_hidden_layers
|
|
290
|
+
for _ in range(num_hidden_layers):
|
|
291
|
+
attentions += (output.pop(0),)
|
|
292
|
+
else:
|
|
293
|
+
attentions = None
|
|
192
294
|
|
|
193
295
|
if not return_dict:
|
|
194
|
-
return (
|
|
296
|
+
return tuple(
|
|
297
|
+
item for item in (last_hidden_state, pooler_output, hidden_states, attentions) if item is not None
|
|
298
|
+
)
|
|
195
299
|
else:
|
|
196
|
-
return
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
hidden_states=
|
|
300
|
+
return BaseModelOutputWithPooling(
|
|
301
|
+
last_hidden_state=last_hidden_state,
|
|
302
|
+
pooler_output=pooler_output,
|
|
303
|
+
hidden_states=hidden_states,
|
|
304
|
+
attentions=attentions,
|
|
200
305
|
)
|
|
201
306
|
|
|
202
307
|
|
|
@@ -210,19 +315,70 @@ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
|
|
|
210
315
|
|
|
211
316
|
def forward(
|
|
212
317
|
self,
|
|
213
|
-
pixel_values:
|
|
318
|
+
pixel_values: torch.FloatTensor,
|
|
319
|
+
return_dict: bool = True,
|
|
320
|
+
output_attentions: Optional[bool] = None,
|
|
321
|
+
output_hidden_states: Optional[bool] = None,
|
|
322
|
+
interpolate_pos_encoding: bool = False,
|
|
214
323
|
**kwargs,
|
|
215
324
|
) -> Union[Tuple, CLIPVisionModelOutput]:
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
325
|
+
"""
|
|
326
|
+
Forward pass for the RBLN-optimized CLIP vision encoder model with projection.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
pixel_values (torch.Tensor): The pixel values to the model.
|
|
330
|
+
return_dict (bool): Whether to return a dictionary of outputs.
|
|
331
|
+
output_attentions (Optional[bool]): Whether to return attentions.
|
|
332
|
+
output_hidden_states (Optional[bool]): Whether to return hidden states.
|
|
333
|
+
interpolate_pos_encoding (bool): Whether to interpolate position encoding.
|
|
334
|
+
|
|
335
|
+
Returns:
|
|
336
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPVisionModelOutput object.
|
|
337
|
+
"""
|
|
338
|
+
|
|
339
|
+
return super().forward(
|
|
340
|
+
pixel_values=pixel_values,
|
|
341
|
+
return_dict=return_dict,
|
|
342
|
+
output_attentions=output_attentions,
|
|
343
|
+
output_hidden_states=output_hidden_states,
|
|
344
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
345
|
+
**kwargs,
|
|
228
346
|
)
|
|
347
|
+
|
|
348
|
+
def _prepare_output(self, output, return_dict):
|
|
349
|
+
# Prepare model output based on return_dict flag.
|
|
350
|
+
# This method can be overridden by subclasses to provide task-specific output handling.
|
|
351
|
+
|
|
352
|
+
image_embeds = output.pop(0) if isinstance(output, (tuple, list)) else output
|
|
353
|
+
last_hidden_state = output.pop(0)
|
|
354
|
+
|
|
355
|
+
vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
|
|
356
|
+
|
|
357
|
+
if self.rbln_config.output_hidden_states:
|
|
358
|
+
hidden_states = ()
|
|
359
|
+
num_hidden_layers = vision_config.num_hidden_layers
|
|
360
|
+
for _ in range(num_hidden_layers + 1):
|
|
361
|
+
hidden_states += (output.pop(0),)
|
|
362
|
+
else:
|
|
363
|
+
hidden_states = None
|
|
364
|
+
|
|
365
|
+
if self.rbln_config.output_attentions:
|
|
366
|
+
attentions = ()
|
|
367
|
+
num_hidden_layers = vision_config.num_hidden_layers
|
|
368
|
+
for _ in range(num_hidden_layers):
|
|
369
|
+
attentions += (output.pop(0),)
|
|
370
|
+
else:
|
|
371
|
+
attentions = None
|
|
372
|
+
|
|
373
|
+
if not return_dict:
|
|
374
|
+
return tuple(
|
|
375
|
+
item for item in (image_embeds, last_hidden_state, hidden_states, attentions) if item is not None
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
else:
|
|
379
|
+
return CLIPVisionModelOutput(
|
|
380
|
+
image_embeds=image_embeds,
|
|
381
|
+
last_hidden_state=last_hidden_state,
|
|
382
|
+
hidden_states=hidden_states,
|
|
383
|
+
attentions=attentions,
|
|
384
|
+
)
|
|
@@ -4,10 +4,7 @@ import torch
|
|
|
4
4
|
from torch import nn
|
|
5
5
|
from transformers import GemmaForCausalLM, GemmaModel
|
|
6
6
|
|
|
7
|
-
from ..decoderonly.decoderonly_architecture import
|
|
8
|
-
RotaryEmbedding,
|
|
9
|
-
apply_rotary_pos_emb,
|
|
10
|
-
)
|
|
7
|
+
from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
|
|
11
8
|
|
|
12
9
|
|
|
13
10
|
def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
|
|
@@ -27,11 +24,11 @@ class RBLNColPaliForRetrievalWrapper(nn.Module):
|
|
|
27
24
|
output_hidden_states: bool = False,
|
|
28
25
|
):
|
|
29
26
|
super().__init__()
|
|
30
|
-
self.text_config = causal_lm.config
|
|
27
|
+
self.text_config = causal_lm.config.text_config
|
|
31
28
|
self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
32
29
|
|
|
33
30
|
self.output_hidden_states = output_hidden_states
|
|
34
|
-
self.language_model = self.convert_to_rbln_language_model(causal_lm.model, max_seq_len)
|
|
31
|
+
self.language_model = self.convert_to_rbln_language_model(causal_lm.model.language_model, max_seq_len)
|
|
35
32
|
|
|
36
33
|
self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
|
|
37
34
|
self.embedding_proj_layer = embedding_proj_layer
|
|
@@ -11,9 +11,13 @@
|
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
|
-
from typing import Any,
|
|
14
|
+
from typing import Any, List, Optional, Union
|
|
15
15
|
|
|
16
16
|
from ....configuration_utils import RBLNModelConfig
|
|
17
|
+
from ....utils.logging import get_logger
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
logger = get_logger(__name__)
|
|
17
21
|
|
|
18
22
|
|
|
19
23
|
class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
@@ -24,45 +28,57 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
|
|
|
24
28
|
including vision tower settings and multi-sequence length support.
|
|
25
29
|
|
|
26
30
|
Example usage:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
31
|
+
```python
|
|
32
|
+
from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
|
|
33
|
+
|
|
34
|
+
# Create a configuration object
|
|
35
|
+
config = RBLNColPaliForRetrievalConfig(
|
|
36
|
+
max_seq_lens=1152,
|
|
37
|
+
output_hidden_states=False,
|
|
38
|
+
tensor_parallel_size=4
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Use the configuration with from_pretrained
|
|
42
|
+
model = RBLNColPaliForRetrieval.from_pretrained(
|
|
43
|
+
"vidore/colpali-v1.3-hf",
|
|
44
|
+
export=True,
|
|
45
|
+
rbln_config=config
|
|
46
|
+
)
|
|
47
|
+
```
|
|
44
48
|
"""
|
|
45
49
|
|
|
46
50
|
submodules = ["vision_tower"]
|
|
47
51
|
|
|
48
52
|
def __init__(
|
|
49
53
|
self,
|
|
54
|
+
batch_size: Optional[int] = None,
|
|
50
55
|
max_seq_lens: Union[int, List[int]] = None,
|
|
51
56
|
output_hidden_states: Optional[bool] = None,
|
|
52
57
|
vision_tower: Optional[RBLNModelConfig] = None,
|
|
53
|
-
**kwargs:
|
|
58
|
+
**kwargs: Any,
|
|
54
59
|
):
|
|
55
60
|
"""
|
|
56
61
|
Args:
|
|
62
|
+
batch_size (Optional[int]): The batch size for the model.
|
|
57
63
|
vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
58
64
|
max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
|
|
59
65
|
This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
|
|
60
66
|
output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
|
|
61
|
-
|
|
67
|
+
vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
68
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
62
69
|
Raises:
|
|
63
70
|
ValueError: If batch_size is not a positive integer.
|
|
64
71
|
"""
|
|
65
72
|
super().__init__(**kwargs)
|
|
66
|
-
self.
|
|
73
|
+
self.batch_size = batch_size or 1
|
|
74
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
75
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
76
|
+
|
|
77
|
+
if self.batch_size != 1:
|
|
78
|
+
logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
|
|
79
|
+
|
|
80
|
+
self.vision_tower = self.initialize_submodule_config(
|
|
81
|
+
submodule_config=vision_tower, batch_size=1, force_kwargs=True
|
|
82
|
+
)
|
|
67
83
|
self.max_seq_lens = max_seq_lens
|
|
68
84
|
self.output_hidden_states = output_hidden_states
|