optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +48 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +50 -21
- optimum/rbln/diffusers/__init__.py +12 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +17 -3
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/controlnet.py +17 -2
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +4 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +20 -45
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/configuration_generic.py +0 -27
- optimum/rbln/transformers/modeling_attention_utils.py +156 -127
- optimum/rbln/transformers/modeling_generic.py +2 -61
- optimum/rbln/transformers/modeling_outputs.py +26 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
- optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
- optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
- optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/import_utils.py +23 -2
- optimum/rbln/utils/runtime_utils.py +42 -6
- optimum/rbln/utils/submodule.py +27 -1
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- optimum/rbln/utils/depreacate_utils.py +0 -16
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -27,6 +27,7 @@ from transformers.modeling_utils import no_init_weights
|
|
|
27
27
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
28
28
|
PatchEmbed,
|
|
29
29
|
Qwen2VisionTransformerPretrainedModel,
|
|
30
|
+
Qwen2VLConfig,
|
|
30
31
|
Qwen2VLModel,
|
|
31
32
|
Qwen2VLRotaryEmbedding,
|
|
32
33
|
VisionRotaryEmbedding,
|
|
@@ -35,7 +36,12 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
|
|
|
35
36
|
from ....configuration_utils import RBLNCompileConfig
|
|
36
37
|
from ....modeling import RBLNModel
|
|
37
38
|
from ....utils.logging import get_logger
|
|
38
|
-
from
|
|
39
|
+
from ...modeling_outputs import _validate_output_hidden_states
|
|
40
|
+
from ..decoderonly.modeling_decoderonly import (
|
|
41
|
+
RBLNDecoderOnlyModel,
|
|
42
|
+
RBLNDecoderOnlyModelForCausalLM,
|
|
43
|
+
RBLNDecoderOnlyOutput,
|
|
44
|
+
)
|
|
39
45
|
from .configuration_qwen2_vl import (
|
|
40
46
|
RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
41
47
|
RBLNQwen2VLForConditionalGenerationConfig,
|
|
@@ -56,6 +62,7 @@ if TYPE_CHECKING:
|
|
|
56
62
|
|
|
57
63
|
class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
58
64
|
auto_model_class = None
|
|
65
|
+
_supports_non_fp32 = True
|
|
59
66
|
|
|
60
67
|
def __post_init__(self, **kwargs):
|
|
61
68
|
self.transformer = self.model[0]
|
|
@@ -89,10 +96,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
89
96
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
90
97
|
|
|
91
98
|
@classmethod
|
|
92
|
-
def
|
|
99
|
+
def _wrap_model_if_needed(
|
|
93
100
|
cls, model: "PreTrainedModel", rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig
|
|
94
101
|
):
|
|
95
|
-
return Qwen2VisionTransformerWrapper(model).eval()
|
|
102
|
+
return Qwen2VisionTransformerWrapper(model, rbln_config).eval()
|
|
96
103
|
|
|
97
104
|
def __getattr__(self, __name: str) -> Any:
|
|
98
105
|
def redirect(func):
|
|
@@ -112,24 +119,24 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
112
119
|
model_config: "PretrainedConfig" = None,
|
|
113
120
|
rbln_config: Optional[RBLNQwen2VisionTransformerPretrainedModelConfig] = None,
|
|
114
121
|
) -> RBLNQwen2VisionTransformerPretrainedModelConfig:
|
|
115
|
-
hidden_size =
|
|
116
|
-
num_heads =
|
|
122
|
+
hidden_size = model_config.embed_dim
|
|
123
|
+
num_heads = model_config.num_heads
|
|
117
124
|
head_dim = hidden_size // num_heads
|
|
118
125
|
|
|
119
126
|
input_infos = []
|
|
120
127
|
for max_seq_len in rbln_config.max_seq_lens:
|
|
121
128
|
input_info = [
|
|
122
|
-
("hidden_states", [max_seq_len, hidden_size],
|
|
123
|
-
("full_attn_masks", [1, 1, max_seq_len, max_seq_len],
|
|
129
|
+
("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
|
|
130
|
+
("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
|
|
124
131
|
(
|
|
125
132
|
"cos",
|
|
126
133
|
[1, 1, max_seq_len, head_dim],
|
|
127
|
-
|
|
134
|
+
rbln_config.dtype,
|
|
128
135
|
),
|
|
129
136
|
(
|
|
130
137
|
"sin",
|
|
131
138
|
[1, 1, max_seq_len, head_dim],
|
|
132
|
-
|
|
139
|
+
rbln_config.dtype,
|
|
133
140
|
),
|
|
134
141
|
]
|
|
135
142
|
input_infos.append(input_info)
|
|
@@ -166,7 +173,7 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
166
173
|
1,
|
|
167
174
|
max_seq_len,
|
|
168
175
|
max_seq_len,
|
|
169
|
-
dtype=
|
|
176
|
+
dtype=hidden_state.dtype,
|
|
170
177
|
)
|
|
171
178
|
|
|
172
179
|
full_attn_masks[:, :, hidden_state.shape[0] : max_seq_len, :] = 0
|
|
@@ -177,10 +184,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
177
184
|
# Processes a batch of images (or frames) through the vision transformer.
|
|
178
185
|
# Each image is handled independently for padding and attention mask generation.
|
|
179
186
|
|
|
180
|
-
hidden_states = self.patch_embed(hidden_states)
|
|
187
|
+
hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
|
|
181
188
|
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
|
182
189
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
|
183
|
-
position_embeddings = (emb.cos(), emb.sin())
|
|
190
|
+
position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
|
|
184
191
|
|
|
185
192
|
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
|
186
193
|
dim=0,
|
|
@@ -200,10 +207,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
200
207
|
try:
|
|
201
208
|
cu_index = torch.searchsorted(self.max_seq_lens, cu_seq_len).item()
|
|
202
209
|
max_seq_len = self.max_seq_lens[cu_index]
|
|
203
|
-
except Exception:
|
|
210
|
+
except Exception as e:
|
|
204
211
|
raise ValueError(
|
|
205
212
|
f"Required seq_len({cu_seq_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
|
|
206
|
-
)
|
|
213
|
+
) from e
|
|
207
214
|
|
|
208
215
|
# Padding for Full Attention Layers
|
|
209
216
|
hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks = (
|
|
@@ -230,64 +237,48 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
|
|
|
230
237
|
return hidden_states
|
|
231
238
|
|
|
232
239
|
|
|
233
|
-
class
|
|
234
|
-
"""
|
|
235
|
-
RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
236
|
-
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
237
|
-
|
|
238
|
-
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
239
|
-
|
|
240
|
-
Important Note:
|
|
241
|
-
This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
|
|
242
|
-
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
243
|
-
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
|
|
244
|
-
|
|
245
|
-
Examples:
|
|
246
|
-
```python
|
|
247
|
-
from optimum.rbln import RBLNQwen2VLForConditionalGeneration
|
|
248
|
-
|
|
249
|
-
model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
|
|
250
|
-
"Qwen/Qwen2-VL-7B-Instruct",
|
|
251
|
-
export=True,
|
|
252
|
-
rbln_config={
|
|
253
|
-
"visual": {
|
|
254
|
-
"max_seq_lens": 6400,
|
|
255
|
-
"device": 0,
|
|
256
|
-
},
|
|
257
|
-
"tensor_parallel_size": 8,
|
|
258
|
-
"max_seq_len": 32_768,
|
|
259
|
-
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
|
260
|
-
},
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
model.save_pretrained("compiled-qwen2-vl-7b-instruct")
|
|
264
|
-
```
|
|
265
|
-
"""
|
|
266
|
-
|
|
240
|
+
class RBLNQwen2VLModel(RBLNDecoderOnlyModel):
|
|
267
241
|
auto_model_class = AutoModelForVision2Seq
|
|
242
|
+
_decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
|
|
243
|
+
_supports_non_fp32 = True
|
|
244
|
+
_use_rotary_emb = False
|
|
268
245
|
_rbln_submodules = [
|
|
269
246
|
{"name": "visual"},
|
|
270
247
|
]
|
|
271
|
-
|
|
272
|
-
|
|
248
|
+
_config_class = Qwen2VLConfig
|
|
249
|
+
_rotary_emb_class = Qwen2VLRotaryEmbedding
|
|
250
|
+
_get_rope_index_func = Qwen2VLModel.get_rope_index
|
|
273
251
|
|
|
274
252
|
def __post_init__(self, **kwargs):
|
|
253
|
+
if hasattr(self.config, "embedding_dim"):
|
|
254
|
+
self.embedding_dim = self.config.embedding_dim
|
|
255
|
+
|
|
256
|
+
if not isinstance(self.config.text_config, PretrainedConfig):
|
|
257
|
+
self.config = self._config_class(
|
|
258
|
+
text_config=self.config.text_config, vision_config=self.config.vision_config
|
|
259
|
+
)
|
|
260
|
+
|
|
275
261
|
super().__post_init__(**kwargs)
|
|
276
262
|
self.visual = self.rbln_submodules[0]
|
|
277
|
-
self.
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
263
|
+
self.rotary_emb = self._rotary_emb_class(self.config)
|
|
264
|
+
if not self.can_generate():
|
|
265
|
+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def logits_last_dim(self):
|
|
269
|
+
if self.can_generate():
|
|
270
|
+
return self.config.vocab_size
|
|
271
|
+
else:
|
|
272
|
+
return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
|
|
283
273
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
274
|
+
def _create_embedding_layer(self):
|
|
275
|
+
with no_init_weights():
|
|
276
|
+
embed_tokens = torch.nn.Embedding(
|
|
277
|
+
self.config.text_config.vocab_size,
|
|
278
|
+
self.config.text_config.hidden_size,
|
|
279
|
+
self.config.text_config.pad_token_id,
|
|
280
|
+
)
|
|
281
|
+
return embed_tokens
|
|
291
282
|
|
|
292
283
|
@classmethod
|
|
293
284
|
def get_input_info(
|
|
@@ -304,52 +295,25 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
304
295
|
(
|
|
305
296
|
"position_emb",
|
|
306
297
|
[2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
|
|
307
|
-
|
|
298
|
+
rbln_config.dtype,
|
|
308
299
|
),
|
|
309
300
|
)
|
|
310
301
|
|
|
311
302
|
return input_info
|
|
312
303
|
|
|
313
|
-
def prepare_inputs_for_generation(
|
|
314
|
-
self,
|
|
315
|
-
input_ids: torch.LongTensor,
|
|
316
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
317
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
318
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
319
|
-
pixel_values=None,
|
|
320
|
-
pixel_values_videos=None,
|
|
321
|
-
image_grid_thw=None,
|
|
322
|
-
video_grid_thw=None,
|
|
323
|
-
**kwargs,
|
|
324
|
-
):
|
|
325
|
-
model_inputs = super().prepare_inputs_for_generation(
|
|
326
|
-
input_ids,
|
|
327
|
-
generate_idx,
|
|
328
|
-
attention_mask,
|
|
329
|
-
inputs_embeds,
|
|
330
|
-
**kwargs,
|
|
331
|
-
)
|
|
332
|
-
|
|
333
|
-
is_prefill_phase = generate_idx is None
|
|
334
|
-
if is_prefill_phase:
|
|
335
|
-
model_inputs.update({"input_ids": input_ids})
|
|
336
|
-
|
|
337
|
-
model_inputs.update(
|
|
338
|
-
{
|
|
339
|
-
"pixel_values": pixel_values,
|
|
340
|
-
"pixel_values_videos": pixel_values_videos,
|
|
341
|
-
"image_grid_thw": image_grid_thw,
|
|
342
|
-
"video_grid_thw": video_grid_thw,
|
|
343
|
-
}
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
return model_inputs
|
|
347
|
-
|
|
348
304
|
def _get_position_embeddings(self, hidden_states, position_ids):
|
|
349
305
|
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
|
350
|
-
mrope_section = self.mrope_section * 2
|
|
351
|
-
cos =
|
|
352
|
-
|
|
306
|
+
mrope_section = self.config.rope_scaling["mrope_section"] * 2
|
|
307
|
+
cos = (
|
|
308
|
+
torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
|
|
309
|
+
.unsqueeze(1)
|
|
310
|
+
.to(self.rbln_config.dtype)
|
|
311
|
+
)
|
|
312
|
+
sin = (
|
|
313
|
+
torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
|
|
314
|
+
.unsqueeze(1)
|
|
315
|
+
.to(self.rbln_config.dtype)
|
|
316
|
+
)
|
|
353
317
|
return torch.stack([cos, sin])
|
|
354
318
|
|
|
355
319
|
def _preprocess_prefill(
|
|
@@ -362,7 +326,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
362
326
|
video_grid_thw: torch.LongTensor = None,
|
|
363
327
|
):
|
|
364
328
|
batch_size = input_ids.shape[0]
|
|
365
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
329
|
+
inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
|
|
366
330
|
|
|
367
331
|
if pixel_values is not None:
|
|
368
332
|
image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
|
|
@@ -397,7 +361,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
397
361
|
max_inputs_len = input_ids.shape[1]
|
|
398
362
|
|
|
399
363
|
head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
|
|
400
|
-
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
|
|
364
|
+
all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
|
|
401
365
|
all_rope_deltas = []
|
|
402
366
|
|
|
403
367
|
image_token_id = self.config.image_token_id
|
|
@@ -411,8 +375,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
411
375
|
vision_tokens = input_id[0][vision_start_indices + 1]
|
|
412
376
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
413
377
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
414
|
-
position_ids, rope_deltas =
|
|
415
|
-
self,
|
|
378
|
+
position_ids, rope_deltas = self._get_rope_index_func(
|
|
416
379
|
input_id,
|
|
417
380
|
image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
|
|
418
381
|
video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
|
|
@@ -429,6 +392,177 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
429
392
|
|
|
430
393
|
return inputs_embeds, all_position_embeds, rope_deltas
|
|
431
394
|
|
|
395
|
+
def forward(
|
|
396
|
+
self,
|
|
397
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
398
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
399
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
400
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
401
|
+
pixel_values_videos: Optional[torch.FloatTensor] = None,
|
|
402
|
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
|
403
|
+
video_grid_thw: Optional[torch.LongTensor] = None,
|
|
404
|
+
cache_position: Optional[torch.LongTensor] = None,
|
|
405
|
+
output_hidden_states: Optional[bool] = None,
|
|
406
|
+
return_dict: Optional[bool] = None,
|
|
407
|
+
**kwargs,
|
|
408
|
+
) -> RBLNDecoderOnlyOutput:
|
|
409
|
+
inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
|
|
410
|
+
input_ids,
|
|
411
|
+
attention_mask,
|
|
412
|
+
pixel_values,
|
|
413
|
+
pixel_values_videos,
|
|
414
|
+
image_grid_thw,
|
|
415
|
+
video_grid_thw,
|
|
416
|
+
)
|
|
417
|
+
|
|
418
|
+
self.rope_deltas = rope_deltas
|
|
419
|
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
420
|
+
|
|
421
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
422
|
+
|
|
423
|
+
all_hidden_states = (
|
|
424
|
+
tuple(
|
|
425
|
+
torch.zeros(
|
|
426
|
+
batch_size,
|
|
427
|
+
seq_len,
|
|
428
|
+
self.config.hidden_size,
|
|
429
|
+
dtype=self.rbln_config.dtype,
|
|
430
|
+
)
|
|
431
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
432
|
+
)
|
|
433
|
+
if output_hidden_states
|
|
434
|
+
else None
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
logits = []
|
|
438
|
+
for b_idx in range(batch_size):
|
|
439
|
+
query_length = attention_mask[b_idx].sum(dim=-1).int().item()
|
|
440
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
441
|
+
|
|
442
|
+
outputs = self.prefill_decoder(
|
|
443
|
+
inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
|
|
444
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
445
|
+
cache_position=cache_position,
|
|
446
|
+
batch_idx=b_idx,
|
|
447
|
+
position_embed=position_embed[:, b_idx : b_idx + 1],
|
|
448
|
+
block_tables=self.block_tables,
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
logits.append(outputs.logits)
|
|
452
|
+
if self.rbln_config.output_hidden_states:
|
|
453
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
454
|
+
all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
|
|
455
|
+
|
|
456
|
+
logits = torch.cat(logits, dim=0)
|
|
457
|
+
|
|
458
|
+
if not return_dict:
|
|
459
|
+
return_value = logits if not output_hidden_states else (logits, all_hidden_states)
|
|
460
|
+
return return_value
|
|
461
|
+
else:
|
|
462
|
+
return (
|
|
463
|
+
RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
|
|
464
|
+
if output_hidden_states
|
|
465
|
+
else RBLNDecoderOnlyOutput(logits=logits)
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
# MRO: RBLNQwen2VLForConditionalGeneration -> RBLNQwen2VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
|
|
470
|
+
class RBLNQwen2VLForConditionalGeneration(RBLNQwen2VLModel, RBLNDecoderOnlyModelForCausalLM):
|
|
471
|
+
"""
|
|
472
|
+
RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
473
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
474
|
+
|
|
475
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
476
|
+
|
|
477
|
+
Important Note:
|
|
478
|
+
This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
|
|
479
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
480
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
|
|
481
|
+
|
|
482
|
+
Examples:
|
|
483
|
+
```python
|
|
484
|
+
from optimum.rbln import RBLNQwen2VLForConditionalGeneration
|
|
485
|
+
|
|
486
|
+
model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
|
|
487
|
+
"Qwen/Qwen2-VL-7B-Instruct",
|
|
488
|
+
export=True,
|
|
489
|
+
rbln_config={
|
|
490
|
+
"visual": {
|
|
491
|
+
"max_seq_lens": 6400,
|
|
492
|
+
"device": 0,
|
|
493
|
+
},
|
|
494
|
+
"tensor_parallel_size": 8,
|
|
495
|
+
"max_seq_len": 32_768,
|
|
496
|
+
"device": [0, 1, 2, 3, 4, 5, 6, 7],
|
|
497
|
+
},
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
model.save_pretrained("compiled-qwen2-vl-7b-instruct")
|
|
501
|
+
```
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
auto_model_class = AutoModelForVision2Seq
|
|
505
|
+
_decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
|
|
506
|
+
_supports_non_fp32 = True
|
|
507
|
+
_use_rotary_emb = False
|
|
508
|
+
_rbln_submodules = [
|
|
509
|
+
{"name": "visual"},
|
|
510
|
+
]
|
|
511
|
+
|
|
512
|
+
def __post_init__(self, **kwargs):
|
|
513
|
+
super().__post_init__(**kwargs)
|
|
514
|
+
self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
|
|
515
|
+
|
|
516
|
+
def can_generate(self):
|
|
517
|
+
return True
|
|
518
|
+
|
|
519
|
+
@classmethod
|
|
520
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
521
|
+
model.model.lm_head = model.lm_head
|
|
522
|
+
return model
|
|
523
|
+
|
|
524
|
+
def prepare_inputs_for_generation(
|
|
525
|
+
self,
|
|
526
|
+
input_ids: torch.LongTensor,
|
|
527
|
+
generate_idx: Optional[torch.Tensor] = None,
|
|
528
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
529
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
530
|
+
pixel_values=None,
|
|
531
|
+
pixel_values_videos=None,
|
|
532
|
+
image_grid_thw=None,
|
|
533
|
+
video_grid_thw=None,
|
|
534
|
+
**kwargs,
|
|
535
|
+
):
|
|
536
|
+
model_inputs = {}
|
|
537
|
+
is_prefill_phase = generate_idx is None
|
|
538
|
+
|
|
539
|
+
if is_prefill_phase:
|
|
540
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
541
|
+
cache_position = None
|
|
542
|
+
model_inputs.update({"input_ids": input_ids})
|
|
543
|
+
else:
|
|
544
|
+
if inputs_embeds is not None:
|
|
545
|
+
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
|
546
|
+
|
|
547
|
+
input_ids = input_ids[:, -1:]
|
|
548
|
+
cache_position = generate_idx
|
|
549
|
+
generate_idx = generate_idx + 1
|
|
550
|
+
model_inputs.update({"input_ids": input_ids})
|
|
551
|
+
|
|
552
|
+
model_inputs.update(
|
|
553
|
+
{
|
|
554
|
+
"attention_mask": attention_mask,
|
|
555
|
+
"cache_position": cache_position,
|
|
556
|
+
"generate_idx": generate_idx,
|
|
557
|
+
"pixel_values": pixel_values,
|
|
558
|
+
"pixel_values_videos": pixel_values_videos,
|
|
559
|
+
"image_grid_thw": image_grid_thw,
|
|
560
|
+
"video_grid_thw": video_grid_thw,
|
|
561
|
+
}
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
return model_inputs
|
|
565
|
+
|
|
432
566
|
def _preprocess_decoder(
|
|
433
567
|
self,
|
|
434
568
|
input_ids: torch.LongTensor = None,
|
|
@@ -439,14 +573,14 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
439
573
|
f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
|
|
440
574
|
)
|
|
441
575
|
|
|
442
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
|
576
|
+
inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
|
|
443
577
|
position_embeds = []
|
|
444
578
|
for b_idx in range(self.rbln_config.batch_size):
|
|
445
579
|
delta = cache_position[b_idx] + self.rope_deltas[b_idx]
|
|
446
580
|
position_ids = torch.arange(1).view(1, -1)
|
|
447
581
|
position_ids = position_ids.add(delta)
|
|
448
582
|
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
|
|
449
|
-
position_embed = self._get_position_embeddings(torch.zeros(1, dtype=
|
|
583
|
+
position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
|
|
450
584
|
position_embeds.append(position_embed)
|
|
451
585
|
|
|
452
586
|
position_embeds = torch.cat(position_embeds, dim=1)
|
|
@@ -465,8 +599,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
465
599
|
cache_position: Optional[torch.LongTensor] = None,
|
|
466
600
|
generate_idx: Optional[torch.Tensor] = None,
|
|
467
601
|
return_dict: Optional[bool] = None,
|
|
602
|
+
output_hidden_states: Optional[bool] = None,
|
|
468
603
|
**kwargs,
|
|
469
604
|
) -> RBLNDecoderOnlyOutput:
|
|
605
|
+
output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
|
|
470
606
|
# Prefill
|
|
471
607
|
if cache_position is None:
|
|
472
608
|
inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
|
|
@@ -478,8 +614,21 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
478
614
|
video_grid_thw,
|
|
479
615
|
)
|
|
480
616
|
|
|
617
|
+
batch_size, seq_len = inputs_embeds.shape[:2]
|
|
618
|
+
all_hidden_states = (
|
|
619
|
+
tuple(
|
|
620
|
+
torch.zeros(
|
|
621
|
+
batch_size,
|
|
622
|
+
seq_len,
|
|
623
|
+
self.config.hidden_size,
|
|
624
|
+
dtype=self.rbln_config.dtype,
|
|
625
|
+
)
|
|
626
|
+
for _ in range(self.config.num_hidden_layers + 1)
|
|
627
|
+
)
|
|
628
|
+
if output_hidden_states
|
|
629
|
+
else None
|
|
630
|
+
)
|
|
481
631
|
self.rope_deltas = rope_deltas
|
|
482
|
-
batch_size = inputs_embeds.shape[0]
|
|
483
632
|
|
|
484
633
|
logits = []
|
|
485
634
|
for b_idx in range(batch_size):
|
|
@@ -493,8 +642,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
493
642
|
position_embed=position_embed[:, b_idx : b_idx + 1],
|
|
494
643
|
)
|
|
495
644
|
logits.append(output.logits)
|
|
645
|
+
if self.rbln_config.output_hidden_states:
|
|
646
|
+
for l_idx in range(self.config.num_hidden_layers + 1):
|
|
647
|
+
all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
|
|
496
648
|
logits = torch.cat(logits, dim=0)
|
|
497
|
-
|
|
498
649
|
# Decoder
|
|
499
650
|
else:
|
|
500
651
|
inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
|
|
@@ -504,11 +655,17 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
|
|
|
504
655
|
position_embed=position_embed,
|
|
505
656
|
)
|
|
506
657
|
logits = output.logits
|
|
658
|
+
all_hidden_states = output.hidden_states
|
|
507
659
|
|
|
508
660
|
if not return_dict:
|
|
509
|
-
|
|
661
|
+
return_value = (
|
|
662
|
+
logits,
|
|
663
|
+
generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
|
|
664
|
+
)
|
|
665
|
+
return return_value
|
|
510
666
|
else:
|
|
511
667
|
return RBLNDecoderOnlyOutput(
|
|
512
668
|
logits=logits,
|
|
513
669
|
generate_idx=generate_idx,
|
|
670
|
+
hidden_states=all_hidden_states,
|
|
514
671
|
)
|
|
@@ -9,19 +9,24 @@ from ..decoderonly.decoderonly_architecture import (
|
|
|
9
9
|
DecoderOnlyWrapper,
|
|
10
10
|
apply_rotary_pos_emb,
|
|
11
11
|
)
|
|
12
|
+
from .configuration_qwen2_vl import RBLNQwen2VisionTransformerPretrainedModelConfig
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class Qwen2VisionTransformerWrapper(nn.Module):
|
|
15
|
-
def __init__(self, model: torch.nn.Module):
|
|
16
|
+
def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
|
|
16
17
|
super().__init__()
|
|
17
|
-
self._original_mod = model
|
|
18
18
|
self.merger = model.merger
|
|
19
|
-
self.
|
|
19
|
+
self.rbln_config = rbln_config
|
|
20
|
+
self.blocks = self.wrap_vision_blocks(model.blocks, rbln_config)
|
|
20
21
|
|
|
21
|
-
def wrap_vision_blocks(
|
|
22
|
+
def wrap_vision_blocks(
|
|
23
|
+
self,
|
|
24
|
+
blocks: torch.nn.ModuleList,
|
|
25
|
+
rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig,
|
|
26
|
+
):
|
|
22
27
|
wrapped_blocks = []
|
|
23
|
-
for
|
|
24
|
-
wrapped_blocks.append(Qwen2VLVisionBlock(block))
|
|
28
|
+
for _, block in enumerate(blocks):
|
|
29
|
+
wrapped_blocks.append(Qwen2VLVisionBlock(block, rbln_config))
|
|
25
30
|
return nn.ModuleList(wrapped_blocks)
|
|
26
31
|
|
|
27
32
|
def forward(
|
|
@@ -31,7 +36,7 @@ class Qwen2VisionTransformerWrapper(nn.Module):
|
|
|
31
36
|
cos: torch.Tensor,
|
|
32
37
|
sin: torch.Tensor,
|
|
33
38
|
):
|
|
34
|
-
full_attn_masks = (1 - full_attn_masks) * torch.finfo(
|
|
39
|
+
full_attn_masks = (1.0 - full_attn_masks) * torch.finfo(hidden_states.dtype).min
|
|
35
40
|
|
|
36
41
|
for block in self.blocks:
|
|
37
42
|
hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
|
|
@@ -40,13 +45,13 @@ class Qwen2VisionTransformerWrapper(nn.Module):
|
|
|
40
45
|
|
|
41
46
|
|
|
42
47
|
class Qwen2VLVisionBlock(torch.nn.Module):
|
|
43
|
-
def __init__(self, model: torch.nn.Module):
|
|
48
|
+
def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
|
|
44
49
|
super().__init__()
|
|
45
50
|
self._origin_model = model
|
|
51
|
+
self.rbln_config = rbln_config
|
|
46
52
|
self.norm1 = model.norm1
|
|
47
53
|
self.norm2 = model.norm2
|
|
48
|
-
|
|
49
|
-
self.attn = VisionAttention(model.attn)
|
|
54
|
+
self.attn = VisionAttention(model.attn, rbln_config)
|
|
50
55
|
self.mlp = model.mlp
|
|
51
56
|
|
|
52
57
|
def forward(
|
|
@@ -65,13 +70,15 @@ class Qwen2VLVisionBlock(torch.nn.Module):
|
|
|
65
70
|
|
|
66
71
|
|
|
67
72
|
class VisionAttention(nn.Module):
|
|
68
|
-
def __init__(self, model: nn.Module) -> None:
|
|
73
|
+
def __init__(self, model: nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig) -> None:
|
|
69
74
|
super().__init__()
|
|
70
75
|
self._origin_model = model
|
|
76
|
+
self.rbln_config = rbln_config
|
|
71
77
|
self.num_heads = model.num_heads
|
|
72
78
|
self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
|
|
73
79
|
self.qkv = model.qkv
|
|
74
80
|
self.proj = model.proj
|
|
81
|
+
self.scale = torch.tensor(1 / math.sqrt(self.head_dim), dtype=rbln_config.dtype)
|
|
75
82
|
|
|
76
83
|
def forward(
|
|
77
84
|
self,
|
|
@@ -88,9 +95,9 @@ class VisionAttention(nn.Module):
|
|
|
88
95
|
cos, sin = position_embeddings
|
|
89
96
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
|
90
97
|
|
|
91
|
-
attn_weights = torch.matmul(q, k.transpose(2, 3))
|
|
98
|
+
attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
|
|
92
99
|
attn_weights = attn_weights + attn_masks
|
|
93
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1
|
|
100
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
|
94
101
|
attn_output = torch.matmul(attn_weights, v)
|
|
95
102
|
attn_output = attn_output.transpose(1, 2)
|
|
96
103
|
attn_output = attn_output.reshape(1, seq_length, -1)
|
|
@@ -100,6 +107,12 @@ class VisionAttention(nn.Module):
|
|
|
100
107
|
|
|
101
108
|
|
|
102
109
|
class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
110
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
111
|
+
return model.model.language_model.layers if hasattr(model, "model") else model.language_model.layers
|
|
112
|
+
|
|
113
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
114
|
+
return model.model.language_model if hasattr(model, "model") else model.language_model
|
|
115
|
+
|
|
103
116
|
def prepare_forward_args(self, *args):
|
|
104
117
|
args = list(args)
|
|
105
118
|
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
@@ -108,7 +121,7 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
|
108
121
|
global_block_tables = args.pop(0)
|
|
109
122
|
local_block_tables = None
|
|
110
123
|
position_embeds = args.pop(0)
|
|
111
|
-
query_position = args.pop(0) if self.phase == "prefill" else None
|
|
124
|
+
query_position = args.pop(0) if self.phase == "prefill" and self.rbln_config.logits_to_keep > 0 else None
|
|
112
125
|
position_ids = None
|
|
113
126
|
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
114
127
|
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
@@ -142,24 +155,3 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
|
|
|
142
155
|
past_key_values,
|
|
143
156
|
position_embeds,
|
|
144
157
|
)
|
|
145
|
-
|
|
146
|
-
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
147
|
-
new_layers = []
|
|
148
|
-
|
|
149
|
-
for layer_idx, layer in enumerate(model.model.language_model.layers):
|
|
150
|
-
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
151
|
-
new_self_attn = self.get_rbln_attn_class()(
|
|
152
|
-
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
153
|
-
)
|
|
154
|
-
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
155
|
-
new_layers.append(new_layer)
|
|
156
|
-
|
|
157
|
-
new_model = self.get_rbln_model_class()(
|
|
158
|
-
model.model.language_model,
|
|
159
|
-
new_layers,
|
|
160
|
-
self.rbln_config,
|
|
161
|
-
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
|
|
165
|
-
return new_model
|