optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +14 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/controlnet.py +3 -0
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  11. optimum/rbln/modeling_alias.py +14 -0
  12. optimum/rbln/modeling_base.py +110 -0
  13. optimum/rbln/transformers/__init__.py +6 -0
  14. optimum/rbln/transformers/cache_utils.py +111 -0
  15. optimum/rbln/transformers/generation/utils.py +0 -2
  16. optimum/rbln/transformers/models/__init__.py +2 -0
  17. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  18. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  19. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  20. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  21. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  22. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  23. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  24. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  25. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
  26. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  27. optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  29. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  30. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  31. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  33. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  34. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  35. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
  36. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
  37. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  38. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -1,764 +0,0 @@
1
- # Copyright 2024 Rebellions Inc.
2
-
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at:
6
-
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- # Portions of this software are licensed under the Apache License,
16
- # Version 2.0. See the NOTICE file distributed with this work for
17
- # additional information regarding copyright ownership.
18
-
19
- # All other portions of this software, including proprietary code,
20
- # are the intellectual property of Rebellions Inc. and may not be
21
- # copied, modified, or distributed without prior written permission
22
- # from Rebellions Inc.
23
-
24
- import math
25
- from typing import Any, Dict, List, Optional, Tuple, Union
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- from torch import nn
30
- from transformers.cache_utils import Cache, DynamicCache
31
- from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
32
- from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
33
- from transformers.models.llama.modeling_llama import (
34
- LlamaAttention,
35
- LlamaDecoderLayer,
36
- LlamaForCausalLM,
37
- LlamaModel,
38
- LlamaRotaryEmbedding,
39
- )
40
-
41
-
42
- """
43
- - define new class to put batch_position as a forward args
44
- - _LlamaForCausalLM receives batch_ids (default=None)
45
- """
46
-
47
-
48
- class LlamaDynamicBatchWrapper(torch.nn.Module):
49
- def __init__(self, model):
50
- super().__init__()
51
- self.model = model
52
-
53
- def forward(self, input_ids, attention_mask, cache_position, batch_position, *past_key_values):
54
- if input_ids.shape[1] == 1:
55
- rbln_batch_position = None
56
- else:
57
- rbln_batch_position = batch_position
58
-
59
- past_kv_list = []
60
- for i in range(self.model.config.num_hidden_layers):
61
- cur_kv_layer = []
62
- for j in range(2):
63
- cur_kv_layer.append(past_key_values[2 * i + j])
64
- past_kv_list.append(cur_kv_layer)
65
- model_output = self.model(
66
- input_ids=input_ids,
67
- attention_mask=attention_mask,
68
- position_ids=cache_position,
69
- past_key_values=past_kv_list,
70
- batch_ids=rbln_batch_position,
71
- )
72
-
73
- return model_output, batch_position
74
-
75
-
76
- class _LlamaRotaryEmbedding(nn.Module):
77
- def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
78
- super(LlamaRotaryEmbedding, self).__init__()
79
-
80
- self.dim = dim
81
- self.max_position_embeddings = max_position_embeddings
82
- self.base = base
83
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
84
- self.register_buffer("inv_freq", inv_freq, persistent=False)
85
-
86
- # Build here to make `torch.jit.trace` work.
87
- seq_len = max_position_embeddings
88
- device = self.inv_freq.device
89
- dtype = torch.get_default_dtype()
90
- self.max_seq_len_cached = seq_len
91
- t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
92
-
93
- freqs = torch.outer(t, self.inv_freq)
94
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
95
- emb = torch.cat((freqs, freqs), dim=-1)
96
- self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
97
- self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
98
-
99
- def forward(self, x, seq_len=None):
100
- # x: [bs, num_attention_heads, seq_len, head_size]
101
- if seq_len > self.max_seq_len_cached:
102
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
103
-
104
- return (
105
- self.cos_cached[:seq_len].to(dtype=x.dtype),
106
- self.sin_cached[:seq_len].to(dtype=x.dtype),
107
- )
108
-
109
-
110
- class _LlamaAttention(LlamaAttention):
111
- # single batch llama attention
112
- def forward(
113
- self,
114
- hidden_states: torch.Tensor,
115
- attention_mask: Optional[torch.Tensor] = None,
116
- position_ids: Optional[torch.LongTensor] = None,
117
- past_key_value: Optional[Cache] = None,
118
- batch_index: Optional[int] = None,
119
- output_attentions: bool = False,
120
- use_cache: bool = False,
121
- cos: Optional[torch.Tensor] = None,
122
- sin: Optional[torch.Tensor] = None,
123
- layer_id: int = 0,
124
- **kwargs,
125
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
126
- bsz, q_len, _ = hidden_states.size()
127
- if self.config.pretraining_tp > 1:
128
- key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
129
- query_slices = self.q_proj.weight.split(
130
- (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
131
- )
132
- key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
133
- value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
134
-
135
- query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
136
- query_states = torch.cat(query_states, dim=-1)
137
-
138
- key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
139
- key_states = torch.cat(key_states, dim=-1)
140
-
141
- value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
142
- value_states = torch.cat(value_states, dim=-1)
143
-
144
- else:
145
- query_states = self.q_proj(hidden_states)
146
- key_states = self.k_proj(hidden_states)
147
- value_states = self.v_proj(hidden_states)
148
-
149
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
150
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
151
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
152
-
153
- kv_seq_len = key_states.shape[-2]
154
- if past_key_value is not None:
155
- if self.layer_idx is None:
156
- raise ValueError(
157
- f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
158
- "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
159
- "with a layer index."
160
- )
161
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
162
- if layer_id == 0:
163
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
164
- query_states, key_states, cos, sin = apply_rotary_pos_emb(
165
- query_states, key_states, cos, sin, position_ids, layer_id
166
- )
167
- if past_key_value is not None:
168
- cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
169
- if (batch_index is None or batch_index == -1) and bsz > 1:
170
- all_key_states = []
171
- all_value_states = []
172
- all_attn_output = []
173
- for b in range(bsz):
174
- batch_query_states = query_states[b].unsqueeze(0)
175
- batch_attention_mask = attention_mask[b].unsqueeze(0)
176
- batch_key_states = key_states[b].unsqueeze(0)
177
- batch_value_states = value_states[b].unsqueeze(0)
178
-
179
- # reshape for removing repeat_kv
180
- batch_key_states = batch_key_states.unsqueeze(2)
181
- batch_value_states = batch_value_states.unsqueeze(2)
182
- batch_attention_mask = batch_attention_mask.unsqueeze(2)
183
- batch_query_states = batch_query_states.view(
184
- 1, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
185
- )
186
-
187
- batch_key_states, batch_value_states = past_key_value.update(
188
- batch_key_states, batch_value_states, self.layer_idx, b, cache_kwargs
189
- )
190
-
191
- # batch_key_states = repeat_kv(
192
- # batch_key_states,
193
- # self.num_key_value_groups
194
- # )
195
- # batch_value_states = repeat_kv(
196
- # batch_value_states,
197
- # self.num_key_value_groups
198
- # )
199
-
200
- # attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
201
- # reshape for removing repeat_kv
202
- attn_weights = torch.matmul(batch_query_states, batch_key_states.transpose(3, 4)) / math.sqrt(
203
- self.head_dim
204
- )
205
-
206
- attn_weights = attn_weights + batch_attention_mask
207
-
208
- # upcast attention to fp32
209
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
210
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
211
- attn_output = torch.matmul(attn_weights, batch_value_states)
212
-
213
- # reshape for removing repeat_kv
214
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
215
-
216
- attn_output = attn_output.transpose(1, 2).contiguous()
217
- attn_output = attn_output.reshape(1, q_len, self.hidden_size)
218
- all_key_states.append(batch_key_states)
219
- all_value_states.append(batch_value_states)
220
- all_attn_output.append(attn_output)
221
- key_states = torch.cat(all_key_states, dim=0)
222
- value_states = torch.cat(all_value_states, dim=0)
223
- attn_output = torch.cat(all_attn_output, dim=0)
224
-
225
- else:
226
- assert bsz == 1, "dynamic batch update only support input batch 1"
227
- if batch_index is None or batch_index == -1:
228
- batch_index = 0
229
-
230
- # reshape for removing repeat_kv
231
- key_states = key_states.unsqueeze(2)
232
- value_states = value_states.unsqueeze(2)
233
- attention_mask = attention_mask.unsqueeze(2)
234
- query_states = query_states.view(
235
- 1, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
236
- )
237
-
238
- key_states, value_states = past_key_value.update(
239
- key_states, value_states, self.layer_idx, batch_index, cache_kwargs, read_first_step=True
240
- )
241
-
242
- # key_states = repeat_kv(key_states, self.num_key_value_groups)
243
- # value_states = repeat_kv(value_states, self.num_key_value_groups)
244
-
245
- attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
246
-
247
- attn_weights = attn_weights + attention_mask
248
-
249
- # upcast attention to fp32
250
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
251
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
252
- attn_output = torch.matmul(attn_weights, value_states)
253
-
254
- # reshape for removing repeat_kv
255
- attn_output = attn_output.view(1, self.num_heads, q_len, self.head_dim)
256
-
257
- attn_output = attn_output.transpose(1, 2).contiguous()
258
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
259
-
260
- if self.config.pretraining_tp > 1:
261
- attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
262
- o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
263
- attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
264
- else:
265
- attn_output = self.o_proj(attn_output)
266
-
267
- if not output_attentions:
268
- attn_weights = None
269
-
270
- return attn_output, attn_weights, key_states, value_states, cos, sin
271
-
272
-
273
- class _LlamaDecoderLayer(LlamaDecoderLayer):
274
- def forward(
275
- self,
276
- hidden_states: torch.Tensor,
277
- layer_idx: int,
278
- attention_mask: Optional[torch.Tensor] = None,
279
- position_ids: Optional[torch.LongTensor] = None,
280
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
281
- output_attentions: Optional[bool] = False,
282
- use_cache: Optional[bool] = False,
283
- batch_ids: Optional[torch.LongTensor] = None,
284
- cos: Optional[torch.Tensor] = None,
285
- sin: Optional[torch.Tensor] = None,
286
- layer_id: int = 0,
287
- **kwargs,
288
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
289
- residual = hidden_states
290
-
291
- hidden_states = self.input_layernorm(hidden_states)
292
- bsz, _, _ = hidden_states.size()
293
-
294
- hidden_states, self_attn_weights, k, v, cos, sin = _LlamaAttention.forward(
295
- self.self_attn,
296
- hidden_states=hidden_states,
297
- attention_mask=attention_mask,
298
- position_ids=position_ids,
299
- past_key_value=past_key_value,
300
- output_attentions=output_attentions,
301
- batch_index=batch_ids,
302
- use_cache=use_cache,
303
- cos=cos,
304
- sin=sin,
305
- layer_id=layer_id,
306
- **kwargs,
307
- )
308
- past_key_value.assign(k, v, layer_idx)
309
-
310
- present_key_value = past_key_value
311
-
312
- hidden_states = residual + hidden_states
313
-
314
- # Fully Connected
315
- residual = hidden_states
316
- hidden_states = self.post_attention_layernorm(hidden_states)
317
- hidden_states = self.mlp(hidden_states)
318
- hidden_states = residual + hidden_states
319
-
320
- outputs = (hidden_states,)
321
-
322
- if output_attentions:
323
- outputs += (self_attn_weights,)
324
-
325
- if use_cache:
326
- outputs += (present_key_value,)
327
-
328
- return outputs, cos, sin
329
-
330
-
331
- class _LlamaModel(LlamaModel):
332
- def forward(
333
- self,
334
- input_ids: torch.LongTensor = None,
335
- attention_mask: Optional[torch.Tensor] = None,
336
- position_ids: Optional[torch.LongTensor] = None,
337
- past_key_values: Optional[List[torch.FloatTensor]] = None,
338
- batch_ids: Optional[torch.LongTensor] = None,
339
- inputs_embeds: Optional[torch.FloatTensor] = None,
340
- use_cache: Optional[bool] = None,
341
- output_attentions: Optional[bool] = None,
342
- output_hidden_states: Optional[bool] = None,
343
- return_dict: Optional[bool] = None,
344
- ) -> Union[Tuple, BaseModelOutputWithPast]:
345
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
346
- output_hidden_states = (
347
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
348
- )
349
- use_cache = use_cache if use_cache is not None else self.config.use_cache
350
-
351
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
352
-
353
- #### cannot change forward args? temporal workaround ####
354
- # current_step = position_ids
355
-
356
- # retrieve input_ids and inputs_embeds
357
- if input_ids is not None and inputs_embeds is not None:
358
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
359
- elif input_ids is not None:
360
- batch_size, seq_length = input_ids.shape[:2]
361
- elif inputs_embeds is not None:
362
- batch_size, seq_length = inputs_embeds.shape[:2]
363
- else:
364
- raise ValueError("You have to specify either input_ids or inputs_embeds")
365
-
366
- past_key_values_length = 0
367
- if use_cache:
368
- use_legacy_cache = not isinstance(past_key_values, Cache)
369
- if use_legacy_cache:
370
- past_key_values = RebelDynamicCache.from_legacy_cache(
371
- position_ids=position_ids,
372
- max_length=self.config.max_position_embeddings,
373
- past_key_values=past_key_values,
374
- )
375
-
376
- # not used, get_usable_length will be changed
377
- # past_key_values_length = past_key_values.get_usable_length(seq_length)
378
-
379
- #### position embedding indice ####
380
- # position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
381
-
382
- if position_ids is None:
383
- device = input_ids.device if input_ids is not None else inputs_embeds.device
384
- position_ids = torch.arange(
385
- past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
386
- )
387
- position_ids = position_ids.unsqueeze(0)
388
-
389
- if inputs_embeds is None:
390
- inputs_embeds = self.embed_tokens(input_ids)
391
-
392
- # ##### original condition for generating causal attention mask
393
- # if getattr(self.config, "_flash_attn_2_enabled", False):
394
- # # 2d mask is passed through the layers
395
- # attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
396
-
397
- # else:
398
- # # 4d mask is passed through the layers
399
- # attention_mask = _prepare_4d_causal_attention_mask(
400
- # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
401
- # )
402
- # ########################################################
403
-
404
- # yhboo changed for valid graph generation
405
- if getattr(self.config, "_flash_attn_2_enabled", False):
406
- # 2d mask is passed through the layers
407
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
408
- raise NotImplementedError
409
-
410
- elif attention_mask is not None and attention_mask.ndim == 4:
411
- # assuming attention mask is generated as input
412
- # assumed dim = [batch_size, 1, inp_seq, max_seq]
413
- # only make [1, 0] mask to [0, -inf]
414
- attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
415
-
416
- else:
417
- # 4d mask is passed through the layers
418
- attention_mask = _prepare_4d_causal_attention_mask(
419
- attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
420
- )
421
-
422
- # embed positions
423
- hidden_states = inputs_embeds
424
-
425
- # decoder layers
426
- all_hidden_states = () if output_hidden_states else None
427
- all_self_attns = () if output_attentions else None
428
- next_decoder_cache = () if use_cache else None
429
-
430
- cos = None
431
- sin = None
432
- for layer_idx, decoder_layer in enumerate(self.layers):
433
- if output_hidden_states:
434
- all_hidden_states += (hidden_states,)
435
- layer_outputs = _LlamaDecoderLayer.forward(
436
- decoder_layer,
437
- hidden_states,
438
- layer_idx,
439
- attention_mask=attention_mask,
440
- position_ids=position_ids,
441
- past_key_value=past_key_values,
442
- output_attentions=output_attentions,
443
- use_cache=use_cache,
444
- batch_ids=batch_ids,
445
- cos=cos,
446
- sin=sin,
447
- layer_id=layer_idx,
448
- )
449
- cos = layer_outputs[-2]
450
- sin = layer_outputs[-1]
451
- layer_outputs = layer_outputs[0]
452
-
453
- hidden_states = layer_outputs[0]
454
-
455
- if use_cache:
456
- next_decoder_cache = layer_outputs[2 if output_attentions else 1]
457
-
458
- if output_attentions:
459
- all_self_attns += (layer_outputs[1],)
460
-
461
- hidden_states = self.norm(hidden_states)
462
-
463
- # add hidden states from the last decoder layer
464
- if output_hidden_states:
465
- all_hidden_states += (hidden_states,)
466
-
467
- next_cache = None
468
- if use_cache:
469
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
470
-
471
- return BaseModelOutputWithPast(
472
- last_hidden_state=hidden_states,
473
- past_key_values=next_cache,
474
- hidden_states=all_hidden_states,
475
- attentions=all_self_attns,
476
- )
477
-
478
-
479
- class _LlamaForCausalLM(LlamaForCausalLM):
480
- def forward(
481
- self,
482
- input_ids: torch.LongTensor = None,
483
- attention_mask: Optional[torch.Tensor] = None,
484
- position_ids: Optional[torch.LongTensor] = None,
485
- past_key_values: Optional[List[torch.FloatTensor]] = None,
486
- batch_ids: Optional[torch.LongTensor] = None,
487
- inputs_embeds: Optional[torch.FloatTensor] = None,
488
- labels: Optional[torch.LongTensor] = None,
489
- use_cache: Optional[bool] = None,
490
- output_attentions: Optional[bool] = None,
491
- output_hidden_states: Optional[bool] = None,
492
- return_dict: Optional[bool] = None,
493
- ) -> Union[Tuple, CausalLMOutputWithPast]:
494
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
495
- output_hidden_states = (
496
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
497
- )
498
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
499
-
500
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
501
- outputs = self.model(
502
- input_ids=input_ids,
503
- attention_mask=attention_mask,
504
- position_ids=position_ids,
505
- past_key_values=past_key_values,
506
- inputs_embeds=inputs_embeds,
507
- use_cache=use_cache,
508
- output_attentions=output_attentions,
509
- output_hidden_states=output_hidden_states,
510
- return_dict=return_dict,
511
- batch_ids=batch_ids,
512
- )
513
-
514
- hidden_states = outputs[0]
515
- logits = self.lm_head(hidden_states)
516
- logits = logits.float()
517
-
518
- if not return_dict:
519
- output = (logits,) + outputs[1:]
520
- return output
521
-
522
- return CausalLMOutputWithPast(
523
- logits=logits,
524
- past_key_values=outputs.past_key_values,
525
- hidden_states=outputs.hidden_states,
526
- attentions=outputs.attentions,
527
- )
528
-
529
-
530
- class RebelDynamicCache(DynamicCache):
531
- """
532
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
533
-
534
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
535
- `[batch_size, num_heads, seq_len, head_dim]`.
536
- """
537
-
538
- def __init__(self, current_steps, max_length) -> None:
539
- super().__init__()
540
- self.current_steps = current_steps
541
- self.max_length = max_length
542
-
543
- def assign(
544
- self,
545
- key_states: torch.Tensor,
546
- value_states: torch.Tensor,
547
- layer_idx: int,
548
- ) -> None:
549
- self.key_cache[layer_idx] = key_states.squeeze(2)
550
- self.value_cache[layer_idx] = value_states.squeeze(2)
551
-
552
- def batch_select_assign(
553
- self,
554
- key_states: torch.Tensor,
555
- value_states: torch.Tensor,
556
- batch_ids: int,
557
- layer_idx: int,
558
- ) -> None:
559
- past_key = self.key_cache[layer_idx]
560
- past_value = self.value_cache[layer_idx]
561
-
562
- ## (ISSUE): relay scatter_element index have same shape as cache.. can remove?
563
- # update_key = past_key.slice_scatter(key_states, dim = 0, start=batch_ids, end=batch_ids+1)
564
- # update_value = past_value.slice_scatter(value_states, dim = 0, start=batch_ids, end=batch_ids+1)
565
-
566
- ## (ISSUE): torch select_scatter fits to the purpose (always replace single index), but not implmeneted to TVM yet..
567
- # update_key = past_key.select_scatter(key_states.squeeze(0), dim = 0, index=batch_ids)
568
- # update_value = past_value.select_scatter(value_states.squeeze(0), dim = 0, index=batch_ids)
569
- cache_batch_size = past_key.shape[0]
570
- if cache_batch_size == 1:
571
- self.key_cache[layer_idx] = key_states
572
- self.value_cache[layer_idx] = value_states
573
- else:
574
- update_key = [key_states]
575
- update_value = [value_states]
576
- for i in range(1, cache_batch_size):
577
- update_key.append(past_key[i : i + 1])
578
- update_value.append(past_value[i : i + 1])
579
- update_key = torch.cat(update_key, dim=0)
580
- update_value = torch.cat(update_value, dim=0)
581
- self.key_cache[layer_idx] = update_key
582
- self.value_cache[layer_idx] = update_value
583
-
584
- ## (ISSUE): tvm copy issue
585
- # past_key[batch_ids] = key_states
586
- # past_value[batch_ids] = value_states
587
- # self.key_cache[layer_idx] = past_key
588
- # self.value_cache[layer_idx] = past_value
589
-
590
- def copy(
591
- self,
592
- key_states: torch.Tensor,
593
- value_states: torch.Tensor,
594
- layer_idx: int,
595
- cache_kwargs: Optional[Dict[str, Any]] = None,
596
- ) -> Tuple[torch.Tensor, torch.Tensor]:
597
- """
598
- Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
599
- just for from_legacy_cache function
600
-
601
- Parameters:
602
- key_states (`torch.Tensor`):
603
- The new key states to cache.
604
- value_states (`torch.Tensor`):
605
- The new value states to cache.
606
- layer_idx (`int`):
607
- The index of the layer to cache the states for.
608
- cache_kwargs (`Dict[str, Any]`, `optional`):
609
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
610
-
611
- Return:
612
- A tuple containing the updated key and value states.
613
- """
614
- # Update the number of seen tokens : deprecated
615
- # if layer_idx == 0:
616
- # self.seen_tokens += key_states.shape[-2]
617
-
618
- # Update the cache
619
- if len(self.key_cache) <= layer_idx:
620
- self.key_cache.append(key_states)
621
- self.value_cache.append(value_states)
622
- else:
623
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
624
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
625
-
626
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
627
-
628
- def update(
629
- self,
630
- key_states: torch.Tensor,
631
- value_states: torch.Tensor,
632
- layer_idx: int,
633
- batch_index: int,
634
- cache_kwargs: Optional[Dict[str, Any]] = None,
635
- read_first_step: Optional[bool] = False,
636
- ) -> Tuple[torch.Tensor, torch.Tensor]:
637
- """
638
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
639
- based on self.current_step,
640
-
641
- Parameters:
642
- key_states (`torch.Tensor`):
643
- The new key states to cache.
644
- value_states (`torch.Tensor`):
645
- The new value states to cache.
646
- layer_idx (`int`):
647
- The index of the layer to cache the states for.
648
- cache_kwargs (`Dict[str, Any]`, `optional`):
649
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
650
-
651
- Return:
652
- A tuple containing the updated key and value states.
653
- """
654
- # # Update the number of seen tokens : deprecated
655
- # if layer_idx == 0:
656
- # self.seen_tokens += key_states.shape[-2]
657
-
658
- # Update the cache
659
- if len(self.key_cache) <= layer_idx:
660
- self.key_cache.append(key_states)
661
- self.value_cache.append(value_states)
662
- else:
663
- # [B,H,M,D]
664
- # kv cache = [B, H, 4096, D]
665
- # states = [1, H, 128, D]
666
- # want to update states into kv_cache[batch_index][current_step]
667
- # import pdb; pdb.set_trace()
668
- current_step = self.current_steps[0 if read_first_step else batch_index]
669
- kend = current_step + key_states.shape[-2]
670
- vend = current_step + value_states.shape[-2]
671
- update_key_states = (
672
- self.key_cache[layer_idx][batch_index]
673
- .unsqueeze(0)
674
- .unsqueeze(2)
675
- .slice_scatter(key_states, dim=-2, start=current_step, end=kend)
676
- )
677
- update_value_states = (
678
- self.value_cache[layer_idx][batch_index]
679
- .unsqueeze(0)
680
- .unsqueeze(2)
681
- .slice_scatter(value_states, dim=-2, start=current_step, end=vend)
682
- )
683
- return update_key_states, update_value_states
684
-
685
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
686
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
687
- if len(self.key_cache) <= layer_idx:
688
- return 0
689
- return self.key_cache[layer_idx].shape[-2]
690
-
691
- def get_max_length(self) -> Optional[int]:
692
- return self.max_length
693
-
694
- @classmethod
695
- def from_legacy_cache(
696
- cls, position_ids, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
697
- ) -> "DynamicCache":
698
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
699
- batch, seq_len = position_ids.shape
700
- # make current_steps lists from position_ids
701
- # position_ids[b][0] is equal to cache position of each batch
702
- current_steps = [position_ids[b][0] for b in range(batch)]
703
- assert len(current_steps) == batch
704
- cache = cls(current_steps, max_length)
705
- if past_key_values is not None:
706
- for layer_idx in range(len(past_key_values)):
707
- key_states, value_states = past_key_values[layer_idx]
708
- cache.copy(key_states, value_states, layer_idx)
709
- return cache
710
-
711
-
712
- def rotate_half(x):
713
- """Rotates half the hidden dims of the input."""
714
- x1 = x[..., : x.shape[-1] // 2]
715
- x2 = x[..., x.shape[-1] // 2 :]
716
- return torch.cat((-x2, x1), dim=-1)
717
-
718
-
719
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids, layer_id, unsqueeze_dim=1):
720
- """Applies Rotary Position Embedding to the query and key tensors.
721
-
722
- Args:
723
- q (`torch.Tensor`): The query tensor.
724
- k (`torch.Tensor`): The key tensor.
725
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
726
- sin (`torch.Tensor`): The sine part of the rotary embedding.
727
- position_ids (`torch.Tensor`):
728
- The position indices of the tokens corresponding to the query and key tensors. For example, this can be
729
- used to pass offsetted position ids when working with a KV-cache.
730
- unsqueeze_dim (`int`, *optional*, defaults to 1):
731
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
732
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
733
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
734
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
735
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
736
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
737
- Returns:
738
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
739
- """
740
- if layer_id == 0:
741
- if position_ids.shape[0] > 1:
742
- cos_all = []
743
- sin_all = []
744
- for i in range(position_ids.shape[0]):
745
- cos_all.append(cos[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
746
- sin_all.append(sin[position_ids[i : i + 1]].unsqueeze(unsqueeze_dim))
747
- cos = torch.cat(cos_all, dim=0)
748
- sin = torch.cat(sin_all, dim=0)
749
- else:
750
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
751
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
752
- # cos = cos[position_ids].unsqueeze(unsqueeze_dim)
753
- # sin = sin[position_ids].unsqueeze(unsqueeze_dim)
754
-
755
- q_embed = (q * cos) + (rotate_half(q) * sin)
756
- k_embed = (k * cos) + (rotate_half(k) * sin)
757
- return q_embed, k_embed, cos, sin
758
-
759
-
760
- def wrap_llama():
761
- LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
762
- LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
763
- LlamaModel.forward = _LlamaModel.forward
764
- LlamaForCausalLM.forward = _LlamaForCausalLM.forward