optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,527 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch import nn
|
|
19
|
+
from transformers.utils import logging
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
logger = logging.get_logger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Seq2SeqWrapper:
|
|
26
|
+
"""A wrapper class for Seq2Seq models to support RBLN-specific optimizations.
|
|
27
|
+
|
|
28
|
+
This wrapper divides the Seq2Seq model into separate encoder and decoder wrappers,
|
|
29
|
+
enabling specific optimizations such as custom cache handling and attention mechanisms.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
model (nn.Module): The Seq2Seq model to wrap.
|
|
33
|
+
enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
|
|
34
|
+
kwargs: Additional arguments to pass to the decoder wrapper.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
|
|
38
|
+
self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
|
|
39
|
+
self.decoder = Seq2SeqDecoderWrapper(model, **kwargs)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Seq2SeqEncoderWrapper(nn.Module):
|
|
43
|
+
"""A wrapper for the encoder component of a Seq2Seq model, designed for RBLN optimization.
|
|
44
|
+
|
|
45
|
+
This wrapper modifies the standard encoder-decoder architecture of Seq2Seq models to optimize
|
|
46
|
+
memory usage and attention mechanisms, particularly in cross-attention layers. It supports custom
|
|
47
|
+
cache handling to improve performance during decoding.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model (nn.Module): The Seq2Seq model containing the encoder.
|
|
51
|
+
enc_max_seq_len (int): Maximum sequence length for encoder embeddings and cache sizes.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.config = model.config
|
|
57
|
+
self.encoder = model.get_encoder()
|
|
58
|
+
self.encoder_max_length = enc_max_seq_len
|
|
59
|
+
self.__post_init__(model)
|
|
60
|
+
|
|
61
|
+
def __post_init__(self, model: nn.Module):
|
|
62
|
+
"""
|
|
63
|
+
Post-initialization to extract and configure encoder-related attributes.
|
|
64
|
+
|
|
65
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
|
66
|
+
by subclasses to modify or add custom attributes as necessary.
|
|
67
|
+
"""
|
|
68
|
+
self.n_layer = getattr(self.config, "decoder_layers", None)
|
|
69
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
|
|
70
|
+
self.num_heads = self.config.decoder_attention_heads
|
|
71
|
+
self.d_kv = self.config.d_model // self.num_heads
|
|
72
|
+
|
|
73
|
+
def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
|
|
74
|
+
"""
|
|
75
|
+
Extract cross-attention key and value projection layers from the decoder.
|
|
76
|
+
"""
|
|
77
|
+
return (
|
|
78
|
+
nn.ModuleList(decoder_layers[i].encoder_attn.k_proj for i in range(self.n_layer)),
|
|
79
|
+
nn.ModuleList(decoder_layers[i].encoder_attn.v_proj for i in range(self.n_layer)),
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def forward(
|
|
83
|
+
self,
|
|
84
|
+
input_ids: torch.Tensor,
|
|
85
|
+
attention_mask: torch.Tensor,
|
|
86
|
+
b_idx: torch.Tensor,
|
|
87
|
+
*cross_key_values: Tuple[torch.Tensor],
|
|
88
|
+
) -> Tuple[torch.Tensor]:
|
|
89
|
+
# 1. get encoder last_hidden_states
|
|
90
|
+
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
|
91
|
+
last_hidden_states = encoder_outputs[0]
|
|
92
|
+
|
|
93
|
+
# 2. pre-compute cross_attention's past_key_value which used in decoder phase.
|
|
94
|
+
cross_kv = []
|
|
95
|
+
for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
|
|
96
|
+
past_k = (
|
|
97
|
+
k_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
|
|
98
|
+
)
|
|
99
|
+
past_v = (
|
|
100
|
+
v_proj(last_hidden_states).view(1, self.encoder_max_length, self.num_heads, self.d_kv).transpose(1, 2)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
cross_kv.append(past_k)
|
|
104
|
+
cross_kv.append(past_v)
|
|
105
|
+
|
|
106
|
+
# 3. update the cross_attention's past_key_value direct to the device-dram for optimization.
|
|
107
|
+
batch_axis = torch.tensor(0, dtype=torch.int16)
|
|
108
|
+
cross_key_values = list(cross_key_values)
|
|
109
|
+
for i in range(self.n_layer * 2):
|
|
110
|
+
cross_key_values[i] = torch.ops.rbln_custom_ops.rbln_cache_update(
|
|
111
|
+
cross_key_values[i], cross_kv[i], b_idx[0], batch_axis
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return cross_key_values
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class Seq2SeqDecoderWrapper(nn.Module):
|
|
118
|
+
"""
|
|
119
|
+
A wrapper for the decoder component of a Seq2Seq model, designed for RBLN optimization.
|
|
120
|
+
|
|
121
|
+
This wrapper handles tasks such as:
|
|
122
|
+
1. Converting decoder components to support RBLN-specific conditional generation.
|
|
123
|
+
2. Customizing attention mechanisms, including self-attention and cross-attention.
|
|
124
|
+
3. Managing the decoder's key-value caches for both self and cross-attention.
|
|
125
|
+
|
|
126
|
+
Args:
|
|
127
|
+
model (nn.Module): The Seq2Seq model containing the decoder.
|
|
128
|
+
kwargs: Additional arguments for decoder configuration.
|
|
129
|
+
"""
|
|
130
|
+
|
|
131
|
+
def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
|
|
132
|
+
super().__init__()
|
|
133
|
+
self.config = model.config
|
|
134
|
+
self.use_attention_mask = use_attention_mask
|
|
135
|
+
self.__post_init__(model, **kwargs)
|
|
136
|
+
|
|
137
|
+
def __post_init__(self, model: nn.Module, **kwargs):
|
|
138
|
+
"""
|
|
139
|
+
Post-initialization to extract and configure encoder-related attributes.
|
|
140
|
+
|
|
141
|
+
It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
|
|
142
|
+
by subclasses to modify or add custom attributes as necessary.
|
|
143
|
+
"""
|
|
144
|
+
self.num_layers = self.config.decoder_layers
|
|
145
|
+
self.conditional_generation = self.convert_to_rbln_conditional_generation(model)
|
|
146
|
+
|
|
147
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module):
|
|
148
|
+
new_layers = []
|
|
149
|
+
for layer in model.get_decoder().layers:
|
|
150
|
+
self_attn = Seq2SeqSelfAttention(layer.self_attn)
|
|
151
|
+
cross_attn = Seq2SeqCrossAttention(layer.encoder_attn)
|
|
152
|
+
new_layers.append(Seq2SeqDecoderLayer(layer, self_attn, cross_attn))
|
|
153
|
+
|
|
154
|
+
decoder_model = Seq2SeqDecoder(model.get_decoder(), new_layers)
|
|
155
|
+
new_model = Seq2SeqForConditionalGeneration(model, decoder_model)
|
|
156
|
+
|
|
157
|
+
return new_model
|
|
158
|
+
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
*args,
|
|
162
|
+
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
|
|
163
|
+
if self.use_attention_mask:
|
|
164
|
+
(
|
|
165
|
+
input_ids,
|
|
166
|
+
attention_mask,
|
|
167
|
+
encoder_attention_mask,
|
|
168
|
+
cache_position,
|
|
169
|
+
block_tables,
|
|
170
|
+
*kv_cache,
|
|
171
|
+
) = args
|
|
172
|
+
|
|
173
|
+
else:
|
|
174
|
+
attention_mask = None
|
|
175
|
+
(input_ids, encoder_attention_mask, cache_position, block_tables, *kv_cache) = args
|
|
176
|
+
|
|
177
|
+
self_past_key_values = ()
|
|
178
|
+
cross_past_key_values = ()
|
|
179
|
+
self_kv_cache = kv_cache[self.num_layers * 2 :]
|
|
180
|
+
cross_kv_cache = kv_cache[: self.num_layers * 2]
|
|
181
|
+
for i in range(0, self.num_layers * 2, 2):
|
|
182
|
+
self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
|
|
183
|
+
cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
|
|
184
|
+
|
|
185
|
+
# decode
|
|
186
|
+
lm_logits = self.conditional_generation(
|
|
187
|
+
input_ids=input_ids,
|
|
188
|
+
attention_mask=attention_mask,
|
|
189
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
190
|
+
self_past_key_values=self_past_key_values,
|
|
191
|
+
cross_past_key_values=cross_past_key_values,
|
|
192
|
+
cache_position=cache_position,
|
|
193
|
+
block_tables=block_tables,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
return lm_logits
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class Seq2SeqForConditionalGeneration(nn.Module):
|
|
200
|
+
"""
|
|
201
|
+
A wrapper for Seq2Seq models supporting RBLN-specific optimizations for conditional generation.
|
|
202
|
+
|
|
203
|
+
This class adapts a Seq2Seq model for tasks like machine translation, summarization, or text generation
|
|
204
|
+
by:
|
|
205
|
+
1. Wrapping and customizing the decoder component to support key RBLN features.
|
|
206
|
+
2. Managing rescaling and output processing, if enabled.
|
|
207
|
+
3. Aligning model behavior with RBLN's static and efficient execution requirements.
|
|
208
|
+
|
|
209
|
+
Attributes:
|
|
210
|
+
has_rescaling (bool): Indicates if output rescaling is applied.
|
|
211
|
+
config (PretrainedConfig): Configuration from the original Seq2Seq model.
|
|
212
|
+
lm_head (nn.Linear): The language modeling head for output logits.
|
|
213
|
+
decoder (nn.Module): The wrapped decoder model.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
has_rescaling = False
|
|
217
|
+
|
|
218
|
+
def __init__(self, model, decoder_model):
|
|
219
|
+
super().__init__()
|
|
220
|
+
self.config = model.config
|
|
221
|
+
self.lm_head = model.lm_head
|
|
222
|
+
self.decoder = decoder_model
|
|
223
|
+
self.__post_init__()
|
|
224
|
+
|
|
225
|
+
def __post_init__(self):
|
|
226
|
+
"""
|
|
227
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
|
228
|
+
the attributes of the original model after initialization.
|
|
229
|
+
"""
|
|
230
|
+
|
|
231
|
+
def forward(
|
|
232
|
+
self,
|
|
233
|
+
input_ids,
|
|
234
|
+
attention_mask,
|
|
235
|
+
encoder_attention_mask,
|
|
236
|
+
self_past_key_values,
|
|
237
|
+
cross_past_key_values,
|
|
238
|
+
cache_position,
|
|
239
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
240
|
+
):
|
|
241
|
+
hidden_states = self.decoder(
|
|
242
|
+
input_ids=input_ids,
|
|
243
|
+
attention_mask=attention_mask,
|
|
244
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
245
|
+
self_past_key_values=self_past_key_values,
|
|
246
|
+
cross_past_key_values=cross_past_key_values,
|
|
247
|
+
cache_position=cache_position,
|
|
248
|
+
block_tables=block_tables,
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
if self.has_rescaling and self.config.tie_word_embeddings:
|
|
252
|
+
hidden_states = hidden_states * self.scaling
|
|
253
|
+
|
|
254
|
+
lm_logits = self.lm_head(hidden_states)
|
|
255
|
+
|
|
256
|
+
return lm_logits
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
class Seq2SeqDecoder(torch.nn.Module):
|
|
260
|
+
"""A modified Seq2SeqDecoder implementation optimized for RBLN compilation.
|
|
261
|
+
|
|
262
|
+
Args:
|
|
263
|
+
model: Original Huggingface model to adapt
|
|
264
|
+
layers (List[Seq2SeqDecoderLayer]): Modified transformer layers optimized for RBLN
|
|
265
|
+
"""
|
|
266
|
+
|
|
267
|
+
has_pos_emb = True
|
|
268
|
+
|
|
269
|
+
def __init__(self, model, layers, **kwargs):
|
|
270
|
+
super().__init__()
|
|
271
|
+
self._original_mod = model
|
|
272
|
+
self.layers = nn.ModuleList(layers)
|
|
273
|
+
self.embed_tokens = model.embed_tokens
|
|
274
|
+
self.final_layer_norm = getattr(model, "final_layer_norm", None)
|
|
275
|
+
self.__post_init__(**kwargs)
|
|
276
|
+
|
|
277
|
+
def __post_init__(self, **kwargs):
|
|
278
|
+
"""
|
|
279
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
|
280
|
+
the attributes of the original model after initialization.
|
|
281
|
+
"""
|
|
282
|
+
pass
|
|
283
|
+
|
|
284
|
+
def get_embedding(self):
|
|
285
|
+
return self.embed_tokens
|
|
286
|
+
|
|
287
|
+
def prepare_attn_mask(self, *args, **kwargs):
|
|
288
|
+
raise NotImplementedError(
|
|
289
|
+
"The 'prepare_attn_mask' method is not implemented. Please define this method in a subclass."
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
def apply_position_embedding(self, *args, **kwargs):
|
|
293
|
+
raise NotImplementedError(
|
|
294
|
+
"The 'apply_position_embedding' method is not implemented. Please define this method in a subclass."
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
def forward(
|
|
298
|
+
self,
|
|
299
|
+
input_ids: torch.Tensor,
|
|
300
|
+
attention_mask: torch.Tensor,
|
|
301
|
+
encoder_attention_mask: torch.Tensor,
|
|
302
|
+
self_past_key_values: torch.Tensor,
|
|
303
|
+
cross_past_key_values: torch.Tensor,
|
|
304
|
+
cache_position: torch.Tensor,
|
|
305
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
306
|
+
):
|
|
307
|
+
# embedding
|
|
308
|
+
hidden_states = self.get_embedding()(input_ids)
|
|
309
|
+
attention_mask, encoder_attention_mask = self.prepare_attn_mask(
|
|
310
|
+
attention_mask, encoder_attention_mask, cache_position=cache_position
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
if self.has_pos_emb:
|
|
314
|
+
hidden_states = self.apply_position_embedding(hidden_states, cache_position)
|
|
315
|
+
|
|
316
|
+
# iterate decoder_layer
|
|
317
|
+
for decoder_layer, self_past_key_value, cross_past_key_value in zip(
|
|
318
|
+
self.layers, self_past_key_values, cross_past_key_values
|
|
319
|
+
):
|
|
320
|
+
hidden_states = decoder_layer(
|
|
321
|
+
hidden_states,
|
|
322
|
+
attention_mask=attention_mask,
|
|
323
|
+
encoder_attention_mask=encoder_attention_mask,
|
|
324
|
+
self_past_key_value=self_past_key_value,
|
|
325
|
+
cross_past_key_value=cross_past_key_value,
|
|
326
|
+
cache_position=cache_position,
|
|
327
|
+
block_tables=block_tables,
|
|
328
|
+
)
|
|
329
|
+
|
|
330
|
+
if self.final_layer_norm is not None:
|
|
331
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
|
332
|
+
|
|
333
|
+
return hidden_states
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
class Seq2SeqDecoderLayer(torch.nn.Module):
|
|
337
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
|
338
|
+
|
|
339
|
+
Args:
|
|
340
|
+
model: Original Huggingface model to adapt
|
|
341
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
|
342
|
+
self_attn (Seq2SeqSelfAttention): Modified self-attention layer optimized for RBLN
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
def __init__(self, decoder_layer, self_attn, cross_attn):
|
|
346
|
+
super().__init__()
|
|
347
|
+
self._original_mod = decoder_layer
|
|
348
|
+
self.self_attn = self_attn
|
|
349
|
+
self.cross_attn = cross_attn
|
|
350
|
+
self.__post_init__()
|
|
351
|
+
|
|
352
|
+
def __post_init__(self, **kwargs):
|
|
353
|
+
"""
|
|
354
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
|
355
|
+
the attributes of the original model after initialization.
|
|
356
|
+
"""
|
|
357
|
+
pass
|
|
358
|
+
|
|
359
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
|
360
|
+
raise NotImplementedError(
|
|
361
|
+
"The 'pre_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
|
365
|
+
raise NotImplementedError(
|
|
366
|
+
"The 'post_self_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
|
370
|
+
raise NotImplementedError(
|
|
371
|
+
"The 'pre_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
|
375
|
+
raise NotImplementedError(
|
|
376
|
+
"The 'post_cross_attn_layer_norm' method is not implemented. Please define this method in a subclass."
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
def forward(
|
|
380
|
+
self,
|
|
381
|
+
hidden_states: torch.Tensor,
|
|
382
|
+
attention_mask: torch.Tensor,
|
|
383
|
+
encoder_attention_mask: torch.Tensor,
|
|
384
|
+
self_past_key_value: Tuple[torch.Tensor],
|
|
385
|
+
cross_past_key_value: Tuple[torch.Tensor],
|
|
386
|
+
cache_position: torch.Tensor,
|
|
387
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
388
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
|
389
|
+
dummy_encoder_hidden_states = torch.zeros(1, encoder_attention_mask.shape[-1])
|
|
390
|
+
|
|
391
|
+
# Self Attention Block
|
|
392
|
+
residual = hidden_states
|
|
393
|
+
hidden_states = self.pre_self_attn_layer_norm(hidden_states)
|
|
394
|
+
hidden_states = self.self_attn(
|
|
395
|
+
hidden_states=hidden_states,
|
|
396
|
+
past_key_value=self_past_key_value,
|
|
397
|
+
attention_mask=attention_mask,
|
|
398
|
+
cache_position=cache_position,
|
|
399
|
+
block_tables=block_tables,
|
|
400
|
+
)
|
|
401
|
+
hidden_states = residual + hidden_states
|
|
402
|
+
hidden_states = self.post_self_attn_layer_norm(hidden_states)
|
|
403
|
+
|
|
404
|
+
# Cross-Attention Block
|
|
405
|
+
residual = hidden_states
|
|
406
|
+
hidden_states = self.pre_cross_attn_layer_norm(hidden_states)
|
|
407
|
+
|
|
408
|
+
cross_attn_output = self.cross_attn(
|
|
409
|
+
hidden_states=hidden_states,
|
|
410
|
+
past_key_value=cross_past_key_value,
|
|
411
|
+
attention_mask=encoder_attention_mask,
|
|
412
|
+
key_value_states=dummy_encoder_hidden_states,
|
|
413
|
+
)
|
|
414
|
+
hidden_states = residual + cross_attn_output[0]
|
|
415
|
+
hidden_states = self.post_cross_attn_layer_norm(hidden_states)
|
|
416
|
+
|
|
417
|
+
# Feed-Forward Block
|
|
418
|
+
hidden_states = self.ff_layer(hidden_states)
|
|
419
|
+
|
|
420
|
+
return hidden_states
|
|
421
|
+
|
|
422
|
+
|
|
423
|
+
class Seq2SeqSelfAttention(nn.Module):
|
|
424
|
+
def __init__(self, attn, **kwargs):
|
|
425
|
+
super().__init__()
|
|
426
|
+
self._original_mod = attn
|
|
427
|
+
self.__post_init__(**kwargs)
|
|
428
|
+
|
|
429
|
+
def __post_init__(self, **kwargs):
|
|
430
|
+
"""
|
|
431
|
+
Abstract method intended to be overridden by subclasses to modify or override
|
|
432
|
+
the attributes of the original model after initialization.
|
|
433
|
+
"""
|
|
434
|
+
pass
|
|
435
|
+
|
|
436
|
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
|
|
437
|
+
return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
|
|
438
|
+
|
|
439
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
440
|
+
"""Projects input hidden states into query, key, and value representations.
|
|
441
|
+
|
|
442
|
+
Args:
|
|
443
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
444
|
+
|
|
445
|
+
Returns:
|
|
446
|
+
Tuple of (query_states, key_states, value_states)
|
|
447
|
+
"""
|
|
448
|
+
query_states = self.q_proj(hidden_states)
|
|
449
|
+
key_states = self.k_proj(hidden_states)
|
|
450
|
+
value_states = self.v_proj(hidden_states)
|
|
451
|
+
return query_states, key_states, value_states
|
|
452
|
+
|
|
453
|
+
def forward(
|
|
454
|
+
self,
|
|
455
|
+
hidden_states: torch.Tensor,
|
|
456
|
+
past_key_value: Tuple[torch.Tensor],
|
|
457
|
+
attention_mask: torch.Tensor,
|
|
458
|
+
cache_position: torch.Tensor,
|
|
459
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
460
|
+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
|
|
461
|
+
bsz, tgt_len, _ = hidden_states.size()
|
|
462
|
+
|
|
463
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
464
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
|
465
|
+
key_states = self._shape(key_states, -1, bsz)
|
|
466
|
+
value_states = self._shape(value_states, -1, bsz)
|
|
467
|
+
|
|
468
|
+
block_size = past_key_value[0].shape[-2]
|
|
469
|
+
args = [
|
|
470
|
+
query_states,
|
|
471
|
+
key_states,
|
|
472
|
+
value_states,
|
|
473
|
+
past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
474
|
+
past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
|
|
475
|
+
cache_position,
|
|
476
|
+
torch.tensor(1.0, dtype=torch.float32), # scale
|
|
477
|
+
block_tables,
|
|
478
|
+
block_size,
|
|
479
|
+
]
|
|
480
|
+
if attention_mask is not None:
|
|
481
|
+
args.insert(3, attention_mask.unsqueeze(2))
|
|
482
|
+
else:
|
|
483
|
+
args.append(None)
|
|
484
|
+
|
|
485
|
+
attn_output = self.attn_decode(*args)
|
|
486
|
+
|
|
487
|
+
attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
|
|
488
|
+
attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
|
|
489
|
+
|
|
490
|
+
attn_output = self.out_proj(attn_output)
|
|
491
|
+
|
|
492
|
+
return attn_output
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class Seq2SeqCrossAttention(nn.Module):
|
|
496
|
+
def __init__(self, attn, **kwargs):
|
|
497
|
+
super().__init__()
|
|
498
|
+
self._original_mod = attn
|
|
499
|
+
self.__post_init__(**kwargs)
|
|
500
|
+
|
|
501
|
+
def forward(
|
|
502
|
+
self,
|
|
503
|
+
hidden_states: torch.Tensor,
|
|
504
|
+
key_value_states: torch.Tensor = None,
|
|
505
|
+
past_key_value: Optional[object] = None,
|
|
506
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
507
|
+
):
|
|
508
|
+
bsz, tgt_len, _ = hidden_states.size()
|
|
509
|
+
query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
|
510
|
+
|
|
511
|
+
is_cross_attention = key_value_states is not None
|
|
512
|
+
if is_cross_attention:
|
|
513
|
+
key_states = past_key_value[0]
|
|
514
|
+
value_states = past_key_value[1]
|
|
515
|
+
|
|
516
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
517
|
+
query_states,
|
|
518
|
+
key_states,
|
|
519
|
+
value_states,
|
|
520
|
+
attn_mask=attention_mask,
|
|
521
|
+
)
|
|
522
|
+
|
|
523
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
524
|
+
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
|
525
|
+
attn_output = self.out_proj(attn_output)
|
|
526
|
+
|
|
527
|
+
return attn_output, None, past_key_value
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from .configuration_siglip import RBLNSiglipVisionModelConfig
|
|
16
|
+
from .modeling_siglip import RBLNSiglipVisionModel
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Optional
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class RBLNSiglipVisionModelConfig(RBLNModelConfig):
|
|
21
|
+
"""
|
|
22
|
+
Configuration class for RBLNSiglipVisionModel.
|
|
23
|
+
|
|
24
|
+
This configuration class stores the configuration parameters specific to
|
|
25
|
+
RBLN-optimized SigLIP vision models for image encoding in multimodal tasks.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
batch_size: Optional[int] = None,
|
|
31
|
+
image_size: Optional[int] = None,
|
|
32
|
+
interpolate_pos_encoding: Optional[bool] = None,
|
|
33
|
+
output_hidden_states: Optional[bool] = None,
|
|
34
|
+
output_attentions: Optional[bool] = None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
"""
|
|
38
|
+
Args:
|
|
39
|
+
batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
|
|
40
|
+
image_size (Optional[int]): The size of input images. Can be an integer for square images,
|
|
41
|
+
a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
|
|
42
|
+
interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
|
|
43
|
+
output_hidden_states: (Optional[bool]): Whether to return hidden states.
|
|
44
|
+
output_attentions: (Optional[bool]): Whether to return attentions.
|
|
45
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
46
|
+
|
|
47
|
+
Raises:
|
|
48
|
+
ValueError: If batch_size is not a positive integer.
|
|
49
|
+
"""
|
|
50
|
+
super().__init__(**kwargs)
|
|
51
|
+
self.batch_size = batch_size or 1
|
|
52
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
53
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
54
|
+
|
|
55
|
+
self.image_size = image_size
|
|
56
|
+
self.interpolate_pos_encoding = interpolate_pos_encoding or False
|
|
57
|
+
self.output_hidden_states = output_hidden_states
|
|
58
|
+
self.output_attentions = output_attentions
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def image_width(self):
|
|
62
|
+
if isinstance(self.image_size, int):
|
|
63
|
+
return self.image_size
|
|
64
|
+
elif isinstance(self.image_size, (list, tuple)):
|
|
65
|
+
return self.image_size[1]
|
|
66
|
+
else:
|
|
67
|
+
return self.image_size["width"]
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def image_height(self):
|
|
71
|
+
if isinstance(self.image_size, int):
|
|
72
|
+
return self.image_size
|
|
73
|
+
elif isinstance(self.image_size, (list, tuple)):
|
|
74
|
+
return self.image_size[0]
|
|
75
|
+
else:
|
|
76
|
+
return self.image_size["height"]
|