optimum-rbln 0.1.7__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +14 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -144
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +110 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +56 -220
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +8 -442
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -272
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +2 -3
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/RECORD +38 -30
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -21,616 +21,9 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
import math
|
25
|
-
from typing import Any, Dict, List, Optional, Tuple, Union
|
26
24
|
|
27
|
-
import
|
28
|
-
import torch.nn.functional as F
|
29
|
-
from torch import nn
|
30
|
-
from transformers.cache_utils import Cache, DynamicCache
|
31
|
-
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
32
|
-
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
33
|
-
from transformers.models.llama.modeling_llama import (
|
34
|
-
LlamaAttention,
|
35
|
-
LlamaDecoderLayer,
|
36
|
-
LlamaForCausalLM,
|
37
|
-
LlamaModel,
|
38
|
-
LlamaRotaryEmbedding,
|
39
|
-
)
|
25
|
+
from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
|
40
26
|
|
41
27
|
|
42
|
-
class LlamaWrapper(
|
43
|
-
|
44
|
-
super().__init__()
|
45
|
-
self.model = model
|
46
|
-
|
47
|
-
def forward(self, input_ids, attention_mask, cache_position, *past_key_values):
|
48
|
-
past_kv_list = []
|
49
|
-
for i in range(self.model.config.num_hidden_layers):
|
50
|
-
cur_kv_layer = []
|
51
|
-
for j in range(2):
|
52
|
-
cur_kv_layer.append(past_key_values[2 * i + j])
|
53
|
-
past_kv_list.append(cur_kv_layer)
|
54
|
-
|
55
|
-
model_output = self.model(
|
56
|
-
input_ids=input_ids,
|
57
|
-
attention_mask=attention_mask,
|
58
|
-
position_ids=cache_position, # change forward function
|
59
|
-
past_key_values=past_kv_list,
|
60
|
-
)
|
61
|
-
|
62
|
-
return model_output
|
63
|
-
|
64
|
-
|
65
|
-
class _LlamaRotaryEmbedding(nn.Module):
|
66
|
-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
67
|
-
super(LlamaRotaryEmbedding, self).__init__()
|
68
|
-
|
69
|
-
self.dim = dim
|
70
|
-
self.max_position_embeddings = max_position_embeddings
|
71
|
-
self.base = base
|
72
|
-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
|
73
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
74
|
-
|
75
|
-
# Build here to make `torch.jit.trace` work.
|
76
|
-
seq_len = max_position_embeddings
|
77
|
-
device = self.inv_freq.device
|
78
|
-
dtype = torch.get_default_dtype()
|
79
|
-
self.max_seq_len_cached = seq_len
|
80
|
-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
81
|
-
|
82
|
-
freqs = torch.outer(t, self.inv_freq)
|
83
|
-
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
84
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
85
|
-
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
86
|
-
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
87
|
-
|
88
|
-
def forward(self, x, seq_len=None):
|
89
|
-
# x: [bs, num_attention_heads, seq_len, head_size]
|
90
|
-
if seq_len > self.max_seq_len_cached:
|
91
|
-
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
92
|
-
|
93
|
-
return (
|
94
|
-
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
95
|
-
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
96
|
-
)
|
97
|
-
|
98
|
-
|
99
|
-
class _LlamaAttention(LlamaAttention):
|
100
|
-
def forward(
|
101
|
-
self,
|
102
|
-
hidden_states: torch.Tensor,
|
103
|
-
attention_mask: Optional[torch.Tensor] = None,
|
104
|
-
position_ids: Optional[torch.LongTensor] = None,
|
105
|
-
past_key_value: Optional[Cache] = None,
|
106
|
-
output_attentions: bool = False,
|
107
|
-
use_cache: bool = False,
|
108
|
-
**kwargs,
|
109
|
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
110
|
-
bsz, q_len, _ = hidden_states.size()
|
111
|
-
|
112
|
-
if self.config.pretraining_tp > 1:
|
113
|
-
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
|
114
|
-
query_slices = self.q_proj.weight.split(
|
115
|
-
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
|
116
|
-
)
|
117
|
-
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
118
|
-
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
119
|
-
|
120
|
-
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
|
121
|
-
query_states = torch.cat(query_states, dim=-1)
|
122
|
-
|
123
|
-
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
|
124
|
-
key_states = torch.cat(key_states, dim=-1)
|
125
|
-
|
126
|
-
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
|
127
|
-
value_states = torch.cat(value_states, dim=-1)
|
128
|
-
|
129
|
-
else:
|
130
|
-
query_states = self.q_proj(hidden_states)
|
131
|
-
key_states = self.k_proj(hidden_states)
|
132
|
-
value_states = self.v_proj(hidden_states)
|
133
|
-
|
134
|
-
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
135
|
-
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
136
|
-
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
137
|
-
|
138
|
-
kv_seq_len = key_states.shape[-2]
|
139
|
-
if past_key_value is not None:
|
140
|
-
if self.layer_idx is None:
|
141
|
-
raise ValueError(
|
142
|
-
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
143
|
-
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
144
|
-
"with a layer index."
|
145
|
-
)
|
146
|
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
147
|
-
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
148
|
-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
149
|
-
|
150
|
-
# change to remove repeat
|
151
|
-
key_states = key_states.unsqueeze(2)
|
152
|
-
value_states = value_states.unsqueeze(2)
|
153
|
-
query_states = query_states.view(
|
154
|
-
bsz, self.num_key_value_heads, self.num_heads // self.num_key_value_heads, q_len, self.head_dim
|
155
|
-
)
|
156
|
-
|
157
|
-
if past_key_value is not None:
|
158
|
-
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
159
|
-
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
160
|
-
|
161
|
-
# change to remove repeat
|
162
|
-
# key_states = repeat_kv(key_states, self.num_key_value_groups)
|
163
|
-
# value_states = repeat_kv(value_states, self.num_key_value_groups)
|
164
|
-
|
165
|
-
# attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
166
|
-
|
167
|
-
attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) / math.sqrt(self.head_dim)
|
168
|
-
|
169
|
-
# change to remove repeat
|
170
|
-
# if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
171
|
-
# raise ValueError(
|
172
|
-
# f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
173
|
-
# f" {attn_weights.size()}"
|
174
|
-
# )
|
175
|
-
|
176
|
-
if attention_mask is not None:
|
177
|
-
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
178
|
-
raise ValueError(
|
179
|
-
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
180
|
-
)
|
181
|
-
else:
|
182
|
-
# change to remove repeat
|
183
|
-
attention_mask = attention_mask.unsqueeze(2)
|
184
|
-
|
185
|
-
attn_weights = attn_weights + attention_mask
|
186
|
-
|
187
|
-
# upcast attention to fp32
|
188
|
-
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
189
|
-
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
|
190
|
-
attn_output = torch.matmul(attn_weights, value_states)
|
191
|
-
|
192
|
-
# change to remove repeat
|
193
|
-
attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
|
194
|
-
|
195
|
-
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
196
|
-
raise ValueError(
|
197
|
-
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
198
|
-
f" {attn_output.size()}"
|
199
|
-
)
|
200
|
-
|
201
|
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
202
|
-
|
203
|
-
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
204
|
-
|
205
|
-
if self.config.pretraining_tp > 1:
|
206
|
-
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
|
207
|
-
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
|
208
|
-
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
|
209
|
-
else:
|
210
|
-
attn_output = self.o_proj(attn_output)
|
211
|
-
|
212
|
-
if not output_attentions:
|
213
|
-
attn_weights = None
|
214
|
-
|
215
|
-
return attn_output, attn_weights, past_key_value
|
216
|
-
|
217
|
-
|
218
|
-
class _LlamaDecoderLayer(LlamaDecoderLayer):
|
219
|
-
def forward(
|
220
|
-
self,
|
221
|
-
hidden_states: torch.Tensor,
|
222
|
-
attention_mask: Optional[torch.Tensor] = None,
|
223
|
-
position_ids: Optional[torch.LongTensor] = None,
|
224
|
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
225
|
-
output_attentions: Optional[bool] = False,
|
226
|
-
use_cache: Optional[bool] = False,
|
227
|
-
**kwargs,
|
228
|
-
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
229
|
-
residual = hidden_states
|
230
|
-
|
231
|
-
hidden_states = self.input_layernorm(hidden_states)
|
232
|
-
|
233
|
-
# Self Attention
|
234
|
-
hidden_states, self_attn_weights, present_key_value = _LlamaAttention.forward(
|
235
|
-
self.self_attn,
|
236
|
-
hidden_states=hidden_states,
|
237
|
-
attention_mask=attention_mask,
|
238
|
-
position_ids=position_ids,
|
239
|
-
past_key_value=past_key_value,
|
240
|
-
output_attentions=output_attentions,
|
241
|
-
use_cache=use_cache,
|
242
|
-
**kwargs,
|
243
|
-
)
|
244
|
-
hidden_states = residual + hidden_states
|
245
|
-
|
246
|
-
# Fully Connected
|
247
|
-
residual = hidden_states
|
248
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
249
|
-
hidden_states = self.mlp(hidden_states)
|
250
|
-
hidden_states = residual + hidden_states
|
251
|
-
|
252
|
-
outputs = (hidden_states,)
|
253
|
-
|
254
|
-
if output_attentions:
|
255
|
-
outputs += (self_attn_weights,)
|
256
|
-
|
257
|
-
if use_cache:
|
258
|
-
outputs += (present_key_value,)
|
259
|
-
|
260
|
-
return outputs
|
261
|
-
|
262
|
-
|
263
|
-
class _LlamaModel(LlamaModel):
|
264
|
-
def forward(
|
265
|
-
self,
|
266
|
-
input_ids: torch.LongTensor = None,
|
267
|
-
attention_mask: Optional[torch.Tensor] = None,
|
268
|
-
position_ids: Optional[int] = None,
|
269
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
270
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
271
|
-
use_cache: Optional[bool] = None,
|
272
|
-
output_attentions: Optional[bool] = None,
|
273
|
-
output_hidden_states: Optional[bool] = None,
|
274
|
-
return_dict: Optional[bool] = None,
|
275
|
-
) -> Union[Tuple, BaseModelOutputWithPast]:
|
276
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
277
|
-
output_hidden_states = (
|
278
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
279
|
-
)
|
280
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
281
|
-
|
282
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
283
|
-
|
284
|
-
#### cannot change forward args? temporal workaround ####
|
285
|
-
current_step = position_ids
|
286
|
-
|
287
|
-
# retrieve input_ids and inputs_embeds
|
288
|
-
if input_ids is not None and inputs_embeds is not None:
|
289
|
-
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
290
|
-
elif input_ids is not None:
|
291
|
-
batch_size, seq_length = input_ids.shape[:2]
|
292
|
-
elif inputs_embeds is not None:
|
293
|
-
batch_size, seq_length = inputs_embeds.shape[:2]
|
294
|
-
else:
|
295
|
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
296
|
-
|
297
|
-
past_key_values_length = 0
|
298
|
-
if use_cache:
|
299
|
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
300
|
-
if use_legacy_cache:
|
301
|
-
past_key_values = RebelDynamicCache.from_legacy_cache(
|
302
|
-
current_step=current_step,
|
303
|
-
max_length=self.config.max_position_embeddings,
|
304
|
-
past_key_values=past_key_values,
|
305
|
-
)
|
306
|
-
|
307
|
-
# not used, get_usable_length will be changed
|
308
|
-
# past_key_values_length = past_key_values.get_usable_length(seq_length)
|
309
|
-
|
310
|
-
#### position embedding indice ####
|
311
|
-
position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
|
312
|
-
|
313
|
-
if position_ids is None:
|
314
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
315
|
-
position_ids = torch.arange(
|
316
|
-
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
317
|
-
)
|
318
|
-
position_ids = position_ids.unsqueeze(0)
|
319
|
-
|
320
|
-
if inputs_embeds is None:
|
321
|
-
inputs_embeds = self.embed_tokens(input_ids)
|
322
|
-
|
323
|
-
# ##### original condition for generating causal attention mask
|
324
|
-
# if getattr(self.config, "_flash_attn_2_enabled", False):
|
325
|
-
# # 2d mask is passed through the layers
|
326
|
-
# attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
327
|
-
|
328
|
-
# else:
|
329
|
-
# # 4d mask is passed through the layers
|
330
|
-
# attention_mask = _prepare_4d_causal_attention_mask(
|
331
|
-
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
332
|
-
# )
|
333
|
-
# ########################################################
|
334
|
-
|
335
|
-
# yhboo changed for valid graph generation
|
336
|
-
if getattr(self.config, "_flash_attn_2_enabled", False):
|
337
|
-
# 2d mask is passed through the layers
|
338
|
-
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
339
|
-
raise NotImplementedError
|
340
|
-
|
341
|
-
elif attention_mask is not None and attention_mask.ndim == 4:
|
342
|
-
# assuming attention mask is generated as input
|
343
|
-
# assumed dim = [batch_size, 1, inp_seq, max_seq]
|
344
|
-
# only make [1, 0] mask to [0, -inf]
|
345
|
-
attention_mask = (1 - attention_mask) * torch.finfo(torch.float16).min
|
346
|
-
|
347
|
-
else:
|
348
|
-
# 4d mask is passed through the layers
|
349
|
-
attention_mask = _prepare_4d_causal_attention_mask(
|
350
|
-
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
351
|
-
)
|
352
|
-
|
353
|
-
# embed positions
|
354
|
-
hidden_states = inputs_embeds
|
355
|
-
|
356
|
-
# decoder layers
|
357
|
-
all_hidden_states = () if output_hidden_states else None
|
358
|
-
all_self_attns = () if output_attentions else None
|
359
|
-
next_decoder_cache = () if use_cache else None
|
360
|
-
|
361
|
-
for idx, decoder_layer in enumerate(self.layers):
|
362
|
-
if output_hidden_states:
|
363
|
-
all_hidden_states += (hidden_states,)
|
364
|
-
|
365
|
-
layer_outputs = _LlamaDecoderLayer.forward(
|
366
|
-
decoder_layer,
|
367
|
-
hidden_states,
|
368
|
-
attention_mask=attention_mask,
|
369
|
-
position_ids=position_ids,
|
370
|
-
past_key_value=past_key_values,
|
371
|
-
output_attentions=output_attentions,
|
372
|
-
use_cache=use_cache,
|
373
|
-
)
|
374
|
-
|
375
|
-
hidden_states = layer_outputs[0]
|
376
|
-
|
377
|
-
if use_cache:
|
378
|
-
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
379
|
-
|
380
|
-
if output_attentions:
|
381
|
-
all_self_attns += (layer_outputs[1],)
|
382
|
-
|
383
|
-
hidden_states = self.norm(hidden_states)
|
384
|
-
|
385
|
-
# add hidden states from the last decoder layer
|
386
|
-
if output_hidden_states:
|
387
|
-
all_hidden_states += (hidden_states,)
|
388
|
-
|
389
|
-
next_cache = None
|
390
|
-
if use_cache:
|
391
|
-
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
392
|
-
|
393
|
-
return BaseModelOutputWithPast(
|
394
|
-
last_hidden_state=hidden_states,
|
395
|
-
past_key_values=next_cache,
|
396
|
-
hidden_states=all_hidden_states,
|
397
|
-
attentions=all_self_attns,
|
398
|
-
)
|
399
|
-
|
400
|
-
|
401
|
-
class _LlamaForCausalLM(LlamaForCausalLM):
|
402
|
-
def forward(
|
403
|
-
self,
|
404
|
-
input_ids: torch.LongTensor = None,
|
405
|
-
attention_mask: Optional[torch.Tensor] = None,
|
406
|
-
position_ids: Optional[torch.LongTensor] = None,
|
407
|
-
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
408
|
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
409
|
-
labels: Optional[torch.LongTensor] = None,
|
410
|
-
use_cache: Optional[bool] = None,
|
411
|
-
output_attentions: Optional[bool] = None,
|
412
|
-
output_hidden_states: Optional[bool] = None,
|
413
|
-
return_dict: Optional[bool] = None,
|
414
|
-
) -> Union[Tuple, CausalLMOutputWithPast]:
|
415
|
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
416
|
-
output_hidden_states = (
|
417
|
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
418
|
-
)
|
419
|
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
420
|
-
|
421
|
-
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
422
|
-
outputs = self.model(
|
423
|
-
input_ids=input_ids,
|
424
|
-
attention_mask=attention_mask,
|
425
|
-
position_ids=position_ids,
|
426
|
-
past_key_values=past_key_values,
|
427
|
-
inputs_embeds=inputs_embeds,
|
428
|
-
use_cache=use_cache,
|
429
|
-
output_attentions=output_attentions,
|
430
|
-
output_hidden_states=output_hidden_states,
|
431
|
-
return_dict=return_dict,
|
432
|
-
)
|
433
|
-
|
434
|
-
hidden_states = outputs[0]
|
435
|
-
logits = self.lm_head(hidden_states)
|
436
|
-
logits = logits.float()
|
437
|
-
|
438
|
-
if not return_dict:
|
439
|
-
output = (logits,) + outputs[1:]
|
440
|
-
return output
|
441
|
-
|
442
|
-
return CausalLMOutputWithPast(
|
443
|
-
logits=logits,
|
444
|
-
past_key_values=outputs.past_key_values,
|
445
|
-
hidden_states=outputs.hidden_states,
|
446
|
-
attentions=outputs.attentions,
|
447
|
-
)
|
448
|
-
|
449
|
-
|
450
|
-
class RebelDynamicCache(DynamicCache):
|
451
|
-
"""
|
452
|
-
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
453
|
-
|
454
|
-
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
455
|
-
`[batch_size, num_heads, seq_len, head_dim]`.
|
456
|
-
"""
|
457
|
-
|
458
|
-
def __init__(self, current_step, max_length) -> None:
|
459
|
-
super().__init__()
|
460
|
-
self.current_step = current_step
|
461
|
-
self.max_length = max_length
|
462
|
-
|
463
|
-
def copy(
|
464
|
-
self,
|
465
|
-
key_states: torch.Tensor,
|
466
|
-
value_states: torch.Tensor,
|
467
|
-
layer_idx: int,
|
468
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
469
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
470
|
-
"""
|
471
|
-
Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
472
|
-
just for from_legacy_cache function
|
473
|
-
|
474
|
-
Parameters:
|
475
|
-
key_states (`torch.Tensor`):
|
476
|
-
The new key states to cache.
|
477
|
-
value_states (`torch.Tensor`):
|
478
|
-
The new value states to cache.
|
479
|
-
layer_idx (`int`):
|
480
|
-
The index of the layer to cache the states for.
|
481
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
482
|
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
483
|
-
|
484
|
-
Return:
|
485
|
-
A tuple containing the updated key and value states.
|
486
|
-
"""
|
487
|
-
# Update the number of seen tokens : deprecated
|
488
|
-
# if layer_idx == 0:
|
489
|
-
# self.seen_tokens += key_states.shape[-2]
|
490
|
-
|
491
|
-
# Update the cache
|
492
|
-
if len(self.key_cache) <= layer_idx:
|
493
|
-
self.key_cache.append(key_states)
|
494
|
-
self.value_cache.append(value_states)
|
495
|
-
else:
|
496
|
-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
497
|
-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
498
|
-
|
499
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
500
|
-
|
501
|
-
def update(
|
502
|
-
self,
|
503
|
-
key_states: torch.Tensor,
|
504
|
-
value_states: torch.Tensor,
|
505
|
-
layer_idx: int,
|
506
|
-
cache_kwargs: Optional[Dict[str, Any]] = None,
|
507
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
508
|
-
"""
|
509
|
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
|
510
|
-
based on self.current_step,
|
511
|
-
|
512
|
-
Parameters:
|
513
|
-
key_states (`torch.Tensor`):
|
514
|
-
The new key states to cache.
|
515
|
-
value_states (`torch.Tensor`):
|
516
|
-
The new value states to cache.
|
517
|
-
layer_idx (`int`):
|
518
|
-
The index of the layer to cache the states for.
|
519
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
520
|
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
521
|
-
|
522
|
-
Return:
|
523
|
-
A tuple containing the updated key and value states.
|
524
|
-
"""
|
525
|
-
# # Update the number of seen tokens : deprecated
|
526
|
-
# if layer_idx == 0:
|
527
|
-
# self.seen_tokens += key_states.shape[-2]
|
528
|
-
|
529
|
-
# Update the cache
|
530
|
-
if len(self.key_cache) <= layer_idx:
|
531
|
-
self.key_cache.append(key_states)
|
532
|
-
self.value_cache.append(value_states)
|
533
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
534
|
-
else:
|
535
|
-
# change to remove repeat
|
536
|
-
# self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
537
|
-
# key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
538
|
-
# )
|
539
|
-
# self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
540
|
-
# value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
541
|
-
# )
|
542
|
-
updated_key = (
|
543
|
-
self.key_cache[layer_idx]
|
544
|
-
.unsqueeze(2)
|
545
|
-
.slice_scatter(
|
546
|
-
key_states, dim=-2, start=self.current_step, end=self.current_step + key_states.shape[-2]
|
547
|
-
)
|
548
|
-
)
|
549
|
-
updated_value = (
|
550
|
-
self.value_cache[layer_idx]
|
551
|
-
.unsqueeze(2)
|
552
|
-
.slice_scatter(
|
553
|
-
value_states, dim=-2, start=self.current_step, end=self.current_step + value_states.shape[-2]
|
554
|
-
)
|
555
|
-
)
|
556
|
-
self.key_cache[layer_idx] = updated_key.squeeze(2)
|
557
|
-
self.value_cache[layer_idx] = updated_value.squeeze(2)
|
558
|
-
return updated_key, updated_value
|
559
|
-
|
560
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
561
|
-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
562
|
-
if len(self.key_cache) <= layer_idx:
|
563
|
-
return 0
|
564
|
-
return self.key_cache[layer_idx].shape[-2]
|
565
|
-
|
566
|
-
def get_max_length(self) -> Optional[int]:
|
567
|
-
return self.max_length
|
568
|
-
|
569
|
-
@classmethod
|
570
|
-
def from_legacy_cache(
|
571
|
-
cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
572
|
-
) -> "DynamicCache":
|
573
|
-
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
574
|
-
cache = cls(current_step, max_length)
|
575
|
-
if past_key_values is not None:
|
576
|
-
for layer_idx in range(len(past_key_values)):
|
577
|
-
key_states, value_states = past_key_values[layer_idx]
|
578
|
-
cache.copy(key_states, value_states, layer_idx)
|
579
|
-
return cache
|
580
|
-
|
581
|
-
|
582
|
-
def rotate_half(x):
|
583
|
-
"""Rotates half the hidden dims of the input."""
|
584
|
-
x1 = x[..., : x.shape[-1] // 2]
|
585
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
586
|
-
return torch.cat((-x2, x1), dim=-1)
|
587
|
-
|
588
|
-
|
589
|
-
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
590
|
-
"""Applies Rotary Position Embedding to the query and key tensors.
|
591
|
-
|
592
|
-
Args:
|
593
|
-
q (`torch.Tensor`): The query tensor.
|
594
|
-
k (`torch.Tensor`): The key tensor.
|
595
|
-
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
596
|
-
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
597
|
-
position_ids (`torch.Tensor`):
|
598
|
-
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
599
|
-
used to pass offsetted position ids when working with a KV-cache.
|
600
|
-
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
601
|
-
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
602
|
-
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
603
|
-
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
604
|
-
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
605
|
-
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
606
|
-
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
607
|
-
Returns:
|
608
|
-
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
609
|
-
"""
|
610
|
-
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
611
|
-
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
612
|
-
q_embed = (q * cos) + (rotate_half(q) * sin)
|
613
|
-
k_embed = (k * cos) + (rotate_half(k) * sin)
|
614
|
-
return q_embed, k_embed
|
615
|
-
|
616
|
-
|
617
|
-
origin_methods = {}
|
618
|
-
origin_methods["LlamaRotaryEmbedding_INIT"] = LlamaRotaryEmbedding.__init__
|
619
|
-
origin_methods["LlamaRotaryEmbedding_forward"] = LlamaRotaryEmbedding.forward
|
620
|
-
origin_methods["LlamaModel_forward"] = LlamaModel.forward
|
621
|
-
origin_methods["LlamaForCausalLM_forward"] = LlamaForCausalLM.forward
|
622
|
-
|
623
|
-
|
624
|
-
def wrap_llama():
|
625
|
-
LlamaRotaryEmbedding.__init__ = _LlamaRotaryEmbedding.__init__
|
626
|
-
LlamaRotaryEmbedding.forward = _LlamaRotaryEmbedding.forward
|
627
|
-
LlamaModel.forward = _LlamaModel.forward
|
628
|
-
LlamaForCausalLM.forward = _LlamaForCausalLM.forward
|
629
|
-
|
630
|
-
|
631
|
-
def unwrap_llama():
|
632
|
-
global origin_methods
|
633
|
-
LlamaRotaryEmbedding.__init__ = origin_methods["LlamaRotaryEmbedding_INIT"]
|
634
|
-
LlamaRotaryEmbedding.forward = origin_methods["LlamaRotaryEmbedding_forward"]
|
635
|
-
LlamaModel.forward = origin_methods["LlamaModel_forward"]
|
636
|
-
LlamaForCausalLM.forward = origin_methods["LlamaForCausalLM_forward"]
|
28
|
+
class LlamaWrapper(DecoderOnlyWrapper):
|
29
|
+
pass
|