optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -27,6 +27,7 @@ from transformers.modeling_utils import no_init_weights
27
27
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
28
28
  PatchEmbed,
29
29
  Qwen2VisionTransformerPretrainedModel,
30
+ Qwen2VLConfig,
30
31
  Qwen2VLModel,
31
32
  Qwen2VLRotaryEmbedding,
32
33
  VisionRotaryEmbedding,
@@ -35,7 +36,12 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
35
36
  from ....configuration_utils import RBLNCompileConfig
36
37
  from ....modeling import RBLNModel
37
38
  from ....utils.logging import get_logger
38
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput
39
+ from ...modeling_outputs import _validate_output_hidden_states
40
+ from ..decoderonly.modeling_decoderonly import (
41
+ RBLNDecoderOnlyModel,
42
+ RBLNDecoderOnlyModelForCausalLM,
43
+ RBLNDecoderOnlyOutput,
44
+ )
39
45
  from .configuration_qwen2_vl import (
40
46
  RBLNQwen2VisionTransformerPretrainedModelConfig,
41
47
  RBLNQwen2VLForConditionalGenerationConfig,
@@ -56,6 +62,7 @@ if TYPE_CHECKING:
56
62
 
57
63
  class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
58
64
  auto_model_class = None
65
+ _supports_non_fp32 = True
59
66
 
60
67
  def __post_init__(self, **kwargs):
61
68
  self.transformer = self.model[0]
@@ -89,10 +96,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
89
96
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
90
97
 
91
98
  @classmethod
92
- def wrap_model_if_needed(
99
+ def _wrap_model_if_needed(
93
100
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig
94
101
  ):
95
- return Qwen2VisionTransformerWrapper(model).eval()
102
+ return Qwen2VisionTransformerWrapper(model, rbln_config).eval()
96
103
 
97
104
  def __getattr__(self, __name: str) -> Any:
98
105
  def redirect(func):
@@ -112,24 +119,24 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
112
119
  model_config: "PretrainedConfig" = None,
113
120
  rbln_config: Optional[RBLNQwen2VisionTransformerPretrainedModelConfig] = None,
114
121
  ) -> RBLNQwen2VisionTransformerPretrainedModelConfig:
115
- hidden_size = getattr(model_config, "embed_dim")
116
- num_heads = getattr(model_config, "num_heads")
122
+ hidden_size = model_config.embed_dim
123
+ num_heads = model_config.num_heads
117
124
  head_dim = hidden_size // num_heads
118
125
 
119
126
  input_infos = []
120
127
  for max_seq_len in rbln_config.max_seq_lens:
121
128
  input_info = [
122
- ("hidden_states", [max_seq_len, hidden_size], "float32"),
123
- ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
129
+ ("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
130
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
124
131
  (
125
132
  "cos",
126
133
  [1, 1, max_seq_len, head_dim],
127
- "float32",
134
+ rbln_config.dtype,
128
135
  ),
129
136
  (
130
137
  "sin",
131
138
  [1, 1, max_seq_len, head_dim],
132
- "float32",
139
+ rbln_config.dtype,
133
140
  ),
134
141
  ]
135
142
  input_infos.append(input_info)
@@ -166,7 +173,7 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
166
173
  1,
167
174
  max_seq_len,
168
175
  max_seq_len,
169
- dtype=torch.float32,
176
+ dtype=hidden_state.dtype,
170
177
  )
171
178
 
172
179
  full_attn_masks[:, :, hidden_state.shape[0] : max_seq_len, :] = 0
@@ -177,10 +184,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
177
184
  # Processes a batch of images (or frames) through the vision transformer.
178
185
  # Each image is handled independently for padding and attention mask generation.
179
186
 
180
- hidden_states = self.patch_embed(hidden_states)
187
+ hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
181
188
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
182
189
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
183
- position_embeddings = (emb.cos(), emb.sin())
190
+ position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
184
191
 
185
192
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
186
193
  dim=0,
@@ -200,10 +207,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
200
207
  try:
201
208
  cu_index = torch.searchsorted(self.max_seq_lens, cu_seq_len).item()
202
209
  max_seq_len = self.max_seq_lens[cu_index]
203
- except Exception:
210
+ except Exception as e:
204
211
  raise ValueError(
205
212
  f"Required seq_len({cu_seq_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
206
- )
213
+ ) from e
207
214
 
208
215
  # Padding for Full Attention Layers
209
216
  hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks = (
@@ -230,64 +237,48 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
230
237
  return hidden_states
231
238
 
232
239
 
233
- class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
234
- """
235
- RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
236
- optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
237
-
238
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
239
-
240
- Important Note:
241
- This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
242
- tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
243
- `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
244
-
245
- Examples:
246
- ```python
247
- from optimum.rbln import RBLNQwen2VLForConditionalGeneration
248
-
249
- model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
250
- "Qwen/Qwen2-VL-7B-Instruct",
251
- export=True,
252
- rbln_config={
253
- "visual": {
254
- "max_seq_lens": 6400,
255
- "device": 0,
256
- },
257
- "tensor_parallel_size": 8,
258
- "max_seq_len": 32_768,
259
- "device": [0, 1, 2, 3, 4, 5, 6, 7],
260
- },
261
- )
262
-
263
- model.save_pretrained("compiled-qwen2-vl-7b-instruct")
264
- ```
265
- """
266
-
240
+ class RBLNQwen2VLModel(RBLNDecoderOnlyModel):
267
241
  auto_model_class = AutoModelForVision2Seq
242
+ _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
243
+ _supports_non_fp32 = True
244
+ _use_rotary_emb = False
268
245
  _rbln_submodules = [
269
246
  {"name": "visual"},
270
247
  ]
271
- _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
272
- _use_rotary_emb = False
248
+ _config_class = Qwen2VLConfig
249
+ _rotary_emb_class = Qwen2VLRotaryEmbedding
250
+ _get_rope_index_func = Qwen2VLModel.get_rope_index
273
251
 
274
252
  def __post_init__(self, **kwargs):
253
+ if hasattr(self.config, "embedding_dim"):
254
+ self.embedding_dim = self.config.embedding_dim
255
+
256
+ if not isinstance(self.config.text_config, PretrainedConfig):
257
+ self.config = self._config_class(
258
+ text_config=self.config.text_config, vision_config=self.config.vision_config
259
+ )
260
+
275
261
  super().__post_init__(**kwargs)
276
262
  self.visual = self.rbln_submodules[0]
277
- self.mrope_section = self.config.rope_scaling["mrope_section"]
278
- self.rotary_emb = Qwen2VLRotaryEmbedding(self.config)
279
- self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
280
-
281
- def can_generate(self):
282
- return True
263
+ self.rotary_emb = self._rotary_emb_class(self.config)
264
+ if not self.can_generate():
265
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
266
+
267
+ @property
268
+ def logits_last_dim(self):
269
+ if self.can_generate():
270
+ return self.config.vocab_size
271
+ else:
272
+ return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
283
273
 
284
- @classmethod
285
- def get_pytorch_model(cls, *args, **kwargs):
286
- model = super().get_pytorch_model(*args, **kwargs)
287
- model.model.lm_head = model.lm_head
288
- model.lm_head = None
289
- del model.lm_head
290
- return model
274
+ def _create_embedding_layer(self):
275
+ with no_init_weights():
276
+ embed_tokens = torch.nn.Embedding(
277
+ self.config.text_config.vocab_size,
278
+ self.config.text_config.hidden_size,
279
+ self.config.text_config.pad_token_id,
280
+ )
281
+ return embed_tokens
291
282
 
292
283
  @classmethod
293
284
  def get_input_info(
@@ -304,52 +295,25 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
304
295
  (
305
296
  "position_emb",
306
297
  [2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
307
- "float32",
298
+ rbln_config.dtype,
308
299
  ),
309
300
  )
310
301
 
311
302
  return input_info
312
303
 
313
- def prepare_inputs_for_generation(
314
- self,
315
- input_ids: torch.LongTensor,
316
- generate_idx: Optional[torch.Tensor] = None,
317
- attention_mask: Optional[torch.LongTensor] = None,
318
- inputs_embeds: Optional[torch.Tensor] = None,
319
- pixel_values=None,
320
- pixel_values_videos=None,
321
- image_grid_thw=None,
322
- video_grid_thw=None,
323
- **kwargs,
324
- ):
325
- model_inputs = super().prepare_inputs_for_generation(
326
- input_ids,
327
- generate_idx,
328
- attention_mask,
329
- inputs_embeds,
330
- **kwargs,
331
- )
332
-
333
- is_prefill_phase = generate_idx is None
334
- if is_prefill_phase:
335
- model_inputs.update({"input_ids": input_ids})
336
-
337
- model_inputs.update(
338
- {
339
- "pixel_values": pixel_values,
340
- "pixel_values_videos": pixel_values_videos,
341
- "image_grid_thw": image_grid_thw,
342
- "video_grid_thw": video_grid_thw,
343
- }
344
- )
345
-
346
- return model_inputs
347
-
348
304
  def _get_position_embeddings(self, hidden_states, position_ids):
349
305
  cos, sin = self.rotary_emb(hidden_states, position_ids)
350
- mrope_section = self.mrope_section * 2
351
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
352
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
306
+ mrope_section = self.config.rope_scaling["mrope_section"] * 2
307
+ cos = (
308
+ torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
309
+ .unsqueeze(1)
310
+ .to(self.rbln_config.dtype)
311
+ )
312
+ sin = (
313
+ torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
314
+ .unsqueeze(1)
315
+ .to(self.rbln_config.dtype)
316
+ )
353
317
  return torch.stack([cos, sin])
354
318
 
355
319
  def _preprocess_prefill(
@@ -362,7 +326,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
362
326
  video_grid_thw: torch.LongTensor = None,
363
327
  ):
364
328
  batch_size = input_ids.shape[0]
365
- inputs_embeds = self.embed_tokens(input_ids)
329
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
366
330
 
367
331
  if pixel_values is not None:
368
332
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -397,7 +361,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
397
361
  max_inputs_len = input_ids.shape[1]
398
362
 
399
363
  head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
400
- all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
364
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
401
365
  all_rope_deltas = []
402
366
 
403
367
  image_token_id = self.config.image_token_id
@@ -411,8 +375,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
411
375
  vision_tokens = input_id[0][vision_start_indices + 1]
412
376
  image_nums = (vision_tokens == image_token_id).sum()
413
377
  video_nums = (vision_tokens == video_token_id).sum()
414
- position_ids, rope_deltas = Qwen2VLModel.get_rope_index(
415
- self,
378
+ position_ids, rope_deltas = self._get_rope_index_func(
416
379
  input_id,
417
380
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
418
381
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -429,6 +392,177 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
429
392
 
430
393
  return inputs_embeds, all_position_embeds, rope_deltas
431
394
 
395
+ def forward(
396
+ self,
397
+ input_ids: Optional[torch.LongTensor] = None,
398
+ inputs_embeds: Optional[torch.FloatTensor] = None,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ pixel_values: Optional[torch.Tensor] = None,
401
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
402
+ image_grid_thw: Optional[torch.LongTensor] = None,
403
+ video_grid_thw: Optional[torch.LongTensor] = None,
404
+ cache_position: Optional[torch.LongTensor] = None,
405
+ output_hidden_states: Optional[bool] = None,
406
+ return_dict: Optional[bool] = None,
407
+ **kwargs,
408
+ ) -> RBLNDecoderOnlyOutput:
409
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
410
+ input_ids,
411
+ attention_mask,
412
+ pixel_values,
413
+ pixel_values_videos,
414
+ image_grid_thw,
415
+ video_grid_thw,
416
+ )
417
+
418
+ self.rope_deltas = rope_deltas
419
+ batch_size, seq_len = inputs_embeds.shape[:2]
420
+
421
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
422
+
423
+ all_hidden_states = (
424
+ tuple(
425
+ torch.zeros(
426
+ batch_size,
427
+ seq_len,
428
+ self.config.hidden_size,
429
+ dtype=self.rbln_config.dtype,
430
+ )
431
+ for _ in range(self.config.num_hidden_layers + 1)
432
+ )
433
+ if output_hidden_states
434
+ else None
435
+ )
436
+
437
+ logits = []
438
+ for b_idx in range(batch_size):
439
+ query_length = attention_mask[b_idx].sum(dim=-1).int().item()
440
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
441
+
442
+ outputs = self.prefill_decoder(
443
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
444
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
445
+ cache_position=cache_position,
446
+ batch_idx=b_idx,
447
+ position_embed=position_embed[:, b_idx : b_idx + 1],
448
+ block_tables=self.block_tables,
449
+ )
450
+
451
+ logits.append(outputs.logits)
452
+ if self.rbln_config.output_hidden_states:
453
+ for l_idx in range(self.config.num_hidden_layers + 1):
454
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
455
+
456
+ logits = torch.cat(logits, dim=0)
457
+
458
+ if not return_dict:
459
+ return_value = logits if not output_hidden_states else (logits, all_hidden_states)
460
+ return return_value
461
+ else:
462
+ return (
463
+ RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
464
+ if output_hidden_states
465
+ else RBLNDecoderOnlyOutput(logits=logits)
466
+ )
467
+
468
+
469
+ # MRO: RBLNQwen2VLForConditionalGeneration -> RBLNQwen2VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
470
+ class RBLNQwen2VLForConditionalGeneration(RBLNQwen2VLModel, RBLNDecoderOnlyModelForCausalLM):
471
+ """
472
+ RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
473
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
474
+
475
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
476
+
477
+ Important Note:
478
+ This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
479
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
480
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
481
+
482
+ Examples:
483
+ ```python
484
+ from optimum.rbln import RBLNQwen2VLForConditionalGeneration
485
+
486
+ model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
487
+ "Qwen/Qwen2-VL-7B-Instruct",
488
+ export=True,
489
+ rbln_config={
490
+ "visual": {
491
+ "max_seq_lens": 6400,
492
+ "device": 0,
493
+ },
494
+ "tensor_parallel_size": 8,
495
+ "max_seq_len": 32_768,
496
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
497
+ },
498
+ )
499
+
500
+ model.save_pretrained("compiled-qwen2-vl-7b-instruct")
501
+ ```
502
+ """
503
+
504
+ auto_model_class = AutoModelForVision2Seq
505
+ _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
506
+ _supports_non_fp32 = True
507
+ _use_rotary_emb = False
508
+ _rbln_submodules = [
509
+ {"name": "visual"},
510
+ ]
511
+
512
+ def __post_init__(self, **kwargs):
513
+ super().__post_init__(**kwargs)
514
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
515
+
516
+ def can_generate(self):
517
+ return True
518
+
519
+ @classmethod
520
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
521
+ model.model.lm_head = model.lm_head
522
+ return model
523
+
524
+ def prepare_inputs_for_generation(
525
+ self,
526
+ input_ids: torch.LongTensor,
527
+ generate_idx: Optional[torch.Tensor] = None,
528
+ attention_mask: Optional[torch.LongTensor] = None,
529
+ inputs_embeds: Optional[torch.Tensor] = None,
530
+ pixel_values=None,
531
+ pixel_values_videos=None,
532
+ image_grid_thw=None,
533
+ video_grid_thw=None,
534
+ **kwargs,
535
+ ):
536
+ model_inputs = {}
537
+ is_prefill_phase = generate_idx is None
538
+
539
+ if is_prefill_phase:
540
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
541
+ cache_position = None
542
+ model_inputs.update({"input_ids": input_ids})
543
+ else:
544
+ if inputs_embeds is not None:
545
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
546
+
547
+ input_ids = input_ids[:, -1:]
548
+ cache_position = generate_idx
549
+ generate_idx = generate_idx + 1
550
+ model_inputs.update({"input_ids": input_ids})
551
+
552
+ model_inputs.update(
553
+ {
554
+ "attention_mask": attention_mask,
555
+ "cache_position": cache_position,
556
+ "generate_idx": generate_idx,
557
+ "pixel_values": pixel_values,
558
+ "pixel_values_videos": pixel_values_videos,
559
+ "image_grid_thw": image_grid_thw,
560
+ "video_grid_thw": video_grid_thw,
561
+ }
562
+ )
563
+
564
+ return model_inputs
565
+
432
566
  def _preprocess_decoder(
433
567
  self,
434
568
  input_ids: torch.LongTensor = None,
@@ -439,14 +573,14 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
439
573
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
440
574
  )
441
575
 
442
- inputs_embeds = self.embed_tokens(input_ids)
576
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
443
577
  position_embeds = []
444
578
  for b_idx in range(self.rbln_config.batch_size):
445
579
  delta = cache_position[b_idx] + self.rope_deltas[b_idx]
446
580
  position_ids = torch.arange(1).view(1, -1)
447
581
  position_ids = position_ids.add(delta)
448
582
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
449
- position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
583
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
450
584
  position_embeds.append(position_embed)
451
585
 
452
586
  position_embeds = torch.cat(position_embeds, dim=1)
@@ -465,8 +599,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
465
599
  cache_position: Optional[torch.LongTensor] = None,
466
600
  generate_idx: Optional[torch.Tensor] = None,
467
601
  return_dict: Optional[bool] = None,
602
+ output_hidden_states: Optional[bool] = None,
468
603
  **kwargs,
469
604
  ) -> RBLNDecoderOnlyOutput:
605
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
470
606
  # Prefill
471
607
  if cache_position is None:
472
608
  inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
@@ -478,8 +614,21 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
478
614
  video_grid_thw,
479
615
  )
480
616
 
617
+ batch_size, seq_len = inputs_embeds.shape[:2]
618
+ all_hidden_states = (
619
+ tuple(
620
+ torch.zeros(
621
+ batch_size,
622
+ seq_len,
623
+ self.config.hidden_size,
624
+ dtype=self.rbln_config.dtype,
625
+ )
626
+ for _ in range(self.config.num_hidden_layers + 1)
627
+ )
628
+ if output_hidden_states
629
+ else None
630
+ )
481
631
  self.rope_deltas = rope_deltas
482
- batch_size = inputs_embeds.shape[0]
483
632
 
484
633
  logits = []
485
634
  for b_idx in range(batch_size):
@@ -493,8 +642,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
493
642
  position_embed=position_embed[:, b_idx : b_idx + 1],
494
643
  )
495
644
  logits.append(output.logits)
645
+ if self.rbln_config.output_hidden_states:
646
+ for l_idx in range(self.config.num_hidden_layers + 1):
647
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
496
648
  logits = torch.cat(logits, dim=0)
497
-
498
649
  # Decoder
499
650
  else:
500
651
  inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
@@ -504,11 +655,17 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
504
655
  position_embed=position_embed,
505
656
  )
506
657
  logits = output.logits
658
+ all_hidden_states = output.hidden_states
507
659
 
508
660
  if not return_dict:
509
- return logits, generate_idx
661
+ return_value = (
662
+ logits,
663
+ generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
664
+ )
665
+ return return_value
510
666
  else:
511
667
  return RBLNDecoderOnlyOutput(
512
668
  logits=logits,
513
669
  generate_idx=generate_idx,
670
+ hidden_states=all_hidden_states,
514
671
  )
@@ -9,19 +9,24 @@ from ..decoderonly.decoderonly_architecture import (
9
9
  DecoderOnlyWrapper,
10
10
  apply_rotary_pos_emb,
11
11
  )
12
+ from .configuration_qwen2_vl import RBLNQwen2VisionTransformerPretrainedModelConfig
12
13
 
13
14
 
14
15
  class Qwen2VisionTransformerWrapper(nn.Module):
15
- def __init__(self, model: torch.nn.Module):
16
+ def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
16
17
  super().__init__()
17
- self._original_mod = model
18
18
  self.merger = model.merger
19
- self.blocks = self.wrap_vision_blocks(model.blocks)
19
+ self.rbln_config = rbln_config
20
+ self.blocks = self.wrap_vision_blocks(model.blocks, rbln_config)
20
21
 
21
- def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
22
+ def wrap_vision_blocks(
23
+ self,
24
+ blocks: torch.nn.ModuleList,
25
+ rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig,
26
+ ):
22
27
  wrapped_blocks = []
23
- for i, block in enumerate(blocks):
24
- wrapped_blocks.append(Qwen2VLVisionBlock(block))
28
+ for _, block in enumerate(blocks):
29
+ wrapped_blocks.append(Qwen2VLVisionBlock(block, rbln_config))
25
30
  return nn.ModuleList(wrapped_blocks)
26
31
 
27
32
  def forward(
@@ -31,7 +36,7 @@ class Qwen2VisionTransformerWrapper(nn.Module):
31
36
  cos: torch.Tensor,
32
37
  sin: torch.Tensor,
33
38
  ):
34
- full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
39
+ full_attn_masks = (1.0 - full_attn_masks) * torch.finfo(hidden_states.dtype).min
35
40
 
36
41
  for block in self.blocks:
37
42
  hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
@@ -40,13 +45,13 @@ class Qwen2VisionTransformerWrapper(nn.Module):
40
45
 
41
46
 
42
47
  class Qwen2VLVisionBlock(torch.nn.Module):
43
- def __init__(self, model: torch.nn.Module):
48
+ def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
44
49
  super().__init__()
45
50
  self._origin_model = model
51
+ self.rbln_config = rbln_config
46
52
  self.norm1 = model.norm1
47
53
  self.norm2 = model.norm2
48
-
49
- self.attn = VisionAttention(model.attn)
54
+ self.attn = VisionAttention(model.attn, rbln_config)
50
55
  self.mlp = model.mlp
51
56
 
52
57
  def forward(
@@ -65,13 +70,15 @@ class Qwen2VLVisionBlock(torch.nn.Module):
65
70
 
66
71
 
67
72
  class VisionAttention(nn.Module):
68
- def __init__(self, model: nn.Module) -> None:
73
+ def __init__(self, model: nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig) -> None:
69
74
  super().__init__()
70
75
  self._origin_model = model
76
+ self.rbln_config = rbln_config
71
77
  self.num_heads = model.num_heads
72
78
  self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
73
79
  self.qkv = model.qkv
74
80
  self.proj = model.proj
81
+ self.scale = torch.tensor(1 / math.sqrt(self.head_dim), dtype=rbln_config.dtype)
75
82
 
76
83
  def forward(
77
84
  self,
@@ -88,9 +95,9 @@ class VisionAttention(nn.Module):
88
95
  cos, sin = position_embeddings
89
96
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
90
97
 
91
- attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
98
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
92
99
  attn_weights = attn_weights + attn_masks
93
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
100
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
94
101
  attn_output = torch.matmul(attn_weights, v)
95
102
  attn_output = attn_output.transpose(1, 2)
96
103
  attn_output = attn_output.reshape(1, seq_length, -1)
@@ -100,6 +107,12 @@ class VisionAttention(nn.Module):
100
107
 
101
108
 
102
109
  class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
110
+ def get_decoder_layers(self, model: PreTrainedModel):
111
+ return model.model.language_model.layers if hasattr(model, "model") else model.language_model.layers
112
+
113
+ def get_model_layer(self, model: PreTrainedModel):
114
+ return model.model.language_model if hasattr(model, "model") else model.language_model
115
+
103
116
  def prepare_forward_args(self, *args):
104
117
  args = list(args)
105
118
  input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
@@ -108,7 +121,7 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
108
121
  global_block_tables = args.pop(0)
109
122
  local_block_tables = None
110
123
  position_embeds = args.pop(0)
111
- query_position = args.pop(0) if self.phase == "prefill" else None
124
+ query_position = args.pop(0) if self.phase == "prefill" and self.rbln_config.logits_to_keep > 0 else None
112
125
  position_ids = None
113
126
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
114
127
  lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
@@ -142,24 +155,3 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
142
155
  past_key_values,
143
156
  position_embeds,
144
157
  )
145
-
146
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
147
- new_layers = []
148
-
149
- for layer_idx, layer in enumerate(model.model.language_model.layers):
150
- is_sliding = layer_idx in self.rbln_config.sliding_window_layers
151
- new_self_attn = self.get_rbln_attn_class()(
152
- self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
- )
154
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
- new_layers.append(new_layer)
156
-
157
- new_model = self.get_rbln_model_class()(
158
- model.model.language_model,
159
- new_layers,
160
- self.rbln_config,
161
- use_learned_pos_emb=self.__class__._use_learned_pos_emb,
162
- )
163
-
164
- new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
165
- return new_model