optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,619 +21,9 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import math
25
- from typing import Any, Dict, List, Optional, Tuple, Union
26
24
 
27
- import torch
28
- import torch.nn.functional as F
29
- from torch import nn
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
- from transformers.models.llama.modeling_llama import (
34
- LlamaAttention,
35
- LlamaDecoderLayer,
36
- LlamaForCausalLM,
37
- LlamaModel,
38
- LlamaRotaryEmbedding,
39
- )
25
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
40
26
 
41
27
 
42
- class LlamaWrapper(torch.nn.Module):
43
- def __init__(self, model):
44
- super().__init__()
45
- self.model = model
46
-
47
- def forward(self, input_ids, attention_mask, cache_position, *past_key_values):
48
- past_kv_list = []
49
- for i in range(self.model.config.num_hidden_layers):
50
- cur_kv_layer = []
51
- for j in range(2):
52
- cur_kv_layer.append(past_key_values[2 * i + j])
53
- past_kv_list.append(cur_kv_layer)
54
-
55
- model_output = self.model(
56
- input_ids=input_ids,
57
- attention_mask=attention_mask,
58
- position_ids=cache_position, # change forward function
59
- past_key_values=past_kv_list,
60
- )
61
-
62
- return model_output
63
-
64
-
65
- class _LlamaRotaryEmbedding(nn.Module):
66
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
67
- super(LlamaRotaryEmbedding, self).__init__()
68
-
69
- self.dim = dim
70
- self.max_position_embeddings = max_position_embeddings
71
- self.base = base
72
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
73
- self.register_buffer("inv_freq", inv_freq, persistent=False)
74
-
75
- # Build here to make `torch.jit.trace` work.
76
- seq_len = max_position_embeddings
77
- device = self.inv_freq.device
78
- dtype = torch.get_default_dtype()
79
- self.max_seq_len_cached = seq_len
80
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
81
-
82
- freqs = torch.outer(t, self.inv_freq)
83
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
84
- emb = torch.cat((freqs, freqs), dim=-1)
85
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
86
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
87
-
88
- def forward(self, x, seq_len=None):
89
- # x: [bs, num_attention_heads, seq_len, head_size]
90
- if seq_len > self.max_seq_len_cached:
91
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
92
-
93
- return (
94
- self.cos_cached[:seq_len].to(dtype=x.dtype),
95
- self.sin_cached[:seq_len].to(dtype=x.dtype),
96
- )
97
-
98
-
99
- class _LlamaAttention(LlamaAttention):
100
- def forward(
101
- self,
102
- hidden_states: torch.Tensor,
103
- attention_mask: Optional[torch.Tensor] = None,
104
- position_ids: Optional[torch.LongTensor] = None,
105
- past_key_value: Optional[Cache] = None,
106
- output_attentions: bool = False,
107
- use_cache: bool = False,
108
- **kwargs,
109
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110
-
111
- bsz, q_len, _ = hidden_states.size()
112
-
113
- if self.config.pretraining_tp > 1:
114
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
115
- query_slices = self.q_proj.weight.split(
116
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
117
- )
118
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
119
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
120
-
121
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
122
- query_states = torch.cat(query_states, dim=-1)
123
-
124
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
125
- key_states = torch.cat(key_states, dim=-1)
126
-
127
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
128
- value_states = torch.cat(value_states, dim=-1)
129
-
130
- else:
131
- query_states = self.q_proj(hidden_states)
132
- key_states = self.k_proj(hidden_states)
133
- value_states = self.v_proj(hidden_states)
134
-
135
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
136
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
137
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
138
-
139
- kv_seq_len = key_states.shape[-2]
140
- if past_key_value is not None:
141
- if self.layer_idx is None:
142
- raise ValueError(
143
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
144
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
145
- "with a layer index."
146
- )
147
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
148
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
149
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
150
-
151
- # change to remove repeat
152
- key_states = key_states.unsqueeze(2)
153
- value_states = value_states.unsqueeze(2)
154
- query_states = query_states.view(
155
- bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
156
- )
157
-
158
- if past_key_value is not None:
159
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
160
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
161
-
162
- # change to remove repeat
163
- # key_states = repeat_kv(key_states, self.num_key_value_groups)
164
- # value_states = repeat_kv(value_states, self.num_key_value_groups)
165
-
166
- # attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
167
-
168
- attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
169
-
170
- # change to remove repeat
171
- # if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
172
- # raise ValueError(
173
- # f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
174
- # f" {attn_weights.size()}"
175
- # )
176
-
177
- if attention_mask is not None:
178
- if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
179
- raise ValueError(
180
- f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
181
- )
182
- else:
183
- # change to remove repeat
184
- attention_mask = attention_mask.unsqueeze(2)
185
-
186
- attn_weights = attn_weights + attention_mask
187
-
188
- # upcast attention to fp32
189
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
190
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
191
- attn_output = torch.matmul(attn_weights, value_states)
192
-
193
- # change to remove repeat
194
- attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
195
-
196
- if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
197
- raise ValueError(
198
- f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
199
- f" {attn_output.size()}"
200
- )
201
-
202
- attn_output = attn_output.transpose(1, 2).contiguous()
203
-
204
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
205
-
206
- if self.config.pretraining_tp > 1:
207
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
208
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
209
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
210
- else:
211
- attn_output = self.o_proj(attn_output)
212
-
213
- if not output_attentions:
214
- attn_weights = None
215
-
216
- return attn_output, attn_weights, past_key_value
217
-
218
-
219
- class _LlamaDecoderLayer(LlamaDecoderLayer):
220
- def forward(
221
- self,
222
- hidden_states: torch.Tensor,
223
- attention_mask: Optional[torch.Tensor] = None,
224
- position_ids: Optional[torch.LongTensor] = None,
225
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
226
- output_attentions: Optional[bool] = False,
227
- use_cache: Optional[bool] = False,
228
- **kwargs,
229
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
230
-
231
- residual = hidden_states
232
-
233
- hidden_states = self.input_layernorm(hidden_states)
234
-
235
- # Self Attention
236
- hidden_states, self_attn_weights, present_key_value = _LlamaAttention.forward(
237
- self.self_attn,
238
- hidden_states=hidden_states,
239
- attention_mask=attention_mask,
240
- position_ids=position_ids,
241
- past_key_value=past_key_value,
242
- output_attentions=output_attentions,
243
- use_cache=use_cache,
244
- **kwargs,
245
- )
246
- hidden_states = residual + hidden_states
247
-
248
- # Fully Connected
249
- residual = hidden_states
250
- hidden_states = self.post_attention_layernorm(hidden_states)
251
- hidden_states = self.mlp(hidden_states)
252
- hidden_states = residual + hidden_states
253
-
254
- outputs = (hidden_states,)
255
-
256
- if output_attentions:
257
- outputs += (self_attn_weights,)
258
-
259
- if use_cache:
260
- outputs += (present_key_value,)
261
-
262
- return outputs
263
-
264
-
265
- class _LlamaModel(LlamaModel):
266
- def forward(
267
- self,
268
- input_ids: torch.LongTensor = None,
269
- attention_mask: Optional[torch.Tensor] = None,
270
- position_ids: Optional[int] = None,
271
- past_key_values: Optional[List[torch.FloatTensor]] = None,
272
- inputs_embeds: Optional[torch.FloatTensor] = None,
273
- use_cache: Optional[bool] = None,
274
- output_attentions: Optional[bool] = None,
275
- output_hidden_states: Optional[bool] = None,
276
- return_dict: Optional[bool] = None,
277
- ) -> Union[Tuple, BaseModelOutputWithPast]:
278
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
279
- output_hidden_states = (
280
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
281
- )
282
- use_cache = use_cache if use_cache is not None else self.config.use_cache
283
-
284
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
285
-
286
- #### cannot change forward args? temporal workaround ####
287
- current_step = position_ids
288
-
289
- # retrieve input_ids and inputs_embeds
290
- if input_ids is not None and inputs_embeds is not None:
291
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
292
- elif input_ids is not None:
293
- batch_size, seq_length = input_ids.shape[:2]
294
- elif inputs_embeds is not None:
295
- batch_size, seq_length = inputs_embeds.shape[:2]
296
- else:
297
- raise ValueError("You have to specify either input_ids or inputs_embeds")
298
-
299
- past_key_values_length = 0
300
- if use_cache:
301
- use_legacy_cache = not isinstance(past_key_values, Cache)
302
- if use_legacy_cache:
303
- past_key_values = RebelDynamicCache.from_legacy_cache(
304
- current_step=current_step,
305
- max_length=self.config.max_position_embeddings,
306
- past_key_values=past_key_values,
307
- )
308
-
309
- # not used, get_usable_length will be changed
310
- # past_key_values_length = past_key_values.get_usable_length(seq_length)
311
-
312
- #### position embedding indice ####
313
- position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
314
-
315
- if position_ids is None:
316
- device = input_ids.device if input_ids is not None else inputs_embeds.device
317
- position_ids = torch.arange(
318
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
319
- )
320
- position_ids = position_ids.unsqueeze(0)
321
-
322
- if inputs_embeds is None:
323
- inputs_embeds = self.embed_tokens(input_ids)
324
-
325
- # ##### original condition for generating causal attention mask
326
- # if getattr(self.config, "_flash_attn_2_enabled", False):
327
- # # 2d mask is passed through the layers
328
- # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
329
-
330
- # else:
331
- # # 4d mask is passed through the layers
332
- # attention_mask = _prepare_4d_causal_attention_mask(
333
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
334
- # )
335
- # ########################################################
336
-
337
- # yhboo changed for valid graph generation
338
- if getattr(self.config, "_flash_attn_2_enabled", False):
339
- # 2d mask is passed through the layers
340
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
341
- raise NotImplementedError
342
-
343
- elif attention_mask is not None and attention_mask.ndim == 4:
344
- # assuming attention mask is generated as input
345
- # assumed dim = [batch_size, 1, inp_seq, max_seq]
346
- # only make [1, 0] mask to [0, -inf]
347
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
348
-
349
- else:
350
- # 4d mask is passed through the layers
351
- attention_mask = _prepare_4d_causal_attention_mask(
352
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
353
- )
354
-
355
- # embed positions
356
- hidden_states = inputs_embeds
357
-
358
- # decoder layers
359
- all_hidden_states = () if output_hidden_states else None
360
- all_self_attns = () if output_attentions else None
361
- next_decoder_cache = () if use_cache else None
362
-
363
- for idx, decoder_layer in enumerate(self.layers):
364
- if output_hidden_states:
365
- all_hidden_states += (hidden_states,)
366
-
367
- layer_outputs = _LlamaDecoderLayer.forward(
368
- decoder_layer,
369
- hidden_states,
370
- attention_mask=attention_mask,
371
- position_ids=position_ids,
372
- past_key_value=past_key_values,
373
- output_attentions=output_attentions,
374
- use_cache=use_cache,
375
- )
376
-
377
- hidden_states = layer_outputs[0]
378
-
379
- if use_cache:
380
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
381
-
382
- if output_attentions:
383
- all_self_attns += (layer_outputs[1],)
384
-
385
- hidden_states = self.norm(hidden_states)
386
-
387
- # add hidden states from the last decoder layer
388
- if output_hidden_states:
389
- all_hidden_states += (hidden_states,)
390
-
391
- next_cache = None
392
- if use_cache:
393
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
394
-
395
- return BaseModelOutputWithPast(
396
- last_hidden_state=hidden_states,
397
- past_key_values=next_cache,
398
- hidden_states=all_hidden_states,
399
- attentions=all_self_attns,
400
- )
401
-
402
-
403
- class _LlamaForCausalLM(LlamaForCausalLM):
404
- def forward(
405
- self,
406
- input_ids: torch.LongTensor = None,
407
- attention_mask: Optional[torch.Tensor] = None,
408
- position_ids: Optional[torch.LongTensor] = None,
409
- past_key_values: Optional[List[torch.FloatTensor]] = None,
410
- inputs_embeds: Optional[torch.FloatTensor] = None,
411
- labels: Optional[torch.LongTensor] = None,
412
- use_cache: Optional[bool] = None,
413
- output_attentions: Optional[bool] = None,
414
- output_hidden_states: Optional[bool] = None,
415
- return_dict: Optional[bool] = None,
416
- ) -> Union[Tuple, CausalLMOutputWithPast]:
417
-
418
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
419
- output_hidden_states = (
420
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
421
- )
422
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
423
-
424
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
425
- outputs = self.model(
426
- input_ids=input_ids,
427
- attention_mask=attention_mask,
428
- position_ids=position_ids,
429
- past_key_values=past_key_values,
430
- inputs_embeds=inputs_embeds,
431
- use_cache=use_cache,
432
- output_attentions=output_attentions,
433
- output_hidden_states=output_hidden_states,
434
- return_dict=return_dict,
435
- )
436
-
437
- hidden_states = outputs[0]
438
- logits = self.lm_head(hidden_states)
439
- logits = logits.float()
440
-
441
- if not return_dict:
442
- output = (logits,) + outputs[1:]
443
- return output
444
-
445
- return CausalLMOutputWithPast(
446
- logits=logits,
447
- past_key_values=outputs.past_key_values,
448
- hidden_states=outputs.hidden_states,
449
- attentions=outputs.attentions,
450
- )
451
-
452
-
453
- class RebelDynamicCache(DynamicCache):
454
- """
455
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
456
-
457
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
458
- `[batch_size, num_heads, seq_len, head_dim]`.
459
- """
460
-
461
- def __init__(self, current_step, max_length) -> None:
462
- super().__init__()
463
- self.current_step = current_step
464
- self.max_length = max_length
465
-
466
- def copy(
467
- self,
468
- key_states: torch.Tensor,
469
- value_states: torch.Tensor,
470
- layer_idx: int,
471
- cache_kwargs: Optional[Dict[str, Any]] = None,
472
- ) -> Tuple[torch.Tensor, torch.Tensor]:
473
- """
474
- Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
475
- just for from_legacy_cache function
476
-
477
- Parameters:
478
- key_states (`torch.Tensor`):
479
- The new key states to cache.
480
- value_states (`torch.Tensor`):
481
- The new value states to cache.
482
- layer_idx (`int`):
483
- The index of the layer to cache the states for.
484
- cache_kwargs (`Dict[str, Any]`, `optional`):
485
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
486
-
487
- Return:
488
- A tuple containing the updated key and value states.
489
- """
490
- # Update the number of seen tokens : deprecated
491
- # if layer_idx == 0:
492
- # self.seen_tokens += key_states.shape[-2]
493
-
494
- # Update the cache
495
- if len(self.key_cache) <= layer_idx:
496
- self.key_cache.append(key_states)
497
- self.value_cache.append(value_states)
498
- else:
499
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
500
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
501
-
502
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
503
-
504
- def update(
505
- self,
506
- key_states: torch.Tensor,
507
- value_states: torch.Tensor,
508
- layer_idx: int,
509
- cache_kwargs: Optional[Dict[str, Any]] = None,
510
- ) -> Tuple[torch.Tensor, torch.Tensor]:
511
- """
512
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
513
- based on self.current_step,
514
-
515
- Parameters:
516
- key_states (`torch.Tensor`):
517
- The new key states to cache.
518
- value_states (`torch.Tensor`):
519
- The new value states to cache.
520
- layer_idx (`int`):
521
- The index of the layer to cache the states for.
522
- cache_kwargs (`Dict[str, Any]`, `optional`):
523
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
524
-
525
- Return:
526
- A tuple containing the updated key and value states.
527
- """
528
- # # Update the number of seen tokens : deprecated
529
- # if layer_idx == 0:
530
- # self.seen_tokens += key_states.shape[-2]
531
-
532
- # Update the cache
533
- if len(self.key_cache) <= layer_idx:
534
- self.key_cache.append(key_states)
535
- self.value_cache.append(value_states)
536
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
537
- else:
538
- # change to remove repeat
539
- # self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
540
- # key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
541
- # )
542
- # self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
543
- # value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
544
- # )
545
- updated_key = (
546
- self.key_cache[layer_idx]
547
- .unsqueeze(2)
548
- .slice_scatter(
549
- key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
550
- )
551
- )
552
- updated_value = (
553
- self.value_cache[layer_idx]
554
- .unsqueeze(2)
555
- .slice_scatter(
556
- value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
557
- )
558
- )
559
- self.key_cache[layer_idx] = updated_key.squeeze(2)
560
- self.value_cache[layer_idx] = updated_value.squeeze(2)
561
- return updated_key, updated_value
562
-
563
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
564
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
565
- if len(self.key_cache) <= layer_idx:
566
- return 0
567
- return self.key_cache[layer_idx].shape[-2]
568
-
569
- def get_max_length(self) -> Optional[int]:
570
- return self.max_length
571
-
572
- @classmethod
573
- def from_legacy_cache(
574
- cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
575
- ) -> "DynamicCache":
576
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
577
- cache = cls(current_step, max_length)
578
- if past_key_values is not None:
579
- for layer_idx in range(len(past_key_values)):
580
- key_states, value_states = past_key_values[layer_idx]
581
- cache.copy(key_states, value_states, layer_idx)
582
- return cache
583
-
584
-
585
- def rotate_half(x):
586
- """Rotates half the hidden dims of the input."""
587
- x1 = x[..., : x.shape[-1] // 2]
588
- x2 = x[..., x.shape[-1] // 2 :]
589
- return torch.cat((-x2, x1), dim=-1)
590
-
591
-
592
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
593
- """Applies Rotary Position Embedding to the query and key tensors.
594
-
595
- Args:
596
- q (`torch.Tensor`): The query tensor.
597
- k (`torch.Tensor`): The key tensor.
598
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
599
- sin (`torch.Tensor`): The sine part of the rotary embedding.
600
- position_ids (`torch.Tensor`):
601
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
602
- used to pass offsetted position ids when working with a KV-cache.
603
- unsqueeze_dim (`int`, *optional*, defaults to 1):
604
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
605
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
606
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
607
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
608
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
609
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
610
- Returns:
611
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
612
- """
613
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
614
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
615
- q_embed = (q * cos) + (rotate_half(q) * sin)
616
- k_embed = (k * cos) + (rotate_half(k) * sin)
617
- return q_embed, k_embed
618
-
619
-
620
- def wrap_llama():
621
- origin_mehtods = {}
622
- origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
623
- origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
624
- origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
625
- origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
626
-
627
- LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
628
- LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
629
- LlamaModel.forward = _LlamaModel.forward
630
- LlamaForCausalLM.forward = _LlamaForCausalLM.forward
631
-
632
- return origin_mehtods
633
-
634
-
635
- def unwrap_llama(origin_mehtods):
636
- LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
637
- LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
638
- LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
639
- LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]
28
+ class LlamaWrapper(DecoderOnlyWrapper):
29
+ pass