optimum-rbln 0.9.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +505 -0
- optimum/rbln/__version__.py +34 -0
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +968 -0
- optimum/rbln/diffusers/__init__.py +198 -0
- optimum/rbln/diffusers/configurations/__init__.py +37 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +451 -0
- optimum/rbln/diffusers/models/__init__.py +64 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
- optimum/rbln/diffusers/models/controlnet.py +281 -0
- optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
- optimum/rbln/diffusers/models/unets/__init__.py +16 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +113 -0
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
- optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +364 -0
- optimum/rbln/modeling_base.py +637 -0
- optimum/rbln/ops/__init__.py +19 -0
- optimum/rbln/ops/attn.py +455 -0
- optimum/rbln/ops/flash_attn.py +350 -0
- optimum/rbln/ops/kv_cache_update.py +29 -0
- optimum/rbln/ops/linear.py +32 -0
- optimum/rbln/ops/sliding_window_attn.py +111 -0
- optimum/rbln/transformers/__init__.py +340 -0
- optimum/rbln/transformers/configuration_generic.py +120 -0
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +280 -0
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/modeling_rope_utils.py +314 -0
- optimum/rbln/transformers/models/__init__.py +343 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
- optimum/rbln/transformers/models/auto/__init__.py +31 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
- optimum/rbln/transformers/models/bart/__init__.py +17 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
- optimum/rbln/transformers/models/bert/__init__.py +16 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
- optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
- optimum/rbln/transformers/models/clip/__init__.py +26 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
- optimum/rbln/transformers/models/colpali/__init__.py +2 -0
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
- optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
- optimum/rbln/transformers/models/dpt/__init__.py +16 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
- optimum/rbln/transformers/models/exaone/__init__.py +24 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
- optimum/rbln/transformers/models/gemma/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
- optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
- optimum/rbln/transformers/models/llama/__init__.py +16 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
- optimum/rbln/transformers/models/midm/__init__.py +24 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
- optimum/rbln/transformers/models/mistral/__init__.py +16 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
- optimum/rbln/transformers/models/opt/__init__.py +16 -0
- optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
- optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
- optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +16 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/__init__.py +23 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
- optimum/rbln/transformers/models/roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
- optimum/rbln/transformers/models/siglip/__init__.py +16 -0
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/__init__.py +17 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
- optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
- optimum/rbln/transformers/models/vit/__init__.py +19 -0
- optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
- optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
- optimum/rbln/transformers/models/whisper/__init__.py +17 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
- optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/__init__.py +16 -0
- optimum/rbln/utils/decorator_utils.py +86 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +94 -0
- optimum/rbln/utils/import_utils.py +170 -0
- optimum/rbln/utils/logging.py +110 -0
- optimum/rbln/utils/model_utils.py +63 -0
- optimum/rbln/utils/runtime_utils.py +249 -0
- optimum/rbln/utils/save_utils.py +102 -0
- optimum/rbln/utils/submodule.py +152 -0
- optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
- optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
- optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
- optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,1224 @@
|
|
|
1
|
+
# Copyright 2025 Rebellions Inc. All rights reserved.
|
|
2
|
+
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at:
|
|
6
|
+
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import math
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import nn
|
|
20
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
21
|
+
|
|
22
|
+
from ....utils import logging
|
|
23
|
+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
26
|
+
from .lora_architecture import LoRALinear
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
logger = logging.get_logger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DecoderOnlyWrapper(nn.Module):
|
|
37
|
+
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
|
38
|
+
|
|
39
|
+
This wrapper is designed to:
|
|
40
|
+
1. Convert Huggingface decoder models for RBLN compilation with static shapes
|
|
41
|
+
2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
|
|
42
|
+
3. Manage different attention implementations (standard/flash attention)
|
|
43
|
+
4. Support both prefill and decode phases
|
|
44
|
+
|
|
45
|
+
Notes:
|
|
46
|
+
- Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
|
|
47
|
+
- Wrapper should not contain neural network graph operations (including memory view handling)
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
model (PreTrainedModel): The Huggingface causal language model to wrap
|
|
51
|
+
rbln_config: The RBLN model configuration containing all necessary parameters
|
|
52
|
+
use_rotary_emb (bool): Whether to use rotary position embeddings
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
_use_learned_pos_emb = False
|
|
56
|
+
|
|
57
|
+
def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.quantization = rbln_config.quantization
|
|
60
|
+
self.config = model.config
|
|
61
|
+
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
62
|
+
self.rbln_config = rbln_config
|
|
63
|
+
|
|
64
|
+
if use_rotary_emb:
|
|
65
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
|
|
66
|
+
if isinstance(rotary_embs, tuple):
|
|
67
|
+
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
|
68
|
+
else:
|
|
69
|
+
self.rotary_emb = rotary_embs
|
|
70
|
+
else:
|
|
71
|
+
self.rotary_emb = None
|
|
72
|
+
|
|
73
|
+
if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
|
|
76
|
+
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
80
|
+
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
81
|
+
self._phase = "prefill"
|
|
82
|
+
|
|
83
|
+
def get_rotary_emb(self, max_seq_len):
|
|
84
|
+
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
85
|
+
|
|
86
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
87
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
88
|
+
|
|
89
|
+
def get_attn_layer(self, layer: nn.Module):
|
|
90
|
+
return layer.self_attn
|
|
91
|
+
|
|
92
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
93
|
+
return model.model if self.is_causal_lm else model
|
|
94
|
+
|
|
95
|
+
def get_rbln_attn_class(self):
|
|
96
|
+
return DecoderOnlyAttention
|
|
97
|
+
|
|
98
|
+
def get_rbln_layer_class(self):
|
|
99
|
+
return DecoderOnlyLayer
|
|
100
|
+
|
|
101
|
+
def get_rbln_model_class(self):
|
|
102
|
+
return DecoderOnlyModel
|
|
103
|
+
|
|
104
|
+
def get_rbln_causal_lm_class(self):
|
|
105
|
+
return DecoderOnlyForCausalLM
|
|
106
|
+
|
|
107
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
108
|
+
new_layers = []
|
|
109
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
110
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
111
|
+
new_self_attn = self.get_rbln_attn_class()(
|
|
112
|
+
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
113
|
+
)
|
|
114
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
|
|
115
|
+
new_layers.append(new_layer)
|
|
116
|
+
|
|
117
|
+
new_model = self.get_rbln_model_class()(
|
|
118
|
+
self.get_model_layer(model),
|
|
119
|
+
new_layers,
|
|
120
|
+
self.rbln_config,
|
|
121
|
+
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if self.is_causal_lm:
|
|
125
|
+
new_model = self.get_rbln_causal_lm_class()(model, new_model)
|
|
126
|
+
return new_model
|
|
127
|
+
else:
|
|
128
|
+
return new_model
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def phase(self) -> str:
|
|
132
|
+
return self._phase
|
|
133
|
+
|
|
134
|
+
@phase.setter
|
|
135
|
+
def phase(self, phase: str):
|
|
136
|
+
self._phase = phase
|
|
137
|
+
self.model.phase = phase
|
|
138
|
+
|
|
139
|
+
def prepare_forward_args(self, *args):
|
|
140
|
+
args = list(args)
|
|
141
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
142
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
143
|
+
cache_position = args.pop(0)
|
|
144
|
+
global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
|
|
145
|
+
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
146
|
+
query_position = (
|
|
147
|
+
args.pop(0)
|
|
148
|
+
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
149
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
|
|
150
|
+
else None
|
|
151
|
+
)
|
|
152
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
153
|
+
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
154
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
155
|
+
past_key_values = args
|
|
156
|
+
|
|
157
|
+
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# [key, value] * n_layer -> ( (key, value) ) * n_layer
|
|
163
|
+
# cache shape : batch, n_heads, 1, max_seq_len, head_dim
|
|
164
|
+
_past_key_values = []
|
|
165
|
+
for i in range(self.config.num_hidden_layers):
|
|
166
|
+
key_states = past_key_values[i * 2]
|
|
167
|
+
value_states = past_key_values[i * 2 + 1]
|
|
168
|
+
past_key_value = [key_states, value_states]
|
|
169
|
+
_past_key_values.append(past_key_value)
|
|
170
|
+
past_key_values = _past_key_values
|
|
171
|
+
|
|
172
|
+
if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
|
|
173
|
+
rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
|
|
174
|
+
else:
|
|
175
|
+
rotary_emb = self.rotary_emb
|
|
176
|
+
|
|
177
|
+
return (
|
|
178
|
+
input_ids,
|
|
179
|
+
inputs_embeds,
|
|
180
|
+
cache_position,
|
|
181
|
+
global_block_tables,
|
|
182
|
+
local_block_tables,
|
|
183
|
+
query_position,
|
|
184
|
+
attention_mask,
|
|
185
|
+
position_ids,
|
|
186
|
+
lora_int_id,
|
|
187
|
+
past_key_values,
|
|
188
|
+
rotary_emb,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def forward(self, *args):
|
|
192
|
+
(
|
|
193
|
+
input_ids,
|
|
194
|
+
inputs_embeds,
|
|
195
|
+
cache_position,
|
|
196
|
+
global_block_tables,
|
|
197
|
+
local_block_tables,
|
|
198
|
+
query_position,
|
|
199
|
+
attention_mask,
|
|
200
|
+
position_ids,
|
|
201
|
+
lora_int_id,
|
|
202
|
+
past_key_values,
|
|
203
|
+
rotary_emb,
|
|
204
|
+
) = self.prepare_forward_args(*args)
|
|
205
|
+
|
|
206
|
+
logit = self.model(
|
|
207
|
+
input_ids=input_ids,
|
|
208
|
+
inputs_embeds=inputs_embeds,
|
|
209
|
+
attention_mask=attention_mask,
|
|
210
|
+
cache_position=cache_position,
|
|
211
|
+
position_ids=position_ids,
|
|
212
|
+
query_position=query_position,
|
|
213
|
+
past_key_values=past_key_values,
|
|
214
|
+
rotary_emb=rotary_emb,
|
|
215
|
+
global_block_tables=global_block_tables,
|
|
216
|
+
local_block_tables=local_block_tables,
|
|
217
|
+
lora_int_id=lora_int_id,
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return logit
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class DecoderOnlyForCausalLM(nn.Module):
|
|
224
|
+
"""A specialized wrapper for Causal Language Models optimized for RBLN compilation.
|
|
225
|
+
|
|
226
|
+
This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
|
|
227
|
+
1. Managing model phases (prefill/decode) throughout the computation graph
|
|
228
|
+
2. Handling output shape alignments for static compilation
|
|
229
|
+
3. Coordinating between the original model and RBLN-optimized components
|
|
230
|
+
|
|
231
|
+
The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
|
|
232
|
+
focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
causal_lm (PreTrainedModel): Original Huggingface causal language model
|
|
236
|
+
model (DecoderOnlyModel): RBLN-optimized model instance
|
|
237
|
+
|
|
238
|
+
Attributes:
|
|
239
|
+
config: Configuration from the original causal language model
|
|
240
|
+
_original_mod: Reference to the original model for components like lm_head
|
|
241
|
+
model: RBLN-optimized decoder model instance
|
|
242
|
+
_phase: Current processing phase ("prefill" or "decode")
|
|
243
|
+
"""
|
|
244
|
+
|
|
245
|
+
def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
|
|
246
|
+
super().__init__()
|
|
247
|
+
self.config = causal_lm.config
|
|
248
|
+
self._original_mod = causal_lm
|
|
249
|
+
self.model = model
|
|
250
|
+
self._phase = "prefill"
|
|
251
|
+
self.lm_head = self._original_mod.lm_head
|
|
252
|
+
|
|
253
|
+
@property
|
|
254
|
+
def phase(self):
|
|
255
|
+
return self._phase
|
|
256
|
+
|
|
257
|
+
@phase.setter
|
|
258
|
+
def phase(self, phase: str):
|
|
259
|
+
self._phase = phase
|
|
260
|
+
self.model.phase = phase
|
|
261
|
+
|
|
262
|
+
def forward(
|
|
263
|
+
self,
|
|
264
|
+
input_ids: torch.Tensor = None,
|
|
265
|
+
inputs_embeds: torch.Tensor = None,
|
|
266
|
+
attention_mask: torch.Tensor = None,
|
|
267
|
+
cache_position: torch.Tensor = None,
|
|
268
|
+
position_ids: torch.Tensor = None,
|
|
269
|
+
query_position: torch.Tensor = None,
|
|
270
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
|
271
|
+
rotary_emb: nn.Module = None,
|
|
272
|
+
global_block_tables: Optional[torch.Tensor] = None,
|
|
273
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
275
|
+
):
|
|
276
|
+
# outputs
|
|
277
|
+
hidden_states = self.model(
|
|
278
|
+
input_ids=input_ids,
|
|
279
|
+
inputs_embeds=inputs_embeds,
|
|
280
|
+
attention_mask=attention_mask,
|
|
281
|
+
cache_position=cache_position,
|
|
282
|
+
position_ids=position_ids,
|
|
283
|
+
query_position=query_position,
|
|
284
|
+
past_key_values=past_key_values,
|
|
285
|
+
rotary_emb=rotary_emb,
|
|
286
|
+
global_block_tables=global_block_tables,
|
|
287
|
+
local_block_tables=local_block_tables,
|
|
288
|
+
lora_int_id=lora_int_id,
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
if "prefill" in self.phase:
|
|
292
|
+
hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
|
|
293
|
+
|
|
294
|
+
logits = self.lm_head(hidden_states)
|
|
295
|
+
|
|
296
|
+
# Apply final logit softmaxing if configured, e.g. for Gemma2
|
|
297
|
+
if getattr(self.config, "final_logit_softcapping", None) is not None:
|
|
298
|
+
logits = logits / self.config.final_logit_softcapping
|
|
299
|
+
logits = torch.tanh(logits)
|
|
300
|
+
logits = logits * self.config.final_logit_softcapping
|
|
301
|
+
|
|
302
|
+
return logits
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
class DecoderOnlyModel(nn.Module):
|
|
306
|
+
"""A modified decoder-only model implementation optimized for RBLN compilation.
|
|
307
|
+
|
|
308
|
+
Args:
|
|
309
|
+
model: Original Huggingface model to adapt
|
|
310
|
+
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
|
311
|
+
rbln_config: RBLN model configuration
|
|
312
|
+
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
313
|
+
|
|
314
|
+
Attributes:
|
|
315
|
+
_original_mod: Reference to original Huggingface model
|
|
316
|
+
layers: ModuleList of RBLN-optimized transformer layers
|
|
317
|
+
_phase: Current processing phase ("prefill" or "decode")
|
|
318
|
+
"""
|
|
319
|
+
|
|
320
|
+
def __init__(
|
|
321
|
+
self,
|
|
322
|
+
model,
|
|
323
|
+
layers: List["DecoderOnlyLayer"],
|
|
324
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
325
|
+
use_learned_pos_emb=None,
|
|
326
|
+
):
|
|
327
|
+
super().__init__()
|
|
328
|
+
self._original_mod = model
|
|
329
|
+
self.layers = nn.ModuleList(layers)
|
|
330
|
+
self.rbln_config = rbln_config
|
|
331
|
+
self._phase = "prefill"
|
|
332
|
+
self.partition_len = rbln_config.kvcache_partition_len
|
|
333
|
+
self.kvcache_block_size = rbln_config.kvcache_block_size
|
|
334
|
+
self.max_seq_len = rbln_config.max_seq_len
|
|
335
|
+
self.use_learned_pos_emb = use_learned_pos_emb
|
|
336
|
+
self.sliding_window_layers = rbln_config.sliding_window_layers
|
|
337
|
+
|
|
338
|
+
@property
|
|
339
|
+
def phase(self):
|
|
340
|
+
return self._phase
|
|
341
|
+
|
|
342
|
+
@phase.setter
|
|
343
|
+
def phase(self, phase: str):
|
|
344
|
+
self._phase = phase
|
|
345
|
+
for layer in self.layers:
|
|
346
|
+
layer.phase = phase
|
|
347
|
+
|
|
348
|
+
@property
|
|
349
|
+
def attn_impl(self) -> str:
|
|
350
|
+
return "eager" if self.partition_len is None else "flash_attn"
|
|
351
|
+
|
|
352
|
+
@property
|
|
353
|
+
def hidden_multiplier(self):
|
|
354
|
+
return 1
|
|
355
|
+
|
|
356
|
+
def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
|
|
357
|
+
if self.attn_impl not in ["flash_attn"]:
|
|
358
|
+
raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
|
|
359
|
+
partition_len = self.partition_len
|
|
360
|
+
num_partition = max_seq_len // partition_len
|
|
361
|
+
|
|
362
|
+
cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
|
|
363
|
+
pidx = torch.arange(num_partition)
|
|
364
|
+
cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
|
|
365
|
+
return cache_pos_for_partitions
|
|
366
|
+
|
|
367
|
+
def get_local_cache_positions(self, position_ids, query_position):
|
|
368
|
+
max_cache_len = self._original_mod.config.sliding_window
|
|
369
|
+
valid_input_len = 1 if query_position is None else query_position + 1
|
|
370
|
+
cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
|
|
371
|
+
cache_offset = (
|
|
372
|
+
torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
|
|
373
|
+
) # cache offset for next steps
|
|
374
|
+
|
|
375
|
+
return cache_seq_len, cache_offset
|
|
376
|
+
|
|
377
|
+
def get_last_layernorm(self) -> nn.LayerNorm:
|
|
378
|
+
return self._original_mod.norm
|
|
379
|
+
|
|
380
|
+
def get_embedding(self) -> nn.Embedding:
|
|
381
|
+
return self._original_mod.embed_tokens
|
|
382
|
+
|
|
383
|
+
def get_pos_embedding(self) -> nn.Embedding:
|
|
384
|
+
raise NotImplementedError(
|
|
385
|
+
"The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
def forward(
|
|
389
|
+
self,
|
|
390
|
+
input_ids: torch.Tensor = None,
|
|
391
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
392
|
+
attention_mask: torch.Tensor = None,
|
|
393
|
+
cache_position: torch.Tensor = None,
|
|
394
|
+
position_ids: torch.Tensor = None,
|
|
395
|
+
query_position: torch.Tensor = None,
|
|
396
|
+
past_key_values: Tuple[Tuple[torch.Tensor]] = None,
|
|
397
|
+
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
398
|
+
global_block_tables: Optional[torch.Tensor] = None,
|
|
399
|
+
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
401
|
+
):
|
|
402
|
+
# retrieve input_ids and inputs_embeds
|
|
403
|
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
404
|
+
raise ValueError(
|
|
405
|
+
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# embed positions
|
|
409
|
+
if inputs_embeds is None:
|
|
410
|
+
inputs_embeds = self.get_embedding()(input_ids)
|
|
411
|
+
|
|
412
|
+
hidden_states = inputs_embeds * self.hidden_multiplier
|
|
413
|
+
|
|
414
|
+
# get cos,sin vector if needed
|
|
415
|
+
position_ids = position_ids if position_ids is not None else cache_position
|
|
416
|
+
if rotary_emb is not None:
|
|
417
|
+
if isinstance(rotary_emb, torch.Tensor):
|
|
418
|
+
cos = rotary_emb[0]
|
|
419
|
+
sin = rotary_emb[1]
|
|
420
|
+
else:
|
|
421
|
+
cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
|
|
422
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
|
423
|
+
|
|
424
|
+
elif self.use_learned_pos_emb:
|
|
425
|
+
batch_size = inputs_embeds.shape[0]
|
|
426
|
+
hidden_all = []
|
|
427
|
+
for i in range(batch_size):
|
|
428
|
+
positions_idx = position_ids[i]
|
|
429
|
+
position_weight = self.get_pos_embedding().weight[2:]
|
|
430
|
+
position = position_weight[positions_idx]
|
|
431
|
+
batch_hidden = position + inputs_embeds[i]
|
|
432
|
+
hidden_all.append(batch_hidden)
|
|
433
|
+
hidden_states = torch.stack(hidden_all, dim=0)
|
|
434
|
+
cos, sin = None, None
|
|
435
|
+
|
|
436
|
+
else:
|
|
437
|
+
batch_size = inputs_embeds.shape[0]
|
|
438
|
+
if position_ids.shape[0] > 1:
|
|
439
|
+
position_embeds = []
|
|
440
|
+
for b_idx in range(batch_size):
|
|
441
|
+
position_embed = self.get_pos_embedding()(position_ids[b_idx])
|
|
442
|
+
position_embeds.append(position_embed)
|
|
443
|
+
|
|
444
|
+
position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
|
|
445
|
+
else:
|
|
446
|
+
position_embeds = self.get_pos_embedding()(position_ids)
|
|
447
|
+
hidden_states = hidden_states + position_embeds
|
|
448
|
+
cos, sin = None, None
|
|
449
|
+
|
|
450
|
+
# Get sequence positions for flash attention
|
|
451
|
+
if self.attn_impl == "flash_attn":
|
|
452
|
+
seq_positions = cache_position[:, 0]
|
|
453
|
+
seq_positions = self.convert_sequence_positions_for_flash_attn(
|
|
454
|
+
seq_positions=seq_positions, max_seq_len=self.max_seq_len
|
|
455
|
+
)
|
|
456
|
+
else:
|
|
457
|
+
seq_positions = cache_position[:, :1]
|
|
458
|
+
|
|
459
|
+
# Get local cache positions for sliding window layers
|
|
460
|
+
if len(self.sliding_window_layers) > 0:
|
|
461
|
+
sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
|
|
462
|
+
|
|
463
|
+
for layer_idx, layer in enumerate(self.layers):
|
|
464
|
+
is_sliding = True if layer_idx in self.sliding_window_layers else False
|
|
465
|
+
hidden_states = layer(
|
|
466
|
+
hidden_states=hidden_states,
|
|
467
|
+
attention_mask=attention_mask,
|
|
468
|
+
seq_positions=sliding_cache_pos if is_sliding else seq_positions,
|
|
469
|
+
past_key_values=past_key_values,
|
|
470
|
+
cos=cos,
|
|
471
|
+
sin=sin,
|
|
472
|
+
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
473
|
+
lora_int_id=lora_int_id,
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
477
|
+
return hidden_states
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
class DecoderOnlyLayer(nn.Module):
|
|
481
|
+
"""A single transformer layer adapted for RBLN compilation with static shapes.
|
|
482
|
+
|
|
483
|
+
This layer implements a modified transformer block that includes:
|
|
484
|
+
1. Self-attention mechanism (either standard or flash attention)
|
|
485
|
+
2. Feed-forward network (FFN)
|
|
486
|
+
3. Layer normalization
|
|
487
|
+
4. Residual connections
|
|
488
|
+
|
|
489
|
+
The layer is specifically designed to:
|
|
490
|
+
- Support compilation to RBLN custom ops
|
|
491
|
+
- Maintain static tensor shapes throughout computations
|
|
492
|
+
- Handle both prefill and decode phases efficiently
|
|
493
|
+
- Manage attention state transitions properly
|
|
494
|
+
|
|
495
|
+
Args:
|
|
496
|
+
layer: Original transformer layer module to wrap
|
|
497
|
+
self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
|
|
498
|
+
|
|
499
|
+
Attributes:
|
|
500
|
+
_original_mod: Reference to original layer for accessing components
|
|
501
|
+
self_attn: Modified attention mechanism mapped to RBLN ops at compile time
|
|
502
|
+
phase: Current operation phase ("prefill" or "decode")
|
|
503
|
+
"""
|
|
504
|
+
|
|
505
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
506
|
+
super().__init__()
|
|
507
|
+
self._original_mod = layer
|
|
508
|
+
self.self_attn = self_attn
|
|
509
|
+
self._phase = "prefill"
|
|
510
|
+
self.lora_config = lora_config
|
|
511
|
+
|
|
512
|
+
# Replace target Linear modules in MLP with LoRALinear if configured
|
|
513
|
+
if self.lora_config:
|
|
514
|
+
mlp = self.get_mlp()
|
|
515
|
+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
|
|
516
|
+
if hasattr(mlp, proj_name):
|
|
517
|
+
original_linear = getattr(mlp, proj_name)
|
|
518
|
+
if isinstance(original_linear, nn.Linear):
|
|
519
|
+
lora_linear = LoRALinear(
|
|
520
|
+
original_linear=original_linear,
|
|
521
|
+
lora_config=self.lora_config,
|
|
522
|
+
projection_name=proj_name,
|
|
523
|
+
layer_idx=self.self_attn.layer_idx,
|
|
524
|
+
)
|
|
525
|
+
setattr(mlp, proj_name, lora_linear)
|
|
526
|
+
|
|
527
|
+
@property
|
|
528
|
+
def phase(self):
|
|
529
|
+
return self._phase
|
|
530
|
+
|
|
531
|
+
@phase.setter
|
|
532
|
+
def phase(self, phase: str):
|
|
533
|
+
self._phase = phase
|
|
534
|
+
self.self_attn.phase = phase
|
|
535
|
+
|
|
536
|
+
def get_pre_attention_layernorm(self) -> nn.LayerNorm:
|
|
537
|
+
return self._original_mod.input_layernorm
|
|
538
|
+
|
|
539
|
+
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
540
|
+
return self._original_mod.post_attention_layernorm
|
|
541
|
+
|
|
542
|
+
def get_mlp(self) -> nn.Module:
|
|
543
|
+
return self._original_mod.mlp
|
|
544
|
+
|
|
545
|
+
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
|
+
mlp = self.get_mlp()
|
|
547
|
+
if self.lora_config and lora_int_id is not None:
|
|
548
|
+
gate = mlp.gate_proj(hidden_states, lora_int_id)
|
|
549
|
+
up = mlp.up_proj(hidden_states, lora_int_id)
|
|
550
|
+
act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
|
|
551
|
+
if act_fn is None:
|
|
552
|
+
gate = torch.nn.functional.silu(gate)
|
|
553
|
+
else:
|
|
554
|
+
gate = act_fn(gate)
|
|
555
|
+
fused = gate * up
|
|
556
|
+
hidden_states = mlp.down_proj(fused, lora_int_id)
|
|
557
|
+
else:
|
|
558
|
+
hidden_states = mlp(hidden_states)
|
|
559
|
+
return hidden_states
|
|
560
|
+
|
|
561
|
+
def forward(
|
|
562
|
+
self,
|
|
563
|
+
hidden_states: torch.Tensor,
|
|
564
|
+
attention_mask: torch.Tensor,
|
|
565
|
+
seq_positions: torch.LongTensor,
|
|
566
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
567
|
+
cos: Optional[torch.Tensor] = None,
|
|
568
|
+
sin: Optional[torch.Tensor] = None,
|
|
569
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
570
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
571
|
+
):
|
|
572
|
+
residual = hidden_states
|
|
573
|
+
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
574
|
+
|
|
575
|
+
hidden_states = self.self_attn(
|
|
576
|
+
hidden_states=hidden_states,
|
|
577
|
+
attention_mask=attention_mask,
|
|
578
|
+
seq_positions=seq_positions,
|
|
579
|
+
past_key_values=past_key_values,
|
|
580
|
+
cos=cos,
|
|
581
|
+
sin=sin,
|
|
582
|
+
block_tables=block_tables,
|
|
583
|
+
lora_int_id=lora_int_id,
|
|
584
|
+
)
|
|
585
|
+
hidden_states = residual + hidden_states
|
|
586
|
+
|
|
587
|
+
# Fully Connected
|
|
588
|
+
residual = hidden_states
|
|
589
|
+
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
590
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
591
|
+
hidden_states = residual + hidden_states
|
|
592
|
+
|
|
593
|
+
return hidden_states
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class DecoderOnlyAttention(nn.Module):
|
|
597
|
+
"""Attention implementation for decoder-only models optimized for RBLN compilation.
|
|
598
|
+
|
|
599
|
+
This class implements a modified version of the standard attention mechanism that:
|
|
600
|
+
1. Supports static shape requirements for RBLN compilation
|
|
601
|
+
2. Handles explicit batch and position management
|
|
602
|
+
|
|
603
|
+
Args:
|
|
604
|
+
self_attn: Original attention module from the base model
|
|
605
|
+
rbln_config: RBLN model configuration containing attention parameters
|
|
606
|
+
is_sliding: Whether this is sliding window attention
|
|
607
|
+
"""
|
|
608
|
+
|
|
609
|
+
def __init__(
|
|
610
|
+
self,
|
|
611
|
+
self_attn,
|
|
612
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
613
|
+
is_sliding=False,
|
|
614
|
+
):
|
|
615
|
+
super().__init__()
|
|
616
|
+
self._original_mod = self_attn
|
|
617
|
+
self.rbln_config = rbln_config
|
|
618
|
+
self.layer_idx = self_attn.layer_idx
|
|
619
|
+
self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
|
|
620
|
+
self._original_mod.config, "num_attention_heads"
|
|
621
|
+
)
|
|
622
|
+
self.head_dim = self._original_mod.head_dim
|
|
623
|
+
self._phase = "prefill"
|
|
624
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
625
|
+
self.quantization = rbln_config.quantization
|
|
626
|
+
|
|
627
|
+
if hasattr(self._original_mod, "num_key_value_heads"):
|
|
628
|
+
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
|
629
|
+
elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
|
|
630
|
+
self.num_key_value_heads = self._original_mod.config.num_key_value_heads
|
|
631
|
+
else:
|
|
632
|
+
self.num_key_value_heads = self.num_heads
|
|
633
|
+
|
|
634
|
+
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
635
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
636
|
+
self.is_sliding = is_sliding
|
|
637
|
+
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
638
|
+
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
639
|
+
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
|
+
self.lora_config = rbln_config.lora_config
|
|
641
|
+
|
|
642
|
+
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
643
|
+
self.__post_init__()
|
|
644
|
+
|
|
645
|
+
def _init_lora_weights(self):
|
|
646
|
+
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
|
+
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
+
original_linear = getattr(self._original_mod, proj_name)
|
|
649
|
+
lora_linear = LoRALinear(
|
|
650
|
+
original_linear=original_linear,
|
|
651
|
+
lora_config=self.lora_config,
|
|
652
|
+
projection_name=proj_name,
|
|
653
|
+
layer_idx=self.layer_idx,
|
|
654
|
+
)
|
|
655
|
+
setattr(self, proj_name, lora_linear)
|
|
656
|
+
|
|
657
|
+
def get_attention_name(self):
|
|
658
|
+
if self.is_sliding:
|
|
659
|
+
return "sliding_window_attention"
|
|
660
|
+
elif self.attn_impl == "flash_attn":
|
|
661
|
+
return "flash_attention"
|
|
662
|
+
else:
|
|
663
|
+
return "attention"
|
|
664
|
+
|
|
665
|
+
def get_attention_op(self):
|
|
666
|
+
return getattr(self, self.get_attention_name())
|
|
667
|
+
|
|
668
|
+
@property
|
|
669
|
+
def phase(self):
|
|
670
|
+
return self._phase
|
|
671
|
+
|
|
672
|
+
@phase.setter
|
|
673
|
+
def phase(self, phase: str):
|
|
674
|
+
self._phase = phase
|
|
675
|
+
getattr(self, self.get_attention_name()).phase = phase
|
|
676
|
+
|
|
677
|
+
def create_attention_op(self):
|
|
678
|
+
if self.is_sliding:
|
|
679
|
+
return SlidingWindowAttentionOp(
|
|
680
|
+
self.num_heads,
|
|
681
|
+
self.head_dim,
|
|
682
|
+
self.num_key_value_heads,
|
|
683
|
+
self.use_attention_mask,
|
|
684
|
+
self.use_position_ids,
|
|
685
|
+
)
|
|
686
|
+
elif self.attn_impl == "flash_attn":
|
|
687
|
+
return FlashAttentionOp(
|
|
688
|
+
self.num_heads,
|
|
689
|
+
self.head_dim,
|
|
690
|
+
self.num_key_value_heads,
|
|
691
|
+
self.kvcache_partition_len,
|
|
692
|
+
self.use_attention_mask,
|
|
693
|
+
self.use_position_ids,
|
|
694
|
+
self.quantization,
|
|
695
|
+
)
|
|
696
|
+
elif self.attn_impl == "eager":
|
|
697
|
+
return AttentionOp(
|
|
698
|
+
self.num_heads,
|
|
699
|
+
self.head_dim,
|
|
700
|
+
self.num_key_value_heads,
|
|
701
|
+
self.use_attention_mask,
|
|
702
|
+
self.use_position_ids,
|
|
703
|
+
self.quantization,
|
|
704
|
+
)
|
|
705
|
+
else:
|
|
706
|
+
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
707
|
+
|
|
708
|
+
def __post_init__(self):
|
|
709
|
+
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
|
+
if self.lora_config:
|
|
711
|
+
self._init_lora_weights()
|
|
712
|
+
else:
|
|
713
|
+
# Use original linear layers if no LoRA
|
|
714
|
+
self.q_proj = self._original_mod.q_proj
|
|
715
|
+
self.k_proj = self._original_mod.k_proj
|
|
716
|
+
self.v_proj = self._original_mod.v_proj
|
|
717
|
+
self.o_proj = self._original_mod.o_proj
|
|
718
|
+
|
|
719
|
+
def projection(
|
|
720
|
+
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
721
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
722
|
+
"""Projects input hidden states into query, key, and value representations.
|
|
723
|
+
|
|
724
|
+
Args:
|
|
725
|
+
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
726
|
+
lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
|
|
727
|
+
|
|
728
|
+
Returns:
|
|
729
|
+
Tuple of (query_states, key_states, value_states)
|
|
730
|
+
"""
|
|
731
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
732
|
+
if self.lora_config:
|
|
733
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
734
|
+
query_states = self.q_proj(hidden_states, lora_int_id)
|
|
735
|
+
key_states = self.k_proj(hidden_states, lora_int_id)
|
|
736
|
+
value_states = self.v_proj(hidden_states, lora_int_id)
|
|
737
|
+
else:
|
|
738
|
+
# Standard linear projection without LoRA
|
|
739
|
+
query_states = self.q_proj(hidden_states)
|
|
740
|
+
key_states = self.k_proj(hidden_states)
|
|
741
|
+
value_states = self.v_proj(hidden_states)
|
|
742
|
+
|
|
743
|
+
return query_states, key_states, value_states
|
|
744
|
+
|
|
745
|
+
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
746
|
+
return apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
747
|
+
|
|
748
|
+
def get_attn_scale(self):
|
|
749
|
+
return 1 / math.sqrt(self.head_dim)
|
|
750
|
+
|
|
751
|
+
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
752
|
+
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
753
|
+
k_scale = getattr(self.k_proj, "k_scale", None)
|
|
754
|
+
v_scale = getattr(self.v_proj, "v_scale", None)
|
|
755
|
+
else:
|
|
756
|
+
k_scale = None
|
|
757
|
+
v_scale = None
|
|
758
|
+
|
|
759
|
+
return k_scale, v_scale
|
|
760
|
+
|
|
761
|
+
def forward(
|
|
762
|
+
self,
|
|
763
|
+
hidden_states: torch.Tensor,
|
|
764
|
+
attention_mask: torch.Tensor,
|
|
765
|
+
seq_positions: torch.LongTensor,
|
|
766
|
+
past_key_values: Tuple[Tuple[torch.Tensor]],
|
|
767
|
+
cos: Optional[torch.Tensor] = None,
|
|
768
|
+
sin: Optional[torch.Tensor] = None,
|
|
769
|
+
block_tables: Optional[torch.Tensor] = None,
|
|
770
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
771
|
+
):
|
|
772
|
+
batch_size, query_length, _ = hidden_states.size()
|
|
773
|
+
|
|
774
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
|
|
775
|
+
|
|
776
|
+
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
777
|
+
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
778
|
+
value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
|
|
779
|
+
1, 2
|
|
780
|
+
)
|
|
781
|
+
if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
|
|
782
|
+
query_states = self.q_norm(query_states)
|
|
783
|
+
key_states = self.k_norm(key_states)
|
|
784
|
+
|
|
785
|
+
if cos is not None and sin is not None:
|
|
786
|
+
query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
|
|
787
|
+
|
|
788
|
+
if batch_size > 1 and "prefill" in self.phase:
|
|
789
|
+
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
790
|
+
|
|
791
|
+
k_scale, v_scale = self.maybe_get_kvcache_scale()
|
|
792
|
+
|
|
793
|
+
attn_output = self.get_attention_op()(
|
|
794
|
+
query_states,
|
|
795
|
+
key_states,
|
|
796
|
+
value_states,
|
|
797
|
+
attention_mask,
|
|
798
|
+
past_key_state=past_key_values[self.layer_idx][0],
|
|
799
|
+
past_value_state=past_key_values[self.layer_idx][1],
|
|
800
|
+
seq_position=seq_positions,
|
|
801
|
+
scale=self.scale,
|
|
802
|
+
block_tables=block_tables,
|
|
803
|
+
block_size=self.kvcache_block_size,
|
|
804
|
+
k_scale=k_scale,
|
|
805
|
+
v_scale=v_scale,
|
|
806
|
+
)
|
|
807
|
+
|
|
808
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
809
|
+
if self.lora_config:
|
|
810
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
811
|
+
attn_outputs = self.o_proj(attn_output, lora_int_id)
|
|
812
|
+
else:
|
|
813
|
+
# Standard linear projection without LoRA
|
|
814
|
+
attn_outputs = self.o_proj(attn_output)
|
|
815
|
+
|
|
816
|
+
return attn_outputs
|
|
817
|
+
|
|
818
|
+
|
|
819
|
+
class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
820
|
+
def __init__(self, *args, **kwargs):
|
|
821
|
+
super().__init__(*args, **kwargs)
|
|
822
|
+
logger.warning(
|
|
823
|
+
"DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
|
|
824
|
+
)
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
class AttentionOp(nn.Module):
|
|
828
|
+
def __init__(
|
|
829
|
+
self,
|
|
830
|
+
num_heads: int,
|
|
831
|
+
head_dim: int,
|
|
832
|
+
num_key_value_heads: int,
|
|
833
|
+
use_attention_mask: bool,
|
|
834
|
+
use_position_ids: bool,
|
|
835
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
836
|
+
):
|
|
837
|
+
super().__init__()
|
|
838
|
+
self.num_heads = num_heads
|
|
839
|
+
self.head_dim = head_dim
|
|
840
|
+
self.num_key_value_heads = num_key_value_heads
|
|
841
|
+
self.phase = "prefill"
|
|
842
|
+
self.use_attention_mask = use_attention_mask
|
|
843
|
+
self.use_position_ids = use_position_ids
|
|
844
|
+
self.quantization = quantization
|
|
845
|
+
|
|
846
|
+
def get_attn_op_name(self):
|
|
847
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
848
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
849
|
+
attn_op_name = "paged_attn_"
|
|
850
|
+
else:
|
|
851
|
+
attn_op_name = "paged_causal_attn_"
|
|
852
|
+
|
|
853
|
+
attn_op_name += phase
|
|
854
|
+
|
|
855
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
856
|
+
attn_op_name += "_kv_fp8"
|
|
857
|
+
|
|
858
|
+
return attn_op_name
|
|
859
|
+
|
|
860
|
+
def forward(
|
|
861
|
+
self,
|
|
862
|
+
query_state: torch.Tensor,
|
|
863
|
+
key_state: torch.Tensor,
|
|
864
|
+
value_state: torch.Tensor,
|
|
865
|
+
attn_mask: torch.Tensor,
|
|
866
|
+
past_key_state: torch.Tensor,
|
|
867
|
+
past_value_state: torch.Tensor,
|
|
868
|
+
seq_position: torch.Tensor,
|
|
869
|
+
scale: torch.Tensor,
|
|
870
|
+
block_tables: torch.Tensor,
|
|
871
|
+
block_size: int,
|
|
872
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
873
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
874
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
875
|
+
"""Compute attention with static shapes and explicit cache management.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
query_state: Query tensor [1, num_heads, 1, head_dim]
|
|
879
|
+
key_state: Key tensor [1, num_heads, seq_len, head_dim]
|
|
880
|
+
value_state: Value tensor [1, num_heads, seq_len, head_dim]
|
|
881
|
+
attn_mask: Attention mask tensor ∈ {0, 1}
|
|
882
|
+
past_key_state: Previous key cache states
|
|
883
|
+
past_value_state: Previous value cache states
|
|
884
|
+
seq_position: Current position in sequence
|
|
885
|
+
scale: Scale applied to attn weights
|
|
886
|
+
block_tables: Block tables for paged attention
|
|
887
|
+
block_size: Block size for paged attention
|
|
888
|
+
k_scale: Scale applied to key
|
|
889
|
+
v_scale: Scale applied to value
|
|
890
|
+
|
|
891
|
+
Returns:
|
|
892
|
+
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
893
|
+
"""
|
|
894
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
895
|
+
key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
|
|
896
|
+
value_state = value_state.unsqueeze(2)
|
|
897
|
+
|
|
898
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
899
|
+
attn_mask = attn_mask.unsqueeze(2)
|
|
900
|
+
|
|
901
|
+
if self.phase == "decode":
|
|
902
|
+
batch_size = key_state.shape[0]
|
|
903
|
+
else:
|
|
904
|
+
batch_size = 1
|
|
905
|
+
|
|
906
|
+
query_state = query_state.view(
|
|
907
|
+
batch_size,
|
|
908
|
+
self.num_key_value_heads,
|
|
909
|
+
self.num_heads // self.num_key_value_heads,
|
|
910
|
+
-1, # seq len
|
|
911
|
+
self.head_dim,
|
|
912
|
+
)
|
|
913
|
+
|
|
914
|
+
op_args = {
|
|
915
|
+
"q": query_state,
|
|
916
|
+
"k": key_state,
|
|
917
|
+
"v": value_state,
|
|
918
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
919
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
920
|
+
"seq": seq_position,
|
|
921
|
+
"scale": scale,
|
|
922
|
+
"block_table": block_tables,
|
|
923
|
+
"block_size": block_size,
|
|
924
|
+
}
|
|
925
|
+
|
|
926
|
+
if self.use_attention_mask:
|
|
927
|
+
op_args["mask"] = attn_mask
|
|
928
|
+
|
|
929
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
930
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
931
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
932
|
+
|
|
933
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
934
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
935
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
936
|
+
op_args["k_scale"] = k_scale
|
|
937
|
+
op_args["v_scale"] = v_scale
|
|
938
|
+
|
|
939
|
+
attn_op_name = self.get_attn_op_name()
|
|
940
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
941
|
+
if attn_op is None:
|
|
942
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
943
|
+
|
|
944
|
+
attn_output = attn_op(**op_args)
|
|
945
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
946
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
947
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
948
|
+
|
|
949
|
+
return attn_output
|
|
950
|
+
|
|
951
|
+
|
|
952
|
+
class FlashAttentionOp(AttentionOp):
|
|
953
|
+
def __init__(
|
|
954
|
+
self,
|
|
955
|
+
num_heads: int,
|
|
956
|
+
head_dim: int,
|
|
957
|
+
num_key_value_heads: int,
|
|
958
|
+
kvcache_partition_len: int,
|
|
959
|
+
use_attention_mask: bool,
|
|
960
|
+
use_position_ids: bool,
|
|
961
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
962
|
+
):
|
|
963
|
+
super().__init__(
|
|
964
|
+
num_heads=num_heads,
|
|
965
|
+
head_dim=head_dim,
|
|
966
|
+
num_key_value_heads=num_key_value_heads,
|
|
967
|
+
use_attention_mask=use_attention_mask,
|
|
968
|
+
use_position_ids=use_position_ids,
|
|
969
|
+
quantization=quantization,
|
|
970
|
+
)
|
|
971
|
+
self.kvcache_partition_size = kvcache_partition_len
|
|
972
|
+
|
|
973
|
+
def get_attn_op_name(self):
|
|
974
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
975
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
976
|
+
attn_op_name = "paged_flash_attn_"
|
|
977
|
+
else:
|
|
978
|
+
attn_op_name = "paged_flash_causal_attn_"
|
|
979
|
+
|
|
980
|
+
attn_op_name += phase
|
|
981
|
+
|
|
982
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
983
|
+
attn_op_name += "_kv_fp8"
|
|
984
|
+
|
|
985
|
+
return attn_op_name
|
|
986
|
+
|
|
987
|
+
def forward(
|
|
988
|
+
self,
|
|
989
|
+
query_state,
|
|
990
|
+
key_state,
|
|
991
|
+
value_state,
|
|
992
|
+
attn_mask,
|
|
993
|
+
past_key_state,
|
|
994
|
+
past_value_state,
|
|
995
|
+
seq_position,
|
|
996
|
+
scale,
|
|
997
|
+
block_tables,
|
|
998
|
+
block_size,
|
|
999
|
+
k_scale=None,
|
|
1000
|
+
v_scale=None,
|
|
1001
|
+
):
|
|
1002
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1003
|
+
key_state = key_state.unsqueeze(2)
|
|
1004
|
+
value_state = value_state.unsqueeze(2)
|
|
1005
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
1006
|
+
attn_mask = attn_mask.unsqueeze(2)
|
|
1007
|
+
|
|
1008
|
+
if self.phase == "decode":
|
|
1009
|
+
batch_size = key_state.shape[0]
|
|
1010
|
+
else:
|
|
1011
|
+
batch_size = 1
|
|
1012
|
+
|
|
1013
|
+
query_state = query_state.view(
|
|
1014
|
+
batch_size,
|
|
1015
|
+
self.num_key_value_heads,
|
|
1016
|
+
self.num_heads // self.num_key_value_heads,
|
|
1017
|
+
-1, # seq len
|
|
1018
|
+
self.head_dim,
|
|
1019
|
+
)
|
|
1020
|
+
|
|
1021
|
+
op_args = {
|
|
1022
|
+
"q": query_state,
|
|
1023
|
+
"k": key_state,
|
|
1024
|
+
"v": value_state,
|
|
1025
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1026
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1027
|
+
"seq": seq_position,
|
|
1028
|
+
"scale": scale,
|
|
1029
|
+
"block_table": block_tables,
|
|
1030
|
+
"block_size": block_size,
|
|
1031
|
+
"partition": self.kvcache_partition_size,
|
|
1032
|
+
}
|
|
1033
|
+
|
|
1034
|
+
if self.use_attention_mask:
|
|
1035
|
+
op_args["mask"] = attn_mask
|
|
1036
|
+
|
|
1037
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1038
|
+
if not self.use_attention_mask or self.use_position_ids:
|
|
1039
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1040
|
+
|
|
1041
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
1042
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
1043
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
1044
|
+
op_args["k_scale"] = k_scale
|
|
1045
|
+
op_args["v_scale"] = v_scale
|
|
1046
|
+
|
|
1047
|
+
attn_op_name = self.get_attn_op_name()
|
|
1048
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1049
|
+
if attn_op is None:
|
|
1050
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1051
|
+
|
|
1052
|
+
attn_output = attn_op(**op_args)
|
|
1053
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1054
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1055
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1056
|
+
|
|
1057
|
+
return attn_output
|
|
1058
|
+
|
|
1059
|
+
|
|
1060
|
+
class SlidingWindowAttentionOp(AttentionOp):
|
|
1061
|
+
def get_attn_op_name(self):
|
|
1062
|
+
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1063
|
+
if not self.use_attention_mask:
|
|
1064
|
+
raise NotImplementedError("Attention mask is needed for sliding window attention.")
|
|
1065
|
+
|
|
1066
|
+
attn_op_name = "paged_sliding_window_attn_" + phase
|
|
1067
|
+
return attn_op_name
|
|
1068
|
+
|
|
1069
|
+
def forward(
|
|
1070
|
+
self,
|
|
1071
|
+
query_state: torch.Tensor,
|
|
1072
|
+
key_state: torch.Tensor,
|
|
1073
|
+
value_state: torch.Tensor,
|
|
1074
|
+
attn_mask: Optional[torch.Tensor],
|
|
1075
|
+
past_key_state: torch.Tensor,
|
|
1076
|
+
past_value_state: torch.Tensor,
|
|
1077
|
+
seq_position: Tuple[torch.Tensor],
|
|
1078
|
+
scale: torch.Tensor,
|
|
1079
|
+
block_tables: torch.Tensor,
|
|
1080
|
+
block_size: int,
|
|
1081
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
1082
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
1083
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1084
|
+
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1085
|
+
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
1086
|
+
|
|
1087
|
+
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1088
|
+
key_state = key_state.unsqueeze(2)
|
|
1089
|
+
value_state = value_state.unsqueeze(2)
|
|
1090
|
+
|
|
1091
|
+
if self.phase == "decode":
|
|
1092
|
+
batch_size = key_state.shape[0]
|
|
1093
|
+
else:
|
|
1094
|
+
batch_size = 1
|
|
1095
|
+
|
|
1096
|
+
query_state = query_state.view(
|
|
1097
|
+
batch_size,
|
|
1098
|
+
self.num_key_value_heads,
|
|
1099
|
+
self.num_heads // self.num_key_value_heads,
|
|
1100
|
+
-1, # seq len
|
|
1101
|
+
self.head_dim,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
op_args = {
|
|
1105
|
+
"q": query_state,
|
|
1106
|
+
"k": key_state,
|
|
1107
|
+
"v": value_state,
|
|
1108
|
+
"kcache": past_key_state.unsqueeze(2),
|
|
1109
|
+
"vcache": past_value_state.unsqueeze(2),
|
|
1110
|
+
"cache_seq_len": seq_position[0],
|
|
1111
|
+
"cache_offset": seq_position[1],
|
|
1112
|
+
"scale": scale,
|
|
1113
|
+
"block_table": block_tables,
|
|
1114
|
+
"block_size": block_size,
|
|
1115
|
+
}
|
|
1116
|
+
|
|
1117
|
+
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1118
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1119
|
+
|
|
1120
|
+
attn_op_name = self.get_attn_op_name()
|
|
1121
|
+
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1122
|
+
if attn_op is None:
|
|
1123
|
+
raise ValueError(f"Attention operator {attn_op_name} not found.")
|
|
1124
|
+
|
|
1125
|
+
attn_output = attn_op(**op_args)
|
|
1126
|
+
attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
|
|
1127
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
|
1128
|
+
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1129
|
+
|
|
1130
|
+
return attn_output
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
class RotaryEmbedding(nn.Module):
|
|
1134
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1135
|
+
|
|
1136
|
+
def __init__(
|
|
1137
|
+
self,
|
|
1138
|
+
config: PretrainedConfig,
|
|
1139
|
+
max_seq_len_cached: int,
|
|
1140
|
+
):
|
|
1141
|
+
super().__init__()
|
|
1142
|
+
|
|
1143
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1144
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1145
|
+
else:
|
|
1146
|
+
rope_type = "default"
|
|
1147
|
+
|
|
1148
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1149
|
+
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1150
|
+
cache_position_expanded = cache_position[:, None]
|
|
1151
|
+
|
|
1152
|
+
if rope_type == "dynamic":
|
|
1153
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1154
|
+
else:
|
|
1155
|
+
inv_freq_expanded = inv_freq[None, :]
|
|
1156
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1157
|
+
|
|
1158
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1159
|
+
|
|
1160
|
+
cos = emb.cos() * attention_scaling
|
|
1161
|
+
sin = emb.sin() * attention_scaling
|
|
1162
|
+
|
|
1163
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1164
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1165
|
+
|
|
1166
|
+
def forward(self, x, seq_len):
|
|
1167
|
+
return (
|
|
1168
|
+
self._cos_cached[:seq_len].to(dtype=torch.float32),
|
|
1169
|
+
self._sin_cached[:seq_len].to(dtype=torch.float32),
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
1174
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
1175
|
+
if cache_position.shape[0] > 1:
|
|
1176
|
+
cos_all = []
|
|
1177
|
+
sin_all = []
|
|
1178
|
+
for i in range(cache_position.shape[0]):
|
|
1179
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1180
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1181
|
+
cos = torch.cat(cos_all, dim=0)
|
|
1182
|
+
sin = torch.cat(sin_all, dim=0)
|
|
1183
|
+
else:
|
|
1184
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
1185
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
1186
|
+
|
|
1187
|
+
return cos, sin
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def rotate_half(x):
|
|
1191
|
+
"""Rotates half the hidden dims of the input."""
|
|
1192
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1193
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1194
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1198
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1199
|
+
dtype = q.dtype
|
|
1200
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1201
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1202
|
+
q_embed = q_embed.to(dtype)
|
|
1203
|
+
k_embed = k_embed.to(dtype)
|
|
1204
|
+
return q_embed, k_embed
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1208
|
+
# Partial rotary embedding
|
|
1209
|
+
query_rot, query_pass = (
|
|
1210
|
+
query_states[..., :ndim],
|
|
1211
|
+
query_states[..., ndim:],
|
|
1212
|
+
)
|
|
1213
|
+
key_rot, key_pass = (
|
|
1214
|
+
key_states[..., :ndim],
|
|
1215
|
+
key_states[..., ndim:],
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1219
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1220
|
+
|
|
1221
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
|
1222
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1223
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1224
|
+
return query_states, key_states
|