optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -17,7 +17,13 @@ from pathlib import Path
17
17
  from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
- from transformers import AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, Qwen2_5_VLForConditionalGeneration
20
+ from transformers import (
21
+ AutoModelForVision2Seq,
22
+ PretrainedConfig,
23
+ PreTrainedModel,
24
+ Qwen2_5_VLConfig,
25
+ Qwen2_5_VLForConditionalGeneration,
26
+ )
21
27
  from transformers.modeling_utils import no_init_weights
22
28
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
23
29
  Qwen2_5_VisionPatchEmbed,
@@ -30,8 +36,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
30
36
  from ....configuration_utils import RBLNCompileConfig
31
37
  from ....modeling import RBLNModel
32
38
  from ....utils.logging import get_logger
33
- from ...modeling_outputs import RBLNDecoderOnlyOutput
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
39
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
40
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
35
41
  from .configuration_qwen2_5_vl import (
36
42
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
37
43
  RBLNQwen2_5_VLForConditionalGenerationConfig,
@@ -42,7 +48,7 @@ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_V
42
48
  logger = get_logger(__name__)
43
49
 
44
50
  if TYPE_CHECKING:
45
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
51
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
46
52
 
47
53
 
48
54
  class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
@@ -55,6 +61,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
55
61
  """
56
62
 
57
63
  auto_model_class = None
64
+ _supports_non_fp32 = True
58
65
 
59
66
  def __post_init__(self, **kwargs):
60
67
  self.transformer = self.model[0]
@@ -88,10 +95,10 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
88
95
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
89
96
 
90
97
  @classmethod
91
- def wrap_model_if_needed(
98
+ def _wrap_model_if_needed(
92
99
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
93
100
  ):
94
- return Qwen2_5_VisionTransformerWrapper(model).eval()
101
+ return Qwen2_5_VisionTransformerWrapper(model, rbln_config).eval()
95
102
 
96
103
  def __getattr__(self, __name: str) -> Any:
97
104
  def redirect(func):
@@ -111,10 +118,10 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
111
118
  model_config: "PretrainedConfig" = None,
112
119
  rbln_config: Optional[RBLNQwen2_5_VisionTransformerPretrainedModelConfig] = None,
113
120
  ) -> RBLNQwen2_5_VisionTransformerPretrainedModelConfig:
114
- window_size = getattr(model_config, "window_size")
115
- patch_size = getattr(model_config, "patch_size")
116
- hidden_size = getattr(model_config, "hidden_size")
117
- num_heads = getattr(model_config, "num_heads")
121
+ window_size = model_config.window_size
122
+ patch_size = model_config.patch_size
123
+ hidden_size = model_config.hidden_size
124
+ num_heads = model_config.num_heads
118
125
  head_dim = hidden_size // num_heads
119
126
  window_seq_len = (window_size // patch_size) ** 2
120
127
 
@@ -126,22 +133,22 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
126
133
  )
127
134
 
128
135
  input_info = [
129
- ("hidden_states", [max_seq_len, hidden_size], "float32"),
130
- ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
136
+ ("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
137
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
131
138
  (
132
139
  "window_attn_masks",
133
140
  [max_seq_len // window_seq_len, 1, window_seq_len, window_seq_len],
134
- "float32",
141
+ rbln_config.dtype,
135
142
  ),
136
143
  (
137
144
  "cos",
138
145
  [1, 1, max_seq_len, head_dim],
139
- "float32",
146
+ rbln_config.dtype,
140
147
  ),
141
148
  (
142
149
  "sin",
143
150
  [1, 1, max_seq_len, head_dim],
144
- "float32",
151
+ rbln_config.dtype,
145
152
  ),
146
153
  ]
147
154
  input_infos.append(input_info)
@@ -203,7 +210,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
203
210
  1,
204
211
  window_seq_len,
205
212
  window_seq_len,
206
- dtype=torch.float32,
213
+ dtype=hidden_states.dtype,
207
214
  )
208
215
  for i, valid_len in enumerate(window_valid_lengths):
209
216
  if valid_len < window_seq_len:
@@ -242,7 +249,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
242
249
  1,
243
250
  max_seq_len,
244
251
  max_seq_len,
245
- dtype=torch.float32,
252
+ dtype=hidden_state_padded.dtype,
246
253
  )
247
254
  for i, valid_len in enumerate(window_valid_lengths):
248
255
  start = i * window_seq_len
@@ -253,7 +260,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
253
260
  return hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks
254
261
 
255
262
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
256
- hidden_states = self.patch_embed(hidden_states)
263
+ hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
257
264
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
258
265
  window_index, cu_window_seqlens = self.get_window_index(grid_thw)
259
266
  cu_window_seqlens = torch.tensor(
@@ -270,7 +277,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
270
277
  rotary_pos_emb = rotary_pos_emb[window_index, :, :]
271
278
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
272
279
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
273
- position_embeddings = (emb.cos(), emb.sin())
280
+ position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
274
281
 
275
282
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
276
283
  dim=0,
@@ -294,10 +301,10 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
294
301
  try:
295
302
  ws_index = torch.searchsorted(self.max_seq_lens, window_padded_len).item()
296
303
  max_seq_len = self.max_seq_lens[ws_index]
297
- except Exception:
304
+ except Exception as e:
298
305
  raise ValueError(
299
306
  f"Required seq_len({window_padded_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
300
- )
307
+ ) from e
301
308
 
302
309
  # Padding for Window Attention Layers
303
310
  hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths = (
@@ -338,67 +345,47 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
338
345
  return hidden_states
339
346
 
340
347
 
341
- class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
342
- """
343
- RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
344
- optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
345
-
346
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
347
-
348
- Important Note:
349
- This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
350
- tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
351
- `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
352
-
353
- Examples:
354
- ```python
355
- from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
356
-
357
- model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
358
- "Qwen/Qwen2.5-VL-7B-Instruct",
359
- export=True,
360
- rbln_config={
361
- "visual": {
362
- "max_seq_lens": 6400,
363
- "device": 0,
364
- },
365
- "tensor_parallel_size": 8,
366
- "kvcache_partition_len": 16_384,
367
- "max_seq_len": 114_688,
368
- "device": [0, 1, 2, 3, 4, 5, 6, 7],
369
- },
370
- )
371
-
372
- model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
373
- ```
374
- """
375
-
376
- _supports_non_fp32 = False
377
-
348
+ class RBLNQwen2_5_VLModel(RBLNDecoderOnlyModel):
378
349
  auto_model_class = AutoModelForVision2Seq
350
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
351
+ _use_rotary_emb = False
379
352
  _rbln_submodules = [
380
353
  {"name": "visual"},
381
354
  ]
382
- _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
383
- _use_rotary_emb = False
355
+ _config_class = Qwen2_5_VLConfig
356
+ _rotary_emb_class = Qwen2_5_VLRotaryEmbedding
357
+ _get_rope_index_func = Qwen2_5_VLModel.get_rope_index
384
358
 
385
359
  def __post_init__(self, **kwargs):
360
+ if hasattr(self.config, "embedding_dim"):
361
+ self.embedding_dim = self.config.embedding_dim
362
+
363
+ if not isinstance(self.config.text_config, PretrainedConfig):
364
+ self.config = self._config_class(
365
+ text_config=self.config.text_config, vision_config=self.config.vision_config
366
+ )
367
+
386
368
  super().__post_init__(**kwargs)
387
369
  self.visual = self.rbln_submodules[0]
388
- self.mrope_section = self.config.rope_scaling["mrope_section"]
389
- self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
390
- self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
391
-
392
- def can_generate(self):
393
- return True
370
+ self.rotary_emb = self._rotary_emb_class(self.config)
371
+ if not self.can_generate():
372
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
373
+
374
+ @property
375
+ def logits_last_dim(self):
376
+ if self.can_generate():
377
+ return self.config.vocab_size
378
+ else:
379
+ return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
394
380
 
395
- @classmethod
396
- def get_pytorch_model(cls, *args, **kwargs):
397
- model = super().get_pytorch_model(*args, **kwargs)
398
- model.model.lm_head = model.lm_head
399
- model.lm_head = None
400
- del model.lm_head
401
- return model
381
+ def _create_embedding_layer(self):
382
+ with no_init_weights():
383
+ embed_tokens = torch.nn.Embedding(
384
+ self.config.text_config.vocab_size,
385
+ self.config.text_config.hidden_size,
386
+ self.config.text_config.pad_token_id,
387
+ )
388
+ return embed_tokens
402
389
 
403
390
  @classmethod
404
391
  def get_input_info(
@@ -415,61 +402,25 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
415
402
  (
416
403
  "position_emb",
417
404
  [2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
418
- "float32",
405
+ rbln_config.dtype,
419
406
  ),
420
407
  )
421
408
 
422
409
  return input_info
423
410
 
424
- def prepare_inputs_for_generation(
425
- self,
426
- input_ids: torch.LongTensor,
427
- generate_idx: Optional[torch.Tensor] = None,
428
- attention_mask: Optional[torch.LongTensor] = None,
429
- inputs_embeds: Optional[torch.Tensor] = None,
430
- pixel_values=None,
431
- pixel_values_videos=None,
432
- image_grid_thw=None,
433
- video_grid_thw=None,
434
- second_per_grid_ts=None,
435
- **kwargs,
436
- ):
437
- model_inputs = {}
438
- is_prefill_phase = generate_idx is None
439
-
440
- if is_prefill_phase:
441
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
442
- cache_position = None
443
- model_inputs.update({"input_ids": input_ids})
444
- else:
445
- if inputs_embeds is not None:
446
- raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
447
-
448
- input_ids = input_ids[:, -1:]
449
- cache_position = generate_idx
450
- generate_idx = generate_idx + 1
451
- model_inputs.update({"input_ids": input_ids})
452
-
453
- model_inputs.update(
454
- {
455
- "attention_mask": attention_mask,
456
- "cache_position": cache_position,
457
- "generate_idx": generate_idx,
458
- "pixel_values": pixel_values,
459
- "pixel_values_videos": pixel_values_videos,
460
- "image_grid_thw": image_grid_thw,
461
- "video_grid_thw": video_grid_thw,
462
- "second_per_grid_ts": second_per_grid_ts,
463
- }
464
- )
465
-
466
- return model_inputs
467
-
468
411
  def _get_position_embeddings(self, hidden_states, position_ids):
469
412
  cos, sin = self.rotary_emb(hidden_states, position_ids)
470
- mrope_section = self.mrope_section * 2
471
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
472
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
413
+ mrope_section = self.config.rope_scaling["mrope_section"] * 2
414
+ cos = (
415
+ torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
416
+ .unsqueeze(1)
417
+ .to(self.rbln_config.dtype)
418
+ )
419
+ sin = (
420
+ torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
421
+ .unsqueeze(1)
422
+ .to(self.rbln_config.dtype)
423
+ )
473
424
  return torch.stack([cos, sin])
474
425
 
475
426
  def _preprocess_prefill(
@@ -483,7 +434,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
483
434
  second_per_grid_ts: torch.Tensor = None,
484
435
  ):
485
436
  batch_size = input_ids.shape[0]
486
- inputs_embeds = self.embed_tokens(input_ids)
437
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
487
438
 
488
439
  if pixel_values is not None:
489
440
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -518,7 +469,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
518
469
  max_inputs_len = input_ids.shape[1]
519
470
 
520
471
  head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
521
- all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
472
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
522
473
  all_rope_deltas = []
523
474
 
524
475
  image_token_id = self.config.image_token_id
@@ -532,8 +483,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
532
483
  vision_tokens = input_id[0][vision_start_indices + 1]
533
484
  image_nums = (vision_tokens == image_token_id).sum()
534
485
  video_nums = (vision_tokens == video_token_id).sum()
535
- position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
536
- self,
486
+ position_ids, rope_deltas = self._get_rope_index_func(
537
487
  input_id,
538
488
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
539
489
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -551,6 +501,180 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
551
501
 
552
502
  return inputs_embeds, all_position_embeds, rope_deltas
553
503
 
504
+ def forward(
505
+ self,
506
+ input_ids: Optional[torch.LongTensor] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ attention_mask: Optional[torch.Tensor] = None,
509
+ pixel_values: Optional[torch.Tensor] = None,
510
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
511
+ image_grid_thw: Optional[torch.LongTensor] = None,
512
+ video_grid_thw: Optional[torch.LongTensor] = None,
513
+ cache_position: Optional[torch.LongTensor] = None,
514
+ second_per_grid_ts: Optional[torch.Tensor] = None,
515
+ output_hidden_states: Optional[bool] = None,
516
+ return_dict: Optional[bool] = None,
517
+ **kwargs,
518
+ ) -> RBLNDecoderOnlyOutput:
519
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
520
+ input_ids,
521
+ attention_mask,
522
+ pixel_values,
523
+ pixel_values_videos,
524
+ image_grid_thw,
525
+ video_grid_thw,
526
+ second_per_grid_ts,
527
+ )
528
+
529
+ self.rope_deltas = rope_deltas
530
+ batch_size, seq_len = inputs_embeds.shape[:2]
531
+
532
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
533
+
534
+ all_hidden_states = (
535
+ tuple(
536
+ torch.zeros(
537
+ batch_size,
538
+ seq_len,
539
+ self.config.hidden_size,
540
+ dtype=self.rbln_config.dtype,
541
+ )
542
+ for _ in range(self.config.num_hidden_layers + 1)
543
+ )
544
+ if output_hidden_states
545
+ else None
546
+ )
547
+
548
+ logits = []
549
+ for b_idx in range(batch_size):
550
+ query_length = attention_mask[b_idx].sum(dim=-1).int().item()
551
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
552
+
553
+ output = self.prefill_decoder(
554
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
555
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
556
+ cache_position=cache_position,
557
+ batch_idx=b_idx,
558
+ position_embed=position_embed[:, b_idx : b_idx + 1],
559
+ block_tables=self.block_tables,
560
+ )
561
+ logits.append(output.logits)
562
+ if self.rbln_config.output_hidden_states:
563
+ for l_idx in range(self.config.num_hidden_layers + 1):
564
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
565
+ logits = torch.cat(logits, dim=0)
566
+
567
+ if not return_dict:
568
+ return_value = logits if not output_hidden_states else (logits, all_hidden_states)
569
+ return return_value
570
+ else:
571
+ return (
572
+ RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
573
+ if output_hidden_states
574
+ else RBLNDecoderOnlyOutput(logits=logits)
575
+ )
576
+
577
+
578
+ # MRO: RBLNQwen2_5_VLForConditionalGeneration -> RBLNQwen2_5_VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
579
+ class RBLNQwen2_5_VLForConditionalGeneration(RBLNQwen2_5_VLModel, RBLNDecoderOnlyModelForCausalLM):
580
+ """
581
+ RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
582
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
583
+
584
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
585
+
586
+ Important Note:
587
+ This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
588
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
589
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
590
+
591
+ Examples:
592
+ ```python
593
+ from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
594
+
595
+ model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
596
+ "Qwen/Qwen2.5-VL-7B-Instruct",
597
+ export=True,
598
+ rbln_config={
599
+ "visual": {
600
+ "max_seq_lens": 6400,
601
+ "device": 0,
602
+ },
603
+ "tensor_parallel_size": 8,
604
+ "kvcache_partition_len": 16_384,
605
+ "max_seq_len": 114_688,
606
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
607
+ },
608
+ )
609
+
610
+ model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
611
+ ```
612
+ """
613
+
614
+ auto_model_class = AutoModelForVision2Seq
615
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
616
+ _supports_non_fp32 = True
617
+ _use_rotary_emb = False
618
+ _rbln_submodules = [
619
+ {"name": "visual"},
620
+ ]
621
+
622
+ def __post_init__(self, **kwargs):
623
+ super().__post_init__(**kwargs)
624
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
625
+
626
+ def can_generate(self):
627
+ return True
628
+
629
+ @classmethod
630
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
631
+ model.model.lm_head = model.lm_head
632
+ return model
633
+
634
+ def prepare_inputs_for_generation(
635
+ self,
636
+ input_ids: torch.LongTensor,
637
+ generate_idx: Optional[torch.Tensor] = None,
638
+ attention_mask: Optional[torch.LongTensor] = None,
639
+ inputs_embeds: Optional[torch.Tensor] = None,
640
+ pixel_values=None,
641
+ pixel_values_videos=None,
642
+ image_grid_thw=None,
643
+ video_grid_thw=None,
644
+ second_per_grid_ts=None,
645
+ **kwargs,
646
+ ):
647
+ model_inputs = {}
648
+ is_prefill_phase = generate_idx is None
649
+
650
+ if is_prefill_phase:
651
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
652
+ cache_position = None
653
+ model_inputs.update({"input_ids": input_ids})
654
+ else:
655
+ if inputs_embeds is not None:
656
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
657
+
658
+ input_ids = input_ids[:, -1:]
659
+ cache_position = generate_idx
660
+ generate_idx = generate_idx + 1
661
+ model_inputs.update({"input_ids": input_ids})
662
+
663
+ model_inputs.update(
664
+ {
665
+ "attention_mask": attention_mask,
666
+ "cache_position": cache_position,
667
+ "generate_idx": generate_idx,
668
+ "pixel_values": pixel_values,
669
+ "pixel_values_videos": pixel_values_videos,
670
+ "image_grid_thw": image_grid_thw,
671
+ "video_grid_thw": video_grid_thw,
672
+ "second_per_grid_ts": second_per_grid_ts,
673
+ }
674
+ )
675
+
676
+ return model_inputs
677
+
554
678
  def _preprocess_decoder(
555
679
  self,
556
680
  input_ids: torch.LongTensor = None,
@@ -561,14 +685,14 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
561
685
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
562
686
  )
563
687
 
564
- inputs_embeds = self.embed_tokens(input_ids)
688
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
565
689
  position_embeds = []
566
690
  for b_idx in range(self.rbln_config.batch_size):
567
691
  delta = cache_position[b_idx] + self.rope_deltas[b_idx]
568
692
  position_ids = torch.arange(1).view(1, -1)
569
693
  position_ids = position_ids.add(delta)
570
694
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
571
- position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
695
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
572
696
  position_embeds.append(position_embed)
573
697
 
574
698
  position_embeds = torch.cat(position_embeds, dim=1)
@@ -588,8 +712,10 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
588
712
  second_per_grid_ts: Optional[torch.Tensor] = None,
589
713
  generate_idx: Optional[torch.Tensor] = None,
590
714
  return_dict: Optional[bool] = None,
715
+ output_hidden_states: Optional[bool] = None,
591
716
  **kwargs,
592
717
  ) -> RBLNDecoderOnlyOutput:
718
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
593
719
  # Prefill
594
720
  if cache_position is None:
595
721
  inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
@@ -602,8 +728,21 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
602
728
  second_per_grid_ts,
603
729
  )
604
730
 
731
+ batch_size, seq_len = inputs_embeds.shape[:2]
732
+ all_hidden_states = (
733
+ tuple(
734
+ torch.zeros(
735
+ batch_size,
736
+ seq_len,
737
+ self.config.hidden_size,
738
+ dtype=self.rbln_config.dtype,
739
+ )
740
+ for _ in range(self.config.num_hidden_layers + 1)
741
+ )
742
+ if output_hidden_states
743
+ else None
744
+ )
605
745
  self.rope_deltas = rope_deltas
606
- batch_size = inputs_embeds.shape[0]
607
746
 
608
747
  logits = []
609
748
  for b_idx in range(batch_size):
@@ -617,8 +756,11 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
617
756
  position_embed=position_embed[:, b_idx : b_idx + 1],
618
757
  )
619
758
  logits.append(output.logits)
759
+ if self.rbln_config.output_hidden_states:
760
+ for l_idx in range(self.config.num_hidden_layers + 1):
761
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
620
762
  logits = torch.cat(logits, dim=0)
621
- # Decoder
763
+ # Decoder
622
764
  else:
623
765
  inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
624
766
  output = self.decoder(
@@ -627,11 +769,17 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
627
769
  position_embed=position_embed,
628
770
  )
629
771
  logits = output.logits
772
+ all_hidden_states = output.hidden_states
630
773
 
631
774
  if not return_dict:
632
- return logits, generate_idx
775
+ return_value = (
776
+ logits,
777
+ generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
778
+ )
779
+ return return_value
633
780
  else:
634
781
  return RBLNDecoderOnlyOutput(
635
782
  logits=logits,
636
783
  generate_idx=generate_idx,
784
+ hidden_states=all_hidden_states,
637
785
  )