optimum-rbln 0.9.4a2__py3-none-any.whl → 0.9.5a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +35 -16
- optimum/rbln/modeling_base.py +6 -6
- optimum/rbln/ops/__init__.py +1 -0
- optimum/rbln/ops/attn.py +10 -0
- optimum/rbln/ops/flash_attn.py +8 -0
- optimum/rbln/ops/moe.py +180 -0
- optimum/rbln/ops/sliding_window_attn.py +9 -0
- optimum/rbln/transformers/__init__.py +36 -0
- optimum/rbln/transformers/modeling_attention_utils.py +118 -222
- optimum/rbln/transformers/modeling_outputs.py +25 -0
- optimum/rbln/transformers/modeling_rope_utils.py +78 -42
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -182
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -21
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +118 -16
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +1 -1
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +121 -48
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +5 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +75 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
- optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
- optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +16 -18
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
- optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
- optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
- optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +6 -4
- optimum/rbln/transformers/models/llava/modeling_llava.py +0 -1
- optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
- optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
- optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
- optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
- optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
- optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +271 -122
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
- optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
- optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
- optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +263 -105
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +26 -34
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
- optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
- optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
- optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +4 -18
- optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
- optimum/rbln/transformers/models/t5/t5_architecture.py +15 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
- optimum/rbln/transformers/utils/rbln_quantization.py +20 -12
- optimum/rbln/utils/import_utils.py +16 -1
- optimum/rbln/utils/runtime_utils.py +10 -6
- optimum/rbln/utils/submodule.py +24 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +81 -62
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.9.4a2.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,23 +14,16 @@
|
|
|
14
14
|
|
|
15
15
|
import bisect
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import
|
|
17
|
+
from typing import Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
-
from transformers import PretrainedConfig, PreTrainedModel
|
|
21
20
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
22
21
|
from transformers.modeling_utils import no_init_weights
|
|
23
|
-
from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
|
|
24
|
-
from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
|
|
22
|
+
from transformers.models.colpali.modeling_colpali import ColPaliForRetrieval, ColPaliForRetrievalOutput
|
|
25
23
|
|
|
26
|
-
from ....configuration_utils import
|
|
24
|
+
from ....configuration_utils import RBLNModelConfig
|
|
27
25
|
from ....modeling import RBLNModel
|
|
28
26
|
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
29
|
-
from .colpali_architecture import RBLNColPaliForRetrievalWrapper
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
if TYPE_CHECKING:
|
|
33
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
34
27
|
|
|
35
28
|
|
|
36
29
|
class LoopVisionTower(LoopProcessor):
|
|
@@ -115,17 +108,25 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
115
108
|
from optimum.rbln import RBLNColPaliForRetrieval
|
|
116
109
|
|
|
117
110
|
# Simple usage using rbln_* arguments
|
|
118
|
-
# `max_seq_lens` is automatically inferred from the model config
|
|
119
111
|
model = RBLNColPaliForRetrieval.from_pretrained(
|
|
120
112
|
"vidore/colpali-v1.3-hf",
|
|
121
113
|
export=True,
|
|
122
|
-
|
|
114
|
+
rbln_config={
|
|
115
|
+
"vlm": {
|
|
116
|
+
"language_model": {
|
|
117
|
+
"prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
|
|
118
|
+
}
|
|
119
|
+
}
|
|
120
|
+
}
|
|
123
121
|
)
|
|
124
122
|
|
|
125
123
|
# Using a config dictionary
|
|
126
124
|
rbln_config = {
|
|
127
|
-
"
|
|
128
|
-
|
|
125
|
+
"vlm": {
|
|
126
|
+
"language_model": {
|
|
127
|
+
"prefill_chunk_size": 8192, # same as model's max_position_embeddings (max_seq_len)
|
|
128
|
+
}
|
|
129
|
+
}
|
|
129
130
|
}
|
|
130
131
|
model = RBLNColPaliForRetrieval.from_pretrained(
|
|
131
132
|
"vidore/colpali-v1.3-hf",
|
|
@@ -137,7 +138,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
137
138
|
from optimum.rbln import RBLNColPaliForRetrievalConfig
|
|
138
139
|
|
|
139
140
|
config = RBLNColPaliForRetrievalConfig(
|
|
140
|
-
|
|
141
|
+
vlm={
|
|
142
|
+
"language_model": {"prefill_chunk_size": 8192},
|
|
143
|
+
},
|
|
141
144
|
output_hidden_states=False,
|
|
142
145
|
tensor_parallel_size=4
|
|
143
146
|
)
|
|
@@ -150,212 +153,93 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
150
153
|
"""
|
|
151
154
|
|
|
152
155
|
auto_model_class = None
|
|
156
|
+
_rbln_submodule_postfix = "model"
|
|
153
157
|
_rbln_submodules = [
|
|
154
|
-
{"name": "
|
|
158
|
+
{"name": "vlm"},
|
|
155
159
|
]
|
|
156
160
|
|
|
157
161
|
def __post_init__(self, **kwargs):
|
|
158
|
-
self.
|
|
159
|
-
self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
|
|
160
|
-
|
|
162
|
+
self.vlm_model = self.rbln_submodules[0]
|
|
161
163
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
162
|
-
self.
|
|
163
|
-
self.
|
|
164
|
-
self.multi_modal_projector = self._create_multi_modal_projector()
|
|
165
|
-
self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
|
|
166
|
-
|
|
164
|
+
self.embedding_proj_layer = self._create_embedding_proj_layer()
|
|
165
|
+
self.embedding_proj_layer.load_state_dict(artifacts["embedding_proj_layer"])
|
|
167
166
|
return super().__post_init__(**kwargs)
|
|
168
167
|
|
|
169
|
-
def
|
|
168
|
+
def _create_embedding_proj_layer(self):
|
|
170
169
|
with no_init_weights():
|
|
171
|
-
|
|
172
|
-
self.config.text_config.
|
|
173
|
-
self.config.text_config.hidden_size,
|
|
174
|
-
self.config.text_config.pad_token_id,
|
|
170
|
+
embedding_proj_layer = torch.nn.Linear(
|
|
171
|
+
self.config.vlm_config.text_config.hidden_size, self.config.embedding_dim
|
|
175
172
|
)
|
|
176
|
-
return
|
|
177
|
-
|
|
178
|
-
def _create_multi_modal_projector(self):
|
|
179
|
-
with no_init_weights():
|
|
180
|
-
multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
|
|
181
|
-
return multi_modal_projector
|
|
182
|
-
|
|
183
|
-
@classmethod
|
|
184
|
-
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
185
|
-
return RBLNColPaliForRetrievalWrapper(
|
|
186
|
-
causal_lm=model.vlm,
|
|
187
|
-
embedding_proj_layer=model.embedding_proj_layer,
|
|
188
|
-
max_seq_len=max(rbln_config.max_seq_lens),
|
|
189
|
-
output_hidden_states=rbln_config.output_hidden_states,
|
|
190
|
-
)
|
|
173
|
+
return embedding_proj_layer
|
|
191
174
|
|
|
192
175
|
@classmethod
|
|
193
176
|
def save_torch_artifacts(
|
|
194
177
|
cls,
|
|
195
|
-
model: "
|
|
178
|
+
model: "ColPaliForRetrieval",
|
|
196
179
|
save_dir_path: Path,
|
|
197
180
|
subfolder: str,
|
|
198
181
|
rbln_config: RBLNModelConfig,
|
|
199
182
|
):
|
|
200
183
|
save_dict = {}
|
|
201
|
-
save_dict["
|
|
202
|
-
save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
|
|
184
|
+
save_dict["embedding_proj_layer"] = model.embedding_proj_layer.state_dict()
|
|
203
185
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
204
186
|
|
|
205
|
-
@classmethod
|
|
206
|
-
def _update_rbln_config(
|
|
207
|
-
cls,
|
|
208
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
209
|
-
model: Optional["PreTrainedModel"] = None,
|
|
210
|
-
model_config: Optional["PretrainedConfig"] = None,
|
|
211
|
-
rbln_config: Optional[RBLNModelConfig] = None,
|
|
212
|
-
) -> RBLNModelConfig:
|
|
213
|
-
hidden_size = model_config.vlm_config.text_config.hidden_size
|
|
214
|
-
if rbln_config.max_seq_lens is None:
|
|
215
|
-
rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
|
|
216
|
-
if isinstance(rbln_config.max_seq_lens, int):
|
|
217
|
-
rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
|
|
218
|
-
rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
|
|
219
|
-
|
|
220
|
-
if rbln_config.output_hidden_states is None:
|
|
221
|
-
rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
|
|
222
|
-
|
|
223
|
-
input_infos = []
|
|
224
|
-
for max_seq_len in rbln_config.max_seq_lens:
|
|
225
|
-
input_info = [
|
|
226
|
-
("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
|
|
227
|
-
("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
|
|
228
|
-
("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
|
|
229
|
-
]
|
|
230
|
-
input_infos.append(input_info)
|
|
231
|
-
|
|
232
|
-
rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
|
|
233
|
-
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
234
|
-
|
|
235
|
-
return rbln_config
|
|
236
|
-
|
|
237
|
-
@classmethod
|
|
238
|
-
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
239
|
-
if hasattr(model, "vlm"):
|
|
240
|
-
model.vision_tower = model.vlm.vision_tower
|
|
241
|
-
del model.vlm.model.vision_tower
|
|
242
|
-
return model
|
|
243
|
-
return model
|
|
244
|
-
|
|
245
|
-
def get_image_features(self, pixel_values: torch.Tensor):
|
|
246
|
-
# Projects the last hidden state from the vision model into language model space.
|
|
247
|
-
# Args:
|
|
248
|
-
# pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
|
249
|
-
# The tensors corresponding to the input images.
|
|
250
|
-
# Returns:
|
|
251
|
-
# image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
252
|
-
|
|
253
|
-
vision_output_size = [
|
|
254
|
-
pixel_values.shape[0],
|
|
255
|
-
self.config.vlm_config.vision_config.num_image_tokens,
|
|
256
|
-
self.config.vlm_config.vision_config.hidden_size,
|
|
257
|
-
]
|
|
258
|
-
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
|
|
259
|
-
self.vision_tower(pixel_values, out=vision_output)
|
|
260
|
-
image_features = self.multi_modal_projector(vision_output)
|
|
261
|
-
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
|
262
|
-
return image_features
|
|
263
|
-
|
|
264
|
-
def _preprocess_inputs(
|
|
265
|
-
self,
|
|
266
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
267
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
268
|
-
pixel_values: Optional[torch.FloatTensor] = None,
|
|
269
|
-
**kwargs,
|
|
270
|
-
):
|
|
271
|
-
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
272
|
-
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
273
|
-
|
|
274
|
-
# Replace image id woth PAD if the image token if OOV, to avoid index-errors
|
|
275
|
-
if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
|
|
276
|
-
special_image_mask = input_ids == self.config.vlm_config.image_token_index
|
|
277
|
-
llm_input_ids = input_ids.clone()
|
|
278
|
-
llm_input_ids[special_image_mask] = 0
|
|
279
|
-
else:
|
|
280
|
-
llm_input_ids = input_ids
|
|
281
|
-
|
|
282
|
-
if inputs_embeds is None:
|
|
283
|
-
inputs_embeds = self.embed_tokens(llm_input_ids)
|
|
284
|
-
|
|
285
|
-
# Merge text and images
|
|
286
|
-
image_features = None
|
|
287
|
-
if pixel_values is not None:
|
|
288
|
-
image_features = self.get_image_features(pixel_values)
|
|
289
|
-
special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
|
|
290
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
291
|
-
|
|
292
|
-
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
293
|
-
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
|
294
|
-
|
|
295
|
-
return inputs_embeds, image_features
|
|
296
|
-
|
|
297
187
|
def forward(
|
|
298
188
|
self,
|
|
299
189
|
input_ids: Optional[torch.LongTensor] = None,
|
|
300
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
301
190
|
pixel_values: Optional[torch.FloatTensor] = None,
|
|
302
191
|
attention_mask: Optional[torch.Tensor] = None,
|
|
303
|
-
output_attentions: Optional[bool] = None,
|
|
304
192
|
output_hidden_states: Optional[bool] = None,
|
|
305
193
|
return_dict: Optional[bool] = None,
|
|
306
194
|
**kwargs,
|
|
307
195
|
) -> Union[Tuple, ColPaliForRetrievalOutput]:
|
|
196
|
+
"""
|
|
197
|
+
Forward pass for the RBLN-optimized ColPaliForRetrieval model.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
input_ids (torch.LongTensor of shape (batch_size, sequence_length)): Indices of input sequence tokens in the vocabulary.
|
|
201
|
+
pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
|
|
202
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length)): Mask to avoid performing attention on padding token indices.
|
|
203
|
+
output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
|
|
204
|
+
return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
|
|
205
|
+
|
|
206
|
+
Returns:
|
|
207
|
+
ColPaliForRetrievalOutput or tuple(torch.FloatTensor)
|
|
208
|
+
"""
|
|
308
209
|
if pixel_values is not None:
|
|
309
210
|
pixel_values = pixel_values.to(dtype=self.dtype)
|
|
310
211
|
|
|
311
|
-
if
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
212
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
213
|
+
output_hidden_states = (
|
|
214
|
+
output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
|
|
215
|
+
)
|
|
216
|
+
if output_hidden_states != self.rbln_config.output_hidden_states:
|
|
315
217
|
raise ValueError(
|
|
316
218
|
f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
|
|
317
219
|
f"Please compile again with the correct argument."
|
|
318
220
|
)
|
|
319
221
|
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
222
|
+
vlm_output = self.vlm_model(
|
|
223
|
+
input_ids=input_ids,
|
|
224
|
+
attention_mask=attention_mask,
|
|
225
|
+
pixel_values=pixel_values,
|
|
226
|
+
output_hidden_states=output_hidden_states,
|
|
227
|
+
return_dict=True,
|
|
228
|
+
**kwargs,
|
|
324
229
|
)
|
|
230
|
+
vlm_hidden_states = vlm_output.hidden_states if output_hidden_states else None
|
|
231
|
+
vlm_image_hidden_states = vlm_output.image_hidden_states if pixel_values is not None else None
|
|
325
232
|
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
self.rbln_config.max_seq_lens[0],
|
|
331
|
-
self.rbln_config.max_seq_lens[0],
|
|
332
|
-
]
|
|
333
|
-
outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
|
|
334
|
-
if self.rbln_config.output_hidden_states:
|
|
335
|
-
for _ in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
|
|
336
|
-
outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
|
|
337
|
-
|
|
338
|
-
# Embedding_proj_layer is fused on the bottom of the language model.
|
|
339
|
-
self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
|
|
340
|
-
|
|
341
|
-
embeddings = outputs[0][:, : inputs_embeds.shape[1]]
|
|
342
|
-
hidden_states = (
|
|
343
|
-
None
|
|
344
|
-
if not self.rbln_config.output_hidden_states
|
|
345
|
-
else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
|
|
346
|
-
)
|
|
347
|
-
|
|
348
|
-
# L2 normalization
|
|
349
|
-
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
|
233
|
+
last_hidden_states = vlm_output[0]
|
|
234
|
+
proj_dtype = self.embedding_proj_layer.weight.dtype
|
|
235
|
+
embeddings = self.embedding_proj_layer(last_hidden_states.to(proj_dtype))
|
|
236
|
+
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)
|
|
350
237
|
|
|
351
238
|
if attention_mask is not None:
|
|
352
|
-
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
|
239
|
+
embeddings = embeddings * attention_mask.unsqueeze(-1)
|
|
353
240
|
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
hidden_states=hidden_states,
|
|
360
|
-
image_hidden_states=image_features,
|
|
361
|
-
)
|
|
241
|
+
return ColPaliForRetrievalOutput(
|
|
242
|
+
embeddings=embeddings,
|
|
243
|
+
hidden_states=vlm_hidden_states,
|
|
244
|
+
image_hidden_states=vlm_image_hidden_states,
|
|
245
|
+
)
|
|
@@ -32,14 +32,16 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
|
|
|
32
32
|
|
|
33
33
|
# Create a configuration object
|
|
34
34
|
config = RBLNColQwen2ForRetrievalConfig(
|
|
35
|
-
|
|
36
|
-
"
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
35
|
+
vlm = {
|
|
36
|
+
"visual": {
|
|
37
|
+
"max_seq_lens": 6400,
|
|
38
|
+
"device": 0,
|
|
39
|
+
},
|
|
40
|
+
"max_seq_len": 32_768,
|
|
41
|
+
"tensor_parallel_size": 4,
|
|
42
|
+
"device": [0, 1, 2, 3],
|
|
43
|
+
"output_hidden_states": False,
|
|
44
|
+
}
|
|
43
45
|
)
|
|
44
46
|
|
|
45
47
|
# Use the configuration with from_pretrained
|
|
@@ -51,22 +53,37 @@ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
|
|
|
51
53
|
```
|
|
52
54
|
"""
|
|
53
55
|
|
|
54
|
-
submodules = ["
|
|
56
|
+
submodules = ["vlm"]
|
|
57
|
+
_allow_no_compile_cfgs = True
|
|
55
58
|
|
|
56
59
|
def __init__(
|
|
57
60
|
self,
|
|
58
|
-
visual: Optional[RBLNModelConfig] = None,
|
|
59
61
|
batch_size: Optional[int] = None,
|
|
60
|
-
|
|
62
|
+
output_hidden_states: Optional[bool] = None,
|
|
63
|
+
vlm: Optional[RBLNModelConfig] = None,
|
|
61
64
|
**kwargs,
|
|
62
65
|
):
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
66
|
+
"""
|
|
67
|
+
Args:
|
|
68
|
+
batch_size (Optional[int]): The batch size for the model.
|
|
69
|
+
output_hidden_states (Optional[bool]): Whether to output the hidden states of the VLM model.
|
|
70
|
+
vlm (Optional[RBLNModelConfig]): Configuration for the VLM component.
|
|
71
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
72
|
+
Raises:
|
|
73
|
+
ValueError: If batch_size is not a positive integer.
|
|
74
|
+
"""
|
|
75
|
+
super().__init__(**kwargs)
|
|
76
|
+
self.batch_size = batch_size or 1
|
|
77
|
+
self.output_hidden_states = output_hidden_states or False
|
|
78
|
+
|
|
79
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
80
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
81
|
+
|
|
82
|
+
self.vlm = self.initialize_submodule_config(
|
|
83
|
+
submodule_config=vlm,
|
|
84
|
+
batch_size=batch_size,
|
|
85
|
+
output_hidden_states=output_hidden_states,
|
|
86
|
+
force_kwargs=True,
|
|
87
|
+
logits_to_keep=0,
|
|
88
|
+
use_inputs_embeds=True,
|
|
89
|
+
)
|