optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +17 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/models/controlnet.py +7 -3
- optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/modeling_alias.py +19 -1
- optimum/rbln/modeling_base.py +162 -18
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
- optimum/rbln/utils/import_utils.py +1 -4
- optimum/rbln/utils/runtime_utils.py +2 -1
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -21,155 +21,108 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
|
25
|
-
from typing import Dict, Optional, Tuple, Union
|
24
|
+
from typing import Optional, Tuple, Union
|
26
25
|
|
27
26
|
import torch
|
28
27
|
import torch.nn as nn
|
29
|
-
from transformers.
|
30
|
-
from transformers.modeling_outputs import (
|
31
|
-
BaseModelOutputWithPastAndCrossAttentions,
|
32
|
-
)
|
28
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
33
29
|
|
34
|
-
from .
|
35
|
-
|
36
|
-
|
37
|
-
|
30
|
+
from ....transformers.models.decoderonly.decoderonly_architecture import (
|
31
|
+
RotaryEmbedding,
|
32
|
+
rotate_half,
|
33
|
+
slice_and_unsqueeze_cos_sin,
|
38
34
|
)
|
35
|
+
from ...cache_utils import RebelDynamicCache_4D
|
39
36
|
|
40
37
|
|
41
|
-
|
42
|
-
"""
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
def __init__(
|
47
|
-
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
|
48
|
-
):
|
49
|
-
"""
|
50
|
-
Args:
|
51
|
-
|
52
|
-
dim (int): rotary embedding dimension
|
53
|
-
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
|
54
|
-
by this factor via the trick in https://arxiv.org/abs/2306.15595.
|
55
|
-
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
|
56
|
-
"""
|
57
|
-
super().__init__()
|
58
|
-
self.seq_len_interpolation_factor = seq_len_interpolation_factor
|
59
|
-
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
60
|
-
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
61
|
-
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
|
62
|
-
|
63
|
-
seq_len = pretrained_max_position_embeddings
|
64
|
-
device = self.inv_freq.device
|
65
|
-
dtype = torch.get_default_dtype()
|
66
|
-
self.max_seq_len_cached = seq_len
|
67
|
-
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
68
|
-
|
69
|
-
freqs = torch.outer(t, self.inv_freq)
|
70
|
-
|
71
|
-
emb = torch.cat((freqs, freqs), dim=-1)
|
72
|
-
self.register_buffer("emb_cached", emb.to(dtype), persistent=False)
|
73
|
-
|
74
|
-
def forward(self, max_seq_len, offset=0):
|
75
|
-
|
76
|
-
if max_seq_len > self.max_seq_len_cached:
|
77
|
-
self._set_emb_cache(seq_len=max_seq_len)
|
78
|
-
|
79
|
-
return self.emb_cached[:max_seq_len]
|
38
|
+
def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
|
39
|
+
"""Applies rotary position embedding to the specified dimension of the tensor."""
|
40
|
+
tensor_, tensor_pass = tensor[..., :rot_dim], tensor[..., rot_dim:]
|
41
|
+
tensor_embed = (tensor_ * cos) + (rotate_half(tensor_) * sin)
|
42
|
+
return torch.cat((tensor_embed, tensor_pass), dim=-1)
|
80
43
|
|
81
44
|
|
82
|
-
def
|
83
|
-
"""
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
x2 = x[..., x.shape[-1] // 2 :]
|
89
|
-
return torch.cat((-x2, x1), dim=-1)
|
90
|
-
|
91
|
-
|
92
|
-
def apply_rotary_pos_emb(t: torch.Tensor, cache_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
|
93
|
-
"""
|
94
|
-
input tensor t is of shape [seq_length, ..., dim]
|
95
|
-
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
|
96
|
-
check https://kexue.fm/archives/8265 for detailed formulas
|
97
|
-
"""
|
98
|
-
|
99
|
-
freqs = cache_kwargs["rotary_pos_emb"]
|
100
|
-
position_ids = cache_kwargs["position_ids"]
|
101
|
-
unsqueeze_dim = 1
|
102
|
-
|
103
|
-
rot_dim = freqs.shape[-1]
|
104
|
-
|
105
|
-
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
|
106
|
-
cos = freqs.cos()[position_ids].unsqueeze(unsqueeze_dim)
|
107
|
-
sin = freqs.sin()[position_ids].unsqueeze(unsqueeze_dim)
|
108
|
-
|
109
|
-
embed = (t * cos) + (_rotate_half(t) * sin)
|
110
|
-
embed = torch.cat((embed, t_pass), dim=-1)
|
111
|
-
|
112
|
-
return embed
|
45
|
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
46
|
+
"""Applies Rotary Position Embedding to the query and key tensors."""
|
47
|
+
rot_dim = cos.shape[-1]
|
48
|
+
q_embed = apply_rotary_to_tensor(q, cos, sin, rot_dim)
|
49
|
+
k_embed = apply_rotary_to_tensor(k, cos, sin, rot_dim)
|
50
|
+
return q_embed, k_embed
|
113
51
|
|
114
52
|
|
115
53
|
class MidmLMHeadModelWrapper(torch.nn.Module):
|
116
|
-
|
54
|
+
"""A wrapper class for the Midm model with a language modeling head."""
|
55
|
+
|
56
|
+
def __init__(self, model, max_seq_len):
|
117
57
|
super().__init__()
|
118
|
-
self.model = model
|
119
|
-
self.
|
120
|
-
|
121
|
-
self.
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
58
|
+
self.model = model.transformer
|
59
|
+
self.lm_head = model.lm_head
|
60
|
+
self.config = model.config
|
61
|
+
self.head_dim = self.config.n_embd // self.config.n_head
|
62
|
+
self.max_position_embeddings = (
|
63
|
+
self.config.max_position_embeddings if max_seq_len > self.config.max_position_embeddings else max_seq_len
|
64
|
+
)
|
65
|
+
self.max_seq_len = max_seq_len
|
66
|
+
self.rotary_dim = int(
|
67
|
+
model.config.hidden_size // model.config.num_attention_heads * model.config.rotary_percentage
|
68
|
+
)
|
69
|
+
self.rotary_emb = self._init_rope()
|
70
|
+
|
71
|
+
def _init_rope(self):
|
72
|
+
"""Initializes the Rotary Position Embeddings."""
|
73
|
+
rotary_emb = RotaryEmbedding(
|
74
|
+
self.rotary_dim,
|
75
|
+
max_position_embeddings=self.max_position_embeddings,
|
76
|
+
)
|
77
|
+
return rotary_emb
|
132
78
|
|
133
79
|
def forward(
|
134
80
|
self,
|
135
81
|
input_ids: torch.Tensor,
|
136
82
|
attention_mask: torch.Tensor,
|
137
83
|
cache_position: torch.LongTensor,
|
84
|
+
batch_position: int,
|
138
85
|
*past_key_values,
|
139
86
|
):
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
self.
|
87
|
+
"""Defines the forward pass for the wrapper model."""
|
88
|
+
if input_ids.shape[1] == 1:
|
89
|
+
rbln_batch_position = None
|
90
|
+
else:
|
91
|
+
rbln_batch_position = batch_position
|
92
|
+
|
93
|
+
past_key_values = RebelDynamicCache_4D.from_input_format(
|
94
|
+
cache_position,
|
95
|
+
self.config.num_hidden_layers,
|
96
|
+
*past_key_values,
|
97
|
+
)
|
98
|
+
|
99
|
+
outputs = _MidmModel.forward(
|
100
|
+
self.model,
|
149
101
|
input_ids=input_ids,
|
150
|
-
past_key_values=
|
102
|
+
past_key_values=past_key_values,
|
151
103
|
attention_mask=attention_mask,
|
152
104
|
position_ids=cache_position,
|
153
|
-
rotary_pos_emb=self.
|
105
|
+
rotary_pos_emb=self.rotary_emb,
|
106
|
+
batch_ids=rbln_batch_position,
|
154
107
|
)
|
155
108
|
|
156
|
-
hidden_states =
|
109
|
+
hidden_states = outputs[0]
|
110
|
+
logits = self.lm_head(hidden_states)
|
111
|
+
output = (logits,) + outputs[1:]
|
157
112
|
|
158
|
-
|
159
|
-
# This assumption allows us to bypass dynamic indexing.
|
160
|
-
hidden_states = hidden_states[:, -1:]
|
161
|
-
lm_logits = self.model.lm_head(hidden_states)
|
162
|
-
kv_cache = transformer_outputs[1]
|
163
|
-
|
164
|
-
return lm_logits, kv_cache
|
113
|
+
return output, batch_position
|
165
114
|
|
166
115
|
|
167
116
|
def layernorm1p(module, input):
|
117
|
+
"""Applies Layer Normalization with a slight modification on the weights."""
|
168
118
|
return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
|
169
119
|
|
170
120
|
|
171
|
-
class _MidmAttention
|
121
|
+
class _MidmAttention:
|
122
|
+
"""Custom implementation of the MidmAttention class with specific modifications."""
|
123
|
+
|
172
124
|
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
125
|
+
"""Computes the attention weights and output."""
|
173
126
|
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
174
127
|
|
175
128
|
if self.scale_attn_weights:
|
@@ -187,320 +140,170 @@ class _MidmAttention(MidmAttention):
|
|
187
140
|
attn_weights = attn_weights * float(self.layer_idx + 1)
|
188
141
|
|
189
142
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
190
|
-
|
191
143
|
attn_weights = attn_weights.type(value.dtype)
|
192
144
|
|
193
145
|
if head_mask is not None:
|
194
146
|
attn_weights = attn_weights * head_mask
|
195
147
|
|
196
148
|
attn_output = torch.matmul(attn_weights, value)
|
197
|
-
|
198
149
|
return attn_output, attn_weights
|
199
150
|
|
200
151
|
def forward(
|
201
152
|
self,
|
202
153
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
203
154
|
attention_mask: Optional[torch.FloatTensor] = None,
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
rotary_pos_emb=None,
|
155
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
156
|
+
batch_index: Optional[int] = None,
|
157
|
+
cos: Optional[torch.Tensor] = None,
|
158
|
+
sin: Optional[torch.Tensor] = None,
|
209
159
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
160
|
+
"""Defines the forward pass for the attention mechanism."""
|
161
|
+
bsz, q_len, _ = hidden_states.size()
|
162
|
+
|
163
|
+
querys, keys, values = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
164
|
+
|
165
|
+
querys = self._split_heads(querys, self.num_heads, self.head_dim).contiguous()
|
166
|
+
keys = self._split_heads(keys, self.num_heads, self.head_dim).contiguous()
|
167
|
+
values = self._split_heads(values, self.num_heads, self.head_dim).contiguous()
|
168
|
+
|
169
|
+
querys, keys = apply_rotary_pos_emb(querys, keys, cos, sin)
|
170
|
+
|
171
|
+
# Decoder
|
172
|
+
if (batch_index is None or batch_index == -1) and bsz > 1:
|
173
|
+
all_key_states = []
|
174
|
+
all_value_states = []
|
175
|
+
all_attn_output = []
|
176
|
+
|
177
|
+
for b in range(bsz):
|
178
|
+
query = querys[b].unsqueeze(0)
|
179
|
+
attn_mask = attention_mask[b].unsqueeze(0)
|
180
|
+
key = keys[b].unsqueeze(0)
|
181
|
+
value = values[b].unsqueeze(0)
|
182
|
+
|
183
|
+
key, value = past_key_value.update(
|
184
|
+
key,
|
185
|
+
value,
|
186
|
+
self.layer_idx,
|
187
|
+
b,
|
222
188
|
)
|
223
|
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
224
189
|
|
225
|
-
|
226
|
-
|
227
|
-
else:
|
228
|
-
present = None
|
190
|
+
attn_output, _ = _MidmAttention._attn(self, query, key, value, attn_mask)
|
191
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
229
192
|
|
230
|
-
|
231
|
-
|
232
|
-
|
193
|
+
all_key_states.append(key)
|
194
|
+
all_value_states.append(value)
|
195
|
+
all_attn_output.append(attn_output)
|
233
196
|
|
234
|
-
|
235
|
-
|
197
|
+
keys = torch.cat(all_key_states, dim=0)
|
198
|
+
values = torch.cat(all_value_states, dim=0)
|
199
|
+
attn_output = torch.cat(all_attn_output, dim=0)
|
236
200
|
|
237
|
-
|
201
|
+
else:
|
202
|
+
if batch_index is None or batch_index == -1:
|
203
|
+
batch_index = 0
|
204
|
+
|
205
|
+
keys, values = past_key_value.update(
|
206
|
+
keys,
|
207
|
+
values,
|
208
|
+
self.layer_idx,
|
209
|
+
batch_index,
|
210
|
+
read_first_step=True,
|
211
|
+
)
|
238
212
|
|
239
|
-
|
213
|
+
attn_output, _ = _MidmAttention._attn(self, querys, keys, values, attention_mask)
|
214
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
240
215
|
|
241
216
|
attn_output = self.c_proj(attn_output)
|
242
|
-
attn_output
|
243
|
-
|
244
|
-
outputs = (attn_output, present)
|
217
|
+
return attn_output, keys, values
|
245
218
|
|
246
|
-
return outputs, past_key_value
|
247
219
|
|
220
|
+
class _MidmBlock:
|
221
|
+
"""Custom implementation of the MidmBlock class with specific modifications."""
|
248
222
|
|
249
|
-
class _MidmBlock(MidmBlock):
|
250
223
|
def forward(
|
251
224
|
self,
|
252
225
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
226
|
+
layer_idx: int,
|
253
227
|
attention_mask: Optional[torch.FloatTensor] = None,
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
rotary_pos_emb=None,
|
228
|
+
past_key_value: Optional[RebelDynamicCache_4D] = None,
|
229
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
230
|
+
cos: Optional[torch.Tensor] = None,
|
231
|
+
sin: Optional[torch.Tensor] = None,
|
259
232
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
233
|
+
"""Defines the forward pass for the block."""
|
260
234
|
residual = hidden_states
|
261
235
|
if self.use_layernorm1p:
|
262
236
|
hidden_states = layernorm1p(self.ln_1, hidden_states)
|
263
237
|
else:
|
264
238
|
hidden_states = self.ln_1(hidden_states)
|
265
239
|
|
266
|
-
|
240
|
+
hidden_states, k, v = _MidmAttention.forward(
|
267
241
|
self.attn,
|
268
242
|
hidden_states,
|
269
243
|
attention_mask=attention_mask,
|
270
|
-
position_ids=position_ids,
|
271
244
|
past_key_value=past_key_value,
|
272
|
-
|
273
|
-
|
274
|
-
|
245
|
+
cos=cos,
|
246
|
+
sin=sin,
|
247
|
+
batch_index=batch_ids,
|
275
248
|
)
|
249
|
+
past_key_value.assign(k, v, layer_idx)
|
276
250
|
|
277
|
-
|
278
|
-
outputs = attn_outputs[1:]
|
279
|
-
|
280
|
-
hidden_states = attn_output + residual
|
251
|
+
hidden_states = hidden_states + residual
|
281
252
|
|
282
253
|
residual = hidden_states
|
283
254
|
if self.use_layernorm1p:
|
284
255
|
hidden_states = layernorm1p(self.ln_2, hidden_states)
|
285
256
|
else:
|
286
257
|
hidden_states = self.ln_2(hidden_states)
|
287
|
-
feed_forward_hidden_states = self.mlp(hidden_states)
|
288
258
|
|
259
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
289
260
|
hidden_states = residual + feed_forward_hidden_states
|
290
261
|
|
291
|
-
|
292
|
-
outputs = (hidden_states,) + outputs
|
293
|
-
else:
|
294
|
-
outputs = (hidden_states,) + outputs[1:]
|
295
|
-
|
296
|
-
if use_cache:
|
297
|
-
outputs += (present_key_value,)
|
262
|
+
return hidden_states, past_key_value
|
298
263
|
|
299
|
-
return outputs
|
300
264
|
|
265
|
+
class _MidmModel:
|
266
|
+
"""Custom implementation of the MidmModel class with specific modifications."""
|
301
267
|
|
302
|
-
class _MidmModel(MidmModel):
|
303
268
|
def forward(
|
304
269
|
self,
|
305
270
|
input_ids: Optional[torch.LongTensor] = None,
|
306
|
-
past_key_values: Optional[
|
271
|
+
past_key_values: Optional[RebelDynamicCache_4D] = None,
|
307
272
|
attention_mask: Optional[torch.FloatTensor] = None,
|
308
|
-
token_type_ids=None,
|
309
273
|
position_ids: Optional[torch.LongTensor] = None,
|
310
274
|
rotary_pos_emb=None,
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
275
|
+
batch_ids: Optional[torch.LongTensor] = None,
|
276
|
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
277
|
+
"""Defines the forward pass for the model."""
|
315
278
|
input_shape = input_ids.size()
|
316
279
|
|
317
|
-
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
318
|
-
|
319
|
-
current_step = position_ids
|
320
|
-
|
321
|
-
if input_ids is not None and inputs_embeds is not None:
|
322
|
-
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
323
|
-
elif input_ids is not None:
|
324
|
-
batch_size, seq_length = input_ids.shape[:2]
|
325
|
-
elif inputs_embeds is not None:
|
326
|
-
batch_size, seq_length = inputs_embeds.shape[:2]
|
327
|
-
else:
|
328
|
-
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
329
|
-
|
330
|
-
if token_type_ids is not None:
|
331
|
-
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
332
|
-
|
333
|
-
past_key_values_length = 0
|
334
|
-
|
335
|
-
if use_cache:
|
336
|
-
use_legacy_cache = not isinstance(past_key_values, Cache)
|
337
|
-
if use_legacy_cache:
|
338
|
-
past_key_values = RebelDynamicCache.from_legacy_cache(
|
339
|
-
current_step=current_step,
|
340
|
-
max_length=self.config.max_position_embeddings,
|
341
|
-
past_key_values=past_key_values,
|
342
|
-
)
|
343
|
-
|
344
|
-
position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
|
345
|
-
|
346
|
-
if position_ids is None:
|
347
|
-
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
348
|
-
position_ids = torch.arange(
|
349
|
-
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
350
|
-
)
|
351
|
-
position_ids = position_ids.unsqueeze(0)
|
352
|
-
|
353
280
|
attention_mask = (1.0 - attention_mask) * -10000.0
|
354
281
|
|
355
|
-
|
356
|
-
rotary_pos_emb = rotary_pos_emb(self.config.max_position_embeddings)
|
282
|
+
inputs_embeds = self.wte(input_ids)
|
357
283
|
|
358
|
-
|
359
|
-
|
360
|
-
if inputs_embeds is None:
|
361
|
-
inputs_embeds = self.wte(input_ids)
|
284
|
+
cos, sin = rotary_pos_emb(inputs_embeds, attention_mask.shape[-1])
|
285
|
+
cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
|
362
286
|
hidden_states = inputs_embeds
|
363
287
|
|
364
|
-
|
365
|
-
|
366
|
-
hidden_states = hidden_states + token_type_embeds
|
367
|
-
|
368
|
-
hidden_states = self.drop(hidden_states)
|
369
|
-
|
370
|
-
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
371
|
-
|
372
|
-
next_decoder_cache = () if use_cache else None
|
373
|
-
|
374
|
-
for i, (block, _) in enumerate(zip(self.h, past_key_values)):
|
375
|
-
outputs = _MidmBlock.forward(
|
288
|
+
for layer_idx, (block, _) in enumerate(zip(self.h, past_key_values)):
|
289
|
+
hidden_states, updated_cache = _MidmBlock.forward(
|
376
290
|
block,
|
377
291
|
hidden_states,
|
292
|
+
layer_idx,
|
378
293
|
attention_mask=attention_mask,
|
379
|
-
position_ids=position_ids,
|
380
294
|
past_key_value=past_key_values,
|
381
|
-
|
382
|
-
|
383
|
-
|
295
|
+
batch_ids=batch_ids,
|
296
|
+
cos=cos,
|
297
|
+
sin=sin,
|
384
298
|
)
|
385
|
-
hidden_states = outputs[0]
|
386
|
-
|
387
|
-
if use_cache:
|
388
|
-
next_decoder_cache = outputs[2]
|
389
299
|
|
390
|
-
|
391
|
-
|
392
|
-
else:
|
393
|
-
hidden_states = self.ln_f(hidden_states)
|
300
|
+
hidden_states = layernorm1p(self.ln_f, hidden_states)
|
301
|
+
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
394
302
|
hidden_states = hidden_states.view(output_shape)
|
395
303
|
|
396
|
-
next_cache =
|
397
|
-
if use_cache:
|
398
|
-
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
399
|
-
|
400
|
-
# return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
401
|
-
return hidden_states, next_cache
|
402
|
-
|
403
|
-
|
404
|
-
class RebelDynamicCache(DynamicCache):
|
405
|
-
"""
|
406
|
-
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
304
|
+
next_cache = updated_cache.to_legacy_cache()
|
407
305
|
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
def __init__(self, current_step, max_length) -> None:
|
413
|
-
super().__init__()
|
414
|
-
self.current_step = current_step
|
415
|
-
self.max_length = max_length
|
416
|
-
|
417
|
-
def copy(
|
418
|
-
self,
|
419
|
-
key_states: torch.Tensor,
|
420
|
-
value_states: torch.Tensor,
|
421
|
-
layer_idx: int,
|
422
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
423
|
-
"""
|
424
|
-
Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
425
|
-
just for from_legacy_cache function
|
426
|
-
|
427
|
-
Parameters:
|
428
|
-
key_states (`torch.Tensor`):
|
429
|
-
The new key states to cache.
|
430
|
-
value_states (`torch.Tensor`):
|
431
|
-
The new value states to cache.
|
432
|
-
layer_idx (`int`):
|
433
|
-
The index of the layer to cache the states for.
|
434
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
435
|
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
436
|
-
|
437
|
-
Return:
|
438
|
-
A tuple containing the updated key and value states.
|
439
|
-
"""
|
440
|
-
|
441
|
-
if len(self.key_cache) <= layer_idx:
|
442
|
-
self.key_cache.append(key_states)
|
443
|
-
self.value_cache.append(value_states)
|
444
|
-
else:
|
445
|
-
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
446
|
-
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
447
|
-
|
448
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
449
|
-
|
450
|
-
def update(
|
451
|
-
self,
|
452
|
-
key_states: torch.Tensor,
|
453
|
-
value_states: torch.Tensor,
|
454
|
-
layer_idx: int,
|
455
|
-
) -> Tuple[torch.Tensor, torch.Tensor]:
|
456
|
-
"""
|
457
|
-
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
|
458
|
-
based on self.current_step,
|
459
|
-
|
460
|
-
Parameters:
|
461
|
-
key_states (`torch.Tensor`):
|
462
|
-
The new key states to cache.
|
463
|
-
value_states (`torch.Tensor`):
|
464
|
-
The new value states to cache.
|
465
|
-
layer_idx (`int`):
|
466
|
-
The index of the layer to cache the states for.
|
467
|
-
cache_kwargs (`Dict[str, Any]`, `optional`):
|
468
|
-
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
469
|
-
|
470
|
-
Return:
|
471
|
-
A tuple containing the updated key and value states.
|
472
|
-
"""
|
473
|
-
|
474
|
-
if len(self.key_cache) <= layer_idx:
|
475
|
-
self.key_cache.append(key_states)
|
476
|
-
self.value_cache.append(value_states)
|
477
|
-
else:
|
478
|
-
self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
479
|
-
key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
480
|
-
)
|
481
|
-
self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
482
|
-
value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
483
|
-
)
|
484
|
-
|
485
|
-
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
486
|
-
|
487
|
-
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
488
|
-
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
489
|
-
if len(self.key_cache) <= layer_idx:
|
490
|
-
return 0
|
491
|
-
return self.key_cache[layer_idx].shape[-2]
|
492
|
-
|
493
|
-
def get_max_length(self) -> Optional[int]:
|
494
|
-
return self.max_length
|
495
|
-
|
496
|
-
@classmethod
|
497
|
-
def from_legacy_cache(
|
498
|
-
cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
499
|
-
) -> "DynamicCache":
|
500
|
-
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
501
|
-
cache = cls(current_step, max_length)
|
502
|
-
if past_key_values is not None:
|
503
|
-
for layer_idx in range(len(past_key_values)):
|
504
|
-
key_states, value_states = past_key_values[layer_idx]
|
505
|
-
cache.copy(key_states, value_states, layer_idx)
|
506
|
-
return cache
|
306
|
+
return BaseModelOutputWithPast(
|
307
|
+
last_hidden_state=hidden_states,
|
308
|
+
past_key_values=next_cache,
|
309
|
+
)
|