optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from transformers import (
|
|
21
|
+
AutoModelForVisualQuestionAnswering,
|
|
22
|
+
Blip2ForConditionalGeneration,
|
|
23
|
+
Blip2QFormerModel,
|
|
24
|
+
Blip2VisionModel,
|
|
25
|
+
PretrainedConfig,
|
|
26
|
+
PreTrainedModel,
|
|
27
|
+
)
|
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
|
|
29
|
+
from transformers.utils import logging
|
|
30
|
+
|
|
31
|
+
from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
|
|
32
|
+
from ....modeling import RBLNModel
|
|
33
|
+
from ...utils.rbln_runtime_wrapper import LoopProcessor
|
|
34
|
+
from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
logger = logging.get_logger(__name__)
|
|
38
|
+
|
|
39
|
+
if TYPE_CHECKING:
|
|
40
|
+
import rebel
|
|
41
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class LoopProjector(LoopProcessor):
|
|
45
|
+
def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
|
|
46
|
+
super().__init__(model=language_projection)
|
|
47
|
+
|
|
48
|
+
def _get_batch_size(self, query_output, **kwargs):
|
|
49
|
+
return query_output.shape[0]
|
|
50
|
+
|
|
51
|
+
def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
|
|
52
|
+
query_output_item = query_output[index : index + 1]
|
|
53
|
+
return ([query_output_item], {})
|
|
54
|
+
|
|
55
|
+
def _process_outputs(self, outputs: list, **kwargs):
|
|
56
|
+
output = torch.cat(outputs, dim=0)
|
|
57
|
+
return output
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RBLNBlip2VisionModel(RBLNModel):
|
|
61
|
+
"""
|
|
62
|
+
RBLN optimized BLIP-2 vision encoder model.
|
|
63
|
+
|
|
64
|
+
This class provides hardware-accelerated inference for BLIP-2 vision encoders
|
|
65
|
+
on RBLN devices, supporting image encoding for multimodal vision-language tasks.
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
_tp_support = False
|
|
69
|
+
|
|
70
|
+
def get_input_embeddings(self):
|
|
71
|
+
return self.embeddings
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
75
|
+
class Blip2VisionModelWrapper(torch.nn.Module):
|
|
76
|
+
def __init__(self, model: "Blip2VisionModel") -> None:
|
|
77
|
+
super().__init__()
|
|
78
|
+
self.model = model
|
|
79
|
+
|
|
80
|
+
def forward(self, *args, **kwargs):
|
|
81
|
+
kwargs.pop("return_dict", None)
|
|
82
|
+
return self.model(*args, **kwargs, return_dict=False)
|
|
83
|
+
|
|
84
|
+
return Blip2VisionModelWrapper(model).eval()
|
|
85
|
+
|
|
86
|
+
@classmethod
|
|
87
|
+
def _update_rbln_config(
|
|
88
|
+
cls,
|
|
89
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
90
|
+
model: Optional["PreTrainedModel"] = None,
|
|
91
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
92
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
93
|
+
) -> RBLNModelConfig:
|
|
94
|
+
input_info = [
|
|
95
|
+
(
|
|
96
|
+
"pixel_values",
|
|
97
|
+
[
|
|
98
|
+
rbln_config.batch_size,
|
|
99
|
+
model_config.num_channels,
|
|
100
|
+
model_config.image_size,
|
|
101
|
+
model_config.image_size,
|
|
102
|
+
],
|
|
103
|
+
"float32",
|
|
104
|
+
),
|
|
105
|
+
]
|
|
106
|
+
|
|
107
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
108
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
109
|
+
return rbln_config
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self,
|
|
113
|
+
pixel_values: torch.FloatTensor,
|
|
114
|
+
interpolate_pos_encoding: bool = False,
|
|
115
|
+
return_dict: Optional[bool] = None,
|
|
116
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
117
|
+
"""
|
|
118
|
+
Forward pass for the RBLN-optimized Blip2VisionModel model.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
|
|
122
|
+
interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
|
|
123
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
|
|
127
|
+
"""
|
|
128
|
+
batch_size = pixel_values.shape[0]
|
|
129
|
+
outputs = []
|
|
130
|
+
for i in range(batch_size):
|
|
131
|
+
outputs.append(self.model[0](pixel_values[i : i + 1]))
|
|
132
|
+
|
|
133
|
+
last_hidden_state = [output[0] for output in outputs]
|
|
134
|
+
pooler_output = [output[1] for output in outputs]
|
|
135
|
+
|
|
136
|
+
last_hidden_state = torch.cat(last_hidden_state, dim=0)
|
|
137
|
+
pooler_output = torch.cat(pooler_output, dim=0)
|
|
138
|
+
|
|
139
|
+
if not return_dict:
|
|
140
|
+
return (last_hidden_state, pooler_output)
|
|
141
|
+
|
|
142
|
+
return BaseModelOutputWithPooling(
|
|
143
|
+
last_hidden_state=last_hidden_state,
|
|
144
|
+
pooler_output=pooler_output,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class RBLNBlip2QFormerModel(RBLNModel):
|
|
149
|
+
"""
|
|
150
|
+
RBLN optimized BLIP-2 Q-Former model.
|
|
151
|
+
|
|
152
|
+
This class provides hardware-accelerated inference for BLIP-2 Q-Former models
|
|
153
|
+
on RBLN devices, which bridge vision and language modalities through cross-attention
|
|
154
|
+
mechanisms for multimodal understanding tasks.
|
|
155
|
+
"""
|
|
156
|
+
|
|
157
|
+
_tp_support = False
|
|
158
|
+
|
|
159
|
+
def get_input_embeddings(self):
|
|
160
|
+
return self.embeddings.word_embeddings
|
|
161
|
+
|
|
162
|
+
@classmethod
|
|
163
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
|
|
164
|
+
class Blip2QFormerModelWrapper(torch.nn.Module):
|
|
165
|
+
def __init__(self, model: "Blip2QFormerModel"):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.model = model
|
|
168
|
+
|
|
169
|
+
def forward(
|
|
170
|
+
self,
|
|
171
|
+
query_embeds: torch.FloatTensor,
|
|
172
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
173
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
174
|
+
) -> torch.Tensor:
|
|
175
|
+
qformer_out = self.model(
|
|
176
|
+
query_embeds=query_embeds,
|
|
177
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
178
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
179
|
+
return_dict=False,
|
|
180
|
+
)
|
|
181
|
+
return qformer_out
|
|
182
|
+
|
|
183
|
+
return Blip2QFormerModelWrapper(model).eval()
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def _update_submodule_config(
|
|
187
|
+
cls,
|
|
188
|
+
model: "PreTrainedModel",
|
|
189
|
+
rbln_config: RBLNModelConfig,
|
|
190
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
191
|
+
):
|
|
192
|
+
if rbln_config.num_query_tokens is None:
|
|
193
|
+
rbln_config.num_query_tokens = model.config.num_query_tokens
|
|
194
|
+
|
|
195
|
+
if rbln_config.image_text_hidden_size is None:
|
|
196
|
+
rbln_config.image_text_hidden_size = model.config.image_text_hidden_size
|
|
197
|
+
|
|
198
|
+
return rbln_config
|
|
199
|
+
|
|
200
|
+
@classmethod
|
|
201
|
+
def _update_rbln_config(
|
|
202
|
+
cls,
|
|
203
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
204
|
+
model: Optional["PreTrainedModel"] = None,
|
|
205
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
206
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
207
|
+
) -> RBLNModelConfig:
|
|
208
|
+
input_info = [
|
|
209
|
+
(
|
|
210
|
+
"query_embeds",
|
|
211
|
+
[
|
|
212
|
+
rbln_config.batch_size,
|
|
213
|
+
rbln_config.num_query_tokens,
|
|
214
|
+
model_config.hidden_size,
|
|
215
|
+
],
|
|
216
|
+
"float32",
|
|
217
|
+
),
|
|
218
|
+
(
|
|
219
|
+
"encoder_hidden_states",
|
|
220
|
+
[
|
|
221
|
+
rbln_config.batch_size,
|
|
222
|
+
# image_text_hidden_size + cls token
|
|
223
|
+
rbln_config.image_text_hidden_size + 1,
|
|
224
|
+
model_config.encoder_hidden_size,
|
|
225
|
+
],
|
|
226
|
+
"float32",
|
|
227
|
+
),
|
|
228
|
+
(
|
|
229
|
+
"encoder_attention_mask",
|
|
230
|
+
# image_text_hidden_size + cls token
|
|
231
|
+
[rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
|
|
232
|
+
"int64",
|
|
233
|
+
),
|
|
234
|
+
]
|
|
235
|
+
|
|
236
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
237
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
238
|
+
return rbln_config
|
|
239
|
+
|
|
240
|
+
def forward(
|
|
241
|
+
self,
|
|
242
|
+
query_embeds: torch.FloatTensor,
|
|
243
|
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
|
244
|
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
|
245
|
+
return_dict: Optional[bool] = None,
|
|
246
|
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
|
247
|
+
"""
|
|
248
|
+
The forward pass for the RBLN-optimized Blip2QFormerModel model.
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
|
|
252
|
+
encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
|
|
253
|
+
encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
|
|
254
|
+
return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
|
|
258
|
+
"""
|
|
259
|
+
batch_size = query_embeds.shape[0]
|
|
260
|
+
outputs = []
|
|
261
|
+
for i in range(batch_size):
|
|
262
|
+
outputs.append(
|
|
263
|
+
self.model[0](
|
|
264
|
+
query_embeds[i : i + 1], encoder_hidden_states[i : i + 1], encoder_attention_mask[i : i + 1]
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
sequence_output = [output[0] for output in outputs]
|
|
269
|
+
pooled_output = [output[1] for output in outputs]
|
|
270
|
+
|
|
271
|
+
sequence_output = torch.cat(sequence_output, dim=0)
|
|
272
|
+
pooled_output = torch.cat(pooled_output, dim=0)
|
|
273
|
+
|
|
274
|
+
if not return_dict:
|
|
275
|
+
return (sequence_output, pooled_output)
|
|
276
|
+
|
|
277
|
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
|
278
|
+
last_hidden_state=sequence_output,
|
|
279
|
+
pooler_output=pooled_output,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
|
|
284
|
+
"""
|
|
285
|
+
RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
|
|
286
|
+
optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
|
|
287
|
+
|
|
288
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
|
289
|
+
|
|
290
|
+
Important Note:
|
|
291
|
+
This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
|
|
292
|
+
tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
|
|
293
|
+
`from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNBlip2ForConditionalGeneration class for details.
|
|
294
|
+
|
|
295
|
+
Examples:
|
|
296
|
+
```python
|
|
297
|
+
from optimum.rbln import RBLNBlip2ForConditionalGeneration
|
|
298
|
+
|
|
299
|
+
model = RBLNBlip2ForConditionalGeneration.from_pretrained(
|
|
300
|
+
"Salesforce/blip2-opt-2.7b",
|
|
301
|
+
export=True,
|
|
302
|
+
rbln_config={
|
|
303
|
+
"language_model": {
|
|
304
|
+
"batch_size": 1,
|
|
305
|
+
"max_seq_len": 2048,
|
|
306
|
+
"tensor_parallel_size": 1,
|
|
307
|
+
"use_inputs_embeds": True,
|
|
308
|
+
},
|
|
309
|
+
},
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
model.save_pretrained("compiled-blip2-opt-2.7b")
|
|
313
|
+
```
|
|
314
|
+
"""
|
|
315
|
+
|
|
316
|
+
auto_model_class = AutoModelForVisualQuestionAnswering
|
|
317
|
+
_rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
|
|
318
|
+
|
|
319
|
+
def __getattr__(self, __name: str) -> Any:
|
|
320
|
+
def redirect(func):
|
|
321
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
322
|
+
|
|
323
|
+
val = getattr(Blip2ForConditionalGeneration, __name)
|
|
324
|
+
|
|
325
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
326
|
+
return redirect(val)
|
|
327
|
+
return val
|
|
328
|
+
|
|
329
|
+
def can_generate(self):
|
|
330
|
+
return True
|
|
331
|
+
|
|
332
|
+
@classmethod
|
|
333
|
+
def save_torch_artifacts(
|
|
334
|
+
cls,
|
|
335
|
+
model: "Blip2ForConditionalGeneration",
|
|
336
|
+
save_dir_path: Path,
|
|
337
|
+
subfolder: str,
|
|
338
|
+
rbln_config: RBLNModelConfig,
|
|
339
|
+
):
|
|
340
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
341
|
+
# store the torch tensor, weight, etc. in this function.
|
|
342
|
+
|
|
343
|
+
save_dict = {}
|
|
344
|
+
save_dict["query_tokens"] = model.query_tokens
|
|
345
|
+
torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
|
|
346
|
+
|
|
347
|
+
def __post_init__(self, **kwargs):
|
|
348
|
+
self.vision_model = self.rbln_submodules[0]
|
|
349
|
+
self.language_model = self.rbln_submodules[2]
|
|
350
|
+
self.qformer = self.rbln_submodules[1]
|
|
351
|
+
self.language_projection = LoopProjector(self.model[0])
|
|
352
|
+
|
|
353
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "query_tokens.pth", weights_only=False)
|
|
354
|
+
self.query_tokens = artifacts["query_tokens"]
|
|
355
|
+
|
|
356
|
+
def get_attn_impl(self) -> str:
|
|
357
|
+
return self.rbln_config.language_model.attn_impl
|
|
358
|
+
|
|
359
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
360
|
+
return self.rbln_config.language_model.kvcache_num_blocks
|
|
361
|
+
|
|
362
|
+
def get_input_embeddings(self):
|
|
363
|
+
return self.language_model.get_input_embeddings()
|
|
364
|
+
|
|
365
|
+
@classmethod
|
|
366
|
+
def _wrap_model_if_needed(cls, model, rbln_config):
|
|
367
|
+
return model.language_projection
|
|
368
|
+
|
|
369
|
+
@classmethod
|
|
370
|
+
def _update_rbln_config(
|
|
371
|
+
cls,
|
|
372
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
|
373
|
+
model: Optional["PreTrainedModel"] = None,
|
|
374
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
375
|
+
rbln_config: Optional[RBLNModelConfig] = None,
|
|
376
|
+
) -> RBLNModelConfig:
|
|
377
|
+
input_info = [
|
|
378
|
+
(
|
|
379
|
+
"query_output",
|
|
380
|
+
[
|
|
381
|
+
1,
|
|
382
|
+
model_config.num_query_tokens,
|
|
383
|
+
model_config.qformer_config.hidden_size,
|
|
384
|
+
],
|
|
385
|
+
"float32",
|
|
386
|
+
),
|
|
387
|
+
]
|
|
388
|
+
|
|
389
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
|
390
|
+
rbln_config.set_compile_cfgs([rbln_compile_config])
|
|
391
|
+
|
|
392
|
+
return rbln_config
|
|
393
|
+
|
|
394
|
+
def _preprocess_prefill(
|
|
395
|
+
self,
|
|
396
|
+
pixel_values: torch.FloatTensor,
|
|
397
|
+
input_ids: torch.FloatTensor,
|
|
398
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
399
|
+
return_dict: Optional[bool] = None,
|
|
400
|
+
interpolate_pos_encoding: bool = False,
|
|
401
|
+
**kwargs,
|
|
402
|
+
):
|
|
403
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
404
|
+
|
|
405
|
+
vision_outputs = self.vision_model(
|
|
406
|
+
pixel_values=pixel_values,
|
|
407
|
+
return_dict=return_dict,
|
|
408
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
409
|
+
)
|
|
410
|
+
image_embeds = vision_outputs[0]
|
|
411
|
+
|
|
412
|
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
413
|
+
|
|
414
|
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
415
|
+
|
|
416
|
+
query_outputs = self.qformer(
|
|
417
|
+
query_embeds=query_tokens,
|
|
418
|
+
encoder_hidden_states=image_embeds,
|
|
419
|
+
encoder_attention_mask=image_attention_mask,
|
|
420
|
+
return_dict=return_dict,
|
|
421
|
+
)
|
|
422
|
+
query_output = query_outputs[0]
|
|
423
|
+
|
|
424
|
+
if query_output.dtype != image_embeds.dtype:
|
|
425
|
+
query_output = query_output.to(image_embeds.dtype)
|
|
426
|
+
|
|
427
|
+
language_model_inputs = self.language_projection(query_output)
|
|
428
|
+
language_model_attention_mask = torch.ones(
|
|
429
|
+
language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
|
|
430
|
+
)
|
|
431
|
+
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
|
|
432
|
+
if attention_mask is None:
|
|
433
|
+
attention_mask = torch.ones_like(input_ids)
|
|
434
|
+
|
|
435
|
+
if getattr(self.config, "image_token_index", None) is not None:
|
|
436
|
+
special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
|
|
437
|
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
438
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
439
|
+
else:
|
|
440
|
+
logger.warning_once(
|
|
441
|
+
"Expanding inputs for image tokens in BLIP-2 should be done in processing. "
|
|
442
|
+
"Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
|
|
443
|
+
"Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
|
|
444
|
+
)
|
|
445
|
+
inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
|
|
446
|
+
attention_mask = torch.cat(
|
|
447
|
+
[language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
return inputs_embeds
|
|
451
|
+
|
|
452
|
+
@torch.no_grad()
|
|
453
|
+
def generate(
|
|
454
|
+
self,
|
|
455
|
+
pixel_values: torch.FloatTensor,
|
|
456
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
457
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
458
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
459
|
+
interpolate_pos_encoding: bool = False,
|
|
460
|
+
**generate_kwargs,
|
|
461
|
+
) -> List[torch.LongTensor]:
|
|
462
|
+
"""
|
|
463
|
+
The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
|
|
464
|
+
Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
|
|
465
|
+
|
|
466
|
+
Args:
|
|
467
|
+
pixel_values (torch.FloatTensor): Input images to be processed.
|
|
468
|
+
input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
|
|
469
|
+
attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
|
|
470
|
+
inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
|
|
471
|
+
interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
|
|
472
|
+
Returns:
|
|
473
|
+
A list of strings of length batch_size * num_captions.
|
|
474
|
+
"""
|
|
475
|
+
batch_size = pixel_values.shape[0]
|
|
476
|
+
image_embeds = self.vision_model(
|
|
477
|
+
pixel_values,
|
|
478
|
+
return_dict=True,
|
|
479
|
+
interpolate_pos_encoding=interpolate_pos_encoding,
|
|
480
|
+
).last_hidden_state
|
|
481
|
+
image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
|
|
482
|
+
|
|
483
|
+
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
|
|
484
|
+
query_outputs = self.qformer(
|
|
485
|
+
query_embeds=query_tokens,
|
|
486
|
+
encoder_hidden_states=image_embeds,
|
|
487
|
+
encoder_attention_mask=image_attention_mask,
|
|
488
|
+
return_dict=True,
|
|
489
|
+
)
|
|
490
|
+
query_output = query_outputs.last_hidden_state
|
|
491
|
+
|
|
492
|
+
if query_output.dtype != image_embeds.dtype:
|
|
493
|
+
query_output = query_output.to(image_embeds.dtype)
|
|
494
|
+
|
|
495
|
+
language_model_inputs = self.language_projection(query_output)
|
|
496
|
+
|
|
497
|
+
if inputs_embeds is None:
|
|
498
|
+
if input_ids is None:
|
|
499
|
+
image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
|
|
500
|
+
start_tokens = image_tokens + [self.config.text_config.bos_token_id]
|
|
501
|
+
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
|
|
502
|
+
input_ids = input_ids.repeat(batch_size, 1)
|
|
503
|
+
inputs_embeds = self.get_input_embeddings()(input_ids)
|
|
504
|
+
|
|
505
|
+
if attention_mask is None:
|
|
506
|
+
attention_mask = torch.ones_like(input_ids)
|
|
507
|
+
|
|
508
|
+
if input_ids is None:
|
|
509
|
+
special_image_mask = inputs_embeds == self.get_input_embeddings()(
|
|
510
|
+
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
|
|
511
|
+
)
|
|
512
|
+
special_image_mask = special_image_mask.all(-1)
|
|
513
|
+
else:
|
|
514
|
+
special_image_mask = input_ids == self.config.image_token_id
|
|
515
|
+
|
|
516
|
+
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
|
|
517
|
+
language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
|
|
518
|
+
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
|
|
519
|
+
|
|
520
|
+
inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
|
|
521
|
+
if not self.language_model.config.is_encoder_decoder:
|
|
522
|
+
inputs["input_ids"] = input_ids
|
|
523
|
+
|
|
524
|
+
outputs = self.language_model.generate(**inputs, **generate_kwargs)
|
|
525
|
+
|
|
526
|
+
return outputs
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_clip import (
|
|
16
|
+
RBLNCLIPTextModelConfig,
|
|
17
|
+
RBLNCLIPTextModelWithProjectionConfig,
|
|
18
|
+
RBLNCLIPVisionModelConfig,
|
|
19
|
+
RBLNCLIPVisionModelWithProjectionConfig,
|
|
20
|
+
)
|
|
21
|
+
from .modeling_clip import (
|
|
22
|
+
RBLNCLIPTextModel,
|
|
23
|
+
RBLNCLIPTextModelWithProjection,
|
|
24
|
+
RBLNCLIPVisionModel,
|
|
25
|
+
RBLNCLIPVisionModelWithProjection,
|
|
26
|
+
)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNCLIPTextModelConfig(RBLNModelConfig):
|
|
21
|
+
def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
|
|
22
|
+
"""
|
|
23
|
+
Args:
|
|
24
|
+
batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
|
|
25
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
26
|
+
|
|
27
|
+
Raises:
|
|
28
|
+
ValueError: If `batch_size` is not a positive integer.
|
|
29
|
+
"""
|
|
30
|
+
super().__init__(**kwargs)
|
|
31
|
+
self.batch_size = batch_size or 1
|
|
32
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
33
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
|
|
37
|
+
"""
|
|
38
|
+
Configuration class for RBLNCLIPTextModelWithProjection.
|
|
39
|
+
|
|
40
|
+
This configuration inherits from RBLNCLIPTextModelConfig and stores
|
|
41
|
+
configuration parameters for CLIP text models with projection layers.
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class RBLNCLIPVisionModelConfig(RBLNModelConfig):
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
batch_size: Optional[int] = None,
|
|
49
|
+
image_size: Optional[int] = None,
|
|
50
|
+
interpolate_pos_encoding: Optional[bool] = None,
|
|
51
|
+
output_hidden_states: Optional[bool] = None,
|
|
52
|
+
output_attentions: Optional[bool] = None,
|
|
53
|
+
**kwargs: Any,
|
|
54
|
+
):
|
|
55
|
+
"""
|
|
56
|
+
Args:
|
|
57
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
|
58
|
+
image_size (Optional[int]): The size of input images. Can be an integer for square images,
|
|
59
|
+
a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
|
|
60
|
+
interpolate_pos_encoding (Optional[bool]): Whether or not to interpolate pre-trained position encodings. Defaults to `False`.
|
|
61
|
+
output_hidden_states (Optional[bool]): Whether or not to return the hidden states of all layers.
|
|
62
|
+
output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers
|
|
63
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
64
|
+
|
|
65
|
+
Raises:
|
|
66
|
+
ValueError: If `batch_size` is not a positive integer.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__(**kwargs)
|
|
69
|
+
self.batch_size = batch_size or 1
|
|
70
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
71
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
72
|
+
|
|
73
|
+
self.image_size = image_size
|
|
74
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding or False
|
|
75
|
+
self.output_hidden_states = output_hidden_states
|
|
76
|
+
self.output_attentions = output_attentions
|
|
77
|
+
|
|
78
|
+
@property
|
|
79
|
+
def image_width(self):
|
|
80
|
+
if isinstance(self.image_size, int):
|
|
81
|
+
return self.image_size
|
|
82
|
+
elif isinstance(self.image_size, (list, tuple)):
|
|
83
|
+
return self.image_size[1]
|
|
84
|
+
else:
|
|
85
|
+
return self.image_size["width"]
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def image_height(self):
|
|
89
|
+
if isinstance(self.image_size, int):
|
|
90
|
+
return self.image_size
|
|
91
|
+
elif isinstance(self.image_size, (list, tuple)):
|
|
92
|
+
return self.image_size[0]
|
|
93
|
+
else:
|
|
94
|
+
return self.image_size["height"]
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
|
|
98
|
+
"""
|
|
99
|
+
Configuration class for RBLNCLIPVisionModelWithProjection.
|
|
100
|
+
|
|
101
|
+
This configuration inherits from RBLNCLIPVisionModelConfig and stores
|
|
102
|
+
configuration parameters for CLIP vision models with projection layers.
|
|
103
|
+
"""
|