optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,616 +21,9 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import math
25
- from typing import Any, Dict, List, Optional, Tuple, Union
26
24
 
27
- import torch
28
- import torch.nn.functional as F
29
- from torch import nn
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
- from transformers.models.llama.modeling_llama import (
34
- LlamaAttention,
35
- LlamaDecoderLayer,
36
- LlamaForCausalLM,
37
- LlamaModel,
38
- LlamaRotaryEmbedding,
39
- )
25
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
40
26
 
41
27
 
42
- class LlamaWrapper(torch.nn.Module):
43
- def __init__(self, model):
44
- super().__init__()
45
- self.model = model
46
-
47
- def forward(self, input_ids, attention_mask, cache_position, *past_key_values):
48
- past_kv_list = []
49
- for i in range(self.model.config.num_hidden_layers):
50
- cur_kv_layer = []
51
- for j in range(2):
52
- cur_kv_layer.append(past_key_values[2 * i + j])
53
- past_kv_list.append(cur_kv_layer)
54
-
55
- model_output = self.model(
56
- input_ids=input_ids,
57
- attention_mask=attention_mask,
58
- position_ids=cache_position, # change forward function
59
- past_key_values=past_kv_list,
60
- )
61
-
62
- return model_output
63
-
64
-
65
- class _LlamaRotaryEmbedding(nn.Module):
66
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
67
- super(LlamaRotaryEmbedding, self).__init__()
68
-
69
- self.dim = dim
70
- self.max_position_embeddings = max_position_embeddings
71
- self.base = base
72
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
73
- self.register_buffer("inv_freq", inv_freq, persistent=False)
74
-
75
- # Build here to make `torch.jit.trace` work.
76
- seq_len = max_position_embeddings
77
- device = self.inv_freq.device
78
- dtype = torch.get_default_dtype()
79
- self.max_seq_len_cached = seq_len
80
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
81
-
82
- freqs = torch.outer(t, self.inv_freq)
83
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
84
- emb = torch.cat((freqs, freqs), dim=-1)
85
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
86
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
87
-
88
- def forward(self, x, seq_len=None):
89
- # x: [bs, num_attention_heads, seq_len, head_size]
90
- if seq_len > self.max_seq_len_cached:
91
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
92
-
93
- return (
94
- self.cos_cached[:seq_len].to(dtype=x.dtype),
95
- self.sin_cached[:seq_len].to(dtype=x.dtype),
96
- )
97
-
98
-
99
- class _LlamaAttention(LlamaAttention):
100
- def forward(
101
- self,
102
- hidden_states: torch.Tensor,
103
- attention_mask: Optional[torch.Tensor] = None,
104
- position_ids: Optional[torch.LongTensor] = None,
105
- past_key_value: Optional[Cache] = None,
106
- output_attentions: bool = False,
107
- use_cache: bool = False,
108
- **kwargs,
109
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110
- bsz, q_len, _ = hidden_states.size()
111
-
112
- if self.config.pretraining_tp > 1:
113
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
114
- query_slices = self.q_proj.weight.split(
115
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
116
- )
117
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
118
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
119
-
120
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
121
- query_states = torch.cat(query_states, dim=-1)
122
-
123
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
124
- key_states = torch.cat(key_states, dim=-1)
125
-
126
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
127
- value_states = torch.cat(value_states, dim=-1)
128
-
129
- else:
130
- query_states = self.q_proj(hidden_states)
131
- key_states = self.k_proj(hidden_states)
132
- value_states = self.v_proj(hidden_states)
133
-
134
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
135
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
136
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
137
-
138
- kv_seq_len = key_states.shape[-2]
139
- if past_key_value is not None:
140
- if self.layer_idx is None:
141
- raise ValueError(
142
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
143
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
144
- "with a layer index."
145
- )
146
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
147
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
148
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
149
-
150
- # change to remove repeat
151
- key_states = key_states.unsqueeze(2)
152
- value_states = value_states.unsqueeze(2)
153
- query_states = query_states.view(
154
- bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
155
- )
156
-
157
- if past_key_value is not None:
158
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
159
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
160
-
161
- # change to remove repeat
162
- # key_states = repeat_kv(key_states, self.num_key_value_groups)
163
- # value_states = repeat_kv(value_states, self.num_key_value_groups)
164
-
165
- # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
166
-
167
- attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
168
-
169
- # change to remove repeat
170
- # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
171
- # raise ValueError(
172
- # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
173
- # f" {attn_weights.size()}"
174
- # )
175
-
176
- if attention_mask is not None:
177
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
178
- raise ValueError(
179
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
180
- )
181
- else:
182
- # change to remove repeat
183
- attention_mask = attention_mask.unsqueeze(2)
184
-
185
- attn_weights = attn_weights + attention_mask
186
-
187
- # upcast attention to fp32
188
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
189
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
190
- attn_output = torch.matmul(attn_weights, value_states)
191
-
192
- # change to remove repeat
193
- attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
194
-
195
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
196
- raise ValueError(
197
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
198
- f" {attn_output.size()}"
199
- )
200
-
201
- attn_output = attn_output.transpose(1, 2).contiguous()
202
-
203
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
204
-
205
- if self.config.pretraining_tp > 1:
206
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
207
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
208
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
209
- else:
210
- attn_output = self.o_proj(attn_output)
211
-
212
- if not output_attentions:
213
- attn_weights = None
214
-
215
- return attn_output, attn_weights, past_key_value
216
-
217
-
218
- class _LlamaDecoderLayer(LlamaDecoderLayer):
219
- def forward(
220
- self,
221
- hidden_states: torch.Tensor,
222
- attention_mask: Optional[torch.Tensor] = None,
223
- position_ids: Optional[torch.LongTensor] = None,
224
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
225
- output_attentions: Optional[bool] = False,
226
- use_cache: Optional[bool] = False,
227
- **kwargs,
228
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
229
- residual = hidden_states
230
-
231
- hidden_states = self.input_layernorm(hidden_states)
232
-
233
- # Self Attention
234
- hidden_states, self_attn_weights, present_key_value = _LlamaAttention.forward(
235
- self.self_attn,
236
- hidden_states=hidden_states,
237
- attention_mask=attention_mask,
238
- position_ids=position_ids,
239
- past_key_value=past_key_value,
240
- output_attentions=output_attentions,
241
- use_cache=use_cache,
242
- **kwargs,
243
- )
244
- hidden_states = residual + hidden_states
245
-
246
- # Fully Connected
247
- residual = hidden_states
248
- hidden_states = self.post_attention_layernorm(hidden_states)
249
- hidden_states = self.mlp(hidden_states)
250
- hidden_states = residual + hidden_states
251
-
252
- outputs = (hidden_states,)
253
-
254
- if output_attentions:
255
- outputs += (self_attn_weights,)
256
-
257
- if use_cache:
258
- outputs += (present_key_value,)
259
-
260
- return outputs
261
-
262
-
263
- class _LlamaModel(LlamaModel):
264
- def forward(
265
- self,
266
- input_ids: torch.LongTensor = None,
267
- attention_mask: Optional[torch.Tensor] = None,
268
- position_ids: Optional[int] = None,
269
- past_key_values: Optional[List[torch.FloatTensor]] = None,
270
- inputs_embeds: Optional[torch.FloatTensor] = None,
271
- use_cache: Optional[bool] = None,
272
- output_attentions: Optional[bool] = None,
273
- output_hidden_states: Optional[bool] = None,
274
- return_dict: Optional[bool] = None,
275
- ) -> Union[Tuple, BaseModelOutputWithPast]:
276
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
277
- output_hidden_states = (
278
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
279
- )
280
- use_cache = use_cache if use_cache is not None else self.config.use_cache
281
-
282
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
283
-
284
- #### cannot change forward args? temporal workaround ####
285
- current_step = position_ids
286
-
287
- # retrieve input_ids and inputs_embeds
288
- if input_ids is not None and inputs_embeds is not None:
289
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
290
- elif input_ids is not None:
291
- batch_size, seq_length = input_ids.shape[:2]
292
- elif inputs_embeds is not None:
293
- batch_size, seq_length = inputs_embeds.shape[:2]
294
- else:
295
- raise ValueError("You have to specify either input_ids or inputs_embeds")
296
-
297
- past_key_values_length = 0
298
- if use_cache:
299
- use_legacy_cache = not isinstance(past_key_values, Cache)
300
- if use_legacy_cache:
301
- past_key_values = RebelDynamicCache.from_legacy_cache(
302
- current_step=current_step,
303
- max_length=self.config.max_position_embeddings,
304
- past_key_values=past_key_values,
305
- )
306
-
307
- # not used, get_usable_length will be changed
308
- # past_key_values_length = past_key_values.get_usable_length(seq_length)
309
-
310
- #### position embedding indice ####
311
- position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
312
-
313
- if position_ids is None:
314
- device = input_ids.device if input_ids is not None else inputs_embeds.device
315
- position_ids = torch.arange(
316
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
317
- )
318
- position_ids = position_ids.unsqueeze(0)
319
-
320
- if inputs_embeds is None:
321
- inputs_embeds = self.embed_tokens(input_ids)
322
-
323
- # ##### original condition for generating causal attention mask
324
- # if getattr(self.config, "_flash_attn_2_enabled", False):
325
- # # 2d mask is passed through the layers
326
- # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
327
-
328
- # else:
329
- # # 4d mask is passed through the layers
330
- # attention_mask = _prepare_4d_causal_attention_mask(
331
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
332
- # )
333
- # ########################################################
334
-
335
- # yhboo changed for valid graph generation
336
- if getattr(self.config, "_flash_attn_2_enabled", False):
337
- # 2d mask is passed through the layers
338
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
339
- raise NotImplementedError
340
-
341
- elif attention_mask is not None and attention_mask.ndim == 4:
342
- # assuming attention mask is generated as input
343
- # assumed dim = [batch_size, 1, inp_seq, max_seq]
344
- # only make [1, 0] mask to [0, -inf]
345
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
346
-
347
- else:
348
- # 4d mask is passed through the layers
349
- attention_mask = _prepare_4d_causal_attention_mask(
350
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
351
- )
352
-
353
- # embed positions
354
- hidden_states = inputs_embeds
355
-
356
- # decoder layers
357
- all_hidden_states = () if output_hidden_states else None
358
- all_self_attns = () if output_attentions else None
359
- next_decoder_cache = () if use_cache else None
360
-
361
- for idx, decoder_layer in enumerate(self.layers):
362
- if output_hidden_states:
363
- all_hidden_states += (hidden_states,)
364
-
365
- layer_outputs = _LlamaDecoderLayer.forward(
366
- decoder_layer,
367
- hidden_states,
368
- attention_mask=attention_mask,
369
- position_ids=position_ids,
370
- past_key_value=past_key_values,
371
- output_attentions=output_attentions,
372
- use_cache=use_cache,
373
- )
374
-
375
- hidden_states = layer_outputs[0]
376
-
377
- if use_cache:
378
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
379
-
380
- if output_attentions:
381
- all_self_attns += (layer_outputs[1],)
382
-
383
- hidden_states = self.norm(hidden_states)
384
-
385
- # add hidden states from the last decoder layer
386
- if output_hidden_states:
387
- all_hidden_states += (hidden_states,)
388
-
389
- next_cache = None
390
- if use_cache:
391
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
392
-
393
- return BaseModelOutputWithPast(
394
- last_hidden_state=hidden_states,
395
- past_key_values=next_cache,
396
- hidden_states=all_hidden_states,
397
- attentions=all_self_attns,
398
- )
399
-
400
-
401
- class _LlamaForCausalLM(LlamaForCausalLM):
402
- def forward(
403
- self,
404
- input_ids: torch.LongTensor = None,
405
- attention_mask: Optional[torch.Tensor] = None,
406
- position_ids: Optional[torch.LongTensor] = None,
407
- past_key_values: Optional[List[torch.FloatTensor]] = None,
408
- inputs_embeds: Optional[torch.FloatTensor] = None,
409
- labels: Optional[torch.LongTensor] = None,
410
- use_cache: Optional[bool] = None,
411
- output_attentions: Optional[bool] = None,
412
- output_hidden_states: Optional[bool] = None,
413
- return_dict: Optional[bool] = None,
414
- ) -> Union[Tuple, CausalLMOutputWithPast]:
415
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
416
- output_hidden_states = (
417
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
418
- )
419
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
420
-
421
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
422
- outputs = self.model(
423
- input_ids=input_ids,
424
- attention_mask=attention_mask,
425
- position_ids=position_ids,
426
- past_key_values=past_key_values,
427
- inputs_embeds=inputs_embeds,
428
- use_cache=use_cache,
429
- output_attentions=output_attentions,
430
- output_hidden_states=output_hidden_states,
431
- return_dict=return_dict,
432
- )
433
-
434
- hidden_states = outputs[0]
435
- logits = self.lm_head(hidden_states)
436
- logits = logits.float()
437
-
438
- if not return_dict:
439
- output = (logits,) + outputs[1:]
440
- return output
441
-
442
- return CausalLMOutputWithPast(
443
- logits=logits,
444
- past_key_values=outputs.past_key_values,
445
- hidden_states=outputs.hidden_states,
446
- attentions=outputs.attentions,
447
- )
448
-
449
-
450
- class RebelDynamicCache(DynamicCache):
451
- """
452
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
453
-
454
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
455
- `[batch_size, num_heads, seq_len, head_dim]`.
456
- """
457
-
458
- def __init__(self, current_step, max_length) -> None:
459
- super().__init__()
460
- self.current_step = current_step
461
- self.max_length = max_length
462
-
463
- def copy(
464
- self,
465
- key_states: torch.Tensor,
466
- value_states: torch.Tensor,
467
- layer_idx: int,
468
- cache_kwargs: Optional[Dict[str, Any]] = None,
469
- ) -> Tuple[torch.Tensor, torch.Tensor]:
470
- """
471
- Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
472
- just for from_legacy_cache function
473
-
474
- Parameters:
475
- key_states (`torch.Tensor`):
476
- The new key states to cache.
477
- value_states (`torch.Tensor`):
478
- The new value states to cache.
479
- layer_idx (`int`):
480
- The index of the layer to cache the states for.
481
- cache_kwargs (`Dict[str, Any]`, `optional`):
482
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
483
-
484
- Return:
485
- A tuple containing the updated key and value states.
486
- """
487
- # Update the number of seen tokens : deprecated
488
- # if layer_idx == 0:
489
- # self.seen_tokens += key_states.shape[-2]
490
-
491
- # Update the cache
492
- if len(self.key_cache) <= layer_idx:
493
- self.key_cache.append(key_states)
494
- self.value_cache.append(value_states)
495
- else:
496
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
497
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
498
-
499
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
500
-
501
- def update(
502
- self,
503
- key_states: torch.Tensor,
504
- value_states: torch.Tensor,
505
- layer_idx: int,
506
- cache_kwargs: Optional[Dict[str, Any]] = None,
507
- ) -> Tuple[torch.Tensor, torch.Tensor]:
508
- """
509
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
510
- based on self.current_step,
511
-
512
- Parameters:
513
- key_states (`torch.Tensor`):
514
- The new key states to cache.
515
- value_states (`torch.Tensor`):
516
- The new value states to cache.
517
- layer_idx (`int`):
518
- The index of the layer to cache the states for.
519
- cache_kwargs (`Dict[str, Any]`, `optional`):
520
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
521
-
522
- Return:
523
- A tuple containing the updated key and value states.
524
- """
525
- # # Update the number of seen tokens : deprecated
526
- # if layer_idx == 0:
527
- # self.seen_tokens += key_states.shape[-2]
528
-
529
- # Update the cache
530
- if len(self.key_cache) <= layer_idx:
531
- self.key_cache.append(key_states)
532
- self.value_cache.append(value_states)
533
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
534
- else:
535
- # change to remove repeat
536
- # self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
537
- # key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
538
- # )
539
- # self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
540
- # value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
541
- # )
542
- updated_key = (
543
- self.key_cache[layer_idx]
544
- .unsqueeze(2)
545
- .slice_scatter(
546
- key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
547
- )
548
- )
549
- updated_value = (
550
- self.value_cache[layer_idx]
551
- .unsqueeze(2)
552
- .slice_scatter(
553
- value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
554
- )
555
- )
556
- self.key_cache[layer_idx] = updated_key.squeeze(2)
557
- self.value_cache[layer_idx] = updated_value.squeeze(2)
558
- return updated_key, updated_value
559
-
560
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
561
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
562
- if len(self.key_cache) <= layer_idx:
563
- return 0
564
- return self.key_cache[layer_idx].shape[-2]
565
-
566
- def get_max_length(self) -> Optional[int]:
567
- return self.max_length
568
-
569
- @classmethod
570
- def from_legacy_cache(
571
- cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
572
- ) -> "DynamicCache":
573
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
574
- cache = cls(current_step, max_length)
575
- if past_key_values is not None:
576
- for layer_idx in range(len(past_key_values)):
577
- key_states, value_states = past_key_values[layer_idx]
578
- cache.copy(key_states, value_states, layer_idx)
579
- return cache
580
-
581
-
582
- def rotate_half(x):
583
- """Rotates half the hidden dims of the input."""
584
- x1 = x[..., : x.shape[-1] // 2]
585
- x2 = x[..., x.shape[-1] // 2 :]
586
- return torch.cat((-x2, x1), dim=-1)
587
-
588
-
589
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
590
- """Applies Rotary Position Embedding to the query and key tensors.
591
-
592
- Args:
593
- q (`torch.Tensor`): The query tensor.
594
- k (`torch.Tensor`): The key tensor.
595
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
596
- sin (`torch.Tensor`): The sine part of the rotary embedding.
597
- position_ids (`torch.Tensor`):
598
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
599
- used to pass offsetted position ids when working with a KV-cache.
600
- unsqueeze_dim (`int`, *optional*, defaults to 1):
601
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
602
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
603
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
604
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
605
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
606
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
607
- Returns:
608
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
609
- """
610
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
611
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
612
- q_embed = (q * cos) + (rotate_half(q) * sin)
613
- k_embed = (k * cos) + (rotate_half(k) * sin)
614
- return q_embed, k_embed
615
-
616
-
617
- origin_methods = {}
618
- origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
619
- origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
620
- origin_methods["LlamaModel_forward"] = LlamaModel.forward
621
- origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
622
-
623
-
624
- def wrap_llama():
625
- LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
626
- LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
627
- LlamaModel.forward = _LlamaModel.forward
628
- LlamaForCausalLM.forward = _LlamaForCausalLM.forward
629
-
630
-
631
- def unwrap_llama():
632
- global origin_methods
633
- LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
634
- LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
635
- LlamaModel.forward = origin_methods["LlamaModel_forward"]
636
- LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]
28
+ class LlamaWrapper(DecoderOnlyWrapper):
29
+ pass