optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,478 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import inspect
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import AutoModelForImageTextToText, LlavaForConditionalGeneration, PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
22
+ from transformers.modeling_utils import no_init_weights
23
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
24
+
25
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ if TYPE_CHECKING:
36
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
37
+
38
+
39
+ class LoopVisionTower(LoopProcessor):
40
+ def __init__(self, vision_tower):
41
+ # FIXME: need to know RBLNModel or RuntimeWrapper
42
+ if hasattr(vision_tower.model, "runtime"):
43
+ super().__init__(model=vision_tower)
44
+ else:
45
+ super().__init__(model=vision_tower.model[0])
46
+
47
+ self.rbln_config = vision_tower.rbln_config
48
+
49
+ def _get_batch_size(self, pixel_values, **kwargs):
50
+ return pixel_values.shape[0]
51
+
52
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
53
+ pixel_values_item = pixel_values[index : index + 1]
54
+ if "image_sizes" in kwargs and kwargs["image_sizes"] is not None:
55
+ ret_val = [pixel_values_item, kwargs["image_sizes"][index : index + 1]]
56
+ else:
57
+ ret_val = [pixel_values_item]
58
+
59
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]] if "out" in kwargs else None
60
+ return (ret_val, {"out": out_buffer})
61
+
62
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
63
+ # when use another Wrapper
64
+ if hasattr(self.rbln_config, "max_image_size"):
65
+ last_hidden_states = [output.last_hidden_state for output in outputs]
66
+ last_hidden_states = torch.cat(last_hidden_states, dim=1)
67
+ hidden_states = tuple(
68
+ torch.cat(
69
+ [output.hidden_states[layer_idx] for output in outputs],
70
+ dim=1,
71
+ )
72
+ for layer_idx in range(len(outputs[0].hidden_states))
73
+ )
74
+ else:
75
+ output = kwargs["out"]
76
+ last_hidden_states = output[0]
77
+
78
+ if not output[2:]:
79
+ hidden_states = None
80
+ else:
81
+ hidden_states = tuple(output[2:])
82
+
83
+ return BaseModelOutputWithPooling(
84
+ last_hidden_state=last_hidden_states,
85
+ pooler_output=None,
86
+ hidden_states=hidden_states,
87
+ )
88
+
89
+
90
+ class LoopProjector(LoopProcessor):
91
+ def __init__(self, multi_modal_projector: "RBLNModel"):
92
+ super().__init__(model=multi_modal_projector)
93
+
94
+ def _get_batch_size(self, image_feature, **kwargs):
95
+ return image_feature.shape[0]
96
+
97
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
98
+ image_feature_item = image_feature[index : index + 1]
99
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
+ return ([image_feature_item], {"out": out_buffer})
101
+
102
+ def _process_outputs(self, outputs: list, **kwargs):
103
+ output = kwargs["out"]
104
+ return output[0]
105
+
106
+
107
+ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
108
+ """
109
+ RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
110
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
111
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
112
+
113
+ Important Note:
114
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
115
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
116
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNLlavaForConditionalGeneration class for details.
117
+ Examples:
118
+ ```python
119
+ from optimum.rbln import RBLNLlavaForConditionalGeneration
120
+ model = RBLNLlavaForConditionalGeneration.from_pretrained(
121
+ "llava-hf/llava-1.5-7b-hf",
122
+ export=True,
123
+ rbln_config={
124
+ "vision_tower": {"output_hidden_states": True},
125
+ "language_model": {
126
+ "tensor_parallel_size": 4,
127
+ "use_inputs_embeds": True, # In Llava, language model must use inputs_embeds as input.
128
+ },
129
+ },
130
+ )
131
+ model.save_pretrained("compiled-llava-1.5-7b-hf")
132
+
133
+ # Using a RBLNLlavaForConditionalGenerationConfig instance (recommended for type checking)
134
+ from optimum.rbln import RBLNLlavaForConditionalGenerationConfig
135
+ vision_config = RBLNCLIPVisionModelConfig(
136
+ batch_size=1,
137
+ output_hidden_states=True
138
+ )
139
+ language_model_config = RBLNLlamaForCausalLMConfig(
140
+ batch_size=1,
141
+ max_seq_len=4096,
142
+ use_inputs_embeds=True,
143
+ tensor_parallel_size=4
144
+ )
145
+ llava_config = RBLNLlavaForConditionalGenerationConfig(
146
+ batch_size=1,
147
+ vision_tower=vision_config,
148
+ language_model=language_model_config
149
+ )
150
+ model = RBLNLlavaForConditionalGeneration.from_pretrained(
151
+ "llava-hf/llava-1.5-7b-hf",
152
+ export=True,
153
+ rbln_config=llava_config
154
+ )
155
+ ```
156
+ """
157
+
158
+ auto_model_class = AutoModelForImageTextToText
159
+ _rbln_submodules = [
160
+ {"name": "vision_tower"},
161
+ {"name": "language_model"},
162
+ ]
163
+
164
+ def __getattr__(self, __name: str) -> Any:
165
+ def redirect(func):
166
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
167
+
168
+ val = getattr(LlavaForConditionalGeneration, __name)
169
+
170
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
171
+ return redirect(val)
172
+ return val
173
+
174
+ def can_generate(self):
175
+ return True
176
+
177
+ @classmethod
178
+ def get_pytorch_model(cls, *args, **kwargs):
179
+ model = super().get_pytorch_model(*args, **kwargs)
180
+
181
+ with no_init_weights():
182
+ model_cls_name = model.model.language_model.__class__.__name__
183
+ causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
184
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
185
+ new_language_model = causal_model_cls(model.model.language_model.config)
186
+
187
+ new_language_model.lm_head = model.lm_head
188
+ new_language_model.model = model.model.language_model
189
+ model.model.language_model = new_language_model
190
+ model.lm_head = None
191
+ del model.lm_head
192
+ return model
193
+
194
+ def __post_init__(self, **kwargs):
195
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
196
+ self.language_model = self.rbln_submodules[1]
197
+ self.multi_modal_projector = LoopProjector(self.model[0])
198
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
199
+ return super().__post_init__(**kwargs)
200
+
201
+ def get_attn_impl(self) -> str:
202
+ return self.rbln_config.language_model.attn_impl
203
+
204
+ def get_kvcache_num_blocks(self) -> int:
205
+ return self.rbln_config.language_model.kvcache_num_blocks
206
+
207
+ def get_input_embeddings(self):
208
+ return self.language_model.get_input_embeddings()
209
+
210
+ @classmethod
211
+ def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
212
+ return model.multi_modal_projector
213
+
214
+ @classmethod
215
+ def _update_rbln_config(
216
+ cls,
217
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
218
+ model: Optional["PreTrainedModel"] = None,
219
+ model_config: Optional["PretrainedConfig"] = None,
220
+ rbln_config: Optional[RBLNModelConfig] = None,
221
+ ) -> RBLNModelConfig:
222
+ # support for pixtral that needs padding
223
+ if hasattr(rbln_config.vision_tower, "max_image_size"):
224
+ num_positions = (
225
+ rbln_config.batch_size
226
+ * (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size)
227
+ * (rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size)
228
+ )
229
+ selected_image_feature_dim = num_positions
230
+
231
+ else:
232
+ num_positions = (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2 + 1
233
+ if model_config.vision_feature_select_strategy == "default":
234
+ selected_image_feature_dim = num_positions - 1
235
+ else:
236
+ selected_image_feature_dim = num_positions
237
+
238
+ input_info = [
239
+ (
240
+ "image_features",
241
+ [
242
+ 1,
243
+ selected_image_feature_dim,
244
+ model_config.vision_config.hidden_size,
245
+ ],
246
+ "float32",
247
+ )
248
+ ]
249
+
250
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
251
+ rbln_config.set_compile_cfgs([rbln_compile_config])
252
+ return rbln_config
253
+
254
+ def prepare_inputs_for_generation(
255
+ self,
256
+ input_ids,
257
+ inputs_embeds=None,
258
+ pixel_values=None,
259
+ attention_mask=None,
260
+ cache_position=None,
261
+ image_sizes=None,
262
+ generate_idx=None,
263
+ **kwargs,
264
+ ):
265
+ is_prefill_phase = generate_idx is None
266
+ model_inputs = {}
267
+
268
+ if is_prefill_phase:
269
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
270
+ cache_position = None
271
+ pixel_values = pixel_values
272
+ model_inputs.update({"image_sizes": image_sizes})
273
+ else:
274
+ if inputs_embeds is not None:
275
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
276
+
277
+ pixel_values = None
278
+ input_ids = input_ids[:, -1:]
279
+ cache_position = generate_idx
280
+ generate_idx = generate_idx + 1
281
+ model_inputs.update({"input_ids": input_ids})
282
+
283
+ if inputs_embeds is not None:
284
+ if self.rbln_config.use_inputs_embeds:
285
+ model_inputs.update({"inputs_embeds": inputs_embeds})
286
+ else:
287
+ raise ValueError(
288
+ "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
289
+ )
290
+ else:
291
+ model_inputs.update({"input_ids": input_ids})
292
+
293
+ model_inputs.update(
294
+ {
295
+ "attention_mask": attention_mask,
296
+ "pixel_values": pixel_values,
297
+ "cache_position": cache_position,
298
+ "generate_idx": generate_idx,
299
+ }
300
+ )
301
+ return model_inputs
302
+
303
+ def _update_model_kwargs_for_generation(self, outputs, model_kwargs, is_encoder_decoder, **kwargs):
304
+ model_kwargs["generate_idx"] = outputs.generate_idx
305
+ return model_kwargs
306
+
307
+ def get_image_features(
308
+ self,
309
+ pixel_values: torch.FloatTensor,
310
+ vision_feature_layer: Union[int, List[int]],
311
+ vision_feature_select_strategy: str,
312
+ **kwargs,
313
+ ):
314
+ if vision_feature_select_strategy not in ["default", "full"]:
315
+ raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
316
+
317
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
318
+
319
+ # prepare out buffer for pre-allocation
320
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
321
+ vision_out_size = [
322
+ pixel_values.shape[0],
323
+ (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
324
+ * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size),
325
+ self.config.vision_config.hidden_size,
326
+ ]
327
+ pooler_out_size = None
328
+ else:
329
+ vision_out_size = [
330
+ pixel_values.shape[0],
331
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2 + 1,
332
+ self.config.vision_config.hidden_size,
333
+ ]
334
+ pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
335
+
336
+ vision_out_buffer = []
337
+ for i in range(self.config.vision_config.num_hidden_layers + 2):
338
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
339
+ if pooler_out_size is not None:
340
+ vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
341
+
342
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, out=vision_out_buffer, **kwargs)
343
+
344
+ if isinstance(vision_feature_layer, int):
345
+ selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
346
+ if vision_feature_select_strategy == "default":
347
+ selected_image_feature = selected_image_feature[:, 1:]
348
+ else:
349
+ hs_pool = [image_outputs.hidden_states[layer_idx] for layer_idx in vision_feature_layer]
350
+ if vision_feature_select_strategy == "default":
351
+ hs_pool = [hs[:, 1:] for hs in hs_pool]
352
+ selected_image_feature = torch.cat(hs_pool, dim=-1)
353
+
354
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
355
+ num_real_patches = selected_image_feature.shape[1]
356
+ max_patches = (
357
+ (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
358
+ * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size)
359
+ * pixel_values.shape[0]
360
+ )
361
+ num_padding_patches = max_patches - num_real_patches
362
+
363
+ projector_out_size = [1, max_patches, self.config.text_config.hidden_size]
364
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
365
+
366
+ padding_tensor = torch.zeros(
367
+ (selected_image_feature.shape[0], num_padding_patches, selected_image_feature.shape[2]),
368
+ dtype=selected_image_feature.dtype,
369
+ )
370
+ padded_feature = torch.cat([selected_image_feature, padding_tensor], dim=1)
371
+ padded_projected_feature = self.multi_modal_projector(padded_feature, out=projector_out_buffer)
372
+ image_features = padded_projected_feature[:, :num_real_patches, :]
373
+ else:
374
+ projector_out_size = [
375
+ pixel_values.shape[0] * pixel_values.shape[1],
376
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
377
+ self.config.text_config.hidden_size,
378
+ ]
379
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
380
+ image_features = self.multi_modal_projector(selected_image_feature, out=projector_out_buffer)
381
+
382
+ return image_features
383
+
384
+ def _preprocess_prefill(
385
+ self,
386
+ input_ids: Optional[torch.LongTensor] = None,
387
+ pixel_values: Optional[torch.FloatTensor] = None,
388
+ inputs_embeds: Optional[torch.FloatTensor] = None,
389
+ vision_feature_layer: Optional[Union[int, List[int]]] = None,
390
+ vision_feature_select_strategy: Optional[str] = None,
391
+ return_dict: Optional[bool] = None,
392
+ image_sizes: Optional[torch.Tensor] = None,
393
+ **lm_kwargs,
394
+ ):
395
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
396
+ vision_feature_layer = (
397
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
398
+ )
399
+ vision_feature_select_strategy = (
400
+ vision_feature_select_strategy
401
+ if vision_feature_select_strategy is not None
402
+ else self.config.vision_feature_select_strategy
403
+ )
404
+
405
+ if (input_ids is None) ^ (inputs_embeds is not None):
406
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
407
+
408
+ if pixel_values is not None and inputs_embeds is not None:
409
+ raise ValueError(
410
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
411
+ )
412
+
413
+ if inputs_embeds is None:
414
+ inputs_embeds = self.get_input_embeddings()(input_ids)
415
+
416
+ if pixel_values is not None:
417
+ image_features = self.get_image_features(
418
+ pixel_values=pixel_values,
419
+ vision_feature_layer=vision_feature_layer,
420
+ vision_feature_select_strategy=vision_feature_select_strategy,
421
+ image_sizes=image_sizes,
422
+ )
423
+
424
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1)
425
+ special_image_mask = special_image_mask.expand_as(inputs_embeds)
426
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
427
+
428
+ return inputs_embeds
429
+
430
+ def forward(
431
+ self,
432
+ input_ids: Optional[torch.LongTensor] = None,
433
+ pixel_values: Optional[torch.FloatTensor] = None,
434
+ attention_mask: Optional[torch.Tensor] = None,
435
+ inputs_embeds: Optional[torch.FloatTensor] = None,
436
+ return_dict: Optional[bool] = None,
437
+ cache_position: Optional[torch.LongTensor] = None,
438
+ image_sizes: Optional[torch.Tensor] = None,
439
+ generate_idx: Optional[torch.Tensor] = None,
440
+ **kwargs,
441
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
442
+ # Prefill
443
+ if cache_position is None:
444
+ inputs_embeds = self._preprocess_prefill(
445
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values, image_sizes=image_sizes
446
+ )
447
+ logits = []
448
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
449
+ batch_size = inputs.shape[0]
450
+
451
+ for b_idx in range(batch_size):
452
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
453
+ output = self.language_model.prefill_decoder(
454
+ input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
455
+ inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
456
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
457
+ cache_position=cache_position,
458
+ batch_idx=b_idx,
459
+ )
460
+ logits.append(output.logits)
461
+
462
+ logits = torch.cat(logits, dim=0)
463
+
464
+ # Decoder
465
+ else:
466
+ logits = self.language_model.decoder(
467
+ input_ids=input_ids,
468
+ inputs_embeds=inputs_embeds,
469
+ cache_position=cache_position,
470
+ ).logits
471
+
472
+ if not return_dict:
473
+ return logits, generate_idx
474
+ else:
475
+ return RBLNDecoderOnlyOutput(
476
+ logits=logits,
477
+ generate_idx=generate_idx,
478
+ )
@@ -12,11 +12,10 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
19
- from ...models.clip import RBLNCLIPVisionModelConfig
20
19
 
21
20
 
22
21
  logger = get_logger(__name__)
@@ -38,34 +37,33 @@ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
38
37
  batch_size: Optional[int] = None,
39
38
  vision_tower: Optional[RBLNModelConfig] = None,
40
39
  language_model: Optional[RBLNModelConfig] = None,
41
- **kwargs: Dict[str, Any],
40
+ **kwargs: Any,
42
41
  ):
43
42
  """
44
43
  Args:
45
44
  batch_size (Optional[int]): The batch size for inference. Defaults to 1.
46
45
  vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
47
46
  language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
48
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
47
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
49
48
 
50
49
  Raises:
51
- ValueError: If batch_size is not a positive integer.
50
+ ValueError: If `batch_size` is not a positive integer.
52
51
  """
53
52
  super().__init__(**kwargs)
54
53
  self.batch_size = batch_size or 1
55
54
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
56
55
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
57
56
 
58
- self.vision_tower = self.init_submodule_config(
59
- RBLNCLIPVisionModelConfig,
60
- vision_tower,
61
- )
57
+ if self.batch_size != 1:
58
+ logger.warning("Ignore batch_size for LlavaNext vision tower. It will be set to 1.")
62
59
 
63
- if self.vision_tower.output_hidden_states is False:
64
- raise ValueError(
65
- f"LlavaNext requires output_hidden_states to be True, but found output_hidden_states={self.vision_tower.output_hidden_states}. "
66
- f"Please compile again with the correct argument."
67
- )
68
- else:
69
- self.vision_tower.output_hidden_states = True
60
+ self.vision_tower = self.initialize_submodule_config(
61
+ submodule_config=vision_tower,
62
+ batch_size=1, # vision_tower batch_size is always 1 in LlavaNext
63
+ output_hidden_states=True, # LlavaNext requires output_hidden_states to be True
64
+ force_kwargs=True,
65
+ )
70
66
 
71
- self.language_model = language_model
67
+ self.language_model = self.initialize_submodule_config(
68
+ submodule_config=language_model,
69
+ )