optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -27,6 +27,7 @@ from transformers.modeling_utils import no_init_weights
27
27
  from transformers.models.qwen2_vl.modeling_qwen2_vl import (
28
28
  PatchEmbed,
29
29
  Qwen2VisionTransformerPretrainedModel,
30
+ Qwen2VLConfig,
30
31
  Qwen2VLModel,
31
32
  Qwen2VLRotaryEmbedding,
32
33
  VisionRotaryEmbedding,
@@ -35,7 +36,12 @@ from transformers.models.qwen2_vl.modeling_qwen2_vl import (
35
36
  from ....configuration_utils import RBLNCompileConfig
36
37
  from ....modeling import RBLNModel
37
38
  from ....utils.logging import get_logger
38
- from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput
39
+ from ...modeling_outputs import _validate_output_hidden_states
40
+ from ..decoderonly.modeling_decoderonly import (
41
+ RBLNDecoderOnlyModel,
42
+ RBLNDecoderOnlyModelForCausalLM,
43
+ RBLNDecoderOnlyOutput,
44
+ )
39
45
  from .configuration_qwen2_vl import (
40
46
  RBLNQwen2VisionTransformerPretrainedModelConfig,
41
47
  RBLNQwen2VLForConditionalGenerationConfig,
@@ -56,6 +62,7 @@ if TYPE_CHECKING:
56
62
 
57
63
  class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
58
64
  auto_model_class = None
65
+ _supports_non_fp32 = True
59
66
 
60
67
  def __post_init__(self, **kwargs):
61
68
  self.transformer = self.model[0]
@@ -92,7 +99,7 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
92
99
  def _wrap_model_if_needed(
93
100
  cls, model: "PreTrainedModel", rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig
94
101
  ):
95
- return Qwen2VisionTransformerWrapper(model).eval()
102
+ return Qwen2VisionTransformerWrapper(model, rbln_config).eval()
96
103
 
97
104
  def __getattr__(self, __name: str) -> Any:
98
105
  def redirect(func):
@@ -119,17 +126,17 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
119
126
  input_infos = []
120
127
  for max_seq_len in rbln_config.max_seq_lens:
121
128
  input_info = [
122
- ("hidden_states", [max_seq_len, hidden_size], "float32"),
123
- ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
129
+ ("hidden_states", [max_seq_len, hidden_size], rbln_config.dtype),
130
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], rbln_config.dtype),
124
131
  (
125
132
  "cos",
126
133
  [1, 1, max_seq_len, head_dim],
127
- "float32",
134
+ rbln_config.dtype,
128
135
  ),
129
136
  (
130
137
  "sin",
131
138
  [1, 1, max_seq_len, head_dim],
132
- "float32",
139
+ rbln_config.dtype,
133
140
  ),
134
141
  ]
135
142
  input_infos.append(input_info)
@@ -166,7 +173,7 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
166
173
  1,
167
174
  max_seq_len,
168
175
  max_seq_len,
169
- dtype=torch.float32,
176
+ dtype=hidden_state.dtype,
170
177
  )
171
178
 
172
179
  full_attn_masks[:, :, hidden_state.shape[0] : max_seq_len, :] = 0
@@ -177,10 +184,10 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
177
184
  # Processes a batch of images (or frames) through the vision transformer.
178
185
  # Each image is handled independently for padding and attention mask generation.
179
186
 
180
- hidden_states = self.patch_embed(hidden_states)
187
+ hidden_states = self.patch_embed(hidden_states).to(self.rbln_config.dtype)
181
188
  rotary_pos_emb = self.rot_pos_emb(grid_thw)
182
189
  emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
183
- position_embeddings = (emb.cos(), emb.sin())
190
+ position_embeddings = (emb.cos().to(self.rbln_config.dtype), emb.sin().to(self.rbln_config.dtype))
184
191
 
185
192
  cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
186
193
  dim=0,
@@ -230,63 +237,48 @@ class RBLNQwen2VisionTransformerPretrainedModel(RBLNModel):
230
237
  return hidden_states
231
238
 
232
239
 
233
- class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
234
- """
235
- RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
236
- optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
237
-
238
- This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
239
-
240
- Important Note:
241
- This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
242
- tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
243
- `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
244
-
245
- Examples:
246
- ```python
247
- from optimum.rbln import RBLNQwen2VLForConditionalGeneration
248
-
249
- model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
250
- "Qwen/Qwen2-VL-7B-Instruct",
251
- export=True,
252
- rbln_config={
253
- "visual": {
254
- "max_seq_lens": 6400,
255
- "device": 0,
256
- },
257
- "tensor_parallel_size": 8,
258
- "max_seq_len": 32_768,
259
- "device": [0, 1, 2, 3, 4, 5, 6, 7],
260
- },
261
- )
262
-
263
- model.save_pretrained("compiled-qwen2-vl-7b-instruct")
264
- ```
265
- """
266
-
240
+ class RBLNQwen2VLModel(RBLNDecoderOnlyModel):
267
241
  auto_model_class = AutoModelForVision2Seq
242
+ _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
243
+ _supports_non_fp32 = True
244
+ _use_rotary_emb = False
268
245
  _rbln_submodules = [
269
246
  {"name": "visual"},
270
247
  ]
271
- _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
272
- _use_rotary_emb = False
248
+ _config_class = Qwen2VLConfig
249
+ _rotary_emb_class = Qwen2VLRotaryEmbedding
250
+ _get_rope_index_func = Qwen2VLModel.get_rope_index
273
251
 
274
252
  def __post_init__(self, **kwargs):
253
+ if hasattr(self.config, "embedding_dim"):
254
+ self.embedding_dim = self.config.embedding_dim
255
+
256
+ if not isinstance(self.config.text_config, PretrainedConfig):
257
+ self.config = self._config_class(
258
+ text_config=self.config.text_config, vision_config=self.config.vision_config
259
+ )
260
+
275
261
  super().__post_init__(**kwargs)
276
262
  self.visual = self.rbln_submodules[0]
277
- self.mrope_section = self.config.rope_scaling["mrope_section"]
278
- self.rotary_emb = Qwen2VLRotaryEmbedding(self.config)
279
- self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
280
-
281
- def can_generate(self):
282
- return True
263
+ self.rotary_emb = self._rotary_emb_class(self.config)
264
+ if not self.can_generate():
265
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
266
+
267
+ @property
268
+ def logits_last_dim(self):
269
+ if self.can_generate():
270
+ return self.config.vocab_size
271
+ else:
272
+ return self.embedding_dim if hasattr(self, "embedding_dim") else self.config.hidden_size
283
273
 
284
- @classmethod
285
- def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
286
- model.model.lm_head = model.lm_head
287
- model.lm_head = None
288
- del model.lm_head
289
- return model
274
+ def _create_embedding_layer(self):
275
+ with no_init_weights():
276
+ embed_tokens = torch.nn.Embedding(
277
+ self.config.text_config.vocab_size,
278
+ self.config.text_config.hidden_size,
279
+ self.config.text_config.pad_token_id,
280
+ )
281
+ return embed_tokens
290
282
 
291
283
  @classmethod
292
284
  def get_input_info(
@@ -303,52 +295,25 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
303
295
  (
304
296
  "position_emb",
305
297
  [2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
306
- "float32",
298
+ rbln_config.dtype,
307
299
  ),
308
300
  )
309
301
 
310
302
  return input_info
311
303
 
312
- def prepare_inputs_for_generation(
313
- self,
314
- input_ids: torch.LongTensor,
315
- generate_idx: Optional[torch.Tensor] = None,
316
- attention_mask: Optional[torch.LongTensor] = None,
317
- inputs_embeds: Optional[torch.Tensor] = None,
318
- pixel_values=None,
319
- pixel_values_videos=None,
320
- image_grid_thw=None,
321
- video_grid_thw=None,
322
- **kwargs,
323
- ):
324
- model_inputs = super().prepare_inputs_for_generation(
325
- input_ids,
326
- generate_idx,
327
- attention_mask,
328
- inputs_embeds,
329
- **kwargs,
330
- )
331
-
332
- is_prefill_phase = generate_idx is None
333
- if is_prefill_phase:
334
- model_inputs.update({"input_ids": input_ids})
335
-
336
- model_inputs.update(
337
- {
338
- "pixel_values": pixel_values,
339
- "pixel_values_videos": pixel_values_videos,
340
- "image_grid_thw": image_grid_thw,
341
- "video_grid_thw": video_grid_thw,
342
- }
343
- )
344
-
345
- return model_inputs
346
-
347
304
  def _get_position_embeddings(self, hidden_states, position_ids):
348
305
  cos, sin = self.rotary_emb(hidden_states, position_ids)
349
- mrope_section = self.mrope_section * 2
350
- cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
351
- sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
306
+ mrope_section = self.config.rope_scaling["mrope_section"] * 2
307
+ cos = (
308
+ torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1)
309
+ .unsqueeze(1)
310
+ .to(self.rbln_config.dtype)
311
+ )
312
+ sin = (
313
+ torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1)
314
+ .unsqueeze(1)
315
+ .to(self.rbln_config.dtype)
316
+ )
352
317
  return torch.stack([cos, sin])
353
318
 
354
319
  def _preprocess_prefill(
@@ -361,7 +326,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
361
326
  video_grid_thw: torch.LongTensor = None,
362
327
  ):
363
328
  batch_size = input_ids.shape[0]
364
- inputs_embeds = self.embed_tokens(input_ids)
329
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
365
330
 
366
331
  if pixel_values is not None:
367
332
  image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
@@ -396,7 +361,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
396
361
  max_inputs_len = input_ids.shape[1]
397
362
 
398
363
  head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
399
- all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
364
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim, dtype=self.rbln_config.dtype)
400
365
  all_rope_deltas = []
401
366
 
402
367
  image_token_id = self.config.image_token_id
@@ -410,8 +375,7 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
410
375
  vision_tokens = input_id[0][vision_start_indices + 1]
411
376
  image_nums = (vision_tokens == image_token_id).sum()
412
377
  video_nums = (vision_tokens == video_token_id).sum()
413
- position_ids, rope_deltas = Qwen2VLModel.get_rope_index(
414
- self,
378
+ position_ids, rope_deltas = self._get_rope_index_func(
415
379
  input_id,
416
380
  image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
417
381
  video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
@@ -428,6 +392,177 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
428
392
 
429
393
  return inputs_embeds, all_position_embeds, rope_deltas
430
394
 
395
+ def forward(
396
+ self,
397
+ input_ids: Optional[torch.LongTensor] = None,
398
+ inputs_embeds: Optional[torch.FloatTensor] = None,
399
+ attention_mask: Optional[torch.Tensor] = None,
400
+ pixel_values: Optional[torch.Tensor] = None,
401
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
402
+ image_grid_thw: Optional[torch.LongTensor] = None,
403
+ video_grid_thw: Optional[torch.LongTensor] = None,
404
+ cache_position: Optional[torch.LongTensor] = None,
405
+ output_hidden_states: Optional[bool] = None,
406
+ return_dict: Optional[bool] = None,
407
+ **kwargs,
408
+ ) -> RBLNDecoderOnlyOutput:
409
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
410
+ input_ids,
411
+ attention_mask,
412
+ pixel_values,
413
+ pixel_values_videos,
414
+ image_grid_thw,
415
+ video_grid_thw,
416
+ )
417
+
418
+ self.rope_deltas = rope_deltas
419
+ batch_size, seq_len = inputs_embeds.shape[:2]
420
+
421
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
422
+
423
+ all_hidden_states = (
424
+ tuple(
425
+ torch.zeros(
426
+ batch_size,
427
+ seq_len,
428
+ self.config.hidden_size,
429
+ dtype=self.rbln_config.dtype,
430
+ )
431
+ for _ in range(self.config.num_hidden_layers + 1)
432
+ )
433
+ if output_hidden_states
434
+ else None
435
+ )
436
+
437
+ logits = []
438
+ for b_idx in range(batch_size):
439
+ query_length = attention_mask[b_idx].sum(dim=-1).int().item()
440
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
441
+
442
+ outputs = self.prefill_decoder(
443
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
444
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
445
+ cache_position=cache_position,
446
+ batch_idx=b_idx,
447
+ position_embed=position_embed[:, b_idx : b_idx + 1],
448
+ block_tables=self.block_tables,
449
+ )
450
+
451
+ logits.append(outputs.logits)
452
+ if self.rbln_config.output_hidden_states:
453
+ for l_idx in range(self.config.num_hidden_layers + 1):
454
+ all_hidden_states[l_idx][b_idx].copy_(outputs.hidden_states[l_idx][0])
455
+
456
+ logits = torch.cat(logits, dim=0)
457
+
458
+ if not return_dict:
459
+ return_value = logits if not output_hidden_states else (logits, all_hidden_states)
460
+ return return_value
461
+ else:
462
+ return (
463
+ RBLNDecoderOnlyOutput(logits=logits, hidden_states=all_hidden_states)
464
+ if output_hidden_states
465
+ else RBLNDecoderOnlyOutput(logits=logits)
466
+ )
467
+
468
+
469
+ # MRO: RBLNQwen2VLForConditionalGeneration -> RBLNQwen2VLModel -> RBLNDecoderOnlyModelForCausalLM -> RBLNDecoderOnlyModel -> RBLNModel
470
+ class RBLNQwen2VLForConditionalGeneration(RBLNQwen2VLModel, RBLNDecoderOnlyModelForCausalLM):
471
+ """
472
+ RBLNQwen2VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
473
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
474
+
475
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
476
+
477
+ Important Note:
478
+ This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
479
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
480
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2VLForConditionalGenerationConfig class for details.
481
+
482
+ Examples:
483
+ ```python
484
+ from optimum.rbln import RBLNQwen2VLForConditionalGeneration
485
+
486
+ model = RBLNQwen2VLForConditionalGeneration.from_pretrained(
487
+ "Qwen/Qwen2-VL-7B-Instruct",
488
+ export=True,
489
+ rbln_config={
490
+ "visual": {
491
+ "max_seq_lens": 6400,
492
+ "device": 0,
493
+ },
494
+ "tensor_parallel_size": 8,
495
+ "max_seq_len": 32_768,
496
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
497
+ },
498
+ )
499
+
500
+ model.save_pretrained("compiled-qwen2-vl-7b-instruct")
501
+ ```
502
+ """
503
+
504
+ auto_model_class = AutoModelForVision2Seq
505
+ _decoder_wrapper_cls = Qwen2VL_LanguageModelWrapper
506
+ _supports_non_fp32 = True
507
+ _use_rotary_emb = False
508
+ _rbln_submodules = [
509
+ {"name": "visual"},
510
+ ]
511
+
512
+ def __post_init__(self, **kwargs):
513
+ super().__post_init__(**kwargs)
514
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
515
+
516
+ def can_generate(self):
517
+ return True
518
+
519
+ @classmethod
520
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
521
+ model.model.lm_head = model.lm_head
522
+ return model
523
+
524
+ def prepare_inputs_for_generation(
525
+ self,
526
+ input_ids: torch.LongTensor,
527
+ generate_idx: Optional[torch.Tensor] = None,
528
+ attention_mask: Optional[torch.LongTensor] = None,
529
+ inputs_embeds: Optional[torch.Tensor] = None,
530
+ pixel_values=None,
531
+ pixel_values_videos=None,
532
+ image_grid_thw=None,
533
+ video_grid_thw=None,
534
+ **kwargs,
535
+ ):
536
+ model_inputs = {}
537
+ is_prefill_phase = generate_idx is None
538
+
539
+ if is_prefill_phase:
540
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
541
+ cache_position = None
542
+ model_inputs.update({"input_ids": input_ids})
543
+ else:
544
+ if inputs_embeds is not None:
545
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
546
+
547
+ input_ids = input_ids[:, -1:]
548
+ cache_position = generate_idx
549
+ generate_idx = generate_idx + 1
550
+ model_inputs.update({"input_ids": input_ids})
551
+
552
+ model_inputs.update(
553
+ {
554
+ "attention_mask": attention_mask,
555
+ "cache_position": cache_position,
556
+ "generate_idx": generate_idx,
557
+ "pixel_values": pixel_values,
558
+ "pixel_values_videos": pixel_values_videos,
559
+ "image_grid_thw": image_grid_thw,
560
+ "video_grid_thw": video_grid_thw,
561
+ }
562
+ )
563
+
564
+ return model_inputs
565
+
431
566
  def _preprocess_decoder(
432
567
  self,
433
568
  input_ids: torch.LongTensor = None,
@@ -438,14 +573,14 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
438
573
  f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
439
574
  )
440
575
 
441
- inputs_embeds = self.embed_tokens(input_ids)
576
+ inputs_embeds = self.embed_tokens(input_ids).to(self.rbln_config.dtype)
442
577
  position_embeds = []
443
578
  for b_idx in range(self.rbln_config.batch_size):
444
579
  delta = cache_position[b_idx] + self.rope_deltas[b_idx]
445
580
  position_ids = torch.arange(1).view(1, -1)
446
581
  position_ids = position_ids.add(delta)
447
582
  position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
448
- position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
583
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=self.rbln_config.dtype), position_ids)
449
584
  position_embeds.append(position_embed)
450
585
 
451
586
  position_embeds = torch.cat(position_embeds, dim=1)
@@ -464,8 +599,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
464
599
  cache_position: Optional[torch.LongTensor] = None,
465
600
  generate_idx: Optional[torch.Tensor] = None,
466
601
  return_dict: Optional[bool] = None,
602
+ output_hidden_states: Optional[bool] = None,
467
603
  **kwargs,
468
604
  ) -> RBLNDecoderOnlyOutput:
605
+ output_hidden_states = _validate_output_hidden_states(output_hidden_states, self.rbln_config)
469
606
  # Prefill
470
607
  if cache_position is None:
471
608
  inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
@@ -477,8 +614,21 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
477
614
  video_grid_thw,
478
615
  )
479
616
 
617
+ batch_size, seq_len = inputs_embeds.shape[:2]
618
+ all_hidden_states = (
619
+ tuple(
620
+ torch.zeros(
621
+ batch_size,
622
+ seq_len,
623
+ self.config.hidden_size,
624
+ dtype=self.rbln_config.dtype,
625
+ )
626
+ for _ in range(self.config.num_hidden_layers + 1)
627
+ )
628
+ if output_hidden_states
629
+ else None
630
+ )
480
631
  self.rope_deltas = rope_deltas
481
- batch_size = inputs_embeds.shape[0]
482
632
 
483
633
  logits = []
484
634
  for b_idx in range(batch_size):
@@ -492,8 +642,10 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
492
642
  position_embed=position_embed[:, b_idx : b_idx + 1],
493
643
  )
494
644
  logits.append(output.logits)
645
+ if self.rbln_config.output_hidden_states:
646
+ for l_idx in range(self.config.num_hidden_layers + 1):
647
+ all_hidden_states[l_idx][b_idx].copy_(output.hidden_states[l_idx][0])
495
648
  logits = torch.cat(logits, dim=0)
496
-
497
649
  # Decoder
498
650
  else:
499
651
  inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
@@ -503,11 +655,17 @@ class RBLNQwen2VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
503
655
  position_embed=position_embed,
504
656
  )
505
657
  logits = output.logits
658
+ all_hidden_states = output.hidden_states
506
659
 
507
660
  if not return_dict:
508
- return logits, generate_idx
661
+ return_value = (
662
+ logits,
663
+ generate_idx if not output_hidden_states else (logits, generate_idx, all_hidden_states),
664
+ )
665
+ return return_value
509
666
  else:
510
667
  return RBLNDecoderOnlyOutput(
511
668
  logits=logits,
512
669
  generate_idx=generate_idx,
670
+ hidden_states=all_hidden_states,
513
671
  )
@@ -9,19 +9,24 @@ from ..decoderonly.decoderonly_architecture import (
9
9
  DecoderOnlyWrapper,
10
10
  apply_rotary_pos_emb,
11
11
  )
12
+ from .configuration_qwen2_vl import RBLNQwen2VisionTransformerPretrainedModelConfig
12
13
 
13
14
 
14
15
  class Qwen2VisionTransformerWrapper(nn.Module):
15
- def __init__(self, model: torch.nn.Module):
16
+ def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
16
17
  super().__init__()
17
- self._original_mod = model
18
18
  self.merger = model.merger
19
- self.blocks = self.wrap_vision_blocks(model.blocks)
19
+ self.rbln_config = rbln_config
20
+ self.blocks = self.wrap_vision_blocks(model.blocks, rbln_config)
20
21
 
21
- def wrap_vision_blocks(self, blocks: torch.nn.ModuleList):
22
+ def wrap_vision_blocks(
23
+ self,
24
+ blocks: torch.nn.ModuleList,
25
+ rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig,
26
+ ):
22
27
  wrapped_blocks = []
23
28
  for _, block in enumerate(blocks):
24
- wrapped_blocks.append(Qwen2VLVisionBlock(block))
29
+ wrapped_blocks.append(Qwen2VLVisionBlock(block, rbln_config))
25
30
  return nn.ModuleList(wrapped_blocks)
26
31
 
27
32
  def forward(
@@ -31,7 +36,7 @@ class Qwen2VisionTransformerWrapper(nn.Module):
31
36
  cos: torch.Tensor,
32
37
  sin: torch.Tensor,
33
38
  ):
34
- full_attn_masks = (1 - full_attn_masks) * torch.finfo(torch.float32).min
39
+ full_attn_masks = (1.0 - full_attn_masks) * torch.finfo(hidden_states.dtype).min
35
40
 
36
41
  for block in self.blocks:
37
42
  hidden_states = block(hidden_states, full_attn_masks, [cos, sin])
@@ -40,13 +45,13 @@ class Qwen2VisionTransformerWrapper(nn.Module):
40
45
 
41
46
 
42
47
  class Qwen2VLVisionBlock(torch.nn.Module):
43
- def __init__(self, model: torch.nn.Module):
48
+ def __init__(self, model: torch.nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig):
44
49
  super().__init__()
45
50
  self._origin_model = model
51
+ self.rbln_config = rbln_config
46
52
  self.norm1 = model.norm1
47
53
  self.norm2 = model.norm2
48
-
49
- self.attn = VisionAttention(model.attn)
54
+ self.attn = VisionAttention(model.attn, rbln_config)
50
55
  self.mlp = model.mlp
51
56
 
52
57
  def forward(
@@ -65,13 +70,15 @@ class Qwen2VLVisionBlock(torch.nn.Module):
65
70
 
66
71
 
67
72
  class VisionAttention(nn.Module):
68
- def __init__(self, model: nn.Module) -> None:
73
+ def __init__(self, model: nn.Module, rbln_config: RBLNQwen2VisionTransformerPretrainedModelConfig) -> None:
69
74
  super().__init__()
70
75
  self._origin_model = model
76
+ self.rbln_config = rbln_config
71
77
  self.num_heads = model.num_heads
72
78
  self.head_dim = getattr(model, "head_dim", model.proj.in_features // model.num_heads)
73
79
  self.qkv = model.qkv
74
80
  self.proj = model.proj
81
+ self.scale = torch.tensor(1 / math.sqrt(self.head_dim), dtype=rbln_config.dtype)
75
82
 
76
83
  def forward(
77
84
  self,
@@ -88,9 +95,9 @@ class VisionAttention(nn.Module):
88
95
  cos, sin = position_embeddings
89
96
  q, k = apply_rotary_pos_emb(q, k, cos, sin)
90
97
 
91
- attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
98
+ attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale
92
99
  attn_weights = attn_weights + attn_masks
93
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
100
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
94
101
  attn_output = torch.matmul(attn_weights, v)
95
102
  attn_output = attn_output.transpose(1, 2)
96
103
  attn_output = attn_output.reshape(1, seq_length, -1)
@@ -100,6 +107,12 @@ class VisionAttention(nn.Module):
100
107
 
101
108
 
102
109
  class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
110
+ def get_decoder_layers(self, model: PreTrainedModel):
111
+ return model.model.language_model.layers if hasattr(model, "model") else model.language_model.layers
112
+
113
+ def get_model_layer(self, model: PreTrainedModel):
114
+ return model.model.language_model if hasattr(model, "model") else model.language_model
115
+
103
116
  def prepare_forward_args(self, *args):
104
117
  args = list(args)
105
118
  input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
@@ -108,7 +121,7 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
108
121
  global_block_tables = args.pop(0)
109
122
  local_block_tables = None
110
123
  position_embeds = args.pop(0)
111
- query_position = args.pop(0) if self.phase == "prefill" else None
124
+ query_position = args.pop(0) if self.phase == "prefill" and self.rbln_config.logits_to_keep > 0 else None
112
125
  position_ids = None
113
126
  attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
114
127
  lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
@@ -142,24 +155,3 @@ class Qwen2VL_LanguageModelWrapper(DecoderOnlyWrapper):
142
155
  past_key_values,
143
156
  position_embeds,
144
157
  )
145
-
146
- def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
147
- new_layers = []
148
-
149
- for layer_idx, layer in enumerate(model.model.language_model.layers):
150
- is_sliding = layer_idx in self.rbln_config.sliding_window_layers
151
- new_self_attn = self.get_rbln_attn_class()(
152
- self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
153
- )
154
- new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
155
- new_layers.append(new_layer)
156
-
157
- new_model = self.get_rbln_model_class()(
158
- model.model.language_model,
159
- new_layers,
160
- self.rbln_config,
161
- use_learned_pos_emb=self.__class__._use_learned_pos_emb,
162
- )
163
-
164
- new_model = self.get_rbln_causal_lm_class()(model.model, new_model)
165
- return new_model
@@ -22,10 +22,10 @@ class Qwen3Wrapper(DecoderOnlyWrapper):
22
22
 
23
23
 
24
24
  class Qwen3Attention(DecoderOnlyAttention):
25
- def __post_init__(self):
26
- self.k_proj = self._original_mod.k_proj
27
- self.v_proj = self._original_mod.v_proj
28
- self.q_proj = self._original_mod.q_proj
29
- self.o_proj = self._original_mod.o_proj
30
- self.q_norm = self._original_mod.q_norm
31
- self.k_norm = self._original_mod.k_norm
25
+ def __post_init__(self, self_attn):
26
+ self.q_proj = self_attn.q_proj
27
+ self.k_proj = self_attn.k_proj
28
+ self.v_proj = self_attn.v_proj
29
+ self.o_proj = self_attn.o_proj
30
+ self.q_norm = self_attn.q_norm
31
+ self.k_norm = self_attn.k_norm