optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -1,759 +0,0 @@
1
- # Copyright 2024 Rebellions Inc.
2
-
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at:
6
-
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
- import math
25
- from typing import Any, Dict, List, Optional, Tuple, Union
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- from torch import nn
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
- from transformers.models.llama.modeling_llama import (
34
- LlamaAttention,
35
- LlamaDecoderLayer,
36
- LlamaForCausalLM,
37
- LlamaModel,
38
- LlamaRotaryEmbedding,
39
- )
40
-
41
-
42
- """
43
- - define new class to put batch_position as a forward args
44
- - _LlamaForCausalLM receives batch_ids (default=None)
45
- """
46
-
47
-
48
- class LlamaDynamicBatchWrapper(torch.nn.Module):
49
- def __init__(self, model):
50
- super().__init__()
51
- self.model = model
52
-
53
- def forward(self, input_ids, attention_mask, cache_position, batch_position, *past_key_values):
54
- if input_ids.shape[1] == 1:
55
- rbln_batch_position = None
56
- else:
57
- rbln_batch_position = batch_position
58
-
59
- past_kv_list = []
60
- for i in range(self.model.config.num_hidden_layers):
61
- cur_kv_layer = []
62
- for j in range(2):
63
- cur_kv_layer.append(past_key_values[2 * i + j])
64
- past_kv_list.append(cur_kv_layer)
65
- model_output = self.model(
66
- input_ids=input_ids,
67
- attention_mask=attention_mask,
68
- position_ids=cache_position,
69
- past_key_values=past_kv_list,
70
- batch_ids=rbln_batch_position,
71
- )
72
-
73
- return model_output, batch_position
74
-
75
-
76
- class _LlamaRotaryEmbedding(nn.Module):
77
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
78
- super(LlamaRotaryEmbedding, self).__init__()
79
-
80
- self.dim = dim
81
- self.max_position_embeddings = max_position_embeddings
82
- self.base = base
83
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
84
- self.register_buffer("inv_freq", inv_freq, persistent=False)
85
-
86
- # Build here to make `torch.jit.trace` work.
87
- seq_len = max_position_embeddings
88
- device = self.inv_freq.device
89
- dtype = torch.get_default_dtype()
90
- self.max_seq_len_cached = seq_len
91
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
92
-
93
- freqs = torch.outer(t, self.inv_freq)
94
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
95
- emb = torch.cat((freqs, freqs), dim=-1)
96
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
97
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
98
-
99
- def forward(self, x, seq_len=None):
100
- # x: [bs, num_attention_heads, seq_len, head_size]
101
- if seq_len > self.max_seq_len_cached:
102
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
103
-
104
- return (
105
- self.cos_cached[:seq_len].to(dtype=x.dtype),
106
- self.sin_cached[:seq_len].to(dtype=x.dtype),
107
- )
108
-
109
-
110
- class _LlamaAttention(LlamaAttention):
111
- # single batch llama attention
112
- def forward(
113
- self,
114
- hidden_states: torch.Tensor,
115
- attention_mask: Optional[torch.Tensor] = None,
116
- position_ids: Optional[torch.LongTensor] = None,
117
- past_key_value: Optional[Cache] = None,
118
- batch_index: Optional[int] = None,
119
- output_attentions: bool = False,
120
- use_cache: bool = False,
121
- **kwargs,
122
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
123
- bsz, q_len, _ = hidden_states.size()
124
- if self.config.pretraining_tp > 1:
125
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
126
- query_slices = self.q_proj.weight.split(
127
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
128
- )
129
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
130
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
131
-
132
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
133
- query_states = torch.cat(query_states, dim=-1)
134
-
135
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
136
- key_states = torch.cat(key_states, dim=-1)
137
-
138
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
139
- value_states = torch.cat(value_states, dim=-1)
140
-
141
- else:
142
- query_states = self.q_proj(hidden_states)
143
- key_states = self.k_proj(hidden_states)
144
- value_states = self.v_proj(hidden_states)
145
-
146
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
147
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
148
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
149
-
150
- kv_seq_len = key_states.shape[-2]
151
- if past_key_value is not None:
152
- if self.layer_idx is None:
153
- raise ValueError(
154
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
155
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
156
- "with a layer index."
157
- )
158
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
159
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
160
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
161
- if past_key_value is not None:
162
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
163
- if (batch_index is None or batch_index == -1) and bsz > 1:
164
- all_key_states = []
165
- all_value_states = []
166
- all_attn_output = []
167
- for b in range(bsz):
168
- batch_query_states = query_states[b].unsqueeze(0)
169
- batch_attention_mask = attention_mask[b].unsqueeze(0)
170
- batch_key_states = key_states[b].unsqueeze(0)
171
- batch_value_states = value_states[b].unsqueeze(0)
172
-
173
- # reshape for removing repeat_kv
174
- batch_key_states = batch_key_states.unsqueeze(2)
175
- batch_value_states = batch_value_states.unsqueeze(2)
176
- batch_attention_mask = batch_attention_mask.unsqueeze(2)
177
- batch_query_states = batch_query_states.view(
178
- 1, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
179
- )
180
-
181
- batch_key_states, batch_value_states = past_key_value.update(
182
- batch_key_states, batch_value_states, self.layer_idx, b, cache_kwargs
183
- )
184
-
185
- # batch_key_states = repeat_kv(
186
- # batch_key_states,
187
- # self.num_key_value_groups
188
- # )
189
- # batch_value_states = repeat_kv(
190
- # batch_value_states,
191
- # self.num_key_value_groups
192
- # )
193
-
194
- # attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
195
- # reshape for removing repeat_kv
196
- attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4)) / math.sqrt(
197
- self.head_dim
198
- )
199
-
200
- attn_weights = attn_weights + batch_attention_mask
201
-
202
- # upcast attention to fp32
203
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
204
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
205
- attn_output = torch.matmul(attn_weights, batch_value_states)
206
-
207
- # reshape for removing repeat_kv
208
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
209
-
210
- attn_output = attn_output.transpose(1, 2).contiguous()
211
- attn_output = attn_output.reshape(1, q_len, self.hidden_size)
212
- all_key_states.append(batch_key_states)
213
- all_value_states.append(batch_value_states)
214
- all_attn_output.append(attn_output)
215
- key_states = torch.cat(all_key_states, dim=0)
216
- value_states = torch.cat(all_value_states, dim=0)
217
- attn_output = torch.cat(all_attn_output, dim=0)
218
-
219
- else:
220
- assert bsz == 1, "dynamic batch update only support input batch 1"
221
- if batch_index is None or batch_index == -1:
222
- batch_index = 0
223
-
224
- # reshape for removing repeat_kv
225
- key_states = key_states.unsqueeze(2)
226
- value_states = value_states.unsqueeze(2)
227
- attention_mask = attention_mask.unsqueeze(2)
228
- query_states = query_states.view(
229
- 1, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
230
- )
231
-
232
- key_states, value_states = past_key_value.update(
233
- key_states, value_states, self.layer_idx, batch_index, cache_kwargs, read_first_step=True
234
- )
235
-
236
- # key_states = repeat_kv(key_states, self.num_key_value_groups)
237
- # value_states = repeat_kv(value_states, self.num_key_value_groups)
238
-
239
- attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
240
-
241
- attn_weights = attn_weights + attention_mask
242
-
243
- # upcast attention to fp32
244
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
245
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
246
- attn_output = torch.matmul(attn_weights, value_states)
247
-
248
- # reshape for removing repeat_kv
249
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
250
-
251
- attn_output = attn_output.transpose(1, 2).contiguous()
252
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
253
-
254
- if self.config.pretraining_tp > 1:
255
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
256
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
257
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
258
- else:
259
- attn_output = self.o_proj(attn_output)
260
-
261
- if not output_attentions:
262
- attn_weights = None
263
-
264
- return attn_output, attn_weights, key_states, value_states
265
-
266
-
267
- class _LlamaDecoderLayer(LlamaDecoderLayer):
268
- def forward(
269
- self,
270
- hidden_states: torch.Tensor,
271
- layer_idx: int,
272
- attention_mask: Optional[torch.Tensor] = None,
273
- position_ids: Optional[torch.LongTensor] = None,
274
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
275
- output_attentions: Optional[bool] = False,
276
- use_cache: Optional[bool] = False,
277
- batch_ids: Optional[torch.LongTensor] = None,
278
- **kwargs,
279
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
280
- residual = hidden_states
281
-
282
- hidden_states = self.input_layernorm(hidden_states)
283
- bsz, _, _ = hidden_states.size()
284
-
285
- hidden_states, self_attn_weights, k, v = _LlamaAttention.forward(
286
- self.self_attn,
287
- hidden_states=hidden_states,
288
- attention_mask=attention_mask,
289
- position_ids=position_ids,
290
- past_key_value=past_key_value,
291
- output_attentions=output_attentions,
292
- batch_index=batch_ids,
293
- use_cache=use_cache,
294
- **kwargs,
295
- )
296
- past_key_value.assign(k, v, layer_idx)
297
-
298
- present_key_value = past_key_value
299
-
300
- hidden_states = residual + hidden_states
301
-
302
- # Fully Connected
303
- residual = hidden_states
304
- hidden_states = self.post_attention_layernorm(hidden_states)
305
- hidden_states = self.mlp(hidden_states)
306
- hidden_states = residual + hidden_states
307
-
308
- outputs = (hidden_states,)
309
-
310
- if output_attentions:
311
- outputs += (self_attn_weights,)
312
-
313
- if use_cache:
314
- outputs += (present_key_value,)
315
-
316
- return outputs
317
-
318
-
319
- class _LlamaModel(LlamaModel):
320
- def forward(
321
- self,
322
- input_ids: torch.LongTensor = None,
323
- attention_mask: Optional[torch.Tensor] = None,
324
- position_ids: Optional[torch.LongTensor] = None,
325
- past_key_values: Optional[List[torch.FloatTensor]] = None,
326
- batch_ids: Optional[torch.LongTensor] = None,
327
- inputs_embeds: Optional[torch.FloatTensor] = None,
328
- use_cache: Optional[bool] = None,
329
- output_attentions: Optional[bool] = None,
330
- output_hidden_states: Optional[bool] = None,
331
- return_dict: Optional[bool] = None,
332
- ) -> Union[Tuple, BaseModelOutputWithPast]:
333
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
334
- output_hidden_states = (
335
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
336
- )
337
- use_cache = use_cache if use_cache is not None else self.config.use_cache
338
-
339
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
340
-
341
- #### cannot change forward args? temporal workaround ####
342
- # current_step = position_ids
343
-
344
- # retrieve input_ids and inputs_embeds
345
- if input_ids is not None and inputs_embeds is not None:
346
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
347
- elif input_ids is not None:
348
- batch_size, seq_length = input_ids.shape[:2]
349
- elif inputs_embeds is not None:
350
- batch_size, seq_length = inputs_embeds.shape[:2]
351
- else:
352
- raise ValueError("You have to specify either input_ids or inputs_embeds")
353
-
354
- past_key_values_length = 0
355
- if use_cache:
356
- use_legacy_cache = not isinstance(past_key_values, Cache)
357
- if use_legacy_cache:
358
- past_key_values = RebelDynamicCache.from_legacy_cache(
359
- position_ids=position_ids,
360
- max_length=self.config.max_position_embeddings,
361
- past_key_values=past_key_values,
362
- )
363
-
364
- # not used, get_usable_length will be changed
365
- # past_key_values_length = past_key_values.get_usable_length(seq_length)
366
-
367
- #### position embedding indice ####
368
- # position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
369
-
370
- if position_ids is None:
371
- device = input_ids.device if input_ids is not None else inputs_embeds.device
372
- position_ids = torch.arange(
373
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
374
- )
375
- position_ids = position_ids.unsqueeze(0)
376
-
377
- if inputs_embeds is None:
378
- inputs_embeds = self.embed_tokens(input_ids)
379
-
380
- # ##### original condition for generating causal attention mask
381
- # if getattr(self.config, "_flash_attn_2_enabled", False):
382
- # # 2d mask is passed through the layers
383
- # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
384
-
385
- # else:
386
- # # 4d mask is passed through the layers
387
- # attention_mask = _prepare_4d_causal_attention_mask(
388
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
389
- # )
390
- # ########################################################
391
-
392
- # yhboo changed for valid graph generation
393
- if getattr(self.config, "_flash_attn_2_enabled", False):
394
- # 2d mask is passed through the layers
395
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
396
- raise NotImplementedError
397
-
398
- elif attention_mask is not None and attention_mask.ndim == 4:
399
- # assuming attention mask is generated as input
400
- # assumed dim = [batch_size, 1, inp_seq, max_seq]
401
- # only make [1, 0] mask to [0, -inf]
402
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
403
-
404
- else:
405
- # 4d mask is passed through the layers
406
- attention_mask = _prepare_4d_causal_attention_mask(
407
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
408
- )
409
-
410
- # embed positions
411
- hidden_states = inputs_embeds
412
-
413
- # decoder layers
414
- all_hidden_states = () if output_hidden_states else None
415
- all_self_attns = () if output_attentions else None
416
- next_decoder_cache = () if use_cache else None
417
-
418
- for layer_idx, decoder_layer in enumerate(self.layers):
419
- if output_hidden_states:
420
- all_hidden_states += (hidden_states,)
421
-
422
- layer_outputs = _LlamaDecoderLayer.forward(
423
- decoder_layer,
424
- hidden_states,
425
- layer_idx,
426
- attention_mask=attention_mask,
427
- position_ids=position_ids,
428
- past_key_value=past_key_values,
429
- output_attentions=output_attentions,
430
- use_cache=use_cache,
431
- batch_ids=batch_ids,
432
- )
433
-
434
- hidden_states = layer_outputs[0]
435
-
436
- if use_cache:
437
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
438
-
439
- if output_attentions:
440
- all_self_attns += (layer_outputs[1],)
441
-
442
- hidden_states = self.norm(hidden_states)
443
-
444
- # add hidden states from the last decoder layer
445
- if output_hidden_states:
446
- all_hidden_states += (hidden_states,)
447
-
448
- next_cache = None
449
- if use_cache:
450
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
451
-
452
- return BaseModelOutputWithPast(
453
- last_hidden_state=hidden_states,
454
- past_key_values=next_cache,
455
- hidden_states=all_hidden_states,
456
- attentions=all_self_attns,
457
- )
458
-
459
-
460
- class _LlamaForCausalLM(LlamaForCausalLM):
461
- def forward(
462
- self,
463
- input_ids: torch.LongTensor = None,
464
- attention_mask: Optional[torch.Tensor] = None,
465
- position_ids: Optional[torch.LongTensor] = None,
466
- past_key_values: Optional[List[torch.FloatTensor]] = None,
467
- batch_ids: Optional[torch.LongTensor] = None,
468
- inputs_embeds: Optional[torch.FloatTensor] = None,
469
- labels: Optional[torch.LongTensor] = None,
470
- use_cache: Optional[bool] = None,
471
- output_attentions: Optional[bool] = None,
472
- output_hidden_states: Optional[bool] = None,
473
- return_dict: Optional[bool] = None,
474
- ) -> Union[Tuple, CausalLMOutputWithPast]:
475
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
476
- output_hidden_states = (
477
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
478
- )
479
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
480
-
481
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
482
- outputs = self.model(
483
- input_ids=input_ids,
484
- attention_mask=attention_mask,
485
- position_ids=position_ids,
486
- past_key_values=past_key_values,
487
- inputs_embeds=inputs_embeds,
488
- use_cache=use_cache,
489
- output_attentions=output_attentions,
490
- output_hidden_states=output_hidden_states,
491
- return_dict=return_dict,
492
- batch_ids=batch_ids,
493
- )
494
-
495
- hidden_states = outputs[0]
496
- logits = self.lm_head(hidden_states)
497
- logits = logits.float()
498
-
499
- if not return_dict:
500
- output = (logits,) + outputs[1:]
501
- return output
502
-
503
- return CausalLMOutputWithPast(
504
- logits=logits,
505
- past_key_values=outputs.past_key_values,
506
- hidden_states=outputs.hidden_states,
507
- attentions=outputs.attentions,
508
- )
509
-
510
-
511
- class RebelDynamicCache(DynamicCache):
512
- """
513
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
514
-
515
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
516
- `[batch_size, num_heads, seq_len, head_dim]`.
517
- """
518
-
519
- def __init__(self, current_steps, max_length) -> None:
520
- super().__init__()
521
- self.current_steps = current_steps
522
- self.max_length = max_length
523
-
524
- def assign(
525
- self,
526
- key_states: torch.Tensor,
527
- value_states: torch.Tensor,
528
- layer_idx: int,
529
- ) -> None:
530
- self.key_cache[layer_idx] = key_states.squeeze(2)
531
- self.value_cache[layer_idx] = value_states.squeeze(2)
532
-
533
- def batch_select_assign(
534
- self,
535
- key_states: torch.Tensor,
536
- value_states: torch.Tensor,
537
- batch_ids: int,
538
- layer_idx: int,
539
- ) -> None:
540
- past_key = self.key_cache[layer_idx]
541
- past_value = self.value_cache[layer_idx]
542
-
543
- ## (ISSUE): relay scatter_element index have same shape as cache.. can remove?
544
- # update_key = past_key.slice_scatter(key_states, dim = 0, start=batch_ids, end=batch_ids+1)
545
- # update_value = past_value.slice_scatter(value_states, dim = 0, start=batch_ids, end=batch_ids+1)
546
-
547
- ## (ISSUE): torch select_scatter fits to the purpose (always replace single index), but not implmeneted to TVM yet..
548
- # update_key = past_key.select_scatter(key_states.squeeze(0), dim = 0, index=batch_ids)
549
- # update_value = past_value.select_scatter(value_states.squeeze(0), dim = 0, index=batch_ids)
550
- cache_batch_size = past_key.shape[0]
551
- if cache_batch_size == 1:
552
- self.key_cache[layer_idx] = key_states
553
- self.value_cache[layer_idx] = value_states
554
- else:
555
- update_key = [key_states]
556
- update_value = [value_states]
557
- for i in range(1, cache_batch_size):
558
- update_key.append(past_key[i : i + 1])
559
- update_value.append(past_value[i : i + 1])
560
- update_key = torch.cat(update_key, dim=0)
561
- update_value = torch.cat(update_value, dim=0)
562
- self.key_cache[layer_idx] = update_key
563
- self.value_cache[layer_idx] = update_value
564
-
565
- ## (ISSUE): tvm copy issue
566
- # past_key[batch_ids] = key_states
567
- # past_value[batch_ids] = value_states
568
- # self.key_cache[layer_idx] = past_key
569
- # self.value_cache[layer_idx] = past_value
570
-
571
- def copy(
572
- self,
573
- key_states: torch.Tensor,
574
- value_states: torch.Tensor,
575
- layer_idx: int,
576
- cache_kwargs: Optional[Dict[str, Any]] = None,
577
- ) -> Tuple[torch.Tensor, torch.Tensor]:
578
- """
579
- Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
580
- just for from_legacy_cache function
581
-
582
- Parameters:
583
- key_states (`torch.Tensor`):
584
- The new key states to cache.
585
- value_states (`torch.Tensor`):
586
- The new value states to cache.
587
- layer_idx (`int`):
588
- The index of the layer to cache the states for.
589
- cache_kwargs (`Dict[str, Any]`, `optional`):
590
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
591
-
592
- Return:
593
- A tuple containing the updated key and value states.
594
- """
595
- # Update the number of seen tokens : deprecated
596
- # if layer_idx == 0:
597
- # self.seen_tokens += key_states.shape[-2]
598
-
599
- # Update the cache
600
- if len(self.key_cache) <= layer_idx:
601
- self.key_cache.append(key_states)
602
- self.value_cache.append(value_states)
603
- else:
604
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
605
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
606
-
607
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
608
-
609
- def update(
610
- self,
611
- key_states: torch.Tensor,
612
- value_states: torch.Tensor,
613
- layer_idx: int,
614
- batch_index: int,
615
- cache_kwargs: Optional[Dict[str, Any]] = None,
616
- read_first_step: Optional[bool] = False,
617
- ) -> Tuple[torch.Tensor, torch.Tensor]:
618
- """
619
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
620
- based on self.current_step,
621
-
622
- Parameters:
623
- key_states (`torch.Tensor`):
624
- The new key states to cache.
625
- value_states (`torch.Tensor`):
626
- The new value states to cache.
627
- layer_idx (`int`):
628
- The index of the layer to cache the states for.
629
- cache_kwargs (`Dict[str, Any]`, `optional`):
630
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
631
-
632
- Return:
633
- A tuple containing the updated key and value states.
634
- """
635
- # # Update the number of seen tokens : deprecated
636
- # if layer_idx == 0:
637
- # self.seen_tokens += key_states.shape[-2]
638
-
639
- # Update the cache
640
- if len(self.key_cache) <= layer_idx:
641
- self.key_cache.append(key_states)
642
- self.value_cache.append(value_states)
643
- else:
644
- # [B,H,M,D]
645
- # kv cache = [B, H, 4096, D]
646
- # states = [1, H, 128, D]
647
- # want to update states into kv_cache[batch_index][current_step]
648
- # import pdb; pdb.set_trace()
649
- current_step = self.current_steps[0 if read_first_step else batch_index]
650
- kend = current_step + key_states.shape[-2]
651
- vend = current_step + value_states.shape[-2]
652
- update_key_states = (
653
- self.key_cache[layer_idx][batch_index]
654
- .unsqueeze(0)
655
- .unsqueeze(2)
656
- .slice_scatter(key_states, dim=-2, start=current_step, end=kend)
657
- )
658
- update_value_states = (
659
- self.value_cache[layer_idx][batch_index]
660
- .unsqueeze(0)
661
- .unsqueeze(2)
662
- .slice_scatter(value_states, dim=-2, start=current_step, end=vend)
663
- )
664
- return update_key_states, update_value_states
665
-
666
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
667
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
668
- if len(self.key_cache) <= layer_idx:
669
- return 0
670
- return self.key_cache[layer_idx].shape[-2]
671
-
672
- def get_max_length(self) -> Optional[int]:
673
- return self.max_length
674
-
675
- @classmethod
676
- def from_legacy_cache(
677
- cls, position_ids, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
678
- ) -> "DynamicCache":
679
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
680
- batch, seq_len = position_ids.shape
681
- # make current_steps lists from position_ids
682
- # position_ids[b][0] is equal to cache position of each batch
683
- current_steps = [position_ids[b][0] for b in range(batch)]
684
- assert len(current_steps) == batch
685
- cache = cls(current_steps, max_length)
686
- if past_key_values is not None:
687
- for layer_idx in range(len(past_key_values)):
688
- key_states, value_states = past_key_values[layer_idx]
689
- cache.copy(key_states, value_states, layer_idx)
690
- return cache
691
-
692
-
693
- def rotate_half(x):
694
- """Rotates half the hidden dims of the input."""
695
- x1 = x[..., : x.shape[-1] // 2]
696
- x2 = x[..., x.shape[-1] // 2 :]
697
- return torch.cat((-x2, x1), dim=-1)
698
-
699
-
700
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
701
- """Applies Rotary Position Embedding to the query and key tensors.
702
-
703
- Args:
704
- q (`torch.Tensor`): The query tensor.
705
- k (`torch.Tensor`): The key tensor.
706
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
707
- sin (`torch.Tensor`): The sine part of the rotary embedding.
708
- position_ids (`torch.Tensor`):
709
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
710
- used to pass offsetted position ids when working with a KV-cache.
711
- unsqueeze_dim (`int`, *optional*, defaults to 1):
712
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
713
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
714
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
715
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
716
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
717
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
718
- Returns:
719
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
720
- """
721
- if position_ids.shape[0] > 1:
722
- cos_all = []
723
- sin_all = []
724
- for i in range(position_ids.shape[0]):
725
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
726
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
727
- cos = torch.cat(cos_all, dim=0)
728
- sin = torch.cat(sin_all, dim=0)
729
- else:
730
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
731
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
732
- # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
733
- # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
734
-
735
- q_embed = (q * cos) + (rotate_half(q) * sin)
736
- k_embed = (k * cos) + (rotate_half(k) * sin)
737
- return q_embed, k_embed
738
-
739
-
740
- def wrap_llama():
741
- origin_mehtods = {}
742
- origin_mehtods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
743
- origin_mehtods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
744
- origin_mehtods["LlamaModel_forward"] = LlamaModel.forward
745
- origin_mehtods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
746
-
747
- LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
748
- LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
749
- LlamaModel.forward = _LlamaModel.forward
750
- LlamaForCausalLM.forward = _LlamaForCausalLM.forward
751
-
752
- return origin_mehtods
753
-
754
-
755
- def unwrap_llama(origin_mehtods):
756
- LlamaRotaryEmbedding.__init__ = origin_mehtods["LlamaRotaryEmbedding_INIT"]
757
- LlamaRotaryEmbedding.forward = origin_mehtods["LlamaRotaryEmbedding_forward"]
758
- LlamaModel.forward = origin_mehtods["LlamaModel_forward"]
759
- LlamaForCausalLM.forward = origin_mehtods["LlamaForCausalLM_forward"]