optimum-rbln 0.8.1a0__py3-none-any.whl → 0.8.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +53 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +1 -1
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
  53. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  54. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  55. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  56. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  57. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +229 -175
  59. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  60. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  63. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  64. optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
  65. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  66. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  67. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  68. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  69. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  70. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +88 -236
  71. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  72. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  73. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  74. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  75. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  76. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  77. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  78. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
  79. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  80. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  81. optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
  82. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  83. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  84. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  85. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  86. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  90. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  91. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  92. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  93. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +47 -26
  94. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
  95. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  96. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  97. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  98. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  99. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  100. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  102. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  104. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  105. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  106. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  107. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  108. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  110. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  111. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  112. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  114. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  115. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  116. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  117. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  118. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  119. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  120. optimum/rbln/utils/model_utils.py +20 -0
  121. optimum/rbln/utils/submodule.py +6 -8
  122. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
  123. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
  124. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  125. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  126. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
  127. {optimum_rbln-0.8.1a0.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -32,10 +32,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
34
  from ....utils.logging import get_logger
35
- from ..decoderonly.decoderonly_architecture import (
36
- set_default_values,
37
- validate_attention_method,
38
- )
39
35
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
40
36
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
41
37
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
@@ -215,15 +211,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
215
211
 
216
212
  return model_kwargs
217
213
 
218
- def get_image_features(self, pixel_values: torch.Tensor):
214
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
219
215
  """
220
216
  Projects the last hidden state from the vision model into language model space.
221
217
 
222
218
  Args:
223
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
224
- The tensors corresponding to the input images.
219
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
220
+ The tensors corresponding to the input images.
221
+
225
222
  Returns:
226
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
223
+ Image feature tensor of shape `(num_images, image_length, embed_dim)`.
227
224
  """
228
225
  vision_outputs = self.vision_tower(pixel_values).last_hidden_state
229
226
  image_features = self.multi_modal_projector(vision_outputs)
@@ -272,7 +269,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
272
269
  padded_cache_lengths: Optional[torch.Tensor] = None,
273
270
  position_ids: Optional[torch.Tensor] = None,
274
271
  token_type_ids: Optional[torch.Tensor] = None,
275
- **lm_kwargs,
272
+ **lm_kwargs: Dict[str, Any],
276
273
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
277
274
  # prefill
278
275
  if cache_position is None:
@@ -352,16 +349,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
352
349
  # Find image start positions
353
350
  image_starts = [
354
351
  s
355
- for s in range(seq_len - self.prefill_chunk_size + 1)
356
- if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
352
+ for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
357
354
  ]
358
355
 
359
356
  # Initialize padded tensors
360
357
  padded_input_len = seq_len
361
358
  for image_start in image_starts:
362
359
  pad_needed = (
363
- self.prefill_chunk_size - (image_start + padded_input_len - seq_len) % self.prefill_chunk_size
364
- ) % self.prefill_chunk_size
360
+ self.rbln_config.prefill_chunk_size
361
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
+ ) % self.rbln_config.prefill_chunk_size
365
363
  padded_input_len += pad_needed
366
364
  total_padding = padded_input_len - seq_len
367
365
 
@@ -390,7 +388,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
390
388
  src_pos = image_start
391
389
 
392
390
  # Padding
393
- pad_needed = (self.prefill_chunk_size - dest_pos % self.prefill_chunk_size) % self.prefill_chunk_size
391
+ pad_needed = (
392
+ self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
+ ) % self.rbln_config.prefill_chunk_size
394
394
  if pad_needed and dest_pos < padded_input_len:
395
395
  position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
396
  last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
@@ -399,21 +399,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
399
399
 
400
400
  # Image segment
401
401
  if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.prefill_chunk_size
402
+ inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
404
  ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.prefill_chunk_size
405
+ attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
407
  ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.prefill_chunk_size
408
+ position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
410
  ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.prefill_chunk_size
411
+ token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
413
  ]
414
- dest_pos += self.prefill_chunk_size
415
- src_pos += self.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
414
+ dest_pos += self.rbln_config.prefill_chunk_size
415
+ src_pos += self.rbln_config.prefill_chunk_size
416
+ last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
417
 
418
418
  return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
419
 
@@ -444,11 +444,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
444
444
 
445
445
  seq_len = inputs.shape[1]
446
446
  # Initialize attention mask for chunked processing
447
- if self.use_attention_mask:
447
+ if self.rbln_config.use_attention_mask:
448
448
  chunked_attention_mask = (
449
449
  torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.use_position_ids
451
- else torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
450
+ if self.rbln_config.use_position_ids
451
+ else torch.zeros(
452
+ 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
+ )
452
454
  )
453
455
  else:
454
456
  chunked_attention_mask = None
@@ -467,20 +469,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
467
469
  )
468
470
 
469
471
  query_length = inputs.shape[1]
470
- if query_length > self.max_seq_len:
472
+ if query_length > self.rbln_config.max_seq_len:
471
473
  raise ValueError(
472
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
474
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
473
475
  )
474
476
 
475
477
  # Align attention_mask to compiled shape
476
- if self.use_position_ids:
478
+ if self.rbln_config.use_position_ids:
477
479
  chunked_attention_mask = torch.nn.functional.pad(
478
- chunked_attention_mask, (0, self.max_seq_len - query_length)
480
+ chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
479
481
  )
480
482
 
481
483
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
482
- if query_length % self.prefill_chunk_size != 0:
483
- padding_size = self.prefill_chunk_size - query_length % self.prefill_chunk_size
484
+ padding_size = 0
485
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
486
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
484
487
  # inputs_embeds
485
488
  if inputs.dim() == 3:
486
489
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -548,65 +551,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
548
551
  ) = self._prepare_prefill_inputs(
549
552
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
550
553
  )
551
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
552
554
  if not is_external_block_tables:
553
555
  local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
554
557
 
555
- if self.use_attention_mask and self.use_position_ids:
556
- chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
558
+ if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
557
560
 
558
561
  # Process input in chunks of size `prefill_chunk_size`
559
- for step in range(0, query_length, self.prefill_chunk_size):
562
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
560
563
  # Extract the current chunk of inputs and cache positions
561
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
562
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
564
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
563
566
  position_ids_chunk = (
564
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
567
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
+ if position_ids is not None
569
+ else None
565
570
  )
566
571
 
567
572
  # Not used in Gemma3 yet.
568
- if self.use_attention_mask:
569
- if self.use_position_ids:
570
- chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
571
- batch_idx, step : step + self.prefill_chunk_size
573
+ if self.rbln_config.use_attention_mask:
574
+ if self.rbln_config.use_position_ids:
575
+ chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
576
+ batch_idx, step : step + self.rbln_config.prefill_chunk_size
572
577
  ]
573
578
  else:
574
579
  # Update attention mask to ensure proper causal behavior
575
- if step >= self.prefill_chunk_size:
576
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
577
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
580
+ if step >= self.rbln_config.prefill_chunk_size:
581
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
582
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
583
+ self.causal_mask
584
+ )
578
585
 
579
586
  # Define query position
580
587
  query_position = (
581
588
  torch.sum(
582
- chunked_attention_mask[0][step : step + self.prefill_chunk_size], dim=-1, dtype=torch.int16
589
+ chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
590
+ dim=-1,
591
+ dtype=torch.int16,
583
592
  ).squeeze(0)
584
593
  - 1
585
594
  )
586
595
  if token_type_ids_padded[:, step] == 1:
587
- if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
596
+ if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
588
597
  raise ValueError("All tokens of image_prefill should be the same image.")
589
598
  else:
590
599
  logits = self.image_prefill(
591
600
  input_chunk,
592
- chunked_attention_mask,
593
601
  cache_pos_chunk,
594
- position_ids_chunk,
595
- query_position,
596
602
  block_tables,
597
603
  local_block_tables,
604
+ query_position,
605
+ chunked_attention_mask,
606
+ position_ids_chunk,
598
607
  out=out_buffers,
599
608
  )
600
609
  else:
601
610
  # Forward pass for the current chunk
602
611
  logits = self.prefill(
603
612
  input_chunk,
604
- chunked_attention_mask,
605
613
  cache_pos_chunk,
606
- position_ids_chunk,
607
- query_position,
608
614
  block_tables,
609
615
  local_block_tables,
616
+ query_position,
617
+ chunked_attention_mask,
618
+ position_ids_chunk,
610
619
  out=out_buffers,
611
620
  )
612
621
 
@@ -646,7 +655,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
646
655
  if local_block_tables is not None
647
656
  else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
648
657
  )
649
- if self.use_attention_mask and attention_mask is None:
658
+ if self.rbln_config.use_attention_mask and attention_mask is None:
650
659
  for b_idx in range(batch_size):
651
660
  decoding_step = cache_position[b_idx].item()
652
661
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -663,7 +672,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
663
672
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
664
673
  attention_mask = attention_mask[: self.batch_size]
665
674
 
666
- logits = self.decode(inputs, attention_mask, cache_position, position_ids, block_tables, local_block_tables)
675
+ logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
667
676
 
668
677
  return RBLNDecoderOnlyOutput(logits=logits)
669
678
 
@@ -700,7 +709,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
700
709
  dtype=torch.int16,
701
710
  ).fill_(-1)
702
711
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
703
-
704
712
  self.prefill_decoder = RBLNGemma3RuntimeModel(
705
713
  runtime=self.model[0],
706
714
  image_prefill=self.model[1],
@@ -710,14 +718,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
710
718
  batch_size=self.rbln_config.batch_size,
711
719
  dec_attn_mask=dec_attn_mask,
712
720
  block_tables=block_tables,
713
- free_block_pool=free_block_pool,
714
- kvcache_block_size=self.rbln_config.kvcache_block_size,
715
721
  vocab_size=self.config.vocab_size,
716
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
717
- max_seq_len=self.rbln_config.max_seq_len,
718
- use_attention_mask=self.rbln_config.use_attention_mask,
719
- attn_impl=self.rbln_config.attn_impl,
720
- use_position_ids=self.rbln_config.use_position_ids,
722
+ free_block_pool=free_block_pool,
723
+ rbln_config=self.rbln_config,
721
724
  )
722
725
 
723
726
  self.decoders = {}
@@ -731,10 +734,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
731
734
  dec_attn_mask=dec_attn_mask,
732
735
  block_tables=block_tables,
733
736
  free_block_pool=free_block_pool,
734
- kvcache_block_size=self.rbln_config.kvcache_block_size,
735
- use_attention_mask=self.rbln_config.use_attention_mask,
736
- attn_impl=self.rbln_config.attn_impl,
737
- use_position_ids=self.rbln_config.use_position_ids,
737
+ rbln_config=self.rbln_config,
738
738
  )
739
739
 
740
740
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -751,81 +751,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
751
751
  return embed_tokens
752
752
 
753
753
  @classmethod
754
- def get_input_info(
755
- cls,
756
- batch_size: int,
757
- query_length: int,
758
- use_inputs_embeds: bool,
759
- use_attention_mask: bool,
760
- use_position_ids: bool,
761
- max_seq_len: int,
762
- kvcache_block_size: int,
763
- kvcache_num_blocks: int,
764
- num_key_value_heads: int,
765
- num_hidden_layers: int,
766
- hidden_size: int,
767
- head_dim: int,
768
- sliding_window: int,
769
- sliding_window_pattern: int,
770
- dec_batch_size: int,
771
- ):
772
- if use_inputs_embeds:
773
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
774
- else:
775
- main_input = ("input_ids", [batch_size, query_length], "int64")
776
-
777
- input_info = [
778
- main_input,
779
- (
780
- "attention_mask",
781
- [batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
782
- "float32",
783
- ),
784
- (
785
- "cache_position",
786
- [batch_size, query_length],
787
- "int32",
788
- ),
789
- (
790
- "position_ids",
791
- [batch_size, query_length],
792
- "int32",
793
- ),
794
- ]
795
-
796
- if query_length > 1:
797
- input_info.extend(
798
- [
799
- ("query_position", [], "int16"),
800
- ]
801
- )
802
-
803
- max_block_cnt = max_seq_len // kvcache_block_size
804
-
805
- if query_length > 1:
806
- input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
807
- input_info.extend([("local_block_tables", [1], "int16")])
808
- else:
809
- input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
810
- input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
811
-
812
- def is_sliding(layer_idx: int) -> bool:
813
- return bool((layer_idx + 1) % sliding_window_pattern)
814
-
815
- local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
816
- global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
817
- input_info.extend(
818
- [
819
- (
820
- f"past_key_values_{i}",
821
- local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
822
- "float32",
823
- )
824
- for i in range(num_hidden_layers * 2)
754
+ def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
755
+ sliding_window = getattr(model_config, "sliding_window", None)
756
+ sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
757
+ if sliding_window_pattern <= model_config.num_hidden_layers:
758
+ rbln_config.cache_impl = "hybrid"
759
+ rbln_config.sliding_window = sliding_window
760
+ rbln_config.sliding_window_layers = [
761
+ i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
825
762
  ]
826
- )
827
763
 
828
- return input_info
764
+ return rbln_config
829
765
 
830
766
  @classmethod
831
767
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
@@ -846,102 +782,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
846
782
  model_config: Optional["PretrainedConfig"] = None,
847
783
  rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
848
784
  ) -> RBLNGemma3ForCausalLMConfig:
849
- if rbln_config.max_seq_len is None:
850
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None)
851
- if rbln_config.max_seq_len is None:
852
- raise ValueError("`max_seq_len` should be specified.")
853
-
854
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
855
- attn_impl=rbln_config.attn_impl,
856
- kvcache_partition_len=rbln_config.kvcache_partition_len,
857
- kvcache_block_size=rbln_config.kvcache_block_size,
858
- max_seq_len=rbln_config.max_seq_len,
859
- )
860
-
861
- validate_attention_method(
862
- attn_impl=rbln_config.attn_impl,
863
- kvcache_partition_len=rbln_config.kvcache_partition_len,
864
- kvcache_block_size=rbln_config.kvcache_block_size,
865
- max_seq_len=rbln_config.max_seq_len,
866
- )
867
-
868
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
869
- max_num_blocks = required_num_blocks
785
+ # Update rbln_config with super class
786
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
870
787
 
871
- if rbln_config.attn_impl == "flash_attn":
872
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
873
- if max_num_blocks < flash_min_blocks:
874
- max_num_blocks = flash_min_blocks
875
-
876
- if max_num_blocks < rbln_config.batch_size:
877
- raise RuntimeError(
878
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
879
- "Ensure the number of blocks is at least equal to the batch size."
880
- )
881
-
882
- if rbln_config.kvcache_num_blocks is None:
883
- rbln_config.kvcache_num_blocks = max_num_blocks
884
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
885
- logger.warning(
886
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
887
- f" than the estimated maximum number of blocks ({max_num_blocks})."
888
- "This can cause a failure during model compilation."
889
- )
890
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
891
-
892
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
893
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
894
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
895
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
896
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
897
- sliding_window = getattr(model_config, "sliding_window", None)
898
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
899
-
900
- prefill_input_info = cls.get_input_info(
901
- batch_size=1,
902
- query_length=rbln_config.prefill_chunk_size,
903
- use_inputs_embeds=rbln_config.use_inputs_embeds,
904
- use_attention_mask=rbln_config.use_attention_mask,
905
- use_position_ids=rbln_config.use_position_ids,
906
- max_seq_len=rbln_config.max_seq_len,
907
- kvcache_block_size=rbln_config.kvcache_block_size,
908
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
909
- num_key_value_heads=num_key_value_heads,
910
- num_hidden_layers=num_hidden_layers,
911
- hidden_size=hidden_size,
912
- head_dim=head_dim,
913
- sliding_window=sliding_window,
914
- sliding_window_pattern=sliding_window_pattern,
915
- dec_batch_size=max(rbln_config.decoder_batch_sizes),
916
- )
917
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
788
+ # Assume that prefill compile config is at index 0
789
+ compile_cfgs = rbln_config.compile_cfgs
918
790
  image_prefill_compile_config = RBLNCompileConfig(
919
- compiled_model_name="image_prefill", input_info=prefill_input_info
791
+ compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
920
792
  )
921
-
922
- dec_compile_configs = []
923
- for batch_size in rbln_config.decoder_batch_sizes:
924
- dec_input_info = cls.get_input_info(
925
- batch_size=batch_size,
926
- query_length=1,
927
- use_inputs_embeds=rbln_config.use_inputs_embeds,
928
- use_attention_mask=rbln_config.use_attention_mask,
929
- use_position_ids=rbln_config.use_position_ids,
930
- max_seq_len=rbln_config.max_seq_len,
931
- kvcache_block_size=rbln_config.kvcache_block_size,
932
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
933
- num_key_value_heads=num_key_value_heads,
934
- num_hidden_layers=num_hidden_layers,
935
- hidden_size=hidden_size,
936
- head_dim=head_dim,
937
- sliding_window=sliding_window,
938
- sliding_window_pattern=sliding_window_pattern,
939
- dec_batch_size=batch_size,
940
- )
941
- dec_compile_configs.append(
942
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
943
- )
944
- rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
793
+ # Insert image_prefill compile config at index 1
794
+ image_idx = 1
795
+ compile_cfgs.insert(image_idx, image_prefill_compile_config)
796
+ rbln_config.set_compile_cfgs(compile_cfgs)
945
797
 
946
798
  return rbln_config
947
799
 
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for GPT-2 causal language model.
21
+ Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
+ """
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
45
45
  )
46
46
  new_layer = GPT2Layer(layer, new_self_attn)
47
47
  new_layers.append(new_layer)
48
- new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
48
+ new_model = GPT2Model(
49
+ causal_lm.transformer,
50
+ new_layers,
51
+ max_seq_len=max_seq_len,
52
+ sliding_window_layers=self.sliding_window_layers,
53
+ )
49
54
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
50
55
  return new_causal_lm
51
56
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
22
22
 
23
23
 
24
24
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNIdefics3ForConditionalGeneration models.
27
+
28
+ This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
29
+ It allows configuration of the batch size and separate configurations for the vision and text submodules.
30
+
31
+ Attributes:
32
+ submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
33
+ """
34
+
25
35
  submodules = ["vision_model", "text_model"]
26
36
 
27
37
  def __init__(
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
29
39
  batch_size: Optional[int] = None,
30
40
  vision_model: Optional[RBLNModelConfig] = None,
31
41
  text_model: Optional[RBLNModelConfig] = None,
32
- **kwargs,
42
+ **kwargs: Dict[str, Any],
33
43
  ):
34
44
  """
35
45
  Args:
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
102
102
  subfolder: str,
103
103
  rbln_config: RBLNModelConfig,
104
104
  ):
105
- """
106
- If you are unavoidably running on a CPU rather than an RBLN device,
107
- store the torch tensor, weight, etc. in this function.
108
- """
105
+ # If you are unavoidably running on a CPU rather than an RBLN device,
106
+ # store the torch tensor, weight, etc. in this function.
107
+
109
108
  save_dict = {}
110
109
  save_dict["embeddings"] = model.get_input_embeddings().state_dict()
111
110
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
190
189
 
191
190
 
192
191
  class RBLNIdefics3ForConditionalGeneration(RBLNModel):
192
+ """
193
+ RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
194
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
195
+
196
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
197
+
198
+ Important Note:
199
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
200
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
201
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
202
+
203
+ Examples:
204
+ ```python
205
+ from optimum.rbln import RBLNIdefics3ForConditionalGeneration
206
+
207
+ model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
208
+ "HuggingFaceM4/idefics3-8b",
209
+ export=True,
210
+ rbln_config={
211
+ "vision_model": {
212
+ "device": 0,
213
+ },
214
+ "text_model": {
215
+ "batch_size": 1,
216
+ "max_seq_len": 131_072,
217
+ "tensor_parallel_size": 8,
218
+ "use_inputs_embeds": True,
219
+ "attn_impl": "flash_attn",
220
+ "kvcache_partition_len": 16_384,
221
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
222
+ },
223
+ },
224
+ )
225
+
226
+ model.save_pretrained("compiled-idefics3-8b")
227
+ ```
228
+ """
229
+
193
230
  auto_model_class = AutoModelForVision2Seq
194
231
  _rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
195
232
  _rbln_submodule_prefix = "model"
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Llama models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNLlamaForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNLlamaForCausalLM.from_pretrained(
37
+ "meta-llama/Llama-2-7b-hf",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """