optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -12,7 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Optional, Tuple, Union
16
+
15
17
  import torch
18
+ from transformers.modeling_outputs import (
19
+ BaseModelOutputWithPoolingAndCrossAttentions,
20
+ MaskedLMOutput,
21
+ QuestionAnsweringModelOutput,
22
+ )
16
23
 
17
24
  from ...modeling_generic import (
18
25
  RBLNModelForMaskedLM,
@@ -35,9 +42,45 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
35
42
  rbln_model_input_names = ["input_ids", "attention_mask"]
36
43
 
37
44
  @classmethod
38
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
45
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
39
46
  return BertModelWrapper(model, rbln_config)
40
47
 
48
+ def forward(
49
+ self,
50
+ input_ids: Optional[torch.Tensor] = None,
51
+ attention_mask: Optional[torch.Tensor] = None,
52
+ token_type_ids: Optional[torch.Tensor] = None,
53
+ position_ids: Optional[torch.Tensor] = None,
54
+ **kwargs,
55
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple]:
56
+ """
57
+ Forward pass for the RBLN-optimized BERT model for feature extraction tasks.
58
+
59
+ Args:
60
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
61
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
62
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
63
+ position_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of positions of each input sequence tokens in the position embeddings.
64
+
65
+ Returns:
66
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
67
+ """
68
+
69
+ input_map = {
70
+ "input_ids": input_ids,
71
+ "attention_mask": attention_mask,
72
+ "token_type_ids": token_type_ids,
73
+ "position_ids": position_ids,
74
+ }
75
+
76
+ model_input_names = getattr(self.rbln_config, "model_input_names", None)
77
+ if model_input_names is None:
78
+ model_input_names = self.rbln_model_input_names
79
+
80
+ ordered_inputs = [input_map[name] for name in model_input_names if name in input_map]
81
+
82
+ return super().forward(*ordered_inputs, **kwargs)
83
+
41
84
 
42
85
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
43
86
  """
@@ -50,6 +93,27 @@ class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
50
93
 
51
94
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
52
95
 
96
+ def forward(
97
+ self,
98
+ input_ids: Optional[torch.Tensor] = None,
99
+ attention_mask: Optional[torch.Tensor] = None,
100
+ token_type_ids: Optional[torch.Tensor] = None,
101
+ **kwargs,
102
+ ) -> Union[MaskedLMOutput, Tuple]:
103
+ """
104
+ Forward pass for the RBLN-optimized BERT model for masked language modeling tasks.
105
+
106
+ Args:
107
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
108
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
109
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
110
+
111
+ Returns:
112
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
113
+ """
114
+
115
+ return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
116
+
53
117
 
54
118
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
55
119
  """
@@ -61,3 +125,24 @@ class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
61
125
  """
62
126
 
63
127
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
128
+
129
+ def forward(
130
+ self,
131
+ input_ids: Optional[torch.Tensor] = None,
132
+ attention_mask: Optional[torch.Tensor] = None,
133
+ token_type_ids: Optional[torch.Tensor] = None,
134
+ **kwargs,
135
+ ) -> Union[QuestionAnsweringModelOutput, Tuple]:
136
+ """
137
+ Forward pass for the RBLN-optimized BERT model for question answering tasks.
138
+
139
+ Args:
140
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
141
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
142
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
143
+
144
+ Returns:
145
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
146
+ """
147
+
148
+ return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
@@ -14,7 +14,7 @@
14
14
 
15
15
  import inspect
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
18
 
19
19
  import torch
20
20
  from transformers import (
@@ -71,7 +71,7 @@ class RBLNBlip2VisionModel(RBLNModel):
71
71
  return self.embeddings
72
72
 
73
73
  @classmethod
74
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
74
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
75
75
  class Blip2VisionModelWrapper(torch.nn.Module):
76
76
  def __init__(self, model: "Blip2VisionModel") -> None:
77
77
  super().__init__()
@@ -111,11 +111,20 @@ class RBLNBlip2VisionModel(RBLNModel):
111
111
  def forward(
112
112
  self,
113
113
  pixel_values: torch.FloatTensor,
114
- output_attentions: Optional[bool] = None,
115
- output_hidden_states: Optional[bool] = None,
116
- return_dict: Optional[bool] = None,
117
114
  interpolate_pos_encoding: bool = False,
115
+ return_dict: Optional[bool] = None,
118
116
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
117
+ """
118
+ Forward pass for the RBLN-optimized Blip2VisionModel model.
119
+
120
+ Args:
121
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
122
+ interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
123
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
124
+
125
+ Returns:
126
+ BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
127
+ """
119
128
  batch_size = pixel_values.shape[0]
120
129
  outputs = []
121
130
  for i in range(batch_size):
@@ -151,7 +160,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
151
160
  return self.embeddings.word_embeddings
152
161
 
153
162
  @classmethod
154
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
155
164
  class Blip2QFormerModelWrapper(torch.nn.Module):
156
165
  def __init__(self, model: "Blip2QFormerModel"):
157
166
  super().__init__()
@@ -231,17 +240,22 @@ class RBLNBlip2QFormerModel(RBLNModel):
231
240
  def forward(
232
241
  self,
233
242
  query_embeds: torch.FloatTensor,
234
- query_length: Optional[int] = None,
235
- attention_mask: Optional[torch.FloatTensor] = None,
236
- head_mask: Optional[torch.FloatTensor] = None,
237
243
  encoder_hidden_states: Optional[torch.FloatTensor] = None,
238
244
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
239
- past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
240
- use_cache: Optional[bool] = None,
241
- output_attentions: Optional[bool] = None,
242
- output_hidden_states: Optional[bool] = None,
243
245
  return_dict: Optional[bool] = None,
244
246
  ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
247
+ """
248
+ The forward pass for the RBLN-optimized Blip2QFormerModel model.
249
+
250
+ Args:
251
+ query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
252
+ encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
253
+ encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
254
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
255
+
256
+ Returns:
257
+ BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
258
+ """
245
259
  batch_size = query_embeds.shape[0]
246
260
  outputs = []
247
261
  for i in range(batch_size):
@@ -349,7 +363,7 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
349
363
  return self.language_model.get_input_embeddings()
350
364
 
351
365
  @classmethod
352
- def wrap_model_if_needed(cls, model, rbln_config):
366
+ def _wrap_model_if_needed(cls, model, rbln_config):
353
367
  return model.language_projection
354
368
 
355
369
  @classmethod
@@ -444,7 +458,20 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
444
458
  inputs_embeds: Optional[torch.FloatTensor] = None,
445
459
  interpolate_pos_encoding: bool = False,
446
460
  **generate_kwargs,
447
- ) -> torch.LongTensor:
461
+ ) -> List[torch.LongTensor]:
462
+ """
463
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
464
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
465
+
466
+ Args:
467
+ pixel_values (torch.FloatTensor): Input images to be processed.
468
+ input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
469
+ attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
470
+ inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
471
+ interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
472
+ Returns:
473
+ A list of strings of length batch_size * num_captions.
474
+ """
448
475
  batch_size = pixel_values.shape[0]
449
476
  image_embeds = self.vision_model(
450
477
  pixel_values,
@@ -54,7 +54,7 @@ class RBLNCLIPTextModel(RBLNModel):
54
54
  _tp_support = False
55
55
 
56
56
  @classmethod
57
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
57
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
58
58
  return _TextEncoder(model).eval()
59
59
 
60
60
  @classmethod
@@ -92,6 +92,9 @@ class RBLNCLIPTextModel(RBLNModel):
92
92
  Args:
93
93
  input_ids (torch.LongTensor): The input ids to the model.
94
94
  return_dict (Optional[bool]): Whether to return a dictionary of outputs.
95
+
96
+ Returns:
97
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPTextModelOutput object.
95
98
  """
96
99
 
97
100
  # To ignore using attention_mask, we override forward method.
@@ -157,7 +160,7 @@ class RBLNCLIPVisionModel(RBLNModel):
157
160
  _tp_support = False
158
161
 
159
162
  @classmethod
160
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
161
164
  wrapper_cfg = {
162
165
  "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
163
166
  "output_hidden_states": rbln_config.output_hidden_states,
@@ -230,6 +233,9 @@ class RBLNCLIPVisionModel(RBLNModel):
230
233
  output_attentions (Optional[bool]): Whether to return attentions.
231
234
  output_hidden_states (Optional[bool]): Whether to return hidden states.
232
235
  interpolate_pos_encoding (bool): Whether to interpolate position encoding.
236
+
237
+ Returns:
238
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
233
239
  """
234
240
 
235
241
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
@@ -307,6 +313,38 @@ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
307
313
  multimodal embedding alignment tasks.
308
314
  """
309
315
 
316
+ def forward(
317
+ self,
318
+ pixel_values: torch.FloatTensor,
319
+ return_dict: bool = True,
320
+ output_attentions: Optional[bool] = None,
321
+ output_hidden_states: Optional[bool] = None,
322
+ interpolate_pos_encoding: bool = False,
323
+ **kwargs,
324
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
325
+ """
326
+ Forward pass for the RBLN-optimized CLIP vision encoder model with projection.
327
+
328
+ Args:
329
+ pixel_values (torch.Tensor): The pixel values to the model.
330
+ return_dict (bool): Whether to return a dictionary of outputs.
331
+ output_attentions (Optional[bool]): Whether to return attentions.
332
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
333
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
334
+
335
+ Returns:
336
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPVisionModelOutput object.
337
+ """
338
+
339
+ return super().forward(
340
+ pixel_values=pixel_values,
341
+ return_dict=return_dict,
342
+ output_attentions=output_attentions,
343
+ output_hidden_states=output_hidden_states,
344
+ interpolate_pos_encoding=interpolate_pos_encoding,
345
+ **kwargs,
346
+ )
347
+
310
348
  def _prepare_output(self, output, return_dict):
311
349
  # Prepare model output based on return_dict flag.
312
350
  # This method can be overridden by subclasses to provide task-specific output handling.
@@ -77,11 +77,11 @@ class ColPaliModel(nn.Module):
77
77
  self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
78
78
  ):
79
79
  super().__init__()
80
- self._original_mod = model
81
80
  self.layers = nn.ModuleList(layers)
82
81
  self.output_hidden_states = output_hidden_states
83
- self.norm = self._original_mod.norm
84
- self.hidden_size = self._original_mod.config.hidden_size
82
+ self.config = model.config
83
+ self.norm = model.norm
84
+ self.hidden_size = self.config.hidden_size
85
85
  self.max_seq_len = max_seq_len
86
86
 
87
87
  def forward(
@@ -118,7 +118,6 @@ class ColPaliModel(nn.Module):
118
118
  class ColPaliLayer(nn.Module):
119
119
  def __init__(self, layer, self_attn: "ColPaliAttention"):
120
120
  super().__init__()
121
- self._original_mod = layer
122
121
  self.self_attn = self_attn
123
122
  self.mlp = layer.mlp
124
123
  self.input_layernorm = layer.input_layernorm
@@ -155,27 +154,22 @@ class ColPaliLayer(nn.Module):
155
154
  class ColPaliAttention(nn.Module):
156
155
  def __init__(self, self_attn):
157
156
  super().__init__()
158
- self._original_mod = self_attn
159
- self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
160
- self._original_mod.config, "num_attention_heads"
161
- )
162
- self.head_dim = self._original_mod.head_dim
157
+ self.config = self_attn.config
158
+ self.num_heads = getattr(self_attn, "num_heads", None) or self_attn.config.num_attention_heads
159
+ self.head_dim = self_attn.head_dim
163
160
  self.scaling = self.head_dim**-0.5
164
161
 
165
- if hasattr(self._original_mod, "num_key_value_heads"):
166
- self.num_key_value_heads = self._original_mod.num_key_value_heads
167
- elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
168
- self.num_key_value_heads = self._original_mod.config.num_key_value_heads
162
+ if hasattr(self_attn, "num_key_value_heads"):
163
+ self.num_key_value_heads = self_attn.num_key_value_heads
164
+ elif hasattr(self_attn, "config") and hasattr(self_attn.config, "num_key_value_heads"):
165
+ self.num_key_value_heads = self_attn.config.num_key_value_heads
169
166
  else:
170
167
  self.num_key_value_heads = self.num_heads
171
168
 
172
- self.__post_init__()
173
-
174
- def __post_init__(self):
175
- self.q_proj = self._original_mod.q_proj
176
- self.k_proj = self._original_mod.k_proj
177
- self.v_proj = self._original_mod.v_proj
178
- self.o_proj = self._original_mod.o_proj
169
+ self.q_proj = self_attn.q_proj
170
+ self.k_proj = self_attn.k_proj
171
+ self.v_proj = self_attn.v_proj
172
+ self.o_proj = self_attn.o_proj
179
173
 
180
174
  def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
175
  query_states = self.q_proj(hidden_states)
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, List, Optional, Union
14
+ from typing import Any, Optional
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
17
  from ....utils.logging import get_logger
@@ -33,7 +33,9 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
33
33
 
34
34
  # Create a configuration object
35
35
  config = RBLNColPaliForRetrievalConfig(
36
- max_seq_lens=1152,
36
+ vlm={
37
+ "language_model": {"prefill_chunk_size": 8192},
38
+ }
37
39
  output_hidden_states=False,
38
40
  tensor_parallel_size=4
39
41
  )
@@ -47,24 +49,21 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
47
49
  ```
48
50
  """
49
51
 
50
- submodules = ["vision_tower"]
52
+ _allow_no_compile_cfgs = True
53
+ submodules = ["vlm"]
51
54
 
52
55
  def __init__(
53
56
  self,
54
57
  batch_size: Optional[int] = None,
55
- max_seq_lens: Union[int, List[int]] = None,
58
+ vlm: Optional[RBLNModelConfig] = None,
56
59
  output_hidden_states: Optional[bool] = None,
57
- vision_tower: Optional[RBLNModelConfig] = None,
58
60
  **kwargs: Any,
59
61
  ):
60
62
  """
61
63
  Args:
62
64
  batch_size (Optional[int]): The batch size for the model.
63
- vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
64
- max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
65
- This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
66
- output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
67
- vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
65
+ vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
66
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
68
67
  kwargs: Additional arguments passed to the parent RBLNModelConfig.
69
68
  Raises:
70
69
  ValueError: If batch_size is not a positive integer.
@@ -74,11 +73,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
74
73
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
74
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
75
 
77
- if self.batch_size != 1:
78
- logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
-
80
- self.vision_tower = self.initialize_submodule_config(
81
- submodule_config=vision_tower, batch_size=1, force_kwargs=True
76
+ self.output_hidden_states = output_hidden_states or False
77
+ self.vlm = self.initialize_submodule_config(
78
+ submodule_config=vlm, batch_size=batch_size, output_hidden_states=output_hidden_states
82
79
  )
83
- self.max_seq_lens = max_seq_lens
84
- self.output_hidden_states = output_hidden_states