optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. optimum/rbln/__init__.py +2 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +33 -9
  11. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -12
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +22 -6
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +16 -6
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +16 -6
  15. optimum/rbln/diffusers/modeling_diffusers.py +16 -26
  16. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +11 -0
  17. optimum/rbln/diffusers/models/autoencoders/vae.py +1 -8
  18. optimum/rbln/diffusers/models/autoencoders/vq_model.py +11 -0
  19. optimum/rbln/diffusers/models/controlnet.py +13 -7
  20. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  21. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +7 -0
  23. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  24. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  25. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  26. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  28. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  30. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  31. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  32. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  33. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  34. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  35. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  36. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  37. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  38. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  39. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  42. optimum/rbln/modeling.py +33 -35
  43. optimum/rbln/modeling_base.py +45 -107
  44. optimum/rbln/transformers/__init__.py +39 -47
  45. optimum/rbln/transformers/configuration_generic.py +16 -13
  46. optimum/rbln/transformers/modeling_generic.py +18 -19
  47. optimum/rbln/transformers/modeling_rope_utils.py +1 -1
  48. optimum/rbln/transformers/models/__init__.py +46 -4
  49. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  52. optimum/rbln/transformers/models/auto/auto_factory.py +30 -12
  53. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +35 -4
  54. optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
  55. optimum/rbln/transformers/models/clip/modeling_clip.py +11 -12
  56. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  57. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  58. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +231 -175
  59. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  60. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +19 -0
  61. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +19 -0
  62. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  63. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  64. optimum/rbln/transformers/models/exaone/modeling_exaone.py +51 -5
  65. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  66. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  67. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  68. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  69. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  70. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +87 -236
  71. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  72. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  73. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  74. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  75. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  76. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  77. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
  78. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +33 -4
  79. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  80. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  81. optimum/rbln/transformers/models/midm/modeling_midm.py +51 -5
  82. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  83. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  84. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  85. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  86. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  90. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  91. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  92. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +15 -3
  93. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +46 -25
  94. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -2
  95. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  96. optimum/rbln/transformers/models/resnet/configuration_resnet.py +20 -0
  97. optimum/rbln/transformers/models/resnet/modeling_resnet.py +22 -0
  98. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  99. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +4 -30
  100. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +2 -32
  101. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  102. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  103. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -1
  104. optimum/rbln/transformers/models/siglip/configuration_siglip.py +3 -0
  105. optimum/rbln/transformers/models/siglip/modeling_siglip.py +62 -21
  106. optimum/rbln/transformers/models/t5/modeling_t5.py +46 -4
  107. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  108. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +14 -9
  110. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  111. optimum/rbln/transformers/models/vit/configuration_vit.py +19 -0
  112. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  113. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  114. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  115. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -1
  116. optimum/rbln/transformers/models/whisper/modeling_whisper.py +35 -15
  117. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  118. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  119. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  120. optimum/rbln/utils/model_utils.py +20 -0
  121. optimum/rbln/utils/submodule.py +6 -8
  122. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/METADATA +1 -1
  123. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/RECORD +127 -114
  124. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  125. /optimum/rbln/transformers/models/wav2vec2/{configuration_wav2vec.py → configuration_wav2vec2.py} +0 -0
  126. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/WHEEL +0 -0
  127. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1a1.dist-info}/licenses/LICENSE +0 -0
@@ -32,10 +32,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
34
  from ....utils.logging import get_logger
35
- from ..decoderonly.decoderonly_architecture import (
36
- set_default_values,
37
- validate_attention_method,
38
- )
39
35
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
40
36
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
41
37
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
@@ -215,15 +211,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
215
211
 
216
212
  return model_kwargs
217
213
 
218
- def get_image_features(self, pixel_values: torch.Tensor):
214
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
219
215
  """
220
216
  Projects the last hidden state from the vision model into language model space.
221
217
 
222
218
  Args:
223
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
224
- The tensors corresponding to the input images.
219
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
220
+ The tensors corresponding to the input images.
221
+
225
222
  Returns:
226
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
223
+ Image feature tensor of shape `(num_images, image_length, embed_dim)`.
227
224
  """
228
225
  vision_outputs = self.vision_tower(pixel_values).last_hidden_state
229
226
  image_features = self.multi_modal_projector(vision_outputs)
@@ -272,7 +269,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
272
269
  padded_cache_lengths: Optional[torch.Tensor] = None,
273
270
  position_ids: Optional[torch.Tensor] = None,
274
271
  token_type_ids: Optional[torch.Tensor] = None,
275
- **lm_kwargs,
272
+ **lm_kwargs: Dict[str, Any],
276
273
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
277
274
  # prefill
278
275
  if cache_position is None:
@@ -352,16 +349,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
352
349
  # Find image start positions
353
350
  image_starts = [
354
351
  s
355
- for s in range(seq_len - self.prefill_chunk_size + 1)
356
- if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
352
+ for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
357
354
  ]
358
355
 
359
356
  # Initialize padded tensors
360
357
  padded_input_len = seq_len
361
358
  for image_start in image_starts:
362
359
  pad_needed = (
363
- self.prefill_chunk_size - (image_start + padded_input_len - seq_len) % self.prefill_chunk_size
364
- ) % self.prefill_chunk_size
360
+ self.rbln_config.prefill_chunk_size
361
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
+ ) % self.rbln_config.prefill_chunk_size
365
363
  padded_input_len += pad_needed
366
364
  total_padding = padded_input_len - seq_len
367
365
 
@@ -390,7 +388,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
390
388
  src_pos = image_start
391
389
 
392
390
  # Padding
393
- pad_needed = (self.prefill_chunk_size - dest_pos % self.prefill_chunk_size) % self.prefill_chunk_size
391
+ pad_needed = (
392
+ self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
+ ) % self.rbln_config.prefill_chunk_size
394
394
  if pad_needed and dest_pos < padded_input_len:
395
395
  position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
396
  last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
@@ -399,21 +399,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
399
399
 
400
400
  # Image segment
401
401
  if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.prefill_chunk_size
402
+ inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
404
  ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.prefill_chunk_size
405
+ attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
407
  ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.prefill_chunk_size
408
+ position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
410
  ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.prefill_chunk_size
411
+ token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
413
  ]
414
- dest_pos += self.prefill_chunk_size
415
- src_pos += self.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
414
+ dest_pos += self.rbln_config.prefill_chunk_size
415
+ src_pos += self.rbln_config.prefill_chunk_size
416
+ last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
417
 
418
418
  return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
419
 
@@ -444,11 +444,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
444
444
 
445
445
  seq_len = inputs.shape[1]
446
446
  # Initialize attention mask for chunked processing
447
- if self.use_attention_mask:
447
+ if self.rbln_config.use_attention_mask:
448
448
  chunked_attention_mask = (
449
449
  torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.use_position_ids
451
- else torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
450
+ if self.rbln_config.use_position_ids
451
+ else torch.zeros(
452
+ 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
+ )
452
454
  )
453
455
  else:
454
456
  chunked_attention_mask = None
@@ -467,21 +469,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
467
469
  )
468
470
 
469
471
  query_length = inputs.shape[1]
470
- if query_length > self.max_seq_len:
472
+ if query_length > self.rbln_config.max_seq_len:
471
473
  raise ValueError(
472
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
474
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
473
475
  )
474
476
 
475
477
  # Align attention_mask to compiled shape
476
- if self.use_position_ids:
478
+ if self.rbln_config.use_position_ids:
477
479
  chunked_attention_mask = torch.nn.functional.pad(
478
- chunked_attention_mask, (0, self.max_seq_len - query_length)
480
+ chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
479
481
  )
480
482
 
481
483
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
482
484
  padding_size = 0
483
- if query_length % self.prefill_chunk_size != 0:
484
- padding_size = (self.prefill_chunk_size - query_length) % self.prefill_chunk_size
485
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
486
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
485
487
  # inputs_embeds
486
488
  if inputs.dim() == 3:
487
489
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -549,65 +551,71 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
549
551
  ) = self._prepare_prefill_inputs(
550
552
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
551
553
  )
552
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
553
554
  if not is_external_block_tables:
554
555
  local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
555
557
 
556
- if self.use_attention_mask and self.use_position_ids:
557
- chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
558
+ if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
558
560
 
559
561
  # Process input in chunks of size `prefill_chunk_size`
560
- for step in range(0, query_length, self.prefill_chunk_size):
562
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
561
563
  # Extract the current chunk of inputs and cache positions
562
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
563
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
564
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
564
566
  position_ids_chunk = (
565
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
567
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
+ if position_ids is not None
569
+ else None
566
570
  )
567
571
 
568
572
  # Not used in Gemma3 yet.
569
- if self.use_attention_mask:
570
- if self.use_position_ids:
571
- chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
572
- batch_idx, step : step + self.prefill_chunk_size
573
+ if self.rbln_config.use_attention_mask:
574
+ if self.rbln_config.use_position_ids:
575
+ chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = self.dec_attn_mask[
576
+ batch_idx, step : step + self.rbln_config.prefill_chunk_size
573
577
  ]
574
578
  else:
575
579
  # Update attention mask to ensure proper causal behavior
576
- if step >= self.prefill_chunk_size:
577
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
578
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
580
+ if step >= self.rbln_config.prefill_chunk_size:
581
+ chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
582
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = (
583
+ self.causal_mask
584
+ )
579
585
 
580
586
  # Define query position
581
587
  query_position = (
582
588
  torch.sum(
583
- chunked_attention_mask[0][step : step + self.prefill_chunk_size], dim=-1, dtype=torch.int16
589
+ chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
590
+ dim=-1,
591
+ dtype=torch.int16,
584
592
  ).squeeze(0)
585
593
  - 1
586
594
  )
587
595
  if token_type_ids_padded[:, step] == 1:
588
- if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
596
+ if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
589
597
  raise ValueError("All tokens of image_prefill should be the same image.")
590
598
  else:
591
599
  logits = self.image_prefill(
592
600
  input_chunk,
593
- chunked_attention_mask,
594
601
  cache_pos_chunk,
595
- position_ids_chunk,
596
- query_position,
597
602
  block_tables,
598
603
  local_block_tables,
604
+ query_position,
605
+ chunked_attention_mask,
606
+ position_ids_chunk,
599
607
  out=out_buffers,
600
608
  )
601
609
  else:
602
610
  # Forward pass for the current chunk
603
611
  logits = self.prefill(
604
612
  input_chunk,
605
- chunked_attention_mask,
606
613
  cache_pos_chunk,
607
- position_ids_chunk,
608
- query_position,
609
614
  block_tables,
610
615
  local_block_tables,
616
+ query_position,
617
+ chunked_attention_mask,
618
+ position_ids_chunk,
611
619
  out=out_buffers,
612
620
  )
613
621
 
@@ -647,7 +655,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
647
655
  if local_block_tables is not None
648
656
  else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
649
657
  )
650
- if self.use_attention_mask and attention_mask is None:
658
+ if self.rbln_config.use_attention_mask and attention_mask is None:
651
659
  for b_idx in range(batch_size):
652
660
  decoding_step = cache_position[b_idx].item()
653
661
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -664,7 +672,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
664
672
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
665
673
  attention_mask = attention_mask[: self.batch_size]
666
674
 
667
- logits = self.decode(inputs, attention_mask, cache_position, position_ids, block_tables, local_block_tables)
675
+ logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
668
676
 
669
677
  return RBLNDecoderOnlyOutput(logits=logits)
670
678
 
@@ -701,7 +709,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
701
709
  dtype=torch.int16,
702
710
  ).fill_(-1)
703
711
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
704
-
705
712
  self.prefill_decoder = RBLNGemma3RuntimeModel(
706
713
  runtime=self.model[0],
707
714
  image_prefill=self.model[1],
@@ -711,14 +718,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
711
718
  batch_size=self.rbln_config.batch_size,
712
719
  dec_attn_mask=dec_attn_mask,
713
720
  block_tables=block_tables,
714
- free_block_pool=free_block_pool,
715
- kvcache_block_size=self.rbln_config.kvcache_block_size,
716
721
  vocab_size=self.config.vocab_size,
717
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
718
- max_seq_len=self.rbln_config.max_seq_len,
719
- use_attention_mask=self.rbln_config.use_attention_mask,
720
- attn_impl=self.rbln_config.attn_impl,
721
- use_position_ids=self.rbln_config.use_position_ids,
722
+ free_block_pool=free_block_pool,
723
+ rbln_config=self.rbln_config,
722
724
  )
723
725
 
724
726
  self.decoders = {}
@@ -732,10 +734,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
732
734
  dec_attn_mask=dec_attn_mask,
733
735
  block_tables=block_tables,
734
736
  free_block_pool=free_block_pool,
735
- kvcache_block_size=self.rbln_config.kvcache_block_size,
736
- use_attention_mask=self.rbln_config.use_attention_mask,
737
- attn_impl=self.rbln_config.attn_impl,
738
- use_position_ids=self.rbln_config.use_position_ids,
737
+ rbln_config=self.rbln_config,
739
738
  )
740
739
 
741
740
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -752,81 +751,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
752
751
  return embed_tokens
753
752
 
754
753
  @classmethod
755
- def get_input_info(
756
- cls,
757
- batch_size: int,
758
- query_length: int,
759
- use_inputs_embeds: bool,
760
- use_attention_mask: bool,
761
- use_position_ids: bool,
762
- max_seq_len: int,
763
- kvcache_block_size: int,
764
- kvcache_num_blocks: int,
765
- num_key_value_heads: int,
766
- num_hidden_layers: int,
767
- hidden_size: int,
768
- head_dim: int,
769
- sliding_window: int,
770
- sliding_window_pattern: int,
771
- dec_batch_size: int,
772
- ):
773
- if use_inputs_embeds:
774
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
775
- else:
776
- main_input = ("input_ids", [batch_size, query_length], "int64")
777
-
778
- input_info = [
779
- main_input,
780
- (
781
- "attention_mask",
782
- [batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
783
- "float32",
784
- ),
785
- (
786
- "cache_position",
787
- [batch_size, query_length],
788
- "int32",
789
- ),
790
- (
791
- "position_ids",
792
- [batch_size, query_length],
793
- "int32",
794
- ),
795
- ]
796
-
797
- if query_length > 1:
798
- input_info.extend(
799
- [
800
- ("query_position", [], "int16"),
801
- ]
802
- )
803
-
804
- max_block_cnt = max_seq_len // kvcache_block_size
805
-
806
- if query_length > 1:
807
- input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
808
- input_info.extend([("local_block_tables", [1], "int16")])
809
- else:
810
- input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
811
- input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
812
-
813
- def is_sliding(layer_idx: int) -> bool:
814
- return bool((layer_idx + 1) % sliding_window_pattern)
815
-
816
- local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
817
- global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
818
- input_info.extend(
819
- [
820
- (
821
- f"past_key_values_{i}",
822
- local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
823
- "float32",
824
- )
825
- for i in range(num_hidden_layers * 2)
754
+ def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
755
+ sliding_window = getattr(model_config, "sliding_window", None)
756
+ sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
757
+ if sliding_window_pattern <= model_config.num_hidden_layers:
758
+ rbln_config.cache_impl = "hybrid"
759
+ rbln_config.sliding_window = sliding_window
760
+ rbln_config.sliding_window_layers = [
761
+ i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
826
762
  ]
827
- )
828
763
 
829
- return input_info
764
+ return rbln_config
830
765
 
831
766
  @classmethod
832
767
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
@@ -847,102 +782,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
847
782
  model_config: Optional["PretrainedConfig"] = None,
848
783
  rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
849
784
  ) -> RBLNGemma3ForCausalLMConfig:
850
- if rbln_config.max_seq_len is None:
851
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None)
852
- if rbln_config.max_seq_len is None:
853
- raise ValueError("`max_seq_len` should be specified.")
854
-
855
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
856
- attn_impl=rbln_config.attn_impl,
857
- kvcache_partition_len=rbln_config.kvcache_partition_len,
858
- kvcache_block_size=rbln_config.kvcache_block_size,
859
- max_seq_len=rbln_config.max_seq_len,
860
- )
861
-
862
- validate_attention_method(
863
- attn_impl=rbln_config.attn_impl,
864
- kvcache_partition_len=rbln_config.kvcache_partition_len,
865
- kvcache_block_size=rbln_config.kvcache_block_size,
866
- max_seq_len=rbln_config.max_seq_len,
867
- )
868
-
869
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
870
- max_num_blocks = required_num_blocks
785
+ # Update rbln_config with super class
786
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
871
787
 
872
- if rbln_config.attn_impl == "flash_attn":
873
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
874
- if max_num_blocks < flash_min_blocks:
875
- max_num_blocks = flash_min_blocks
876
-
877
- if max_num_blocks < rbln_config.batch_size:
878
- raise RuntimeError(
879
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
880
- "Ensure the number of blocks is at least equal to the batch size."
881
- )
882
-
883
- if rbln_config.kvcache_num_blocks is None:
884
- rbln_config.kvcache_num_blocks = max_num_blocks
885
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
886
- logger.warning(
887
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
888
- f" than the estimated maximum number of blocks ({max_num_blocks})."
889
- "This can cause a failure during model compilation."
890
- )
891
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
892
-
893
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
894
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
895
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
896
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
897
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
898
- sliding_window = getattr(model_config, "sliding_window", None)
899
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
900
-
901
- prefill_input_info = cls.get_input_info(
902
- batch_size=1,
903
- query_length=rbln_config.prefill_chunk_size,
904
- use_inputs_embeds=rbln_config.use_inputs_embeds,
905
- use_attention_mask=rbln_config.use_attention_mask,
906
- use_position_ids=rbln_config.use_position_ids,
907
- max_seq_len=rbln_config.max_seq_len,
908
- kvcache_block_size=rbln_config.kvcache_block_size,
909
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
910
- num_key_value_heads=num_key_value_heads,
911
- num_hidden_layers=num_hidden_layers,
912
- hidden_size=hidden_size,
913
- head_dim=head_dim,
914
- sliding_window=sliding_window,
915
- sliding_window_pattern=sliding_window_pattern,
916
- dec_batch_size=max(rbln_config.decoder_batch_sizes),
917
- )
918
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
788
+ # Assume that prefill compile config is at index 0
789
+ compile_cfgs = rbln_config.compile_cfgs
919
790
  image_prefill_compile_config = RBLNCompileConfig(
920
- compiled_model_name="image_prefill", input_info=prefill_input_info
791
+ compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
921
792
  )
922
-
923
- dec_compile_configs = []
924
- for batch_size in rbln_config.decoder_batch_sizes:
925
- dec_input_info = cls.get_input_info(
926
- batch_size=batch_size,
927
- query_length=1,
928
- use_inputs_embeds=rbln_config.use_inputs_embeds,
929
- use_attention_mask=rbln_config.use_attention_mask,
930
- use_position_ids=rbln_config.use_position_ids,
931
- max_seq_len=rbln_config.max_seq_len,
932
- kvcache_block_size=rbln_config.kvcache_block_size,
933
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
934
- num_key_value_heads=num_key_value_heads,
935
- num_hidden_layers=num_hidden_layers,
936
- hidden_size=hidden_size,
937
- head_dim=head_dim,
938
- sliding_window=sliding_window,
939
- sliding_window_pattern=sliding_window_pattern,
940
- dec_batch_size=batch_size,
941
- )
942
- dec_compile_configs.append(
943
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
944
- )
945
- rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
793
+ # Insert image_prefill compile config at index 1
794
+ image_idx = 1
795
+ compile_cfgs.insert(image_idx, image_prefill_compile_config)
796
+ rbln_config.set_compile_cfgs(compile_cfgs)
946
797
 
947
798
  return rbln_config
948
799
 
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for GPT-2 causal language model.
21
+ Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
+ """
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
45
45
  )
46
46
  new_layer = GPT2Layer(layer, new_self_attn)
47
47
  new_layers.append(new_layer)
48
- new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
48
+ new_model = GPT2Model(
49
+ causal_lm.transformer,
50
+ new_layers,
51
+ max_seq_len=max_seq_len,
52
+ sliding_window_layers=self.sliding_window_layers,
53
+ )
49
54
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
50
55
  return new_causal_lm
51
56
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
22
22
 
23
23
 
24
24
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNIdefics3ForConditionalGeneration models.
27
+
28
+ This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
29
+ It allows configuration of the batch size and separate configurations for the vision and text submodules.
30
+
31
+ Attributes:
32
+ submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
33
+ """
34
+
25
35
  submodules = ["vision_model", "text_model"]
26
36
 
27
37
  def __init__(
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
29
39
  batch_size: Optional[int] = None,
30
40
  vision_model: Optional[RBLNModelConfig] = None,
31
41
  text_model: Optional[RBLNModelConfig] = None,
32
- **kwargs,
42
+ **kwargs: Dict[str, Any],
33
43
  ):
34
44
  """
35
45
  Args:
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
102
102
  subfolder: str,
103
103
  rbln_config: RBLNModelConfig,
104
104
  ):
105
- """
106
- If you are unavoidably running on a CPU rather than an RBLN device,
107
- store the torch tensor, weight, etc. in this function.
108
- """
105
+ # If you are unavoidably running on a CPU rather than an RBLN device,
106
+ # store the torch tensor, weight, etc. in this function.
107
+
109
108
  save_dict = {}
110
109
  save_dict["embeddings"] = model.get_input_embeddings().state_dict()
111
110
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
190
189
 
191
190
 
192
191
  class RBLNIdefics3ForConditionalGeneration(RBLNModel):
192
+ """
193
+ RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
194
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
195
+
196
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
197
+
198
+ Important Note:
199
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
200
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
201
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
202
+
203
+ Examples:
204
+ ```python
205
+ from optimum.rbln import RBLNIdefics3ForConditionalGeneration
206
+
207
+ model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
208
+ "HuggingFaceM4/idefics3-8b",
209
+ export=True,
210
+ rbln_config={
211
+ "vision_model": {
212
+ "device": 0,
213
+ },
214
+ "text_model": {
215
+ "batch_size": 1,
216
+ "max_seq_len": 131_072,
217
+ "tensor_parallel_size": 8,
218
+ "use_inputs_embeds": True,
219
+ "attn_impl": "flash_attn",
220
+ "kvcache_partition_len": 16_384,
221
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
222
+ },
223
+ },
224
+ )
225
+
226
+ model.save_pretrained("compiled-idefics3-8b")
227
+ ```
228
+ """
229
+
193
230
  auto_model_class = AutoModelForVision2Seq
194
231
  _rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
195
232
  _rbln_submodule_prefix = "model"
@@ -16,4 +16,27 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for RBLN Llama models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNLlamaForCausalLM, RBLNLlamaForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNLlamaForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNLlamaForCausalLM.from_pretrained(
37
+ "meta-llama/Llama-2-7b-hf",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """