optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,13 +14,10 @@
|
|
|
14
14
|
|
|
15
15
|
import bisect
|
|
16
16
|
from pathlib import Path
|
|
17
|
-
from typing import TYPE_CHECKING,
|
|
17
|
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
|
18
18
|
|
|
19
19
|
import torch
|
|
20
|
-
from transformers import
|
|
21
|
-
PretrainedConfig,
|
|
22
|
-
PreTrainedModel,
|
|
23
|
-
)
|
|
20
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
24
21
|
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
25
22
|
from transformers.modeling_utils import no_init_weights
|
|
26
23
|
from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
|
|
@@ -28,105 +25,72 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModal
|
|
|
28
25
|
|
|
29
26
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
30
27
|
from ....modeling import RBLNModel
|
|
28
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
31
29
|
from .colpali_architecture import RBLNColPaliForRetrievalWrapper
|
|
32
30
|
|
|
33
31
|
|
|
34
32
|
if TYPE_CHECKING:
|
|
35
|
-
from transformers import
|
|
36
|
-
AutoFeatureExtractor,
|
|
37
|
-
AutoProcessor,
|
|
38
|
-
AutoTokenizer,
|
|
39
|
-
PretrainedConfig,
|
|
40
|
-
)
|
|
33
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
|
41
34
|
|
|
42
35
|
|
|
43
|
-
class LoopVisionTower:
|
|
44
|
-
def __init__(self, vision_tower: RBLNModel)
|
|
45
|
-
|
|
36
|
+
class LoopVisionTower(LoopProcessor):
|
|
37
|
+
def __init__(self, vision_tower: "RBLNModel"):
|
|
38
|
+
super().__init__(model=vision_tower.model[0])
|
|
46
39
|
|
|
47
|
-
def
|
|
48
|
-
|
|
49
|
-
outputs = []
|
|
50
|
-
for i in range(batch_size):
|
|
51
|
-
outputs.append(self.vision_tower(pixel_values[i : i + 1]))
|
|
40
|
+
def _get_batch_size(self, pixel_values, **kwargs):
|
|
41
|
+
return pixel_values.shape[0]
|
|
52
42
|
|
|
53
|
-
|
|
54
|
-
|
|
43
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
|
|
44
|
+
pixel_values_item = pixel_values[index : index + 1]
|
|
45
|
+
out_buffer = kwargs["out"][index : index + 1]
|
|
46
|
+
return ([pixel_values_item], {"out": out_buffer})
|
|
55
47
|
|
|
48
|
+
def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
|
|
56
49
|
return BaseModelOutputWithPooling(
|
|
57
|
-
last_hidden_state=
|
|
50
|
+
last_hidden_state=kwargs["out"],
|
|
58
51
|
)
|
|
59
52
|
|
|
60
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
61
|
-
return self.forward(*args, **kwds)
|
|
62
|
-
|
|
63
|
-
def __repr__(self) -> str:
|
|
64
|
-
return repr(self.vision_tower)
|
|
65
|
-
|
|
66
53
|
|
|
67
|
-
class LoopLanguageModel:
|
|
68
|
-
def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig)
|
|
69
|
-
|
|
54
|
+
class LoopLanguageModel(LoopProcessor):
|
|
55
|
+
def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig):
|
|
56
|
+
super().__init__(model=language_model)
|
|
70
57
|
self.rbln_config = rbln_config
|
|
71
58
|
|
|
72
|
-
def
|
|
59
|
+
def _get_batch_size(self, inputs_embeds, **kwargs):
|
|
60
|
+
return inputs_embeds.shape[0]
|
|
61
|
+
|
|
62
|
+
def _prepare_inputs_before_loop(self, *, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
|
|
73
63
|
input_len = inputs_embeds.shape[1]
|
|
74
64
|
idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
|
|
75
65
|
if idx == len(self.rbln_config.max_seq_lens):
|
|
76
66
|
raise ValueError(
|
|
77
67
|
f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
|
|
78
68
|
)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
|
|
90
|
-
input_batch_size = inputs_embeds.shape[0]
|
|
91
|
-
input_seq_len = inputs_embeds.shape[1]
|
|
92
|
-
|
|
93
|
-
all_embeddings = []
|
|
94
|
-
all_hidden_states = []
|
|
95
|
-
for i in range(input_batch_size):
|
|
96
|
-
outputs = self.language_model(
|
|
97
|
-
inputs_embeds=padded_inputs_embed[i : i + 1],
|
|
98
|
-
attention_mask=padded_attn_mask[i : i + 1],
|
|
99
|
-
position_ids=padded_position_ids,
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
if self.rbln_config.output_hidden_states:
|
|
103
|
-
embedding = outputs[0]
|
|
104
|
-
hidden_states = outputs[1:]
|
|
105
|
-
else:
|
|
106
|
-
embedding = outputs
|
|
107
|
-
hidden_states = None
|
|
69
|
+
max_seq_len = self.rbln_config.max_seq_lens[idx]
|
|
70
|
+
padded_inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
|
|
71
|
+
padded_attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
|
|
72
|
+
padded_position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
|
|
73
|
+
|
|
74
|
+
return {
|
|
75
|
+
"padded_inputs_embed": padded_inputs_embed,
|
|
76
|
+
"padded_attn_mask": padded_attn_mask,
|
|
77
|
+
"padded_position_ids": padded_position_ids,
|
|
78
|
+
}
|
|
108
79
|
|
|
109
|
-
|
|
110
|
-
|
|
80
|
+
def _prepare_inputs_for_iteration(self, index: int, common_inputs, *args, **kwargs):
|
|
81
|
+
item_kwargs = {
|
|
82
|
+
"inputs_embeds": common_inputs["padded_inputs_embed"][index : index + 1],
|
|
83
|
+
"attention_mask": common_inputs["padded_attn_mask"][index : index + 1],
|
|
84
|
+
"position_ids": common_inputs["padded_position_ids"],
|
|
85
|
+
"out": [tensor[index : index + 1] for tensor in kwargs["out"]],
|
|
86
|
+
}
|
|
87
|
+
return ([], item_kwargs)
|
|
111
88
|
|
|
112
|
-
|
|
89
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
113
90
|
if self.rbln_config.output_hidden_states:
|
|
114
|
-
|
|
115
|
-
torch.cat(
|
|
116
|
-
[batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
|
|
117
|
-
dim=0,
|
|
118
|
-
)
|
|
119
|
-
for layer_idx in range(len(all_hidden_states[0]))
|
|
120
|
-
]
|
|
121
|
-
return embeddings, tuple(hidden_states)
|
|
91
|
+
return kwargs["out"][0], tuple(kwargs["out"][1:])
|
|
122
92
|
else:
|
|
123
|
-
return
|
|
124
|
-
|
|
125
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
126
|
-
return self.forward(*args, **kwds)
|
|
127
|
-
|
|
128
|
-
def __repr__(self) -> str:
|
|
129
|
-
return repr(self.language_model)
|
|
93
|
+
return kwargs["out"]
|
|
130
94
|
|
|
131
95
|
|
|
132
96
|
class RBLNColPaliForRetrieval(RBLNModel):
|
|
@@ -134,8 +98,8 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
134
98
|
The ColPali Model transformer for document retrieval using vision-language models.
|
|
135
99
|
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
136
100
|
|
|
137
|
-
A class to convert and run pre-trained transformers based ColPaliForRetrieval model on RBLN devices.
|
|
138
|
-
It implements the methods to convert a pre-trained transformers ColPaliForRetrieval model into a RBLN transformer model by:
|
|
101
|
+
A class to convert and run pre-trained transformers based `ColPaliForRetrieval` model on RBLN devices.
|
|
102
|
+
It implements the methods to convert a pre-trained transformers `ColPaliForRetrieval` model into a RBLN transformer model by:
|
|
139
103
|
|
|
140
104
|
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
|
141
105
|
- compiling the resulting graph using the RBLN compiler.
|
|
@@ -217,9 +181,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
217
181
|
return multi_modal_projector
|
|
218
182
|
|
|
219
183
|
@classmethod
|
|
220
|
-
def
|
|
184
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
|
|
221
185
|
return RBLNColPaliForRetrievalWrapper(
|
|
222
|
-
causal_lm=model.vlm
|
|
186
|
+
causal_lm=model.vlm,
|
|
223
187
|
embedding_proj_layer=model.embedding_proj_layer,
|
|
224
188
|
max_seq_len=max(rbln_config.max_seq_lens),
|
|
225
189
|
output_hidden_states=rbln_config.output_hidden_states,
|
|
@@ -259,9 +223,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
259
223
|
input_infos = []
|
|
260
224
|
for max_seq_len in rbln_config.max_seq_lens:
|
|
261
225
|
input_info = [
|
|
262
|
-
("inputs_embeds", [
|
|
263
|
-
("attention_mask", [
|
|
264
|
-
("position_ids", [
|
|
226
|
+
("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
|
|
227
|
+
("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
|
|
228
|
+
("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
|
|
265
229
|
]
|
|
266
230
|
input_infos.append(input_info)
|
|
267
231
|
|
|
@@ -271,19 +235,11 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
271
235
|
return rbln_config
|
|
272
236
|
|
|
273
237
|
@classmethod
|
|
274
|
-
def
|
|
275
|
-
if
|
|
238
|
+
def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
|
|
239
|
+
if hasattr(model, "vlm"):
|
|
276
240
|
model.vision_tower = model.vlm.vision_tower
|
|
277
|
-
del model.vlm.vision_tower
|
|
278
|
-
|
|
279
|
-
return model
|
|
280
|
-
|
|
281
|
-
@classmethod
|
|
282
|
-
def get_pytorch_model(cls, *args, **kwargs):
|
|
283
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
284
|
-
model.vision_tower = model.vlm.vision_tower
|
|
285
|
-
del model.vlm.vision_tower
|
|
286
|
-
|
|
241
|
+
del model.vlm.model.vision_tower
|
|
242
|
+
return model
|
|
287
243
|
return model
|
|
288
244
|
|
|
289
245
|
def get_image_features(self, pixel_values: torch.Tensor):
|
|
@@ -294,8 +250,14 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
294
250
|
# Returns:
|
|
295
251
|
# image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
|
296
252
|
|
|
297
|
-
|
|
298
|
-
|
|
253
|
+
vision_output_size = [
|
|
254
|
+
pixel_values.shape[0],
|
|
255
|
+
self.config.vlm_config.vision_config.num_image_tokens,
|
|
256
|
+
self.config.vlm_config.vision_config.hidden_size,
|
|
257
|
+
]
|
|
258
|
+
vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
|
|
259
|
+
self.vision_tower(pixel_values, out=vision_output)
|
|
260
|
+
image_features = self.multi_modal_projector(vision_output)
|
|
299
261
|
image_features = image_features / (self.config.text_config.hidden_size**0.5)
|
|
300
262
|
return image_features
|
|
301
263
|
|
|
@@ -342,7 +304,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
342
304
|
output_hidden_states: Optional[bool] = None,
|
|
343
305
|
return_dict: Optional[bool] = None,
|
|
344
306
|
**kwargs,
|
|
345
|
-
) -> ColPaliForRetrievalOutput:
|
|
307
|
+
) -> Union[Tuple, ColPaliForRetrievalOutput]:
|
|
346
308
|
if pixel_values is not None:
|
|
347
309
|
pixel_values = pixel_values.to(dtype=self.dtype)
|
|
348
310
|
|
|
@@ -361,11 +323,27 @@ class RBLNColPaliForRetrieval(RBLNModel):
|
|
|
361
323
|
input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
|
|
362
324
|
)
|
|
363
325
|
|
|
326
|
+
outputs = []
|
|
327
|
+
language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
|
|
328
|
+
language_model_hidden_states_size = [
|
|
329
|
+
inputs_embeds.shape[0],
|
|
330
|
+
self.rbln_config.max_seq_lens[0],
|
|
331
|
+
self.rbln_config.max_seq_lens[0],
|
|
332
|
+
]
|
|
333
|
+
outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
|
|
334
|
+
if self.rbln_config.output_hidden_states:
|
|
335
|
+
for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
|
|
336
|
+
outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
|
|
337
|
+
|
|
364
338
|
# Embedding_proj_layer is fused on the bottom of the language model.
|
|
365
|
-
|
|
339
|
+
self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
|
|
366
340
|
|
|
367
|
-
embeddings = outputs
|
|
368
|
-
hidden_states =
|
|
341
|
+
embeddings = outputs[0][:, : inputs_embeds.shape[1]]
|
|
342
|
+
hidden_states = (
|
|
343
|
+
None
|
|
344
|
+
if not self.rbln_config.output_hidden_states
|
|
345
|
+
else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
|
|
346
|
+
)
|
|
369
347
|
|
|
370
348
|
# L2 normalization
|
|
371
349
|
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import List, Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
import torch.nn as nn
|
|
19
|
+
from transformers import PreTrainedModel
|
|
20
|
+
|
|
21
|
+
from optimum.rbln.transformers.models.decoderonly.decoderonly_architecture import (
|
|
22
|
+
DecoderOnlyLayer,
|
|
23
|
+
DecoderOnlyModel,
|
|
24
|
+
DecoderOnlyWrapper,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
from .configuration_colqwen2 import (
|
|
28
|
+
RBLNColQwen2ForRetrievalConfig,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
|
|
33
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
34
|
+
cos = cos[position_ids[0]][None, None, None, :, :]
|
|
35
|
+
sin = sin[position_ids[0]][None, None, None, :, :]
|
|
36
|
+
|
|
37
|
+
return cos, sin
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ColQwen2LanguageModelWrapper(DecoderOnlyWrapper):
|
|
41
|
+
def __init__(
|
|
42
|
+
self, model: PreTrainedModel, rbln_config: "RBLNColQwen2ForRetrievalConfig", use_rotary_emb: bool = True
|
|
43
|
+
):
|
|
44
|
+
model.config = (
|
|
45
|
+
model.config.vlm_config.text_config if hasattr(model.config, "vlm_config") else model.config.text_config
|
|
46
|
+
)
|
|
47
|
+
super().__init__(model, rbln_config, use_rotary_emb)
|
|
48
|
+
|
|
49
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
50
|
+
return model.language_model.layers
|
|
51
|
+
|
|
52
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
53
|
+
new_layers = []
|
|
54
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
55
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
56
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
57
|
+
self.get_attn_layer(layer),
|
|
58
|
+
self.rbln_config,
|
|
59
|
+
is_sliding=is_sliding,
|
|
60
|
+
)
|
|
61
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
62
|
+
new_layers.append(new_layer)
|
|
63
|
+
|
|
64
|
+
new_model = self.get_rbln_model_class()(
|
|
65
|
+
model.language_model,
|
|
66
|
+
new_layers,
|
|
67
|
+
self.rbln_config,
|
|
68
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
# text_projection layer from model
|
|
72
|
+
self.embedding_proj_layer = (
|
|
73
|
+
model.embedding_proj_layer if hasattr(model, "embedding_proj_layer") else model.custom_text_proj
|
|
74
|
+
)
|
|
75
|
+
return new_model
|
|
76
|
+
|
|
77
|
+
def get_rbln_model_class(self):
|
|
78
|
+
return RBLNColQwen2LanguageModel
|
|
79
|
+
|
|
80
|
+
def prepare_forward_args(self, *args):
|
|
81
|
+
args = list(args)
|
|
82
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
83
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
84
|
+
cache_position = args.pop(0)
|
|
85
|
+
global_block_tables = args.pop(0)
|
|
86
|
+
local_block_tables = None
|
|
87
|
+
position_embeds = args.pop(0)
|
|
88
|
+
position_ids = None
|
|
89
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
90
|
+
past_key_values = args
|
|
91
|
+
|
|
92
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
93
|
+
raise ValueError(
|
|
94
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
_past_key_values = []
|
|
98
|
+
for i in range(self.config.num_hidden_layers):
|
|
99
|
+
key_states = past_key_values[i * 2]
|
|
100
|
+
value_states = past_key_values[i * 2 + 1]
|
|
101
|
+
past_key_value = [key_states, value_states]
|
|
102
|
+
_past_key_values.append(past_key_value)
|
|
103
|
+
past_key_values = _past_key_values
|
|
104
|
+
|
|
105
|
+
return (
|
|
106
|
+
input_ids,
|
|
107
|
+
inputs_embeds,
|
|
108
|
+
cache_position,
|
|
109
|
+
global_block_tables,
|
|
110
|
+
local_block_tables,
|
|
111
|
+
attention_mask,
|
|
112
|
+
position_ids,
|
|
113
|
+
past_key_values,
|
|
114
|
+
position_embeds,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def forward(self, *args):
|
|
118
|
+
(
|
|
119
|
+
input_ids,
|
|
120
|
+
inputs_embeds,
|
|
121
|
+
cache_position,
|
|
122
|
+
global_block_tables,
|
|
123
|
+
local_block_tables,
|
|
124
|
+
attention_mask,
|
|
125
|
+
position_ids,
|
|
126
|
+
past_key_values,
|
|
127
|
+
rotary_emb,
|
|
128
|
+
) = self.prepare_forward_args(*args)
|
|
129
|
+
|
|
130
|
+
last_hidden_states = self.model(
|
|
131
|
+
input_ids=input_ids,
|
|
132
|
+
inputs_embeds=inputs_embeds,
|
|
133
|
+
attention_mask=attention_mask,
|
|
134
|
+
cache_position=cache_position,
|
|
135
|
+
position_ids=position_ids,
|
|
136
|
+
past_key_values=past_key_values,
|
|
137
|
+
rotary_emb=rotary_emb,
|
|
138
|
+
global_block_tables=global_block_tables,
|
|
139
|
+
local_block_tables=local_block_tables,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
proj = self.embedding_proj_layer(last_hidden_states[0])
|
|
143
|
+
all_hidden_states = last_hidden_states[1] if self.rbln_config.output_hidden_states else None
|
|
144
|
+
|
|
145
|
+
if self.rbln_config.output_hidden_states:
|
|
146
|
+
return proj, all_hidden_states
|
|
147
|
+
else:
|
|
148
|
+
return proj
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class RBLNColQwen2LanguageModel(DecoderOnlyModel):
|
|
152
|
+
def __init__(
|
|
153
|
+
self,
|
|
154
|
+
model,
|
|
155
|
+
layers: List["DecoderOnlyLayer"],
|
|
156
|
+
rbln_config: "RBLNColQwen2ForRetrievalConfig",
|
|
157
|
+
use_learned_pos_emb=None,
|
|
158
|
+
):
|
|
159
|
+
super().__init__(model, layers, rbln_config, use_learned_pos_emb)
|
|
160
|
+
|
|
161
|
+
self.output_hidden_states = rbln_config.output_hidden_states
|
|
162
|
+
|
|
163
|
+
def forward(
|
|
164
|
+
self,
|
|
165
|
+
input_ids: torch.Tensor = None,
|
|
166
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
167
|
+
attention_mask: torch.Tensor = None,
|
|
168
|
+
cache_position: torch.Tensor = None,
|
|
169
|
+
position_ids: torch.Tensor = None,
|
|
170
|
+
query_position: torch.Tensor = None,
|
|
171
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
|
172
|
+
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
173
|
+
global_block_tables: Optional[torch.Tensor] = None,
|
|
174
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
175
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
176
|
+
):
|
|
177
|
+
# retrieve input_ids and inputs_embeds
|
|
178
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
179
|
+
raise ValueError(
|
|
180
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
# embed positions
|
|
184
|
+
if inputs_embeds is None:
|
|
185
|
+
inputs_embeds = self.get_embedding()(input_ids)
|
|
186
|
+
|
|
187
|
+
hidden_states = inputs_embeds * self.hidden_multiplier
|
|
188
|
+
|
|
189
|
+
# get cos,sin vector if needed
|
|
190
|
+
position_ids = position_ids if position_ids is not None else cache_position
|
|
191
|
+
if rotary_emb is not None:
|
|
192
|
+
if isinstance(rotary_emb, torch.Tensor):
|
|
193
|
+
cos = rotary_emb[0]
|
|
194
|
+
sin = rotary_emb[1]
|
|
195
|
+
else:
|
|
196
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
|
197
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
|
198
|
+
|
|
199
|
+
# Get sequence positions for flash attention
|
|
200
|
+
if self.attn_impl == "flash_attn":
|
|
201
|
+
seq_positions = cache_position[:, 0]
|
|
202
|
+
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
|
203
|
+
seq_positions=seq_positions, max_seq_len=self.max_seq_len
|
|
204
|
+
)
|
|
205
|
+
else:
|
|
206
|
+
seq_positions = cache_position[:, :1]
|
|
207
|
+
|
|
208
|
+
# Get local cache positions for sliding window layers
|
|
209
|
+
if len(self.sliding_window_layers) > 0:
|
|
210
|
+
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
211
|
+
|
|
212
|
+
all_hidden_states = () if self.output_hidden_states else None
|
|
213
|
+
for layer_idx, layer in enumerate(self.layers):
|
|
214
|
+
if self.output_hidden_states:
|
|
215
|
+
all_hidden_states += (hidden_states,)
|
|
216
|
+
|
|
217
|
+
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
218
|
+
hidden_states = layer(
|
|
219
|
+
hidden_states=hidden_states,
|
|
220
|
+
attention_mask=attention_mask,
|
|
221
|
+
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
222
|
+
past_key_values=past_key_values,
|
|
223
|
+
cos=cos,
|
|
224
|
+
sin=sin,
|
|
225
|
+
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
226
|
+
lora_int_id=lora_int_id,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
230
|
+
if self.output_hidden_states:
|
|
231
|
+
all_hidden_states += (hidden_states,)
|
|
232
|
+
|
|
233
|
+
return hidden_states, all_hidden_states
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
from optimum.rbln.configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for RBLN ColQwen2 models for document retrieval.
|
|
25
|
+
|
|
26
|
+
This class extends RBLNModelConfig with specific configurations for ColQwen2 models,
|
|
27
|
+
including vision tower settings and multi-sequence length support.
|
|
28
|
+
|
|
29
|
+
Example usage:
|
|
30
|
+
```python
|
|
31
|
+
from optimum.rbln import RBLNColQwen2ForRetrievalConfig, RBLNColQwen2ForRetrievalConfig
|
|
32
|
+
|
|
33
|
+
# Create a configuration object
|
|
34
|
+
config = RBLNColQwen2ForRetrievalConfig(
|
|
35
|
+
visual={
|
|
36
|
+
"max_seq_lens": 6400,
|
|
37
|
+
"device": 0,
|
|
38
|
+
},
|
|
39
|
+
max_seq_len=32_768,
|
|
40
|
+
tensor_parallel_size=4,
|
|
41
|
+
device=[0, 1, 2, 3],
|
|
42
|
+
output_hidden_states=False,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Use the configuration with from_pretrained
|
|
46
|
+
model = RBLNColQwen2ForRetrieval.from_pretrained(
|
|
47
|
+
"vidore/colqwen2-v1.0-hf",
|
|
48
|
+
export=True,
|
|
49
|
+
rbln_config=config
|
|
50
|
+
)
|
|
51
|
+
```
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
submodules = ["visual"]
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
visual: Optional[RBLNModelConfig] = None,
|
|
59
|
+
batch_size: Optional[int] = None,
|
|
60
|
+
use_inputs_embeds: bool = True,
|
|
61
|
+
output_hidden_states: Optional[bool] = False,
|
|
62
|
+
**kwargs,
|
|
63
|
+
):
|
|
64
|
+
super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
|
|
65
|
+
if not self.use_inputs_embeds:
|
|
66
|
+
raise ValueError(
|
|
67
|
+
"RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
|
|
68
|
+
"as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
|
|
69
|
+
)
|
|
70
|
+
if batch_size is not None and batch_size != 1:
|
|
71
|
+
raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
|
|
72
|
+
|
|
73
|
+
self.visual = visual
|
|
74
|
+
self.output_hidden_states = output_hidden_states
|