optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -20,7 +20,6 @@ from transformers import PhiForCausalLM
20
20
  from ..decoderonly.decoderonly_architecture import (
21
21
  DecoderOnlyAttention,
22
22
  DecoderOnlyLayer,
23
- DecoderOnlyModel,
24
23
  DecoderOnlyWrapper,
25
24
  apply_rotary_pos_emb_partial,
26
25
  )
@@ -37,9 +36,6 @@ class PhiWrapper(DecoderOnlyWrapper):
37
36
  def get_rbln_layer_class(self):
38
37
  return PhiLayer
39
38
 
40
- def get_rbln_model_class(self):
41
- return PhiModel
42
-
43
39
  def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
44
40
  return model.model if self.is_causal_lm else model
45
41
 
@@ -48,13 +44,15 @@ class PhiWrapper(DecoderOnlyWrapper):
48
44
 
49
45
 
50
46
  class PhiAttention(DecoderOnlyAttention):
51
- def __post_init__(self):
52
- self.q_proj = self._original_mod.q_proj
53
- self.k_proj = self._original_mod.k_proj
54
- self.v_proj = self._original_mod.v_proj
55
- self.o_proj = self._original_mod.dense
56
- self.qk_layernorm = self._original_mod.qk_layernorm
57
- self.rotary_ndims = self._original_mod.rotary_ndims
47
+ def __post_init__(self, self_attn):
48
+ self.q_proj = self_attn.q_proj
49
+ self.k_proj = self_attn.k_proj
50
+ self.v_proj = self_attn.v_proj
51
+ self.o_proj = self_attn.dense
52
+ self.qk_layernorm = self_attn.qk_layernorm
53
+ self.rotary_ndims = self_attn.rotary_ndims
54
+ self.q_layernorm = getattr(self_attn, "q_layernorm", None)
55
+ self.k_layernorm = getattr(self_attn, "k_layernorm", None)
58
56
 
59
57
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
58
  if lora_int_id is not None:
@@ -65,8 +63,8 @@ class PhiAttention(DecoderOnlyAttention):
65
63
  value_states = self.v_proj(hidden_states)
66
64
 
67
65
  if self.qk_layernorm:
68
- query_states = self._original_mod.q_layernorm(query_states)
69
- key_states = self._original_mod.k_layernorm(key_states)
66
+ query_states = self.q_layernorm(query_states)
67
+ key_states = self.k_layernorm(key_states)
70
68
 
71
69
  return query_states, key_states, value_states
72
70
 
@@ -75,8 +73,7 @@ class PhiAttention(DecoderOnlyAttention):
75
73
 
76
74
 
77
75
  class PhiLayer(DecoderOnlyLayer):
78
- def get_post_attention_layernorm(self):
79
- raise NotImplementedError
76
+ _POST_ATTN_LAYERNORM = None
80
77
 
81
78
  def forward(
82
79
  self,
@@ -103,13 +100,8 @@ class PhiLayer(DecoderOnlyLayer):
103
100
  block_tables=block_tables,
104
101
  )
105
102
 
106
- feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
103
+ feed_forward_hidden_states = self.mlp(hidden_states)
107
104
 
108
105
  hidden_states = attn_output + feed_forward_hidden_states + residual
109
106
 
110
107
  return hidden_states
111
-
112
-
113
- class PhiModel(DecoderOnlyModel):
114
- def get_last_layernorm(self):
115
- return self._original_mod.final_layernorm
@@ -15,5 +15,10 @@
15
15
  from .configuration_qwen2_5_vl import (
16
16
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
17
17
  RBLNQwen2_5_VLForConditionalGenerationConfig,
18
+ RBLNQwen2_5_VLModelConfig,
19
+ )
20
+ from .modeling_qwen2_5_vl import (
21
+ RBLNQwen2_5_VisionTransformerPretrainedModel,
22
+ RBLNQwen2_5_VLForConditionalGeneration,
23
+ RBLNQwen2_5_VLModel,
18
24
  )
19
- from .modeling_qwen2_5_vl import RBLNQwen2_5_VisionTransformerPretrainedModel, RBLNQwen2_5_VLForConditionalGeneration
@@ -15,7 +15,7 @@
15
15
  from typing import Any, List, Optional, Union
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
18
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
19
19
 
20
20
 
21
21
  class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -56,6 +56,16 @@ class RBLNQwen2_5_VLForConditionalGenerationConfig(RBLNDecoderOnlyModelForCausal
56
56
  self.visual = visual
57
57
 
58
58
 
59
+ class RBLNQwen2_5_VLModelConfig(RBLNDecoderOnlyModelConfig):
60
+ """
61
+ Configuration class for RBLNQwen2_5_VLModel.
62
+ """
63
+
64
+ def __init__(self, visual: Optional[RBLNModelConfig] = None, **kwargs: Any):
65
+ super().__init__(**kwargs)
66
+ self.visual = self.initialize_submodule_config(submodule_config=visual)
67
+
68
+
59
69
  class RBLNQwen2_5_VisionTransformerPretrainedModelConfig(RBLNModelConfig):
60
70
  """
61
71
  Configuration class for RBLNQwen2_5_VisionTransformerPretrainedModel.
@@ -17,7 +17,13 @@ from pathlib import Path
17
17
  from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
- from transformers import AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, Qwen2_5_VLForConditionalGeneration
20
+ from transformers import (
21
+ AutoModelForVision2Seq,
22
+ PretrainedConfig,
23
+ PreTrainedModel,
24
+ Qwen2_5_VLConfig,
25
+ Qwen2_5_VLForConditionalGeneration,
26
+ )
21
27
  from transformers.modeling_utils import no_init_weights
22
28
  from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
23
29
  Qwen2_5_VisionPatchEmbed,
@@ -30,8 +36,8 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
30
36
  from ....configuration_utils import RBLNCompileConfig
31
37
  from ....modeling import RBLNModel
32
38
  from ....utils.logging import get_logger
33
- from ...modeling_outputs import RBLNDecoderOnlyOutput
34
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
39
+ from ...modeling_outputs import RBLNDecoderOnlyOutput, _validate_output_hidden_states
40
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
35
41
  from .configuration_qwen2_5_vl import (
36
42
  RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
37
43
  RBLNQwen2_5_VLForConditionalGenerationConfig,
@@ -42,7 +48,7 @@ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_V
42
48
  logger = get_logger(__name__)
43
49
 
44
50
  if TYPE_CHECKING:
45
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
51
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
46
52
 
47
53
 
48
54
  class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
@@ -55,6 +61,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
55
61
  """
56
62
 
57
63
  auto_model_class = None
64
+ _supports_non_fp32 = True
58
65
 
59
66
  def __post_init__(self, **kwargs):
60
67
  self.transformer = self.model[0]
@@ -91,7 +98,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
91
98
  def _wrap_model_if_needed(
92
99
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
93
100
  ):
94
- return Qwen2_5_VisionTransformerWrapper(model).eval()
101
+ return Qwen2_5_VisionTransformerWrapper(model, rbln_config).eval()
95
102
 
96
103
  def __getattr__(self, __name: str) -> Any:
97
104
  def redirect(func):
@@ -126,22 +133,22 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
126
133
  )
127
134
 
128
135
  input_info = [
129
- ("hidden_states", [max_seq_len, hidden_size], "float32"),
130
- ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
136
+ ("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
137
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
131
138
  (
132
139
  "window_attn_masks",
133
140
  [max_seq_len // window_seq_len, 1, window_seq_len, window_seq_len],
134
- "float32",
141
+ rbln_config.dtype,
135
142
  ),
136
143
  (
137
144
  "cos",
138
145
  [1, 1, max_seq_len, head_dim],
139
- "float32",
146
+ rbln_config.dtype,
140
147
  ),
141
148
  (
142
149
  "sin",
143
150
  [1, 1, max_seq_len, head_dim],
144
- "float32",
151
+ rbln_config.dtype,
145
152
  ),
146
153
  ]
147
154
  input_infos.append(input_info)
@@ -203,7 +210,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
203
210
  1,
204
211
  window_seq_len,
205
212
  window_seq_len,
206
- dtype=torch.float32,
213
+ dtype=hidden_states.dtype,
207
214
  )
208
215
  for i, valid_len in enumerate(window_valid_lengths):
209
216
  if valid_len < window_seq_len:
@@ -242,7 +249,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
242
249
  1,
243
250
  max_seq_len,
244
251
  max_seq_len,
245
- dtype=torch.float32,
252
+ dtype=hidden_state_padded.dtype,
246
253
  )
247
254
  for i, valid_len in enumerate(window_valid_lengths):
248
255
  start = i * window_seq_len
@@ -253,7 +260,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
253
260
  return hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks
254
261
 
255
262
  def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
256
- hidden_states = self.patch_embed(hidden_states)
263
+ hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
257
264
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
258
265
  window_index, cu_window_seqlens = self.get_window_index(grid_thw)
259
266
  cu_window_seqlens = torch.tensor(
@@ -270,7 +277,7 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
270
277
  rotary_pos_emb = rotary_pos_emb[window_index, :, :]
271
278
  rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
272
279
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
273
- position_embeddings = (emb.cos(), emb.sin())
280
+ position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
274
281
 
275
282
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
276
283
  dim=0,
@@ -338,66 +345,47 @@ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
338
345
  return hidden_states
339
346
 
340
347
 
341
- class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
342
- """
343
- RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
344
- optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
345
-
346
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
347
-
348
- Important Note:
349
- This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
350
- tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
351
- `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
352
-
353
- Examples:
354
- ```python
355
- from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
356
-
357
- model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
358
- "Qwen/Qwen2.5-VL-7B-Instruct",
359
- export=True,
360
- rbln_config={
361
- "visual": {
362
- "max_seq_lens": 6400,
363
- "device": 0,
364
- },
365
- "tensor_parallel_size": 8,
366
- "kvcache_partition_len": 16_384,
367
- "max_seq_len": 114_688,
368
- "device": [0, 1, 2, 3, 4, 5, 6, 7],
369
- },
370
- )
371
-
372
- model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
373
- ```
374
- """
375
-
376
- _supports_non_fp32 = False
377
-
348
+ class RBLNQwen2_5_VLModel(RBLNDecoderOnlyModel):
378
349
  auto_model_class = AutoModelForVision2Seq
350
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
351
+ _use_rotary_emb = False
379
352
  _rbln_submodules = [
380
353
  {"name": "visual"},
381
354
  ]
382
- _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
383
- _use_rotary_emb = False
355
+ _config_class = Qwen2_5_VLConfig
356
+ _rotary_emb_class = Qwen2_5_VLRotaryEmbedding
357
+ _get_rope_index_func = Qwen2_5_VLModel.get_rope_index
384
358
 
385
359
  def __post_init__(self, **kwargs):
360
+ if hasattr(self.config, "embedding_dim"):
361
+ self.embedding_dim = self.config.embedding_dim
362
+
363
+ if not isinstance(self.config.text_config, PretrainedConfig):
364
+ self.config = self._config_class(
365
+ text_config=self.config.text_config, vision_config=self.config.vision_config
366
+ )
367
+
386
368
  super().__post_init__(**kwargs)
387
369
  self.visual = self.rbln_submodules[0]
388
- self.mrope_section = self.config.rope_scaling["mrope_section"]
389
- self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
390
- self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
391
-
392
- def can_generate(self):
393
- return True
370
+ self.rotary_emb = self._rotary_emb_class(self.config)
371
+ if not self.can_generate():
372
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
373
+
374
+ @property
375
+ def logits_last_dim(self):
376
+ if self.can_generate():
377
+ return self.config.vocab_size
378
+ else:
379
+ return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
394
380
 
395
- @classmethod
396
- def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
397
- model.model.lm_head = model.lm_head
398
- model.lm_head = None
399
- del model.lm_head
400
- return model
381
+ def _create_embedding_layer(self):
382
+ with no_init_weights():
383
+ embed_tokens = torch.nn.Embedding(
384
+ self.config.text_config.vocab_size,
385
+ self.config.text_config.hidden_size,
386
+ self.config.text_config.pad_token_id,
387
+ )
388
+ return embed_tokens
401
389
 
402
390
  @classmethod
403
391
  def get_input_info(
@@ -414,61 +402,25 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
414
402
  (
415
403
  "position_emb",
416
404
  [2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
417
- "float32",
405
+ rbln_config.dtype,
418
406
  ),
419
407
  )
420
408
 
421
409
  return input_info
422
410
 
423
- def prepare_inputs_for_generation(
424
- self,
425
- input_ids: torch.LongTensor,
426
- generate_idx: Optional[torch.Tensor] = None,
427
- attention_mask: Optional[torch.LongTensor] = None,
428
- inputs_embeds: Optional[torch.Tensor] = None,
429
- pixel_values=None,
430
- pixel_values_videos=None,
431
- image_grid_thw=None,
432
- video_grid_thw=None,
433
- second_per_grid_ts=None,
434
- **kwargs,
435
- ):
436
- model_inputs = {}
437
- is_prefill_phase = generate_idx is None
438
-
439
- if is_prefill_phase:
440
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
441
- cache_position = None
442
- model_inputs.update({"input_ids": input_ids})
443
- else:
444
- if inputs_embeds is not None:
445
- raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
446
-
447
- input_ids = input_ids[:, -1:]
448
- cache_position = generate_idx
449
- generate_idx = generate_idx + 1
450
- model_inputs.update({"input_ids": input_ids})
451
-
452
- model_inputs.update(
453
- {
454
- "attention_mask": attention_mask,
455
- "cache_position": cache_position,
456
- "generate_idx": generate_idx,
457
- "pixel_values": pixel_values,
458
- "pixel_values_videos": pixel_values_videos,
459
- "image_grid_thw": image_grid_thw,
460
- "video_grid_thw": video_grid_thw,
461
- "second_per_grid_ts": second_per_grid_ts,
462
- }
463
- )
464
-
465
- return model_inputs
466
-
467
411
  def _get_position_embeddings(self, hidden_states, position_ids):
468
412
  cos, sin = self.rotary_emb(hidden_states, position_ids)
469
- mrope_section = self.mrope_section * 2
470
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
471
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
413
+ mrope_section = self.config.rope_scaling["mrope_section"] * 2
414
+ cos = (
415
+ torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
416
+ .unsqueeze(1)
417
+ .to(self.rbln_config.dtype)
418
+ )
419
+ sin = (
420
+ torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
421
+ .unsqueeze(1)
422
+ .to(self.rbln_config.dtype)
423
+ )
472
424
  return torch.stack([cos, sin])
473
425
 
474
426
  def _preprocess_prefill(
@@ -482,7 +434,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
482
434
  second_per_grid_ts: torch.Tensor = None,
483
435
  ):
484
436
  batch_size = input_ids.shape[0]
485
- inputs_embeds = self.embed_tokens(input_ids)
437
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
486
438
 
487
439
  if pixel_values is not None:
488
440
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -517,7 +469,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
517
469
  max_inputs_len = input_ids.shape[1]
518
470
 
519
471
  head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
520
- all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
472
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
521
473
  all_rope_deltas = []
522
474
 
523
475
  image_token_id = self.config.image_token_id
@@ -531,8 +483,7 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
531
483
  vision_tokens = input_id[0][vision_start_indices + 1]
532
484
  image_nums = (vision_tokens == image_token_id).sum()
533
485
  video_nums = (vision_tokens == video_token_id).sum()
534
- position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
535
- self,
486
+ position_ids, rope_deltas = self._get_rope_index_func(
536
487
  input_id,
537
488
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
538
489
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -550,6 +501,180 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
550
501
 
551
502
  return inputs_embeds, all_position_embeds, rope_deltas
552
503
 
504
+ def forward(
505
+ self,
506
+ input_ids: Optional[torch.LongTensor] = None,
507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
508
+ attention_mask: Optional[torch.Tensor] = None,
509
+ pixel_values: Optional[torch.Tensor] = None,
510
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
511
+ image_grid_thw: Optional[torch.LongTensor] = None,
512
+ video_grid_thw: Optional[torch.LongTensor] = None,
513
+ cache_position: Optional[torch.LongTensor] = None,
514
+ second_per_grid_ts: Optional[torch.Tensor] = None,
515
+ output_hidden_states: Optional[bool] = None,
516
+ return_dict: Optional[bool] = None,
517
+ **kwargs,
518
+ ) -> RBLNDecoderOnlyOutput:
519
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
520
+ input_ids,
521
+ attention_mask,
522
+ pixel_values,
523
+ pixel_values_videos,
524
+ image_grid_thw,
525
+ video_grid_thw,
526
+ second_per_grid_ts,
527
+ )
528
+
529
+ self.rope_deltas = rope_deltas
530
+ batch_size, seq_len = inputs_embeds.shape[:2]
531
+
532
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
533
+
534
+ all_hidden_states = (
535
+ tuple(
536
+ torch.zeros(
537
+ batch_size,
538
+ seq_len,
539
+ self.config.hidden_size,
540
+ dtype=self.rbln_config.dtype,
541
+ )
542
+ for _ in range(self.config.num_hidden_layers + 1)
543
+ )
544
+ if output_hidden_states
545
+ else None
546
+ )
547
+
548
+ logits = []
549
+ for b_idx in range(batch_size):
550
+ query_length = attention_mask[b_idx].sum(dim=-1).int().item()
551
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
552
+
553
+ output = self.prefill_decoder(
554
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
555
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
556
+ cache_position=cache_position,
557
+ batch_idx=b_idx,
558
+ position_embed=position_embed[:, b_idx : b_idx + 1],
559
+ block_tables=self.block_tables,
560
+ )
561
+ logits.append(output.logits)
562
+ if self.rbln_config.output_hidden_states:
563
+ for l_idx in range(self.config.num_hidden_layers + 1):
564
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
565
+ logits = torch.cat(logits, dim=0)
566
+
567
+ if not return_dict:
568
+ return_value = logits if not output_hidden_states else (logits, all_hidden_states)
569
+ return return_value
570
+ else:
571
+ return (
572
+ RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
573
+ if output_hidden_states
574
+ else RBLNDecoderOnlyOutput(logits=logits)
575
+ )
576
+
577
+
578
+ # MRO: RBLNQwen2_5_VLForConditionalGeneration -> RBLNQwen2_5_VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
579
+ class RBLNQwen2_5_VLForConditionalGeneration(RBLNQwen2_5_VLModel, RBLNDecoderOnlyModelForCausalLM):
580
+ """
581
+ RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
582
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
583
+
584
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
585
+
586
+ Important Note:
587
+ This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
588
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
589
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
590
+
591
+ Examples:
592
+ ```python
593
+ from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
594
+
595
+ model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
596
+ "Qwen/Qwen2.5-VL-7B-Instruct",
597
+ export=True,
598
+ rbln_config={
599
+ "visual": {
600
+ "max_seq_lens": 6400,
601
+ "device": 0,
602
+ },
603
+ "tensor_parallel_size": 8,
604
+ "kvcache_partition_len": 16_384,
605
+ "max_seq_len": 114_688,
606
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
607
+ },
608
+ )
609
+
610
+ model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
611
+ ```
612
+ """
613
+
614
+ auto_model_class = AutoModelForVision2Seq
615
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
616
+ _supports_non_fp32 = True
617
+ _use_rotary_emb = False
618
+ _rbln_submodules = [
619
+ {"name": "visual"},
620
+ ]
621
+
622
+ def __post_init__(self, **kwargs):
623
+ super().__post_init__(**kwargs)
624
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
625
+
626
+ def can_generate(self):
627
+ return True
628
+
629
+ @classmethod
630
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
631
+ model.model.lm_head = model.lm_head
632
+ return model
633
+
634
+ def prepare_inputs_for_generation(
635
+ self,
636
+ input_ids: torch.LongTensor,
637
+ generate_idx: Optional[torch.Tensor] = None,
638
+ attention_mask: Optional[torch.LongTensor] = None,
639
+ inputs_embeds: Optional[torch.Tensor] = None,
640
+ pixel_values=None,
641
+ pixel_values_videos=None,
642
+ image_grid_thw=None,
643
+ video_grid_thw=None,
644
+ second_per_grid_ts=None,
645
+ **kwargs,
646
+ ):
647
+ model_inputs = {}
648
+ is_prefill_phase = generate_idx is None
649
+
650
+ if is_prefill_phase:
651
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
652
+ cache_position = None
653
+ model_inputs.update({"input_ids": input_ids})
654
+ else:
655
+ if inputs_embeds is not None:
656
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
657
+
658
+ input_ids = input_ids[:, -1:]
659
+ cache_position = generate_idx
660
+ generate_idx = generate_idx + 1
661
+ model_inputs.update({"input_ids": input_ids})
662
+
663
+ model_inputs.update(
664
+ {
665
+ "attention_mask": attention_mask,
666
+ "cache_position": cache_position,
667
+ "generate_idx": generate_idx,
668
+ "pixel_values": pixel_values,
669
+ "pixel_values_videos": pixel_values_videos,
670
+ "image_grid_thw": image_grid_thw,
671
+ "video_grid_thw": video_grid_thw,
672
+ "second_per_grid_ts": second_per_grid_ts,
673
+ }
674
+ )
675
+
676
+ return model_inputs
677
+
553
678
  def _preprocess_decoder(
554
679
  self,
555
680
  input_ids: torch.LongTensor = None,
@@ -560,14 +685,14 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
560
685
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
561
686
  )
562
687
 
563
- inputs_embeds = self.embed_tokens(input_ids)
688
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
564
689
  position_embeds = []
565
690
  for b_idx in range(self.rbln_config.batch_size):
566
691
  delta = cache_position[b_idx] + self.rope_deltas[b_idx]
567
692
  position_ids = torch.arange(1).view(1, -1)
568
693
  position_ids = position_ids.add(delta)
569
694
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
570
- position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
695
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
571
696
  position_embeds.append(position_embed)
572
697
 
573
698
  position_embeds = torch.cat(position_embeds, dim=1)
@@ -587,8 +712,10 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
587
712
  second_per_grid_ts: Optional[torch.Tensor] = None,
588
713
  generate_idx: Optional[torch.Tensor] = None,
589
714
  return_dict: Optional[bool] = None,
715
+ output_hidden_states: Optional[bool] = None,
590
716
  **kwargs,
591
717
  ) -> RBLNDecoderOnlyOutput:
718
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
592
719
  # Prefill
593
720
  if cache_position is None:
594
721
  inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
@@ -601,8 +728,21 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
601
728
  second_per_grid_ts,
602
729
  )
603
730
 
731
+ batch_size, seq_len = inputs_embeds.shape[:2]
732
+ all_hidden_states = (
733
+ tuple(
734
+ torch.zeros(
735
+ batch_size,
736
+ seq_len,
737
+ self.config.hidden_size,
738
+ dtype=self.rbln_config.dtype,
739
+ )
740
+ for _ in range(self.config.num_hidden_layers + 1)
741
+ )
742
+ if output_hidden_states
743
+ else None
744
+ )
604
745
  self.rope_deltas = rope_deltas
605
- batch_size = inputs_embeds.shape[0]
606
746
 
607
747
  logits = []
608
748
  for b_idx in range(batch_size):
@@ -616,8 +756,11 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
616
756
  position_embed=position_embed[:, b_idx : b_idx + 1],
617
757
  )
618
758
  logits.append(output.logits)
759
+ if self.rbln_config.output_hidden_states:
760
+ for l_idx in range(self.config.num_hidden_layers + 1):
761
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
619
762
  logits = torch.cat(logits, dim=0)
620
- # Decoder
763
+ # Decoder
621
764
  else:
622
765
  inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
623
766
  output = self.decoder(
@@ -626,11 +769,17 @@ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
626
769
  position_embed=position_embed,
627
770
  )
628
771
  logits = output.logits
772
+ all_hidden_states = output.hidden_states
629
773
 
630
774
  if not return_dict:
631
- return logits, generate_idx
775
+ return_value = (
776
+ logits,
777
+ generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
778
+ )
779
+ return return_value
632
780
  else:
633
781
  return RBLNDecoderOnlyOutput(
634
782
  logits=logits,
635
783
  generate_idx=generate_idx,
784
+ hidden_states=all_hidden_states,
636
785
  )