optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,385 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections import Counter, defaultdict
|
|
3
|
+
from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
|
4
|
+
|
|
5
|
+
import rebel
|
|
6
|
+
|
|
7
|
+
from ..utils.logging import get_logger
|
|
8
|
+
from ..utils.runtime_utils import get_available_dram
|
|
9
|
+
from .models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
logger = get_logger()
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
19
|
+
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
20
|
+
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
21
|
+
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
22
|
+
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
23
|
+
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_default_values(
|
|
27
|
+
attn_impl: Optional[str] = None,
|
|
28
|
+
kvcache_partition_len: Optional[int] = None,
|
|
29
|
+
kvcache_block_size: Optional[int] = None,
|
|
30
|
+
max_seq_len: Optional[int] = None,
|
|
31
|
+
) -> Tuple[str, int, int]:
|
|
32
|
+
if attn_impl is None:
|
|
33
|
+
attn_impl = "eager"
|
|
34
|
+
|
|
35
|
+
if kvcache_partition_len is not None:
|
|
36
|
+
if attn_impl == "eager":
|
|
37
|
+
attn_impl = "flash_attn"
|
|
38
|
+
logger.warning(
|
|
39
|
+
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
40
|
+
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
41
|
+
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
45
|
+
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
46
|
+
|
|
47
|
+
if kvcache_block_size is None:
|
|
48
|
+
if attn_impl == "eager":
|
|
49
|
+
kvcache_block_size = max_seq_len
|
|
50
|
+
else:
|
|
51
|
+
kvcache_block_size = kvcache_partition_len
|
|
52
|
+
|
|
53
|
+
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
57
|
+
if attn_impl not in ["eager", "flash_attn"]:
|
|
58
|
+
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
59
|
+
|
|
60
|
+
## Checking Constraints...
|
|
61
|
+
# Constraint of eager attention:
|
|
62
|
+
# - `max_seq_len` <= 32k
|
|
63
|
+
|
|
64
|
+
# Constraints of flash attention:
|
|
65
|
+
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
66
|
+
# 2. 4k <= `partition_len` <= 32k.
|
|
67
|
+
# 3. `max_seq_len` should be larger then 8k.
|
|
68
|
+
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
69
|
+
raise ValueError(
|
|
70
|
+
f"`max_seq_len` is set to {max_seq_len}, "
|
|
71
|
+
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
72
|
+
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
73
|
+
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if attn_impl == "flash_attn":
|
|
77
|
+
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
78
|
+
raise ValueError(
|
|
79
|
+
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
80
|
+
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
81
|
+
)
|
|
82
|
+
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
85
|
+
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
86
|
+
f"Please provide a valid value within this range."
|
|
87
|
+
)
|
|
88
|
+
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
89
|
+
raise ValueError(
|
|
90
|
+
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
91
|
+
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
92
|
+
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
if kvcache_block_size is not None:
|
|
96
|
+
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
97
|
+
raise ValueError(
|
|
98
|
+
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
99
|
+
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
100
|
+
)
|
|
101
|
+
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
102
|
+
raise ValueError(
|
|
103
|
+
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
104
|
+
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def validate_sliding_window(rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
109
|
+
if rbln_config.sliding_window > MAX_SLIDING_WINDOW_SIZE - rbln_config.prefill_chunk_size:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Sliding window size ({rbln_config.sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - rbln_config.prefill_chunk_size})"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if rbln_config.cache_impl == "sliding_window" and rbln_config.use_attention_mask:
|
|
115
|
+
raise ValueError("`use_attention_mask` must be set to False when `cache_impl` is set to 'sliding_window'.")
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def align(x: int, nbytes: int) -> int:
|
|
119
|
+
return int(math.ceil(x / nbytes) * nbytes)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def align_2MB(x: int) -> int:
|
|
123
|
+
return align(x, 2**21)
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def get_alloc_memory_by_key(compiled_models: Dict[str, "rebel.RBLNCompiledModel"]) -> Dict[str, int]:
|
|
127
|
+
alloc_memory_by_key = defaultdict(int)
|
|
128
|
+
# Get the actual memory allocation of each node by key
|
|
129
|
+
for compiled_model in compiled_models.values():
|
|
130
|
+
alloc_per_node_by_key = compiled_model.get_alloc_per_node_by_key()
|
|
131
|
+
for key, memory_per_node in alloc_per_node_by_key.items():
|
|
132
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
133
|
+
|
|
134
|
+
return alloc_memory_by_key
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def format_byte_size(nbytes: int) -> str:
|
|
138
|
+
if nbytes < 1024:
|
|
139
|
+
return f"{nbytes} B"
|
|
140
|
+
elif nbytes < 1024**2:
|
|
141
|
+
return f"{nbytes / 1024:.2f} KB"
|
|
142
|
+
elif nbytes < 1024**3:
|
|
143
|
+
return f"{nbytes / 1024**2:.2f} MB"
|
|
144
|
+
else:
|
|
145
|
+
return f"{nbytes / 1024**3:.2f} GB"
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class RBLNDecoderOnlyFlashAttentionMixin:
|
|
149
|
+
@classmethod
|
|
150
|
+
def get_maximum_num_blocks_by_model(
|
|
151
|
+
cls,
|
|
152
|
+
model: "PreTrainedModel",
|
|
153
|
+
model_config: "PretrainedConfig",
|
|
154
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
155
|
+
) -> int:
|
|
156
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
157
|
+
available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
|
|
158
|
+
|
|
159
|
+
kernel_memory = cls._get_kernel_memory(model, model_config=model_config, rbln_config=rbln_config)
|
|
160
|
+
buffer = cls._get_buffer(rbln_config)
|
|
161
|
+
|
|
162
|
+
remaining_dram = available_dram - kernel_memory - buffer
|
|
163
|
+
if remaining_dram <= 0:
|
|
164
|
+
raise ValueError(
|
|
165
|
+
"Insufficient available DRAM after accounting for kernel memory and buffer. "
|
|
166
|
+
"Cannot allocate any KV cache blocks."
|
|
167
|
+
f" (Available DRAM: {format_byte_size(available_dram)}, "
|
|
168
|
+
f"Kernel Memory: {format_byte_size(kernel_memory)}, "
|
|
169
|
+
f"Buffer: {format_byte_size(buffer)})"
|
|
170
|
+
)
|
|
171
|
+
estimated_num_blocks = cls._estimate_num_blocks(
|
|
172
|
+
remaining_dram, model_config=model_config, rbln_config=rbln_config
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
return estimated_num_blocks
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def _get_kernel_memory(
|
|
179
|
+
cls,
|
|
180
|
+
model: "PreTrainedModel",
|
|
181
|
+
model_config: "PretrainedConfig",
|
|
182
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
183
|
+
) -> int:
|
|
184
|
+
if model.get_output_embeddings() is None:
|
|
185
|
+
lm_head_nbytes = 0
|
|
186
|
+
else:
|
|
187
|
+
lm_head_nbytes = cls._get_lm_head_memory(model_config, rbln_config)
|
|
188
|
+
|
|
189
|
+
layer_nbytes = cls._get_layer_memory(model, model_config, rbln_config)
|
|
190
|
+
return lm_head_nbytes + layer_nbytes
|
|
191
|
+
|
|
192
|
+
@classmethod
|
|
193
|
+
def _get_lm_head_memory(
|
|
194
|
+
cls, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
195
|
+
) -> int:
|
|
196
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
197
|
+
vocab_size = model_config.vocab_size
|
|
198
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
199
|
+
lm_head_params = align(vocab_size, 64) * hidden_size
|
|
200
|
+
|
|
201
|
+
nbytes_per_param = 2 # Assuming lm_head is always not quantized
|
|
202
|
+
lm_head_memory_in_bytes = (
|
|
203
|
+
align_2MB(lm_head_params * nbytes_per_param / tensor_parallel_size) * tensor_parallel_size
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
return lm_head_memory_in_bytes
|
|
207
|
+
|
|
208
|
+
@classmethod
|
|
209
|
+
def _get_layer_memory(
|
|
210
|
+
cls,
|
|
211
|
+
model: "PreTrainedModel",
|
|
212
|
+
model_config: "PretrainedConfig",
|
|
213
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
214
|
+
) -> int:
|
|
215
|
+
# This is an *APPROXIMATE* calculation based on the number of parameters
|
|
216
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
217
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
218
|
+
|
|
219
|
+
n_model_params = sum(p.numel() for p in model.parameters())
|
|
220
|
+
embed_token_params = sum(p.numel() for p in model.get_input_embeddings().parameters())
|
|
221
|
+
|
|
222
|
+
# Check : `embed_token` is same as `lm_head`
|
|
223
|
+
if model.get_output_embeddings() is not None:
|
|
224
|
+
params = n_model_params - 2 * embed_token_params
|
|
225
|
+
else:
|
|
226
|
+
params = n_model_params - embed_token_params
|
|
227
|
+
|
|
228
|
+
# Assuming all layers have the same number of parameters
|
|
229
|
+
# and all linear layers are quantized if quantization is enabled (This is not always true)
|
|
230
|
+
# TODO(jongho): More accurate calculation
|
|
231
|
+
nbits_per_param = rbln_config.nbits_per_param
|
|
232
|
+
layer_nbytes = (
|
|
233
|
+
(align_2MB(params // num_hidden_layers * nbits_per_param // 8 / tensor_parallel_size))
|
|
234
|
+
* num_hidden_layers
|
|
235
|
+
* tensor_parallel_size
|
|
236
|
+
)
|
|
237
|
+
|
|
238
|
+
return layer_nbytes
|
|
239
|
+
|
|
240
|
+
@classmethod
|
|
241
|
+
def _get_buffer(cls, rbln_config) -> int:
|
|
242
|
+
# TODO(jongho): Accurate buffer estimation
|
|
243
|
+
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
|
|
244
|
+
num_runtimes = 1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes)
|
|
245
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
246
|
+
|
|
247
|
+
buffer_per_core = buffer_per_runtime_per_core * num_runtimes
|
|
248
|
+
buffer = buffer_per_core * tensor_parallel_size
|
|
249
|
+
return buffer
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def get_maximum_num_blocks_by_compiled_model(
|
|
253
|
+
cls,
|
|
254
|
+
compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
|
|
255
|
+
model_config: "PretrainedConfig",
|
|
256
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
257
|
+
) -> int:
|
|
258
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
259
|
+
available_dram = get_available_dram(rbln_config.npu) * tensor_parallel_size
|
|
260
|
+
|
|
261
|
+
alloc_memory_by_key = get_alloc_memory_by_key(compiled_models)
|
|
262
|
+
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
263
|
+
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
264
|
+
used_memory = sum(alloc_memory_by_key.values())
|
|
265
|
+
|
|
266
|
+
remaining_dram = available_dram - used_memory
|
|
267
|
+
|
|
268
|
+
if remaining_dram <= 0:
|
|
269
|
+
logger.warning(
|
|
270
|
+
"Insufficient available DRAM after accounting for kernel memory and buffer. "
|
|
271
|
+
"Model cannot allocate any KV cache blocks."
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
estimated_num_blocks = cls._estimate_num_blocks(
|
|
275
|
+
remaining_dram, model_config=model_config, rbln_config=rbln_config
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
return estimated_num_blocks
|
|
279
|
+
|
|
280
|
+
@classmethod
|
|
281
|
+
def _estimate_num_blocks(
|
|
282
|
+
cls, available_dram: int, model_config: "PretrainedConfig", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
283
|
+
) -> int:
|
|
284
|
+
"""
|
|
285
|
+
Estimate the maximum number of KV cache blocks that can be allocated.
|
|
286
|
+
|
|
287
|
+
if all of the layers are full attention, the dram_per_block can be calculated simply as follows:
|
|
288
|
+
num_blocks = available_dram // dram_per_block
|
|
289
|
+
|
|
290
|
+
However, if the model contains a mix of full attention and sliding window attention layers,
|
|
291
|
+
we need to consider the memory occupied by the sliding window attention layers first,
|
|
292
|
+
since their memory usage is constant regardless of the number of blocks.
|
|
293
|
+
num_blocks = (available_dram - swa_kv_nbytes) // dram_per_block
|
|
294
|
+
|
|
295
|
+
"""
|
|
296
|
+
|
|
297
|
+
def get_dram_per_block(seq_len: int, num_key_value_heads: int, tensor_parallel_size: int) -> int:
|
|
298
|
+
nbytes_per_param = 2 # Assuming kv-cache is always not quantized
|
|
299
|
+
dram_per_block = (
|
|
300
|
+
seq_len
|
|
301
|
+
* align(head_dim, 64)
|
|
302
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
303
|
+
* nbytes_per_param
|
|
304
|
+
* tensor_parallel_size
|
|
305
|
+
* 2
|
|
306
|
+
) # *2 for key and value
|
|
307
|
+
|
|
308
|
+
return dram_per_block
|
|
309
|
+
|
|
310
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
311
|
+
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
|
312
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
313
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
314
|
+
tensor_parallel_size = rbln_config.tensor_parallel_size or 1
|
|
315
|
+
|
|
316
|
+
# Consider layer types if available
|
|
317
|
+
# If layer types are not found, assume all layers are full attention
|
|
318
|
+
layer_types = getattr(model_config, "layer_types", None)
|
|
319
|
+
if layer_types:
|
|
320
|
+
layer_types_dict = Counter(layer_types)
|
|
321
|
+
num_full_attention = layer_types_dict.pop("full_attention", 0)
|
|
322
|
+
num_sliding_window_attention = layer_types_dict.pop("sliding_attention", 0)
|
|
323
|
+
if len(layer_types_dict) > 0:
|
|
324
|
+
raise ValueError(f"Unknown layer types found in the config: {layer_types_dict.keys()}")
|
|
325
|
+
|
|
326
|
+
else:
|
|
327
|
+
num_full_attention = num_hidden_layers
|
|
328
|
+
num_sliding_window_attention = 0
|
|
329
|
+
|
|
330
|
+
# Reduce available DRAM by sliding window attention kv-cache
|
|
331
|
+
# Since memory occupation of swa layer is constant regardless of num_blocks
|
|
332
|
+
swa_kv_nbytes = 0
|
|
333
|
+
if num_sliding_window_attention > 0:
|
|
334
|
+
sliding_window = getattr(model_config, "sliding_window", None)
|
|
335
|
+
if sliding_window is None:
|
|
336
|
+
logger.warning(
|
|
337
|
+
"`sliding_window` is not found in the config while `sliding_attention` layers are present. "
|
|
338
|
+
"Assuming maximum sliding window size for estimation."
|
|
339
|
+
)
|
|
340
|
+
sliding_window = rbln_config.kvcache_block_size
|
|
341
|
+
|
|
342
|
+
swa_kv_nbytes = num_sliding_window_attention * get_dram_per_block(
|
|
343
|
+
seq_len=sliding_window,
|
|
344
|
+
num_key_value_heads=num_key_value_heads,
|
|
345
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
available_dram -= swa_kv_nbytes
|
|
349
|
+
|
|
350
|
+
dram_per_block = num_full_attention * get_dram_per_block(
|
|
351
|
+
seq_len=rbln_config.kvcache_block_size,
|
|
352
|
+
num_key_value_heads=num_key_value_heads,
|
|
353
|
+
tensor_parallel_size=tensor_parallel_size,
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
if dram_per_block == 0:
|
|
357
|
+
raise ValueError("DRAM per block is calculated as zero, cannot estimate maximum number of blocks.")
|
|
358
|
+
|
|
359
|
+
max_n_blocks = available_dram // dram_per_block
|
|
360
|
+
return max_n_blocks
|
|
361
|
+
|
|
362
|
+
@classmethod
|
|
363
|
+
def maybe_suggest_kvcache_num_blocks(
|
|
364
|
+
cls,
|
|
365
|
+
compiled_models: Dict[str, "rebel.RBLNCompiledModel"],
|
|
366
|
+
model_config: "PretrainedConfig",
|
|
367
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
368
|
+
) -> None:
|
|
369
|
+
max_num_blocks = cls.get_maximum_num_blocks_by_compiled_model(
|
|
370
|
+
compiled_models=compiled_models,
|
|
371
|
+
model_config=model_config,
|
|
372
|
+
rbln_config=rbln_config,
|
|
373
|
+
)
|
|
374
|
+
|
|
375
|
+
# Since our estimation logic is not always accurate,
|
|
376
|
+
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
377
|
+
# If the memory is not enough, the model will fail to compile.
|
|
378
|
+
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
379
|
+
logger.warning(
|
|
380
|
+
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
381
|
+
"Our analysis indicates that additional memory is available for more blocks. "
|
|
382
|
+
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
383
|
+
"Please be advised that our memory estimation algorithm has limitations, "
|
|
384
|
+
"and increasing this value may not guarantee successful model compilation."
|
|
385
|
+
)
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""
|
|
16
|
+
This file defines generic base classes for various RBLN models,
|
|
17
|
+
such as Question Answering, Image Classification, Audio Classification,
|
|
18
|
+
Sequence Classification, and Masked Language Modeling. These classes
|
|
19
|
+
implement common functionalities and configurations to be used across
|
|
20
|
+
different model architectures.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
import inspect
|
|
24
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
25
|
+
|
|
26
|
+
from torch import nn
|
|
27
|
+
from transformers import (
|
|
28
|
+
AutoModel,
|
|
29
|
+
AutoModelForDepthEstimation,
|
|
30
|
+
AutoModelForImageClassification,
|
|
31
|
+
AutoModelForMaskedLM,
|
|
32
|
+
AutoModelForQuestionAnswering,
|
|
33
|
+
AutoModelForSequenceClassification,
|
|
34
|
+
AutoModelForTextEncoding,
|
|
35
|
+
PretrainedConfig,
|
|
36
|
+
)
|
|
37
|
+
from transformers.modeling_outputs import BaseModelOutput, QuestionAnsweringModelOutput
|
|
38
|
+
|
|
39
|
+
from ..configuration_utils import RBLNCompileConfig
|
|
40
|
+
from ..modeling import RBLNModel
|
|
41
|
+
from ..utils.logging import get_logger
|
|
42
|
+
from .configuration_generic import (
|
|
43
|
+
RBLNImageModelConfig,
|
|
44
|
+
RBLNTransformerEncoderConfig,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
if TYPE_CHECKING:
|
|
49
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
|
50
|
+
|
|
51
|
+
logger = get_logger()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class RBLNTransformerEncoder(RBLNModel):
|
|
55
|
+
auto_model_class = AutoModel
|
|
56
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
57
|
+
rbln_dtype = "int64"
|
|
58
|
+
|
|
59
|
+
@classmethod
|
|
60
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig) -> nn.Module:
|
|
61
|
+
class TransformerEncoderWrapper(nn.Module):
|
|
62
|
+
# Parameters to disable for RBLN compilation
|
|
63
|
+
DISABLED_PARAMS = {"return_dict", "use_cache"}
|
|
64
|
+
|
|
65
|
+
def __init__(self, model: "PreTrainedModel", rbln_config: RBLNTransformerEncoderConfig):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.model = model
|
|
68
|
+
self.rbln_config = rbln_config
|
|
69
|
+
self._forward_signature = inspect.signature(model.forward)
|
|
70
|
+
|
|
71
|
+
def forward(self, *args, **kwargs):
|
|
72
|
+
# Disable parameters that are not compatible with RBLN compilation
|
|
73
|
+
for param_name in self.DISABLED_PARAMS:
|
|
74
|
+
if param_name in self._forward_signature.parameters:
|
|
75
|
+
kwargs[param_name] = False
|
|
76
|
+
|
|
77
|
+
return self.model(*args, **kwargs)
|
|
78
|
+
|
|
79
|
+
return TransformerEncoderWrapper(model, rbln_config).eval()
|
|
80
|
+
|
|
81
|
+
@classmethod
|
|
82
|
+
def _update_rbln_config(
|
|
83
|
+
cls,
|
|
84
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
85
|
+
model: Optional["PreTrainedModel"] = None,
|
|
86
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
87
|
+
rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
|
|
88
|
+
) -> RBLNTransformerEncoderConfig:
|
|
89
|
+
return cls.update_rbln_config_for_transformers_encoder(
|
|
90
|
+
preprocessors=preprocessors,
|
|
91
|
+
model=model,
|
|
92
|
+
model_config=model_config,
|
|
93
|
+
rbln_config=rbln_config,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
@classmethod
|
|
97
|
+
def update_rbln_config_for_transformers_encoder(
|
|
98
|
+
cls,
|
|
99
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
100
|
+
model: Optional["PreTrainedModel"] = None,
|
|
101
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
102
|
+
rbln_config: Optional[RBLNTransformerEncoderConfig] = None,
|
|
103
|
+
) -> RBLNTransformerEncoderConfig:
|
|
104
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
|
105
|
+
model_config, "max_position_embeddings", None
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
if rbln_config.max_seq_len is None:
|
|
109
|
+
rbln_config.max_seq_len = max_position_embeddings
|
|
110
|
+
if rbln_config.max_seq_len is None:
|
|
111
|
+
for tokenizer in preprocessors:
|
|
112
|
+
if hasattr(tokenizer, "model_max_length"):
|
|
113
|
+
rbln_config.max_seq_len = tokenizer.model_max_length
|
|
114
|
+
break
|
|
115
|
+
if rbln_config.max_seq_len is None:
|
|
116
|
+
raise ValueError("`max_seq_len` should be specified!")
|
|
117
|
+
|
|
118
|
+
if max_position_embeddings is not None and rbln_config.max_seq_len > max_position_embeddings:
|
|
119
|
+
raise ValueError("`max_seq_len` should be less or equal than max_position_embeddings!")
|
|
120
|
+
|
|
121
|
+
signature_params = inspect.signature(model.forward).parameters.keys()
|
|
122
|
+
|
|
123
|
+
if rbln_config.model_input_names is None:
|
|
124
|
+
for tokenizer in preprocessors:
|
|
125
|
+
if hasattr(tokenizer, "model_input_names"):
|
|
126
|
+
rbln_config.model_input_names = [
|
|
127
|
+
name for name in signature_params if name in tokenizer.model_input_names
|
|
128
|
+
]
|
|
129
|
+
|
|
130
|
+
invalid_params = set(rbln_config.model_input_names) - set(signature_params)
|
|
131
|
+
if invalid_params:
|
|
132
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
|
133
|
+
break
|
|
134
|
+
if rbln_config.model_input_names is None and cls.rbln_model_input_names is not None:
|
|
135
|
+
rbln_config.model_input_names = cls.rbln_model_input_names
|
|
136
|
+
|
|
137
|
+
else:
|
|
138
|
+
invalid_params = set(rbln_config.model_input_names) - set(signature_params)
|
|
139
|
+
if invalid_params:
|
|
140
|
+
raise ValueError(f"Invalid model input names: {invalid_params}")
|
|
141
|
+
rbln_config.model_input_names = [
|
|
142
|
+
name for name in signature_params if name in rbln_config.model_input_names
|
|
143
|
+
]
|
|
144
|
+
|
|
145
|
+
if rbln_config.model_input_names is None or len(rbln_config.model_input_names) == 0:
|
|
146
|
+
raise ValueError(
|
|
147
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`. "
|
|
148
|
+
"This is an internal error. Please report it to the developers."
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if rbln_config.model_input_shapes is None:
|
|
152
|
+
input_info = [
|
|
153
|
+
(model_input_name, [rbln_config.batch_size, rbln_config.max_seq_len], cls.rbln_dtype)
|
|
154
|
+
for model_input_name in rbln_config.model_input_names
|
|
155
|
+
]
|
|
156
|
+
else:
|
|
157
|
+
input_info = [
|
|
158
|
+
(model_input_name, model_input_shape, cls.rbln_dtype)
|
|
159
|
+
for model_input_name, model_input_shape in zip(
|
|
160
|
+
rbln_config.model_input_names, rbln_config.model_input_shapes
|
|
161
|
+
)
|
|
162
|
+
]
|
|
163
|
+
|
|
164
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
165
|
+
return rbln_config
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
class RBLNImageModel(RBLNModel):
|
|
169
|
+
auto_model_class = AutoModel
|
|
170
|
+
main_input_name = "pixel_values"
|
|
171
|
+
output_class = BaseModelOutput
|
|
172
|
+
|
|
173
|
+
@classmethod
|
|
174
|
+
def _update_rbln_config(
|
|
175
|
+
cls,
|
|
176
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
177
|
+
model: Optional["PreTrainedModel"] = None,
|
|
178
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
179
|
+
rbln_config: Optional[RBLNImageModelConfig] = None,
|
|
180
|
+
) -> RBLNImageModelConfig:
|
|
181
|
+
return cls.update_rbln_config_for_image_model(
|
|
182
|
+
preprocessors=preprocessors,
|
|
183
|
+
model=model,
|
|
184
|
+
model_config=model_config,
|
|
185
|
+
rbln_config=rbln_config,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
@classmethod
|
|
189
|
+
def update_rbln_config_for_image_model(
|
|
190
|
+
cls,
|
|
191
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
192
|
+
model: Optional["PreTrainedModel"] = None,
|
|
193
|
+
model_config: Optional["PretrainedConfig"] = None,
|
|
194
|
+
rbln_config: Optional[RBLNImageModelConfig] = None,
|
|
195
|
+
) -> RBLNImageModelConfig:
|
|
196
|
+
if rbln_config.image_size is None:
|
|
197
|
+
for processor in preprocessors:
|
|
198
|
+
if hasattr(processor, "size"):
|
|
199
|
+
if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
|
|
200
|
+
rbln_config.image_size = (processor.size["height"], processor.size["width"])
|
|
201
|
+
elif "shortest_edge" in processor.size.keys():
|
|
202
|
+
rbln_config.image_size = (processor.size["shortest_edge"], processor.size["shortest_edge"])
|
|
203
|
+
elif "longest_edge" in processor.size.keys():
|
|
204
|
+
rbln_config.image_size = (processor.size["longest_edge"], processor.size["longest_edge"])
|
|
205
|
+
break
|
|
206
|
+
|
|
207
|
+
if rbln_config.image_size is None:
|
|
208
|
+
rbln_config.image_size = model_config.image_size
|
|
209
|
+
|
|
210
|
+
if rbln_config.image_size is None:
|
|
211
|
+
raise ValueError("`image_size` should be specified!")
|
|
212
|
+
|
|
213
|
+
input_info = [
|
|
214
|
+
(
|
|
215
|
+
cls.main_input_name,
|
|
216
|
+
[rbln_config.batch_size, 3, rbln_config.image_height, rbln_config.image_width],
|
|
217
|
+
"float32",
|
|
218
|
+
)
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
|
|
222
|
+
return rbln_config
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
class RBLNModelForQuestionAnswering(RBLNTransformerEncoder):
|
|
226
|
+
auto_model_class = AutoModelForQuestionAnswering
|
|
227
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
|
228
|
+
output_class = QuestionAnsweringModelOutput
|
|
229
|
+
|
|
230
|
+
def _prepare_output(self, output, return_dict):
|
|
231
|
+
# Prepare QuestionAnswering specific output format.
|
|
232
|
+
start_logits, end_logits = output
|
|
233
|
+
|
|
234
|
+
if not return_dict:
|
|
235
|
+
return (start_logits, end_logits)
|
|
236
|
+
else:
|
|
237
|
+
return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class RBLNModelForSequenceClassification(RBLNTransformerEncoder):
|
|
241
|
+
auto_model_class = AutoModelForSequenceClassification
|
|
242
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
class RBLNModelForMaskedLM(RBLNTransformerEncoder):
|
|
246
|
+
auto_model_class = AutoModelForMaskedLM
|
|
247
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class RBLNModelForTextEncoding(RBLNTransformerEncoder):
|
|
251
|
+
auto_model_class = AutoModelForTextEncoding
|
|
252
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
class RBLNTransformerEncoderForFeatureExtraction(RBLNTransformerEncoder):
|
|
256
|
+
# TODO: RBLNModel is also for feature extraction.
|
|
257
|
+
auto_model_class = AutoModel
|
|
258
|
+
rbln_model_input_names = ["input_ids", "attention_mask"]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class RBLNModelForImageClassification(RBLNImageModel):
|
|
262
|
+
auto_model_class = AutoModelForImageClassification
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
class RBLNModelForDepthEstimation(RBLNImageModel):
|
|
266
|
+
auto_model_class = AutoModelForDepthEstimation
|
|
267
|
+
|
|
268
|
+
@classmethod
|
|
269
|
+
def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
|
|
270
|
+
class ImageModelWrapper(nn.Module):
|
|
271
|
+
def __init__(self, model: "PreTrainedModel", rbln_config: RBLNImageModelConfig):
|
|
272
|
+
super().__init__()
|
|
273
|
+
self.model = model
|
|
274
|
+
self.rbln_config = rbln_config
|
|
275
|
+
|
|
276
|
+
def forward(self, *args, **kwargs):
|
|
277
|
+
output = self.model(*args, return_dict=True, **kwargs)
|
|
278
|
+
return output.predicted_depth
|
|
279
|
+
|
|
280
|
+
return ImageModelWrapper(model, rbln_config).eval()
|