optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -32,10 +32,6 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
32
32
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
33
33
  from ....modeling import RBLNModel
34
34
  from ....utils.logging import get_logger
35
- from ..decoderonly.decoderonly_architecture import (
36
- set_default_values,
37
- validate_attention_method,
38
- )
39
35
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM, RBLNDecoderOnlyOutput, RBLNRuntimeModel
40
36
  from .configuration_gemma3 import RBLNGemma3ForCausalLMConfig
41
37
  from .gemma3_architecture import Gemma3ForCausalLMWrapper
@@ -215,15 +211,16 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
215
211
 
216
212
  return model_kwargs
217
213
 
218
- def get_image_features(self, pixel_values: torch.Tensor):
214
+ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
219
215
  """
220
216
  Projects the last hidden state from the vision model into language model space.
221
217
 
222
218
  Args:
223
- pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
224
- The tensors corresponding to the input images.
219
+ pixel_values: (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`)
220
+ The tensors corresponding to the input images.
221
+
225
222
  Returns:
226
- image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
223
+ Image feature tensor of shape `(num_images, image_length, embed_dim)`.
227
224
  """
228
225
  vision_outputs = self.vision_tower(pixel_values).last_hidden_state
229
226
  image_features = self.multi_modal_projector(vision_outputs)
@@ -272,7 +269,7 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
272
269
  padded_cache_lengths: Optional[torch.Tensor] = None,
273
270
  position_ids: Optional[torch.Tensor] = None,
274
271
  token_type_ids: Optional[torch.Tensor] = None,
275
- **lm_kwargs,
272
+ **lm_kwargs: Dict[str, Any],
276
273
  ) -> Union[Tuple, RBLNDecoderOnlyOutput]:
277
274
  # prefill
278
275
  if cache_position is None:
@@ -329,7 +326,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
329
326
  attention_mask: torch.Tensor,
330
327
  position_ids: torch.Tensor,
331
328
  token_type_ids: Optional[torch.Tensor] = None,
332
- ):
329
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]:
333
330
  """
334
331
  Pads inputs, attention_mask, and position_ids so image token groups (256 tokens with token_type_ids == 1)
335
332
  start at multiples of prefill_chunk_size (256). Returns padded tensors and total padded length.
@@ -341,7 +338,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
341
338
  token_type_ids: (1, seq_len) tensor, 0 for text, 1 for image.
342
339
 
343
340
  Returns:
344
- Tuple: (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
341
+ (inputs_padded, attention_mask_padded, position_ids_padded, padded_len, token_type_ids_padded).
345
342
  """
346
343
 
347
344
  if token_type_ids is None:
@@ -352,16 +349,17 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
352
349
  # Find image start positions
353
350
  image_starts = [
354
351
  s
355
- for s in range(seq_len - self.prefill_chunk_size + 1)
356
- if torch.all(token_type_ids[:, s : s + self.prefill_chunk_size] == 1)
352
+ for s in range(seq_len - self.rbln_config.prefill_chunk_size + 1)
353
+ if torch.all(token_type_ids[:, s : s + self.rbln_config.prefill_chunk_size] == 1)
357
354
  ]
358
355
 
359
356
  # Initialize padded tensors
360
357
  padded_input_len = seq_len
361
358
  for image_start in image_starts:
362
359
  pad_needed = (
363
- self.prefill_chunk_size - (image_start + padded_input_len - seq_len) % self.prefill_chunk_size
364
- ) % self.prefill_chunk_size
360
+ self.rbln_config.prefill_chunk_size
361
+ - (image_start + padded_input_len - seq_len) % self.rbln_config.prefill_chunk_size
362
+ ) % self.rbln_config.prefill_chunk_size
365
363
  padded_input_len += pad_needed
366
364
  total_padding = padded_input_len - seq_len
367
365
 
@@ -390,7 +388,9 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
390
388
  src_pos = image_start
391
389
 
392
390
  # Padding
393
- pad_needed = (self.prefill_chunk_size - dest_pos % self.prefill_chunk_size) % self.prefill_chunk_size
391
+ pad_needed = (
392
+ self.rbln_config.prefill_chunk_size - dest_pos % self.rbln_config.prefill_chunk_size
393
+ ) % self.rbln_config.prefill_chunk_size
394
394
  if pad_needed and dest_pos < padded_input_len:
395
395
  position_ids_padded[:, dest_pos : dest_pos + pad_needed] = torch.arange(
396
396
  last_pos_id + 1, last_pos_id + pad_needed + 1, dtype=position_ids.dtype
@@ -399,21 +399,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
399
399
 
400
400
  # Image segment
401
401
  if src_pos < seq_len and src_pos == image_start:
402
- inputs_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = inputs[
403
- :, src_pos : src_pos + self.prefill_chunk_size
402
+ inputs_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = inputs[
403
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
404
404
  ]
405
- attention_mask_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = attention_mask[
406
- :, src_pos : src_pos + self.prefill_chunk_size
405
+ attention_mask_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = attention_mask[
406
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
407
407
  ]
408
- position_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = position_ids[
409
- :, src_pos : src_pos + self.prefill_chunk_size
408
+ position_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = position_ids[
409
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
410
410
  ]
411
- token_type_ids_padded[:, dest_pos : dest_pos + self.prefill_chunk_size] = token_type_ids[
412
- :, src_pos : src_pos + self.prefill_chunk_size
411
+ token_type_ids_padded[:, dest_pos : dest_pos + self.rbln_config.prefill_chunk_size] = token_type_ids[
412
+ :, src_pos : src_pos + self.rbln_config.prefill_chunk_size
413
413
  ]
414
- dest_pos += self.prefill_chunk_size
415
- src_pos += self.prefill_chunk_size
416
- last_pos_id = position_ids[0, image_start + self.prefill_chunk_size - 1].item()
414
+ dest_pos += self.rbln_config.prefill_chunk_size
415
+ src_pos += self.rbln_config.prefill_chunk_size
416
+ last_pos_id = position_ids[0, image_start + self.rbln_config.prefill_chunk_size - 1].item()
417
417
 
418
418
  return inputs_padded, attention_mask_padded, position_ids_padded, total_padding, token_type_ids_padded
419
419
 
@@ -444,11 +444,13 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
444
444
 
445
445
  seq_len = inputs.shape[1]
446
446
  # Initialize attention mask for chunked processing
447
- if self.use_attention_mask:
447
+ if self.rbln_config.use_attention_mask:
448
448
  chunked_attention_mask = (
449
449
  torch.ones(1, seq_len, dtype=torch.float32)
450
- if self.use_position_ids
451
- else torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
450
+ if self.rbln_config.use_position_ids
451
+ else torch.zeros(
452
+ 1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32
453
+ )
452
454
  )
453
455
  else:
454
456
  chunked_attention_mask = None
@@ -467,21 +469,21 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
467
469
  )
468
470
 
469
471
  query_length = inputs.shape[1]
470
- if query_length > self.max_seq_len:
472
+ if query_length > self.rbln_config.max_seq_len:
471
473
  raise ValueError(
472
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
474
+ f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
473
475
  )
474
476
 
475
477
  # Align attention_mask to compiled shape
476
- if self.use_position_ids:
478
+ if self.rbln_config.use_position_ids:
477
479
  chunked_attention_mask = torch.nn.functional.pad(
478
- chunked_attention_mask, (0, self.max_seq_len - query_length)
480
+ chunked_attention_mask, (0, self.rbln_config.max_seq_len - query_length)
479
481
  )
480
482
 
481
483
  # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
482
484
  padding_size = 0
483
- if query_length % self.prefill_chunk_size != 0:
484
- padding_size = (self.prefill_chunk_size - query_length) % self.prefill_chunk_size
485
+ if query_length % self.rbln_config.prefill_chunk_size != 0:
486
+ padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
485
487
  # inputs_embeds
486
488
  if inputs.dim() == 3:
487
489
  inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
@@ -539,7 +541,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
539
541
  (
540
542
  inputs,
541
543
  cache_position,
542
- chunked_attention_mask,
544
+ padded_attention_mask,
543
545
  out_buffers,
544
546
  position_ids,
545
547
  position_embed,
@@ -549,65 +551,63 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
549
551
  ) = self._prepare_prefill_inputs(
550
552
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
551
553
  )
552
- self.dec_attn_mask[batch_idx : batch_idx + 1] = chunked_attention_mask[:1]
553
554
  if not is_external_block_tables:
554
555
  local_block_tables = torch.tensor([batch_idx], dtype=torch.int16)
556
+ self.dec_attn_mask[batch_idx : batch_idx + 1] = padded_attention_mask[:1]
555
557
 
556
- if self.use_attention_mask and self.use_position_ids:
557
- chunked_attention_mask = torch.zeros(1, self.max_seq_len, dtype=torch.float32)
558
+ if self.rbln_config.use_attention_mask and self.rbln_config.use_position_ids:
559
+ chunked_attention_mask = torch.zeros(1, self.rbln_config.max_seq_len, dtype=torch.float32)
558
560
 
559
561
  # Process input in chunks of size `prefill_chunk_size`
560
- for step in range(0, query_length, self.prefill_chunk_size):
562
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
561
563
  # Extract the current chunk of inputs and cache positions
562
- input_chunk = inputs[:, step : step + self.prefill_chunk_size]
563
- cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
564
+ input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
565
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
564
566
  position_ids_chunk = (
565
- position_ids[:, step : step + self.prefill_chunk_size] if position_ids is not None else None
567
+ position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
568
+ if position_ids is not None
569
+ else None
566
570
  )
567
571
 
568
- # Not used in Gemma3 yet.
569
- if self.use_attention_mask:
570
- if self.use_position_ids:
571
- chunked_attention_mask[0, step : step + self.prefill_chunk_size] = self.dec_attn_mask[
572
- batch_idx, step : step + self.prefill_chunk_size
573
- ]
574
- else:
575
- # Update attention mask to ensure proper causal behavior
576
- if step >= self.prefill_chunk_size:
577
- chunked_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
578
- chunked_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
572
+ if self.rbln_config.use_attention_mask:
573
+ if self.rbln_config.use_position_ids:
574
+ chunked_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size] = (
575
+ padded_attention_mask[0, step : step + self.rbln_config.prefill_chunk_size]
576
+ )
579
577
 
580
578
  # Define query position
581
579
  query_position = (
582
580
  torch.sum(
583
- chunked_attention_mask[0][step : step + self.prefill_chunk_size], dim=-1, dtype=torch.int16
581
+ chunked_attention_mask[0][step : step + self.rbln_config.prefill_chunk_size],
582
+ dim=-1,
583
+ dtype=torch.int16,
584
584
  ).squeeze(0)
585
585
  - 1
586
586
  )
587
587
  if token_type_ids_padded[:, step] == 1:
588
- if torch.any(token_type_ids_padded[:, step : step + self.prefill_chunk_size] == 0):
588
+ if torch.any(token_type_ids_padded[:, step : step + self.rbln_config.prefill_chunk_size] == 0):
589
589
  raise ValueError("All tokens of image_prefill should be the same image.")
590
590
  else:
591
591
  logits = self.image_prefill(
592
592
  input_chunk,
593
- chunked_attention_mask,
594
593
  cache_pos_chunk,
595
- position_ids_chunk,
596
- query_position,
597
594
  block_tables,
598
595
  local_block_tables,
596
+ query_position,
597
+ chunked_attention_mask,
598
+ position_ids_chunk,
599
599
  out=out_buffers,
600
600
  )
601
601
  else:
602
602
  # Forward pass for the current chunk
603
603
  logits = self.prefill(
604
604
  input_chunk,
605
- chunked_attention_mask,
606
605
  cache_pos_chunk,
607
- position_ids_chunk,
608
- query_position,
609
606
  block_tables,
610
607
  local_block_tables,
608
+ query_position,
609
+ chunked_attention_mask,
610
+ position_ids_chunk,
611
611
  out=out_buffers,
612
612
  )
613
613
 
@@ -647,7 +647,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
647
647
  if local_block_tables is not None
648
648
  else torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
649
649
  )
650
- if self.use_attention_mask and attention_mask is None:
650
+ if self.rbln_config.use_attention_mask and attention_mask is None:
651
651
  for b_idx in range(batch_size):
652
652
  decoding_step = cache_position[b_idx].item()
653
653
  if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
@@ -664,7 +664,7 @@ class RBLNGemma3RuntimeModel(RBLNRuntimeModel):
664
664
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
665
665
  attention_mask = attention_mask[: self.batch_size]
666
666
 
667
- logits = self.decode(inputs, attention_mask, cache_position, position_ids, block_tables, local_block_tables)
667
+ logits = self.decode(inputs, cache_position, block_tables, local_block_tables, attention_mask, position_ids)
668
668
 
669
669
  return RBLNDecoderOnlyOutput(logits=logits)
670
670
 
@@ -701,7 +701,6 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
701
701
  dtype=torch.int16,
702
702
  ).fill_(-1)
703
703
  free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
704
-
705
704
  self.prefill_decoder = RBLNGemma3RuntimeModel(
706
705
  runtime=self.model[0],
707
706
  image_prefill=self.model[1],
@@ -711,14 +710,9 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
711
710
  batch_size=self.rbln_config.batch_size,
712
711
  dec_attn_mask=dec_attn_mask,
713
712
  block_tables=block_tables,
714
- free_block_pool=free_block_pool,
715
- kvcache_block_size=self.rbln_config.kvcache_block_size,
716
713
  vocab_size=self.config.vocab_size,
717
- prefill_chunk_size=self.rbln_config.prefill_chunk_size,
718
- max_seq_len=self.rbln_config.max_seq_len,
719
- use_attention_mask=self.rbln_config.use_attention_mask,
720
- attn_impl=self.rbln_config.attn_impl,
721
- use_position_ids=self.rbln_config.use_position_ids,
714
+ free_block_pool=free_block_pool,
715
+ rbln_config=self.rbln_config,
722
716
  )
723
717
 
724
718
  self.decoders = {}
@@ -732,10 +726,7 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
732
726
  dec_attn_mask=dec_attn_mask,
733
727
  block_tables=block_tables,
734
728
  free_block_pool=free_block_pool,
735
- kvcache_block_size=self.rbln_config.kvcache_block_size,
736
- use_attention_mask=self.rbln_config.use_attention_mask,
737
- attn_impl=self.rbln_config.attn_impl,
738
- use_position_ids=self.rbln_config.use_position_ids,
729
+ rbln_config=self.rbln_config,
739
730
  )
740
731
 
741
732
  # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
@@ -752,81 +743,17 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
752
743
  return embed_tokens
753
744
 
754
745
  @classmethod
755
- def get_input_info(
756
- cls,
757
- batch_size: int,
758
- query_length: int,
759
- use_inputs_embeds: bool,
760
- use_attention_mask: bool,
761
- use_position_ids: bool,
762
- max_seq_len: int,
763
- kvcache_block_size: int,
764
- kvcache_num_blocks: int,
765
- num_key_value_heads: int,
766
- num_hidden_layers: int,
767
- hidden_size: int,
768
- head_dim: int,
769
- sliding_window: int,
770
- sliding_window_pattern: int,
771
- dec_batch_size: int,
772
- ):
773
- if use_inputs_embeds:
774
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
775
- else:
776
- main_input = ("input_ids", [batch_size, query_length], "int64")
777
-
778
- input_info = [
779
- main_input,
780
- (
781
- "attention_mask",
782
- [batch_size, 1, query_length, max_seq_len] if not use_position_ids else [batch_size, max_seq_len],
783
- "float32",
784
- ),
785
- (
786
- "cache_position",
787
- [batch_size, query_length],
788
- "int32",
789
- ),
790
- (
791
- "position_ids",
792
- [batch_size, query_length],
793
- "int32",
794
- ),
795
- ]
796
-
797
- if query_length > 1:
798
- input_info.extend(
799
- [
800
- ("query_position", [], "int16"),
801
- ]
802
- )
803
-
804
- max_block_cnt = max_seq_len // kvcache_block_size
805
-
806
- if query_length > 1:
807
- input_info.extend([("global_block_tables", [max_block_cnt], "int16")])
808
- input_info.extend([("local_block_tables", [1], "int16")])
809
- else:
810
- input_info.extend([("global_block_tables", [batch_size, max_block_cnt], "int16")])
811
- input_info.extend([("local_block_tables", [batch_size, 1], "int16")])
812
-
813
- def is_sliding(layer_idx: int) -> bool:
814
- return bool((layer_idx + 1) % sliding_window_pattern)
815
-
816
- local_kvcache_shape = [dec_batch_size, num_key_value_heads, sliding_window, head_dim]
817
- global_kvcache_shape = [kvcache_num_blocks, num_key_value_heads, kvcache_block_size, head_dim]
818
- input_info.extend(
819
- [
820
- (
821
- f"past_key_values_{i}",
822
- local_kvcache_shape if is_sliding(i // 2) else global_kvcache_shape,
823
- "float32",
824
- )
825
- for i in range(num_hidden_layers * 2)
746
+ def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
747
+ sliding_window = getattr(model_config, "sliding_window", None)
748
+ sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
749
+ if sliding_window_pattern <= model_config.num_hidden_layers:
750
+ rbln_config.cache_impl = "hybrid"
751
+ rbln_config.sliding_window = sliding_window
752
+ rbln_config.sliding_window_layers = [
753
+ i for i in range(model_config.num_hidden_layers) if (i + 1) % sliding_window_pattern > 0
826
754
  ]
827
- )
828
755
 
829
- return input_info
756
+ return rbln_config
830
757
 
831
758
  @classmethod
832
759
  def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
@@ -847,102 +774,18 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
847
774
  model_config: Optional["PretrainedConfig"] = None,
848
775
  rbln_config: Optional[RBLNGemma3ForCausalLMConfig] = None,
849
776
  ) -> RBLNGemma3ForCausalLMConfig:
850
- if rbln_config.max_seq_len is None:
851
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None)
852
- if rbln_config.max_seq_len is None:
853
- raise ValueError("`max_seq_len` should be specified.")
854
-
855
- rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
856
- attn_impl=rbln_config.attn_impl,
857
- kvcache_partition_len=rbln_config.kvcache_partition_len,
858
- kvcache_block_size=rbln_config.kvcache_block_size,
859
- max_seq_len=rbln_config.max_seq_len,
860
- )
861
-
862
- validate_attention_method(
863
- attn_impl=rbln_config.attn_impl,
864
- kvcache_partition_len=rbln_config.kvcache_partition_len,
865
- kvcache_block_size=rbln_config.kvcache_block_size,
866
- max_seq_len=rbln_config.max_seq_len,
867
- )
777
+ # Update rbln_config with super class
778
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
868
779
 
869
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
870
- max_num_blocks = required_num_blocks
871
-
872
- if rbln_config.attn_impl == "flash_attn":
873
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
874
- if max_num_blocks < flash_min_blocks:
875
- max_num_blocks = flash_min_blocks
876
-
877
- if max_num_blocks < rbln_config.batch_size:
878
- raise RuntimeError(
879
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
880
- "Ensure the number of blocks is at least equal to the batch size."
881
- )
882
-
883
- if rbln_config.kvcache_num_blocks is None:
884
- rbln_config.kvcache_num_blocks = max_num_blocks
885
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
886
- logger.warning(
887
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
888
- f" than the estimated maximum number of blocks ({max_num_blocks})."
889
- "This can cause a failure during model compilation."
890
- )
891
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
892
-
893
- num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
894
- num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
895
- num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
896
- hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
897
- head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
898
- sliding_window = getattr(model_config, "sliding_window", None)
899
- sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
900
-
901
- prefill_input_info = cls.get_input_info(
902
- batch_size=1,
903
- query_length=rbln_config.prefill_chunk_size,
904
- use_inputs_embeds=rbln_config.use_inputs_embeds,
905
- use_attention_mask=rbln_config.use_attention_mask,
906
- use_position_ids=rbln_config.use_position_ids,
907
- max_seq_len=rbln_config.max_seq_len,
908
- kvcache_block_size=rbln_config.kvcache_block_size,
909
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
910
- num_key_value_heads=num_key_value_heads,
911
- num_hidden_layers=num_hidden_layers,
912
- hidden_size=hidden_size,
913
- head_dim=head_dim,
914
- sliding_window=sliding_window,
915
- sliding_window_pattern=sliding_window_pattern,
916
- dec_batch_size=max(rbln_config.decoder_batch_sizes),
917
- )
918
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
780
+ # Assume that prefill compile config is at index 0
781
+ compile_cfgs = rbln_config.compile_cfgs
919
782
  image_prefill_compile_config = RBLNCompileConfig(
920
- compiled_model_name="image_prefill", input_info=prefill_input_info
783
+ compiled_model_name="image_prefill", input_info=compile_cfgs[0].input_info
921
784
  )
922
-
923
- dec_compile_configs = []
924
- for batch_size in rbln_config.decoder_batch_sizes:
925
- dec_input_info = cls.get_input_info(
926
- batch_size=batch_size,
927
- query_length=1,
928
- use_inputs_embeds=rbln_config.use_inputs_embeds,
929
- use_attention_mask=rbln_config.use_attention_mask,
930
- use_position_ids=rbln_config.use_position_ids,
931
- max_seq_len=rbln_config.max_seq_len,
932
- kvcache_block_size=rbln_config.kvcache_block_size,
933
- kvcache_num_blocks=rbln_config.kvcache_num_blocks,
934
- num_key_value_heads=num_key_value_heads,
935
- num_hidden_layers=num_hidden_layers,
936
- hidden_size=hidden_size,
937
- head_dim=head_dim,
938
- sliding_window=sliding_window,
939
- sliding_window_pattern=sliding_window_pattern,
940
- dec_batch_size=batch_size,
941
- )
942
- dec_compile_configs.append(
943
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
944
- )
945
- rbln_config.set_compile_cfgs([prefill_compile_config, image_prefill_compile_config, *dec_compile_configs])
785
+ # Insert image_prefill compile config at index 1
786
+ image_idx = 1
787
+ compile_cfgs.insert(image_idx, image_prefill_compile_config)
788
+ rbln_config.set_compile_cfgs(compile_cfgs)
946
789
 
947
790
  return rbln_config
948
791
 
@@ -973,9 +816,11 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
973
816
  quantization.maybe_set_quantization_env()
974
817
  original_linear = torch.nn.functional.linear
975
818
  torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
976
- compiled_model = RBLNModel.compile(
819
+ compiled_model = cls.compile(
977
820
  wrapped_model,
978
821
  compile_config,
822
+ create_runtimes=rbln_config.create_runtimes,
823
+ device=rbln_config.device,
979
824
  example_inputs=example_inputs,
980
825
  compile_context=compile_context,
981
826
  )
@@ -16,4 +16,7 @@ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausa
16
16
 
17
17
 
18
18
  class RBLNGPT2LMHeadModelConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
- pass
19
+ """
20
+ Configuration class for GPT-2 causal language model.
21
+ Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
+ """
@@ -45,7 +45,12 @@ class GPT2Wrapper(DecoderOnlyWrapper):
45
45
  )
46
46
  new_layer = GPT2Layer(layer, new_self_attn)
47
47
  new_layers.append(new_layer)
48
- new_model = GPT2Model(causal_lm.transformer, new_layers, max_seq_len=max_seq_len)
48
+ new_model = GPT2Model(
49
+ causal_lm.transformer,
50
+ new_layers,
51
+ max_seq_len=max_seq_len,
52
+ sliding_window_layers=self.sliding_window_layers,
53
+ )
49
54
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
50
55
  return new_causal_lm
51
56
 
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
@@ -22,6 +22,16 @@ class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
22
22
 
23
23
 
24
24
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNIdefics3ForConditionalGeneration models.
27
+
28
+ This class extends `RBLNModelConfig` to include settings specific to the Idefics3 vision-language model optimized for RBLN devices.
29
+ It allows configuration of the batch size and separate configurations for the vision and text submodules.
30
+
31
+ Attributes:
32
+ submodules (List[str]): List of submodules included in the model. Defaults to `["vision_model", "text_model"]`.
33
+ """
34
+
25
35
  submodules = ["vision_model", "text_model"]
26
36
 
27
37
  def __init__(
@@ -29,7 +39,7 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
29
39
  batch_size: Optional[int] = None,
30
40
  vision_model: Optional[RBLNModelConfig] = None,
31
41
  text_model: Optional[RBLNModelConfig] = None,
32
- **kwargs,
42
+ **kwargs: Dict[str, Any],
33
43
  ):
34
44
  """
35
45
  Args:
@@ -102,10 +102,9 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
102
102
  subfolder: str,
103
103
  rbln_config: RBLNModelConfig,
104
104
  ):
105
- """
106
- If you are unavoidably running on a CPU rather than an RBLN device,
107
- store the torch tensor, weight, etc. in this function.
108
- """
105
+ # If you are unavoidably running on a CPU rather than an RBLN device,
106
+ # store the torch tensor, weight, etc. in this function.
107
+
109
108
  save_dict = {}
110
109
  save_dict["embeddings"] = model.get_input_embeddings().state_dict()
111
110
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -190,6 +189,44 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
190
189
 
191
190
 
192
191
  class RBLNIdefics3ForConditionalGeneration(RBLNModel):
192
+ """
193
+ RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
194
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
195
+
196
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
197
+
198
+ Important Note:
199
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
200
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
201
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNIdefics3ForConditionalGenerationConfig class for details.
202
+
203
+ Examples:
204
+ ```python
205
+ from optimum.rbln import RBLNIdefics3ForConditionalGeneration
206
+
207
+ model = RBLNIdefics3ForConditionalGeneration.from_pretrained(
208
+ "HuggingFaceM4/idefics3-8b",
209
+ export=True,
210
+ rbln_config={
211
+ "vision_model": {
212
+ "device": 0,
213
+ },
214
+ "text_model": {
215
+ "batch_size": 1,
216
+ "max_seq_len": 131_072,
217
+ "tensor_parallel_size": 8,
218
+ "use_inputs_embeds": True,
219
+ "attn_impl": "flash_attn",
220
+ "kvcache_partition_len": 16_384,
221
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
222
+ },
223
+ },
224
+ )
225
+
226
+ model.save_pretrained("compiled-idefics3-8b")
227
+ ```
228
+ """
229
+
193
230
  auto_model_class = AutoModelForVision2Seq
194
231
  _rbln_submodules = [{"name": "vision_model"}, {"name": "text_model"}]
195
232
  _rbln_submodule_prefix = "model"