optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +48 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +35 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  27. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  28. optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
  29. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  30. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  31. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  32. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  33. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  34. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  35. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  36. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  37. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  38. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  39. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  40. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  41. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  42. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  43. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  44. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  45. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  46. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  47. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
  49. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  51. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  52. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  53. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  54. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  55. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  56. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  57. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  58. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  59. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  60. optimum/rbln/utils/depreacate_utils.py +16 -0
  61. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
  62. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
  63. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
@@ -20,108 +20,13 @@ from torch import nn
20
20
  from transformers import PretrainedConfig, PreTrainedModel
21
21
 
22
22
  from ....utils import logging
23
+ from ...modeling_attention_utils import DEFAULT_FLASH_ATTN_PARTITION_LENGTH
23
24
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
25
  from .configuration_decoderonly import CacheImplType
25
26
 
26
27
 
27
28
  logger = logging.get_logger(__name__)
28
29
 
29
- DEFAULT_FLASH_ATTN_PARTITION_LENGTH = 16_384
30
- DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH = 32_768
31
- MIN_FLASH_ATTN_MAX_SEQ_LEN = 8_192
32
- MIN_FLASH_ATTN_PARTITION_LENGTH = 4_096
33
- MAX_FLASH_ATTN_PARTITION_LENGTH = 32_768
34
- MAX_SLIDING_WINDOW_SIZE = 32_768
35
-
36
-
37
- def set_default_values(
38
- attn_impl: Optional[str] = None,
39
- kvcache_partition_len: Optional[int] = None,
40
- kvcache_block_size: Optional[int] = None,
41
- max_seq_len: Optional[int] = None,
42
- ) -> Tuple[str, int, int]:
43
- if attn_impl is None:
44
- attn_impl = "eager"
45
-
46
- if kvcache_partition_len is not None:
47
- if attn_impl == "eager":
48
- attn_impl = "flash_attn"
49
- logger.warning(
50
- "A non-null `kvcache_partition_len` was provided, but `attn_impl` was not explicitly set or "
51
- "set to 'eager'. Since KV cache partitioning is only supported with flash attention, "
52
- "`attn_impl` has been automatically switched to 'flash_attn'."
53
- )
54
-
55
- if kvcache_partition_len is None and attn_impl == "flash_attn":
56
- kvcache_partition_len = DEFAULT_FLASH_ATTN_PARTITION_LENGTH
57
-
58
- if kvcache_block_size is None:
59
- if attn_impl == "eager":
60
- kvcache_block_size = max_seq_len
61
- else:
62
- kvcache_block_size = kvcache_partition_len
63
-
64
- return attn_impl, kvcache_partition_len, kvcache_block_size
65
-
66
-
67
- def validate_attention_method(attn_impl: str, kvcache_partition_len: int, kvcache_block_size: int, max_seq_len: int):
68
- if attn_impl not in ["eager", "flash_attn"]:
69
- raise ValueError(f"Unknown `attn_impl` : {attn_impl}. (Available : 'eager', 'flash_attn`)")
70
-
71
- ## Checking Constraints...
72
- # Constraint of eager attention:
73
- # - `max_seq_len` <= 32k
74
-
75
- # Constraints of flash attention:
76
- # 1. `max_seq_len` should be multiple of `partition_len`.
77
- # 2. 4k <= `partition_len` <= 32k.
78
- # 3. `max_seq_len` should be larger then 8k.
79
- if attn_impl == "eager" and max_seq_len > DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH:
80
- raise ValueError(
81
- f"`max_seq_len` is set to {max_seq_len}, "
82
- f"which exceeds the limit of {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} for 'eager' attention. "
83
- f"Please reduce the `max_seq_len` to {DEFAULT_MAX_EAGER_ATTN_SEQUENCE_LENGTH} or lower,"
84
- " or consider switching `attn_impl` to 'flash_attn' for larger sequence lengths."
85
- )
86
-
87
- if attn_impl == "flash_attn":
88
- if max_seq_len // kvcache_partition_len < 2 or max_seq_len % kvcache_partition_len != 0:
89
- raise ValueError(
90
- f"`max_seq_len` ({max_seq_len}) must be a multiple of `kvcache_partition_len` ({kvcache_partition_len}) "
91
- f"when using 'flash_attn'. Please adjust either value to meet this requirement."
92
- )
93
- elif not (MIN_FLASH_ATTN_PARTITION_LENGTH <= kvcache_partition_len <= MAX_FLASH_ATTN_PARTITION_LENGTH):
94
- raise ValueError(
95
- f"`kvcache_partition_len` ({kvcache_partition_len}) is out of the supported range for 'flash_attn' "
96
- f"({MIN_FLASH_ATTN_PARTITION_LENGTH} <= `kvcache_partition_len` <= {MAX_FLASH_ATTN_PARTITION_LENGTH}). "
97
- f"Please provide a valid value within this range."
98
- )
99
- elif max_seq_len < MIN_FLASH_ATTN_MAX_SEQ_LEN:
100
- raise ValueError(
101
- f"`max_seq_len` ({max_seq_len}) is too small for 'flash_attn'. The minimum "
102
- f"supported value is {MIN_FLASH_ATTN_MAX_SEQ_LEN}. Please increase `max_seq_len` to meet "
103
- "this requirement, or consider switching `attn_impl` to 'eager' for shorter lengths."
104
- )
105
-
106
- if kvcache_block_size is not None:
107
- if attn_impl == "flash_attn" and kvcache_partition_len != kvcache_block_size:
108
- raise ValueError(
109
- f" When using 'flash attention', the `kvcache_block_size` ({kvcache_block_size}) "
110
- f"must always be set equal to the `kvcache_partition_len` {kvcache_partition_len}."
111
- )
112
- elif attn_impl == "eager" and kvcache_block_size != max_seq_len:
113
- raise ValueError(
114
- f" When using 'eager attention', the `kvcache_block_size` ({kvcache_block_size}) "
115
- f"must always be set equal to the `max_seq_len` {max_seq_len}."
116
- )
117
-
118
-
119
- def validate_sliding_window_size(sliding_window: int, prefill_chunk_size: int):
120
- if sliding_window > MAX_SLIDING_WINDOW_SIZE - prefill_chunk_size:
121
- raise ValueError(
122
- f"Sliding window size ({sliding_window}) must be less than 32768 - prefill_chunk_size ({32768 - prefill_chunk_size})"
123
- )
124
-
125
30
 
126
31
  class DecoderOnlyWrapper(nn.Module):
127
32
  """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
@@ -153,7 +58,7 @@ class DecoderOnlyWrapper(nn.Module):
153
58
 
154
59
  def __init__(
155
60
  self,
156
- causal_lm: PreTrainedModel,
61
+ model: PreTrainedModel,
157
62
  max_seq_len: int,
158
63
  use_rotary_emb: bool,
159
64
  attn_impl: str,
@@ -167,7 +72,8 @@ class DecoderOnlyWrapper(nn.Module):
167
72
  sliding_window_layers: Optional[List[int]] = None,
168
73
  ):
169
74
  super().__init__()
170
- self.config = causal_lm.config
75
+ self.config = model.config
76
+ self.is_causal_lm = getattr(model, "lm_head", None) is not None
171
77
 
172
78
  if use_rotary_emb:
173
79
  rotary_embs = self.get_rotary_emb(max_seq_len=max_seq_len)
@@ -185,6 +91,8 @@ class DecoderOnlyWrapper(nn.Module):
185
91
  self.use_inputs_embeds = use_inputs_embeds
186
92
  self.sliding_window_layers = sliding_window_layers
187
93
  self.cache_impl = cache_impl
94
+ self.use_global_attention = cache_impl in ["static", "hybrid"]
95
+ self.use_local_attention = cache_impl in ["hybrid", "sliding_window"]
188
96
  self.sliding_window = sliding_window
189
97
 
190
98
  if self.attn_impl == "flash_attn":
@@ -200,21 +108,21 @@ class DecoderOnlyWrapper(nn.Module):
200
108
  f" or equal to max_seq_len({max_seq_len})!"
201
109
  )
202
110
 
203
- self.causal_lm = self.convert_to_rbln_causal_lm(causal_lm, max_seq_len)
111
+ self.model = self.convert_to_rbln_class(model, max_seq_len)
204
112
  self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
205
113
  self._phase = "prefill"
206
114
 
207
115
  def get_rotary_emb(self, max_seq_len):
208
116
  return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
209
117
 
210
- def get_decoder_layers(self, causal_lm: PreTrainedModel):
211
- return causal_lm.model.layers
118
+ def get_decoder_layers(self, model: PreTrainedModel):
119
+ return model.model.layers if self.is_causal_lm else model.layers
212
120
 
213
121
  def get_attn_layer(self, layer: nn.Module):
214
122
  return layer.self_attn
215
123
 
216
- def get_model_layer(self, causal_lm: PreTrainedModel):
217
- return causal_lm.model
124
+ def get_model_layer(self, model: PreTrainedModel):
125
+ return model.model if self.is_causal_lm else model
218
126
 
219
127
  def get_rbln_attn_class(self):
220
128
  return DecoderOnlyAttention
@@ -228,9 +136,9 @@ class DecoderOnlyWrapper(nn.Module):
228
136
  def get_rbln_causal_lm_class(self):
229
137
  return DecoderOnlyForCausalLM
230
138
 
231
- def convert_to_rbln_causal_lm(self, causal_lm: PreTrainedModel, max_seq_len: int):
139
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
232
140
  new_layers = []
233
- for layer_idx, layer in enumerate(self.get_decoder_layers(causal_lm)):
141
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
234
142
  is_sliding = layer_idx in self.sliding_window_layers
235
143
  new_self_attn = self.get_rbln_attn_class()(
236
144
  self.get_attn_layer(layer),
@@ -247,7 +155,7 @@ class DecoderOnlyWrapper(nn.Module):
247
155
  new_layers.append(new_layer)
248
156
 
249
157
  new_model = self.get_rbln_model_class()(
250
- self.get_model_layer(causal_lm),
158
+ self.get_model_layer(model),
251
159
  new_layers,
252
160
  partition_len=self.kvcache_partition_len,
253
161
  max_seq_len=max_seq_len,
@@ -255,8 +163,12 @@ class DecoderOnlyWrapper(nn.Module):
255
163
  use_learned_pos_emb=self.__class__._use_learned_pos_emb,
256
164
  sliding_window_layers=self.sliding_window_layers,
257
165
  )
258
- new_causal_lm = self.get_rbln_causal_lm_class()(causal_lm, new_model)
259
- return new_causal_lm
166
+
167
+ if self.is_causal_lm:
168
+ new_model = self.get_rbln_causal_lm_class()(model, new_model)
169
+ return new_model
170
+ else:
171
+ return new_model
260
172
 
261
173
  @property
262
174
  def phase(self) -> str:
@@ -265,16 +177,21 @@ class DecoderOnlyWrapper(nn.Module):
265
177
  @phase.setter
266
178
  def phase(self, phase: str):
267
179
  self._phase = phase
268
- self.causal_lm.phase = phase
180
+ self.model.phase = phase
269
181
 
270
182
  def prepare_forward_args(self, *args):
271
183
  args = list(args)
272
184
  input_ids = None if self.use_inputs_embeds else args.pop(0)
273
185
  inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
274
186
  cache_position = args.pop(0)
275
- global_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "static"] else None
276
- local_block_tables = args.pop(0) if self.cache_impl in ["hybrid", "sliding_window"] else None
277
- query_position = args.pop(0) if "prefill" in self.phase else None
187
+ global_block_tables = args.pop(0) if self.use_global_attention else None
188
+ local_block_tables = args.pop(0) if self.use_local_attention else None
189
+ query_position = (
190
+ args.pop(0)
191
+ # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
192
+ if ("prefill" in self.phase and (self.is_causal_lm or self.use_local_attention))
193
+ else None
194
+ )
278
195
  attention_mask = args.pop(0) if self.use_attention_mask else None
279
196
  position_ids = args.pop(0) if self.use_position_ids else None
280
197
  past_key_values = args
@@ -326,7 +243,7 @@ class DecoderOnlyWrapper(nn.Module):
326
243
  rotary_emb,
327
244
  ) = self.prepare_forward_args(*args)
328
245
 
329
- logit = self.causal_lm(
246
+ logit = self.model(
330
247
  input_ids=input_ids,
331
248
  inputs_embeds=inputs_embeds,
332
249
  attention_mask=attention_mask,
@@ -940,7 +857,7 @@ class AttentionOp(nn.Module):
940
857
  "block_size": block_size,
941
858
  }
942
859
 
943
- if self.use_attention_mask != self.use_position_ids:
860
+ if self.use_attention_mask:
944
861
  op_args["mask"] = attn_mask
945
862
 
946
863
  if self.phase == "prefill" or self.phase == "image_prefill":
@@ -960,97 +877,6 @@ class AttentionOp(nn.Module):
960
877
  return attn_output
961
878
 
962
879
 
963
- def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
964
- """Slice cos[cache_position], sin[cache_position] vector for the query."""
965
- if cache_position.shape[0] > 1:
966
- cos_all = []
967
- sin_all = []
968
- for i in range(cache_position.shape[0]):
969
- cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
970
- sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
971
- cos = torch.cat(cos_all, dim=0)
972
- sin = torch.cat(sin_all, dim=0)
973
- else:
974
- cos = cos[cache_position].unsqueeze(unsqueeze_dim)
975
- sin = sin[cache_position].unsqueeze(unsqueeze_dim)
976
-
977
- return cos, sin
978
-
979
-
980
- def rotate_half(x):
981
- """Rotates half the hidden dims of the input."""
982
- x1 = x[..., : x.shape[-1] // 2]
983
- x2 = x[..., x.shape[-1] // 2 :]
984
- return torch.cat((-x2, x1), dim=-1)
985
-
986
-
987
- def apply_rotary_pos_emb(q, k, cos, sin):
988
- """Applies Rotary Position Embedding to the query and key tensors."""
989
- q_embed = (q * cos) + (rotate_half(q) * sin)
990
- k_embed = (k * cos) + (rotate_half(k) * sin)
991
- return q_embed, k_embed
992
-
993
-
994
- def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
995
- # Partial rotary embedding
996
- query_rot, query_pass = (
997
- query_states[..., :ndim],
998
- query_states[..., ndim:],
999
- )
1000
- key_rot, key_pass = (
1001
- key_states[..., :ndim],
1002
- key_states[..., ndim:],
1003
- )
1004
-
1005
- # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1006
- query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1007
-
1008
- # [batch_size, seq_length, num_heads, head_dim]
1009
- query_states = torch.cat((query_rot, query_pass), dim=-1)
1010
- key_states = torch.cat((key_rot, key_pass), dim=-1)
1011
- return query_states, key_states
1012
-
1013
-
1014
- class RotaryEmbedding(nn.Module):
1015
- """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1016
-
1017
- def __init__(
1018
- self,
1019
- config: PretrainedConfig,
1020
- max_seq_len_cached: int,
1021
- ):
1022
- super().__init__()
1023
-
1024
- if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1025
- rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1026
- else:
1027
- rope_type = "default"
1028
-
1029
- inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1030
- cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1031
- cache_position_expanded = cache_position[:, None]
1032
-
1033
- if rope_type == "dynamic":
1034
- freqs = cache_position_expanded.float() * inv_freq.float()
1035
- else:
1036
- inv_freq_expanded = inv_freq[None, :]
1037
- freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1038
-
1039
- emb = torch.cat((freqs, freqs), dim=-1)
1040
-
1041
- cos = emb.cos() * attention_scaling
1042
- sin = emb.sin() * attention_scaling
1043
-
1044
- self.register_buffer("_cos_cached", cos, persistent=False)
1045
- self.register_buffer("_sin_cached", sin, persistent=False)
1046
-
1047
- def forward(self, x, seq_len):
1048
- return (
1049
- self._cos_cached[:seq_len].to(dtype=x.dtype),
1050
- self._sin_cached[:seq_len].to(dtype=x.dtype),
1051
- )
1052
-
1053
-
1054
880
  class FlashAttentionOp(AttentionOp):
1055
881
  def __init__(
1056
882
  self,
@@ -1213,3 +1039,94 @@ class SlidingWindowAttentionOp(AttentionOp):
1213
1039
  attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1214
1040
 
1215
1041
  return attn_output
1042
+
1043
+
1044
+ class RotaryEmbedding(nn.Module):
1045
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1046
+
1047
+ def __init__(
1048
+ self,
1049
+ config: PretrainedConfig,
1050
+ max_seq_len_cached: int,
1051
+ ):
1052
+ super().__init__()
1053
+
1054
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1055
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1056
+ else:
1057
+ rope_type = "default"
1058
+
1059
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1060
+ cache_position = torch.arange(0, max_seq_len_cached, dtype=torch.float32)
1061
+ cache_position_expanded = cache_position[:, None]
1062
+
1063
+ if rope_type == "dynamic":
1064
+ freqs = cache_position_expanded.float() * inv_freq.float()
1065
+ else:
1066
+ inv_freq_expanded = inv_freq[None, :]
1067
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1068
+
1069
+ emb = torch.cat((freqs, freqs), dim=-1)
1070
+
1071
+ cos = emb.cos() * attention_scaling
1072
+ sin = emb.sin() * attention_scaling
1073
+
1074
+ self.register_buffer("_cos_cached", cos, persistent=False)
1075
+ self.register_buffer("_sin_cached", sin, persistent=False)
1076
+
1077
+ def forward(self, x, seq_len):
1078
+ return (
1079
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
1080
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
1081
+ )
1082
+
1083
+
1084
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
1085
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
1086
+ if cache_position.shape[0] > 1:
1087
+ cos_all = []
1088
+ sin_all = []
1089
+ for i in range(cache_position.shape[0]):
1090
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1091
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1092
+ cos = torch.cat(cos_all, dim=0)
1093
+ sin = torch.cat(sin_all, dim=0)
1094
+ else:
1095
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
1096
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
1097
+
1098
+ return cos, sin
1099
+
1100
+
1101
+ def rotate_half(x):
1102
+ """Rotates half the hidden dims of the input."""
1103
+ x1 = x[..., : x.shape[-1] // 2]
1104
+ x2 = x[..., x.shape[-1] // 2 :]
1105
+ return torch.cat((-x2, x1), dim=-1)
1106
+
1107
+
1108
+ def apply_rotary_pos_emb(q, k, cos, sin):
1109
+ """Applies Rotary Position Embedding to the query and key tensors."""
1110
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1111
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1112
+ return q_embed, k_embed
1113
+
1114
+
1115
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
1116
+ # Partial rotary embedding
1117
+ query_rot, query_pass = (
1118
+ query_states[..., :ndim],
1119
+ query_states[..., ndim:],
1120
+ )
1121
+ key_rot, key_pass = (
1122
+ key_states[..., :ndim],
1123
+ key_states[..., ndim:],
1124
+ )
1125
+
1126
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1127
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1128
+
1129
+ # [batch_size, seq_length, num_heads, head_dim]
1130
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
1131
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
1132
+ return query_states, key_states