optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. optimum/rbln/__init__.py +108 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +156 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +48 -21
  52. optimum/rbln/modeling_base.py +99 -22
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +92 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +91 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  157. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  158. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  159. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  160. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  161. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  162. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  163. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
  164. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  165. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  166. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  167. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  168. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  169. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  170. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  171. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  172. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  173. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  174. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  175. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
  176. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  177. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  178. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  179. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  180. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  181. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  182. optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
  183. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  184. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  185. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  186. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  187. optimum/rbln/utils/deprecation.py +213 -0
  188. optimum/rbln/utils/hub.py +14 -3
  189. optimum/rbln/utils/runtime_utils.py +60 -18
  190. optimum/rbln/utils/submodule.py +31 -9
  191. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  192. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  193. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  194. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  195. optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
  196. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -12,13 +12,32 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
21
- pass
25
+ """
26
+ Configuration class for RBLNIdefics3VisionTransformer.
27
+
28
+ This configuration class stores the configuration parameters specific to
29
+ RBLN-optimized Idefics3 vision transformer.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs: Any,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
22
41
 
23
42
 
24
43
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -39,17 +58,21 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
39
58
  batch_size: Optional[int] = None,
40
59
  vision_model: Optional[RBLNModelConfig] = None,
41
60
  text_model: Optional[RBLNModelConfig] = None,
42
- **kwargs: Dict[str, Any],
61
+ **kwargs: Any,
43
62
  ):
44
63
  """
45
64
  Args:
46
65
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
47
66
  vision_model (Optional[RBLNModelConfig]): Configuration for the vision transformer component.
67
+ This can include settings specific to the vision encoder, such as input resolution or other vision-related parameters.
68
+ If not provided, default settings will be used.
48
69
  text_model (Optional[RBLNModelConfig]): Configuration for the text model component.
49
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
70
+ This can include settings specific to the language model, such as tensor parallelism or other text-related parameters.
71
+ If not provided, default settings will be used.
72
+ kwargs: Additional arguments passed to the parent `RBLNModelConfig`.
50
73
 
51
74
  Raises:
52
- ValueError: If batch_size is not a positive integer.
75
+ ValueError: If `batch_size` is not a positive integer.
53
76
  """
54
77
 
55
78
  super().__init__(**kwargs)
@@ -57,5 +80,10 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
57
80
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
58
81
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
59
82
 
60
- self.vision_model = vision_model
61
- self.text_model = text_model
83
+ if self.batch_size != 1:
84
+ logger.warning("Ignore batch_size for Idefics3 vision transformer. It will be set to 1.")
85
+
86
+ self.vision_model = self.initialize_submodule_config(
87
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
88
+ )
89
+ self.text_model = self.initialize_submodule_config(submodule_config=text_model)
@@ -34,17 +34,12 @@ from transformers.models.idefics3.modeling_idefics3 import Idefics3CausalLMOutpu
34
34
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
- from ..decoderonly.modeling_decoderonly import (
38
- RBLNDecoderOnlyOutput,
39
- )
37
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
38
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
40
39
 
41
40
 
42
41
  if TYPE_CHECKING:
43
- from transformers import (
44
- AutoFeatureExtractor,
45
- AutoProcessor,
46
- AutoTokenizer,
47
- )
42
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
48
43
 
49
44
 
50
45
  class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
@@ -81,10 +76,12 @@ class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
81
76
 
82
77
  hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
83
78
 
84
- return super().forward(hidden_states.contiguous())
79
+ return super().forward(hidden_states.contiguous(), **kwargs)
85
80
 
86
81
 
87
82
  class RBLNIdefics3VisionTransformer(RBLNModel):
83
+ _tp_support = False
84
+
88
85
  def __post_init__(self, **kwargs):
89
86
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
90
87
  with no_init_weights():
@@ -113,7 +110,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
113
110
  return self.embeddings
114
111
 
115
112
  @classmethod
116
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
113
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
117
114
  class Idefics3VisionTransformerWrapper(torch.nn.Module):
118
115
  def __init__(self, model: "Idefics3VisionTransformer"):
119
116
  super().__init__()
@@ -124,9 +121,6 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
124
121
  encoder_outputs = self.encoder(
125
122
  inputs_embeds=hidden_states,
126
123
  attention_mask=patch_attention_mask,
127
- output_attentions=None,
128
- output_hidden_states=None,
129
- return_dict=False,
130
124
  )
131
125
  last_hidden_state = encoder_outputs[0]
132
126
  last_hidden_state = self.post_layernorm(last_hidden_state)
@@ -146,8 +140,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
146
140
  (
147
141
  "hidden_states",
148
142
  [
149
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
150
- 1,
143
+ rbln_config.batch_size,
151
144
  (model_config.image_size // model_config.patch_size) ** 2,
152
145
  model_config.hidden_size,
153
146
  ],
@@ -166,29 +159,31 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
166
159
  return_dict: Optional[bool] = None,
167
160
  **kwargs,
168
161
  ) -> Union[Tuple, BaseModelOutput]:
169
- batch_size = pixel_values.shape[0]
170
- last_hidden_state = []
171
- for i in range(batch_size):
162
+ last_hidden_state_size = [
163
+ pixel_values.shape[0],
164
+ (self.config.image_size // self.config.patch_size) ** 2,
165
+ self.config.hidden_size,
166
+ ]
167
+ last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=torch.float32, device="cpu")
168
+ for i in range(pixel_values.shape[0]):
172
169
  if patch_attention_mask is not None:
173
170
  batch_attention_mask = patch_attention_mask[i : i + 1,]
174
171
  else:
175
172
  batch_attention_mask = None
176
173
 
177
- batch_hidden_state = self.model(
174
+ self.model(
178
175
  pixel_values[i : i + 1,],
179
176
  batch_attention_mask,
177
+ out=last_hidden_state[i : i + 1,],
180
178
  return_dict=False,
181
179
  )
182
- last_hidden_state.append(batch_hidden_state)
183
- last_hidden_state = torch.cat(last_hidden_state, dim=0)
184
-
185
180
  if not return_dict:
186
181
  return (last_hidden_state,)
187
182
  else:
188
183
  return BaseModelOutput(last_hidden_state=last_hidden_state)
189
184
 
190
185
 
191
- class RBLNIdefics3ForConditionalGeneration(RBLNModel):
186
+ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
192
187
  """
193
188
  RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
194
189
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -245,9 +240,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
245
240
  return True
246
241
 
247
242
  @classmethod
248
- def get_pytorch_model(cls, *args, **kwargs):
249
- model = super().get_pytorch_model(*args, **kwargs)
250
-
243
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
251
244
  with no_init_weights():
252
245
  model_cls_name = model.model.text_model.__class__.__name__
253
246
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -276,7 +269,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
276
269
  return self.text_model.get_input_embeddings()
277
270
 
278
271
  @classmethod
279
- def wrap_model_if_needed(cls, model, rbln_config):
272
+ def _wrap_model_if_needed(cls, model, rbln_config):
280
273
  return model.model.connector
281
274
 
282
275
  @classmethod
@@ -291,8 +284,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
291
284
  (
292
285
  "image_hidden_states",
293
286
  [
294
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
295
- 1,
287
+ rbln_config.vision_model.batch_size,
296
288
  (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2,
297
289
  model_config.vision_config.hidden_size,
298
290
  ],
@@ -431,10 +423,15 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
431
423
  pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True
432
424
  ).last_hidden_state
433
425
 
434
- connector_outputs = []
426
+ connector_output_size = [
427
+ image_hidden_states.shape[0],
428
+ image_hidden_states.shape[1] // self.config.scale_factor**2,
429
+ self.config.text_config.hidden_size,
430
+ ]
431
+ connector_outputs = torch.empty(size=connector_output_size, dtype=torch.float32, device="cpu")
435
432
  for i in range(image_hidden_states.shape[0]):
436
- connector_outputs.append(self.connector(image_hidden_states[i : i + 1,]))
437
- image_hidden_states = torch.cat(connector_outputs, dim=0)
433
+ self.connector(image_hidden_states[i : i + 1,], out=connector_outputs[i : i + 1,])
434
+ image_hidden_states = connector_outputs
438
435
 
439
436
  elif image_hidden_states is not None:
440
437
  image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_llama import RBLNLlamaForCausalLMConfig
16
- from .modeling_llama import RBLNLlamaForCausalLM
15
+ from .configuration_llama import RBLNLlamaForCausalLMConfig, RBLNLlamaModelConfig
16
+ from .modeling_llama import RBLNLlamaForCausalLM, RBLNLlamaModel
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
16
 
17
17
 
18
18
  class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -40,3 +40,11 @@ class RBLNLlamaForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
40
40
  )
41
41
  ```
42
42
  """
43
+
44
+
45
+ class RBLNLlamaModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Llama models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from ....utils import logging
16
- from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
17
  from .llama_architecture import LlamaWrapper
18
18
 
19
19
 
@@ -81,3 +81,24 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
81
81
  """
82
82
 
83
83
  _decoder_wrapper_cls = LlamaWrapper
84
+
85
+
86
+ class RBLNLlamaModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Llama Model transformer outputting raw hidden-states without any specific head on top.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+
91
+ A class to convert and run pre-trained transformers based LlamaModel on RBLN devices.
92
+ It implements the methods to convert a pre-trained transformers LlamaModel into a RBLN transformer model by:
93
+
94
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
95
+ - compiling the resulting graph using the RBLN compiler.
96
+
97
+ **Configuration:**
98
+ This model uses [`RBLNLlamaModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
99
+ the `rbln_config` parameter should be an instance of [`RBLNLlamaModelConfig`] or a dictionary conforming to its structure.
100
+
101
+ See the [`RBLNLlamaModelConfig`] class for all available configuration options.
102
+ """
103
+
104
+ _decoder_wrapper_cls = LlamaWrapper
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_llava import RBLNLlavaForConditionalGenerationConfig
16
+ from .modeling_llava import RBLNLlavaForConditionalGeneration
@@ -0,0 +1,72 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNLlavaForConditionalGenerationConfig.
27
+
28
+ This configuration class stores the configuration parameters specific to
29
+ RBLN-optimized LLaVA models for multimodal conditional generation tasks
30
+ that combine vision and language processing capabilities.
31
+ """
32
+
33
+ submodules = ["vision_tower", "language_model"]
34
+
35
+ def __init__(
36
+ self,
37
+ batch_size: Optional[int] = None,
38
+ vision_tower: Optional[RBLNModelConfig] = None,
39
+ language_model: Optional[RBLNModelConfig] = None,
40
+ **kwargs: Any,
41
+ ):
42
+ """
43
+ Args:
44
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
45
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
46
+ This can include settings specific to the vision encoder, such as input resolution or other vision-related parameters.
47
+ If not provided, default settings will be used.
48
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
49
+ This can include settings specific to the language model, such as tensor parallelism or other text-related parameters.
50
+ If not provided, default settings will be used.
51
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
52
+
53
+ Raises:
54
+ ValueError: If `batch_size` is not a positive integer.
55
+ """
56
+ super().__init__(**kwargs)
57
+ self.batch_size = batch_size or 1
58
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
59
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
60
+
61
+ if self.batch_size != 1:
62
+ logger.warning("Ignore batch_size for Llava vision tower. It will be set to 1.")
63
+
64
+ self.vision_tower = self.initialize_submodule_config(
65
+ submodule_config=vision_tower,
66
+ batch_size=1, # vision_tower batch_size is always 1 in Llava
67
+ force_kwargs=True,
68
+ )
69
+
70
+ self.language_model = self.initialize_submodule_config(
71
+ submodule_config=language_model,
72
+ )