optimum-rbln 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +5 -5
  61. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -54,6 +54,8 @@ _import_structure = {
54
54
  "RBLNBlip2VisionModelConfig",
55
55
  "RBLNColPaliForRetrieval",
56
56
  "RBLNColPaliForRetrievalConfig",
57
+ "RBLNColQwen2ForRetrieval",
58
+ "RBLNColQwen2ForRetrievalConfig",
57
59
  "RBLNCLIPTextModel",
58
60
  "RBLNCLIPTextModelConfig",
59
61
  "RBLNCLIPTextModelWithProjection",
@@ -110,6 +112,8 @@ _import_structure = {
110
112
  "RBLNPegasusModelConfig",
111
113
  "RBLNLlavaNextForConditionalGeneration",
112
114
  "RBLNLlavaNextForConditionalGenerationConfig",
115
+ "RBLNLoRAAdapterConfig",
116
+ "RBLNLoRAConfig",
113
117
  "RBLNMidmLMHeadModel",
114
118
  "RBLNMidmLMHeadModelConfig",
115
119
  "RBLNMistralForCausalLM",
@@ -216,6 +220,8 @@ if TYPE_CHECKING:
216
220
  RBLNCLIPVisionModelWithProjectionConfig,
217
221
  RBLNColPaliForRetrieval,
218
222
  RBLNColPaliForRetrievalConfig,
223
+ RBLNColQwen2ForRetrieval,
224
+ RBLNColQwen2ForRetrievalConfig,
219
225
  RBLNDecoderOnlyModel,
220
226
  RBLNDecoderOnlyModelConfig,
221
227
  RBLNDecoderOnlyModelForCausalLM,
@@ -258,6 +264,8 @@ if TYPE_CHECKING:
258
264
  RBLNLlavaForConditionalGenerationConfig,
259
265
  RBLNLlavaNextForConditionalGeneration,
260
266
  RBLNLlavaNextForConditionalGenerationConfig,
267
+ RBLNLoRAAdapterConfig,
268
+ RBLNLoRAConfig,
261
269
  RBLNMidmLMHeadModel,
262
270
  RBLNMidmLMHeadModelConfig,
263
271
  RBLNMistralForCausalLM,
@@ -23,6 +23,7 @@ different model architectures.
23
23
  import inspect
24
24
  from typing import TYPE_CHECKING, Optional, Union
25
25
 
26
+ from torch import nn
26
27
  from transformers import (
27
28
  AutoModel,
28
29
  AutoModelForAudioClassification,
@@ -57,6 +58,28 @@ class RBLNTransformerEncoder(RBLNModel):
57
58
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
58
59
  rbln_dtype = "int64"
59
60
 
61
+ @classmethod
62
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
63
+ class TransformerEncoderWrapper(nn.Module):
64
+ # Parameters to disable for RBLN compilation
65
+ DISABLED_PARAMS = {"return_dict", "use_cache"}
66
+
67
+ def __init__(self, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig):
68
+ super().__init__()
69
+ self.model = model
70
+ self.rbln_config = rbln_config
71
+ self._forward_signature = inspect.signature(model.forward)
72
+
73
+ def forward(self, *args, **kwargs):
74
+ # Disable parameters that are not compatible with RBLN compilation
75
+ for param_name in self.DISABLED_PARAMS:
76
+ if param_name in self._forward_signature.parameters:
77
+ kwargs[param_name] = False
78
+
79
+ return self.model(*args, **kwargs)
80
+
81
+ return TransformerEncoderWrapper(model, rbln_config).eval()
82
+
60
83
  @classmethod
61
84
  def _update_rbln_config(
62
85
  cls,
@@ -208,7 +231,6 @@ class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
208
231
 
209
232
  def _prepare_output(self, output, return_dict):
210
233
  # Prepare QuestionAnswering specific output format.
211
-
212
234
  start_logits, end_logits = output
213
235
 
214
236
  if not return_dict:
@@ -245,6 +267,20 @@ class RBLNModelForImageClassification(RBLNImageModel):
245
267
  class RBLNModelForDepthEstimation(RBLNImageModel):
246
268
  auto_model_class = AutoModelForDepthEstimation
247
269
 
270
+ @classmethod
271
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
272
+ class ImageModelWrapper(nn.Module):
273
+ def __init__(self, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
274
+ super().__init__()
275
+ self.model = model
276
+ self.rbln_config = rbln_config
277
+
278
+ def forward(self, *args, **kwargs):
279
+ output = self.model(*args, return_dict=True, **kwargs)
280
+ return output.predicted_depth
281
+
282
+ return ImageModelWrapper(model, rbln_config).eval()
283
+
248
284
 
249
285
  class RBLNModelForAudioClassification(RBLNModel):
250
286
  """
@@ -75,6 +75,10 @@ _import_structure = {
75
75
  "RBLNColPaliForRetrieval",
76
76
  "RBLNColPaliForRetrievalConfig",
77
77
  ],
78
+ "colqwen2": [
79
+ "RBLNColQwen2ForRetrieval",
80
+ "RBLNColQwen2ForRetrievalConfig",
81
+ ],
78
82
  "distilbert": [
79
83
  "RBLNDistilBertForQuestionAnswering",
80
84
  "RBLNDistilBertForQuestionAnsweringConfig",
@@ -96,6 +100,8 @@ _import_structure = {
96
100
  "RBLNDecoderOnlyModel",
97
101
  "RBLNDecoderOnlyModelForCausalLM",
98
102
  "RBLNDecoderOnlyModelForCausalLMConfig",
103
+ "RBLNLoRAAdapterConfig",
104
+ "RBLNLoRAConfig",
99
105
  ],
100
106
  "depth_anything": ["RBLNDepthAnythingForDepthEstimationConfig", "RBLNDepthAnythingForDepthEstimation"],
101
107
  "dpt": [
@@ -234,11 +240,14 @@ if TYPE_CHECKING:
234
240
  RBLNCLIPVisionModelWithProjectionConfig,
235
241
  )
236
242
  from .colpali import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
243
+ from .colqwen2 import RBLNColQwen2ForRetrieval, RBLNColQwen2ForRetrievalConfig
237
244
  from .decoderonly import (
238
245
  RBLNDecoderOnlyModel,
239
246
  RBLNDecoderOnlyModelConfig,
240
247
  RBLNDecoderOnlyModelForCausalLM,
241
248
  RBLNDecoderOnlyModelForCausalLMConfig,
249
+ RBLNLoRAAdapterConfig,
250
+ RBLNLoRAConfig,
242
251
  )
243
252
  from .depth_anything import RBLNDepthAnythingForDepthEstimation, RBLNDepthAnythingForDepthEstimationConfig
244
253
  from .distilbert import RBLNDistilBertForQuestionAnswering, RBLNDistilBertForQuestionAnsweringConfig
@@ -15,6 +15,10 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
@@ -25,6 +29,16 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
25
29
  RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
30
  """
27
31
 
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
41
+
28
42
 
29
43
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
44
  """
@@ -36,6 +50,7 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
36
50
 
37
51
  def __init__(
38
52
  self,
53
+ batch_size: Optional[int] = None,
39
54
  num_query_tokens: Optional[int] = None,
40
55
  image_text_hidden_size: Optional[int] = None,
41
56
  **kwargs,
@@ -47,11 +62,22 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
47
62
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
48
63
  """
49
64
  super().__init__(**kwargs)
65
+ self.batch_size = batch_size or 1
66
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
67
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
68
+
50
69
  self.num_query_tokens = num_query_tokens
51
70
  self.image_text_hidden_size = image_text_hidden_size
52
71
 
53
72
 
54
73
  class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
74
+ """
75
+ Configuration class for RBLNBlip2ForConditionalGeneration.
76
+
77
+ This configuration class stores the configuration parameters specific to
78
+ RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
79
+ """
80
+
55
81
  submodules = ["vision_model", "qformer", "language_model"]
56
82
 
57
83
  def __init__(
@@ -78,6 +104,12 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
78
104
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
79
105
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
80
106
 
81
- self.vision_model = self.init_submodule_config(RBLNBlip2VisionModelConfig, vision_model)
82
- self.language_model = language_model
83
- self.qformer = self.init_submodule_config(RBLNBlip2QFormerModelConfig, qformer)
107
+ if self.batch_size != 1:
108
+ logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
109
+ logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
110
+
111
+ self.vision_model = self.initialize_submodule_config(
112
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
113
+ )
114
+ self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
115
+ self.language_model = self.initialize_submodule_config(submodule_config=language_model)
@@ -30,34 +30,31 @@ from transformers.utils import logging
30
30
 
31
31
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
32
  from ....modeling import RBLNModel
33
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
34
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
33
35
 
34
36
 
35
37
  logger = logging.get_logger(__name__)
36
38
 
37
39
  if TYPE_CHECKING:
40
+ import rebel
38
41
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
39
42
 
40
43
 
41
- class LoopProjector:
42
- def __init__(self, language_projection) -> None:
43
- self.language_projection = language_projection
44
+ class LoopProjector(LoopProcessor):
45
+ def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
46
+ super().__init__(model=language_projection)
44
47
 
45
- def forward(self, *args, **kwargs):
46
- query_output = args[0]
48
+ def _get_batch_size(self, query_output, **kwargs):
49
+ return query_output.shape[0]
47
50
 
48
- batch_size = query_output.shape[0]
49
- outputs = []
50
- for i in range(batch_size):
51
- outputs.append(self.language_projection(query_output[i : i + 1]))
52
-
53
- outputs = torch.cat(outputs, dim=0)
54
- return outputs
55
-
56
- def __call__(self, *args: Any, **kwds: Any) -> Any:
57
- return self.forward(*args, **kwds)
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
52
+ query_output_item = query_output[index : index + 1]
53
+ return ([query_output_item], {})
58
54
 
59
- def __repr__(self) -> str:
60
- return repr(self.language_projection)
55
+ def _process_outputs(self, outputs: list, **kwargs):
56
+ output = torch.cat(outputs, dim=0)
57
+ return output
61
58
 
62
59
 
63
60
  class RBLNBlip2VisionModel(RBLNModel):
@@ -68,6 +65,8 @@ class RBLNBlip2VisionModel(RBLNModel):
68
65
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
69
66
  """
70
67
 
68
+ _tp_support = False
69
+
71
70
  def get_input_embeddings(self):
72
71
  return self.embeddings
73
72
 
@@ -96,8 +95,7 @@ class RBLNBlip2VisionModel(RBLNModel):
96
95
  (
97
96
  "pixel_values",
98
97
  [
99
- # support for vllm CB (prefill)
100
- 1,
98
+ rbln_config.batch_size,
101
99
  model_config.num_channels,
102
100
  model_config.image_size,
103
101
  model_config.image_size,
@@ -147,6 +145,8 @@ class RBLNBlip2QFormerModel(RBLNModel):
147
145
  mechanisms for multimodal understanding tasks.
148
146
  """
149
147
 
148
+ _tp_support = False
149
+
150
150
  def get_input_embeddings(self):
151
151
  return self.embeddings.word_embeddings
152
152
 
@@ -200,7 +200,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
200
200
  (
201
201
  "query_embeds",
202
202
  [
203
- 1,
203
+ rbln_config.batch_size,
204
204
  rbln_config.num_query_tokens,
205
205
  model_config.hidden_size,
206
206
  ],
@@ -209,7 +209,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
209
209
  (
210
210
  "encoder_hidden_states",
211
211
  [
212
- 1,
212
+ rbln_config.batch_size,
213
213
  # image_text_hidden_size + cls token
214
214
  rbln_config.image_text_hidden_size + 1,
215
215
  model_config.encoder_hidden_size,
@@ -219,7 +219,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
219
219
  (
220
220
  "encoder_attention_mask",
221
221
  # image_text_hidden_size + cls token
222
- [1, rbln_config.image_text_hidden_size + 1],
222
+ [rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
223
223
  "int64",
224
224
  ),
225
225
  ]
@@ -266,7 +266,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
266
266
  )
267
267
 
268
268
 
269
- class RBLNBlip2ForConditionalGeneration(RBLNModel):
269
+ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
270
270
  """
271
271
  RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
272
272
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -434,3 +434,66 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
434
434
  )
435
435
 
436
436
  return inputs_embeds
437
+
438
+ @torch.no_grad()
439
+ def generate(
440
+ self,
441
+ pixel_values: torch.FloatTensor,
442
+ input_ids: Optional[torch.LongTensor] = None,
443
+ attention_mask: Optional[torch.LongTensor] = None,
444
+ inputs_embeds: Optional[torch.FloatTensor] = None,
445
+ interpolate_pos_encoding: bool = False,
446
+ **generate_kwargs,
447
+ ) -> torch.LongTensor:
448
+ batch_size = pixel_values.shape[0]
449
+ image_embeds = self.vision_model(
450
+ pixel_values,
451
+ return_dict=True,
452
+ interpolate_pos_encoding=interpolate_pos_encoding,
453
+ ).last_hidden_state
454
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
455
+
456
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
457
+ query_outputs = self.qformer(
458
+ query_embeds=query_tokens,
459
+ encoder_hidden_states=image_embeds,
460
+ encoder_attention_mask=image_attention_mask,
461
+ return_dict=True,
462
+ )
463
+ query_output = query_outputs.last_hidden_state
464
+
465
+ if query_output.dtype != image_embeds.dtype:
466
+ query_output = query_output.to(image_embeds.dtype)
467
+
468
+ language_model_inputs = self.language_projection(query_output)
469
+
470
+ if inputs_embeds is None:
471
+ if input_ids is None:
472
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
473
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
474
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
475
+ input_ids = input_ids.repeat(batch_size, 1)
476
+ inputs_embeds = self.get_input_embeddings()(input_ids)
477
+
478
+ if attention_mask is None:
479
+ attention_mask = torch.ones_like(input_ids)
480
+
481
+ if input_ids is None:
482
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
483
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
484
+ )
485
+ special_image_mask = special_image_mask.all(-1)
486
+ else:
487
+ special_image_mask = input_ids == self.config.image_token_id
488
+
489
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
490
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
491
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
492
+
493
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
494
+ if not self.language_model.config.is_encoder_decoder:
495
+ inputs["input_ids"] = input_ids
496
+
497
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
498
+
499
+ return outputs
@@ -51,6 +51,8 @@ class RBLNCLIPTextModel(RBLNModel):
51
51
  on RBLN devices, supporting text encoding for multimodal tasks.
52
52
  """
53
53
 
54
+ _tp_support = False
55
+
54
56
  @classmethod
55
57
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
56
58
  return _TextEncoder(model).eval()
@@ -152,6 +154,8 @@ class RBLNCLIPVisionModel(RBLNModel):
152
154
  on RBLN devices, supporting image encoding for multimodal tasks.
153
155
  """
154
156
 
157
+ _tp_support = False
158
+
155
159
  @classmethod
156
160
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
157
161
  wrapper_cfg = {
@@ -24,11 +24,11 @@ class RBLNColPaliForRetrievalWrapper(nn.Module):
24
24
  output_hidden_states: bool = False,
25
25
  ):
26
26
  super().__init__()
27
- self.text_config = causal_lm.config
27
+ self.text_config = causal_lm.config.text_config
28
28
  self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
29
29
 
30
30
  self.output_hidden_states = output_hidden_states
31
- self.language_model = self.convert_to_rbln_language_model(causal_lm.model, max_seq_len)
31
+ self.language_model = self.convert_to_rbln_language_model(causal_lm.model.language_model, max_seq_len)
32
32
 
33
33
  self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
34
34
  self.embedding_proj_layer = embedding_proj_layer
@@ -14,6 +14,10 @@
14
14
  from typing import Any, List, Optional, Union
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
18
+
19
+
20
+ logger = get_logger(__name__)
17
21
 
18
22
 
19
23
  class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
@@ -24,29 +28,30 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
24
28
  including vision tower settings and multi-sequence length support.
25
29
 
26
30
  Example usage:
27
- ```python
28
- from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
29
-
30
- # Create a configuration object
31
- config = RBLNColPaliForRetrievalConfig(
32
- max_seq_lens=1152,
33
- output_hidden_states=False,
34
- tensor_parallel_size=4
35
- )
36
-
37
- # Use the configuration with from_pretrained
38
- model = RBLNColPaliForRetrieval.from_pretrained(
39
- "vidore/colpali-v1.3-hf",
40
- export=True,
41
- rbln_config=config
42
- )
43
- ```
31
+ ```python
32
+ from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
33
+
34
+ # Create a configuration object
35
+ config = RBLNColPaliForRetrievalConfig(
36
+ max_seq_lens=1152,
37
+ output_hidden_states=False,
38
+ tensor_parallel_size=4
39
+ )
40
+
41
+ # Use the configuration with from_pretrained
42
+ model = RBLNColPaliForRetrieval.from_pretrained(
43
+ "vidore/colpali-v1.3-hf",
44
+ export=True,
45
+ rbln_config=config
46
+ )
47
+ ```
44
48
  """
45
49
 
46
50
  submodules = ["vision_tower"]
47
51
 
48
52
  def __init__(
49
53
  self,
54
+ batch_size: Optional[int] = None,
50
55
  max_seq_lens: Union[int, List[int]] = None,
51
56
  output_hidden_states: Optional[bool] = None,
52
57
  vision_tower: Optional[RBLNModelConfig] = None,
@@ -54,6 +59,8 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
54
59
  ):
55
60
  """
56
61
  Args:
62
+ batch_size (Optional[int]): The batch size for the model.
63
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
57
64
  max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
58
65
  This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
59
66
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
@@ -63,6 +70,15 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
63
70
  ValueError: If batch_size is not a positive integer.
64
71
  """
65
72
  super().__init__(**kwargs)
66
- self.vision_tower = vision_tower
73
+ self.batch_size = batch_size or 1
74
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
+
77
+ if self.batch_size != 1:
78
+ logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
+
80
+ self.vision_tower = self.initialize_submodule_config(
81
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
82
+ )
67
83
  self.max_seq_lens = max_seq_lens
68
84
  self.output_hidden_states = output_hidden_states