optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +48 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +35 -14
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
|
@@ -20,108 +20,13 @@ from torch import nn
|
|
|
20
20
|
from transformers import PretrainedConfig, PreTrainedModel
|
|
21
21
|
|
|
22
22
|
from ....utils import logging
|
|
23
|
+
from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
23
24
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
|
24
25
|
from .configuration_decoderonly import CacheImplType
|
|
25
26
|
|
|
26
27
|
|
|
27
28
|
logger = logging.get_logger(__name__)
|
|
28
29
|
|
|
29
|
-
DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
|
|
30
|
-
DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
|
|
31
|
-
MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
|
|
32
|
-
MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
|
|
33
|
-
MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
|
|
34
|
-
MAX_SLIDING_WINDOW_SIZE = 32_768
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
def set_default_values(
|
|
38
|
-
attn_impl: Optional[str] = None,
|
|
39
|
-
kvcache_partition_len: Optional[int] = None,
|
|
40
|
-
kvcache_block_size: Optional[int] = None,
|
|
41
|
-
max_seq_len: Optional[int] = None,
|
|
42
|
-
) -> Tuple[str, int, int]:
|
|
43
|
-
if attn_impl is None:
|
|
44
|
-
attn_impl = "eager"
|
|
45
|
-
|
|
46
|
-
if kvcache_partition_len is not None:
|
|
47
|
-
if attn_impl == "eager":
|
|
48
|
-
attn_impl = "flash_attn"
|
|
49
|
-
logger.warning(
|
|
50
|
-
"A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
|
|
51
|
-
"set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
|
|
52
|
-
"`attn_impl` has been automatically switched to 'flash_attn'."
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
if kvcache_partition_len is None and attn_impl == "flash_attn":
|
|
56
|
-
kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
|
|
57
|
-
|
|
58
|
-
if kvcache_block_size is None:
|
|
59
|
-
if attn_impl == "eager":
|
|
60
|
-
kvcache_block_size = max_seq_len
|
|
61
|
-
else:
|
|
62
|
-
kvcache_block_size = kvcache_partition_len
|
|
63
|
-
|
|
64
|
-
return attn_impl, kvcache_partition_len, kvcache_block_size
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
|
|
68
|
-
if attn_impl not in ["eager", "flash_attn"]:
|
|
69
|
-
raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
|
|
70
|
-
|
|
71
|
-
## Checking Constraints...
|
|
72
|
-
# Constraint of eager attention:
|
|
73
|
-
# - `max_seq_len` <= 32k
|
|
74
|
-
|
|
75
|
-
# Constraints of flash attention:
|
|
76
|
-
# 1. `max_seq_len` should be multiple of `partition_len`.
|
|
77
|
-
# 2. 4k <= `partition_len` <= 32k.
|
|
78
|
-
# 3. `max_seq_len` should be larger then 8k.
|
|
79
|
-
if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
|
|
80
|
-
raise ValueError(
|
|
81
|
-
f"`max_seq_len` is set to {max_seq_len}, "
|
|
82
|
-
f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
|
|
83
|
-
f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
|
|
84
|
-
" or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
|
|
85
|
-
)
|
|
86
|
-
|
|
87
|
-
if attn_impl == "flash_attn":
|
|
88
|
-
if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
|
|
89
|
-
raise ValueError(
|
|
90
|
-
f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
|
|
91
|
-
f"when using 'flash_attn'. Please adjust either value to meet this requirement."
|
|
92
|
-
)
|
|
93
|
-
elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
|
|
94
|
-
raise ValueError(
|
|
95
|
-
f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
|
|
96
|
-
f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
|
|
97
|
-
f"Please provide a valid value within this range."
|
|
98
|
-
)
|
|
99
|
-
elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
|
|
100
|
-
raise ValueError(
|
|
101
|
-
f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
|
|
102
|
-
f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
|
|
103
|
-
"this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
|
|
104
|
-
)
|
|
105
|
-
|
|
106
|
-
if kvcache_block_size is not None:
|
|
107
|
-
if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
|
|
108
|
-
raise ValueError(
|
|
109
|
-
f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
110
|
-
f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
|
|
111
|
-
)
|
|
112
|
-
elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
|
|
113
|
-
raise ValueError(
|
|
114
|
-
f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
|
|
115
|
-
f"must always be set equal to the `max_seq_len` {max_seq_len}."
|
|
116
|
-
)
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
|
|
120
|
-
if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
|
|
121
|
-
raise ValueError(
|
|
122
|
-
f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
|
|
123
|
-
)
|
|
124
|
-
|
|
125
30
|
|
|
126
31
|
class DecoderOnlyWrapper(nn.Module):
|
|
127
32
|
"""A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
|
|
@@ -153,7 +58,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
153
58
|
|
|
154
59
|
def __init__(
|
|
155
60
|
self,
|
|
156
|
-
|
|
61
|
+
model: PreTrainedModel,
|
|
157
62
|
max_seq_len: int,
|
|
158
63
|
use_rotary_emb: bool,
|
|
159
64
|
attn_impl: str,
|
|
@@ -167,7 +72,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
167
72
|
sliding_window_layers: Optional[List[int]] = None,
|
|
168
73
|
):
|
|
169
74
|
super().__init__()
|
|
170
|
-
self.config =
|
|
75
|
+
self.config = model.config
|
|
76
|
+
self.is_causal_lm = getattr(model, "lm_head", None) is not None
|
|
171
77
|
|
|
172
78
|
if use_rotary_emb:
|
|
173
79
|
rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
|
|
@@ -185,6 +91,8 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
185
91
|
self.use_inputs_embeds = use_inputs_embeds
|
|
186
92
|
self.sliding_window_layers = sliding_window_layers
|
|
187
93
|
self.cache_impl = cache_impl
|
|
94
|
+
self.use_global_attention = cache_impl in ["static", "hybrid"]
|
|
95
|
+
self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
|
|
188
96
|
self.sliding_window = sliding_window
|
|
189
97
|
|
|
190
98
|
if self.attn_impl == "flash_attn":
|
|
@@ -200,21 +108,21 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
200
108
|
f" or equal to max_seq_len({max_seq_len})!"
|
|
201
109
|
)
|
|
202
110
|
|
|
203
|
-
self.
|
|
111
|
+
self.model = self.convert_to_rbln_class(model, max_seq_len)
|
|
204
112
|
self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
|
|
205
113
|
self._phase = "prefill"
|
|
206
114
|
|
|
207
115
|
def get_rotary_emb(self, max_seq_len):
|
|
208
116
|
return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
|
|
209
117
|
|
|
210
|
-
def get_decoder_layers(self,
|
|
211
|
-
return
|
|
118
|
+
def get_decoder_layers(self, model: PreTrainedModel):
|
|
119
|
+
return model.model.layers if self.is_causal_lm else model.layers
|
|
212
120
|
|
|
213
121
|
def get_attn_layer(self, layer: nn.Module):
|
|
214
122
|
return layer.self_attn
|
|
215
123
|
|
|
216
|
-
def get_model_layer(self,
|
|
217
|
-
return
|
|
124
|
+
def get_model_layer(self, model: PreTrainedModel):
|
|
125
|
+
return model.model if self.is_causal_lm else model
|
|
218
126
|
|
|
219
127
|
def get_rbln_attn_class(self):
|
|
220
128
|
return DecoderOnlyAttention
|
|
@@ -228,9 +136,9 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
228
136
|
def get_rbln_causal_lm_class(self):
|
|
229
137
|
return DecoderOnlyForCausalLM
|
|
230
138
|
|
|
231
|
-
def
|
|
139
|
+
def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
|
|
232
140
|
new_layers = []
|
|
233
|
-
for layer_idx, layer in enumerate(self.get_decoder_layers(
|
|
141
|
+
for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
|
|
234
142
|
is_sliding = layer_idx in self.sliding_window_layers
|
|
235
143
|
new_self_attn = self.get_rbln_attn_class()(
|
|
236
144
|
self.get_attn_layer(layer),
|
|
@@ -247,7 +155,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
247
155
|
new_layers.append(new_layer)
|
|
248
156
|
|
|
249
157
|
new_model = self.get_rbln_model_class()(
|
|
250
|
-
self.get_model_layer(
|
|
158
|
+
self.get_model_layer(model),
|
|
251
159
|
new_layers,
|
|
252
160
|
partition_len=self.kvcache_partition_len,
|
|
253
161
|
max_seq_len=max_seq_len,
|
|
@@ -255,8 +163,12 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
255
163
|
use_learned_pos_emb=self.__class__._use_learned_pos_emb,
|
|
256
164
|
sliding_window_layers=self.sliding_window_layers,
|
|
257
165
|
)
|
|
258
|
-
|
|
259
|
-
|
|
166
|
+
|
|
167
|
+
if self.is_causal_lm:
|
|
168
|
+
new_model = self.get_rbln_causal_lm_class()(model, new_model)
|
|
169
|
+
return new_model
|
|
170
|
+
else:
|
|
171
|
+
return new_model
|
|
260
172
|
|
|
261
173
|
@property
|
|
262
174
|
def phase(self) -> str:
|
|
@@ -265,16 +177,21 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
265
177
|
@phase.setter
|
|
266
178
|
def phase(self, phase: str):
|
|
267
179
|
self._phase = phase
|
|
268
|
-
self.
|
|
180
|
+
self.model.phase = phase
|
|
269
181
|
|
|
270
182
|
def prepare_forward_args(self, *args):
|
|
271
183
|
args = list(args)
|
|
272
184
|
input_ids = None if self.use_inputs_embeds else args.pop(0)
|
|
273
185
|
inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
|
|
274
186
|
cache_position = args.pop(0)
|
|
275
|
-
global_block_tables = args.pop(0) if self.
|
|
276
|
-
local_block_tables = args.pop(0) if self.
|
|
277
|
-
query_position =
|
|
187
|
+
global_block_tables = args.pop(0) if self.use_global_attention else None
|
|
188
|
+
local_block_tables = args.pop(0) if self.use_local_attention else None
|
|
189
|
+
query_position = (
|
|
190
|
+
args.pop(0)
|
|
191
|
+
# query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
|
|
192
|
+
if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
|
|
193
|
+
else None
|
|
194
|
+
)
|
|
278
195
|
attention_mask = args.pop(0) if self.use_attention_mask else None
|
|
279
196
|
position_ids = args.pop(0) if self.use_position_ids else None
|
|
280
197
|
past_key_values = args
|
|
@@ -326,7 +243,7 @@ class DecoderOnlyWrapper(nn.Module):
|
|
|
326
243
|
rotary_emb,
|
|
327
244
|
) = self.prepare_forward_args(*args)
|
|
328
245
|
|
|
329
|
-
logit = self.
|
|
246
|
+
logit = self.model(
|
|
330
247
|
input_ids=input_ids,
|
|
331
248
|
inputs_embeds=inputs_embeds,
|
|
332
249
|
attention_mask=attention_mask,
|
|
@@ -940,7 +857,7 @@ class AttentionOp(nn.Module):
|
|
|
940
857
|
"block_size": block_size,
|
|
941
858
|
}
|
|
942
859
|
|
|
943
|
-
if self.use_attention_mask
|
|
860
|
+
if self.use_attention_mask:
|
|
944
861
|
op_args["mask"] = attn_mask
|
|
945
862
|
|
|
946
863
|
if self.phase == "prefill" or self.phase == "image_prefill":
|
|
@@ -960,97 +877,6 @@ class AttentionOp(nn.Module):
|
|
|
960
877
|
return attn_output
|
|
961
878
|
|
|
962
879
|
|
|
963
|
-
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
964
|
-
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
965
|
-
if cache_position.shape[0] > 1:
|
|
966
|
-
cos_all = []
|
|
967
|
-
sin_all = []
|
|
968
|
-
for i in range(cache_position.shape[0]):
|
|
969
|
-
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
970
|
-
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
971
|
-
cos = torch.cat(cos_all, dim=0)
|
|
972
|
-
sin = torch.cat(sin_all, dim=0)
|
|
973
|
-
else:
|
|
974
|
-
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
975
|
-
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
976
|
-
|
|
977
|
-
return cos, sin
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
def rotate_half(x):
|
|
981
|
-
"""Rotates half the hidden dims of the input."""
|
|
982
|
-
x1 = x[..., : x.shape[-1] // 2]
|
|
983
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
|
984
|
-
return torch.cat((-x2, x1), dim=-1)
|
|
985
|
-
|
|
986
|
-
|
|
987
|
-
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
988
|
-
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
989
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
990
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
991
|
-
return q_embed, k_embed
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
995
|
-
# Partial rotary embedding
|
|
996
|
-
query_rot, query_pass = (
|
|
997
|
-
query_states[..., :ndim],
|
|
998
|
-
query_states[..., ndim:],
|
|
999
|
-
)
|
|
1000
|
-
key_rot, key_pass = (
|
|
1001
|
-
key_states[..., :ndim],
|
|
1002
|
-
key_states[..., ndim:],
|
|
1003
|
-
)
|
|
1004
|
-
|
|
1005
|
-
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1006
|
-
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1007
|
-
|
|
1008
|
-
# [batch_size, seq_length, num_heads, head_dim]
|
|
1009
|
-
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1010
|
-
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1011
|
-
return query_states, key_states
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
class RotaryEmbedding(nn.Module):
|
|
1015
|
-
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1016
|
-
|
|
1017
|
-
def __init__(
|
|
1018
|
-
self,
|
|
1019
|
-
config: PretrainedConfig,
|
|
1020
|
-
max_seq_len_cached: int,
|
|
1021
|
-
):
|
|
1022
|
-
super().__init__()
|
|
1023
|
-
|
|
1024
|
-
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1025
|
-
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1026
|
-
else:
|
|
1027
|
-
rope_type = "default"
|
|
1028
|
-
|
|
1029
|
-
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1030
|
-
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
1031
|
-
cache_position_expanded = cache_position[:, None]
|
|
1032
|
-
|
|
1033
|
-
if rope_type == "dynamic":
|
|
1034
|
-
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1035
|
-
else:
|
|
1036
|
-
inv_freq_expanded = inv_freq[None, :]
|
|
1037
|
-
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1038
|
-
|
|
1039
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1040
|
-
|
|
1041
|
-
cos = emb.cos() * attention_scaling
|
|
1042
|
-
sin = emb.sin() * attention_scaling
|
|
1043
|
-
|
|
1044
|
-
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1045
|
-
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1046
|
-
|
|
1047
|
-
def forward(self, x, seq_len):
|
|
1048
|
-
return (
|
|
1049
|
-
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1050
|
-
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1051
|
-
)
|
|
1052
|
-
|
|
1053
|
-
|
|
1054
880
|
class FlashAttentionOp(AttentionOp):
|
|
1055
881
|
def __init__(
|
|
1056
882
|
self,
|
|
@@ -1213,3 +1039,94 @@ class SlidingWindowAttentionOp(AttentionOp):
|
|
|
1213
1039
|
attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
|
|
1214
1040
|
|
|
1215
1041
|
return attn_output
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
class RotaryEmbedding(nn.Module):
|
|
1045
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
|
1046
|
+
|
|
1047
|
+
def __init__(
|
|
1048
|
+
self,
|
|
1049
|
+
config: PretrainedConfig,
|
|
1050
|
+
max_seq_len_cached: int,
|
|
1051
|
+
):
|
|
1052
|
+
super().__init__()
|
|
1053
|
+
|
|
1054
|
+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
|
|
1055
|
+
rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
|
1056
|
+
else:
|
|
1057
|
+
rope_type = "default"
|
|
1058
|
+
|
|
1059
|
+
inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
|
|
1060
|
+
cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
|
|
1061
|
+
cache_position_expanded = cache_position[:, None]
|
|
1062
|
+
|
|
1063
|
+
if rope_type == "dynamic":
|
|
1064
|
+
freqs = cache_position_expanded.float() * inv_freq.float()
|
|
1065
|
+
else:
|
|
1066
|
+
inv_freq_expanded = inv_freq[None, :]
|
|
1067
|
+
freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
|
|
1068
|
+
|
|
1069
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
1070
|
+
|
|
1071
|
+
cos = emb.cos() * attention_scaling
|
|
1072
|
+
sin = emb.sin() * attention_scaling
|
|
1073
|
+
|
|
1074
|
+
self.register_buffer("_cos_cached", cos, persistent=False)
|
|
1075
|
+
self.register_buffer("_sin_cached", sin, persistent=False)
|
|
1076
|
+
|
|
1077
|
+
def forward(self, x, seq_len):
|
|
1078
|
+
return (
|
|
1079
|
+
self._cos_cached[:seq_len].to(dtype=x.dtype),
|
|
1080
|
+
self._sin_cached[:seq_len].to(dtype=x.dtype),
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1083
|
+
|
|
1084
|
+
def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
|
|
1085
|
+
"""Slice cos[cache_position], sin[cache_position] vector for the query."""
|
|
1086
|
+
if cache_position.shape[0] > 1:
|
|
1087
|
+
cos_all = []
|
|
1088
|
+
sin_all = []
|
|
1089
|
+
for i in range(cache_position.shape[0]):
|
|
1090
|
+
cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1091
|
+
sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
|
|
1092
|
+
cos = torch.cat(cos_all, dim=0)
|
|
1093
|
+
sin = torch.cat(sin_all, dim=0)
|
|
1094
|
+
else:
|
|
1095
|
+
cos = cos[cache_position].unsqueeze(unsqueeze_dim)
|
|
1096
|
+
sin = sin[cache_position].unsqueeze(unsqueeze_dim)
|
|
1097
|
+
|
|
1098
|
+
return cos, sin
|
|
1099
|
+
|
|
1100
|
+
|
|
1101
|
+
def rotate_half(x):
|
|
1102
|
+
"""Rotates half the hidden dims of the input."""
|
|
1103
|
+
x1 = x[..., : x.shape[-1] // 2]
|
|
1104
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
|
1105
|
+
return torch.cat((-x2, x1), dim=-1)
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
|
1109
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
|
1110
|
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
|
1111
|
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
|
1112
|
+
return q_embed, k_embed
|
|
1113
|
+
|
|
1114
|
+
|
|
1115
|
+
def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1116
|
+
# Partial rotary embedding
|
|
1117
|
+
query_rot, query_pass = (
|
|
1118
|
+
query_states[..., :ndim],
|
|
1119
|
+
query_states[..., ndim:],
|
|
1120
|
+
)
|
|
1121
|
+
key_rot, key_pass = (
|
|
1122
|
+
key_states[..., :ndim],
|
|
1123
|
+
key_states[..., ndim:],
|
|
1124
|
+
)
|
|
1125
|
+
|
|
1126
|
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
|
1127
|
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
|
|
1128
|
+
|
|
1129
|
+
# [batch_size, seq_length, num_heads, head_dim]
|
|
1130
|
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
|
1131
|
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
|
1132
|
+
return query_states, key_states
|