optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from transformers import BartForConditionalGeneration, PreTrainedModel
|
|
20
|
+
from transformers.modeling_outputs import Seq2SeqModelOutput
|
|
21
|
+
|
|
22
|
+
from ....utils.logging import get_logger
|
|
23
|
+
from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
|
|
24
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
|
25
|
+
from .bart_architecture import BartWrapper
|
|
26
|
+
from .configuration_bart import RBLNBartForConditionalGenerationConfig
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
logger = get_logger()
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
33
|
+
"""
|
|
34
|
+
RBLN optimized BART model for feature extraction tasks.
|
|
35
|
+
|
|
36
|
+
This class provides hardware-accelerated inference for BART encoder models
|
|
37
|
+
on RBLN devices, optimized for feature extraction use cases.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def forward(
|
|
41
|
+
self,
|
|
42
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
43
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> Union[Tuple, Seq2SeqModelOutput]:
|
|
46
|
+
"""
|
|
47
|
+
Forward pass for the RBLN-optimized BART model for feature extraction tasks.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
51
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a Seq2SeqModelOutput object.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
return super().forward(input_ids, attention_mask, **kwargs)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
|
61
|
+
"""
|
|
62
|
+
RBLN optimized BART model for conditional text generation tasks.
|
|
63
|
+
|
|
64
|
+
This class provides hardware-accelerated inference for BART models
|
|
65
|
+
on RBLN devices, supporting sequence-to-sequence generation tasks
|
|
66
|
+
such as summarization, translation, and text generation.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
support_causal_attn = True
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def _wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
|
|
73
|
+
return BartWrapper(
|
|
74
|
+
model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def __getattr__(self, __name: str) -> Any:
|
|
78
|
+
def redirect(func):
|
|
79
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
80
|
+
|
|
81
|
+
val = getattr(BartForConditionalGeneration, __name)
|
|
82
|
+
|
|
83
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
84
|
+
return redirect(val)
|
|
85
|
+
|
|
86
|
+
return val
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_bert import RBLNBertForMaskedLMConfig, RBLNBertForQuestionAnsweringConfig, RBLNBertModelConfig
|
|
16
|
+
from .modeling_bert import RBLNBertForMaskedLM, RBLNBertForQuestionAnswering, RBLNBertModel
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class BertModelWrapper(torch.nn.Module):
|
|
5
|
+
def __init__(self, model, rbln_config):
|
|
6
|
+
super().__init__()
|
|
7
|
+
self.model = model
|
|
8
|
+
self.rbln_config = rbln_config
|
|
9
|
+
|
|
10
|
+
def forward(self, *args, **kwargs):
|
|
11
|
+
output = self.model(*args, **kwargs)
|
|
12
|
+
if isinstance(output, torch.Tensor):
|
|
13
|
+
return output
|
|
14
|
+
elif isinstance(output, tuple):
|
|
15
|
+
return tuple(x for x in output if x is not None)
|
|
16
|
+
return output
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from ...configuration_generic import (
|
|
16
|
+
RBLNModelForMaskedLMConfig,
|
|
17
|
+
RBLNModelForQuestionAnsweringConfig,
|
|
18
|
+
RBLNTransformerEncoderForFeatureExtractionConfig,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
|
|
23
|
+
"""
|
|
24
|
+
Configuration class for RBLNBertModel.
|
|
25
|
+
|
|
26
|
+
This configuration class stores the configuration parameters specific to
|
|
27
|
+
RBLN-optimized BERT models for feature extraction tasks.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
|
|
32
|
+
"""
|
|
33
|
+
Configuration class for RBLNBertForMaskedLM.
|
|
34
|
+
|
|
35
|
+
This configuration class stores the configuration parameters specific to
|
|
36
|
+
RBLN-optimized BERT models for masked language modeling tasks.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
|
|
41
|
+
"""
|
|
42
|
+
Configuration class for RBLNBertForQuestionAnswering.
|
|
43
|
+
|
|
44
|
+
This configuration class stores the configuration parameters specific to
|
|
45
|
+
RBLN-optimized BERT models for question answering tasks.
|
|
46
|
+
"""
|
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple, Union
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from transformers.modeling_outputs import (
|
|
19
|
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
|
20
|
+
MaskedLMOutput,
|
|
21
|
+
QuestionAnsweringModelOutput,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from ...modeling_generic import (
|
|
25
|
+
RBLNModelForMaskedLM,
|
|
26
|
+
RBLNModelForQuestionAnswering,
|
|
27
|
+
RBLNTransformerEncoderForFeatureExtraction,
|
|
28
|
+
)
|
|
29
|
+
from .bert_architecture import BertModelWrapper
|
|
30
|
+
from .configuration_bert import RBLNBertModelConfig
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
|
|
34
|
+
"""
|
|
35
|
+
RBLN optimized BERT model for feature extraction tasks.
|
|
36
|
+
|
|
37
|
+
This class provides hardware-accelerated inference for BERT models
|
|
38
|
+
on RBLN devices, optimized for extracting contextualized embeddings
|
|
39
|
+
and features from text sequences.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNBertModelConfig) -> torch.nn.Module:
|
|
46
|
+
return BertModelWrapper(model, rbln_config)
|
|
47
|
+
|
|
48
|
+
def forward(
|
|
49
|
+
self,
|
|
50
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
51
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
52
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
53
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
54
|
+
**kwargs,
|
|
55
|
+
) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, Tuple]:
|
|
56
|
+
"""
|
|
57
|
+
Forward pass for the RBLN-optimized BERT model for feature extraction tasks.
|
|
58
|
+
|
|
59
|
+
Args:
|
|
60
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
61
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
62
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
63
|
+
position_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of positions of each input sequence tokens in the position embeddings.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
input_map = {
|
|
70
|
+
"input_ids": input_ids,
|
|
71
|
+
"attention_mask": attention_mask,
|
|
72
|
+
"token_type_ids": token_type_ids,
|
|
73
|
+
"position_ids": position_ids,
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
model_input_names = getattr(self.rbln_config, "model_input_names", None)
|
|
77
|
+
if model_input_names is None:
|
|
78
|
+
model_input_names = self.rbln_model_input_names
|
|
79
|
+
|
|
80
|
+
ordered_inputs = [input_map[name] for name in model_input_names if name in input_map]
|
|
81
|
+
|
|
82
|
+
return super().forward(*ordered_inputs, **kwargs)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
|
|
86
|
+
"""
|
|
87
|
+
RBLN optimized BERT model for masked language modeling tasks.
|
|
88
|
+
|
|
89
|
+
This class provides hardware-accelerated inference for BERT models
|
|
90
|
+
on RBLN devices, supporting masked language modeling tasks such as
|
|
91
|
+
token prediction and text completion.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
95
|
+
|
|
96
|
+
def forward(
|
|
97
|
+
self,
|
|
98
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
99
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
100
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
101
|
+
**kwargs,
|
|
102
|
+
) -> Union[MaskedLMOutput, Tuple]:
|
|
103
|
+
"""
|
|
104
|
+
Forward pass for the RBLN-optimized BERT model for masked language modeling tasks.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
108
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
109
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a MaskedLMOutput object.
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
|
|
119
|
+
"""
|
|
120
|
+
RBLN optimized BERT model for question answering tasks.
|
|
121
|
+
|
|
122
|
+
This class provides hardware-accelerated inference for BERT models
|
|
123
|
+
on RBLN devices, supporting extractive question answering tasks where
|
|
124
|
+
the model predicts start and end positions of answers in text.
|
|
125
|
+
"""
|
|
126
|
+
|
|
127
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
128
|
+
|
|
129
|
+
def forward(
|
|
130
|
+
self,
|
|
131
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
132
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
133
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
134
|
+
**kwargs,
|
|
135
|
+
) -> Union[QuestionAnsweringModelOutput, Tuple]:
|
|
136
|
+
"""
|
|
137
|
+
Forward pass for the RBLN-optimized BERT model for question answering tasks.
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
|
|
141
|
+
attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
|
|
142
|
+
token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a QuestionAnsweringModelOutput object.
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
return super().forward(input_ids, attention_mask, token_type_ids, **kwargs)
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_blip_2 import (
|
|
16
|
+
RBLNBlip2ForConditionalGenerationConfig,
|
|
17
|
+
RBLNBlip2QFormerModelConfig,
|
|
18
|
+
RBLNBlip2VisionModelConfig,
|
|
19
|
+
)
|
|
20
|
+
from .modeling_blip_2 import RBLNBlip2ForConditionalGeneration, RBLNBlip2QFormerModel, RBLNBlip2VisionModel
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Optional
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
logger = get_logger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RBLNBlip2VisionModelConfig(RBLNModelConfig):
|
|
25
|
+
"""
|
|
26
|
+
Configuration class for RBLNBlip2VisionModel.
|
|
27
|
+
|
|
28
|
+
This configuration class stores the configuration parameters specific to
|
|
29
|
+
RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
batch_size: Optional[int] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(**kwargs)
|
|
38
|
+
self.batch_size = batch_size or 1
|
|
39
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
40
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
|
|
44
|
+
"""
|
|
45
|
+
Configuration class for RBLNBlip2QFormerModel.
|
|
46
|
+
|
|
47
|
+
This configuration class stores the configuration parameters specific to
|
|
48
|
+
RBLN-optimized BLIP-2 Q-Former models that bridge vision and language modalities.
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
batch_size: Optional[int] = None,
|
|
54
|
+
num_query_tokens: Optional[int] = None,
|
|
55
|
+
image_text_hidden_size: Optional[int] = None,
|
|
56
|
+
**kwargs,
|
|
57
|
+
):
|
|
58
|
+
"""
|
|
59
|
+
Args:
|
|
60
|
+
num_query_tokens (Optional[int]): The number of query tokens passed through the Transformer.
|
|
61
|
+
image_text_hidden_size (Optional[int]): Dimensionality of the hidden state of the image-text fusion layer.
|
|
62
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
63
|
+
"""
|
|
64
|
+
super().__init__(**kwargs)
|
|
65
|
+
self.batch_size = batch_size or 1
|
|
66
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
67
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
68
|
+
|
|
69
|
+
self.num_query_tokens = num_query_tokens
|
|
70
|
+
self.image_text_hidden_size = image_text_hidden_size
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
|
|
74
|
+
"""
|
|
75
|
+
Configuration class for RBLNBlip2ForConditionalGeneration.
|
|
76
|
+
|
|
77
|
+
This configuration class stores the configuration parameters specific to
|
|
78
|
+
RBLN-optimized BLIP-2 models for conditional generation tasks that involve both image and text inputs.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
submodules = ["vision_model", "qformer", "language_model"]
|
|
82
|
+
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
batch_size: Optional[int] = None,
|
|
86
|
+
vision_model: Optional[RBLNModelConfig] = None,
|
|
87
|
+
qformer: Optional[RBLNModelConfig] = None,
|
|
88
|
+
language_model: Optional[RBLNModelConfig] = None,
|
|
89
|
+
**kwargs: Any,
|
|
90
|
+
):
|
|
91
|
+
"""
|
|
92
|
+
Args:
|
|
93
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
94
|
+
vision_model (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
|
|
95
|
+
qformer (Optional[RBLNModelConfig]): Configuration for the RBLN-optimized BLIP-2 Q-Former model.
|
|
96
|
+
language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
|
|
97
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
98
|
+
|
|
99
|
+
Raises:
|
|
100
|
+
ValueError: If batch_size is not a positive integer.
|
|
101
|
+
"""
|
|
102
|
+
super().__init__(**kwargs)
|
|
103
|
+
self.batch_size = batch_size or 1
|
|
104
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
105
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
106
|
+
|
|
107
|
+
if self.batch_size != 1:
|
|
108
|
+
logger.warning("Ignore batch_size for Blip2 vision model. It will be set to 1.")
|
|
109
|
+
logger.warning("Ignore batch_size for Blip2 qformer. It will be set to 1.")
|
|
110
|
+
|
|
111
|
+
self.vision_model = self.initialize_submodule_config(
|
|
112
|
+
submodule_config=vision_model, batch_size=1, force_kwargs=True
|
|
113
|
+
)
|
|
114
|
+
self.qformer = self.initialize_submodule_config(submodule_config=qformer, batch_size=1, force_kwargs=True)
|
|
115
|
+
self.language_model = self.initialize_submodule_config(submodule_config=language_model)
|