optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +35 -16
  4. optimum/rbln/modeling_base.py +6 -6
  5. optimum/rbln/ops/__init__.py +1 -0
  6. optimum/rbln/ops/attn.py +10 -0
  7. optimum/rbln/ops/flash_attn.py +8 -0
  8. optimum/rbln/ops/moe.py +180 -0
  9. optimum/rbln/ops/sliding_window_attn.py +9 -0
  10. optimum/rbln/transformers/__init__.py +36 -0
  11. optimum/rbln/transformers/modeling_attention_utils.py +118 -222
  12. optimum/rbln/transformers/modeling_outputs.py +25 -0
  13. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  14. optimum/rbln/transformers/models/__init__.py +28 -0
  15. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  16. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  17. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  18. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
  19. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
  20. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  21. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  22. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
  23. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
  24. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
  25. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
  26. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
  27. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  28. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  29. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  30. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  31. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  32. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  33. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
  34. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
  35. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  36. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  37. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  38. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  39. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  40. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  41. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
  42. optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
  43. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  44. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  45. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  46. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  47. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  48. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  50. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  51. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
  53. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  54. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  55. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  56. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  57. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  58. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  59. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  60. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
  61. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
  62. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  63. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  64. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  65. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  66. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  67. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  68. optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
  69. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  70. optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
  71. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  72. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  73. optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
  74. optimum/rbln/utils/import_utils.py +16 -1
  75. optimum/rbln/utils/runtime_utils.py +10 -6
  76. optimum/rbln/utils/submodule.py +24 -0
  77. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  78. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
  79. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  80. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
  81. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  82. {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -14,23 +14,16 @@
14
14
 
15
15
  import bisect
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, Optional, Tuple, Union
17
+ from typing import Optional, Tuple, Union
18
18
 
19
19
  import torch
20
- from transformers import PretrainedConfig, PreTrainedModel
21
20
  from transformers.modeling_outputs import BaseModelOutputWithPooling
22
21
  from transformers.modeling_utils import no_init_weights
23
- from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
24
- from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
22
+ from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
25
23
 
26
- from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....configuration_utils import RBLNModelConfig
27
25
  from ....modeling import RBLNModel
28
26
  from ...utils.rbln_runtime_wrapper import LoopProcessor
29
- from .colpali_architecture import RBLNColPaliForRetrievalWrapper
30
-
31
-
32
- if TYPE_CHECKING:
33
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
34
27
 
35
28
 
36
29
  class LoopVisionTower(LoopProcessor):
@@ -115,17 +108,25 @@ class RBLNColPaliForRetrieval(RBLNModel):
115
108
  from optimum.rbln import RBLNColPaliForRetrieval
116
109
 
117
110
  # Simple usage using rbln_* arguments
118
- # `max_seq_lens` is automatically inferred from the model config
119
111
  model = RBLNColPaliForRetrieval.from_pretrained(
120
112
  "vidore/colpali-v1.3-hf",
121
113
  export=True,
122
- rbln_max_seq_lens=1152,
114
+ rbln_config={
115
+ "vlm": {
116
+ "language_model": {
117
+ "prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
118
+ }
119
+ }
120
+ }
123
121
  )
124
122
 
125
123
  # Using a config dictionary
126
124
  rbln_config = {
127
- "max_seq_lens": 1152,
128
- "output_hidden_states": False,
125
+ "vlm": {
126
+ "language_model": {
127
+ "prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
128
+ }
129
+ }
129
130
  }
130
131
  model = RBLNColPaliForRetrieval.from_pretrained(
131
132
  "vidore/colpali-v1.3-hf",
@@ -137,7 +138,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
137
138
  from optimum.rbln import RBLNColPaliForRetrievalConfig
138
139
 
139
140
  config = RBLNColPaliForRetrievalConfig(
140
- max_seq_lens=1152,
141
+ vlm={
142
+ "language_model": {"prefill_chunk_size": 8192},
143
+ },
141
144
  output_hidden_states=False,
142
145
  tensor_parallel_size=4
143
146
  )
@@ -150,212 +153,93 @@ class RBLNColPaliForRetrieval(RBLNModel):
150
153
  """
151
154
 
152
155
  auto_model_class = None
156
+ _rbln_submodule_postfix = "model"
153
157
  _rbln_submodules = [
154
- {"name": "vision_tower"},
158
+ {"name": "vlm"},
155
159
  ]
156
160
 
157
161
  def __post_init__(self, **kwargs):
158
- self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
159
- self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
160
-
162
+ self.vlm_model = self.rbln_submodules[0]
161
163
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
162
- self.embed_tokens = self._create_embedding_layer()
163
- self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
164
- self.multi_modal_projector = self._create_multi_modal_projector()
165
- self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
166
-
164
+ self.embedding_proj_layer = self._create_embedding_proj_layer()
165
+ self.embedding_proj_layer.load_state_dict(artifacts["embedding_proj_layer"])
167
166
  return super().__post_init__(**kwargs)
168
167
 
169
- def _create_embedding_layer(self):
168
+ def _create_embedding_proj_layer(self):
170
169
  with no_init_weights():
171
- embed_tokens = torch.nn.Embedding(
172
- self.config.text_config.vocab_size,
173
- self.config.text_config.hidden_size,
174
- self.config.text_config.pad_token_id,
170
+ embedding_proj_layer = torch.nn.Linear(
171
+ self.config.vlm_config.text_config.hidden_size, self.config.embedding_dim
175
172
  )
176
- return embed_tokens
177
-
178
- def _create_multi_modal_projector(self):
179
- with no_init_weights():
180
- multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
181
- return multi_modal_projector
182
-
183
- @classmethod
184
- def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
185
- return RBLNColPaliForRetrievalWrapper(
186
- causal_lm=model.vlm,
187
- embedding_proj_layer=model.embedding_proj_layer,
188
- max_seq_len=max(rbln_config.max_seq_lens),
189
- output_hidden_states=rbln_config.output_hidden_states,
190
- )
173
+ return embedding_proj_layer
191
174
 
192
175
  @classmethod
193
176
  def save_torch_artifacts(
194
177
  cls,
195
- model: "PreTrainedModel",
178
+ model: "ColPaliForRetrieval",
196
179
  save_dir_path: Path,
197
180
  subfolder: str,
198
181
  rbln_config: RBLNModelConfig,
199
182
  ):
200
183
  save_dict = {}
201
- save_dict["embed_tokens"] = model.vlm.get_input_embeddings().state_dict()
202
- save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
184
+ save_dict["embedding_proj_layer"] = model.embedding_proj_layer.state_dict()
203
185
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
204
186
 
205
- @classmethod
206
- def _update_rbln_config(
207
- cls,
208
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
209
- model: Optional["PreTrainedModel"] = None,
210
- model_config: Optional["PretrainedConfig"] = None,
211
- rbln_config: Optional[RBLNModelConfig] = None,
212
- ) -> RBLNModelConfig:
213
- hidden_size = model_config.vlm_config.text_config.hidden_size
214
- if rbln_config.max_seq_lens is None:
215
- rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
216
- if isinstance(rbln_config.max_seq_lens, int):
217
- rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
218
- rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
219
-
220
- if rbln_config.output_hidden_states is None:
221
- rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
222
-
223
- input_infos = []
224
- for max_seq_len in rbln_config.max_seq_lens:
225
- input_info = [
226
- ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
227
- ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
228
- ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
229
- ]
230
- input_infos.append(input_info)
231
-
232
- rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
233
- rbln_config.set_compile_cfgs([rbln_compile_config])
234
-
235
- return rbln_config
236
-
237
- @classmethod
238
- def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
239
- if hasattr(model, "vlm"):
240
- model.vision_tower = model.vlm.vision_tower
241
- del model.vlm.model.vision_tower
242
- return model
243
- return model
244
-
245
- def get_image_features(self, pixel_values: torch.Tensor):
246
- # Projects the last hidden state from the vision model into language model space.
247
- # Args:
248
- # pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
249
- # The tensors corresponding to the input images.
250
- # Returns:
251
- # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
252
-
253
- vision_output_size = [
254
- pixel_values.shape[0],
255
- self.config.vlm_config.vision_config.num_image_tokens,
256
- self.config.vlm_config.vision_config.hidden_size,
257
- ]
258
- vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
259
- self.vision_tower(pixel_values, out=vision_output)
260
- image_features = self.multi_modal_projector(vision_output)
261
- image_features = image_features / (self.config.text_config.hidden_size**0.5)
262
- return image_features
263
-
264
- def _preprocess_inputs(
265
- self,
266
- input_ids: Optional[torch.LongTensor] = None,
267
- inputs_embeds: Optional[torch.FloatTensor] = None,
268
- pixel_values: Optional[torch.FloatTensor] = None,
269
- **kwargs,
270
- ):
271
- if (input_ids is None) ^ (inputs_embeds is not None):
272
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
273
-
274
- # Replace image id woth PAD if the image token if OOV, to avoid index-errors
275
- if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
276
- special_image_mask = input_ids == self.config.vlm_config.image_token_index
277
- llm_input_ids = input_ids.clone()
278
- llm_input_ids[special_image_mask] = 0
279
- else:
280
- llm_input_ids = input_ids
281
-
282
- if inputs_embeds is None:
283
- inputs_embeds = self.embed_tokens(llm_input_ids)
284
-
285
- # Merge text and images
286
- image_features = None
287
- if pixel_values is not None:
288
- image_features = self.get_image_features(pixel_values)
289
- special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
290
- special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
291
-
292
- image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
293
- inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
294
-
295
- return inputs_embeds, image_features
296
-
297
187
  def forward(
298
188
  self,
299
189
  input_ids: Optional[torch.LongTensor] = None,
300
- inputs_embeds: Optional[torch.FloatTensor] = None,
301
190
  pixel_values: Optional[torch.FloatTensor] = None,
302
191
  attention_mask: Optional[torch.Tensor] = None,
303
- output_attentions: Optional[bool] = None,
304
192
  output_hidden_states: Optional[bool] = None,
305
193
  return_dict: Optional[bool] = None,
306
194
  **kwargs,
307
195
  ) -> Union[Tuple, ColPaliForRetrievalOutput]:
196
+ """
197
+ Forward pass for the RBLN-optimized ColPaliForRetrieval model.
198
+
199
+ Args:
200
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of input sequence tokens in the vocabulary.
201
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
202
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length)): Mask to avoid performing attention on padding token indices.
203
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
204
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
205
+
206
+ Returns:
207
+ ColPaliForRetrievalOutput or tuple(torch.FloatTensor)
208
+ """
308
209
  if pixel_values is not None:
309
210
  pixel_values = pixel_values.to(dtype=self.dtype)
310
211
 
311
- if output_attentions:
312
- raise ValueError("output_attentions is not supported for RBLNColPaliForRetrieval")
313
-
314
- if output_hidden_states is not None and output_hidden_states != self.rbln_config.output_hidden_states:
212
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
213
+ output_hidden_states = (
214
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
215
+ )
216
+ if output_hidden_states != self.rbln_config.output_hidden_states:
315
217
  raise ValueError(
316
218
  f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
317
219
  f"Please compile again with the correct argument."
318
220
  )
319
221
 
320
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
321
-
322
- inputs_embeds, image_features = self._preprocess_inputs(
323
- input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
222
+ vlm_output = self.vlm_model(
223
+ input_ids=input_ids,
224
+ attention_mask=attention_mask,
225
+ pixel_values=pixel_values,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=True,
228
+ **kwargs,
324
229
  )
230
+ vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
231
+ vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
325
232
 
326
- outputs = []
327
- language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
328
- language_model_hidden_states_size = [
329
- inputs_embeds.shape[0],
330
- self.rbln_config.max_seq_lens[0],
331
- self.rbln_config.max_seq_lens[0],
332
- ]
333
- outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
334
- if self.rbln_config.output_hidden_states:
335
- for _ in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
336
- outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
337
-
338
- # Embedding_proj_layer is fused on the bottom of the language model.
339
- self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
340
-
341
- embeddings = outputs[0][:, : inputs_embeds.shape[1]]
342
- hidden_states = (
343
- None
344
- if not self.rbln_config.output_hidden_states
345
- else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
346
- )
347
-
348
- # L2 normalization
349
- embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
233
+ last_hidden_states = vlm_output[0]
234
+ proj_dtype = self.embedding_proj_layer.weight.dtype
235
+ embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype))
236
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
350
237
 
351
238
  if attention_mask is not None:
352
- embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
239
+ embeddings = embeddings * attention_mask.unsqueeze(-1)
353
240
 
354
- if not return_dict:
355
- return (embeddings, hidden_states, image_features)
356
- else:
357
- return ColPaliForRetrievalOutput(
358
- embeddings=embeddings,
359
- hidden_states=hidden_states,
360
- image_hidden_states=image_features,
361
- )
241
+ return ColPaliForRetrievalOutput(
242
+ embeddings=embeddings,
243
+ hidden_states=vlm_hidden_states,
244
+ image_hidden_states=vlm_image_hidden_states,
245
+ )
@@ -32,14 +32,16 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
32
32
 
33
33
  # Create a configuration object
34
34
  config = RBLNColQwen2ForRetrievalConfig(
35
- visual={
36
- "max_seq_lens": 6400,
37
- "device": 0,
38
- },
39
- max_seq_len=32_768,
40
- tensor_parallel_size=4,
41
- device=[0, 1, 2, 3],
42
- output_hidden_states=False,
35
+ vlm = {
36
+ "visual": {
37
+ "max_seq_lens": 6400,
38
+ "device": 0,
39
+ },
40
+ "max_seq_len": 32_768,
41
+ "tensor_parallel_size": 4,
42
+ "device": [0, 1, 2, 3],
43
+ "output_hidden_states": False,
44
+ }
43
45
  )
44
46
 
45
47
  # Use the configuration with from_pretrained
@@ -51,22 +53,37 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
51
53
  ```
52
54
  """
53
55
 
54
- submodules = ["visual"]
56
+ submodules = ["vlm"]
57
+ _allow_no_compile_cfgs = True
55
58
 
56
59
  def __init__(
57
60
  self,
58
- visual: Optional[RBLNModelConfig] = None,
59
61
  batch_size: Optional[int] = None,
60
- use_inputs_embeds: bool = True,
62
+ output_hidden_states: Optional[bool] = None,
63
+ vlm: Optional[RBLNModelConfig] = None,
61
64
  **kwargs,
62
65
  ):
63
- super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
64
- if not self.use_inputs_embeds:
65
- raise ValueError(
66
- "RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
67
- "as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
68
- )
69
- if batch_size is not None and batch_size != 1:
70
- raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
71
-
72
- self.visual = visual
66
+ """
67
+ Args:
68
+ batch_size (Optional[int]): The batch size for the model.
69
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the VLM model.
70
+ vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
71
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
72
+ Raises:
73
+ ValueError: If batch_size is not a positive integer.
74
+ """
75
+ super().__init__(**kwargs)
76
+ self.batch_size = batch_size or 1
77
+ self.output_hidden_states = output_hidden_states or False
78
+
79
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
80
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
81
+
82
+ self.vlm = self.initialize_submodule_config(
83
+ submodule_config=vlm,
84
+ batch_size=batch_size,
85
+ output_hidden_states=output_hidden_states,
86
+ force_kwargs=True,
87
+ logits_to_keep=0,
88
+ use_inputs_embeds=True,
89
+ )