optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,24 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
import bisect
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from
|
|
18
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
|
17
|
+
from typing import Optional, Tuple, Union
|
|
19
18
|
|
|
20
19
|
import torch
|
|
21
|
-
from transformers import PretrainedConfig, PreTrainedModel
|
|
22
20
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
23
21
|
from transformers.modeling_utils import no_init_weights
|
|
24
|
-
from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
|
|
25
|
-
from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
|
|
22
|
+
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
|
|
26
23
|
|
|
27
|
-
from ....configuration_utils import
|
|
24
|
+
from ....configuration_utils import RBLNModelConfig
|
|
28
25
|
from ....modeling import RBLNModel
|
|
29
26
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
30
|
-
from .colpali_architecture import RBLNColPaliForRetrievalWrapper
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
if TYPE_CHECKING:
|
|
34
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
35
27
|
|
|
36
28
|
|
|
37
29
|
class LoopVisionTower(LoopProcessor):
|
|
@@ -116,17 +108,25 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
116
108
|
from optimum.rbln import RBLNColPaliForRetrieval
|
|
117
109
|
|
|
118
110
|
# Simple usage using rbln_* arguments
|
|
119
|
-
# `max_seq_lens` is automatically inferred from the model config
|
|
120
111
|
model = RBLNColPaliForRetrieval.from_pretrained(
|
|
121
112
|
"vidore/colpali-v1.3-hf",
|
|
122
113
|
export=True,
|
|
123
|
-
|
|
114
|
+
rbln_config={
|
|
115
|
+
"vlm": {
|
|
116
|
+
"language_model": {
|
|
117
|
+
"prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
124
121
|
)
|
|
125
122
|
|
|
126
123
|
# Using a config dictionary
|
|
127
124
|
rbln_config = {
|
|
128
|
-
"
|
|
129
|
-
|
|
125
|
+
"vlm": {
|
|
126
|
+
"language_model": {
|
|
127
|
+
"prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
|
|
128
|
+
}
|
|
129
|
+
}
|
|
130
130
|
}
|
|
131
131
|
model = RBLNColPaliForRetrieval.from_pretrained(
|
|
132
132
|
"vidore/colpali-v1.3-hf",
|
|
@@ -138,7 +138,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
138
138
|
from optimum.rbln import RBLNColPaliForRetrievalConfig
|
|
139
139
|
|
|
140
140
|
config = RBLNColPaliForRetrievalConfig(
|
|
141
|
-
|
|
141
|
+
vlm={
|
|
142
|
+
"language_model": {"prefill_chunk_size": 8192},
|
|
143
|
+
},
|
|
142
144
|
output_hidden_states=False,
|
|
143
145
|
tensor_parallel_size=4
|
|
144
146
|
)
|
|
@@ -151,250 +153,93 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
151
153
|
"""
|
|
152
154
|
|
|
153
155
|
auto_model_class = None
|
|
156
|
+
_rbln_submodule_postfix = "model"
|
|
154
157
|
_rbln_submodules = [
|
|
155
|
-
{"name": "
|
|
158
|
+
{"name": "vlm"},
|
|
156
159
|
]
|
|
157
160
|
|
|
158
161
|
def __post_init__(self, **kwargs):
|
|
159
|
-
self.
|
|
160
|
-
self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
|
|
161
|
-
|
|
162
|
+
self.vlm_model = self.rbln_submodules[0]
|
|
162
163
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
163
|
-
self.
|
|
164
|
-
self.
|
|
165
|
-
self.multi_modal_projector = self._create_multi_modal_projector()
|
|
166
|
-
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
|
|
167
|
-
|
|
164
|
+
self.embedding_proj_layer = self._create_embedding_proj_layer()
|
|
165
|
+
self.embedding_proj_layer.load_state_dict(artifacts["embedding_proj_layer"])
|
|
168
166
|
return super().__post_init__(**kwargs)
|
|
169
167
|
|
|
170
|
-
def
|
|
168
|
+
def _create_embedding_proj_layer(self):
|
|
171
169
|
with no_init_weights():
|
|
172
|
-
|
|
173
|
-
self.config.text_config.
|
|
174
|
-
self.config.text_config.hidden_size,
|
|
175
|
-
self.config.text_config.pad_token_id,
|
|
170
|
+
embedding_proj_layer = torch.nn.Linear(
|
|
171
|
+
self.config.vlm_config.text_config.hidden_size, self.config.embedding_dim
|
|
176
172
|
)
|
|
177
|
-
return
|
|
178
|
-
|
|
179
|
-
def _create_multi_modal_projector(self):
|
|
180
|
-
with no_init_weights():
|
|
181
|
-
multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
|
|
182
|
-
return multi_modal_projector
|
|
183
|
-
|
|
184
|
-
@classmethod
|
|
185
|
-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
186
|
-
return RBLNColPaliForRetrievalWrapper(
|
|
187
|
-
causal_lm=model.vlm,
|
|
188
|
-
embedding_proj_layer=model.embedding_proj_layer,
|
|
189
|
-
max_seq_len=max(rbln_config.max_seq_lens),
|
|
190
|
-
output_hidden_states=rbln_config.output_hidden_states,
|
|
191
|
-
)
|
|
173
|
+
return embedding_proj_layer
|
|
192
174
|
|
|
193
175
|
@classmethod
|
|
194
176
|
def save_torch_artifacts(
|
|
195
177
|
cls,
|
|
196
|
-
model: "
|
|
178
|
+
model: "ColPaliForRetrieval",
|
|
197
179
|
save_dir_path: Path,
|
|
198
180
|
subfolder: str,
|
|
199
181
|
rbln_config: RBLNModelConfig,
|
|
200
182
|
):
|
|
201
183
|
save_dict = {}
|
|
202
|
-
save_dict["
|
|
203
|
-
save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
|
|
184
|
+
save_dict["embedding_proj_layer"] = model.embedding_proj_layer.state_dict()
|
|
204
185
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
205
186
|
|
|
206
|
-
@classmethod
|
|
207
|
-
def _update_rbln_config(
|
|
208
|
-
cls,
|
|
209
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
210
|
-
model: Optional["PreTrainedModel"] = None,
|
|
211
|
-
model_config: Optional["PretrainedConfig"] = None,
|
|
212
|
-
rbln_config: Optional[RBLNModelConfig] = None,
|
|
213
|
-
) -> RBLNModelConfig:
|
|
214
|
-
hidden_size = model_config.vlm_config.text_config.hidden_size
|
|
215
|
-
if rbln_config.max_seq_lens is None:
|
|
216
|
-
rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
|
|
217
|
-
if isinstance(rbln_config.max_seq_lens, int):
|
|
218
|
-
rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
|
|
219
|
-
rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
|
|
220
|
-
|
|
221
|
-
if rbln_config.output_hidden_states is None:
|
|
222
|
-
rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
|
|
223
|
-
|
|
224
|
-
input_infos = []
|
|
225
|
-
for max_seq_len in rbln_config.max_seq_lens:
|
|
226
|
-
input_info = [
|
|
227
|
-
("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
|
|
228
|
-
("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
|
|
229
|
-
("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
|
|
230
|
-
]
|
|
231
|
-
input_infos.append(input_info)
|
|
232
|
-
|
|
233
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
|
|
234
|
-
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
235
|
-
|
|
236
|
-
return rbln_config
|
|
237
|
-
|
|
238
|
-
@classmethod
|
|
239
|
-
def from_model(
|
|
240
|
-
cls,
|
|
241
|
-
model: "PreTrainedModel",
|
|
242
|
-
config: Optional[PretrainedConfig] = None,
|
|
243
|
-
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
244
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
|
245
|
-
subfolder: str = "",
|
|
246
|
-
**kwargs: Any,
|
|
247
|
-
) -> "RBLNModel":
|
|
248
|
-
"""
|
|
249
|
-
Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
|
|
250
|
-
This method performs the actual model conversion and compilation process.
|
|
251
|
-
|
|
252
|
-
Args:
|
|
253
|
-
model (PreTrainedModel): The PyTorch model to be compiled.
|
|
254
|
-
The object must be an instance of the HuggingFace transformers PreTrainedModel class.
|
|
255
|
-
config (Optional[PretrainedConfig]): The configuration object associated with the model.
|
|
256
|
-
rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
|
|
257
|
-
This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
|
|
258
|
-
For detailed configuration options, see the specific model's configuration class documentation.
|
|
259
|
-
kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
|
|
260
|
-
|
|
261
|
-
The method performs the following steps:
|
|
262
|
-
|
|
263
|
-
1. Compiles the PyTorch model into an optimized RBLN graph
|
|
264
|
-
2. Configures the model for the specified NPU device
|
|
265
|
-
3. Creates the necessary runtime objects if requested
|
|
266
|
-
4. Saves the compiled model and configurations
|
|
267
|
-
|
|
268
|
-
Returns:
|
|
269
|
-
(RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
|
|
270
|
-
"""
|
|
271
|
-
if not hasattr(model, "vision_tower"):
|
|
272
|
-
model.vision_tower = model.vlm.vision_tower
|
|
273
|
-
del model.vlm.model.vision_tower
|
|
274
|
-
model = super().from_model(model, config, rbln_config, model_save_dir, subfolder, **kwargs)
|
|
275
|
-
return model
|
|
276
|
-
|
|
277
|
-
@classmethod
|
|
278
|
-
def get_pytorch_model(cls, *args, **kwargs):
|
|
279
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
280
|
-
model.vision_tower = model.vlm.vision_tower
|
|
281
|
-
del model.vlm.model.vision_tower
|
|
282
|
-
return model
|
|
283
|
-
|
|
284
|
-
def get_image_features(self, pixel_values: torch.Tensor):
|
|
285
|
-
# Projects the last hidden state from the vision model into language model space.
|
|
286
|
-
# Args:
|
|
287
|
-
# pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
|
288
|
-
# The tensors corresponding to the input images.
|
|
289
|
-
# Returns:
|
|
290
|
-
# image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
291
|
-
|
|
292
|
-
vision_output_size = [
|
|
293
|
-
pixel_values.shape[0],
|
|
294
|
-
self.config.vlm_config.vision_config.num_image_tokens,
|
|
295
|
-
self.config.vlm_config.vision_config.hidden_size,
|
|
296
|
-
]
|
|
297
|
-
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
|
|
298
|
-
self.vision_tower(pixel_values, out=vision_output)
|
|
299
|
-
image_features = self.multi_modal_projector(vision_output)
|
|
300
|
-
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
|
301
|
-
return image_features
|
|
302
|
-
|
|
303
|
-
def _preprocess_inputs(
|
|
304
|
-
self,
|
|
305
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
306
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
307
|
-
pixel_values: Optional[torch.FloatTensor] = None,
|
|
308
|
-
**kwargs,
|
|
309
|
-
):
|
|
310
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
311
|
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
312
|
-
|
|
313
|
-
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
314
|
-
if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
|
|
315
|
-
special_image_mask = input_ids == self.config.vlm_config.image_token_index
|
|
316
|
-
llm_input_ids = input_ids.clone()
|
|
317
|
-
llm_input_ids[special_image_mask] = 0
|
|
318
|
-
else:
|
|
319
|
-
llm_input_ids = input_ids
|
|
320
|
-
|
|
321
|
-
if inputs_embeds is None:
|
|
322
|
-
inputs_embeds = self.embed_tokens(llm_input_ids)
|
|
323
|
-
|
|
324
|
-
# Merge text and images
|
|
325
|
-
image_features = None
|
|
326
|
-
if pixel_values is not None:
|
|
327
|
-
image_features = self.get_image_features(pixel_values)
|
|
328
|
-
special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
|
|
329
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
330
|
-
|
|
331
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
332
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
333
|
-
|
|
334
|
-
return inputs_embeds, image_features
|
|
335
|
-
|
|
336
187
|
def forward(
|
|
337
188
|
self,
|
|
338
189
|
input_ids: Optional[torch.LongTensor] = None,
|
|
339
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
340
190
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
341
191
|
attention_mask: Optional[torch.Tensor] = None,
|
|
342
|
-
output_attentions: Optional[bool] = None,
|
|
343
192
|
output_hidden_states: Optional[bool] = None,
|
|
344
193
|
return_dict: Optional[bool] = None,
|
|
345
194
|
**kwargs,
|
|
346
195
|
) -> Union[Tuple, ColPaliForRetrievalOutput]:
|
|
196
|
+
"""
|
|
197
|
+
Forward pass for the RBLN-optimized ColPaliForRetrieval model.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of input sequence tokens in the vocabulary.
|
|
201
|
+
pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
|
|
202
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length)): Mask to avoid performing attention on padding token indices.
|
|
203
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
204
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
ColPaliForRetrievalOutput or tuple(torch.FloatTensor)
|
|
208
|
+
"""
|
|
347
209
|
if pixel_values is not None:
|
|
348
210
|
pixel_values = pixel_values.to(dtype=self.dtype)
|
|
349
211
|
|
|
350
|
-
if
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
212
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
213
|
+
output_hidden_states = (
|
|
214
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
215
|
+
)
|
|
216
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
354
217
|
raise ValueError(
|
|
355
218
|
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
356
219
|
f"Please compile again with the correct argument."
|
|
357
220
|
)
|
|
358
221
|
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
222
|
+
vlm_output = self.vlm_model(
|
|
223
|
+
input_ids=input_ids,
|
|
224
|
+
attention_mask=attention_mask,
|
|
225
|
+
pixel_values=pixel_values,
|
|
226
|
+
output_hidden_states=output_hidden_states,
|
|
227
|
+
return_dict=True,
|
|
228
|
+
**kwargs,
|
|
363
229
|
)
|
|
230
|
+
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
|
231
|
+
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
|
|
364
232
|
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
self.rbln_config.max_seq_lens[0],
|
|
370
|
-
self.rbln_config.max_seq_lens[0],
|
|
371
|
-
]
|
|
372
|
-
outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
|
|
373
|
-
if self.rbln_config.output_hidden_states:
|
|
374
|
-
for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
|
|
375
|
-
outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
|
|
376
|
-
|
|
377
|
-
# Embedding_proj_layer is fused on the bottom of the language model.
|
|
378
|
-
self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
|
|
379
|
-
|
|
380
|
-
embeddings = outputs[0][:, : inputs_embeds.shape[1]]
|
|
381
|
-
hidden_states = (
|
|
382
|
-
None
|
|
383
|
-
if not self.rbln_config.output_hidden_states
|
|
384
|
-
else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
# L2 normalization
|
|
388
|
-
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
|
233
|
+
last_hidden_states = vlm_output[0]
|
|
234
|
+
proj_dtype = self.embedding_proj_layer.weight.dtype
|
|
235
|
+
embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype))
|
|
236
|
+
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
|
389
237
|
|
|
390
238
|
if attention_mask is not None:
|
|
391
|
-
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
|
239
|
+
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
|
392
240
|
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
hidden_states=hidden_states,
|
|
399
|
-
image_hidden_states=image_features,
|
|
400
|
-
)
|
|
241
|
+
return ColPaliForRetrievalOutput(
|
|
242
|
+
embeddings=embeddings,
|
|
243
|
+
hidden_states=vlm_hidden_states,
|
|
244
|
+
image_hidden_states=vlm_image_hidden_states,
|
|
245
|
+
)
|
|
@@ -32,14 +32,16 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
|
|
|
32
32
|
|
|
33
33
|
# Create a configuration object
|
|
34
34
|
config = RBLNColQwen2ForRetrievalConfig(
|
|
35
|
-
|
|
36
|
-
"
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
35
|
+
vlm = {
|
|
36
|
+
"visual": {
|
|
37
|
+
"max_seq_lens": 6400,
|
|
38
|
+
"device": 0,
|
|
39
|
+
},
|
|
40
|
+
"max_seq_len": 32_768,
|
|
41
|
+
"tensor_parallel_size": 4,
|
|
42
|
+
"device": [0, 1, 2, 3],
|
|
43
|
+
"output_hidden_states": False,
|
|
44
|
+
}
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
# Use the configuration with from_pretrained
|
|
@@ -51,24 +53,37 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
|
|
|
51
53
|
```
|
|
52
54
|
"""
|
|
53
55
|
|
|
54
|
-
submodules = ["
|
|
56
|
+
submodules = ["vlm"]
|
|
57
|
+
_allow_no_compile_cfgs = True
|
|
55
58
|
|
|
56
59
|
def __init__(
|
|
57
60
|
self,
|
|
58
|
-
visual: Optional[RBLNModelConfig] = None,
|
|
59
61
|
batch_size: Optional[int] = None,
|
|
60
|
-
|
|
61
|
-
|
|
62
|
+
output_hidden_states: Optional[bool] = None,
|
|
63
|
+
vlm: Optional[RBLNModelConfig] = None,
|
|
62
64
|
**kwargs,
|
|
63
65
|
):
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
self.
|
|
66
|
+
"""
|
|
67
|
+
Args:
|
|
68
|
+
batch_size (Optional[int]): The batch size for the model.
|
|
69
|
+
output_hidden_states (Optional[bool]): Whether to output the hidden states of the VLM model.
|
|
70
|
+
vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
|
|
71
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If batch_size is not a positive integer.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(**kwargs)
|
|
76
|
+
self.batch_size = batch_size or 1
|
|
77
|
+
self.output_hidden_states = output_hidden_states or False
|
|
78
|
+
|
|
79
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
80
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
81
|
+
|
|
82
|
+
self.vlm = self.initialize_submodule_config(
|
|
83
|
+
submodule_config=vlm,
|
|
84
|
+
batch_size=batch_size,
|
|
85
|
+
output_hidden_states=output_hidden_states,
|
|
86
|
+
force_kwargs=True,
|
|
87
|
+
logits_to_keep=0,
|
|
88
|
+
use_inputs_embeds=True,
|
|
89
|
+
)
|