optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +17 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  5. optimum/rbln/diffusers/models/controlnet.py +7 -3
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/modeling_alias.py +19 -1
  13. optimum/rbln/modeling_base.py +162 -18
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/cache_utils.py +111 -0
  16. optimum/rbln/transformers/generation/utils.py +0 -2
  17. optimum/rbln/transformers/models/__init__.py +3 -0
  18. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  19. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  20. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  21. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
  22. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
  23. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  24. optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
  25. optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
  26. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  27. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
  28. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  29. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  32. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  33. optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
  34. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  35. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  36. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  37. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  38. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  39. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  40. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
  41. optimum/rbln/transformers/utils/__init__.py +0 -0
  42. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  43. optimum/rbln/utils/import_utils.py +1 -4
  44. optimum/rbln/utils/runtime_utils.py +2 -1
  45. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
  46. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
  47. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  48. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  49. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,516 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import math
25
+ from typing import Dict, Optional, Tuple
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ )
32
+
33
+ from ...cache_utils import RebelDynamicCache
34
+
35
+
36
+ class DecoderOnlyWrapper(torch.nn.Module):
37
+ def __init__(self, model, max_seq_len):
38
+ super().__init__()
39
+ self.config = model.config
40
+ self.model = model.model
41
+ self.lm_head = model.lm_head
42
+
43
+ self.head_dim = (
44
+ self.config.head_dim
45
+ if hasattr(self.config, "head_dim")
46
+ else self.config.hidden_size // self.config.num_attention_heads
47
+ )
48
+ self.max_position_embeddings = (
49
+ self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
50
+ )
51
+ self.max_seq_len = max_seq_len
52
+ self.rope_scaling = getattr(self.config, "rope_scaling", None)
53
+ self.rotary_emb = self._init_rope()
54
+
55
+ def _init_rope(self):
56
+ if self.rope_scaling is None:
57
+ rotary_emb = RotaryEmbedding(
58
+ self.head_dim,
59
+ max_position_embeddings=self.max_position_embeddings,
60
+ base=self.config.rope_theta,
61
+ )
62
+ else:
63
+ scaling_type = self.rope_scaling["type"]
64
+ scaling_factor = self.rope_scaling["factor"]
65
+ if scaling_type == "linear":
66
+ rotary_emb = LinearScalingRotaryEmbedding(
67
+ self.head_dim,
68
+ max_position_embeddings=self.max_position_embeddings,
69
+ scaling_factor=scaling_factor,
70
+ base=self.config.rope_theta,
71
+ max_seq_len=self.max_seq_len,
72
+ )
73
+ elif scaling_type == "dynamic":
74
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
75
+ self.head_dim,
76
+ max_position_embeddings=self.max_position_embeddings,
77
+ scaling_factor=scaling_factor,
78
+ base=self.config.rope_theta,
79
+ max_seq_len=self.max_seq_len,
80
+ )
81
+ else:
82
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
83
+
84
+ return rotary_emb
85
+
86
+ def get_forward_dict(self):
87
+ forward_dict = {
88
+ "wrapper": DecoderOnlyModel.forward,
89
+ "model": DecoderOnlyDecoderLayer.forward,
90
+ "decoder_layer": DecoderOnlyAttention.forward,
91
+ }
92
+ return forward_dict
93
+
94
+ def forward(
95
+ self,
96
+ input_ids,
97
+ attention_mask,
98
+ cache_position,
99
+ batch_position,
100
+ *past_key_values,
101
+ ):
102
+ if input_ids.shape[1] == 1:
103
+ rbln_batch_position = None
104
+ else:
105
+ rbln_batch_position = batch_position
106
+
107
+ # Formatting list of past_kv to DynamicCache class.
108
+ past_key_values = RebelDynamicCache.from_input_format(
109
+ cache_position,
110
+ self.config.num_hidden_layers,
111
+ *past_key_values,
112
+ )
113
+
114
+ forward_dict = self.get_forward_dict()
115
+ outputs = forward_dict["wrapper"](
116
+ self.model,
117
+ input_ids=input_ids,
118
+ attention_mask=attention_mask,
119
+ position_ids=cache_position,
120
+ past_key_values=past_key_values,
121
+ batch_ids=rbln_batch_position,
122
+ rotary_pos_emb=self.rotary_emb,
123
+ forward_dict=forward_dict,
124
+ )
125
+
126
+ hidden_states = outputs[0]
127
+ logits = self.lm_head(hidden_states)
128
+
129
+ output = (logits,) + outputs[1:]
130
+
131
+ return output, batch_position
132
+
133
+
134
+ class DecoderOnlyAttention:
135
+ def forward(
136
+ self,
137
+ hidden_states: torch.Tensor,
138
+ attention_mask: Optional[torch.Tensor] = None,
139
+ past_key_value: Optional[RebelDynamicCache] = None,
140
+ batch_index: Optional[int] = None,
141
+ output_attentions: bool = False,
142
+ cos: Optional[torch.Tensor] = None,
143
+ sin: Optional[torch.Tensor] = None,
144
+ **kwargs,
145
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
146
+ bsz, q_len, _ = hidden_states.size()
147
+
148
+ query_states = self.q_proj(hidden_states)
149
+ key_states = self.k_proj(hidden_states)
150
+ value_states = self.v_proj(hidden_states)
151
+
152
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
153
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
154
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
155
+
156
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
157
+
158
+ # Decoder
159
+ if (batch_index is None or batch_index == -1) and bsz > 1:
160
+ all_key_states = []
161
+ all_value_states = []
162
+ all_attn_output = []
163
+
164
+ for b in range(bsz):
165
+ query_state = query_states[b].unsqueeze(0)
166
+ attn_mask = attention_mask[b].unsqueeze(0)
167
+ key_state = key_states[b].unsqueeze(0)
168
+ value_state = value_states[b].unsqueeze(0)
169
+
170
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
171
+ key_state = key_state.unsqueeze(2)
172
+ value_state = value_state.unsqueeze(2)
173
+ attn_mask = attn_mask.unsqueeze(2)
174
+
175
+ query_state = query_state.view(
176
+ 1,
177
+ self.num_key_value_heads,
178
+ self.num_heads // self.num_key_value_heads,
179
+ q_len,
180
+ self.head_dim,
181
+ )
182
+
183
+ key_state, value_state = past_key_value.update(
184
+ key_state,
185
+ value_state,
186
+ self.layer_idx,
187
+ b,
188
+ )
189
+
190
+ # reshape for removing repeat_kv
191
+ attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
192
+
193
+ attn_weight = attn_weight + attn_mask
194
+
195
+ # upcast attention to fp32
196
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
197
+ attn_output = torch.matmul(attn_weight, value_state)
198
+
199
+ # reshape for removing repeat_kv
200
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
201
+
202
+ attn_output = attn_output.transpose(1, 2).contiguous()
203
+ attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
204
+
205
+ all_key_states.append(key_state)
206
+ all_value_states.append(value_state)
207
+ all_attn_output.append(attn_output)
208
+
209
+ key_states = torch.cat(all_key_states, dim=0)
210
+ value_states = torch.cat(all_value_states, dim=0)
211
+ attn_output = torch.cat(all_attn_output, dim=0)
212
+
213
+ else:
214
+ if batch_index is None or batch_index == -1:
215
+ batch_index = 0
216
+
217
+ # reshape for removing repeat_kv
218
+ key_states = key_states.unsqueeze(2)
219
+ value_states = value_states.unsqueeze(2)
220
+ attention_mask = attention_mask.unsqueeze(2)
221
+ query_states = query_states.view(
222
+ 1,
223
+ self.num_key_value_heads,
224
+ self.num_heads // self.num_key_value_heads,
225
+ q_len,
226
+ self.head_dim,
227
+ )
228
+
229
+ key_states, value_states = past_key_value.update(
230
+ key_states,
231
+ value_states,
232
+ self.layer_idx,
233
+ batch_index,
234
+ read_first_step=True,
235
+ )
236
+
237
+ attn_weight = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
238
+ attn_weight = attn_weight + attention_mask
239
+
240
+ # upcast attention to fp32
241
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
242
+ attn_output = torch.matmul(attn_weight, value_states)
243
+
244
+ # reshape for removing repeat_kv
245
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
246
+ attn_output = attn_output.transpose(1, 2).contiguous()
247
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
248
+
249
+ attn_output = self.o_proj(attn_output)
250
+
251
+ if not output_attentions:
252
+ attn_weight = None
253
+
254
+ return attn_output, attn_weight, key_states, value_states
255
+
256
+
257
+ class DecoderOnlyDecoderLayer:
258
+ def forward(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ layer_idx: int,
262
+ attention_mask: Optional[torch.Tensor] = None,
263
+ position_ids: Optional[torch.LongTensor] = None,
264
+ past_key_value: Optional[RebelDynamicCache] = None,
265
+ output_attentions: Optional[bool] = None,
266
+ use_cache: Optional[bool] = None,
267
+ batch_ids: Optional[torch.LongTensor] = None,
268
+ cos: Optional[torch.Tensor] = None,
269
+ sin: Optional[torch.Tensor] = None,
270
+ forward_dict: Optional[Dict[str, classmethod]] = None,
271
+ **kwargs,
272
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
273
+ residual = hidden_states
274
+
275
+ hidden_states = self.input_layernorm(hidden_states)
276
+
277
+ hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
278
+ self.self_attn,
279
+ hidden_states=hidden_states,
280
+ attention_mask=attention_mask,
281
+ position_ids=position_ids,
282
+ past_key_value=past_key_value,
283
+ output_attentions=output_attentions,
284
+ batch_index=batch_ids,
285
+ use_cache=use_cache,
286
+ cos=cos,
287
+ sin=sin,
288
+ **kwargs,
289
+ )
290
+ past_key_value.assign(k, v, layer_idx)
291
+
292
+ hidden_states = residual + hidden_states
293
+
294
+ # Fully Connected
295
+ residual = hidden_states
296
+ hidden_states = self.post_attention_layernorm(hidden_states)
297
+ hidden_states = self.mlp(hidden_states)
298
+ hidden_states = residual + hidden_states
299
+
300
+ outputs = (hidden_states,)
301
+
302
+ if output_attentions:
303
+ outputs += (self_attn_weight,)
304
+
305
+ if use_cache:
306
+ outputs += (past_key_value,)
307
+
308
+ return outputs
309
+
310
+
311
+ class DecoderOnlyModel:
312
+ def forward(
313
+ self,
314
+ input_ids: torch.LongTensor = None,
315
+ attention_mask: Optional[torch.Tensor] = None,
316
+ position_ids: Optional[torch.LongTensor] = None,
317
+ past_key_values: Optional[RebelDynamicCache] = None,
318
+ batch_ids: Optional[torch.LongTensor] = None,
319
+ inputs_embeds: Optional[torch.FloatTensor] = None,
320
+ use_cache: Optional[bool] = True,
321
+ output_attentions: Optional[bool] = False,
322
+ output_hidden_states: Optional[bool] = False,
323
+ forward_dict: Optional[Dict[str, classmethod]] = None,
324
+ rotary_pos_emb=None,
325
+ ) -> BaseModelOutputWithPast:
326
+ # embed positions
327
+ inputs_embeds = self.embed_tokens(input_ids)
328
+ hidden_states = inputs_embeds
329
+ attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
330
+
331
+ # get cos,sin vector
332
+ cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
333
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
334
+
335
+ # decoder layers
336
+ all_hidden_states = () if output_hidden_states else None
337
+ all_self_attns = () if output_attentions else None
338
+
339
+ for layer_idx, decoder_layer in enumerate(self.layers):
340
+ if output_hidden_states:
341
+ all_hidden_states += (hidden_states,)
342
+ layer_outputs = forward_dict["model"](
343
+ decoder_layer,
344
+ hidden_states,
345
+ layer_idx,
346
+ attention_mask=attention_mask,
347
+ position_ids=position_ids,
348
+ past_key_value=past_key_values,
349
+ output_attentions=output_attentions,
350
+ use_cache=use_cache,
351
+ batch_ids=batch_ids,
352
+ cos=cos,
353
+ sin=sin,
354
+ forward_dict=forward_dict,
355
+ )
356
+
357
+ hidden_states = layer_outputs[0]
358
+
359
+ updated_cache = layer_outputs[2 if output_attentions else 1]
360
+
361
+ if output_attentions:
362
+ all_self_attns += (layer_outputs[1],)
363
+
364
+ hidden_states = self.norm(hidden_states)
365
+
366
+ # add hidden states from the last decoder layer
367
+ if output_hidden_states:
368
+ all_hidden_states += (hidden_states,)
369
+
370
+ # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
371
+ next_cache = updated_cache.to_legacy_cache()
372
+
373
+ return BaseModelOutputWithPast(
374
+ last_hidden_state=hidden_states,
375
+ past_key_values=next_cache,
376
+ hidden_states=all_hidden_states,
377
+ attentions=all_self_attns,
378
+ )
379
+
380
+
381
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids, unsqueeze_dim=1):
382
+ """Slice cos[position_ids], sin[position_ids] vector for the query."""
383
+ if position_ids.shape[0] > 1:
384
+ cos_all = []
385
+ sin_all = []
386
+ for i in range(position_ids.shape[0]):
387
+ cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
388
+ sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
389
+ cos = torch.cat(cos_all, dim=0)
390
+ sin = torch.cat(sin_all, dim=0)
391
+ else:
392
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
393
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
394
+
395
+ return cos, sin
396
+
397
+
398
+ def rotate_half(x):
399
+ """Rotates half the hidden dims of the input."""
400
+ x1 = x[..., : x.shape[-1] // 2]
401
+ x2 = x[..., x.shape[-1] // 2 :]
402
+ return torch.cat((-x2, x1), dim=-1)
403
+
404
+
405
+ def apply_rotary_pos_emb(q, k, cos, sin):
406
+ """Applies Rotary Position Embedding to the query and key tensors."""
407
+
408
+ q_embed = (q * cos) + (rotate_half(q) * sin)
409
+ k_embed = (k * cos) + (rotate_half(k) * sin)
410
+ return q_embed, k_embed
411
+
412
+
413
+ class RotaryEmbedding(nn.Module):
414
+ def __init__(
415
+ self,
416
+ dim,
417
+ max_position_embeddings=2048,
418
+ base=10000,
419
+ device=None,
420
+ scaling_factor=1.0,
421
+ ):
422
+ super().__init__()
423
+
424
+ self.scaling_factor = scaling_factor
425
+ self.dim = dim
426
+ self.max_position_embeddings = max_position_embeddings
427
+ self.base = base
428
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
429
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
430
+
431
+ # Build here to make `torch.jit.trace` work.
432
+ device = self.inv_freq.device
433
+
434
+ positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
435
+ freqs = torch.outer(positions_ids, self.inv_freq)
436
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
437
+ emb = torch.cat((freqs, freqs), dim=-1)
438
+
439
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
440
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
441
+
442
+ def forward(self, x, seq_len):
443
+ return (
444
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
445
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
446
+ )
447
+
448
+
449
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
450
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
451
+
452
+ def __init__(
453
+ self,
454
+ dim,
455
+ max_position_embeddings=2048,
456
+ base=10000,
457
+ device=None,
458
+ scaling_factor=1.0,
459
+ max_seq_len=2048,
460
+ ):
461
+ super().__init__(
462
+ dim,
463
+ max_position_embeddings=max_position_embeddings,
464
+ base=base,
465
+ scaling_factor=scaling_factor,
466
+ )
467
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
468
+ if max_seq_len > max_position_embeddings:
469
+ positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
470
+ positions_ids = positions_ids / self.scaling_factor
471
+ freqs = torch.outer(positions_ids, self.inv_freq)
472
+ emb = torch.cat((freqs, freqs), dim=-1)
473
+ cos = emb.cos()
474
+ sin = emb.sin()
475
+
476
+ self._cos_cached = torch.cat([self._cos_cached, cos[max_position_embeddings:]], dim=0)
477
+ self._sin_cached = torch.cat([self._sin_cached, sin[max_position_embeddings:]], dim=0)
478
+
479
+
480
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
481
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
482
+
483
+ def __init__(
484
+ self,
485
+ dim,
486
+ max_position_embeddings=2048,
487
+ base=10000,
488
+ device=None,
489
+ scaling_factor=1.0,
490
+ max_seq_len=2048,
491
+ ):
492
+ super().__init__(
493
+ dim,
494
+ max_position_embeddings=max_position_embeddings,
495
+ base=base,
496
+ scaling_factor=scaling_factor,
497
+ )
498
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
499
+ device = self.inv_freq.device
500
+ dtype = self.inv_freq.dtype
501
+ if max_seq_len > max_position_embeddings:
502
+ position_ids = torch.arange(max_position_embeddings, max_seq_len, dtype=dtype).view(-1, 1)
503
+ seq_len = position_ids + 1
504
+ base = self.base * (
505
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
506
+ ) ** (self.dim / (self.dim - 2))
507
+
508
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
509
+
510
+ freqs = position_ids * inv_freq
511
+ emb = torch.cat((freqs, freqs), dim=-1)
512
+ cos = emb.cos()
513
+ sin = emb.sin()
514
+
515
+ self._cos_cached = torch.cat([self._cos_cached, cos], dim=0)
516
+ self._sin_cached = torch.cat([self._sin_cached, sin], dim=0)