optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
from typing import Any, Dict, List, Literal, Optional, Union, get_args
|
|
16
|
+
|
|
17
|
+
from ....configuration_utils import RBLNModelConfig
|
|
18
|
+
from ....utils.logging import get_logger
|
|
19
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
20
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
logger = get_logger()
|
|
24
|
+
|
|
25
|
+
CacheImplType = Literal["static", "sliding_window", "hybrid"]
|
|
26
|
+
PhaseType = Literal["prefill", "image_prefill", "decode"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class RBLNDecoderOnlyModelConfig(RBLNModelConfig):
|
|
30
|
+
"""
|
|
31
|
+
Configuration class for RBLN decoder-only models.
|
|
32
|
+
|
|
33
|
+
This class extends RBLNModelConfig with parameters specific to decoder-only transformer
|
|
34
|
+
architectures optimized for RBLN devices. It controls aspects like attention implementation,
|
|
35
|
+
KV cache management, and batching for inference.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
_default_phases = ["prefill"]
|
|
39
|
+
_default_logits_to_keep = 0
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
batch_size: Optional[int] = None,
|
|
44
|
+
max_seq_len: Optional[int] = None,
|
|
45
|
+
use_inputs_embeds: Optional[bool] = None,
|
|
46
|
+
use_attention_mask: Optional[bool] = None,
|
|
47
|
+
use_position_ids: Optional[bool] = None,
|
|
48
|
+
attn_impl: Optional[str] = None,
|
|
49
|
+
kvcache_partition_len: Optional[int] = None,
|
|
50
|
+
kvcache_block_size: Optional[int] = None,
|
|
51
|
+
quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
|
|
52
|
+
lora_config: Optional[Union[Dict[str, Any], RBLNLoRAConfig]] = None,
|
|
53
|
+
prefill_chunk_size: Optional[int] = None,
|
|
54
|
+
kvcache_num_blocks: Optional[int] = None,
|
|
55
|
+
decoder_batch_sizes: Optional[List[int]] = None,
|
|
56
|
+
cache_impl: Optional[CacheImplType] = None,
|
|
57
|
+
sliding_window: Optional[int] = None,
|
|
58
|
+
sliding_window_layers: Optional[List[int]] = None,
|
|
59
|
+
phases: Optional[List[PhaseType]] = None,
|
|
60
|
+
logits_to_keep: Optional[int] = None,
|
|
61
|
+
**kwargs,
|
|
62
|
+
):
|
|
63
|
+
"""
|
|
64
|
+
Args:
|
|
65
|
+
batch_size (Optional[int]): The batch size for inference. Defaults to 1.
|
|
66
|
+
max_seq_len (Optional[int]): The maximum sequence length supported by the model.
|
|
67
|
+
If not provided, it attempts to infer from the model's configuration
|
|
68
|
+
(`max_position_embeddings` or `n_positions`). Must be specified if not available
|
|
69
|
+
in the model config.
|
|
70
|
+
use_inputs_embeds (Optional[bool]): Whether to use input embeddings (`inputs_embeds`)
|
|
71
|
+
directly instead of `input_ids`. Defaults to False. Requires the model to be
|
|
72
|
+
compiled with this option enabled.
|
|
73
|
+
use_attention_mask (Optional[bool]): Whether the model requires attention masks during
|
|
74
|
+
inference. This is typically determined based on the target device and model
|
|
75
|
+
architecture. Defaults are often set automatically based on the model and RBLN NPU.
|
|
76
|
+
use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
|
|
77
|
+
attn_impl (Optional[str]): Specifies the attention implementation to use.
|
|
78
|
+
See the "Attention Implementation (`attn_impl`)" section below for details.
|
|
79
|
+
kvcache_partition_len (Optional[int]): Defines the partition length for the KV cache
|
|
80
|
+
when using "flash_attn". See the "KV Cache Partition Length (`kvcache_partition_len`)"
|
|
81
|
+
section below for details.
|
|
82
|
+
kvcache_block_size (Optional[int]): Sets the size (in number of tokens) of each block
|
|
83
|
+
in the PagedAttention KV cache. See the "KV Cache Block Size (`kvcache_block_size`)"
|
|
84
|
+
section below for details.
|
|
85
|
+
quantization (Optional[Dict[str, Any]]): Configuration dictionary for applying model
|
|
86
|
+
quantization. Specifies format, etc.
|
|
87
|
+
lora_config (Optional[Union[Dict[str, Any], RBLNLoRAConfig]]): Configuration for LoRA
|
|
88
|
+
(Low-Rank Adaptation) settings when using (multi-)LoRA support. Can be provided as
|
|
89
|
+
a dictionary or an RBLNLoRAConfig instance. When provided, enables LoRA functionality
|
|
90
|
+
for the model compilation. Defaults to None (no LoRA).
|
|
91
|
+
prefill_chunk_size (Optional[int]): The chunk size used during the prefill phase for
|
|
92
|
+
processing input sequences. Defaults to 128. Must be a positive integer
|
|
93
|
+
divisible by 64. Affects prefill performance and memory usage.
|
|
94
|
+
kvcache_num_blocks (Optional[int]): The total number of blocks to allocate for the
|
|
95
|
+
PagedAttention KV cache. See the "KV Cache Number of Blocks (`kvcache_num_blocks`)"
|
|
96
|
+
section below for details.
|
|
97
|
+
decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
|
|
98
|
+
This allows the model to handle varying batch sizes efficiently during generation. If not specified,
|
|
99
|
+
defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
|
|
100
|
+
1) All values must be less than or equal to the main batch size.
|
|
101
|
+
2) The list will be sorted in descending order (larger batch sizes first).
|
|
102
|
+
3) If using multiple decoders, at least one batch size should match the main batch size.
|
|
103
|
+
cache_impl (Optional[CacheImplType]): Specifies the KV cache implementation strategy. Defaults to "static".
|
|
104
|
+
- "static": Uses a fixed-size global KV cache for all layers, suitable for standard attention patterns.
|
|
105
|
+
- "sliding_window": Implements a sliding window KV cache, where each layer maintains a local cache of recent tokens.
|
|
106
|
+
- "hybrid": Combines both static and sliding window approaches, allowing different layers to use different cache strategies.
|
|
107
|
+
The choice affects memory usage and attention patterns. When using "sliding_window" or "hybrid",
|
|
108
|
+
you must specify the `sliding_window` size and optionally `sliding_window_layers` for hybrid mode.
|
|
109
|
+
sliding_window (Optional[int]): The size of the sliding window. Defaults to None.
|
|
110
|
+
sliding_window_layers (Optional[List[int]]): The layers to use for the sliding window used in the hybrid model. Defaults to None.
|
|
111
|
+
phases (Optional[List[PhaseType]]): The phases to compile the model for. Defaults to ["prefill"] if DecoderOnlyModel is used,
|
|
112
|
+
["prefill", "decode"] if DecoderOnlyModelForCausalLM is used.
|
|
113
|
+
logits_to_keep (Optional[int]): The number of logits to keep for the decoder. If set to 0, the decoder will keep all logits.
|
|
114
|
+
Defaults to 0 if DecoderOnlyModel is used, 1 if DecoderOnlyModelForCausalLM is used.
|
|
115
|
+
kwargs: Additional arguments passed to the parent RBLNModelConfig.
|
|
116
|
+
|
|
117
|
+
Raises:
|
|
118
|
+
ValueError: If `batch_size` is not a positive integer.
|
|
119
|
+
ValueError: If `prefill_chunk_size` is not a positive integer divisible by 64.
|
|
120
|
+
ValueError: If `max_seq_len` cannot be determined and is required.
|
|
121
|
+
ValueError: If attention parameter constraints are violated (e.g., `max_seq_len` vs
|
|
122
|
+
`kvcache_partition_len` for flash attention).
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
Attention Implementation:
|
|
126
|
+
`attn_impl` determines the underlying attention mechanism used by the model.
|
|
127
|
+
|
|
128
|
+
- **`"eager"`** (Default if `kvcache_partition_len` is not set): Uses the standard PyTorch
|
|
129
|
+
attention implementation. Suitable for sequences up to a certain limit (e.g., 32,768 tokens).
|
|
130
|
+
- **`"flash_attn"`**: Utilizes an optimized Flash Attention implementation, beneficial for
|
|
131
|
+
longer sequences and potentially faster execution. Requires `max_seq_len` to be at least
|
|
132
|
+
8,192. If `kvcache_partition_len` is specified, `attn_impl` automatically defaults
|
|
133
|
+
to `"flash_attn"`. When using `"flash_attn"`, `kvcache_block_size` must equal
|
|
134
|
+
`kvcache_partition_len`.
|
|
135
|
+
|
|
136
|
+
The choice impacts performance and memory usage, especially for long sequences.
|
|
137
|
+
Constraints related to `max_seq_len` and `kvcache_partition_len` apply when using
|
|
138
|
+
`"flash_attn"`.
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
KV Cache Partition Length:
|
|
142
|
+
`kvcache_partition_len` is relevant **only** when `attn_impl` is `"flash_attn"`.
|
|
143
|
+
|
|
144
|
+
- It defines the length (number of tokens) of each partition within the Key-Value (KV) cache.
|
|
145
|
+
- Must be between 4,096 and 32,768 (inclusive).
|
|
146
|
+
- When using `"flash_attn"`, `max_seq_len` must be a multiple of `kvcache_partition_len`
|
|
147
|
+
and at least twice its value (`max_seq_len >= 2 * kvcache_partition_len`).
|
|
148
|
+
- If `attn_impl` is `"flash_attn"` and `kvcache_partition_len` is `None`, it defaults to
|
|
149
|
+
16,384.
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
KV Cache Number of Blocks:
|
|
153
|
+
`kvcache_num_blocks` controls the total number of memory blocks allocated for the PagedAttention KV cache.
|
|
154
|
+
Each block holds `kvcache_block_size` tokens of Key and Value states.
|
|
155
|
+
|
|
156
|
+
- **Automatic Estimation (Default)**: If `kvcache_num_blocks` is `None`, the system estimates
|
|
157
|
+
the maximum number of blocks that can fit into the available RBLN device memory. This
|
|
158
|
+
calculation considers the model size (kernel memory), required buffer memory, the number
|
|
159
|
+
of layers and heads, `kvcache_block_size`, tensor parallelism, and available RBLN NPU DRAM.
|
|
160
|
+
This aims to maximize cache capacity for potentially better performance with long sequences
|
|
161
|
+
or larger batches without manual tuning.
|
|
162
|
+
- **Manual Setting**: You can explicitly set the number of blocks. This provides finer control
|
|
163
|
+
but requires careful consideration of memory limits. Setting it too high may lead to
|
|
164
|
+
compilation errors if it exceeds available memory. The system will issue warnings if your
|
|
165
|
+
setting exceeds the estimated maximum.
|
|
166
|
+
- **Performance Impact**: A larger number of blocks reduces the likelihood of cache eviction,
|
|
167
|
+
which is beneficial for tasks involving many long sequences or large batch sizes, enabling
|
|
168
|
+
higher throughput. However, allocating more blocks consumes more memory.
|
|
169
|
+
- **Minimum Requirement**: The system requires a minimum number of blocks to function,
|
|
170
|
+
calculated based on `max_seq_len`, `kvcache_block_size`, and `batch_size`. The number of
|
|
171
|
+
allocated blocks must be sufficient to hold at least one full sequence length per item
|
|
172
|
+
in the batch concurrently. The system will log warnings or raise errors if constraints
|
|
173
|
+
are violated (e.g., if `kvcache_num_blocks` is less than `batch_size` when using Flash Attention).
|
|
174
|
+
|
|
175
|
+
The optimal value depends on the specific model, task, hardware, and desired trade-off
|
|
176
|
+
between performance and memory usage. The automatic estimation provides a robust starting point.
|
|
177
|
+
"""
|
|
178
|
+
|
|
179
|
+
super().__init__(**kwargs)
|
|
180
|
+
self.batch_size = batch_size or 1
|
|
181
|
+
if not isinstance(self.batch_size, int) or self.batch_size < 0:
|
|
182
|
+
raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
|
|
183
|
+
|
|
184
|
+
self.max_seq_len = max_seq_len
|
|
185
|
+
self.use_inputs_embeds = use_inputs_embeds or False
|
|
186
|
+
self.use_position_ids = use_position_ids or False
|
|
187
|
+
self.use_attention_mask = use_attention_mask or False
|
|
188
|
+
|
|
189
|
+
if self.use_position_ids and not self.use_attention_mask:
|
|
190
|
+
raise ValueError("Position IDs should be used with attention mask.")
|
|
191
|
+
|
|
192
|
+
self.quantization = quantization or {}
|
|
193
|
+
if self.quantization and isinstance(self.quantization, dict):
|
|
194
|
+
self.quantization = RBLNQuantizationConfig(**self.quantization)
|
|
195
|
+
|
|
196
|
+
self.lora_config = lora_config
|
|
197
|
+
if self.lora_config and isinstance(self.lora_config, dict):
|
|
198
|
+
self.lora_config = RBLNLoRAConfig(**self.lora_config)
|
|
199
|
+
|
|
200
|
+
# Validate LoRA adapters if LoRA is enabled
|
|
201
|
+
if self.lora_config is not None:
|
|
202
|
+
validation_results = self.lora_config.validate_adapter_weights()
|
|
203
|
+
failed_adapters = [adapter_id for adapter_id, is_valid in validation_results.items() if not is_valid]
|
|
204
|
+
|
|
205
|
+
if failed_adapters:
|
|
206
|
+
raise ValueError(
|
|
207
|
+
f"Some LoRA adapters failed validation and may not be accessible at compile time: {failed_adapters}. "
|
|
208
|
+
"Please ensure all adapter weights are available and properly formatted."
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
logger.info(
|
|
212
|
+
f"LoRA configuration initialized with {self.lora_config.num_adapters} adapters: "
|
|
213
|
+
f"{self.lora_config.adapter_ids}. Max rank: {self.lora_config.max_lora_rank}"
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
self.attn_impl = attn_impl
|
|
217
|
+
self.kvcache_partition_len = kvcache_partition_len
|
|
218
|
+
self.kvcache_block_size = kvcache_block_size
|
|
219
|
+
self.prefill_chunk_size = prefill_chunk_size or 128
|
|
220
|
+
if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
|
|
221
|
+
raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
|
|
222
|
+
|
|
223
|
+
self.kvcache_num_blocks = kvcache_num_blocks
|
|
224
|
+
self.cache_impl = cache_impl or "static"
|
|
225
|
+
self.sliding_window = sliding_window
|
|
226
|
+
self.sliding_window_layers = sliding_window_layers or []
|
|
227
|
+
|
|
228
|
+
if phases is not None:
|
|
229
|
+
self.validate_phases_type(phases)
|
|
230
|
+
self.phases = phases or self._default_phases
|
|
231
|
+
self.logits_to_keep = logits_to_keep or self._default_logits_to_keep
|
|
232
|
+
if self.logits_to_keep is not None and self.logits_to_keep > 1:
|
|
233
|
+
raise NotImplementedError("`logits_to_keep` > 1 is currently not supported for RBLN models.")
|
|
234
|
+
|
|
235
|
+
self.decoder_batch_sizes = None
|
|
236
|
+
if "decode" in self.phases:
|
|
237
|
+
self.decoder_batch_sizes = decoder_batch_sizes
|
|
238
|
+
if self.decoder_batch_sizes is None:
|
|
239
|
+
self.decoder_batch_sizes = [self.batch_size]
|
|
240
|
+
|
|
241
|
+
if self.use_multiple_decoder:
|
|
242
|
+
if max(self.decoder_batch_sizes) > self.batch_size:
|
|
243
|
+
raise ValueError(
|
|
244
|
+
f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
|
|
245
|
+
)
|
|
246
|
+
if max(self.decoder_batch_sizes) < self.batch_size:
|
|
247
|
+
logger.warning(
|
|
248
|
+
f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
|
|
249
|
+
"Appending the model's batch size to the decoder batch size."
|
|
250
|
+
)
|
|
251
|
+
self.decoder_batch_sizes.append(self.batch_size)
|
|
252
|
+
|
|
253
|
+
# Larger batch size should be at the beginning of the list.
|
|
254
|
+
self.decoder_batch_sizes.sort(reverse=True)
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def validate_phases_type(phases: List[PhaseType]):
|
|
258
|
+
if not isinstance(phases, list):
|
|
259
|
+
raise ValueError("`phases` must be a list.")
|
|
260
|
+
if not all(phase in get_args(PhaseType) for phase in phases):
|
|
261
|
+
raise ValueError(f"All elements in `phases` must be of type `PhaseType`({get_args(PhaseType)}).")
|
|
262
|
+
|
|
263
|
+
@property
|
|
264
|
+
def use_global_attention(self) -> bool:
|
|
265
|
+
return self.cache_impl in ["static", "hybrid"]
|
|
266
|
+
|
|
267
|
+
@property
|
|
268
|
+
def use_local_attention(self) -> bool:
|
|
269
|
+
return self.cache_impl in ["sliding_window", "hybrid"]
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def use_multiple_decoder(self) -> bool:
|
|
273
|
+
return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def use_lora(self):
|
|
277
|
+
return self.lora_config is not None
|
|
278
|
+
|
|
279
|
+
@property
|
|
280
|
+
def can_generate(self) -> bool:
|
|
281
|
+
return "decode" in self.phases
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def nbits_per_param(self) -> int:
|
|
285
|
+
if self.quantization:
|
|
286
|
+
return self.quantization.nbits_per_param
|
|
287
|
+
return 16
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
class RBLNDecoderOnlyModelForCausalLMConfig(RBLNDecoderOnlyModelConfig):
|
|
291
|
+
"""
|
|
292
|
+
Configuration class for RBLN decoder-only models for Causal Language Modeling.
|
|
293
|
+
|
|
294
|
+
This class extends RBLNModelConfig with parameters specific to decoder-only transformer
|
|
295
|
+
architectures optimized for RBLN devices. It controls aspects like attention implementation,
|
|
296
|
+
KV cache management, and batching for inference.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
_default_phases = ["prefill", "decode"]
|
|
300
|
+
_default_logits_to_keep = 1
|