optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -14,13 +14,14 @@
|
|
|
14
14
|
import importlib
|
|
15
15
|
import inspect
|
|
16
16
|
import warnings
|
|
17
|
-
from
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Dict, Optional, Type, Union
|
|
18
19
|
|
|
19
|
-
from transformers import AutoConfig, PretrainedConfig
|
|
20
|
+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
|
|
20
21
|
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
|
21
22
|
from transformers.models.auto.auto_factory import _get_model_class
|
|
22
23
|
|
|
23
|
-
from optimum.rbln.configuration_utils import RBLNAutoConfig
|
|
24
|
+
from optimum.rbln.configuration_utils import RBLNAutoConfig, RBLNModelConfig
|
|
24
25
|
from optimum.rbln.modeling_base import RBLNBaseModel
|
|
25
26
|
from optimum.rbln.utils.model_utils import (
|
|
26
27
|
MODEL_MAPPING,
|
|
@@ -43,10 +44,10 @@ class _BaseAutoModelClass:
|
|
|
43
44
|
@classmethod
|
|
44
45
|
def get_rbln_cls(
|
|
45
46
|
cls,
|
|
46
|
-
pretrained_model_name_or_path,
|
|
47
|
-
*args,
|
|
48
|
-
export=
|
|
49
|
-
**kwargs,
|
|
47
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
48
|
+
*args: Any,
|
|
49
|
+
export: bool = None,
|
|
50
|
+
**kwargs: Any,
|
|
50
51
|
):
|
|
51
52
|
"""
|
|
52
53
|
Determine the appropriate RBLN model class based on the given model ID and configuration.
|
|
@@ -59,6 +60,20 @@ class _BaseAutoModelClass:
|
|
|
59
60
|
Returns:
|
|
60
61
|
RBLNBaseModel: The corresponding RBLN model class.
|
|
61
62
|
"""
|
|
63
|
+
if isinstance(pretrained_model_name_or_path, Path):
|
|
64
|
+
pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
|
|
65
|
+
|
|
66
|
+
if export is None:
|
|
67
|
+
export = not RBLNBaseModel._is_compiled(
|
|
68
|
+
model_id=pretrained_model_name_or_path,
|
|
69
|
+
token=kwargs.get("token"),
|
|
70
|
+
revision=kwargs.get("revision"),
|
|
71
|
+
force_download=kwargs.get("force_download", False),
|
|
72
|
+
cache_dir=kwargs.get("cache_dir"),
|
|
73
|
+
subfolder=kwargs.get("subfolder", ""),
|
|
74
|
+
local_files_only=kwargs.get("local_files_only", False),
|
|
75
|
+
)
|
|
76
|
+
|
|
62
77
|
if export:
|
|
63
78
|
hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
|
|
64
79
|
rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
|
|
@@ -85,9 +100,9 @@ class _BaseAutoModelClass:
|
|
|
85
100
|
@classmethod
|
|
86
101
|
def infer_hf_model_class(
|
|
87
102
|
cls,
|
|
88
|
-
pretrained_model_name_or_path,
|
|
89
|
-
*args,
|
|
90
|
-
**kwargs,
|
|
103
|
+
pretrained_model_name_or_path: Union[str, Path],
|
|
104
|
+
*args: Any,
|
|
105
|
+
**kwargs: Any,
|
|
91
106
|
):
|
|
92
107
|
"""
|
|
93
108
|
Infer the HuggingFace model class based on the configuration or model name.
|
|
@@ -140,7 +155,7 @@ class _BaseAutoModelClass:
|
|
|
140
155
|
return model_class
|
|
141
156
|
|
|
142
157
|
@classmethod
|
|
143
|
-
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
|
|
158
|
+
def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
|
|
144
159
|
"""
|
|
145
160
|
Retrieve the path to the compiled model directory for a given RBLN model.
|
|
146
161
|
|
|
@@ -163,17 +178,77 @@ class _BaseAutoModelClass:
|
|
|
163
178
|
return rbln_config.rbln_model_cls_name
|
|
164
179
|
|
|
165
180
|
@classmethod
|
|
166
|
-
def from_pretrained(
|
|
167
|
-
|
|
168
|
-
|
|
181
|
+
def from_pretrained(
|
|
182
|
+
cls,
|
|
183
|
+
model_id: Union[str, Path],
|
|
184
|
+
export: bool = None,
|
|
185
|
+
rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
|
|
186
|
+
**kwargs,
|
|
187
|
+
):
|
|
188
|
+
"""
|
|
189
|
+
Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
|
|
190
|
+
|
|
191
|
+
This convenience method determines the concrete `RBLN*` model class that matches the
|
|
192
|
+
underlying HuggingFace architecture and dispatches to that class's
|
|
193
|
+
`from_pretrained()` implementation. Depending on whether a compiled RBLN folder is
|
|
194
|
+
detected (or if `export=True` is passed), it will either:
|
|
195
|
+
|
|
196
|
+
- Compile from a HuggingFace checkpoint to an RBLN model
|
|
197
|
+
- Or load an already-compiled RBLN model directory/repository
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
model_id:
|
|
201
|
+
HF repo id or local path. For compiled models, this should point to a directory
|
|
202
|
+
(optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
|
|
203
|
+
export:
|
|
204
|
+
Force compilation from a HuggingFace checkpoint. When `None`, this is inferred by
|
|
205
|
+
checking whether compiled artifacts exist at `model_id`.
|
|
206
|
+
rbln_config:
|
|
207
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
208
|
+
instance of the specific model's config class (e.g., `RBLNLlamaForCausalLMConfig`).
|
|
209
|
+
kwargs: Additional keyword arguments.
|
|
210
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
211
|
+
- Remaining arguments are forwarded to the HuggingFace loader (e.g., `revision`,
|
|
212
|
+
`token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
216
|
+
"""
|
|
217
|
+
rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
|
|
218
|
+
return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
|
|
169
219
|
|
|
170
220
|
@classmethod
|
|
171
|
-
def from_model(
|
|
221
|
+
def from_model(
|
|
222
|
+
cls,
|
|
223
|
+
model: PreTrainedModel,
|
|
224
|
+
config: Optional[PretrainedConfig] = None,
|
|
225
|
+
rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
|
|
226
|
+
**kwargs: Any,
|
|
227
|
+
) -> RBLNBaseModel:
|
|
228
|
+
"""
|
|
229
|
+
Convert and compile an in-memory HuggingFace model into an RBLN model.
|
|
230
|
+
|
|
231
|
+
This method resolves the appropriate concrete `RBLN*` class from the input model's class
|
|
232
|
+
name (e.g., `LlamaForCausalLM` -> `RBLNLlamaForCausalLM`) and then delegates to that
|
|
233
|
+
class's `from_model()` implementation.
|
|
234
|
+
|
|
235
|
+
Args:
|
|
236
|
+
model: A HuggingFace model instance to convert.
|
|
237
|
+
config: The configuration object associated with the model.
|
|
238
|
+
rbln_config:
|
|
239
|
+
RBLN compilation/runtime configuration. May be provided as a dictionary or as an
|
|
240
|
+
instance of the specific model's config class.
|
|
241
|
+
kwargs: Additional keyword arguments.
|
|
242
|
+
- Arguments prefixed with `rbln_` are forwarded to the RBLN config.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
An instantiated RBLN model ready for inference on RBLN NPUs.
|
|
246
|
+
"""
|
|
172
247
|
rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
|
|
173
|
-
return rbln_cls.from_model(model,
|
|
248
|
+
return rbln_cls.from_model(model, config=config, rbln_config=rbln_config, **kwargs)
|
|
174
249
|
|
|
175
250
|
@staticmethod
|
|
176
|
-
def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
|
|
251
|
+
def register(rbln_cls: Type[RBLNBaseModel], exist_ok: bool = False):
|
|
177
252
|
"""
|
|
178
253
|
Register a new RBLN model class.
|
|
179
254
|
|
|
@@ -35,8 +35,12 @@ from transformers.models.auto.modeling_auto import (
|
|
|
35
35
|
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
|
|
36
36
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
|
37
37
|
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
|
|
38
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING,
|
|
39
|
+
MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
|
|
38
40
|
MODEL_FOR_VISION_2_SEQ_MAPPING,
|
|
39
41
|
MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
|
|
42
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
|
|
43
|
+
MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
|
|
40
44
|
MODEL_MAPPING,
|
|
41
45
|
MODEL_MAPPING_NAMES,
|
|
42
46
|
)
|
|
@@ -53,65 +57,106 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
|
|
|
53
57
|
|
|
54
58
|
|
|
55
59
|
class RBLNAutoModel(_BaseAutoModelClass):
|
|
60
|
+
"""Automatically detect all supported transformers models."""
|
|
61
|
+
|
|
56
62
|
_model_mapping = MODEL_MAPPING
|
|
57
63
|
_model_mapping_names = MODEL_MAPPING_NAMES
|
|
58
64
|
|
|
59
65
|
|
|
60
66
|
class RBLNAutoModelForCTC(_BaseAutoModelClass):
|
|
67
|
+
"""Automatically detect Connectionist Temporal Classification (CTC) head Models."""
|
|
68
|
+
|
|
61
69
|
_model_mapping = MODEL_FOR_CTC_MAPPING
|
|
62
70
|
_model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
|
|
63
71
|
|
|
64
72
|
|
|
65
73
|
class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
|
|
74
|
+
"""Automatically detect Casual Language Models."""
|
|
75
|
+
|
|
76
|
+
""""""
|
|
66
77
|
_model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
|
|
67
78
|
_model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
|
68
79
|
|
|
69
80
|
|
|
70
81
|
class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
|
|
82
|
+
"""Automatically detect Sequence to Sequence Language Models."""
|
|
83
|
+
|
|
71
84
|
_model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
|
|
72
85
|
_model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
|
|
73
86
|
|
|
74
87
|
|
|
75
88
|
class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
|
|
89
|
+
"""Automatically detect Sequence to Sequence Generation Models."""
|
|
90
|
+
|
|
76
91
|
_model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
|
|
77
92
|
_model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
|
|
78
93
|
|
|
79
94
|
|
|
80
95
|
class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
|
|
96
|
+
"""Automatically detect Speech Sequence to Sequence Language Models."""
|
|
97
|
+
|
|
81
98
|
_model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
|
|
82
99
|
_model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
|
|
83
100
|
|
|
84
101
|
|
|
85
102
|
class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
|
|
103
|
+
"""Automatically detect Sequence Classification Models."""
|
|
104
|
+
|
|
86
105
|
_model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
|
87
106
|
_model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
|
|
88
107
|
|
|
89
108
|
|
|
90
109
|
class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
|
|
110
|
+
"""Automatically detect Vision to Sequence Generation Models."""
|
|
111
|
+
|
|
91
112
|
_model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
|
|
92
113
|
_model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
|
|
93
114
|
|
|
94
115
|
|
|
95
116
|
class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
|
|
117
|
+
"""Automatically detect Image and Text to Text Generation Models."""
|
|
118
|
+
|
|
96
119
|
_model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
|
|
97
120
|
_model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
|
|
98
121
|
|
|
99
122
|
|
|
100
123
|
class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
|
|
124
|
+
"""Automatically detect Masked Lanuage Models."""
|
|
125
|
+
|
|
101
126
|
_model_mapping = MODEL_FOR_MASKED_LM_MAPPING
|
|
102
127
|
_model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
|
|
103
128
|
|
|
104
129
|
|
|
105
130
|
class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
|
|
131
|
+
"""Automatically detect Audio Classification Models."""
|
|
132
|
+
|
|
106
133
|
_model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
|
|
107
134
|
_model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
|
|
108
135
|
|
|
109
136
|
|
|
110
137
|
class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
|
|
138
|
+
"""Automatically detect Image Classification Models."""
|
|
139
|
+
|
|
111
140
|
_model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
|
|
112
141
|
_model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
|
|
113
142
|
|
|
114
143
|
|
|
115
144
|
class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
|
|
145
|
+
"""Automatically detect Question Answering Models."""
|
|
146
|
+
|
|
116
147
|
_model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
|
|
117
148
|
_model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
|
|
152
|
+
"""Automatically detect Text Encoding Models."""
|
|
153
|
+
|
|
154
|
+
_model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
|
|
155
|
+
_model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
|
|
159
|
+
"""Automatically detect Zero Shot Object Detection Models."""
|
|
160
|
+
|
|
161
|
+
_model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
|
|
162
|
+
_model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
|
|
@@ -16,9 +16,7 @@ from typing import Tuple
|
|
|
16
16
|
|
|
17
17
|
import torch
|
|
18
18
|
from torch import nn
|
|
19
|
-
from transformers.modeling_attn_mask_utils import
|
|
20
|
-
_prepare_4d_attention_mask,
|
|
21
|
-
)
|
|
19
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
22
20
|
from transformers.utils import logging
|
|
23
21
|
|
|
24
22
|
from ..seq2seq.seq2seq_architecture import (
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BertModelWrapper(torch.nn.Module):
|
|
5
|
+
def __init__(self, model, rbln_config):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.model = model
|
|
8
|
+
self.rbln_config = rbln_config
|
|
9
|
+
|
|
10
|
+
def forward(self, *args, **kwargs):
|
|
11
|
+
output = self.model(*args, **kwargs)
|
|
12
|
+
if isinstance(output, torch.Tensor):
|
|
13
|
+
return output
|
|
14
|
+
elif isinstance(output, tuple):
|
|
15
|
+
return tuple(x for x in output if x is not None)
|
|
16
|
+
return output
|
|
@@ -12,15 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
|
|
15
|
+
import torch
|
|
16
|
+
|
|
16
17
|
from ...modeling_generic import (
|
|
17
18
|
RBLNModelForMaskedLM,
|
|
18
19
|
RBLNModelForQuestionAnswering,
|
|
19
20
|
RBLNTransformerEncoderForFeatureExtraction,
|
|
20
21
|
)
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
logger = get_logger(__name__)
|
|
22
|
+
from .bert_architecture import BertModelWrapper
|
|
23
|
+
from .configuration_bert import RBLNBertModelConfig
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
@@ -34,6 +34,10 @@ class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
|
34
34
|
|
|
35
35
|
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
36
36
|
|
|
37
|
+
@classmethod
|
|
38
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
|
|
39
|
+
return BertModelWrapper(model, rbln_config)
|
|
40
|
+
|
|
37
41
|
|
|
38
42
|
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
39
43
|
"""
|
|
@@ -12,9 +12,13 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
from typing import Any,
|
|
15
|
+
from typing import Any, Optional
|
|
16
16
|
|
|
17
17
|
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
18
22
|
|
|
19
23
|
|
|
20
24
|
class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
@@ -25,6 +29,16 @@ class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
|
25
29
|
RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
|
|
26
30
|
"""
|
|
27
31
|
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
batch_size: Optional[int] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.batch_size = batch_size or 1
|
|
39
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
40
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
41
|
+
|
|
28
42
|
|
|
29
43
|
class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
30
44
|
"""
|
|
@@ -36,24 +50,34 @@ class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
|
36
50
|
|
|
37
51
|
def __init__(
|
|
38
52
|
self,
|
|
53
|
+
batch_size: Optional[int] = None,
|
|
39
54
|
num_query_tokens: Optional[int] = None,
|
|
40
55
|
image_text_hidden_size: Optional[int] = None,
|
|
41
56
|
**kwargs,
|
|
42
57
|
):
|
|
43
58
|
"""
|
|
44
59
|
Args:
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
Raises:
|
|
49
|
-
ValueError: If batch_size is not a positive integer.
|
|
60
|
+
num_query_tokens (Optional[int]): The number of query tokens passed through the Transformer.
|
|
61
|
+
image_text_hidden_size (Optional[int]): Dimensionality of the hidden state of the image-text fusion layer.
|
|
62
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
50
63
|
"""
|
|
51
64
|
super().__init__(**kwargs)
|
|
65
|
+
self.batch_size = batch_size or 1
|
|
66
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
67
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
68
|
+
|
|
52
69
|
self.num_query_tokens = num_query_tokens
|
|
53
70
|
self.image_text_hidden_size = image_text_hidden_size
|
|
54
71
|
|
|
55
72
|
|
|
56
73
|
class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
74
|
+
"""
|
|
75
|
+
Configuration class for RBLNBlip2ForConditionalGeneration.
|
|
76
|
+
|
|
77
|
+
This configuration class stores the configuration parameters specific to
|
|
78
|
+
RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
|
|
79
|
+
"""
|
|
80
|
+
|
|
57
81
|
submodules = ["vision_model", "qformer", "language_model"]
|
|
58
82
|
|
|
59
83
|
def __init__(
|
|
@@ -62,14 +86,15 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
62
86
|
vision_model: Optional[RBLNModelConfig] = None,
|
|
63
87
|
qformer: Optional[RBLNModelConfig] = None,
|
|
64
88
|
language_model: Optional[RBLNModelConfig] = None,
|
|
65
|
-
**kwargs:
|
|
89
|
+
**kwargs: Any,
|
|
66
90
|
):
|
|
67
91
|
"""
|
|
68
92
|
Args:
|
|
69
93
|
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
70
94
|
vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
95
|
+
qformer (Optional[RBLNModelConfig]): Configuration for the RBLN-optimized BLIP-2 Q-Former model.
|
|
71
96
|
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
|
72
|
-
|
|
97
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
73
98
|
|
|
74
99
|
Raises:
|
|
75
100
|
ValueError: If batch_size is not a positive integer.
|
|
@@ -79,6 +104,12 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
|
79
104
|
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
80
105
|
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
81
106
|
|
|
82
|
-
self.
|
|
83
|
-
|
|
84
|
-
|
|
107
|
+
if self.batch_size != 1:
|
|
108
|
+
logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
|
|
109
|
+
logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
|
|
110
|
+
|
|
111
|
+
self.vision_model = self.initialize_submodule_config(
|
|
112
|
+
submodule_config=vision_model, batch_size=1, force_kwargs=True
|
|
113
|
+
)
|
|
114
|
+
self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
|
|
115
|
+
self.language_model = self.initialize_submodule_config(submodule_config=language_model)
|
|
@@ -30,38 +30,31 @@ from transformers.utils import logging
|
|
|
30
30
|
|
|
31
31
|
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
32
32
|
from ....modeling import RBLNModel
|
|
33
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
34
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
33
35
|
|
|
34
36
|
|
|
35
37
|
logger = logging.get_logger(__name__)
|
|
36
38
|
|
|
37
39
|
if TYPE_CHECKING:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
AutoProcessor,
|
|
41
|
-
AutoTokenizer,
|
|
42
|
-
)
|
|
40
|
+
import rebel
|
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
43
42
|
|
|
44
43
|
|
|
45
|
-
class LoopProjector:
|
|
46
|
-
def __init__(self, language_projection
|
|
47
|
-
|
|
44
|
+
class LoopProjector(LoopProcessor):
|
|
45
|
+
def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
|
|
46
|
+
super().__init__(model=language_projection)
|
|
48
47
|
|
|
49
|
-
def
|
|
50
|
-
query_output
|
|
48
|
+
def _get_batch_size(self, query_output, **kwargs):
|
|
49
|
+
return query_output.shape[0]
|
|
51
50
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
outputs.append(self.language_projection(query_output[i : i + 1]))
|
|
56
|
-
|
|
57
|
-
outputs = torch.cat(outputs, dim=0)
|
|
58
|
-
return outputs
|
|
59
|
-
|
|
60
|
-
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
|
61
|
-
return self.forward(*args, **kwds)
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
|
|
52
|
+
query_output_item = query_output[index : index + 1]
|
|
53
|
+
return ([query_output_item], {})
|
|
62
54
|
|
|
63
|
-
def
|
|
64
|
-
|
|
55
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
56
|
+
output = torch.cat(outputs, dim=0)
|
|
57
|
+
return output
|
|
65
58
|
|
|
66
59
|
|
|
67
60
|
class RBLNBlip2VisionModel(RBLNModel):
|
|
@@ -72,6 +65,8 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
72
65
|
on RBLN devices, supporting image encoding for multimodal vision-language tasks.
|
|
73
66
|
"""
|
|
74
67
|
|
|
68
|
+
_tp_support = False
|
|
69
|
+
|
|
75
70
|
def get_input_embeddings(self):
|
|
76
71
|
return self.embeddings
|
|
77
72
|
|
|
@@ -100,8 +95,7 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
100
95
|
(
|
|
101
96
|
"pixel_values",
|
|
102
97
|
[
|
|
103
|
-
|
|
104
|
-
1,
|
|
98
|
+
rbln_config.batch_size,
|
|
105
99
|
model_config.num_channels,
|
|
106
100
|
model_config.image_size,
|
|
107
101
|
model_config.image_size,
|
|
@@ -116,7 +110,7 @@ class RBLNBlip2VisionModel(RBLNModel):
|
|
|
116
110
|
|
|
117
111
|
def forward(
|
|
118
112
|
self,
|
|
119
|
-
pixel_values,
|
|
113
|
+
pixel_values: torch.FloatTensor,
|
|
120
114
|
output_attentions: Optional[bool] = None,
|
|
121
115
|
output_hidden_states: Optional[bool] = None,
|
|
122
116
|
return_dict: Optional[bool] = None,
|
|
@@ -151,6 +145,8 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
151
145
|
mechanisms for multimodal understanding tasks.
|
|
152
146
|
"""
|
|
153
147
|
|
|
148
|
+
_tp_support = False
|
|
149
|
+
|
|
154
150
|
def get_input_embeddings(self):
|
|
155
151
|
return self.embeddings.word_embeddings
|
|
156
152
|
|
|
@@ -178,7 +174,12 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
178
174
|
return Blip2QFormerModelWrapper(model).eval()
|
|
179
175
|
|
|
180
176
|
@classmethod
|
|
181
|
-
def _update_submodule_config(
|
|
177
|
+
def _update_submodule_config(
|
|
178
|
+
cls,
|
|
179
|
+
model: "PreTrainedModel",
|
|
180
|
+
rbln_config: RBLNModelConfig,
|
|
181
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
182
|
+
):
|
|
182
183
|
if rbln_config.num_query_tokens is None:
|
|
183
184
|
rbln_config.num_query_tokens = model.config.num_query_tokens
|
|
184
185
|
|
|
@@ -199,7 +200,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
199
200
|
(
|
|
200
201
|
"query_embeds",
|
|
201
202
|
[
|
|
202
|
-
|
|
203
|
+
rbln_config.batch_size,
|
|
203
204
|
rbln_config.num_query_tokens,
|
|
204
205
|
model_config.hidden_size,
|
|
205
206
|
],
|
|
@@ -208,7 +209,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
208
209
|
(
|
|
209
210
|
"encoder_hidden_states",
|
|
210
211
|
[
|
|
211
|
-
|
|
212
|
+
rbln_config.batch_size,
|
|
212
213
|
# image_text_hidden_size + cls token
|
|
213
214
|
rbln_config.image_text_hidden_size + 1,
|
|
214
215
|
model_config.encoder_hidden_size,
|
|
@@ -218,7 +219,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
218
219
|
(
|
|
219
220
|
"encoder_attention_mask",
|
|
220
221
|
# image_text_hidden_size + cls token
|
|
221
|
-
[
|
|
222
|
+
[rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
|
|
222
223
|
"int64",
|
|
223
224
|
),
|
|
224
225
|
]
|
|
@@ -265,7 +266,7 @@ class RBLNBlip2QFormerModel(RBLNModel):
|
|
|
265
266
|
)
|
|
266
267
|
|
|
267
268
|
|
|
268
|
-
class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
269
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
269
270
|
"""
|
|
270
271
|
RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
271
272
|
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
@@ -433,3 +434,66 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
|
|
|
433
434
|
)
|
|
434
435
|
|
|
435
436
|
return inputs_embeds
|
|
437
|
+
|
|
438
|
+
@torch.no_grad()
|
|
439
|
+
def generate(
|
|
440
|
+
self,
|
|
441
|
+
pixel_values: torch.FloatTensor,
|
|
442
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
443
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
444
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
445
|
+
interpolate_pos_encoding: bool = False,
|
|
446
|
+
**generate_kwargs,
|
|
447
|
+
) -> torch.LongTensor:
|
|
448
|
+
batch_size = pixel_values.shape[0]
|
|
449
|
+
image_embeds = self.vision_model(
|
|
450
|
+
pixel_values,
|
|
451
|
+
return_dict=True,
|
|
452
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
453
|
+
).last_hidden_state
|
|
454
|
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
455
|
+
|
|
456
|
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
457
|
+
query_outputs = self.qformer(
|
|
458
|
+
query_embeds=query_tokens,
|
|
459
|
+
encoder_hidden_states=image_embeds,
|
|
460
|
+
encoder_attention_mask=image_attention_mask,
|
|
461
|
+
return_dict=True,
|
|
462
|
+
)
|
|
463
|
+
query_output = query_outputs.last_hidden_state
|
|
464
|
+
|
|
465
|
+
if query_output.dtype != image_embeds.dtype:
|
|
466
|
+
query_output = query_output.to(image_embeds.dtype)
|
|
467
|
+
|
|
468
|
+
language_model_inputs = self.language_projection(query_output)
|
|
469
|
+
|
|
470
|
+
if inputs_embeds is None:
|
|
471
|
+
if input_ids is None:
|
|
472
|
+
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
|
473
|
+
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
|
474
|
+
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
|
475
|
+
input_ids = input_ids.repeat(batch_size, 1)
|
|
476
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
477
|
+
|
|
478
|
+
if attention_mask is None:
|
|
479
|
+
attention_mask = torch.ones_like(input_ids)
|
|
480
|
+
|
|
481
|
+
if input_ids is None:
|
|
482
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
483
|
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
484
|
+
)
|
|
485
|
+
special_image_mask = special_image_mask.all(-1)
|
|
486
|
+
else:
|
|
487
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
488
|
+
|
|
489
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
490
|
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
491
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
492
|
+
|
|
493
|
+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
|
494
|
+
if not self.language_model.config.is_encoder_decoder:
|
|
495
|
+
inputs["input_ids"] = input_ids
|
|
496
|
+
|
|
497
|
+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
|
498
|
+
|
|
499
|
+
return outputs
|