optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  54. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  55. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  56. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  57. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  58. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  59. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  60. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  63. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  64. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  65. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  66. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  67. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  68. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  69. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  70. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  71. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +106 -236
  72. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  73. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  74. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  75. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  76. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  77. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  78. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  79. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  80. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  81. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  82. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  83. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  84. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  85. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  86. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  87. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  88. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  89. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  90. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  91. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  92. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  93. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  94. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +58 -27
  95. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +47 -2
  96. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  97. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  98. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  99. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  100. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  101. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  102. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  103. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  104. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  105. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  106. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  108. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  109. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  110. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  111. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  112. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  113. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  114. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  115. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  116. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  117. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  118. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  119. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  120. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  121. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  122. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  123. optimum/rbln/utils/model_utils.py +20 -0
  124. optimum/rbln/utils/submodule.py +6 -8
  125. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/METADATA +2 -2
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/RECORD +130 -117
  127. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  128. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  129. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/WHEEL +0 -0
  130. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a2.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
16
  from collections import deque
16
17
  from dataclasses import dataclass
@@ -32,10 +33,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
32
33
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
34
  from ....modeling import RBLNModel
34
35
  from ....utils.logging import get_logger
35
- from ..decoderonly.decoderonly_architecture import (
36
- set_default_values,
37
- validate_attention_method,
38
- )
39
36
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
40
37
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
41
38
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
@@ -127,6 +124,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
127
124
  def can_generate(self):
128
125
  return True
129
126
 
127
+ @classmethod
128
+ def get_pytorch_model(cls, *args, **kwargs):
129
+ model = super().get_pytorch_model(*args, **kwargs)
130
+
131
+ with no_init_weights():
132
+ model_cls_name = model.model.language_model.__class__.__name__
133
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
134
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
135
+ new_language_model = causal_model_cls(model.model.language_model.config)
136
+
137
+ new_language_model.lm_head = model.lm_head
138
+ new_language_model.model = model.model.language_model
139
+ model.model.language_model = new_language_model
140
+ model.lm_head = None
141
+ del model.lm_head
142
+ return model
143
+
130
144
  def __post_init__(self, **kwargs):
131
145
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
132
146
  self.language_model = self.rbln_submodules[1]
@@ -215,15 +229,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
215
229
 
216
230
  return model_kwargs
217
231
 
218
- def get_image_features(self, pixel_values: torch.Tensor):
232
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
219
233
  """
220
234
  Projects the last hidden state from the vision model into language model space.
221
235
 
222
236
  Args:
223
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
224
- The tensors corresponding to the input images.
237
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
238
+ The tensors corresponding to the input images.
239
+
225
240
  Returns:
226
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
241
+ Image feature tensor of shape `(num_images, image_length, embed_dim)`.
227
242
  """
228
243
  vision_outputs = self.vision_tower(pixel_values).last_hidden_state
229
244
  image_features = self.multi_modal_projector(vision_outputs)
@@ -272,7 +287,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
272
287
  padded_cache_lengths: Optional[torch.Tensor] = None,
273
288
  position_ids: Optional[torch.Tensor] = None,
274
289
  token_type_ids: Optional[torch.Tensor] = None,
275
- **lm_kwargs,
290
+ **lm_kwargs: Dict[str, Any],
276
291
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
277
292
  # prefill
278
293
  if cache_position is None:
@@ -352,16 +367,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
352
367
  # Find image start positions
353
368
  image_starts = [
354
369
  s
355
- for s in range(seq_len - self.prefill_chunk_size + 1)
356
- if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
370
+ for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
371
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
357
372
  ]
358
373
 
359
374
  # Initialize padded tensors
360
375
  padded_input_len = seq_len
361
376
  for image_start in image_starts:
362
377
  pad_needed = (
363
- self.prefill_chunk_size - (image_start + padded_input_len - seq_len) % self.prefill_chunk_size
364
- ) % self.prefill_chunk_size
378
+ self.rbln_config.prefill_chunk_size
379
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
380
+ ) % self.rbln_config.prefill_chunk_size
365
381
  padded_input_len += pad_needed
366
382
  total_padding = padded_input_len - seq_len
367
383
 
@@ -390,7 +406,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
390
406
  src_pos = image_start
391
407
 
392
408
  # Padding
393
- pad_needed = (self.prefill_chunk_size - dest_pos % self.prefill_chunk_size) % self.prefill_chunk_size
409
+ pad_needed = (
410
+ self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
411
+ ) % self.rbln_config.prefill_chunk_size
394
412
  if pad_needed and dest_pos < padded_input_len:
395
413
  position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
414
  last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
@@ -399,21 +417,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
399
417
 
400
418
  # Image segment
401
419
  if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.prefill_chunk_size
420
+ inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
421
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
422
  ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.prefill_chunk_size
423
+ attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
424
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
425
  ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.prefill_chunk_size
426
+ position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
427
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
428
  ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.prefill_chunk_size
429
+ token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
430
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
431
  ]
414
- dest_pos += self.prefill_chunk_size
415
- src_pos += self.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
432
+ dest_pos += self.rbln_config.prefill_chunk_size
433
+ src_pos += self.rbln_config.prefill_chunk_size
434
+ last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
435
 
418
436
  return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
437
 
@@ -444,11 +462,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
444
462
 
445
463
  seq_len = inputs.shape[1]
446
464
  # Initialize attention mask for chunked processing
447
- if self.use_attention_mask:
465
+ if self.rbln_config.use_attention_mask:
448
466
  chunked_attention_mask = (
449
467
  torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.use_position_ids
451
- else torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
468
+ if self.rbln_config.use_position_ids
469
+ else torch.zeros(
470
+ 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
471
+ )
452
472
  )
453
473
  else:
454
474
  chunked_attention_mask = None
@@ -467,20 +487,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
467
487
  )
468
488
 
469
489
  query_length = inputs.shape[1]
470
- if query_length > self.max_seq_len:
490
+ if query_length > self.rbln_config.max_seq_len:
471
491
  raise ValueError(
472
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
492
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
473
493
  )
474
494
 
475
495
  # Align attention_mask to compiled shape
476
- if self.use_position_ids:
496
+ if self.rbln_config.use_position_ids:
477
497
  chunked_attention_mask = torch.nn.functional.pad(
478
- chunked_attention_mask, (0, self.max_seq_len - query_length)
498
+ chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
479
499
  )
480
500
 
481
501
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
482
- if query_length % self.prefill_chunk_size != 0:
483
- padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
502
+ padding_size = 0
503
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
504
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
484
505
  # inputs_embeds
485
506
  if inputs.dim() == 3:
486
507
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -548,65 +569,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
548
569
  ) = self._prepare_prefill_inputs(
549
570
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
550
571
  )
551
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
552
572
  if not is_external_block_tables:
553
573
  local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
574
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
554
575
 
555
- if self.use_attention_mask and self.use_position_ids:
556
- chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
576
+ if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
577
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
557
578
 
558
579
  # Process input in chunks of size `prefill_chunk_size`
559
- for step in range(0, query_length, self.prefill_chunk_size):
580
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
560
581
  # Extract the current chunk of inputs and cache positions
561
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
562
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
582
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
583
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
563
584
  position_ids_chunk = (
564
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
585
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
586
+ if position_ids is not None
587
+ else None
565
588
  )
566
589
 
567
590
  # Not used in Gemma3 yet.
568
- if self.use_attention_mask:
569
- if self.use_position_ids:
570
- chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
571
- batch_idx, step : step + self.prefill_chunk_size
591
+ if self.rbln_config.use_attention_mask:
592
+ if self.rbln_config.use_position_ids:
593
+ chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
594
+ batch_idx, step : step + self.rbln_config.prefill_chunk_size
572
595
  ]
573
596
  else:
574
597
  # Update attention mask to ensure proper causal behavior
575
- if step >= self.prefill_chunk_size:
576
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
577
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
598
+ if step >= self.rbln_config.prefill_chunk_size:
599
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
600
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
601
+ self.causal_mask
602
+ )
578
603
 
579
604
  # Define query position
580
605
  query_position = (
581
606
  torch.sum(
582
- chunked_attention_mask[0][step : step + self.prefill_chunk_size], dim=-1, dtype=torch.int16
607
+ chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
608
+ dim=-1,
609
+ dtype=torch.int16,
583
610
  ).squeeze(0)
584
611
  - 1
585
612
  )
586
613
  if token_type_ids_padded[:, step] == 1:
587
- if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
614
+ if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
588
615
  raise ValueError("All tokens of image_prefill should be the same image.")
589
616
  else:
590
617
  logits = self.image_prefill(
591
618
  input_chunk,
592
- chunked_attention_mask,
593
619
  cache_pos_chunk,
594
- position_ids_chunk,
595
- query_position,
596
620
  block_tables,
597
621
  local_block_tables,
622
+ query_position,
623
+ chunked_attention_mask,
624
+ position_ids_chunk,
598
625
  out=out_buffers,
599
626
  )
600
627
  else:
601
628
  # Forward pass for the current chunk
602
629
  logits = self.prefill(
603
630
  input_chunk,
604
- chunked_attention_mask,
605
631
  cache_pos_chunk,
606
- position_ids_chunk,
607
- query_position,
608
632
  block_tables,
609
633
  local_block_tables,
634
+ query_position,
635
+ chunked_attention_mask,
636
+ position_ids_chunk,
610
637
  out=out_buffers,
611
638
  )
612
639
 
@@ -646,7 +673,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
646
673
  if local_block_tables is not None
647
674
  else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
648
675
  )
649
- if self.use_attention_mask and attention_mask is None:
676
+ if self.rbln_config.use_attention_mask and attention_mask is None:
650
677
  for b_idx in range(batch_size):
651
678
  decoding_step = cache_position[b_idx].item()
652
679
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -663,7 +690,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
663
690
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
664
691
  attention_mask = attention_mask[: self.batch_size]
665
692
 
666
- logits = self.decode(inputs, attention_mask, cache_position, position_ids, block_tables, local_block_tables)
693
+ logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
667
694
 
668
695
  return RBLNDecoderOnlyOutput(logits=logits)
669
696
 
@@ -700,7 +727,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
700
727
  dtype=torch.int16,
701
728
  ).fill_(-1)
702
729
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
703
-
704
730
  self.prefill_decoder = RBLNGemma3RuntimeModel(
705
731
  runtime=self.model[0],
706
732
  image_prefill=self.model[1],
@@ -710,14 +736,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
710
736
  batch_size=self.rbln_config.batch_size,
711
737
  dec_attn_mask=dec_attn_mask,
712
738
  block_tables=block_tables,
713
- free_block_pool=free_block_pool,
714
- kvcache_block_size=self.rbln_config.kvcache_block_size,
715
739
  vocab_size=self.config.vocab_size,
716
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
717
- max_seq_len=self.rbln_config.max_seq_len,
718
- use_attention_mask=self.rbln_config.use_attention_mask,
719
- attn_impl=self.rbln_config.attn_impl,
720
- use_position_ids=self.rbln_config.use_position_ids,
740
+ free_block_pool=free_block_pool,
741
+ rbln_config=self.rbln_config,
721
742
  )
722
743
 
723
744
  self.decoders = {}
@@ -731,10 +752,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
731
752
  dec_attn_mask=dec_attn_mask,
732
753
  block_tables=block_tables,
733
754
  free_block_pool=free_block_pool,
734
- kvcache_block_size=self.rbln_config.kvcache_block_size,
735
- use_attention_mask=self.rbln_config.use_attention_mask,
736
- attn_impl=self.rbln_config.attn_impl,
737
- use_position_ids=self.rbln_config.use_position_ids,
755
+ rbln_config=self.rbln_config,
738
756
  )
739
757
 
740
758
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -751,81 +769,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
751
769
  return embed_tokens
752
770
 
753
771
  @classmethod
754
- def get_input_info(
755
- cls,
756
- batch_size: int,
757
- query_length: int,
758
- use_inputs_embeds: bool,
759
- use_attention_mask: bool,
760
- use_position_ids: bool,
761
- max_seq_len: int,
762
- kvcache_block_size: int,
763
- kvcache_num_blocks: int,
764
- num_key_value_heads: int,
765
- num_hidden_layers: int,
766
- hidden_size: int,
767
- head_dim: int,
768
- sliding_window: int,
769
- sliding_window_pattern: int,
770
- dec_batch_size: int,
771
- ):
772
- if use_inputs_embeds:
773
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
774
- else:
775
- main_input = ("input_ids", [batch_size, query_length], "int64")
776
-
777
- input_info = [
778
- main_input,
779
- (
780
- "attention_mask",
781
- [batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
782
- "float32",
783
- ),
784
- (
785
- "cache_position",
786
- [batch_size, query_length],
787
- "int32",
788
- ),
789
- (
790
- "position_ids",
791
- [batch_size, query_length],
792
- "int32",
793
- ),
794
- ]
795
-
796
- if query_length > 1:
797
- input_info.extend(
798
- [
799
- ("query_position", [], "int16"),
800
- ]
801
- )
802
-
803
- max_block_cnt = max_seq_len // kvcache_block_size
804
-
805
- if query_length > 1:
806
- input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
807
- input_info.extend([("local_block_tables", [1], "int16")])
808
- else:
809
- input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
810
- input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
811
-
812
- def is_sliding(layer_idx: int) -> bool:
813
- return bool((layer_idx + 1) % sliding_window_pattern)
814
-
815
- local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
816
- global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
817
- input_info.extend(
818
- [
819
- (
820
- f"past_key_values_{i}",
821
- local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
822
- "float32",
823
- )
824
- for i in range(num_hidden_layers * 2)
772
+ def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
773
+ sliding_window = getattr(model_config, "sliding_window", None)
774
+ sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
775
+ if sliding_window_pattern <= model_config.num_hidden_layers:
776
+ rbln_config.cache_impl = "hybrid"
777
+ rbln_config.sliding_window = sliding_window
778
+ rbln_config.sliding_window_layers = [
779
+ i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
825
780
  ]
826
- )
827
781
 
828
- return input_info
782
+ return rbln_config
829
783
 
830
784
  @classmethod
831
785
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
@@ -846,102 +800,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
846
800
  model_config: Optional["PretrainedConfig"] = None,
847
801
  rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
848
802
  ) -> RBLNGemma3ForCausalLMConfig:
849
- if rbln_config.max_seq_len is None:
850
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None)
851
- if rbln_config.max_seq_len is None:
852
- raise ValueError("`max_seq_len` should be specified.")
853
-
854
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
855
- attn_impl=rbln_config.attn_impl,
856
- kvcache_partition_len=rbln_config.kvcache_partition_len,
857
- kvcache_block_size=rbln_config.kvcache_block_size,
858
- max_seq_len=rbln_config.max_seq_len,
859
- )
860
-
861
- validate_attention_method(
862
- attn_impl=rbln_config.attn_impl,
863
- kvcache_partition_len=rbln_config.kvcache_partition_len,
864
- kvcache_block_size=rbln_config.kvcache_block_size,
865
- max_seq_len=rbln_config.max_seq_len,
866
- )
803
+ # Update rbln_config with super class
804
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
867
805
 
868
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
869
- max_num_blocks = required_num_blocks
870
-
871
- if rbln_config.attn_impl == "flash_attn":
872
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
873
- if max_num_blocks < flash_min_blocks:
874
- max_num_blocks = flash_min_blocks
875
-
876
- if max_num_blocks < rbln_config.batch_size:
877
- raise RuntimeError(
878
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
879
- "Ensure the number of blocks is at least equal to the batch size."
880
- )
881
-
882
- if rbln_config.kvcache_num_blocks is None:
883
- rbln_config.kvcache_num_blocks = max_num_blocks
884
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
885
- logger.warning(
886
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
887
- f" than the estimated maximum number of blocks ({max_num_blocks})."
888
- "This can cause a failure during model compilation."
889
- )
890
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
891
-
892
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
893
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
894
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
895
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
896
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
897
- sliding_window = getattr(model_config, "sliding_window", None)
898
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
899
-
900
- prefill_input_info = cls.get_input_info(
901
- batch_size=1,
902
- query_length=rbln_config.prefill_chunk_size,
903
- use_inputs_embeds=rbln_config.use_inputs_embeds,
904
- use_attention_mask=rbln_config.use_attention_mask,
905
- use_position_ids=rbln_config.use_position_ids,
906
- max_seq_len=rbln_config.max_seq_len,
907
- kvcache_block_size=rbln_config.kvcache_block_size,
908
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
909
- num_key_value_heads=num_key_value_heads,
910
- num_hidden_layers=num_hidden_layers,
911
- hidden_size=hidden_size,
912
- head_dim=head_dim,
913
- sliding_window=sliding_window,
914
- sliding_window_pattern=sliding_window_pattern,
915
- dec_batch_size=max(rbln_config.decoder_batch_sizes),
916
- )
917
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
806
+ # Assume that prefill compile config is at index 0
807
+ compile_cfgs = rbln_config.compile_cfgs
918
808
  image_prefill_compile_config = RBLNCompileConfig(
919
- compiled_model_name="image_prefill", input_info=prefill_input_info
809
+ compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
920
810
  )
921
-
922
- dec_compile_configs = []
923
- for batch_size in rbln_config.decoder_batch_sizes:
924
- dec_input_info = cls.get_input_info(
925
- batch_size=batch_size,
926
- query_length=1,
927
- use_inputs_embeds=rbln_config.use_inputs_embeds,
928
- use_attention_mask=rbln_config.use_attention_mask,
929
- use_position_ids=rbln_config.use_position_ids,
930
- max_seq_len=rbln_config.max_seq_len,
931
- kvcache_block_size=rbln_config.kvcache_block_size,
932
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
933
- num_key_value_heads=num_key_value_heads,
934
- num_hidden_layers=num_hidden_layers,
935
- hidden_size=hidden_size,
936
- head_dim=head_dim,
937
- sliding_window=sliding_window,
938
- sliding_window_pattern=sliding_window_pattern,
939
- dec_batch_size=batch_size,
940
- )
941
- dec_compile_configs.append(
942
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
943
- )
944
- rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
811
+ # Insert image_prefill compile config at index 1
812
+ image_idx = 1
813
+ compile_cfgs.insert(image_idx, image_prefill_compile_config)
814
+ rbln_config.set_compile_cfgs(compile_cfgs)
945
815
 
946
816
  return rbln_config
947
817
 
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for GPT-2 causal language model.
21
+ Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
+ """
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
45
45
  )
46
46
  new_layer = GPT2Layer(layer, new_self_attn)
47
47
  new_layers.append(new_layer)
48
- new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
48
+ new_model = GPT2Model(
49
+ causal_lm.transformer,
50
+ new_layers,
51
+ max_seq_len=max_seq_len,
52
+ sliding_window_layers=self.sliding_window_layers,
53
+ )
49
54
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
50
55
  return new_causal_lm
51
56
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
22
22
 
23
23
 
24
24
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNIdefics3ForConditionalGeneration models.
27
+
28
+ This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
29
+ It allows configuration of the batch size and separate configurations for the vision and text submodules.
30
+
31
+ Attributes:
32
+ submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
33
+ """
34
+
25
35
  submodules = ["vision_model", "text_model"]
26
36
 
27
37
  def __init__(
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
29
39
  batch_size: Optional[int] = None,
30
40
  vision_model: Optional[RBLNModelConfig] = None,
31
41
  text_model: Optional[RBLNModelConfig] = None,
32
- **kwargs,
42
+ **kwargs: Dict[str, Any],
33
43
  ):
34
44
  """
35
45
  Args:
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
102
102
  subfolder: str,
103
103
  rbln_config: RBLNModelConfig,
104
104
  ):
105
- """
106
- If you are unavoidably running on a CPU rather than an RBLN device,
107
- store the torch tensor, weight, etc. in this function.
108
- """
105
+ # If you are unavoidably running on a CPU rather than an RBLN device,
106
+ # store the torch tensor, weight, etc. in this function.
107
+
109
108
  save_dict = {}
110
109
  save_dict["embeddings"] = model.get_input_embeddings().state_dict()
111
110
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
190
189
 
191
190
 
192
191
  class RBLNIdefics3ForConditionalGeneration(RBLNModel):
192
+ """
193
+ RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
194
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
195
+
196
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
197
+
198
+ Important Note:
199
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
200
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
201
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
202
+
203
+ Examples:
204
+ ```python
205
+ from optimum.rbln import RBLNIdefics3ForConditionalGeneration
206
+
207
+ model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
208
+ "HuggingFaceM4/idefics3-8b",
209
+ export=True,
210
+ rbln_config={
211
+ "vision_model": {
212
+ "device": 0,
213
+ },
214
+ "text_model": {
215
+ "batch_size": 1,
216
+ "max_seq_len": 131_072,
217
+ "tensor_parallel_size": 8,
218
+ "use_inputs_embeds": True,
219
+ "attn_impl": "flash_attn",
220
+ "kvcache_partition_len": 16_384,
221
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
222
+ },
223
+ },
224
+ )
225
+
226
+ model.save_pretrained("compiled-idefics3-8b")
227
+ ```
228
+ """
229
+
193
230
  auto_model_class = AutoModelForVision2Seq
194
231
  _rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
195
232
  _rbln_submodule_prefix = "model"