optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -14,24 +14,16 @@
14
14
 
15
15
  import bisect
16
16
  from pathlib import Path
17
- from tempfile import TemporaryDirectory
18
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
+ from typing import Optional, Tuple, Union
19
18
 
20
19
  import torch
21
- from transformers import PretrainedConfig, PreTrainedModel
22
20
  from transformers.modeling_outputs import BaseModelOutputWithPooling
23
21
  from transformers.modeling_utils import no_init_weights
24
- from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
25
- from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
22
+ from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
26
23
 
27
- from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....configuration_utils import RBLNModelConfig
28
25
  from ....modeling import RBLNModel
29
26
  from ...utils.rbln_runtime_wrapper import LoopProcessor
30
- from .colpali_architecture import RBLNColPaliForRetrievalWrapper
31
-
32
-
33
- if TYPE_CHECKING:
34
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
35
27
 
36
28
 
37
29
  class LoopVisionTower(LoopProcessor):
@@ -116,17 +108,25 @@ class RBLNColPaliForRetrieval(RBLNModel):
116
108
  from optimum.rbln import RBLNColPaliForRetrieval
117
109
 
118
110
  # Simple usage using rbln_* arguments
119
- # `max_seq_lens` is automatically inferred from the model config
120
111
  model = RBLNColPaliForRetrieval.from_pretrained(
121
112
  "vidore/colpali-v1.3-hf",
122
113
  export=True,
123
- rbln_max_seq_lens=1152,
114
+ rbln_config={
115
+ "vlm": {
116
+ "language_model": {
117
+ "prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
118
+ }
119
+ }
120
+ }
124
121
  )
125
122
 
126
123
  # Using a config dictionary
127
124
  rbln_config = {
128
- "max_seq_lens": 1152,
129
- "output_hidden_states": False,
125
+ "vlm": {
126
+ "language_model": {
127
+ "prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
128
+ }
129
+ }
130
130
  }
131
131
  model = RBLNColPaliForRetrieval.from_pretrained(
132
132
  "vidore/colpali-v1.3-hf",
@@ -138,7 +138,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
138
138
  from optimum.rbln import RBLNColPaliForRetrievalConfig
139
139
 
140
140
  config = RBLNColPaliForRetrievalConfig(
141
- max_seq_lens=1152,
141
+ vlm={
142
+ "language_model": {"prefill_chunk_size": 8192},
143
+ },
142
144
  output_hidden_states=False,
143
145
  tensor_parallel_size=4
144
146
  )
@@ -151,250 +153,93 @@ class RBLNColPaliForRetrieval(RBLNModel):
151
153
  """
152
154
 
153
155
  auto_model_class = None
156
+ _rbln_submodule_postfix = "model"
154
157
  _rbln_submodules = [
155
- {"name": "vision_tower"},
158
+ {"name": "vlm"},
156
159
  ]
157
160
 
158
161
  def __post_init__(self, **kwargs):
159
- self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
160
- self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
161
-
162
+ self.vlm_model = self.rbln_submodules[0]
162
163
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
163
- self.embed_tokens = self._create_embedding_layer()
164
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
165
- self.multi_modal_projector = self._create_multi_modal_projector()
166
- self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
167
-
164
+ self.embedding_proj_layer = self._create_embedding_proj_layer()
165
+ self.embedding_proj_layer.load_state_dict(artifacts["embedding_proj_layer"])
168
166
  return super().__post_init__(**kwargs)
169
167
 
170
- def _create_embedding_layer(self):
168
+ def _create_embedding_proj_layer(self):
171
169
  with no_init_weights():
172
- embed_tokens = torch.nn.Embedding(
173
- self.config.text_config.vocab_size,
174
- self.config.text_config.hidden_size,
175
- self.config.text_config.pad_token_id,
170
+ embedding_proj_layer = torch.nn.Linear(
171
+ self.config.vlm_config.text_config.hidden_size, self.config.embedding_dim
176
172
  )
177
- return embed_tokens
178
-
179
- def _create_multi_modal_projector(self):
180
- with no_init_weights():
181
- multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
182
- return multi_modal_projector
183
-
184
- @classmethod
185
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
186
- return RBLNColPaliForRetrievalWrapper(
187
- causal_lm=model.vlm,
188
- embedding_proj_layer=model.embedding_proj_layer,
189
- max_seq_len=max(rbln_config.max_seq_lens),
190
- output_hidden_states=rbln_config.output_hidden_states,
191
- )
173
+ return embedding_proj_layer
192
174
 
193
175
  @classmethod
194
176
  def save_torch_artifacts(
195
177
  cls,
196
- model: "PreTrainedModel",
178
+ model: "ColPaliForRetrieval",
197
179
  save_dir_path: Path,
198
180
  subfolder: str,
199
181
  rbln_config: RBLNModelConfig,
200
182
  ):
201
183
  save_dict = {}
202
- save_dict["embed_tokens"] = model.vlm.get_input_embeddings().state_dict()
203
- save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
184
+ save_dict["embedding_proj_layer"] = model.embedding_proj_layer.state_dict()
204
185
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
205
186
 
206
- @classmethod
207
- def _update_rbln_config(
208
- cls,
209
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
210
- model: Optional["PreTrainedModel"] = None,
211
- model_config: Optional["PretrainedConfig"] = None,
212
- rbln_config: Optional[RBLNModelConfig] = None,
213
- ) -> RBLNModelConfig:
214
- hidden_size = model_config.vlm_config.text_config.hidden_size
215
- if rbln_config.max_seq_lens is None:
216
- rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
217
- if isinstance(rbln_config.max_seq_lens, int):
218
- rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
219
- rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
220
-
221
- if rbln_config.output_hidden_states is None:
222
- rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
223
-
224
- input_infos = []
225
- for max_seq_len in rbln_config.max_seq_lens:
226
- input_info = [
227
- ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
228
- ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
229
- ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
230
- ]
231
- input_infos.append(input_info)
232
-
233
- rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
234
- rbln_config.set_compile_cfgs([rbln_compile_config])
235
-
236
- return rbln_config
237
-
238
- @classmethod
239
- def from_model(
240
- cls,
241
- model: "PreTrainedModel",
242
- config: Optional[PretrainedConfig] = None,
243
- rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
244
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
245
- subfolder: str = "",
246
- **kwargs: Any,
247
- ) -> "RBLNModel":
248
- """
249
- Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
250
- This method performs the actual model conversion and compilation process.
251
-
252
- Args:
253
- model (PreTrainedModel): The PyTorch model to be compiled.
254
- The object must be an instance of the HuggingFace transformers PreTrainedModel class.
255
- config (Optional[PretrainedConfig]): The configuration object associated with the model.
256
- rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
257
- This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
258
- For detailed configuration options, see the specific model's configuration class documentation.
259
- kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
260
-
261
- The method performs the following steps:
262
-
263
- 1. Compiles the PyTorch model into an optimized RBLN graph
264
- 2. Configures the model for the specified NPU device
265
- 3. Creates the necessary runtime objects if requested
266
- 4. Saves the compiled model and configurations
267
-
268
- Returns:
269
- (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
270
- """
271
- if not hasattr(model, "vision_tower"):
272
- model.vision_tower = model.vlm.vision_tower
273
- del model.vlm.model.vision_tower
274
- model = super().from_model(model, config, rbln_config, model_save_dir, subfolder, **kwargs)
275
- return model
276
-
277
- @classmethod
278
- def get_pytorch_model(cls, *args, **kwargs):
279
- model = super().get_pytorch_model(*args, **kwargs)
280
- model.vision_tower = model.vlm.vision_tower
281
- del model.vlm.model.vision_tower
282
- return model
283
-
284
- def get_image_features(self, pixel_values: torch.Tensor):
285
- # Projects the last hidden state from the vision model into language model space.
286
- # Args:
287
- # pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
288
- # The tensors corresponding to the input images.
289
- # Returns:
290
- # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
291
-
292
- vision_output_size = [
293
- pixel_values.shape[0],
294
- self.config.vlm_config.vision_config.num_image_tokens,
295
- self.config.vlm_config.vision_config.hidden_size,
296
- ]
297
- vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
298
- self.vision_tower(pixel_values, out=vision_output)
299
- image_features = self.multi_modal_projector(vision_output)
300
- image_features = image_features / (self.config.text_config.hidden_size**0.5)
301
- return image_features
302
-
303
- def _preprocess_inputs(
304
- self,
305
- input_ids: Optional[torch.LongTensor] = None,
306
- inputs_embeds: Optional[torch.FloatTensor] = None,
307
- pixel_values: Optional[torch.FloatTensor] = None,
308
- **kwargs,
309
- ):
310
- if (input_ids is None) ^ (inputs_embeds is not None):
311
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
312
-
313
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
314
- if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
315
- special_image_mask = input_ids == self.config.vlm_config.image_token_index
316
- llm_input_ids = input_ids.clone()
317
- llm_input_ids[special_image_mask] = 0
318
- else:
319
- llm_input_ids = input_ids
320
-
321
- if inputs_embeds is None:
322
- inputs_embeds = self.embed_tokens(llm_input_ids)
323
-
324
- # Merge text and images
325
- image_features = None
326
- if pixel_values is not None:
327
- image_features = self.get_image_features(pixel_values)
328
- special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
329
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
330
-
331
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
332
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
333
-
334
- return inputs_embeds, image_features
335
-
336
187
  def forward(
337
188
  self,
338
189
  input_ids: Optional[torch.LongTensor] = None,
339
- inputs_embeds: Optional[torch.FloatTensor] = None,
340
190
  pixel_values: Optional[torch.FloatTensor] = None,
341
191
  attention_mask: Optional[torch.Tensor] = None,
342
- output_attentions: Optional[bool] = None,
343
192
  output_hidden_states: Optional[bool] = None,
344
193
  return_dict: Optional[bool] = None,
345
194
  **kwargs,
346
195
  ) -> Union[Tuple, ColPaliForRetrievalOutput]:
196
+ """
197
+ Forward pass for the RBLN-optimized ColPaliForRetrieval model.
198
+
199
+ Args:
200
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of input sequence tokens in the vocabulary.
201
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
202
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length)): Mask to avoid performing attention on padding token indices.
203
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
204
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
205
+
206
+ Returns:
207
+ ColPaliForRetrievalOutput or tuple(torch.FloatTensor)
208
+ """
347
209
  if pixel_values is not None:
348
210
  pixel_values = pixel_values.to(dtype=self.dtype)
349
211
 
350
- if output_attentions:
351
- raise ValueError("output_attentions is not supported for RBLNColPaliForRetrieval")
352
-
353
- if output_hidden_states is not None and output_hidden_states != self.rbln_config.output_hidden_states:
212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
213
+ output_hidden_states = (
214
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
215
+ )
216
+ if output_hidden_states != self.rbln_config.output_hidden_states:
354
217
  raise ValueError(
355
218
  f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
356
219
  f"Please compile again with the correct argument."
357
220
  )
358
221
 
359
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
360
-
361
- inputs_embeds, image_features = self._preprocess_inputs(
362
- input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
222
+ vlm_output = self.vlm_model(
223
+ input_ids=input_ids,
224
+ attention_mask=attention_mask,
225
+ pixel_values=pixel_values,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=True,
228
+ **kwargs,
363
229
  )
230
+ vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
231
+ vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
364
232
 
365
- outputs = []
366
- language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
367
- language_model_hidden_states_size = [
368
- inputs_embeds.shape[0],
369
- self.rbln_config.max_seq_lens[0],
370
- self.rbln_config.max_seq_lens[0],
371
- ]
372
- outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
373
- if self.rbln_config.output_hidden_states:
374
- for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
375
- outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
376
-
377
- # Embedding_proj_layer is fused on the bottom of the language model.
378
- self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
379
-
380
- embeddings = outputs[0][:, : inputs_embeds.shape[1]]
381
- hidden_states = (
382
- None
383
- if not self.rbln_config.output_hidden_states
384
- else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
385
- )
386
-
387
- # L2 normalization
388
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
233
+ last_hidden_states = vlm_output[0]
234
+ proj_dtype = self.embedding_proj_layer.weight.dtype
235
+ embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype))
236
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
389
237
 
390
238
  if attention_mask is not None:
391
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
239
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
392
240
 
393
- if not return_dict:
394
- return (embeddings, hidden_states, image_features)
395
- else:
396
- return ColPaliForRetrievalOutput(
397
- embeddings=embeddings,
398
- hidden_states=hidden_states,
399
- image_hidden_states=image_features,
400
- )
241
+ return ColPaliForRetrievalOutput(
242
+ embeddings=embeddings,
243
+ hidden_states=vlm_hidden_states,
244
+ image_hidden_states=vlm_image_hidden_states,
245
+ )
@@ -32,14 +32,16 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
32
32
 
33
33
  # Create a configuration object
34
34
  config = RBLNColQwen2ForRetrievalConfig(
35
- visual={
36
- "max_seq_lens": 6400,
37
- "device": 0,
38
- },
39
- max_seq_len=32_768,
40
- tensor_parallel_size=4,
41
- device=[0, 1, 2, 3],
42
- output_hidden_states=False,
35
+ vlm = {
36
+ "visual": {
37
+ "max_seq_lens": 6400,
38
+ "device": 0,
39
+ },
40
+ "max_seq_len": 32_768,
41
+ "tensor_parallel_size": 4,
42
+ "device": [0, 1, 2, 3],
43
+ "output_hidden_states": False,
44
+ }
43
45
  )
44
46
 
45
47
  # Use the configuration with from_pretrained
@@ -51,24 +53,37 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
51
53
  ```
52
54
  """
53
55
 
54
- submodules = ["visual"]
56
+ submodules = ["vlm"]
57
+ _allow_no_compile_cfgs = True
55
58
 
56
59
  def __init__(
57
60
  self,
58
- visual: Optional[RBLNModelConfig] = None,
59
61
  batch_size: Optional[int] = None,
60
- use_inputs_embeds: bool = True,
61
- output_hidden_states: Optional[bool] = False,
62
+ output_hidden_states: Optional[bool] = None,
63
+ vlm: Optional[RBLNModelConfig] = None,
62
64
  **kwargs,
63
65
  ):
64
- super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
65
- if not self.use_inputs_embeds:
66
- raise ValueError(
67
- "RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
68
- "as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
69
- )
70
- if batch_size is not None and batch_size != 1:
71
- raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
72
-
73
- self.visual = visual
74
- self.output_hidden_states = output_hidden_states
66
+ """
67
+ Args:
68
+ batch_size (Optional[int]): The batch size for the model.
69
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the VLM model.
70
+ vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
71
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
72
+ Raises:
73
+ ValueError: If batch_size is not a positive integer.
74
+ """
75
+ super().__init__(**kwargs)
76
+ self.batch_size = batch_size or 1
77
+ self.output_hidden_states = output_hidden_states or False
78
+
79
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
80
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
81
+
82
+ self.vlm = self.initialize_submodule_config(
83
+ submodule_config=vlm,
84
+ batch_size=batch_size,
85
+ output_hidden_states=output_hidden_states,
86
+ force_kwargs=True,
87
+ logits_to_keep=0,
88
+ use_inputs_embeds=True,
89
+ )