optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import math
|
|
16
|
-
from typing import List, Optional, Tuple, Union
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import nn
|
|
@@ -21,106 +21,16 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
|
-
from .
|
|
24
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
26
|
+
from .lora_architecture import LoRALinear
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
30
|
-
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
31
|
-
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
32
|
-
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
33
|
-
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
34
|
-
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_default_values(
|
|
38
|
-
attn_impl: Optional[str] = None,
|
|
39
|
-
kvcache_partition_len: Optional[int] = None,
|
|
40
|
-
kvcache_block_size: Optional[int] = None,
|
|
41
|
-
max_seq_len: Optional[int] = None,
|
|
42
|
-
) -> Tuple[str, int, int]:
|
|
43
|
-
if attn_impl is None:
|
|
44
|
-
attn_impl = "eager"
|
|
45
|
-
|
|
46
|
-
if kvcache_partition_len is not None:
|
|
47
|
-
if attn_impl == "eager":
|
|
48
|
-
attn_impl = "flash_attn"
|
|
49
|
-
logger.warning(
|
|
50
|
-
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
51
|
-
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
52
|
-
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
56
|
-
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
57
|
-
|
|
58
|
-
if kvcache_block_size is None:
|
|
59
|
-
if attn_impl == "eager":
|
|
60
|
-
kvcache_block_size = max_seq_len
|
|
61
|
-
else:
|
|
62
|
-
kvcache_block_size = kvcache_partition_len
|
|
63
|
-
|
|
64
|
-
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
68
|
-
if attn_impl not in ["eager", "flash_attn"]:
|
|
69
|
-
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
|
|
70
31
|
|
|
71
|
-
## Checking Constraints...
|
|
72
|
-
# Constraint of eager attention:
|
|
73
|
-
# - `max_seq_len` <= 32k
|
|
74
32
|
|
|
75
|
-
|
|
76
|
-
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
77
|
-
# 2. 4k <= `partition_len` <= 32k.
|
|
78
|
-
# 3. `max_seq_len` should be larger then 8k.
|
|
79
|
-
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f"`max_seq_len` is set to {max_seq_len}, "
|
|
82
|
-
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
83
|
-
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
84
|
-
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if attn_impl == "flash_attn":
|
|
88
|
-
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
91
|
-
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
92
|
-
)
|
|
93
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
96
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
97
|
-
f"Please provide a valid value within this range."
|
|
98
|
-
)
|
|
99
|
-
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
102
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
103
|
-
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if kvcache_block_size is not None:
|
|
107
|
-
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
108
|
-
raise ValueError(
|
|
109
|
-
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
110
|
-
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
111
|
-
)
|
|
112
|
-
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
115
|
-
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
|
|
120
|
-
if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
|
|
123
|
-
)
|
|
33
|
+
logger = logging.get_logger(__name__)
|
|
124
34
|
|
|
125
35
|
|
|
126
36
|
class DecoderOnlyWrapper(nn.Module):
|
|
@@ -137,40 +47,22 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
137
47
|
- Wrapper should not contain neural network graph operations (including memory view handling)
|
|
138
48
|
|
|
139
49
|
Args:
|
|
140
|
-
|
|
141
|
-
|
|
50
|
+
model (PreTrainedModel): The Huggingface causal language model to wrap
|
|
51
|
+
rbln_config: The RBLN model configuration containing all necessary parameters
|
|
142
52
|
use_rotary_emb (bool): Whether to use rotary position embeddings
|
|
143
|
-
attn_impl (str): The attention implementation to use.
|
|
144
|
-
- "eager": Uses the standard attention.
|
|
145
|
-
- "flash_attn": Uses flash attention. When set,
|
|
146
|
-
the key/value cache is partitioned into chunks of length
|
|
147
|
-
`kvcache_partition_len`.
|
|
148
|
-
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
|
149
|
-
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
150
53
|
"""
|
|
151
54
|
|
|
152
55
|
_use_learned_pos_emb = False
|
|
153
56
|
|
|
154
|
-
def __init__(
|
|
155
|
-
self,
|
|
156
|
-
causal_lm: PreTrainedModel,
|
|
157
|
-
max_seq_len: int,
|
|
158
|
-
use_rotary_emb: bool,
|
|
159
|
-
attn_impl: str,
|
|
160
|
-
cache_impl: CacheImplType,
|
|
161
|
-
use_inputs_embeds: bool,
|
|
162
|
-
use_attention_mask: bool,
|
|
163
|
-
use_position_ids: bool,
|
|
164
|
-
kvcache_partition_len: Optional[int] = None,
|
|
165
|
-
kvcache_block_size: Optional[int] = None,
|
|
166
|
-
sliding_window: Optional[int] = None,
|
|
167
|
-
sliding_window_layers: Optional[List[int]] = None,
|
|
168
|
-
):
|
|
57
|
+
def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
|
|
169
58
|
super().__init__()
|
|
170
|
-
self.
|
|
59
|
+
self.quantization = rbln_config.quantization
|
|
60
|
+
self.config = model.config
|
|
61
|
+
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
62
|
+
self.rbln_config = rbln_config
|
|
171
63
|
|
|
172
64
|
if use_rotary_emb:
|
|
173
|
-
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
65
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
|
|
174
66
|
if isinstance(rotary_embs, tuple):
|
|
175
67
|
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
|
176
68
|
else:
|
|
@@ -178,43 +70,27 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
178
70
|
else:
|
|
179
71
|
self.rotary_emb = None
|
|
180
72
|
|
|
181
|
-
|
|
182
|
-
self.kvcache_block_size = kvcache_block_size
|
|
183
|
-
self.use_attention_mask = use_attention_mask
|
|
184
|
-
self.use_position_ids = use_position_ids
|
|
185
|
-
self.use_inputs_embeds = use_inputs_embeds
|
|
186
|
-
self.sliding_window_layers = sliding_window_layers
|
|
187
|
-
self.cache_impl = cache_impl
|
|
188
|
-
self.sliding_window = sliding_window
|
|
189
|
-
|
|
190
|
-
if self.attn_impl == "flash_attn":
|
|
191
|
-
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
192
|
-
elif self.attn_impl == "eager":
|
|
193
|
-
self.kvcache_partition_len = None
|
|
194
|
-
else:
|
|
195
|
-
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
|
196
|
-
|
|
197
|
-
if kvcache_partition_len and kvcache_partition_len > max_seq_len:
|
|
73
|
+
if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
|
|
198
74
|
raise ValueError(
|
|
199
|
-
f"kvcache_partition_len({kvcache_partition_len}) should be lower"
|
|
200
|
-
f" or equal to max_seq_len({max_seq_len})!"
|
|
75
|
+
f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
|
|
76
|
+
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
201
77
|
)
|
|
202
78
|
|
|
203
|
-
self.
|
|
79
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
204
80
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
205
81
|
self._phase = "prefill"
|
|
206
82
|
|
|
207
83
|
def get_rotary_emb(self, max_seq_len):
|
|
208
84
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
85
|
|
|
210
|
-
def get_decoder_layers(self,
|
|
211
|
-
return
|
|
86
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
87
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
212
88
|
|
|
213
89
|
def get_attn_layer(self, layer: nn.Module):
|
|
214
90
|
return layer.self_attn
|
|
215
91
|
|
|
216
|
-
def get_model_layer(self,
|
|
217
|
-
return
|
|
92
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
93
|
+
return model.model if self.is_causal_lm else model
|
|
218
94
|
|
|
219
95
|
def get_rbln_attn_class(self):
|
|
220
96
|
return DecoderOnlyAttention
|
|
@@ -228,34 +104,28 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
228
104
|
def get_rbln_causal_lm_class(self):
|
|
229
105
|
return DecoderOnlyForCausalLM
|
|
230
106
|
|
|
231
|
-
def
|
|
107
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
232
108
|
new_layers = []
|
|
233
|
-
for layer_idx, layer in enumerate(self.get_decoder_layers(
|
|
109
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
110
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
234
111
|
new_self_attn = self.get_rbln_attn_class()(
|
|
235
|
-
self.get_attn_layer(layer),
|
|
236
|
-
self.use_attention_mask,
|
|
237
|
-
self.use_position_ids,
|
|
238
|
-
kvcache_block_size=self.sliding_window
|
|
239
|
-
if layer_idx in self.sliding_window_layers
|
|
240
|
-
else self.kvcache_block_size,
|
|
241
|
-
is_sliding=layer_idx in self.sliding_window_layers,
|
|
242
|
-
attn_impl=self.attn_impl,
|
|
243
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
112
|
+
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
244
113
|
)
|
|
245
|
-
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
114
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
|
|
246
115
|
new_layers.append(new_layer)
|
|
247
116
|
|
|
248
117
|
new_model = self.get_rbln_model_class()(
|
|
249
|
-
self.get_model_layer(
|
|
118
|
+
self.get_model_layer(model),
|
|
250
119
|
new_layers,
|
|
251
|
-
|
|
252
|
-
max_seq_len=max_seq_len,
|
|
253
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
120
|
+
self.rbln_config,
|
|
254
121
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
255
|
-
sliding_window_layers=self.sliding_window_layers,
|
|
256
122
|
)
|
|
257
|
-
|
|
258
|
-
|
|
123
|
+
|
|
124
|
+
if self.is_causal_lm:
|
|
125
|
+
new_model = self.get_rbln_causal_lm_class()(model, new_model)
|
|
126
|
+
return new_model
|
|
127
|
+
else:
|
|
128
|
+
return new_model
|
|
259
129
|
|
|
260
130
|
@property
|
|
261
131
|
def phase(self) -> str:
|
|
@@ -264,18 +134,24 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
264
134
|
@phase.setter
|
|
265
135
|
def phase(self, phase: str):
|
|
266
136
|
self._phase = phase
|
|
267
|
-
self.
|
|
137
|
+
self.model.phase = phase
|
|
268
138
|
|
|
269
139
|
def prepare_forward_args(self, *args):
|
|
270
140
|
args = list(args)
|
|
271
|
-
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
272
|
-
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
141
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
142
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
273
143
|
cache_position = args.pop(0)
|
|
274
|
-
global_block_tables = args.pop(0) if self.
|
|
275
|
-
local_block_tables = args.pop(0) if self.
|
|
276
|
-
query_position =
|
|
277
|
-
|
|
278
|
-
|
|
144
|
+
global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
|
|
145
|
+
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
146
|
+
query_position = (
|
|
147
|
+
args.pop(0)
|
|
148
|
+
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
149
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
|
|
150
|
+
else None
|
|
151
|
+
)
|
|
152
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
153
|
+
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
154
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
279
155
|
past_key_values = args
|
|
280
156
|
|
|
281
157
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
@@ -307,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
307
183
|
query_position,
|
|
308
184
|
attention_mask,
|
|
309
185
|
position_ids,
|
|
186
|
+
lora_int_id,
|
|
310
187
|
past_key_values,
|
|
311
188
|
rotary_emb,
|
|
312
189
|
)
|
|
@@ -321,11 +198,12 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
321
198
|
query_position,
|
|
322
199
|
attention_mask,
|
|
323
200
|
position_ids,
|
|
201
|
+
lora_int_id,
|
|
324
202
|
past_key_values,
|
|
325
203
|
rotary_emb,
|
|
326
204
|
) = self.prepare_forward_args(*args)
|
|
327
205
|
|
|
328
|
-
logit = self.
|
|
206
|
+
logit = self.model(
|
|
329
207
|
input_ids=input_ids,
|
|
330
208
|
inputs_embeds=inputs_embeds,
|
|
331
209
|
attention_mask=attention_mask,
|
|
@@ -336,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
336
214
|
rotary_emb=rotary_emb,
|
|
337
215
|
global_block_tables=global_block_tables,
|
|
338
216
|
local_block_tables=local_block_tables,
|
|
217
|
+
lora_int_id=lora_int_id,
|
|
339
218
|
)
|
|
340
219
|
|
|
341
220
|
return logit
|
|
@@ -392,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
392
271
|
rotary_emb: nn.Module = None,
|
|
393
272
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
394
273
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
395
275
|
):
|
|
396
276
|
# outputs
|
|
397
277
|
hidden_states = self.model(
|
|
@@ -405,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
405
285
|
rotary_emb=rotary_emb,
|
|
406
286
|
global_block_tables=global_block_tables,
|
|
407
287
|
local_block_tables=local_block_tables,
|
|
288
|
+
lora_int_id=lora_int_id,
|
|
408
289
|
)
|
|
409
290
|
|
|
410
291
|
if "prefill" in self.phase:
|
|
@@ -427,6 +308,8 @@ class DecoderOnlyModel(nn.Module):
|
|
|
427
308
|
Args:
|
|
428
309
|
model: Original Huggingface model to adapt
|
|
429
310
|
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
|
311
|
+
rbln_config: RBLN model configuration
|
|
312
|
+
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
430
313
|
|
|
431
314
|
Attributes:
|
|
432
315
|
_original_mod: Reference to original Huggingface model
|
|
@@ -438,21 +321,19 @@ class DecoderOnlyModel(nn.Module):
|
|
|
438
321
|
self,
|
|
439
322
|
model,
|
|
440
323
|
layers: List["DecoderOnlyLayer"],
|
|
441
|
-
|
|
442
|
-
max_seq_len=None,
|
|
443
|
-
kvcache_block_size=None,
|
|
324
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
444
325
|
use_learned_pos_emb=None,
|
|
445
|
-
sliding_window_layers=None,
|
|
446
326
|
):
|
|
447
327
|
super().__init__()
|
|
448
328
|
self._original_mod = model
|
|
449
329
|
self.layers = nn.ModuleList(layers)
|
|
330
|
+
self.rbln_config = rbln_config
|
|
450
331
|
self._phase = "prefill"
|
|
451
|
-
self.partition_len =
|
|
452
|
-
self.kvcache_block_size = kvcache_block_size
|
|
453
|
-
self.max_seq_len = max_seq_len
|
|
332
|
+
self.partition_len = rbln_config.kvcache_partition_len
|
|
333
|
+
self.kvcache_block_size = rbln_config.kvcache_block_size
|
|
334
|
+
self.max_seq_len = rbln_config.max_seq_len
|
|
454
335
|
self.use_learned_pos_emb = use_learned_pos_emb
|
|
455
|
-
self.sliding_window_layers = sliding_window_layers
|
|
336
|
+
self.sliding_window_layers = rbln_config.sliding_window_layers
|
|
456
337
|
|
|
457
338
|
@property
|
|
458
339
|
def phase(self):
|
|
@@ -516,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
516
397
|
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
517
398
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
518
399
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
519
401
|
):
|
|
520
402
|
# retrieve input_ids and inputs_embeds
|
|
521
403
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -588,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
588
470
|
cos=cos,
|
|
589
471
|
sin=sin,
|
|
590
472
|
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
473
|
+
lora_int_id=lora_int_id,
|
|
591
474
|
)
|
|
592
475
|
|
|
593
476
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
@@ -619,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
619
502
|
phase: Current operation phase ("prefill" or "decode")
|
|
620
503
|
"""
|
|
621
504
|
|
|
622
|
-
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
|
505
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
623
506
|
super().__init__()
|
|
624
507
|
self._original_mod = layer
|
|
625
508
|
self.self_attn = self_attn
|
|
626
509
|
self._phase = "prefill"
|
|
510
|
+
self.lora_config = lora_config
|
|
511
|
+
|
|
512
|
+
# Replace target Linear modules in MLP with LoRALinear if configured
|
|
513
|
+
if self.lora_config:
|
|
514
|
+
mlp = self.get_mlp()
|
|
515
|
+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
|
|
516
|
+
if hasattr(mlp, proj_name):
|
|
517
|
+
original_linear = getattr(mlp, proj_name)
|
|
518
|
+
if isinstance(original_linear, nn.Linear):
|
|
519
|
+
lora_linear = LoRALinear(
|
|
520
|
+
original_linear=original_linear,
|
|
521
|
+
lora_config=self.lora_config,
|
|
522
|
+
projection_name=proj_name,
|
|
523
|
+
layer_idx=self.self_attn.layer_idx,
|
|
524
|
+
)
|
|
525
|
+
setattr(mlp, proj_name, lora_linear)
|
|
627
526
|
|
|
628
527
|
@property
|
|
629
528
|
def phase(self):
|
|
@@ -640,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
640
539
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
641
540
|
return self._original_mod.post_attention_layernorm
|
|
642
541
|
|
|
542
|
+
def get_mlp(self) -> nn.Module:
|
|
543
|
+
return self._original_mod.mlp
|
|
544
|
+
|
|
545
|
+
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
|
+
mlp = self.get_mlp()
|
|
547
|
+
if self.lora_config and lora_int_id is not None:
|
|
548
|
+
gate = mlp.gate_proj(hidden_states, lora_int_id)
|
|
549
|
+
up = mlp.up_proj(hidden_states, lora_int_id)
|
|
550
|
+
act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
|
|
551
|
+
if act_fn is None:
|
|
552
|
+
gate = torch.nn.functional.silu(gate)
|
|
553
|
+
else:
|
|
554
|
+
gate = act_fn(gate)
|
|
555
|
+
fused = gate * up
|
|
556
|
+
hidden_states = mlp.down_proj(fused, lora_int_id)
|
|
557
|
+
else:
|
|
558
|
+
hidden_states = mlp(hidden_states)
|
|
559
|
+
return hidden_states
|
|
560
|
+
|
|
643
561
|
def forward(
|
|
644
562
|
self,
|
|
645
563
|
hidden_states: torch.Tensor,
|
|
@@ -649,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
649
567
|
cos: Optional[torch.Tensor] = None,
|
|
650
568
|
sin: Optional[torch.Tensor] = None,
|
|
651
569
|
block_tables: Optional[torch.Tensor] = None,
|
|
570
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
652
571
|
):
|
|
653
572
|
residual = hidden_states
|
|
654
573
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
@@ -661,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
661
580
|
cos=cos,
|
|
662
581
|
sin=sin,
|
|
663
582
|
block_tables=block_tables,
|
|
583
|
+
lora_int_id=lora_int_id,
|
|
664
584
|
)
|
|
665
585
|
hidden_states = residual + hidden_states
|
|
666
586
|
|
|
667
587
|
# Fully Connected
|
|
668
588
|
residual = hidden_states
|
|
669
589
|
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
670
|
-
hidden_states = self.
|
|
590
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
671
591
|
hidden_states = residual + hidden_states
|
|
672
592
|
|
|
673
593
|
return hidden_states
|
|
@@ -682,32 +602,27 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
682
602
|
|
|
683
603
|
Args:
|
|
684
604
|
self_attn: Original attention module from the base model
|
|
685
|
-
|
|
686
|
-
use_position_ids: Whether to use position ids
|
|
687
|
-
kvcache_block_size: Block size for KV cache
|
|
605
|
+
rbln_config: RBLN model configuration containing attention parameters
|
|
688
606
|
is_sliding: Whether this is sliding window attention
|
|
689
|
-
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
690
607
|
"""
|
|
691
608
|
|
|
692
609
|
def __init__(
|
|
693
610
|
self,
|
|
694
611
|
self_attn,
|
|
695
|
-
|
|
696
|
-
use_position_ids,
|
|
697
|
-
kvcache_block_size,
|
|
612
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
698
613
|
is_sliding=False,
|
|
699
|
-
attn_impl="eager",
|
|
700
|
-
kvcache_partition_len=None,
|
|
701
614
|
):
|
|
702
615
|
super().__init__()
|
|
703
616
|
self._original_mod = self_attn
|
|
617
|
+
self.rbln_config = rbln_config
|
|
704
618
|
self.layer_idx = self_attn.layer_idx
|
|
705
619
|
self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
|
|
706
620
|
self._original_mod.config, "num_attention_heads"
|
|
707
621
|
)
|
|
708
622
|
self.head_dim = self._original_mod.head_dim
|
|
709
623
|
self._phase = "prefill"
|
|
710
|
-
self.scale = torch.tensor(self.get_attn_scale())
|
|
624
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
625
|
+
self.quantization = rbln_config.quantization
|
|
711
626
|
|
|
712
627
|
if hasattr(self._original_mod, "num_key_value_heads"):
|
|
713
628
|
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
|
@@ -716,20 +631,29 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
716
631
|
else:
|
|
717
632
|
self.num_key_value_heads = self.num_heads
|
|
718
633
|
|
|
719
|
-
self.use_attention_mask = use_attention_mask
|
|
720
|
-
self.use_position_ids = use_position_ids
|
|
634
|
+
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
635
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
721
636
|
self.is_sliding = is_sliding
|
|
722
|
-
self.attn_impl = attn_impl
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
self.kvcache_partition_len = kvcache_partition_len
|
|
637
|
+
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
638
|
+
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
639
|
+
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
|
+
self.lora_config = rbln_config.lora_config
|
|
728
641
|
|
|
729
642
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
730
|
-
self.kvcache_block_size = kvcache_block_size
|
|
731
643
|
self.__post_init__()
|
|
732
644
|
|
|
645
|
+
def _init_lora_weights(self):
|
|
646
|
+
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
|
+
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
+
original_linear = getattr(self._original_mod, proj_name)
|
|
649
|
+
lora_linear = LoRALinear(
|
|
650
|
+
original_linear=original_linear,
|
|
651
|
+
lora_config=self.lora_config,
|
|
652
|
+
projection_name=proj_name,
|
|
653
|
+
layer_idx=self.layer_idx,
|
|
654
|
+
)
|
|
655
|
+
setattr(self, proj_name, lora_linear)
|
|
656
|
+
|
|
733
657
|
def get_attention_name(self):
|
|
734
658
|
if self.is_sliding:
|
|
735
659
|
return "sliding_window_attention"
|
|
@@ -767,6 +691,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
767
691
|
self.kvcache_partition_len,
|
|
768
692
|
self.use_attention_mask,
|
|
769
693
|
self.use_position_ids,
|
|
694
|
+
self.quantization,
|
|
770
695
|
)
|
|
771
696
|
elif self.attn_impl == "eager":
|
|
772
697
|
return AttentionOp(
|
|
@@ -775,28 +700,46 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
775
700
|
self.num_key_value_heads,
|
|
776
701
|
self.use_attention_mask,
|
|
777
702
|
self.use_position_ids,
|
|
703
|
+
self.quantization,
|
|
778
704
|
)
|
|
779
705
|
else:
|
|
780
706
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
781
707
|
|
|
782
708
|
def __post_init__(self):
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
709
|
+
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
|
+
if self.lora_config:
|
|
711
|
+
self._init_lora_weights()
|
|
712
|
+
else:
|
|
713
|
+
# Use original linear layers if no LoRA
|
|
714
|
+
self.q_proj = self._original_mod.q_proj
|
|
715
|
+
self.k_proj = self._original_mod.k_proj
|
|
716
|
+
self.v_proj = self._original_mod.v_proj
|
|
717
|
+
self.o_proj = self._original_mod.o_proj
|
|
718
|
+
|
|
719
|
+
def projection(
|
|
720
|
+
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
721
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
789
722
|
"""Projects input hidden states into query, key, and value representations.
|
|
790
723
|
|
|
791
724
|
Args:
|
|
792
725
|
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
726
|
+
lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
|
|
793
727
|
|
|
794
728
|
Returns:
|
|
795
729
|
Tuple of (query_states, key_states, value_states)
|
|
796
730
|
"""
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
731
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
732
|
+
if self.lora_config:
|
|
733
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
734
|
+
query_states = self.q_proj(hidden_states, lora_int_id)
|
|
735
|
+
key_states = self.k_proj(hidden_states, lora_int_id)
|
|
736
|
+
value_states = self.v_proj(hidden_states, lora_int_id)
|
|
737
|
+
else:
|
|
738
|
+
# Standard linear projection without LoRA
|
|
739
|
+
query_states = self.q_proj(hidden_states)
|
|
740
|
+
key_states = self.k_proj(hidden_states)
|
|
741
|
+
value_states = self.v_proj(hidden_states)
|
|
742
|
+
|
|
800
743
|
return query_states, key_states, value_states
|
|
801
744
|
|
|
802
745
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
@@ -805,6 +748,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
805
748
|
def get_attn_scale(self):
|
|
806
749
|
return 1 / math.sqrt(self.head_dim)
|
|
807
750
|
|
|
751
|
+
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
752
|
+
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
753
|
+
k_scale = getattr(self.k_proj, "k_scale", None)
|
|
754
|
+
v_scale = getattr(self.v_proj, "v_scale", None)
|
|
755
|
+
else:
|
|
756
|
+
k_scale = None
|
|
757
|
+
v_scale = None
|
|
758
|
+
|
|
759
|
+
return k_scale, v_scale
|
|
760
|
+
|
|
808
761
|
def forward(
|
|
809
762
|
self,
|
|
810
763
|
hidden_states: torch.Tensor,
|
|
@@ -814,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
814
767
|
cos: Optional[torch.Tensor] = None,
|
|
815
768
|
sin: Optional[torch.Tensor] = None,
|
|
816
769
|
block_tables: Optional[torch.Tensor] = None,
|
|
770
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
817
771
|
):
|
|
818
772
|
batch_size, query_length, _ = hidden_states.size()
|
|
819
773
|
|
|
820
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
774
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
|
|
821
775
|
|
|
822
776
|
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
823
777
|
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
@@ -834,6 +788,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
834
788
|
if batch_size > 1 and "prefill" in self.phase:
|
|
835
789
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
836
790
|
|
|
791
|
+
k_scale, v_scale = self.maybe_get_kvcache_scale()
|
|
792
|
+
|
|
837
793
|
attn_output = self.get_attention_op()(
|
|
838
794
|
query_states,
|
|
839
795
|
key_states,
|
|
@@ -845,9 +801,18 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
845
801
|
scale=self.scale,
|
|
846
802
|
block_tables=block_tables,
|
|
847
803
|
block_size=self.kvcache_block_size,
|
|
804
|
+
k_scale=k_scale,
|
|
805
|
+
v_scale=v_scale,
|
|
848
806
|
)
|
|
849
807
|
|
|
850
|
-
|
|
808
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
809
|
+
if self.lora_config:
|
|
810
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
811
|
+
attn_outputs = self.o_proj(attn_output, lora_int_id)
|
|
812
|
+
else:
|
|
813
|
+
# Standard linear projection without LoRA
|
|
814
|
+
attn_outputs = self.o_proj(attn_output)
|
|
815
|
+
|
|
851
816
|
return attn_outputs
|
|
852
817
|
|
|
853
818
|
|
|
@@ -861,7 +826,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
|
861
826
|
|
|
862
827
|
class AttentionOp(nn.Module):
|
|
863
828
|
def __init__(
|
|
864
|
-
self,
|
|
829
|
+
self,
|
|
830
|
+
num_heads: int,
|
|
831
|
+
head_dim: int,
|
|
832
|
+
num_key_value_heads: int,
|
|
833
|
+
use_attention_mask: bool,
|
|
834
|
+
use_position_ids: bool,
|
|
835
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
865
836
|
):
|
|
866
837
|
super().__init__()
|
|
867
838
|
self.num_heads = num_heads
|
|
@@ -870,16 +841,20 @@ class AttentionOp(nn.Module):
|
|
|
870
841
|
self.phase = "prefill"
|
|
871
842
|
self.use_attention_mask = use_attention_mask
|
|
872
843
|
self.use_position_ids = use_position_ids
|
|
844
|
+
self.quantization = quantization
|
|
873
845
|
|
|
874
846
|
def get_attn_op_name(self):
|
|
875
847
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
876
|
-
if self.use_attention_mask:
|
|
848
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
877
849
|
attn_op_name = "paged_attn_"
|
|
878
850
|
else:
|
|
879
851
|
attn_op_name = "paged_causal_attn_"
|
|
880
852
|
|
|
881
853
|
attn_op_name += phase
|
|
882
854
|
|
|
855
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
856
|
+
attn_op_name += "_kv_fp8"
|
|
857
|
+
|
|
883
858
|
return attn_op_name
|
|
884
859
|
|
|
885
860
|
def forward(
|
|
@@ -894,6 +869,8 @@ class AttentionOp(nn.Module):
|
|
|
894
869
|
scale: torch.Tensor,
|
|
895
870
|
block_tables: torch.Tensor,
|
|
896
871
|
block_size: int,
|
|
872
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
873
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
897
874
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
898
875
|
"""Compute attention with static shapes and explicit cache management.
|
|
899
876
|
|
|
@@ -906,6 +883,10 @@ class AttentionOp(nn.Module):
|
|
|
906
883
|
past_value_state: Previous value cache states
|
|
907
884
|
seq_position: Current position in sequence
|
|
908
885
|
scale: Scale applied to attn weights
|
|
886
|
+
block_tables: Block tables for paged attention
|
|
887
|
+
block_size: Block size for paged attention
|
|
888
|
+
k_scale: Scale applied to key
|
|
889
|
+
v_scale: Scale applied to value
|
|
909
890
|
|
|
910
891
|
Returns:
|
|
911
892
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -942,13 +923,19 @@ class AttentionOp(nn.Module):
|
|
|
942
923
|
"block_size": block_size,
|
|
943
924
|
}
|
|
944
925
|
|
|
945
|
-
if self.use_attention_mask
|
|
926
|
+
if self.use_attention_mask:
|
|
946
927
|
op_args["mask"] = attn_mask
|
|
947
928
|
|
|
948
929
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
949
930
|
if not self.use_attention_mask or self.use_position_ids:
|
|
950
931
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
951
932
|
|
|
933
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
934
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
935
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
936
|
+
op_args["k_scale"] = k_scale
|
|
937
|
+
op_args["v_scale"] = v_scale
|
|
938
|
+
|
|
952
939
|
attn_op_name = self.get_attn_op_name()
|
|
953
940
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
954
941
|
if attn_op is None:
|
|
@@ -962,97 +949,6 @@ class AttentionOp(nn.Module):
|
|
|
962
949
|
return attn_output
|
|
963
950
|
|
|
964
951
|
|
|
965
|
-
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
966
|
-
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
967
|
-
if cache_position.shape[0] > 1:
|
|
968
|
-
cos_all = []
|
|
969
|
-
sin_all = []
|
|
970
|
-
for i in range(cache_position.shape[0]):
|
|
971
|
-
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
972
|
-
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
973
|
-
cos = torch.cat(cos_all, dim=0)
|
|
974
|
-
sin = torch.cat(sin_all, dim=0)
|
|
975
|
-
else:
|
|
976
|
-
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
977
|
-
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
978
|
-
|
|
979
|
-
return cos, sin
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
def rotate_half(x):
|
|
983
|
-
"""Rotates half the hidden dims of the input."""
|
|
984
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
985
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
986
|
-
return torch.cat((-x2, x1), dim=-1)
|
|
987
|
-
|
|
988
|
-
|
|
989
|
-
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
990
|
-
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
991
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
992
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
993
|
-
return q_embed, k_embed
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
997
|
-
# Partial rotary embedding
|
|
998
|
-
query_rot, query_pass = (
|
|
999
|
-
query_states[..., :ndim],
|
|
1000
|
-
query_states[..., ndim:],
|
|
1001
|
-
)
|
|
1002
|
-
key_rot, key_pass = (
|
|
1003
|
-
key_states[..., :ndim],
|
|
1004
|
-
key_states[..., ndim:],
|
|
1005
|
-
)
|
|
1006
|
-
|
|
1007
|
-
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1008
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1009
|
-
|
|
1010
|
-
# [batch_size, seq_length, num_heads, head_dim]
|
|
1011
|
-
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1012
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1013
|
-
return query_states, key_states
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
class RotaryEmbedding(nn.Module):
|
|
1017
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1018
|
-
|
|
1019
|
-
def __init__(
|
|
1020
|
-
self,
|
|
1021
|
-
config: PretrainedConfig,
|
|
1022
|
-
max_seq_len_cached: int,
|
|
1023
|
-
):
|
|
1024
|
-
super().__init__()
|
|
1025
|
-
|
|
1026
|
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1027
|
-
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1028
|
-
else:
|
|
1029
|
-
rope_type = "default"
|
|
1030
|
-
|
|
1031
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1032
|
-
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
1033
|
-
cache_position_expanded = cache_position[:, None]
|
|
1034
|
-
|
|
1035
|
-
if rope_type == "dynamic":
|
|
1036
|
-
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1037
|
-
else:
|
|
1038
|
-
inv_freq_expanded = inv_freq[None, :]
|
|
1039
|
-
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1040
|
-
|
|
1041
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1042
|
-
|
|
1043
|
-
cos = emb.cos() * attention_scaling
|
|
1044
|
-
sin = emb.sin() * attention_scaling
|
|
1045
|
-
|
|
1046
|
-
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1047
|
-
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1048
|
-
|
|
1049
|
-
def forward(self, x, seq_len):
|
|
1050
|
-
return (
|
|
1051
|
-
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1052
|
-
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
952
|
class FlashAttentionOp(AttentionOp):
|
|
1057
953
|
def __init__(
|
|
1058
954
|
self,
|
|
@@ -1062,6 +958,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1062
958
|
kvcache_partition_len: int,
|
|
1063
959
|
use_attention_mask: bool,
|
|
1064
960
|
use_position_ids: bool,
|
|
961
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
1065
962
|
):
|
|
1066
963
|
super().__init__(
|
|
1067
964
|
num_heads=num_heads,
|
|
@@ -1069,18 +966,22 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1069
966
|
num_key_value_heads=num_key_value_heads,
|
|
1070
967
|
use_attention_mask=use_attention_mask,
|
|
1071
968
|
use_position_ids=use_position_ids,
|
|
969
|
+
quantization=quantization,
|
|
1072
970
|
)
|
|
1073
971
|
self.kvcache_partition_size = kvcache_partition_len
|
|
1074
972
|
|
|
1075
973
|
def get_attn_op_name(self):
|
|
1076
974
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1077
|
-
if self.use_attention_mask:
|
|
975
|
+
if self.use_attention_mask and not self.use_position_ids:
|
|
1078
976
|
attn_op_name = "paged_flash_attn_"
|
|
1079
977
|
else:
|
|
1080
978
|
attn_op_name = "paged_flash_causal_attn_"
|
|
1081
979
|
|
|
1082
980
|
attn_op_name += phase
|
|
1083
981
|
|
|
982
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
983
|
+
attn_op_name += "_kv_fp8"
|
|
984
|
+
|
|
1084
985
|
return attn_op_name
|
|
1085
986
|
|
|
1086
987
|
def forward(
|
|
@@ -1095,6 +996,8 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1095
996
|
scale,
|
|
1096
997
|
block_tables,
|
|
1097
998
|
block_size,
|
|
999
|
+
k_scale=None,
|
|
1000
|
+
v_scale=None,
|
|
1098
1001
|
):
|
|
1099
1002
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1100
1003
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1128,13 +1031,19 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1128
1031
|
"partition": self.kvcache_partition_size,
|
|
1129
1032
|
}
|
|
1130
1033
|
|
|
1131
|
-
if self.use_attention_mask
|
|
1034
|
+
if self.use_attention_mask:
|
|
1132
1035
|
op_args["mask"] = attn_mask
|
|
1133
1036
|
|
|
1134
1037
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1135
1038
|
if not self.use_attention_mask or self.use_position_ids:
|
|
1136
1039
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1137
1040
|
|
|
1041
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
1042
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
1043
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
1044
|
+
op_args["k_scale"] = k_scale
|
|
1045
|
+
op_args["v_scale"] = v_scale
|
|
1046
|
+
|
|
1138
1047
|
attn_op_name = self.get_attn_op_name()
|
|
1139
1048
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1140
1049
|
if attn_op is None:
|
|
@@ -1151,8 +1060,8 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1151
1060
|
class SlidingWindowAttentionOp(AttentionOp):
|
|
1152
1061
|
def get_attn_op_name(self):
|
|
1153
1062
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
1154
|
-
if self.use_attention_mask:
|
|
1155
|
-
raise NotImplementedError("Attention mask is
|
|
1063
|
+
if not self.use_attention_mask:
|
|
1064
|
+
raise NotImplementedError("Attention mask is needed for sliding window attention.")
|
|
1156
1065
|
|
|
1157
1066
|
attn_op_name = "paged_sliding_window_attn_" + phase
|
|
1158
1067
|
return attn_op_name
|
|
@@ -1162,14 +1071,19 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1162
1071
|
query_state: torch.Tensor,
|
|
1163
1072
|
key_state: torch.Tensor,
|
|
1164
1073
|
value_state: torch.Tensor,
|
|
1165
|
-
attn_mask: torch.Tensor,
|
|
1074
|
+
attn_mask: Optional[torch.Tensor],
|
|
1166
1075
|
past_key_state: torch.Tensor,
|
|
1167
1076
|
past_value_state: torch.Tensor,
|
|
1168
1077
|
seq_position: Tuple[torch.Tensor],
|
|
1169
1078
|
scale: torch.Tensor,
|
|
1170
1079
|
block_tables: torch.Tensor,
|
|
1171
1080
|
block_size: int,
|
|
1081
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
1082
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
1172
1083
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1084
|
+
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1085
|
+
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
1086
|
+
|
|
1173
1087
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1174
1088
|
key_state = key_state.unsqueeze(2)
|
|
1175
1089
|
value_state = value_state.unsqueeze(2)
|
|
@@ -1201,8 +1115,7 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1201
1115
|
}
|
|
1202
1116
|
|
|
1203
1117
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1204
|
-
|
|
1205
|
-
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1118
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1206
1119
|
|
|
1207
1120
|
attn_op_name = self.get_attn_op_name()
|
|
1208
1121
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
@@ -1215,3 +1128,97 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1215
1128
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1216
1129
|
|
|
1217
1130
|
return attn_output
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
class RotaryEmbedding(nn.Module):
|
|
1134
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1135
|
+
|
|
1136
|
+
def __init__(
|
|
1137
|
+
self,
|
|
1138
|
+
config: PretrainedConfig,
|
|
1139
|
+
max_seq_len_cached: int,
|
|
1140
|
+
):
|
|
1141
|
+
super().__init__()
|
|
1142
|
+
|
|
1143
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1144
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1145
|
+
else:
|
|
1146
|
+
rope_type = "default"
|
|
1147
|
+
|
|
1148
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1149
|
+
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1150
|
+
cache_position_expanded = cache_position[:, None]
|
|
1151
|
+
|
|
1152
|
+
if rope_type == "dynamic":
|
|
1153
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1154
|
+
else:
|
|
1155
|
+
inv_freq_expanded = inv_freq[None, :]
|
|
1156
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1157
|
+
|
|
1158
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1159
|
+
|
|
1160
|
+
cos = emb.cos() * attention_scaling
|
|
1161
|
+
sin = emb.sin() * attention_scaling
|
|
1162
|
+
|
|
1163
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1164
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1165
|
+
|
|
1166
|
+
def forward(self, x, seq_len):
|
|
1167
|
+
return (
|
|
1168
|
+
self._cos_cached[:seq_len].to(dtype=torch.float32),
|
|
1169
|
+
self._sin_cached[:seq_len].to(dtype=torch.float32),
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
1174
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
1175
|
+
if cache_position.shape[0] > 1:
|
|
1176
|
+
cos_all = []
|
|
1177
|
+
sin_all = []
|
|
1178
|
+
for i in range(cache_position.shape[0]):
|
|
1179
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1180
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1181
|
+
cos = torch.cat(cos_all, dim=0)
|
|
1182
|
+
sin = torch.cat(sin_all, dim=0)
|
|
1183
|
+
else:
|
|
1184
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
1185
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
1186
|
+
|
|
1187
|
+
return cos, sin
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def rotate_half(x):
|
|
1191
|
+
"""Rotates half the hidden dims of the input."""
|
|
1192
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1193
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1194
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1198
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1199
|
+
dtype = q.dtype
|
|
1200
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1201
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1202
|
+
q_embed = q_embed.to(dtype)
|
|
1203
|
+
k_embed = k_embed.to(dtype)
|
|
1204
|
+
return q_embed, k_embed
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1208
|
+
# Partial rotary embedding
|
|
1209
|
+
query_rot, query_pass = (
|
|
1210
|
+
query_states[..., :ndim],
|
|
1211
|
+
query_states[..., ndim:],
|
|
1212
|
+
)
|
|
1213
|
+
key_rot, key_pass = (
|
|
1214
|
+
key_states[..., :ndim],
|
|
1215
|
+
key_states[..., ndim:],
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1219
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1220
|
+
|
|
1221
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
|
1222
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1223
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1224
|
+
return query_states, key_states
|