optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,515 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import math
25
+ from typing import Dict, Optional, Tuple
26
+
27
+ import torch
28
+ from torch import nn
29
+ from transformers.modeling_outputs import (
30
+ BaseModelOutputWithPast,
31
+ )
32
+
33
+ from ...cache_utils import RebelDynamicCache
34
+
35
+
36
+ class DecoderOnlyWrapper(torch.nn.Module):
37
+ def __init__(self, model, max_seq_len):
38
+ super().__init__()
39
+ self.config = model.config
40
+ self.model = model.model
41
+ self.lm_head = model.lm_head
42
+
43
+ self.head_dim = (
44
+ self.config.head_dim
45
+ if hasattr(self.config, "head_dim")
46
+ else self.config.hidden_size // self.config.num_attention_heads
47
+ )
48
+ self.max_position_embeddings = (
49
+ self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
50
+ )
51
+ self.max_seq_len = max_seq_len
52
+ self.rotary_emb = self._init_rope()
53
+
54
+ def _init_rope(self):
55
+ if self.config.rope_scaling is None:
56
+ rotary_emb = RotaryEmbedding(
57
+ self.head_dim,
58
+ max_position_embeddings=self.max_position_embeddings,
59
+ base=self.config.rope_theta,
60
+ )
61
+ else:
62
+ scaling_type = self.config.rope_scaling["type"]
63
+ scaling_factor = self.config.rope_scaling["factor"]
64
+ if scaling_type == "linear":
65
+ rotary_emb = LinearScalingRotaryEmbedding(
66
+ self.head_dim,
67
+ max_position_embeddings=self.max_position_embeddings,
68
+ scaling_factor=scaling_factor,
69
+ base=self.config.rope_theta,
70
+ max_seq_len=self.max_seq_len,
71
+ )
72
+ elif scaling_type == "dynamic":
73
+ rotary_emb = DynamicNTKScalingRotaryEmbedding(
74
+ self.head_dim,
75
+ max_position_embeddings=self.max_position_embeddings,
76
+ scaling_factor=scaling_factor,
77
+ base=self.config.rope_theta,
78
+ max_seq_len=self.max_seq_len,
79
+ )
80
+ else:
81
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
82
+
83
+ return rotary_emb
84
+
85
+ def get_forward_dict(self):
86
+ forward_dict = {
87
+ "wrapper": DecoderOnlyModel.forward,
88
+ "model": DecoderOnlyDecoderLayer.forward,
89
+ "decoder_layer": DecoderOnlyAttention.forward,
90
+ }
91
+ return forward_dict
92
+
93
+ def forward(
94
+ self,
95
+ input_ids,
96
+ attention_mask,
97
+ cache_position,
98
+ batch_position,
99
+ *past_key_values,
100
+ ):
101
+ if input_ids.shape[1] == 1:
102
+ rbln_batch_position = None
103
+ else:
104
+ rbln_batch_position = batch_position
105
+
106
+ # Formatting list of past_kv to DynamicCache class.
107
+ past_key_values = RebelDynamicCache.from_input_format(
108
+ cache_position,
109
+ self.config.num_hidden_layers,
110
+ *past_key_values,
111
+ )
112
+
113
+ forward_dict = self.get_forward_dict()
114
+ outputs = forward_dict["wrapper"](
115
+ self.model,
116
+ input_ids=input_ids,
117
+ attention_mask=attention_mask,
118
+ position_ids=cache_position,
119
+ past_key_values=past_key_values,
120
+ batch_ids=rbln_batch_position,
121
+ rotary_pos_emb=self.rotary_emb,
122
+ forward_dict=forward_dict,
123
+ )
124
+
125
+ hidden_states = outputs[0]
126
+ logits = self.lm_head(hidden_states)
127
+
128
+ output = (logits,) + outputs[1:]
129
+
130
+ return output, batch_position
131
+
132
+
133
+ class DecoderOnlyAttention:
134
+ def forward(
135
+ self,
136
+ hidden_states: torch.Tensor,
137
+ attention_mask: Optional[torch.Tensor] = None,
138
+ past_key_value: Optional[RebelDynamicCache] = None,
139
+ batch_index: Optional[int] = None,
140
+ output_attentions: bool = False,
141
+ cos: Optional[torch.Tensor] = None,
142
+ sin: Optional[torch.Tensor] = None,
143
+ **kwargs,
144
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
145
+ bsz, q_len, _ = hidden_states.size()
146
+
147
+ query_states = self.q_proj(hidden_states)
148
+ key_states = self.k_proj(hidden_states)
149
+ value_states = self.v_proj(hidden_states)
150
+
151
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
152
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
153
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
154
+
155
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
156
+
157
+ # Decoder
158
+ if (batch_index is None or batch_index == -1) and bsz > 1:
159
+ all_key_states = []
160
+ all_value_states = []
161
+ all_attn_output = []
162
+
163
+ for b in range(bsz):
164
+ query_state = query_states[b].unsqueeze(0)
165
+ attn_mask = attention_mask[b].unsqueeze(0)
166
+ key_state = key_states[b].unsqueeze(0)
167
+ value_state = value_states[b].unsqueeze(0)
168
+
169
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
170
+ key_state = key_state.unsqueeze(2)
171
+ value_state = value_state.unsqueeze(2)
172
+ attn_mask = attn_mask.unsqueeze(2)
173
+
174
+ query_state = query_state.view(
175
+ 1,
176
+ self.num_key_value_heads,
177
+ self.num_heads // self.num_key_value_heads,
178
+ q_len,
179
+ self.head_dim,
180
+ )
181
+
182
+ key_state, value_state = past_key_value.update(
183
+ key_state,
184
+ value_state,
185
+ self.layer_idx,
186
+ b,
187
+ )
188
+
189
+ # reshape for removing repeat_kv
190
+ attn_weight = torch.matmul(query_state, key_state.transpose(3, 4)) / math.sqrt(self.head_dim)
191
+
192
+ attn_weight = attn_weight + attn_mask
193
+
194
+ # upcast attention to fp32
195
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
196
+ attn_output = torch.matmul(attn_weight, value_state)
197
+
198
+ # reshape for removing repeat_kv
199
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
200
+
201
+ attn_output = attn_output.transpose(1, 2).contiguous()
202
+ attn_output = attn_output.reshape(1, q_len, self.num_heads * self.head_dim)
203
+
204
+ all_key_states.append(key_state)
205
+ all_value_states.append(value_state)
206
+ all_attn_output.append(attn_output)
207
+
208
+ key_states = torch.cat(all_key_states, dim=0)
209
+ value_states = torch.cat(all_value_states, dim=0)
210
+ attn_output = torch.cat(all_attn_output, dim=0)
211
+
212
+ else:
213
+ if batch_index is None or batch_index == -1:
214
+ batch_index = 0
215
+
216
+ # reshape for removing repeat_kv
217
+ key_states = key_states.unsqueeze(2)
218
+ value_states = value_states.unsqueeze(2)
219
+ attention_mask = attention_mask.unsqueeze(2)
220
+ query_states = query_states.view(
221
+ 1,
222
+ self.num_key_value_heads,
223
+ self.num_heads // self.num_key_value_heads,
224
+ q_len,
225
+ self.head_dim,
226
+ )
227
+
228
+ key_states, value_states = past_key_value.update(
229
+ key_states,
230
+ value_states,
231
+ self.layer_idx,
232
+ batch_index,
233
+ read_first_step=True,
234
+ )
235
+
236
+ attn_weight = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
237
+ attn_weight = attn_weight + attention_mask
238
+
239
+ # upcast attention to fp32
240
+ attn_weight = nn.functional.softmax(attn_weight, dim=-1, dtype=torch.float32).to(query_states.dtype)
241
+ attn_output = torch.matmul(attn_weight, value_states)
242
+
243
+ # reshape for removing repeat_kv
244
+ attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
245
+ attn_output = attn_output.transpose(1, 2).contiguous()
246
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
247
+
248
+ attn_output = self.o_proj(attn_output)
249
+
250
+ if not output_attentions:
251
+ attn_weight = None
252
+
253
+ return attn_output, attn_weight, key_states, value_states
254
+
255
+
256
+ class DecoderOnlyDecoderLayer:
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.Tensor,
260
+ layer_idx: int,
261
+ attention_mask: Optional[torch.Tensor] = None,
262
+ position_ids: Optional[torch.LongTensor] = None,
263
+ past_key_value: Optional[RebelDynamicCache] = None,
264
+ output_attentions: Optional[bool] = None,
265
+ use_cache: Optional[bool] = None,
266
+ batch_ids: Optional[torch.LongTensor] = None,
267
+ cos: Optional[torch.Tensor] = None,
268
+ sin: Optional[torch.Tensor] = None,
269
+ forward_dict: Optional[Dict[str, classmethod]] = None,
270
+ **kwargs,
271
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
272
+ residual = hidden_states
273
+
274
+ hidden_states = self.input_layernorm(hidden_states)
275
+
276
+ hidden_states, self_attn_weight, k, v = forward_dict["decoder_layer"](
277
+ self.self_attn,
278
+ hidden_states=hidden_states,
279
+ attention_mask=attention_mask,
280
+ position_ids=position_ids,
281
+ past_key_value=past_key_value,
282
+ output_attentions=output_attentions,
283
+ batch_index=batch_ids,
284
+ use_cache=use_cache,
285
+ cos=cos,
286
+ sin=sin,
287
+ **kwargs,
288
+ )
289
+ past_key_value.assign(k, v, layer_idx)
290
+
291
+ hidden_states = residual + hidden_states
292
+
293
+ # Fully Connected
294
+ residual = hidden_states
295
+ hidden_states = self.post_attention_layernorm(hidden_states)
296
+ hidden_states = self.mlp(hidden_states)
297
+ hidden_states = residual + hidden_states
298
+
299
+ outputs = (hidden_states,)
300
+
301
+ if output_attentions:
302
+ outputs += (self_attn_weight,)
303
+
304
+ if use_cache:
305
+ outputs += (past_key_value,)
306
+
307
+ return outputs
308
+
309
+
310
+ class DecoderOnlyModel:
311
+ def forward(
312
+ self,
313
+ input_ids: torch.LongTensor = None,
314
+ attention_mask: Optional[torch.Tensor] = None,
315
+ position_ids: Optional[torch.LongTensor] = None,
316
+ past_key_values: Optional[RebelDynamicCache] = None,
317
+ batch_ids: Optional[torch.LongTensor] = None,
318
+ inputs_embeds: Optional[torch.FloatTensor] = None,
319
+ use_cache: Optional[bool] = True,
320
+ output_attentions: Optional[bool] = False,
321
+ output_hidden_states: Optional[bool] = False,
322
+ forward_dict: Optional[Dict[str, classmethod]] = None,
323
+ rotary_pos_emb=None,
324
+ ) -> BaseModelOutputWithPast:
325
+ # embed positions
326
+ inputs_embeds = self.embed_tokens(input_ids)
327
+ hidden_states = inputs_embeds
328
+ attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
329
+
330
+ # get cos,sin vector
331
+ cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
332
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
333
+
334
+ # decoder layers
335
+ all_hidden_states = () if output_hidden_states else None
336
+ all_self_attns = () if output_attentions else None
337
+
338
+ for layer_idx, decoder_layer in enumerate(self.layers):
339
+ if output_hidden_states:
340
+ all_hidden_states += (hidden_states,)
341
+ layer_outputs = forward_dict["model"](
342
+ decoder_layer,
343
+ hidden_states,
344
+ layer_idx,
345
+ attention_mask=attention_mask,
346
+ position_ids=position_ids,
347
+ past_key_value=past_key_values,
348
+ output_attentions=output_attentions,
349
+ use_cache=use_cache,
350
+ batch_ids=batch_ids,
351
+ cos=cos,
352
+ sin=sin,
353
+ forward_dict=forward_dict,
354
+ )
355
+
356
+ hidden_states = layer_outputs[0]
357
+
358
+ updated_cache = layer_outputs[2 if output_attentions else 1]
359
+
360
+ if output_attentions:
361
+ all_self_attns += (layer_outputs[1],)
362
+
363
+ hidden_states = self.norm(hidden_states)
364
+
365
+ # add hidden states from the last decoder layer
366
+ if output_hidden_states:
367
+ all_hidden_states += (hidden_states,)
368
+
369
+ # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
370
+ next_cache = updated_cache.to_legacy_cache()
371
+
372
+ return BaseModelOutputWithPast(
373
+ last_hidden_state=hidden_states,
374
+ past_key_values=next_cache,
375
+ hidden_states=all_hidden_states,
376
+ attentions=all_self_attns,
377
+ )
378
+
379
+
380
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids, unsqueeze_dim=1):
381
+ """Slice cos[position_ids], sin[position_ids] vector for the query."""
382
+ if position_ids.shape[0] > 1:
383
+ cos_all = []
384
+ sin_all = []
385
+ for i in range(position_ids.shape[0]):
386
+ cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
387
+ sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
388
+ cos = torch.cat(cos_all, dim=0)
389
+ sin = torch.cat(sin_all, dim=0)
390
+ else:
391
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
392
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
393
+
394
+ return cos, sin
395
+
396
+
397
+ def rotate_half(x):
398
+ """Rotates half the hidden dims of the input."""
399
+ x1 = x[..., : x.shape[-1] // 2]
400
+ x2 = x[..., x.shape[-1] // 2 :]
401
+ return torch.cat((-x2, x1), dim=-1)
402
+
403
+
404
+ def apply_rotary_pos_emb(q, k, cos, sin):
405
+ """Applies Rotary Position Embedding to the query and key tensors."""
406
+
407
+ q_embed = (q * cos) + (rotate_half(q) * sin)
408
+ k_embed = (k * cos) + (rotate_half(k) * sin)
409
+ return q_embed, k_embed
410
+
411
+
412
+ class RotaryEmbedding(nn.Module):
413
+ def __init__(
414
+ self,
415
+ dim,
416
+ max_position_embeddings=2048,
417
+ base=10000,
418
+ device=None,
419
+ scaling_factor=1.0,
420
+ ):
421
+ super().__init__()
422
+
423
+ self.scaling_factor = scaling_factor
424
+ self.dim = dim
425
+ self.max_position_embeddings = max_position_embeddings
426
+ self.base = base
427
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
428
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
429
+
430
+ # Build here to make `torch.jit.trace` work.
431
+ device = self.inv_freq.device
432
+
433
+ positions_ids = torch.arange(self.max_position_embeddings, device=device, dtype=self.inv_freq.dtype)
434
+ freqs = torch.outer(positions_ids, self.inv_freq)
435
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
436
+ emb = torch.cat((freqs, freqs), dim=-1)
437
+
438
+ self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
439
+ self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)
440
+
441
+ def forward(self, x, seq_len):
442
+ return (
443
+ self._cos_cached[:seq_len].to(dtype=x.dtype),
444
+ self._sin_cached[:seq_len].to(dtype=x.dtype),
445
+ )
446
+
447
+
448
+ class LinearScalingRotaryEmbedding(RotaryEmbedding):
449
+ """RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
450
+
451
+ def __init__(
452
+ self,
453
+ dim,
454
+ max_position_embeddings=2048,
455
+ base=10000,
456
+ device=None,
457
+ scaling_factor=1.0,
458
+ max_seq_len=2048,
459
+ ):
460
+ super().__init__(
461
+ dim,
462
+ max_position_embeddings=max_position_embeddings,
463
+ base=base,
464
+ scaling_factor=scaling_factor,
465
+ )
466
+ # difference to the original RoPE: a scaling factor is aplied to the position ids
467
+ if max_seq_len > max_position_embeddings:
468
+ positions_ids = torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
469
+ positions_ids = positions_ids / self.scaling_factor
470
+ freqs = torch.outer(positions_ids, self.inv_freq)
471
+ emb = torch.cat((freqs, freqs), dim=-1)
472
+ cos = emb.cos()
473
+ sin = emb.sin()
474
+
475
+ self._cos_cached = torch.cat([self._cos_cached, cos[max_position_embeddings:]], dim=0)
476
+ self._sin_cached = torch.cat([self._sin_cached, sin[max_position_embeddings:]], dim=0)
477
+
478
+
479
+ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
480
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
481
+
482
+ def __init__(
483
+ self,
484
+ dim,
485
+ max_position_embeddings=2048,
486
+ base=10000,
487
+ device=None,
488
+ scaling_factor=1.0,
489
+ max_seq_len=2048,
490
+ ):
491
+ super().__init__(
492
+ dim,
493
+ max_position_embeddings=max_position_embeddings,
494
+ base=base,
495
+ scaling_factor=scaling_factor,
496
+ )
497
+ # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length
498
+ device = self.inv_freq.device
499
+ dtype = self.inv_freq.dtype
500
+ if max_seq_len > max_position_embeddings:
501
+ position_ids = torch.arange(max_position_embeddings, max_seq_len, dtype=dtype).view(-1, 1)
502
+ seq_len = position_ids + 1
503
+ base = self.base * (
504
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
505
+ ) ** (self.dim / (self.dim - 2))
506
+
507
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
508
+
509
+ freqs = position_ids * inv_freq
510
+ emb = torch.cat((freqs, freqs), dim=-1)
511
+ cos = emb.cos()
512
+ sin = emb.sin()
513
+
514
+ self._cos_cached = torch.cat([self._cos_cached, cos], dim=0)
515
+ self._sin_cached = torch.cat([self._sin_cached, sin], dim=0)