optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
19
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
19
20
  from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
20
21
 
21
22
  from ....configuration_utils import RBLNCompileConfig
@@ -50,8 +51,10 @@ class RBLNCLIPTextModel(RBLNModel):
50
51
  on RBLN devices, supporting text encoding for multimodal tasks.
51
52
  """
52
53
 
54
+ _tp_support = False
55
+
53
56
  @classmethod
54
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
57
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
55
58
  return _TextEncoder(model).eval()
56
59
 
57
60
  @classmethod
@@ -82,7 +85,18 @@ class RBLNCLIPTextModel(RBLNModel):
82
85
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
83
86
  return rbln_config
84
87
 
85
- def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
88
+ def forward(self, input_ids: torch.LongTensor, return_dict: Optional[bool] = None, **kwargs) -> torch.FloatTensor:
89
+ """
90
+ Forward pass for the RBLN-optimized CLIP text encoder model.
91
+
92
+ Args:
93
+ input_ids (torch.LongTensor): The input ids to the model.
94
+ return_dict (Optional[bool]): Whether to return a dictionary of outputs.
95
+
96
+ Returns:
97
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPTextModelOutput object.
98
+ """
99
+
86
100
  # To ignore using attention_mask, we override forward method.
87
101
  output = super().forward(input_ids, return_dict=return_dict)
88
102
  return output
@@ -111,12 +125,27 @@ class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
111
125
 
112
126
 
113
127
  class _VisionEncoder(torch.nn.Module):
114
- def __init__(self, enc: CLIPVisionModel):
128
+ def __init__(
129
+ self,
130
+ enc: CLIPVisionModel,
131
+ interpolate_pos_encoding: bool,
132
+ output_hidden_states: bool,
133
+ output_attentions: bool,
134
+ ):
115
135
  super().__init__()
116
136
  self.enc = enc
137
+ self.interpolate_pos_encoding = interpolate_pos_encoding
138
+ self.output_hidden_states = output_hidden_states
139
+ self.output_attentions = output_attentions
117
140
 
118
141
  def forward(self, inp):
119
- enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
142
+ enc_out = self.enc(
143
+ inp,
144
+ output_hidden_states=self.output_hidden_states,
145
+ interpolate_pos_encoding=self.interpolate_pos_encoding,
146
+ output_attentions=self.output_attentions,
147
+ return_dict=False,
148
+ )
120
149
  return enc_out
121
150
 
122
151
 
@@ -128,9 +157,16 @@ class RBLNCLIPVisionModel(RBLNModel):
128
157
  on RBLN devices, supporting image encoding for multimodal tasks.
129
158
  """
130
159
 
160
+ _tp_support = False
161
+
131
162
  @classmethod
132
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
133
- return _VisionEncoder(model).eval()
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
164
+ wrapper_cfg = {
165
+ "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
166
+ "output_hidden_states": rbln_config.output_hidden_states,
167
+ "output_attentions": rbln_config.output_attentions,
168
+ }
169
+ return _VisionEncoder(model, **wrapper_cfg).eval()
134
170
 
135
171
  @classmethod
136
172
  def update_rbln_config_using_pipe(
@@ -155,6 +191,12 @@ class RBLNCLIPVisionModel(RBLNModel):
155
191
  if rbln_config.image_size is None:
156
192
  raise ValueError("`rbln_image_size` should be specified!")
157
193
 
194
+ if rbln_config.output_attentions is None:
195
+ rbln_config.output_attentions = getattr(model_config, "output_attentions", False)
196
+
197
+ if rbln_config.output_hidden_states is None:
198
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
199
+
158
200
  rbln_compile_config = RBLNCompileConfig(
159
201
  input_info=[
160
202
  (
@@ -175,28 +217,91 @@ class RBLNCLIPVisionModel(RBLNModel):
175
217
 
176
218
  def forward(
177
219
  self,
178
- pixel_values: Optional[torch.FloatTensor] = None,
179
- return_dict: bool = None,
220
+ pixel_values: torch.FloatTensor,
221
+ return_dict: bool = True,
222
+ output_attentions: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ interpolate_pos_encoding: bool = False,
180
225
  **kwargs,
181
- ) -> Union[Tuple, CLIPVisionModelOutput]:
226
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
227
+ """
228
+ Forward pass for the RBLN-optimized CLIP vision encoder model.
229
+
230
+ Args:
231
+ pixel_values (torch.Tensor): The pixel values to the model.
232
+ return_dict (bool): Whether to return a dictionary of outputs.
233
+ output_attentions (Optional[bool]): Whether to return attentions.
234
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
235
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
236
+
237
+ Returns:
238
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
239
+ """
240
+
182
241
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
183
242
  logger.warning(
184
243
  f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
185
244
  )
245
+
246
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
247
+ output_hidden_states = (
248
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
249
+ )
250
+
251
+ if output_attentions != self.rbln_config.output_attentions:
252
+ raise ValueError(
253
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
254
+ f"Please compile again with the correct argument."
255
+ )
256
+
257
+ if output_hidden_states != self.rbln_config.output_hidden_states:
258
+ raise ValueError(
259
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
260
+ f"Please compile again with the correct argument."
261
+ )
262
+
263
+ if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
264
+ raise ValueError(
265
+ f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
266
+ f"Please compile again with the correct argument."
267
+ )
268
+
186
269
  output = super().forward(pixel_values, return_dict=return_dict)
187
270
  return output
188
271
 
189
272
  def _prepare_output(self, output, return_dict):
190
273
  # Prepare model output based on return_dict flag.
191
274
  # This method can be overridden by subclasses to provide task-specific output handling.
275
+ last_hidden_state = output.pop(0)
276
+ pooler_output = output.pop(0)
277
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
278
+
279
+ if self.rbln_config.output_hidden_states:
280
+ hidden_states = ()
281
+ num_hidden_layers = vision_config.num_hidden_layers
282
+ for _ in range(num_hidden_layers + 1):
283
+ hidden_states += (output.pop(0),)
284
+ else:
285
+ hidden_states = None
286
+
287
+ if self.rbln_config.output_attentions:
288
+ attentions = ()
289
+ num_hidden_layers = vision_config.num_hidden_layers
290
+ for _ in range(num_hidden_layers):
291
+ attentions += (output.pop(0),)
292
+ else:
293
+ attentions = None
192
294
 
193
295
  if not return_dict:
194
- return (output,) if not isinstance(output, (tuple, list)) else output
296
+ return tuple(
297
+ item for item in (last_hidden_state, pooler_output, hidden_states, attentions) if item is not None
298
+ )
195
299
  else:
196
- return CLIPVisionModelOutput(
197
- image_embeds=output[0],
198
- last_hidden_state=output[1],
199
- hidden_states=output[2:],
300
+ return BaseModelOutputWithPooling(
301
+ last_hidden_state=last_hidden_state,
302
+ pooler_output=pooler_output,
303
+ hidden_states=hidden_states,
304
+ attentions=attentions,
200
305
  )
201
306
 
202
307
 
@@ -210,19 +315,70 @@ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
210
315
 
211
316
  def forward(
212
317
  self,
213
- pixel_values: Optional[torch.FloatTensor] = None,
318
+ pixel_values: torch.FloatTensor,
319
+ return_dict: bool = True,
320
+ output_attentions: Optional[bool] = None,
321
+ output_hidden_states: Optional[bool] = None,
322
+ interpolate_pos_encoding: bool = False,
214
323
  **kwargs,
215
324
  ) -> Union[Tuple, CLIPVisionModelOutput]:
216
- if len(kwargs) > 0 and any(kwargs.values()):
217
- logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
218
-
219
- output = super().forward(pixel_values)
220
- image_embeds = output[0]
221
- last_hidden_state = output[1]
222
- hidden_states = output[2:]
223
-
224
- return CLIPVisionModelOutput(
225
- image_embeds=image_embeds,
226
- last_hidden_state=last_hidden_state,
227
- hidden_states=hidden_states,
325
+ """
326
+ Forward pass for the RBLN-optimized CLIP vision encoder model with projection.
327
+
328
+ Args:
329
+ pixel_values (torch.Tensor): The pixel values to the model.
330
+ return_dict (bool): Whether to return a dictionary of outputs.
331
+ output_attentions (Optional[bool]): Whether to return attentions.
332
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
333
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
334
+
335
+ Returns:
336
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPVisionModelOutput object.
337
+ """
338
+
339
+ return super().forward(
340
+ pixel_values=pixel_values,
341
+ return_dict=return_dict,
342
+ output_attentions=output_attentions,
343
+ output_hidden_states=output_hidden_states,
344
+ interpolate_pos_encoding=interpolate_pos_encoding,
345
+ **kwargs,
228
346
  )
347
+
348
+ def _prepare_output(self, output, return_dict):
349
+ # Prepare model output based on return_dict flag.
350
+ # This method can be overridden by subclasses to provide task-specific output handling.
351
+
352
+ image_embeds = output.pop(0) if isinstance(output, (tuple, list)) else output
353
+ last_hidden_state = output.pop(0)
354
+
355
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
356
+
357
+ if self.rbln_config.output_hidden_states:
358
+ hidden_states = ()
359
+ num_hidden_layers = vision_config.num_hidden_layers
360
+ for _ in range(num_hidden_layers + 1):
361
+ hidden_states += (output.pop(0),)
362
+ else:
363
+ hidden_states = None
364
+
365
+ if self.rbln_config.output_attentions:
366
+ attentions = ()
367
+ num_hidden_layers = vision_config.num_hidden_layers
368
+ for _ in range(num_hidden_layers):
369
+ attentions += (output.pop(0),)
370
+ else:
371
+ attentions = None
372
+
373
+ if not return_dict:
374
+ return tuple(
375
+ item for item in (image_embeds, last_hidden_state, hidden_states, attentions) if item is not None
376
+ )
377
+
378
+ else:
379
+ return CLIPVisionModelOutput(
380
+ image_embeds=image_embeds,
381
+ last_hidden_state=last_hidden_state,
382
+ hidden_states=hidden_states,
383
+ attentions=attentions,
384
+ )
@@ -4,10 +4,7 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import GemmaForCausalLM, GemmaModel
6
6
 
7
- from ..decoderonly.decoderonly_architecture import (
8
- RotaryEmbedding,
9
- apply_rotary_pos_emb,
10
- )
7
+ from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
11
8
 
12
9
 
13
10
  def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
@@ -27,11 +24,11 @@ class RBLNColPaliForRetrievalWrapper(nn.Module):
27
24
  output_hidden_states: bool = False,
28
25
  ):
29
26
  super().__init__()
30
- self.text_config = causal_lm.config
27
+ self.text_config = causal_lm.config.text_config
31
28
  self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
32
29
 
33
30
  self.output_hidden_states = output_hidden_states
34
- self.language_model = self.convert_to_rbln_language_model(causal_lm.model, max_seq_len)
31
+ self.language_model = self.convert_to_rbln_language_model(causal_lm.model.language_model, max_seq_len)
35
32
 
36
33
  self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
37
34
  self.embedding_proj_layer = embedding_proj_layer
@@ -11,9 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, List, Optional, Union
14
+ from typing import Any, List, Optional, Union
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
18
+
19
+
20
+ logger = get_logger(__name__)
17
21
 
18
22
 
19
23
  class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
@@ -24,45 +28,57 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
24
28
  including vision tower settings and multi-sequence length support.
25
29
 
26
30
  Example usage:
27
- ```python
28
- from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
29
-
30
- # Create a configuration object
31
- config = RBLNColPaliForRetrievalConfig(
32
- max_seq_lens=1152,
33
- output_hidden_states=False,
34
- tensor_parallel_size=4
35
- )
36
-
37
- # Use the configuration with from_pretrained
38
- model = RBLNColPaliForRetrieval.from_pretrained(
39
- "vidore/colpali-v1.3-hf",
40
- export=True,
41
- rbln_config=config
42
- )
43
- ```
31
+ ```python
32
+ from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
33
+
34
+ # Create a configuration object
35
+ config = RBLNColPaliForRetrievalConfig(
36
+ max_seq_lens=1152,
37
+ output_hidden_states=False,
38
+ tensor_parallel_size=4
39
+ )
40
+
41
+ # Use the configuration with from_pretrained
42
+ model = RBLNColPaliForRetrieval.from_pretrained(
43
+ "vidore/colpali-v1.3-hf",
44
+ export=True,
45
+ rbln_config=config
46
+ )
47
+ ```
44
48
  """
45
49
 
46
50
  submodules = ["vision_tower"]
47
51
 
48
52
  def __init__(
49
53
  self,
54
+ batch_size: Optional[int] = None,
50
55
  max_seq_lens: Union[int, List[int]] = None,
51
56
  output_hidden_states: Optional[bool] = None,
52
57
  vision_tower: Optional[RBLNModelConfig] = None,
53
- **kwargs: Dict[str, Any],
58
+ **kwargs: Any,
54
59
  ):
55
60
  """
56
61
  Args:
62
+ batch_size (Optional[int]): The batch size for the model.
57
63
  vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
58
64
  max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
59
65
  This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
60
66
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
61
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
67
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
68
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
62
69
  Raises:
63
70
  ValueError: If batch_size is not a positive integer.
64
71
  """
65
72
  super().__init__(**kwargs)
66
- self.vision_tower = vision_tower
73
+ self.batch_size = batch_size or 1
74
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
+
77
+ if self.batch_size != 1:
78
+ logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
+
80
+ self.vision_tower = self.initialize_submodule_config(
81
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
82
+ )
67
83
  self.max_seq_lens = max_seq_lens
68
84
  self.output_hidden_states = output_hidden_states