optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -12,20 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
- def __init__(self, batch_size: Optional[int] = None, **kwargs: Dict[str, Any]):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
22
22
  """
23
23
  Args:
24
24
  batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
25
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
25
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
26
26
 
27
27
  Raises:
28
- ValueError: If batch_size is not a positive integer.
28
+ ValueError: If `batch_size` is not a positive integer.
29
29
  """
30
30
  super().__init__(**kwargs)
31
31
  self.batch_size = batch_size or 1
@@ -50,17 +50,20 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
50
50
  interpolate_pos_encoding: Optional[bool] = None,
51
51
  output_hidden_states: Optional[bool] = None,
52
52
  output_attentions: Optional[bool] = None,
53
- **kwargs: Dict[str, Any],
53
+ **kwargs: Any,
54
54
  ):
55
55
  """
56
56
  Args:
57
57
  batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
58
58
  image_size (Optional[int]): The size of input images. Can be an integer for square images,
59
59
  a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
60
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
60
+ interpolate_pos_encoding (Optional[bool]): Whether or not to interpolate pre-trained position encodings. Defaults to `False`.
61
+ output_hidden_states (Optional[bool]): Whether or not to return the hidden states of all layers.
62
+ output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers
63
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
61
64
 
62
65
  Raises:
63
- ValueError: If batch_size is not a positive integer.
66
+ ValueError: If `batch_size` is not a positive integer.
64
67
  """
65
68
  super().__init__(**kwargs)
66
69
  self.batch_size = batch_size or 1
@@ -51,6 +51,8 @@ class RBLNCLIPTextModel(RBLNModel):
51
51
  on RBLN devices, supporting text encoding for multimodal tasks.
52
52
  """
53
53
 
54
+ _tp_support = False
55
+
54
56
  @classmethod
55
57
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
56
58
  return _TextEncoder(model).eval()
@@ -83,7 +85,15 @@ class RBLNCLIPTextModel(RBLNModel):
83
85
  rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
84
86
  return rbln_config
85
87
 
86
- def forward(self, input_ids: torch.LongTensor, return_dict: bool = None, **kwargs) -> torch.FloatTensor:
88
+ def forward(self, input_ids: torch.LongTensor, return_dict: Optional[bool] = None, **kwargs) -> torch.FloatTensor:
89
+ """
90
+ Forward pass for the RBLN-optimized CLIP text encoder model.
91
+
92
+ Args:
93
+ input_ids (torch.LongTensor): The input ids to the model.
94
+ return_dict (Optional[bool]): Whether to return a dictionary of outputs.
95
+ """
96
+
87
97
  # To ignore using attention_mask, we override forward method.
88
98
  output = super().forward(input_ids, return_dict=return_dict)
89
99
  return output
@@ -144,6 +154,8 @@ class RBLNCLIPVisionModel(RBLNModel):
144
154
  on RBLN devices, supporting image encoding for multimodal tasks.
145
155
  """
146
156
 
157
+ _tp_support = False
158
+
147
159
  @classmethod
148
160
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
149
161
  wrapper_cfg = {
@@ -202,13 +214,24 @@ class RBLNCLIPVisionModel(RBLNModel):
202
214
 
203
215
  def forward(
204
216
  self,
205
- pixel_values: Optional[torch.FloatTensor] = None,
217
+ pixel_values: torch.FloatTensor,
206
218
  return_dict: bool = True,
207
- output_attentions: bool = None,
208
- output_hidden_states: bool = None,
219
+ output_attentions: Optional[bool] = None,
220
+ output_hidden_states: Optional[bool] = None,
209
221
  interpolate_pos_encoding: bool = False,
210
222
  **kwargs,
211
223
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
224
+ """
225
+ Forward pass for the RBLN-optimized CLIP vision encoder model.
226
+
227
+ Args:
228
+ pixel_values (torch.Tensor): The pixel values to the model.
229
+ return_dict (bool): Whether to return a dictionary of outputs.
230
+ output_attentions (Optional[bool]): Whether to return attentions.
231
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
232
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
233
+ """
234
+
212
235
  if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
213
236
  logger.warning(
214
237
  f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
@@ -4,10 +4,7 @@ import torch
4
4
  from torch import nn
5
5
  from transformers import GemmaForCausalLM, GemmaModel
6
6
 
7
- from ..decoderonly.decoderonly_architecture import (
8
- RotaryEmbedding,
9
- apply_rotary_pos_emb,
10
- )
7
+ from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
11
8
 
12
9
 
13
10
  def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
@@ -27,11 +24,11 @@ class RBLNColPaliForRetrievalWrapper(nn.Module):
27
24
  output_hidden_states: bool = False,
28
25
  ):
29
26
  super().__init__()
30
- self.text_config = causal_lm.config
27
+ self.text_config = causal_lm.config.text_config
31
28
  self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
32
29
 
33
30
  self.output_hidden_states = output_hidden_states
34
- self.language_model = self.convert_to_rbln_language_model(causal_lm.model, max_seq_len)
31
+ self.language_model = self.convert_to_rbln_language_model(causal_lm.model.language_model, max_seq_len)
35
32
 
36
33
  self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
37
34
  self.embedding_proj_layer = embedding_proj_layer
@@ -11,9 +11,13 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- from typing import Any, Dict, List, Optional, Union
14
+ from typing import Any, List, Optional, Union
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
18
+
19
+
20
+ logger = get_logger(__name__)
17
21
 
18
22
 
19
23
  class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
@@ -24,45 +28,57 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
24
28
  including vision tower settings and multi-sequence length support.
25
29
 
26
30
  Example usage:
27
- ```python
28
- from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
29
-
30
- # Create a configuration object
31
- config = RBLNColPaliForRetrievalConfig(
32
- max_seq_lens=1152,
33
- output_hidden_states=False,
34
- tensor_parallel_size=4
35
- )
36
-
37
- # Use the configuration with from_pretrained
38
- model = RBLNColPaliForRetrieval.from_pretrained(
39
- "vidore/colpali-v1.3-hf",
40
- export=True,
41
- rbln_config=config
42
- )
43
- ```
31
+ ```python
32
+ from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
33
+
34
+ # Create a configuration object
35
+ config = RBLNColPaliForRetrievalConfig(
36
+ max_seq_lens=1152,
37
+ output_hidden_states=False,
38
+ tensor_parallel_size=4
39
+ )
40
+
41
+ # Use the configuration with from_pretrained
42
+ model = RBLNColPaliForRetrieval.from_pretrained(
43
+ "vidore/colpali-v1.3-hf",
44
+ export=True,
45
+ rbln_config=config
46
+ )
47
+ ```
44
48
  """
45
49
 
46
50
  submodules = ["vision_tower"]
47
51
 
48
52
  def __init__(
49
53
  self,
54
+ batch_size: Optional[int] = None,
50
55
  max_seq_lens: Union[int, List[int]] = None,
51
56
  output_hidden_states: Optional[bool] = None,
52
57
  vision_tower: Optional[RBLNModelConfig] = None,
53
- **kwargs: Dict[str, Any],
58
+ **kwargs: Any,
54
59
  ):
55
60
  """
56
61
  Args:
62
+ batch_size (Optional[int]): The batch size for the model.
57
63
  vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
58
64
  max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
59
65
  This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
60
66
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
61
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
67
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
68
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
62
69
  Raises:
63
70
  ValueError: If batch_size is not a positive integer.
64
71
  """
65
72
  super().__init__(**kwargs)
66
- self.vision_tower = vision_tower
73
+ self.batch_size = batch_size or 1
74
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
+
77
+ if self.batch_size != 1:
78
+ logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
+
80
+ self.vision_tower = self.initialize_submodule_config(
81
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
82
+ )
67
83
  self.max_seq_lens = max_seq_lens
68
84
  self.output_hidden_states = output_hidden_states
@@ -14,13 +14,11 @@
14
14
 
15
15
  import bisect
16
16
  from pathlib import Path
17
- from typing import TYPE_CHECKING, Any, Optional, Union
17
+ from tempfile import TemporaryDirectory
18
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
18
19
 
19
20
  import torch
20
- from transformers import (
21
- PretrainedConfig,
22
- PreTrainedModel,
23
- )
21
+ from transformers import PretrainedConfig, PreTrainedModel
24
22
  from transformers.modeling_outputs import BaseModelOutputWithPooling
25
23
  from transformers.modeling_utils import no_init_weights
26
24
  from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
@@ -28,105 +26,72 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModal
28
26
 
29
27
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
30
28
  from ....modeling import RBLNModel
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
31
30
  from .colpali_architecture import RBLNColPaliForRetrievalWrapper
32
31
 
33
32
 
34
33
  if TYPE_CHECKING:
35
- from transformers import (
36
- AutoFeatureExtractor,
37
- AutoProcessor,
38
- AutoTokenizer,
39
- PretrainedConfig,
40
- )
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
41
35
 
42
36
 
43
- class LoopVisionTower:
44
- def __init__(self, vision_tower: RBLNModel) -> None:
45
- self.vision_tower = vision_tower
37
+ class LoopVisionTower(LoopProcessor):
38
+ def __init__(self, vision_tower: "RBLNModel"):
39
+ super().__init__(model=vision_tower.model[0])
46
40
 
47
- def forward(self, pixel_values, **kwargs):
48
- batch_size = pixel_values.shape[0]
49
- outputs = []
50
- for i in range(batch_size):
51
- outputs.append(self.vision_tower(pixel_values[i : i + 1]))
41
+ def _get_batch_size(self, pixel_values, **kwargs):
42
+ return pixel_values.shape[0]
52
43
 
53
- last_hidden_states = [output.last_hidden_state for output in outputs]
54
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
44
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
45
+ pixel_values_item = pixel_values[index : index + 1]
46
+ out_buffer = kwargs["out"][index : index + 1]
47
+ return ([pixel_values_item], {"out": out_buffer})
55
48
 
49
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
56
50
  return BaseModelOutputWithPooling(
57
- last_hidden_state=last_hidden_states,
51
+ last_hidden_state=kwargs["out"],
58
52
  )
59
53
 
60
- def __call__(self, *args: Any, **kwds: Any) -> Any:
61
- return self.forward(*args, **kwds)
62
-
63
- def __repr__(self) -> str:
64
- return repr(self.vision_tower)
65
-
66
54
 
67
- class LoopLanguageModel:
68
- def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig) -> None:
69
- self.language_model = language_model
55
+ class LoopLanguageModel(LoopProcessor):
56
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig):
57
+ super().__init__(model=language_model)
70
58
  self.rbln_config = rbln_config
71
59
 
72
- def prepare_inputs(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor):
60
+ def _get_batch_size(self, inputs_embeds, **kwargs):
61
+ return inputs_embeds.shape[0]
62
+
63
+ def _prepare_inputs_before_loop(self, *, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
73
64
  input_len = inputs_embeds.shape[1]
74
65
  idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
75
66
  if idx == len(self.rbln_config.max_seq_lens):
76
67
  raise ValueError(
77
68
  f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
78
69
  )
79
- else:
80
- max_seq_len = self.rbln_config.max_seq_lens[idx]
81
-
82
- inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
83
- attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
84
- position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
85
-
86
- return inputs_embed, attn_mask, position_ids
87
-
88
- def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
89
- padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
90
- input_batch_size = inputs_embeds.shape[0]
91
- input_seq_len = inputs_embeds.shape[1]
92
-
93
- all_embeddings = []
94
- all_hidden_states = []
95
- for i in range(input_batch_size):
96
- outputs = self.language_model(
97
- inputs_embeds=padded_inputs_embed[i : i + 1],
98
- attention_mask=padded_attn_mask[i : i + 1],
99
- position_ids=padded_position_ids,
100
- )
101
-
102
- if self.rbln_config.output_hidden_states:
103
- embedding = outputs[0]
104
- hidden_states = outputs[1:]
105
- else:
106
- embedding = outputs
107
- hidden_states = None
70
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
71
+ padded_inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
72
+ padded_attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
73
+ padded_position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
74
+
75
+ return {
76
+ "padded_inputs_embed": padded_inputs_embed,
77
+ "padded_attn_mask": padded_attn_mask,
78
+ "padded_position_ids": padded_position_ids,
79
+ }
108
80
 
109
- all_embeddings.append(embedding)
110
- all_hidden_states.append(hidden_states)
81
+ def _prepare_inputs_for_iteration(self, index: int, common_inputs, *args, **kwargs):
82
+ item_kwargs = {
83
+ "inputs_embeds": common_inputs["padded_inputs_embed"][index : index + 1],
84
+ "attention_mask": common_inputs["padded_attn_mask"][index : index + 1],
85
+ "position_ids": common_inputs["padded_position_ids"],
86
+ "out": [tensor[index : index + 1] for tensor in kwargs["out"]],
87
+ }
88
+ return ([], item_kwargs)
111
89
 
112
- embeddings = torch.cat(all_embeddings, dim=0)[:, :input_seq_len]
90
+ def _process_outputs(self, outputs: list, **kwargs):
113
91
  if self.rbln_config.output_hidden_states:
114
- hidden_states = [
115
- torch.cat(
116
- [batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
117
- dim=0,
118
- )
119
- for layer_idx in range(len(all_hidden_states[0]))
120
- ]
121
- return embeddings, tuple(hidden_states)
92
+ return kwargs["out"][0], tuple(kwargs["out"][1:])
122
93
  else:
123
- return embeddings
124
-
125
- def __call__(self, *args: Any, **kwds: Any) -> Any:
126
- return self.forward(*args, **kwds)
127
-
128
- def __repr__(self) -> str:
129
- return repr(self.language_model)
94
+ return kwargs["out"]
130
95
 
131
96
 
132
97
  class RBLNColPaliForRetrieval(RBLNModel):
@@ -134,8 +99,8 @@ class RBLNColPaliForRetrieval(RBLNModel):
134
99
  The ColPali Model transformer for document retrieval using vision-language models.
135
100
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
136
101
 
137
- A class to convert and run pre-trained transformers based ColPaliForRetrieval model on RBLN devices.
138
- It implements the methods to convert a pre-trained transformers ColPaliForRetrieval model into a RBLN transformer model by:
102
+ A class to convert and run pre-trained transformers based `ColPaliForRetrieval` model on RBLN devices.
103
+ It implements the methods to convert a pre-trained transformers `ColPaliForRetrieval` model into a RBLN transformer model by:
139
104
 
140
105
  - transferring the checkpoint weights of the original into an optimized RBLN graph,
141
106
  - compiling the resulting graph using the RBLN compiler.
@@ -219,7 +184,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
219
184
  @classmethod
220
185
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
221
186
  return RBLNColPaliForRetrievalWrapper(
222
- causal_lm=model.vlm.language_model,
187
+ causal_lm=model.vlm,
223
188
  embedding_proj_layer=model.embedding_proj_layer,
224
189
  max_seq_len=max(rbln_config.max_seq_lens),
225
190
  output_hidden_states=rbln_config.output_hidden_states,
@@ -259,9 +224,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
259
224
  input_infos = []
260
225
  for max_seq_len in rbln_config.max_seq_lens:
261
226
  input_info = [
262
- ("inputs_embeds", [1, max_seq_len, hidden_size], "float32"),
263
- ("attention_mask", [1, max_seq_len], "float32"),
264
- ("position_ids", [1, max_seq_len], "int32"),
227
+ ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
228
+ ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
229
+ ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
265
230
  ]
266
231
  input_infos.append(input_info)
267
232
 
@@ -271,19 +236,49 @@ class RBLNColPaliForRetrieval(RBLNModel):
271
236
  return rbln_config
272
237
 
273
238
  @classmethod
274
- def from_model(cls, model: "PreTrainedModel", *args, **kwargs):
239
+ def from_model(
240
+ cls,
241
+ model: "PreTrainedModel",
242
+ config: Optional[PretrainedConfig] = None,
243
+ rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
244
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
245
+ subfolder: str = "",
246
+ **kwargs: Any,
247
+ ) -> "RBLNModel":
248
+ """
249
+ Converts and compiles a pre-trained HuggingFace library model into a RBLN model.
250
+ This method performs the actual model conversion and compilation process.
251
+
252
+ Args:
253
+ model (PreTrainedModel): The PyTorch model to be compiled.
254
+ The object must be an instance of the HuggingFace transformers PreTrainedModel class.
255
+ config (Optional[PretrainedConfig]): The configuration object associated with the model.
256
+ rbln_config (Optional[Union[RBLNModelConfig, Dict]]): Configuration for RBLN model compilation and runtime.
257
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNLlamaForCausalLMConfig` for Llama models).
258
+ For detailed configuration options, see the specific model's configuration class documentation.
259
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
260
+
261
+ The method performs the following steps:
262
+
263
+ 1. Compiles the PyTorch model into an optimized RBLN graph
264
+ 2. Configures the model for the specified NPU device
265
+ 3. Creates the necessary runtime objects if requested
266
+ 4. Saves the compiled model and configurations
267
+
268
+ Returns:
269
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
270
+ """
275
271
  if not hasattr(model, "vision_tower"):
276
272
  model.vision_tower = model.vlm.vision_tower
277
- del model.vlm.vision_tower
278
- model = super().from_model(model, *args, **kwargs)
273
+ del model.vlm.model.vision_tower
274
+ model = super().from_model(model, config, rbln_config, model_save_dir, subfolder, **kwargs)
279
275
  return model
280
276
 
281
277
  @classmethod
282
278
  def get_pytorch_model(cls, *args, **kwargs):
283
279
  model = super().get_pytorch_model(*args, **kwargs)
284
280
  model.vision_tower = model.vlm.vision_tower
285
- del model.vlm.vision_tower
286
-
281
+ del model.vlm.model.vision_tower
287
282
  return model
288
283
 
289
284
  def get_image_features(self, pixel_values: torch.Tensor):
@@ -294,8 +289,14 @@ class RBLNColPaliForRetrieval(RBLNModel):
294
289
  # Returns:
295
290
  # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
296
291
 
297
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
298
- image_features = self.multi_modal_projector(vision_outputs)
292
+ vision_output_size = [
293
+ pixel_values.shape[0],
294
+ self.config.vlm_config.vision_config.num_image_tokens,
295
+ self.config.vlm_config.vision_config.hidden_size,
296
+ ]
297
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
298
+ self.vision_tower(pixel_values, out=vision_output)
299
+ image_features = self.multi_modal_projector(vision_output)
299
300
  image_features = image_features / (self.config.text_config.hidden_size**0.5)
300
301
  return image_features
301
302
 
@@ -342,7 +343,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
342
343
  output_hidden_states: Optional[bool] = None,
343
344
  return_dict: Optional[bool] = None,
344
345
  **kwargs,
345
- ) -> ColPaliForRetrievalOutput:
346
+ ) -> Union[Tuple, ColPaliForRetrievalOutput]:
346
347
  if pixel_values is not None:
347
348
  pixel_values = pixel_values.to(dtype=self.dtype)
348
349
 
@@ -361,11 +362,27 @@ class RBLNColPaliForRetrieval(RBLNModel):
361
362
  input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
362
363
  )
363
364
 
365
+ outputs = []
366
+ language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
367
+ language_model_hidden_states_size = [
368
+ inputs_embeds.shape[0],
369
+ self.rbln_config.max_seq_lens[0],
370
+ self.rbln_config.max_seq_lens[0],
371
+ ]
372
+ outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
373
+ if self.rbln_config.output_hidden_states:
374
+ for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
375
+ outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
376
+
364
377
  # Embedding_proj_layer is fused on the bottom of the language model.
365
- outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
378
+ self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
366
379
 
367
- embeddings = outputs if not self.rbln_config.output_hidden_states else outputs[0]
368
- hidden_states = None if not self.rbln_config.output_hidden_states else outputs[1]
380
+ embeddings = outputs[0][:, : inputs_embeds.shape[1]]
381
+ hidden_states = (
382
+ None
383
+ if not self.rbln_config.output_hidden_states
384
+ else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
385
+ )
369
386
 
370
387
  # L2 normalization
371
388
  embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
@@ -0,0 +1,2 @@
1
+ from .configuration_colqwen2 import RBLNColQwen2ForRetrievalConfig
2
+ from .modeling_colqwen2 import RBLNColQwen2ForRetrieval