optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,7 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import math
|
|
16
|
-
from typing import List, Optional, Tuple, Union
|
|
16
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
|
17
17
|
|
|
18
18
|
import torch
|
|
19
19
|
from torch import nn
|
|
@@ -21,106 +21,16 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
23
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
|
-
from .
|
|
24
|
+
from ...utils.rbln_quantization import RBLNQuantizationConfig
|
|
25
|
+
from .configuration_lora import RBLNLoRAConfig
|
|
26
|
+
from .lora_architecture import LoRALinear
|
|
25
27
|
|
|
26
28
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
30
|
-
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
31
|
-
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
32
|
-
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
33
|
-
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
34
|
-
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_default_values(
|
|
38
|
-
attn_impl: Optional[str] = None,
|
|
39
|
-
kvcache_partition_len: Optional[int] = None,
|
|
40
|
-
kvcache_block_size: Optional[int] = None,
|
|
41
|
-
max_seq_len: Optional[int] = None,
|
|
42
|
-
) -> Tuple[str, int, int]:
|
|
43
|
-
if attn_impl is None:
|
|
44
|
-
attn_impl = "eager"
|
|
45
|
-
|
|
46
|
-
if kvcache_partition_len is not None:
|
|
47
|
-
if attn_impl == "eager":
|
|
48
|
-
attn_impl = "flash_attn"
|
|
49
|
-
logger.warning(
|
|
50
|
-
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
51
|
-
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
52
|
-
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
56
|
-
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
57
|
-
|
|
58
|
-
if kvcache_block_size is None:
|
|
59
|
-
if attn_impl == "eager":
|
|
60
|
-
kvcache_block_size = max_seq_len
|
|
61
|
-
else:
|
|
62
|
-
kvcache_block_size = kvcache_partition_len
|
|
63
|
-
|
|
64
|
-
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
68
|
-
if attn_impl not in ["eager", "flash_attn"]:
|
|
69
|
-
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
70
|
-
|
|
71
|
-
## Checking Constraints...
|
|
72
|
-
# Constraint of eager attention:
|
|
73
|
-
# - `max_seq_len` <= 32k
|
|
74
|
-
|
|
75
|
-
# Constraints of flash attention:
|
|
76
|
-
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
77
|
-
# 2. 4k <= `partition_len` <= 32k.
|
|
78
|
-
# 3. `max_seq_len` should be larger then 8k.
|
|
79
|
-
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f"`max_seq_len` is set to {max_seq_len}, "
|
|
82
|
-
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
83
|
-
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
84
|
-
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if attn_impl == "flash_attn":
|
|
88
|
-
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
91
|
-
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
92
|
-
)
|
|
93
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
96
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
97
|
-
f"Please provide a valid value within this range."
|
|
98
|
-
)
|
|
99
|
-
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
102
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
103
|
-
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if kvcache_block_size is not None:
|
|
107
|
-
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
108
|
-
raise ValueError(
|
|
109
|
-
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
110
|
-
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
111
|
-
)
|
|
112
|
-
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
115
|
-
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
116
|
-
)
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
|
|
117
31
|
|
|
118
32
|
|
|
119
|
-
|
|
120
|
-
if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
|
|
123
|
-
)
|
|
33
|
+
logger = logging.get_logger(__name__)
|
|
124
34
|
|
|
125
35
|
|
|
126
36
|
class DecoderOnlyWrapper(nn.Module):
|
|
@@ -137,40 +47,22 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
137
47
|
- Wrapper should not contain neural network graph operations (including memory view handling)
|
|
138
48
|
|
|
139
49
|
Args:
|
|
140
|
-
|
|
141
|
-
|
|
50
|
+
model (PreTrainedModel): The Huggingface causal language model to wrap
|
|
51
|
+
rbln_config: The RBLN model configuration containing all necessary parameters
|
|
142
52
|
use_rotary_emb (bool): Whether to use rotary position embeddings
|
|
143
|
-
attn_impl (str): The attention implementation to use.
|
|
144
|
-
- "eager": Uses the standard attention.
|
|
145
|
-
- "flash_attn": Uses flash attention. When set,
|
|
146
|
-
the key/value cache is partitioned into chunks of length
|
|
147
|
-
`kvcache_partition_len`.
|
|
148
|
-
kvcache_partition_len (Optional[int]): Length of KV cache partitions for flash attention.
|
|
149
|
-
This is only relevant if `attn_impl` is set to "flash_attn`
|
|
150
53
|
"""
|
|
151
54
|
|
|
152
55
|
_use_learned_pos_emb = False
|
|
153
56
|
|
|
154
|
-
def __init__(
|
|
155
|
-
self,
|
|
156
|
-
causal_lm: PreTrainedModel,
|
|
157
|
-
max_seq_len: int,
|
|
158
|
-
use_rotary_emb: bool,
|
|
159
|
-
attn_impl: str,
|
|
160
|
-
cache_impl: CacheImplType,
|
|
161
|
-
use_inputs_embeds: bool,
|
|
162
|
-
use_attention_mask: bool,
|
|
163
|
-
use_position_ids: bool,
|
|
164
|
-
kvcache_partition_len: Optional[int] = None,
|
|
165
|
-
kvcache_block_size: Optional[int] = None,
|
|
166
|
-
sliding_window: Optional[int] = None,
|
|
167
|
-
sliding_window_layers: Optional[List[int]] = None,
|
|
168
|
-
):
|
|
57
|
+
def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
|
|
169
58
|
super().__init__()
|
|
170
|
-
self.
|
|
59
|
+
self.quantization = rbln_config.quantization
|
|
60
|
+
self.config = model.config
|
|
61
|
+
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
62
|
+
self.rbln_config = rbln_config
|
|
171
63
|
|
|
172
64
|
if use_rotary_emb:
|
|
173
|
-
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
65
|
+
rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
|
|
174
66
|
if isinstance(rotary_embs, tuple):
|
|
175
67
|
self.rotary_emb_global, self.rotary_emb_local = rotary_embs
|
|
176
68
|
else:
|
|
@@ -178,43 +70,27 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
178
70
|
else:
|
|
179
71
|
self.rotary_emb = None
|
|
180
72
|
|
|
181
|
-
|
|
182
|
-
self.kvcache_block_size = kvcache_block_size
|
|
183
|
-
self.use_attention_mask = use_attention_mask
|
|
184
|
-
self.use_position_ids = use_position_ids
|
|
185
|
-
self.use_inputs_embeds = use_inputs_embeds
|
|
186
|
-
self.sliding_window_layers = sliding_window_layers
|
|
187
|
-
self.cache_impl = cache_impl
|
|
188
|
-
self.sliding_window = sliding_window
|
|
189
|
-
|
|
190
|
-
if self.attn_impl == "flash_attn":
|
|
191
|
-
self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
192
|
-
elif self.attn_impl == "eager":
|
|
193
|
-
self.kvcache_partition_len = None
|
|
194
|
-
else:
|
|
195
|
-
raise ValueError(f"Unknown attn_impl : {self.attn_impl}")
|
|
196
|
-
|
|
197
|
-
if kvcache_partition_len and kvcache_partition_len > max_seq_len:
|
|
73
|
+
if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
|
|
198
74
|
raise ValueError(
|
|
199
|
-
f"kvcache_partition_len({kvcache_partition_len}) should be lower"
|
|
200
|
-
f" or equal to max_seq_len({max_seq_len})!"
|
|
75
|
+
f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
|
|
76
|
+
f" or equal to max_seq_len({rbln_config.max_seq_len})!"
|
|
201
77
|
)
|
|
202
78
|
|
|
203
|
-
self.
|
|
79
|
+
self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
|
|
204
80
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
205
81
|
self._phase = "prefill"
|
|
206
82
|
|
|
207
83
|
def get_rotary_emb(self, max_seq_len):
|
|
208
84
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
85
|
|
|
210
|
-
def get_decoder_layers(self,
|
|
211
|
-
return
|
|
86
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
87
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
212
88
|
|
|
213
89
|
def get_attn_layer(self, layer: nn.Module):
|
|
214
90
|
return layer.self_attn
|
|
215
91
|
|
|
216
|
-
def get_model_layer(self,
|
|
217
|
-
return
|
|
92
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
93
|
+
return model.model if self.is_causal_lm else model
|
|
218
94
|
|
|
219
95
|
def get_rbln_attn_class(self):
|
|
220
96
|
return DecoderOnlyAttention
|
|
@@ -228,35 +104,28 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
228
104
|
def get_rbln_causal_lm_class(self):
|
|
229
105
|
return DecoderOnlyForCausalLM
|
|
230
106
|
|
|
231
|
-
def
|
|
107
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
232
108
|
new_layers = []
|
|
233
|
-
for layer_idx, layer in enumerate(self.get_decoder_layers(
|
|
234
|
-
is_sliding = layer_idx in self.sliding_window_layers
|
|
109
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
110
|
+
is_sliding = layer_idx in self.rbln_config.sliding_window_layers
|
|
235
111
|
new_self_attn = self.get_rbln_attn_class()(
|
|
236
|
-
self.get_attn_layer(layer),
|
|
237
|
-
self.use_attention_mask if not is_sliding else True,
|
|
238
|
-
self.use_position_ids,
|
|
239
|
-
kvcache_block_size=self.sliding_window
|
|
240
|
-
if layer_idx in self.sliding_window_layers
|
|
241
|
-
else self.kvcache_block_size,
|
|
242
|
-
is_sliding=is_sliding,
|
|
243
|
-
attn_impl=self.attn_impl if not is_sliding else "eager",
|
|
244
|
-
kvcache_partition_len=self.kvcache_partition_len,
|
|
112
|
+
self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
|
|
245
113
|
)
|
|
246
|
-
new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
|
|
114
|
+
new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
|
|
247
115
|
new_layers.append(new_layer)
|
|
248
116
|
|
|
249
117
|
new_model = self.get_rbln_model_class()(
|
|
250
|
-
self.get_model_layer(
|
|
118
|
+
self.get_model_layer(model),
|
|
251
119
|
new_layers,
|
|
252
|
-
|
|
253
|
-
max_seq_len=max_seq_len,
|
|
254
|
-
kvcache_block_size=self.kvcache_block_size,
|
|
120
|
+
self.rbln_config,
|
|
255
121
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
256
|
-
sliding_window_layers=self.sliding_window_layers,
|
|
257
122
|
)
|
|
258
|
-
|
|
259
|
-
|
|
123
|
+
|
|
124
|
+
if self.is_causal_lm:
|
|
125
|
+
new_model = self.get_rbln_causal_lm_class()(model, new_model)
|
|
126
|
+
return new_model
|
|
127
|
+
else:
|
|
128
|
+
return new_model
|
|
260
129
|
|
|
261
130
|
@property
|
|
262
131
|
def phase(self) -> str:
|
|
@@ -265,18 +134,24 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
265
134
|
@phase.setter
|
|
266
135
|
def phase(self, phase: str):
|
|
267
136
|
self._phase = phase
|
|
268
|
-
self.
|
|
137
|
+
self.model.phase = phase
|
|
269
138
|
|
|
270
139
|
def prepare_forward_args(self, *args):
|
|
271
140
|
args = list(args)
|
|
272
|
-
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
273
|
-
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
141
|
+
input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
|
|
142
|
+
inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
|
|
274
143
|
cache_position = args.pop(0)
|
|
275
|
-
global_block_tables = args.pop(0) if self.
|
|
276
|
-
local_block_tables = args.pop(0) if self.
|
|
277
|
-
query_position =
|
|
278
|
-
|
|
279
|
-
|
|
144
|
+
global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
|
|
145
|
+
local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
|
|
146
|
+
query_position = (
|
|
147
|
+
args.pop(0)
|
|
148
|
+
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
149
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
|
|
150
|
+
else None
|
|
151
|
+
)
|
|
152
|
+
attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
|
|
153
|
+
position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
|
|
154
|
+
lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
|
|
280
155
|
past_key_values = args
|
|
281
156
|
|
|
282
157
|
if len(past_key_values) != 2 * self.num_hidden_layers:
|
|
@@ -308,6 +183,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
308
183
|
query_position,
|
|
309
184
|
attention_mask,
|
|
310
185
|
position_ids,
|
|
186
|
+
lora_int_id,
|
|
311
187
|
past_key_values,
|
|
312
188
|
rotary_emb,
|
|
313
189
|
)
|
|
@@ -322,11 +198,12 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
322
198
|
query_position,
|
|
323
199
|
attention_mask,
|
|
324
200
|
position_ids,
|
|
201
|
+
lora_int_id,
|
|
325
202
|
past_key_values,
|
|
326
203
|
rotary_emb,
|
|
327
204
|
) = self.prepare_forward_args(*args)
|
|
328
205
|
|
|
329
|
-
logit = self.
|
|
206
|
+
logit = self.model(
|
|
330
207
|
input_ids=input_ids,
|
|
331
208
|
inputs_embeds=inputs_embeds,
|
|
332
209
|
attention_mask=attention_mask,
|
|
@@ -337,6 +214,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
337
214
|
rotary_emb=rotary_emb,
|
|
338
215
|
global_block_tables=global_block_tables,
|
|
339
216
|
local_block_tables=local_block_tables,
|
|
217
|
+
lora_int_id=lora_int_id,
|
|
340
218
|
)
|
|
341
219
|
|
|
342
220
|
return logit
|
|
@@ -393,6 +271,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
393
271
|
rotary_emb: nn.Module = None,
|
|
394
272
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
395
273
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
274
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
396
275
|
):
|
|
397
276
|
# outputs
|
|
398
277
|
hidden_states = self.model(
|
|
@@ -406,6 +285,7 @@ class DecoderOnlyForCausalLM(nn.Module):
|
|
|
406
285
|
rotary_emb=rotary_emb,
|
|
407
286
|
global_block_tables=global_block_tables,
|
|
408
287
|
local_block_tables=local_block_tables,
|
|
288
|
+
lora_int_id=lora_int_id,
|
|
409
289
|
)
|
|
410
290
|
|
|
411
291
|
if "prefill" in self.phase:
|
|
@@ -428,6 +308,8 @@ class DecoderOnlyModel(nn.Module):
|
|
|
428
308
|
Args:
|
|
429
309
|
model: Original Huggingface model to adapt
|
|
430
310
|
layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
|
|
311
|
+
rbln_config: RBLN model configuration
|
|
312
|
+
use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
|
|
431
313
|
|
|
432
314
|
Attributes:
|
|
433
315
|
_original_mod: Reference to original Huggingface model
|
|
@@ -439,21 +321,19 @@ class DecoderOnlyModel(nn.Module):
|
|
|
439
321
|
self,
|
|
440
322
|
model,
|
|
441
323
|
layers: List["DecoderOnlyLayer"],
|
|
442
|
-
|
|
443
|
-
max_seq_len=None,
|
|
444
|
-
kvcache_block_size=None,
|
|
324
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
445
325
|
use_learned_pos_emb=None,
|
|
446
|
-
sliding_window_layers=None,
|
|
447
326
|
):
|
|
448
327
|
super().__init__()
|
|
449
328
|
self._original_mod = model
|
|
450
329
|
self.layers = nn.ModuleList(layers)
|
|
330
|
+
self.rbln_config = rbln_config
|
|
451
331
|
self._phase = "prefill"
|
|
452
|
-
self.partition_len =
|
|
453
|
-
self.kvcache_block_size = kvcache_block_size
|
|
454
|
-
self.max_seq_len = max_seq_len
|
|
332
|
+
self.partition_len = rbln_config.kvcache_partition_len
|
|
333
|
+
self.kvcache_block_size = rbln_config.kvcache_block_size
|
|
334
|
+
self.max_seq_len = rbln_config.max_seq_len
|
|
455
335
|
self.use_learned_pos_emb = use_learned_pos_emb
|
|
456
|
-
self.sliding_window_layers = sliding_window_layers
|
|
336
|
+
self.sliding_window_layers = rbln_config.sliding_window_layers
|
|
457
337
|
|
|
458
338
|
@property
|
|
459
339
|
def phase(self):
|
|
@@ -517,6 +397,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
517
397
|
rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
|
|
518
398
|
global_block_tables: Optional[torch.Tensor] = None,
|
|
519
399
|
local_block_tables: Optional[torch.Tensor] = None,
|
|
400
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
520
401
|
):
|
|
521
402
|
# retrieve input_ids and inputs_embeds
|
|
522
403
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
@@ -589,6 +470,7 @@ class DecoderOnlyModel(nn.Module):
|
|
|
589
470
|
cos=cos,
|
|
590
471
|
sin=sin,
|
|
591
472
|
block_tables=local_block_tables if is_sliding else global_block_tables,
|
|
473
|
+
lora_int_id=lora_int_id,
|
|
592
474
|
)
|
|
593
475
|
|
|
594
476
|
hidden_states = self.get_last_layernorm()(hidden_states)
|
|
@@ -620,11 +502,27 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
620
502
|
phase: Current operation phase ("prefill" or "decode")
|
|
621
503
|
"""
|
|
622
504
|
|
|
623
|
-
def __init__(self, layer, self_attn: "DecoderOnlyAttention"):
|
|
505
|
+
def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
|
|
624
506
|
super().__init__()
|
|
625
507
|
self._original_mod = layer
|
|
626
508
|
self.self_attn = self_attn
|
|
627
509
|
self._phase = "prefill"
|
|
510
|
+
self.lora_config = lora_config
|
|
511
|
+
|
|
512
|
+
# Replace target Linear modules in MLP with LoRALinear if configured
|
|
513
|
+
if self.lora_config:
|
|
514
|
+
mlp = self.get_mlp()
|
|
515
|
+
for proj_name in ["gate_proj", "up_proj", "down_proj"]:
|
|
516
|
+
if hasattr(mlp, proj_name):
|
|
517
|
+
original_linear = getattr(mlp, proj_name)
|
|
518
|
+
if isinstance(original_linear, nn.Linear):
|
|
519
|
+
lora_linear = LoRALinear(
|
|
520
|
+
original_linear=original_linear,
|
|
521
|
+
lora_config=self.lora_config,
|
|
522
|
+
projection_name=proj_name,
|
|
523
|
+
layer_idx=self.self_attn.layer_idx,
|
|
524
|
+
)
|
|
525
|
+
setattr(mlp, proj_name, lora_linear)
|
|
628
526
|
|
|
629
527
|
@property
|
|
630
528
|
def phase(self):
|
|
@@ -641,6 +539,25 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
641
539
|
def get_post_attention_layernorm(self) -> nn.LayerNorm:
|
|
642
540
|
return self._original_mod.post_attention_layernorm
|
|
643
541
|
|
|
542
|
+
def get_mlp(self) -> nn.Module:
|
|
543
|
+
return self._original_mod.mlp
|
|
544
|
+
|
|
545
|
+
def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
546
|
+
mlp = self.get_mlp()
|
|
547
|
+
if self.lora_config and lora_int_id is not None:
|
|
548
|
+
gate = mlp.gate_proj(hidden_states, lora_int_id)
|
|
549
|
+
up = mlp.up_proj(hidden_states, lora_int_id)
|
|
550
|
+
act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
|
|
551
|
+
if act_fn is None:
|
|
552
|
+
gate = torch.nn.functional.silu(gate)
|
|
553
|
+
else:
|
|
554
|
+
gate = act_fn(gate)
|
|
555
|
+
fused = gate * up
|
|
556
|
+
hidden_states = mlp.down_proj(fused, lora_int_id)
|
|
557
|
+
else:
|
|
558
|
+
hidden_states = mlp(hidden_states)
|
|
559
|
+
return hidden_states
|
|
560
|
+
|
|
644
561
|
def forward(
|
|
645
562
|
self,
|
|
646
563
|
hidden_states: torch.Tensor,
|
|
@@ -650,6 +567,7 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
650
567
|
cos: Optional[torch.Tensor] = None,
|
|
651
568
|
sin: Optional[torch.Tensor] = None,
|
|
652
569
|
block_tables: Optional[torch.Tensor] = None,
|
|
570
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
653
571
|
):
|
|
654
572
|
residual = hidden_states
|
|
655
573
|
hidden_states = self.get_pre_attention_layernorm()(hidden_states)
|
|
@@ -662,13 +580,14 @@ class DecoderOnlyLayer(nn.Module):
|
|
|
662
580
|
cos=cos,
|
|
663
581
|
sin=sin,
|
|
664
582
|
block_tables=block_tables,
|
|
583
|
+
lora_int_id=lora_int_id,
|
|
665
584
|
)
|
|
666
585
|
hidden_states = residual + hidden_states
|
|
667
586
|
|
|
668
587
|
# Fully Connected
|
|
669
588
|
residual = hidden_states
|
|
670
589
|
hidden_states = self.get_post_attention_layernorm()(hidden_states)
|
|
671
|
-
hidden_states = self.
|
|
590
|
+
hidden_states = self.forward_mlp(hidden_states, lora_int_id)
|
|
672
591
|
hidden_states = residual + hidden_states
|
|
673
592
|
|
|
674
593
|
return hidden_states
|
|
@@ -683,32 +602,27 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
683
602
|
|
|
684
603
|
Args:
|
|
685
604
|
self_attn: Original attention module from the base model
|
|
686
|
-
|
|
687
|
-
use_position_ids: Whether to use position ids
|
|
688
|
-
kvcache_block_size: Block size for KV cache
|
|
605
|
+
rbln_config: RBLN model configuration containing attention parameters
|
|
689
606
|
is_sliding: Whether this is sliding window attention
|
|
690
|
-
attn_impl: Attention implementation type ("eager" or "flash_attn")
|
|
691
607
|
"""
|
|
692
608
|
|
|
693
609
|
def __init__(
|
|
694
610
|
self,
|
|
695
611
|
self_attn,
|
|
696
|
-
|
|
697
|
-
use_position_ids,
|
|
698
|
-
kvcache_block_size,
|
|
612
|
+
rbln_config: "RBLNDecoderOnlyModelConfig",
|
|
699
613
|
is_sliding=False,
|
|
700
|
-
attn_impl="eager",
|
|
701
|
-
kvcache_partition_len=None,
|
|
702
614
|
):
|
|
703
615
|
super().__init__()
|
|
704
616
|
self._original_mod = self_attn
|
|
617
|
+
self.rbln_config = rbln_config
|
|
705
618
|
self.layer_idx = self_attn.layer_idx
|
|
706
619
|
self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
|
|
707
620
|
self._original_mod.config, "num_attention_heads"
|
|
708
621
|
)
|
|
709
622
|
self.head_dim = self._original_mod.head_dim
|
|
710
623
|
self._phase = "prefill"
|
|
711
|
-
self.scale = torch.tensor(self.get_attn_scale())
|
|
624
|
+
self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
|
|
625
|
+
self.quantization = rbln_config.quantization
|
|
712
626
|
|
|
713
627
|
if hasattr(self._original_mod, "num_key_value_heads"):
|
|
714
628
|
self.num_key_value_heads = self._original_mod.num_key_value_heads
|
|
@@ -717,16 +631,29 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
717
631
|
else:
|
|
718
632
|
self.num_key_value_heads = self.num_heads
|
|
719
633
|
|
|
720
|
-
self.use_attention_mask = use_attention_mask
|
|
721
|
-
self.use_position_ids = use_position_ids
|
|
634
|
+
self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
|
|
635
|
+
self.use_position_ids = rbln_config.use_position_ids
|
|
722
636
|
self.is_sliding = is_sliding
|
|
723
|
-
self.attn_impl = attn_impl
|
|
724
|
-
self.kvcache_partition_len = kvcache_partition_len
|
|
637
|
+
self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
|
|
638
|
+
self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
|
|
639
|
+
self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
|
|
640
|
+
self.lora_config = rbln_config.lora_config
|
|
725
641
|
|
|
726
642
|
setattr(self, self.get_attention_name(), self.create_attention_op())
|
|
727
|
-
self.kvcache_block_size = kvcache_block_size
|
|
728
643
|
self.__post_init__()
|
|
729
644
|
|
|
645
|
+
def _init_lora_weights(self):
|
|
646
|
+
"""Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
|
|
647
|
+
for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
|
|
648
|
+
original_linear = getattr(self._original_mod, proj_name)
|
|
649
|
+
lora_linear = LoRALinear(
|
|
650
|
+
original_linear=original_linear,
|
|
651
|
+
lora_config=self.lora_config,
|
|
652
|
+
projection_name=proj_name,
|
|
653
|
+
layer_idx=self.layer_idx,
|
|
654
|
+
)
|
|
655
|
+
setattr(self, proj_name, lora_linear)
|
|
656
|
+
|
|
730
657
|
def get_attention_name(self):
|
|
731
658
|
if self.is_sliding:
|
|
732
659
|
return "sliding_window_attention"
|
|
@@ -764,6 +691,7 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
764
691
|
self.kvcache_partition_len,
|
|
765
692
|
self.use_attention_mask,
|
|
766
693
|
self.use_position_ids,
|
|
694
|
+
self.quantization,
|
|
767
695
|
)
|
|
768
696
|
elif self.attn_impl == "eager":
|
|
769
697
|
return AttentionOp(
|
|
@@ -772,28 +700,46 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
772
700
|
self.num_key_value_heads,
|
|
773
701
|
self.use_attention_mask,
|
|
774
702
|
self.use_position_ids,
|
|
703
|
+
self.quantization,
|
|
775
704
|
)
|
|
776
705
|
else:
|
|
777
706
|
raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
|
|
778
707
|
|
|
779
708
|
def __post_init__(self):
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
709
|
+
# Initialize LoRA weights if configured, which will replace linear layers
|
|
710
|
+
if self.lora_config:
|
|
711
|
+
self._init_lora_weights()
|
|
712
|
+
else:
|
|
713
|
+
# Use original linear layers if no LoRA
|
|
714
|
+
self.q_proj = self._original_mod.q_proj
|
|
715
|
+
self.k_proj = self._original_mod.k_proj
|
|
716
|
+
self.v_proj = self._original_mod.v_proj
|
|
717
|
+
self.o_proj = self._original_mod.o_proj
|
|
718
|
+
|
|
719
|
+
def projection(
|
|
720
|
+
self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
|
|
721
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
786
722
|
"""Projects input hidden states into query, key, and value representations.
|
|
787
723
|
|
|
788
724
|
Args:
|
|
789
725
|
hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
|
|
726
|
+
lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
|
|
790
727
|
|
|
791
728
|
Returns:
|
|
792
729
|
Tuple of (query_states, key_states, value_states)
|
|
793
730
|
"""
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
731
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
732
|
+
if self.lora_config:
|
|
733
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
734
|
+
query_states = self.q_proj(hidden_states, lora_int_id)
|
|
735
|
+
key_states = self.k_proj(hidden_states, lora_int_id)
|
|
736
|
+
value_states = self.v_proj(hidden_states, lora_int_id)
|
|
737
|
+
else:
|
|
738
|
+
# Standard linear projection without LoRA
|
|
739
|
+
query_states = self.q_proj(hidden_states)
|
|
740
|
+
key_states = self.k_proj(hidden_states)
|
|
741
|
+
value_states = self.v_proj(hidden_states)
|
|
742
|
+
|
|
797
743
|
return query_states, key_states, value_states
|
|
798
744
|
|
|
799
745
|
def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
|
|
@@ -802,6 +748,16 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
802
748
|
def get_attn_scale(self):
|
|
803
749
|
return 1 / math.sqrt(self.head_dim)
|
|
804
750
|
|
|
751
|
+
def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
752
|
+
if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
|
|
753
|
+
k_scale = getattr(self.k_proj, "k_scale", None)
|
|
754
|
+
v_scale = getattr(self.v_proj, "v_scale", None)
|
|
755
|
+
else:
|
|
756
|
+
k_scale = None
|
|
757
|
+
v_scale = None
|
|
758
|
+
|
|
759
|
+
return k_scale, v_scale
|
|
760
|
+
|
|
805
761
|
def forward(
|
|
806
762
|
self,
|
|
807
763
|
hidden_states: torch.Tensor,
|
|
@@ -811,10 +767,11 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
811
767
|
cos: Optional[torch.Tensor] = None,
|
|
812
768
|
sin: Optional[torch.Tensor] = None,
|
|
813
769
|
block_tables: Optional[torch.Tensor] = None,
|
|
770
|
+
lora_int_id: Optional[torch.Tensor] = None,
|
|
814
771
|
):
|
|
815
772
|
batch_size, query_length, _ = hidden_states.size()
|
|
816
773
|
|
|
817
|
-
query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
|
|
774
|
+
query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
|
|
818
775
|
|
|
819
776
|
query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
|
|
820
777
|
key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
@@ -831,6 +788,8 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
831
788
|
if batch_size > 1 and "prefill" in self.phase:
|
|
832
789
|
raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
|
|
833
790
|
|
|
791
|
+
k_scale, v_scale = self.maybe_get_kvcache_scale()
|
|
792
|
+
|
|
834
793
|
attn_output = self.get_attention_op()(
|
|
835
794
|
query_states,
|
|
836
795
|
key_states,
|
|
@@ -842,9 +801,18 @@ class DecoderOnlyAttention(nn.Module):
|
|
|
842
801
|
scale=self.scale,
|
|
843
802
|
block_tables=block_tables,
|
|
844
803
|
block_size=self.kvcache_block_size,
|
|
804
|
+
k_scale=k_scale,
|
|
805
|
+
v_scale=v_scale,
|
|
845
806
|
)
|
|
846
807
|
|
|
847
|
-
|
|
808
|
+
# Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
|
|
809
|
+
if self.lora_config:
|
|
810
|
+
# LoRALinear handles both base projection and LoRA in one forward pass
|
|
811
|
+
attn_outputs = self.o_proj(attn_output, lora_int_id)
|
|
812
|
+
else:
|
|
813
|
+
# Standard linear projection without LoRA
|
|
814
|
+
attn_outputs = self.o_proj(attn_output)
|
|
815
|
+
|
|
848
816
|
return attn_outputs
|
|
849
817
|
|
|
850
818
|
|
|
@@ -858,7 +826,13 @@ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
|
|
|
858
826
|
|
|
859
827
|
class AttentionOp(nn.Module):
|
|
860
828
|
def __init__(
|
|
861
|
-
self,
|
|
829
|
+
self,
|
|
830
|
+
num_heads: int,
|
|
831
|
+
head_dim: int,
|
|
832
|
+
num_key_value_heads: int,
|
|
833
|
+
use_attention_mask: bool,
|
|
834
|
+
use_position_ids: bool,
|
|
835
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
862
836
|
):
|
|
863
837
|
super().__init__()
|
|
864
838
|
self.num_heads = num_heads
|
|
@@ -867,10 +841,10 @@ class AttentionOp(nn.Module):
|
|
|
867
841
|
self.phase = "prefill"
|
|
868
842
|
self.use_attention_mask = use_attention_mask
|
|
869
843
|
self.use_position_ids = use_position_ids
|
|
844
|
+
self.quantization = quantization
|
|
870
845
|
|
|
871
846
|
def get_attn_op_name(self):
|
|
872
847
|
phase = "decode" if self.phase == "decode" else "prefill"
|
|
873
|
-
|
|
874
848
|
if self.use_attention_mask and not self.use_position_ids:
|
|
875
849
|
attn_op_name = "paged_attn_"
|
|
876
850
|
else:
|
|
@@ -878,6 +852,9 @@ class AttentionOp(nn.Module):
|
|
|
878
852
|
|
|
879
853
|
attn_op_name += phase
|
|
880
854
|
|
|
855
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
856
|
+
attn_op_name += "_kv_fp8"
|
|
857
|
+
|
|
881
858
|
return attn_op_name
|
|
882
859
|
|
|
883
860
|
def forward(
|
|
@@ -892,6 +869,8 @@ class AttentionOp(nn.Module):
|
|
|
892
869
|
scale: torch.Tensor,
|
|
893
870
|
block_tables: torch.Tensor,
|
|
894
871
|
block_size: int,
|
|
872
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
873
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
895
874
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
896
875
|
"""Compute attention with static shapes and explicit cache management.
|
|
897
876
|
|
|
@@ -904,6 +883,10 @@ class AttentionOp(nn.Module):
|
|
|
904
883
|
past_value_state: Previous value cache states
|
|
905
884
|
seq_position: Current position in sequence
|
|
906
885
|
scale: Scale applied to attn weights
|
|
886
|
+
block_tables: Block tables for paged attention
|
|
887
|
+
block_size: Block size for paged attention
|
|
888
|
+
k_scale: Scale applied to key
|
|
889
|
+
v_scale: Scale applied to value
|
|
907
890
|
|
|
908
891
|
Returns:
|
|
909
892
|
Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
|
|
@@ -940,13 +923,19 @@ class AttentionOp(nn.Module):
|
|
|
940
923
|
"block_size": block_size,
|
|
941
924
|
}
|
|
942
925
|
|
|
943
|
-
if self.use_attention_mask
|
|
926
|
+
if self.use_attention_mask:
|
|
944
927
|
op_args["mask"] = attn_mask
|
|
945
928
|
|
|
946
929
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
947
930
|
if not self.use_attention_mask or self.use_position_ids:
|
|
948
931
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
949
932
|
|
|
933
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
934
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
935
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
936
|
+
op_args["k_scale"] = k_scale
|
|
937
|
+
op_args["v_scale"] = v_scale
|
|
938
|
+
|
|
950
939
|
attn_op_name = self.get_attn_op_name()
|
|
951
940
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
952
941
|
if attn_op is None:
|
|
@@ -960,97 +949,6 @@ class AttentionOp(nn.Module):
|
|
|
960
949
|
return attn_output
|
|
961
950
|
|
|
962
951
|
|
|
963
|
-
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
964
|
-
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
965
|
-
if cache_position.shape[0] > 1:
|
|
966
|
-
cos_all = []
|
|
967
|
-
sin_all = []
|
|
968
|
-
for i in range(cache_position.shape[0]):
|
|
969
|
-
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
970
|
-
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
971
|
-
cos = torch.cat(cos_all, dim=0)
|
|
972
|
-
sin = torch.cat(sin_all, dim=0)
|
|
973
|
-
else:
|
|
974
|
-
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
975
|
-
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
976
|
-
|
|
977
|
-
return cos, sin
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
def rotate_half(x):
|
|
981
|
-
"""Rotates half the hidden dims of the input."""
|
|
982
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
983
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
984
|
-
return torch.cat((-x2, x1), dim=-1)
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
988
|
-
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
989
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
990
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
991
|
-
return q_embed, k_embed
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
995
|
-
# Partial rotary embedding
|
|
996
|
-
query_rot, query_pass = (
|
|
997
|
-
query_states[..., :ndim],
|
|
998
|
-
query_states[..., ndim:],
|
|
999
|
-
)
|
|
1000
|
-
key_rot, key_pass = (
|
|
1001
|
-
key_states[..., :ndim],
|
|
1002
|
-
key_states[..., ndim:],
|
|
1003
|
-
)
|
|
1004
|
-
|
|
1005
|
-
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1006
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1007
|
-
|
|
1008
|
-
# [batch_size, seq_length, num_heads, head_dim]
|
|
1009
|
-
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1010
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1011
|
-
return query_states, key_states
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
class RotaryEmbedding(nn.Module):
|
|
1015
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1016
|
-
|
|
1017
|
-
def __init__(
|
|
1018
|
-
self,
|
|
1019
|
-
config: PretrainedConfig,
|
|
1020
|
-
max_seq_len_cached: int,
|
|
1021
|
-
):
|
|
1022
|
-
super().__init__()
|
|
1023
|
-
|
|
1024
|
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1025
|
-
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1026
|
-
else:
|
|
1027
|
-
rope_type = "default"
|
|
1028
|
-
|
|
1029
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1030
|
-
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
1031
|
-
cache_position_expanded = cache_position[:, None]
|
|
1032
|
-
|
|
1033
|
-
if rope_type == "dynamic":
|
|
1034
|
-
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1035
|
-
else:
|
|
1036
|
-
inv_freq_expanded = inv_freq[None, :]
|
|
1037
|
-
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1038
|
-
|
|
1039
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1040
|
-
|
|
1041
|
-
cos = emb.cos() * attention_scaling
|
|
1042
|
-
sin = emb.sin() * attention_scaling
|
|
1043
|
-
|
|
1044
|
-
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1045
|
-
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1046
|
-
|
|
1047
|
-
def forward(self, x, seq_len):
|
|
1048
|
-
return (
|
|
1049
|
-
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1050
|
-
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1051
|
-
)
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
952
|
class FlashAttentionOp(AttentionOp):
|
|
1055
953
|
def __init__(
|
|
1056
954
|
self,
|
|
@@ -1060,6 +958,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1060
958
|
kvcache_partition_len: int,
|
|
1061
959
|
use_attention_mask: bool,
|
|
1062
960
|
use_position_ids: bool,
|
|
961
|
+
quantization: Optional[RBLNQuantizationConfig] = None,
|
|
1063
962
|
):
|
|
1064
963
|
super().__init__(
|
|
1065
964
|
num_heads=num_heads,
|
|
@@ -1067,6 +966,7 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1067
966
|
num_key_value_heads=num_key_value_heads,
|
|
1068
967
|
use_attention_mask=use_attention_mask,
|
|
1069
968
|
use_position_ids=use_position_ids,
|
|
969
|
+
quantization=quantization,
|
|
1070
970
|
)
|
|
1071
971
|
self.kvcache_partition_size = kvcache_partition_len
|
|
1072
972
|
|
|
@@ -1079,6 +979,9 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1079
979
|
|
|
1080
980
|
attn_op_name += phase
|
|
1081
981
|
|
|
982
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
983
|
+
attn_op_name += "_kv_fp8"
|
|
984
|
+
|
|
1082
985
|
return attn_op_name
|
|
1083
986
|
|
|
1084
987
|
def forward(
|
|
@@ -1093,6 +996,8 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1093
996
|
scale,
|
|
1094
997
|
block_tables,
|
|
1095
998
|
block_size,
|
|
999
|
+
k_scale=None,
|
|
1000
|
+
v_scale=None,
|
|
1096
1001
|
):
|
|
1097
1002
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1098
1003
|
key_state = key_state.unsqueeze(2)
|
|
@@ -1133,6 +1038,12 @@ class FlashAttentionOp(AttentionOp):
|
|
|
1133
1038
|
if not self.use_attention_mask or self.use_position_ids:
|
|
1134
1039
|
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1135
1040
|
|
|
1041
|
+
if self.quantization and self.quantization.kv_caches == "fp8":
|
|
1042
|
+
if past_key_state.dtype != torch.float8_e4m3fn:
|
|
1043
|
+
raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
|
|
1044
|
+
op_args["k_scale"] = k_scale
|
|
1045
|
+
op_args["v_scale"] = v_scale
|
|
1046
|
+
|
|
1136
1047
|
attn_op_name = self.get_attn_op_name()
|
|
1137
1048
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
1138
1049
|
if attn_op is None:
|
|
@@ -1160,14 +1071,19 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1160
1071
|
query_state: torch.Tensor,
|
|
1161
1072
|
key_state: torch.Tensor,
|
|
1162
1073
|
value_state: torch.Tensor,
|
|
1163
|
-
attn_mask: torch.Tensor,
|
|
1074
|
+
attn_mask: Optional[torch.Tensor],
|
|
1164
1075
|
past_key_state: torch.Tensor,
|
|
1165
1076
|
past_value_state: torch.Tensor,
|
|
1166
1077
|
seq_position: Tuple[torch.Tensor],
|
|
1167
1078
|
scale: torch.Tensor,
|
|
1168
1079
|
block_tables: torch.Tensor,
|
|
1169
1080
|
block_size: int,
|
|
1081
|
+
k_scale: Optional[torch.Tensor] = None,
|
|
1082
|
+
v_scale: Optional[torch.Tensor] = None,
|
|
1170
1083
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1084
|
+
assert self.quantization is None, "Sliding window attention does not support quantization"
|
|
1085
|
+
assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
|
|
1086
|
+
|
|
1171
1087
|
# reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
|
|
1172
1088
|
key_state = key_state.unsqueeze(2)
|
|
1173
1089
|
value_state = value_state.unsqueeze(2)
|
|
@@ -1199,8 +1115,7 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1199
1115
|
}
|
|
1200
1116
|
|
|
1201
1117
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
1202
|
-
|
|
1203
|
-
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1118
|
+
op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
|
|
1204
1119
|
|
|
1205
1120
|
attn_op_name = self.get_attn_op_name()
|
|
1206
1121
|
attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
|
|
@@ -1213,3 +1128,97 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1213
1128
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1214
1129
|
|
|
1215
1130
|
return attn_output
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
class RotaryEmbedding(nn.Module):
|
|
1134
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1135
|
+
|
|
1136
|
+
def __init__(
|
|
1137
|
+
self,
|
|
1138
|
+
config: PretrainedConfig,
|
|
1139
|
+
max_seq_len_cached: int,
|
|
1140
|
+
):
|
|
1141
|
+
super().__init__()
|
|
1142
|
+
|
|
1143
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1144
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1145
|
+
else:
|
|
1146
|
+
rope_type = "default"
|
|
1147
|
+
|
|
1148
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1149
|
+
cache_position = torch.arange(0, max_seq_len_cached)
|
|
1150
|
+
cache_position_expanded = cache_position[:, None]
|
|
1151
|
+
|
|
1152
|
+
if rope_type == "dynamic":
|
|
1153
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1154
|
+
else:
|
|
1155
|
+
inv_freq_expanded = inv_freq[None, :]
|
|
1156
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1157
|
+
|
|
1158
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1159
|
+
|
|
1160
|
+
cos = emb.cos() * attention_scaling
|
|
1161
|
+
sin = emb.sin() * attention_scaling
|
|
1162
|
+
|
|
1163
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1164
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1165
|
+
|
|
1166
|
+
def forward(self, x, seq_len):
|
|
1167
|
+
return (
|
|
1168
|
+
self._cos_cached[:seq_len].to(dtype=torch.float32),
|
|
1169
|
+
self._sin_cached[:seq_len].to(dtype=torch.float32),
|
|
1170
|
+
)
|
|
1171
|
+
|
|
1172
|
+
|
|
1173
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
1174
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
1175
|
+
if cache_position.shape[0] > 1:
|
|
1176
|
+
cos_all = []
|
|
1177
|
+
sin_all = []
|
|
1178
|
+
for i in range(cache_position.shape[0]):
|
|
1179
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1180
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1181
|
+
cos = torch.cat(cos_all, dim=0)
|
|
1182
|
+
sin = torch.cat(sin_all, dim=0)
|
|
1183
|
+
else:
|
|
1184
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
1185
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
1186
|
+
|
|
1187
|
+
return cos, sin
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
def rotate_half(x):
|
|
1191
|
+
"""Rotates half the hidden dims of the input."""
|
|
1192
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1193
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1194
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1195
|
+
|
|
1196
|
+
|
|
1197
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1198
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1199
|
+
dtype = q.dtype
|
|
1200
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1201
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1202
|
+
q_embed = q_embed.to(dtype)
|
|
1203
|
+
k_embed = k_embed.to(dtype)
|
|
1204
|
+
return q_embed, k_embed
|
|
1205
|
+
|
|
1206
|
+
|
|
1207
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1208
|
+
# Partial rotary embedding
|
|
1209
|
+
query_rot, query_pass = (
|
|
1210
|
+
query_states[..., :ndim],
|
|
1211
|
+
query_states[..., ndim:],
|
|
1212
|
+
)
|
|
1213
|
+
key_rot, key_pass = (
|
|
1214
|
+
key_states[..., :ndim],
|
|
1215
|
+
key_states[..., ndim:],
|
|
1216
|
+
)
|
|
1217
|
+
|
|
1218
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1219
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1220
|
+
|
|
1221
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
|
1222
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1223
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1224
|
+
return query_states, key_states
|