optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,823 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
18
|
+
|
|
19
|
+
import rebel
|
|
20
|
+
import torch
|
|
21
|
+
from rebel.compile_context import CompileContext
|
|
22
|
+
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
24
|
+
from transformers.modeling_utils import no_init_weights
|
|
25
|
+
|
|
26
|
+
from ....configuration_utils import RBLNCompileConfig
|
|
27
|
+
from ....modeling import RBLNModel
|
|
28
|
+
from ....utils.logging import get_logger
|
|
29
|
+
from ...modeling_attention_utils import (
|
|
30
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
31
|
+
set_default_values,
|
|
32
|
+
validate_attention_method,
|
|
33
|
+
validate_sliding_window,
|
|
34
|
+
)
|
|
35
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
|
+
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
|
+
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
logger = get_logger()
|
|
44
|
+
|
|
45
|
+
if TYPE_CHECKING:
|
|
46
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
50
|
+
"""
|
|
51
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
52
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
53
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
54
|
+
|
|
55
|
+
The class provides core functionality for:
|
|
56
|
+
|
|
57
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
58
|
+
2. Handling the compilation process for RBLN devices
|
|
59
|
+
3. Managing inference operations for decoder-only architectures
|
|
60
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
61
|
+
decoder-only architectures.
|
|
62
|
+
|
|
63
|
+
Note:
|
|
64
|
+
- This class is designed to be subclassed by specific model implementations
|
|
65
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
66
|
+
- Subclasses should implement model-specific conversion logic.
|
|
67
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
_tp_support = True
|
|
71
|
+
|
|
72
|
+
main_input_name = "input_ids"
|
|
73
|
+
auto_model_class = AutoModel
|
|
74
|
+
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
75
|
+
_use_rotary_emb = True
|
|
76
|
+
_supports_non_fp32 = True
|
|
77
|
+
|
|
78
|
+
def __post_init__(self, **kwargs):
|
|
79
|
+
if self.rbln_config.use_inputs_embeds:
|
|
80
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
81
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
82
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
83
|
+
else:
|
|
84
|
+
self.embed_tokens = None
|
|
85
|
+
|
|
86
|
+
self.setup_runtime()
|
|
87
|
+
|
|
88
|
+
def setup_runtime(self):
|
|
89
|
+
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
92
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
|
|
93
|
+
|
|
94
|
+
common_kwargs = {
|
|
95
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
96
|
+
"embed_tokens": self.embed_tokens,
|
|
97
|
+
"dec_attn_mask": dec_attn_mask,
|
|
98
|
+
"page_table_manager": page_table_manager,
|
|
99
|
+
"rbln_config": self.rbln_config,
|
|
100
|
+
}
|
|
101
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
|
102
|
+
runtime=self.model[0],
|
|
103
|
+
phase="prefill",
|
|
104
|
+
batch_size=self.rbln_config.batch_size,
|
|
105
|
+
out_buffers=out_buffers,
|
|
106
|
+
**common_kwargs,
|
|
107
|
+
)
|
|
108
|
+
if self.can_generate():
|
|
109
|
+
self.decoders = {}
|
|
110
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
111
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
112
|
+
runtime=self.model[i + 1],
|
|
113
|
+
phase="decode",
|
|
114
|
+
batch_size=batch_size,
|
|
115
|
+
**common_kwargs,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
119
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def prefill_output_size(self):
|
|
123
|
+
return (
|
|
124
|
+
1,
|
|
125
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
+
self.config.hidden_size,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def get_quantized_model(
|
|
131
|
+
cls,
|
|
132
|
+
model_id: str,
|
|
133
|
+
config: Optional[PretrainedConfig] = None,
|
|
134
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
135
|
+
revision: Optional[str] = None,
|
|
136
|
+
force_download: bool = False,
|
|
137
|
+
cache_dir: Optional[str] = None,
|
|
138
|
+
subfolder: str = "",
|
|
139
|
+
local_files_only: bool = False,
|
|
140
|
+
trust_remote_code: bool = False,
|
|
141
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
142
|
+
**kwargs,
|
|
143
|
+
):
|
|
144
|
+
kwargs = cls.update_kwargs(kwargs)
|
|
145
|
+
|
|
146
|
+
return get_quantized_model(
|
|
147
|
+
cls.auto_model_class,
|
|
148
|
+
model_id,
|
|
149
|
+
use_auth_token=use_auth_token,
|
|
150
|
+
revision=revision,
|
|
151
|
+
cache_dir=cache_dir,
|
|
152
|
+
force_download=force_download,
|
|
153
|
+
local_files_only=local_files_only,
|
|
154
|
+
rbln_quantization=rbln_config.quantization,
|
|
155
|
+
**kwargs,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
def __getattr__(self, __name: str) -> Any:
|
|
159
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
160
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
161
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
162
|
+
# proper method binding.
|
|
163
|
+
|
|
164
|
+
# The method implements a delegation pattern that:
|
|
165
|
+
|
|
166
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
167
|
+
# 2. For other attributes: Returns them directly from the original class
|
|
168
|
+
|
|
169
|
+
def redirect(func):
|
|
170
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
171
|
+
|
|
172
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
173
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
174
|
+
return redirect(val)
|
|
175
|
+
return val
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def save_torch_artifacts(
|
|
179
|
+
cls,
|
|
180
|
+
model: PreTrainedModel,
|
|
181
|
+
save_dir_path: Path,
|
|
182
|
+
subfolder: str,
|
|
183
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
184
|
+
):
|
|
185
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
186
|
+
# store the torch tensor, weight, etc. in this function.
|
|
187
|
+
if rbln_config.use_inputs_embeds:
|
|
188
|
+
save_dict = {}
|
|
189
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
190
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
191
|
+
|
|
192
|
+
def _create_embedding_layer(self):
|
|
193
|
+
with no_init_weights():
|
|
194
|
+
embed_tokens = torch.nn.Embedding(
|
|
195
|
+
self.config.vocab_size,
|
|
196
|
+
self.config.hidden_size,
|
|
197
|
+
self.config.pad_token_id,
|
|
198
|
+
)
|
|
199
|
+
return embed_tokens
|
|
200
|
+
|
|
201
|
+
def get_decoder(self):
|
|
202
|
+
if not self.can_generate():
|
|
203
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
204
|
+
return self.decoder
|
|
205
|
+
|
|
206
|
+
def can_generate(self):
|
|
207
|
+
return self.rbln_config.can_generate
|
|
208
|
+
|
|
209
|
+
def get_input_embeddings(self):
|
|
210
|
+
return self.embed_tokens
|
|
211
|
+
|
|
212
|
+
def get_attn_impl(self) -> str:
|
|
213
|
+
return self.rbln_config.attn_impl
|
|
214
|
+
|
|
215
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
216
|
+
return self.rbln_config.kvcache_num_blocks
|
|
217
|
+
|
|
218
|
+
@classmethod
|
|
219
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
|
+
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def _compile_model(
|
|
224
|
+
cls,
|
|
225
|
+
wrapped_model,
|
|
226
|
+
compile_config,
|
|
227
|
+
example_inputs,
|
|
228
|
+
compile_context,
|
|
229
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
230
|
+
quantization=None,
|
|
231
|
+
phase: str = "prefill",
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
wrapped_model.phase = phase
|
|
235
|
+
if quantization:
|
|
236
|
+
quantization.maybe_set_quantization_env()
|
|
237
|
+
original_linear = torch.nn.functional.linear
|
|
238
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
239
|
+
compiled_model = cls.compile(
|
|
240
|
+
wrapped_model,
|
|
241
|
+
compile_config,
|
|
242
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
243
|
+
device=rbln_config.device,
|
|
244
|
+
example_inputs=example_inputs,
|
|
245
|
+
compile_context=compile_context,
|
|
246
|
+
)
|
|
247
|
+
return compiled_model
|
|
248
|
+
finally:
|
|
249
|
+
torch.nn.functional.linear = original_linear
|
|
250
|
+
if quantization:
|
|
251
|
+
quantization.maybe_reset_quantization_env()
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def _get_compile_context(
|
|
255
|
+
cls,
|
|
256
|
+
compile_config: RBLNCompileConfig,
|
|
257
|
+
example_inputs: List[torch.Tensor],
|
|
258
|
+
):
|
|
259
|
+
context = CompileContext(use_weight_sharing=True)
|
|
260
|
+
|
|
261
|
+
# Mark static tensors (self kv states)
|
|
262
|
+
static_tensors = {}
|
|
263
|
+
idx = 0
|
|
264
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
265
|
+
if "past_key_values" in name:
|
|
266
|
+
static_tensors[name] = tensor
|
|
267
|
+
context.mark_static_address(tensor, f"kv_cache_{idx}")
|
|
268
|
+
idx += 1
|
|
269
|
+
|
|
270
|
+
return context, static_tensors
|
|
271
|
+
|
|
272
|
+
@classmethod
|
|
273
|
+
@torch.inference_mode()
|
|
274
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
275
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
276
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
277
|
+
|
|
278
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
279
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
280
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
281
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
282
|
+
|
|
283
|
+
compiled_models = {}
|
|
284
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
285
|
+
wrapped_model,
|
|
286
|
+
prefill_compile_config,
|
|
287
|
+
prefill_example_inputs,
|
|
288
|
+
context,
|
|
289
|
+
rbln_config,
|
|
290
|
+
rbln_config.quantization,
|
|
291
|
+
phase="prefill",
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
if rbln_config.can_generate:
|
|
295
|
+
wrapped_model.phase = "decode"
|
|
296
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
297
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
298
|
+
compiled_decoder = cls._compile_model(
|
|
299
|
+
wrapped_model,
|
|
300
|
+
dec_compile_config,
|
|
301
|
+
dec_example_inputs,
|
|
302
|
+
context,
|
|
303
|
+
rbln_config,
|
|
304
|
+
rbln_config.quantization,
|
|
305
|
+
phase="decode",
|
|
306
|
+
)
|
|
307
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
308
|
+
|
|
309
|
+
# check if the memory is enough to have additional blocks
|
|
310
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
311
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
312
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
313
|
+
compiled_models=compiled_models,
|
|
314
|
+
model_config=model.config,
|
|
315
|
+
rbln_config=rbln_config,
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return compiled_models
|
|
319
|
+
|
|
320
|
+
@classmethod
|
|
321
|
+
def get_pytorch_model(
|
|
322
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
323
|
+
) -> PreTrainedModel:
|
|
324
|
+
if rbln_config and rbln_config.quantization:
|
|
325
|
+
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
326
|
+
else:
|
|
327
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
328
|
+
|
|
329
|
+
return model
|
|
330
|
+
|
|
331
|
+
@classmethod
|
|
332
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
333
|
+
return use_local_attention
|
|
334
|
+
|
|
335
|
+
@classmethod
|
|
336
|
+
def get_input_info(
|
|
337
|
+
cls,
|
|
338
|
+
batch_size: int,
|
|
339
|
+
query_length: int,
|
|
340
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
341
|
+
model_config: PretrainedConfig,
|
|
342
|
+
):
|
|
343
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
344
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
345
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
346
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
347
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
348
|
+
is_prefill = query_length > 1
|
|
349
|
+
|
|
350
|
+
input_info = []
|
|
351
|
+
if rbln_config.use_inputs_embeds:
|
|
352
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
|
|
353
|
+
else:
|
|
354
|
+
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
355
|
+
|
|
356
|
+
input_info.append(("cache_position", [batch_size, query_length], "int32"))
|
|
357
|
+
|
|
358
|
+
if rbln_config.use_global_attention:
|
|
359
|
+
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
360
|
+
input_info.append(
|
|
361
|
+
("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
|
|
362
|
+
)
|
|
363
|
+
if rbln_config.use_local_attention:
|
|
364
|
+
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
365
|
+
|
|
366
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
367
|
+
input_info.append(("query_position", [], "int16"))
|
|
368
|
+
|
|
369
|
+
if rbln_config.use_attention_mask:
|
|
370
|
+
if rbln_config.use_position_ids:
|
|
371
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
|
|
372
|
+
else:
|
|
373
|
+
input_info.append(
|
|
374
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
if rbln_config.use_position_ids:
|
|
378
|
+
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
379
|
+
|
|
380
|
+
if rbln_config.use_lora:
|
|
381
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
382
|
+
|
|
383
|
+
kvcache_dtype = rbln_config.torch_dtype
|
|
384
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
385
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
386
|
+
|
|
387
|
+
global_kvcache_shape = [
|
|
388
|
+
rbln_config.kvcache_num_blocks,
|
|
389
|
+
num_key_value_heads,
|
|
390
|
+
rbln_config.kvcache_block_size,
|
|
391
|
+
head_dim,
|
|
392
|
+
]
|
|
393
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
394
|
+
input_info.extend(
|
|
395
|
+
[
|
|
396
|
+
(
|
|
397
|
+
f"past_key_values_{i}",
|
|
398
|
+
local_kvcache_shape
|
|
399
|
+
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
|
400
|
+
else global_kvcache_shape,
|
|
401
|
+
kvcache_dtype,
|
|
402
|
+
)
|
|
403
|
+
for i in range(num_hidden_layers * 2)
|
|
404
|
+
]
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
return input_info
|
|
408
|
+
|
|
409
|
+
@classmethod
|
|
410
|
+
def _update_sliding_window_config(
|
|
411
|
+
cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
412
|
+
):
|
|
413
|
+
# Update the sliding window configuration for the RBLN model.
|
|
414
|
+
|
|
415
|
+
# This method must be implemented by subclasses to handle their specific sliding window configurations,
|
|
416
|
+
# as Hugging Face models use different configuration keys to represent sliding window layers.
|
|
417
|
+
|
|
418
|
+
# Args:
|
|
419
|
+
# model_config (PretrainedConfig): The model configuration from Hugging Face.
|
|
420
|
+
# rbln_config (RBLNDecoderOnlyModelForCausalLMConfig): The RBLN model configuration.
|
|
421
|
+
|
|
422
|
+
# Notes:
|
|
423
|
+
# Required configuration settings:
|
|
424
|
+
# - `cache_impl`: Must be one of:
|
|
425
|
+
# - "static": All layers use global attention (no sliding window)
|
|
426
|
+
# - "sliding_window": All layers use sliding window attention
|
|
427
|
+
# - "hybrid": A mix of global and sliding window attention layers
|
|
428
|
+
# - `sliding_window`: Width of the sliding window (required if cache_impl is "sliding_window" or "hybrid")
|
|
429
|
+
# - `sliding_window_layers`: List of layer indices using sliding window attention (required if cache_impl is "hybrid")
|
|
430
|
+
|
|
431
|
+
# Example implementation for a 'sliding_window' model:
|
|
432
|
+
# ```python
|
|
433
|
+
# rbln_config.cache_impl = "sliding_window"
|
|
434
|
+
# rbln_config.sliding_window = model_config.sliding_window
|
|
435
|
+
# rbln_config.sliding_window_layers = [i for i in range(model_config.num_hidden_layers)]
|
|
436
|
+
# return rbln_config
|
|
437
|
+
# ```
|
|
438
|
+
|
|
439
|
+
# Returns:
|
|
440
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
441
|
+
|
|
442
|
+
raise NotImplementedError(
|
|
443
|
+
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
444
|
+
"See method docstring for required configuration details."
|
|
445
|
+
)
|
|
446
|
+
|
|
447
|
+
@classmethod
|
|
448
|
+
def _update_attention_config(
|
|
449
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
450
|
+
):
|
|
451
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
452
|
+
attn_impl=rbln_config.attn_impl,
|
|
453
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
454
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
455
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
validate_attention_method(
|
|
459
|
+
attn_impl=rbln_config.attn_impl,
|
|
460
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
461
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
462
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
466
|
+
|
|
467
|
+
# Update kvcache_num_blocks based on the attention implementation.
|
|
468
|
+
if rbln_config.attn_impl == "flash_attn":
|
|
469
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
|
|
470
|
+
model=model, model_config=model_config, rbln_config=rbln_config
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
474
|
+
if estimated_max_num_blocks < num_full_blocks:
|
|
475
|
+
# lower bound of the number of blocks for flash attention.
|
|
476
|
+
min_blocks_for_flash = min(
|
|
477
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
|
|
478
|
+
)
|
|
479
|
+
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
480
|
+
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
481
|
+
# Even if it's larger than the estimated maximum number of blocks.
|
|
482
|
+
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
483
|
+
else:
|
|
484
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
485
|
+
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
486
|
+
|
|
487
|
+
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
488
|
+
raise RuntimeError(
|
|
489
|
+
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
490
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
494
|
+
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
495
|
+
logger.warning(
|
|
496
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
497
|
+
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
498
|
+
"This can cause a failure during model compilation."
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
502
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
503
|
+
elif rbln_config.kvcache_num_blocks > num_full_blocks:
|
|
504
|
+
logger.warning(
|
|
505
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
506
|
+
f" than the required number of blocks ({num_full_blocks})."
|
|
507
|
+
"This can cause a failure during model compilation."
|
|
508
|
+
)
|
|
509
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
510
|
+
|
|
511
|
+
return rbln_config
|
|
512
|
+
|
|
513
|
+
@classmethod
|
|
514
|
+
def _update_rbln_config(
|
|
515
|
+
cls,
|
|
516
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
517
|
+
model: Optional[PreTrainedModel] = None,
|
|
518
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
519
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
520
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
521
|
+
if rbln_config.max_seq_len is None:
|
|
522
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
523
|
+
model_config, "n_positions", None
|
|
524
|
+
)
|
|
525
|
+
if rbln_config.max_seq_len is None:
|
|
526
|
+
raise ValueError("`max_seq_len` should be specified.")
|
|
527
|
+
|
|
528
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
529
|
+
model_config, "use_sliding_window", True
|
|
530
|
+
):
|
|
531
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
532
|
+
if rbln_config.sliding_window is not None:
|
|
533
|
+
validate_sliding_window(rbln_config)
|
|
534
|
+
|
|
535
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
536
|
+
|
|
537
|
+
prefill_input_info = cls.get_input_info(
|
|
538
|
+
batch_size=1,
|
|
539
|
+
query_length=rbln_config.prefill_chunk_size,
|
|
540
|
+
rbln_config=rbln_config,
|
|
541
|
+
model_config=model_config,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
545
|
+
compile_cfgs = [prefill_compile_config]
|
|
546
|
+
|
|
547
|
+
if rbln_config.can_generate:
|
|
548
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
549
|
+
dec_input_info = cls.get_input_info(
|
|
550
|
+
batch_size=batch_size,
|
|
551
|
+
query_length=1,
|
|
552
|
+
rbln_config=rbln_config,
|
|
553
|
+
model_config=model_config,
|
|
554
|
+
)
|
|
555
|
+
compile_cfgs.append(
|
|
556
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
557
|
+
)
|
|
558
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
559
|
+
|
|
560
|
+
return rbln_config
|
|
561
|
+
|
|
562
|
+
@classmethod
|
|
563
|
+
def _create_runtimes(
|
|
564
|
+
cls,
|
|
565
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
566
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
567
|
+
) -> List[rebel.Runtime]:
|
|
568
|
+
expected_model_names = ["prefill"]
|
|
569
|
+
if rbln_config.can_generate:
|
|
570
|
+
expected_model_names.extend(
|
|
571
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
572
|
+
)
|
|
573
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
574
|
+
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
575
|
+
|
|
576
|
+
ret_val = [
|
|
577
|
+
rebel.Runtime(
|
|
578
|
+
compiled_models[0],
|
|
579
|
+
tensor_type="pt",
|
|
580
|
+
device=rbln_config.device_map["prefill"],
|
|
581
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
582
|
+
timeout=rbln_config.timeout,
|
|
583
|
+
)
|
|
584
|
+
]
|
|
585
|
+
if rbln_config.can_generate:
|
|
586
|
+
ret_val.extend(
|
|
587
|
+
[
|
|
588
|
+
rebel.Runtime(
|
|
589
|
+
compiled_models[i + 1],
|
|
590
|
+
tensor_type="pt",
|
|
591
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
592
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
593
|
+
timeout=rbln_config.timeout,
|
|
594
|
+
)
|
|
595
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
596
|
+
]
|
|
597
|
+
)
|
|
598
|
+
return ret_val
|
|
599
|
+
|
|
600
|
+
def forward(
|
|
601
|
+
self,
|
|
602
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
603
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
604
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
605
|
+
**kwargs,
|
|
606
|
+
) -> BaseModelOutputWithPast:
|
|
607
|
+
"""
|
|
608
|
+
Args:
|
|
609
|
+
input_ids (torch.LongTensor, optional): The input IDs to the model.
|
|
610
|
+
inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
|
|
611
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
612
|
+
kwargs (dict[str, Any], optional): Additional keyword arguments.
|
|
613
|
+
|
|
614
|
+
Returns:
|
|
615
|
+
Dataclass containing the last hidden states of the model.
|
|
616
|
+
"""
|
|
617
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
618
|
+
batch_size = inputs.shape[0]
|
|
619
|
+
position_embed = kwargs.get("position_embed", None)
|
|
620
|
+
|
|
621
|
+
if batch_size != self.rbln_config.batch_size:
|
|
622
|
+
raise ValueError(
|
|
623
|
+
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
all_last_hidden_states = []
|
|
627
|
+
for b_idx in range(self.rbln_config.batch_size):
|
|
628
|
+
query_length = (
|
|
629
|
+
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
630
|
+
)
|
|
631
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
632
|
+
last_hidden_states = self.prefill_decoder(
|
|
633
|
+
inputs[b_idx : b_idx + 1],
|
|
634
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
635
|
+
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
636
|
+
cache_position=cache_position,
|
|
637
|
+
batch_idx=b_idx,
|
|
638
|
+
).logits
|
|
639
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
640
|
+
|
|
641
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
642
|
+
|
|
643
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
647
|
+
"""
|
|
648
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
649
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
650
|
+
|
|
651
|
+
The class provides core functionality for:
|
|
652
|
+
|
|
653
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
654
|
+
2. Handling the compilation process for RBLN devices
|
|
655
|
+
3. Managing inference operations for causal language modeling
|
|
656
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
657
|
+
decoder-only architectures and causal language modeling tasks.
|
|
658
|
+
|
|
659
|
+
Note:
|
|
660
|
+
- This class is designed to be subclassed by specific model implementations
|
|
661
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
662
|
+
- Subclasses should implement model-specific conversion logic.
|
|
663
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
664
|
+
"""
|
|
665
|
+
|
|
666
|
+
auto_model_class = AutoModelForCausalLM
|
|
667
|
+
|
|
668
|
+
@property
|
|
669
|
+
def prefill_output_size(self):
|
|
670
|
+
return (
|
|
671
|
+
1,
|
|
672
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
673
|
+
self.config.vocab_size,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
@classmethod
|
|
677
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
678
|
+
return is_prefill
|
|
679
|
+
|
|
680
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
681
|
+
if isinstance(lora_int_ids, int):
|
|
682
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
683
|
+
elif isinstance(lora_int_ids, list):
|
|
684
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
685
|
+
|
|
686
|
+
self.lora_int_ids = lora_int_ids
|
|
687
|
+
|
|
688
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
689
|
+
if self.rbln_config.can_generate:
|
|
690
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
691
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
692
|
+
|
|
693
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
694
|
+
"""
|
|
695
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
699
|
+
Can be a single adapter name or a list of adapter names.
|
|
700
|
+
|
|
701
|
+
Raises:
|
|
702
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
703
|
+
"""
|
|
704
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
705
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
706
|
+
|
|
707
|
+
# Convert single adapter name to list for uniform processing
|
|
708
|
+
if isinstance(adapter_name, str):
|
|
709
|
+
adapter_names = [adapter_name]
|
|
710
|
+
else:
|
|
711
|
+
adapter_names = adapter_name
|
|
712
|
+
|
|
713
|
+
# Validate that all adapter names exist
|
|
714
|
+
available_adapters = {
|
|
715
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
716
|
+
}
|
|
717
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
718
|
+
if missing_adapters:
|
|
719
|
+
raise ValueError(
|
|
720
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Get the adapter IDs and set them
|
|
724
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
725
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
726
|
+
|
|
727
|
+
def forward(
|
|
728
|
+
self,
|
|
729
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
730
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
731
|
+
cache_position: Optional[torch.Tensor] = None,
|
|
732
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
733
|
+
generate_idx: Optional[torch.Tensor] = None,
|
|
734
|
+
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
735
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
736
|
+
token_type_ids: Optional[torch.Tensor] = None,
|
|
737
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
738
|
+
return_dict: Optional[torch.Tensor] = None,
|
|
739
|
+
**kwargs,
|
|
740
|
+
) -> Tuple[torch.FloatTensor]:
|
|
741
|
+
# Forward method for the RBLN-optimized model, designed for integration with the HuggingFace generate API.
|
|
742
|
+
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
743
|
+
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
744
|
+
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
745
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
746
|
+
if self.lora_int_ids is None:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
"lora_int_id is required when using LoRA. "
|
|
749
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
750
|
+
)
|
|
751
|
+
lora_int_ids = self.lora_int_ids
|
|
752
|
+
|
|
753
|
+
# for only use forward
|
|
754
|
+
if generate_idx is None:
|
|
755
|
+
generate_idx = (
|
|
756
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
757
|
+
if attention_mask is not None
|
|
758
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
759
|
+
)
|
|
760
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
761
|
+
|
|
762
|
+
# Prefill
|
|
763
|
+
if cache_position is None:
|
|
764
|
+
logits = []
|
|
765
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
766
|
+
batch_size = inputs.shape[0]
|
|
767
|
+
input_len = inputs.shape[1]
|
|
768
|
+
if batch_size > self.rbln_config.batch_size:
|
|
769
|
+
raise ValueError(
|
|
770
|
+
f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
|
|
771
|
+
)
|
|
772
|
+
if input_len > self.rbln_config.max_seq_len:
|
|
773
|
+
raise ValueError(
|
|
774
|
+
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
for b_idx in range(batch_size):
|
|
778
|
+
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
779
|
+
output = self.prefill_decoder(
|
|
780
|
+
input_ids=inputs[b_idx : b_idx + 1] if inputs_embeds is None else None,
|
|
781
|
+
inputs_embeds=inputs[b_idx : b_idx + 1] if inputs_embeds is not None else None,
|
|
782
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
783
|
+
cache_position=cache_position,
|
|
784
|
+
batch_idx=b_idx,
|
|
785
|
+
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
786
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
787
|
+
)
|
|
788
|
+
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
789
|
+
logits.append(output.logits)
|
|
790
|
+
logits = torch.cat(logits, dim=0)
|
|
791
|
+
# Decoder
|
|
792
|
+
else:
|
|
793
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
794
|
+
batch_size = inputs.shape[0]
|
|
795
|
+
if batch_size not in self.decoders:
|
|
796
|
+
raise ValueError(
|
|
797
|
+
f"No decoder runtime available for batch size {batch_size}. "
|
|
798
|
+
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
|
799
|
+
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
800
|
+
)
|
|
801
|
+
if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
|
|
802
|
+
raise ValueError(
|
|
803
|
+
f"Cache position exceeds the maximum sequence length.\n"
|
|
804
|
+
f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
|
|
805
|
+
f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
|
|
806
|
+
f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
|
|
807
|
+
f"or `max_length` in the generation config."
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
logits = self.decoders[batch_size](
|
|
811
|
+
input_ids=input_ids,
|
|
812
|
+
inputs_embeds=inputs_embeds,
|
|
813
|
+
cache_position=cache_position,
|
|
814
|
+
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
815
|
+
lora_int_ids=lora_int_ids,
|
|
816
|
+
).logits
|
|
817
|
+
|
|
818
|
+
if not return_dict:
|
|
819
|
+
return logits, generate_idx, padded_cache_lengths
|
|
820
|
+
else:
|
|
821
|
+
return RBLNDecoderOnlyOutput(
|
|
822
|
+
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
823
|
+
)
|