optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +9 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +175 -103
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +4 -0
- optimum/rbln/transformers/models/__init__.py +2 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
- optimum/rbln/transformers/models/midm/__init__.py +32 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
- optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,98 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
3
|
+
#
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
# you may not use this file except in compliance with the License.
|
6
|
+
# You may obtain a copy of the License at
|
7
|
+
#
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
#
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
# See the License for the specific language governing permissions and
|
14
|
+
# limitations under the License.
|
15
|
+
|
16
|
+
import torch
|
17
|
+
from einops import rearrange
|
18
|
+
from torch import einsum, nn
|
19
|
+
|
20
|
+
|
21
|
+
__all__ = ["RotaryEmbedding", "apply_rotary_pos_emb"]
|
22
|
+
|
23
|
+
|
24
|
+
class RotaryEmbedding(nn.Module):
|
25
|
+
"""
|
26
|
+
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
|
31
|
+
):
|
32
|
+
"""
|
33
|
+
Args:
|
34
|
+
|
35
|
+
dim (int): rotary embedding dimension
|
36
|
+
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
|
37
|
+
by this factor via the trick in https://arxiv.org/abs/2306.15595.
|
38
|
+
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
|
39
|
+
"""
|
40
|
+
super().__init__()
|
41
|
+
self.seq_len_interpolation_factor = seq_len_interpolation_factor
|
42
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
43
|
+
self.register_buffer("inv_freq", inv_freq)
|
44
|
+
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
|
45
|
+
|
46
|
+
def forward(self, max_seq_len, offset=0):
|
47
|
+
seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset
|
48
|
+
seq = seq.type_as(self.inv_freq)
|
49
|
+
|
50
|
+
if self.pretrained_max_position_embeddings is not None and self.seq_len_interpolation_factor is not None:
|
51
|
+
if max_seq_len > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor:
|
52
|
+
# dynamic linear scaling (length > position we have learned)
|
53
|
+
seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings)
|
54
|
+
else:
|
55
|
+
# fixed linear scaling
|
56
|
+
seq *= 1 / self.seq_len_interpolation_factor
|
57
|
+
|
58
|
+
freqs = einsum("i , j -> i j", seq, self.inv_freq)
|
59
|
+
# first part even vector components, second part odd vector components,
|
60
|
+
# 2 * dim in dimension size
|
61
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
62
|
+
# emb [seq_length, .., dim]
|
63
|
+
return rearrange(emb, "n d -> n 1 1 d")
|
64
|
+
|
65
|
+
|
66
|
+
def _rotate_half(x):
|
67
|
+
"""
|
68
|
+
change sign so the last dimension
|
69
|
+
[A, B, C, D] -> [-C, -D, A, B]
|
70
|
+
"""
|
71
|
+
x = rearrange(x, "... (j d) -> ... j d", j=2)
|
72
|
+
x1, x2 = x.unbind(dim=-2)
|
73
|
+
return torch.cat((-x2, x1), dim=-1)
|
74
|
+
|
75
|
+
|
76
|
+
def apply_rotary_pos_emb(t, freqs):
|
77
|
+
"""
|
78
|
+
input tensor t is of shape [seq_length, ..., dim]
|
79
|
+
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
|
80
|
+
check https://kexue.fm/archives/8265 for detailed formulas
|
81
|
+
"""
|
82
|
+
# Changes from the original RoPE implementation
|
83
|
+
# 1. The original NeMo implementation assumes the input tensor of shape
|
84
|
+
# [seq_length, ..., dim], but the HF layout is [..., seq_length, dim].
|
85
|
+
# Thus freqs needs to be viewed as [..., seq_length, dim].
|
86
|
+
freqs = freqs.permute(1, 2, 0, 3)
|
87
|
+
# 2. Support for queries which past tokens are truncated
|
88
|
+
assert freqs.shape[-2] >= t.shape[-2]
|
89
|
+
if freqs.shape[-2] != t.shape[-2]:
|
90
|
+
freqs = freqs[:, :, -t.shape[-2] :, :]
|
91
|
+
|
92
|
+
rot_dim = freqs.shape[-1]
|
93
|
+
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
|
94
|
+
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
|
95
|
+
# first part is cosine component
|
96
|
+
# second part is sine component, need to change signs with _rotate_half method
|
97
|
+
t = (t * freqs.cos()) + (_rotate_half(t) * freqs.sin())
|
98
|
+
return torch.cat((t, t_pass), dim=-1)
|
@@ -0,0 +1,506 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
|
25
|
+
from typing import Dict, Optional, Tuple, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
import torch.nn as nn
|
29
|
+
from transformers.cache_utils import Cache, DynamicCache
|
30
|
+
from transformers.modeling_outputs import (
|
31
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
32
|
+
)
|
33
|
+
|
34
|
+
from .hf_hub_cached.modeling_midm import (
|
35
|
+
MidmAttention,
|
36
|
+
MidmBlock,
|
37
|
+
MidmModel,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
class _MidmRotaryEmbedding(nn.Module):
|
42
|
+
"""
|
43
|
+
Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864.
|
44
|
+
"""
|
45
|
+
|
46
|
+
def __init__(
|
47
|
+
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
|
48
|
+
):
|
49
|
+
"""
|
50
|
+
Args:
|
51
|
+
|
52
|
+
dim (int): rotary embedding dimension
|
53
|
+
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
|
54
|
+
by this factor via the trick in https://arxiv.org/abs/2306.15595.
|
55
|
+
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
|
56
|
+
"""
|
57
|
+
super().__init__()
|
58
|
+
self.seq_len_interpolation_factor = seq_len_interpolation_factor
|
59
|
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
60
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
61
|
+
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings
|
62
|
+
|
63
|
+
seq_len = pretrained_max_position_embeddings
|
64
|
+
device = self.inv_freq.device
|
65
|
+
dtype = torch.get_default_dtype()
|
66
|
+
self.max_seq_len_cached = seq_len
|
67
|
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
68
|
+
|
69
|
+
freqs = torch.outer(t, self.inv_freq)
|
70
|
+
|
71
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
72
|
+
self.register_buffer("emb_cached", emb.to(dtype), persistent=False)
|
73
|
+
|
74
|
+
def forward(self, max_seq_len, offset=0):
|
75
|
+
|
76
|
+
if max_seq_len > self.max_seq_len_cached:
|
77
|
+
self._set_emb_cache(seq_len=max_seq_len)
|
78
|
+
|
79
|
+
return self.emb_cached[:max_seq_len]
|
80
|
+
|
81
|
+
|
82
|
+
def _rotate_half(x):
|
83
|
+
"""
|
84
|
+
change sign so the last dimension
|
85
|
+
[A, B, C, D] -> [-C, -D, A, B]
|
86
|
+
"""
|
87
|
+
x1 = x[..., : x.shape[-1] // 2]
|
88
|
+
x2 = x[..., x.shape[-1] // 2 :]
|
89
|
+
return torch.cat((-x2, x1), dim=-1)
|
90
|
+
|
91
|
+
|
92
|
+
def apply_rotary_pos_emb(t: torch.Tensor, cache_kwargs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
|
93
|
+
"""
|
94
|
+
input tensor t is of shape [seq_length, ..., dim]
|
95
|
+
rotary positional embeding tensor freqs is of shape [seq_length, ..., dim]
|
96
|
+
check https://kexue.fm/archives/8265 for detailed formulas
|
97
|
+
"""
|
98
|
+
|
99
|
+
freqs = cache_kwargs["rotary_pos_emb"]
|
100
|
+
position_ids = cache_kwargs["position_ids"]
|
101
|
+
unsqueeze_dim = 1
|
102
|
+
|
103
|
+
rot_dim = freqs.shape[-1]
|
104
|
+
|
105
|
+
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
|
106
|
+
cos = freqs.cos()[position_ids].unsqueeze(unsqueeze_dim)
|
107
|
+
sin = freqs.sin()[position_ids].unsqueeze(unsqueeze_dim)
|
108
|
+
|
109
|
+
embed = (t * cos) + (_rotate_half(t) * sin)
|
110
|
+
embed = torch.cat((embed, t_pass), dim=-1)
|
111
|
+
|
112
|
+
return embed
|
113
|
+
|
114
|
+
|
115
|
+
class MidmLMHeadModelWrapper(torch.nn.Module):
|
116
|
+
def __init__(self, model):
|
117
|
+
super().__init__()
|
118
|
+
self.model = model
|
119
|
+
self.confg = model.config
|
120
|
+
|
121
|
+
self.use_rotary_position_embedding = model.config.use_rotary_position_embedding
|
122
|
+
if self.use_rotary_position_embedding:
|
123
|
+
rotary_dim = model.config.hidden_size // model.config.num_attention_heads
|
124
|
+
assert 0 < model.config.rotary_percentage <= 1
|
125
|
+
if model.config.rotary_percentage < 1:
|
126
|
+
rotary_dim = int(rotary_dim * model.config.rotary_percentage)
|
127
|
+
self._rotary_pos_emb = _MidmRotaryEmbedding(
|
128
|
+
rotary_dim,
|
129
|
+
seq_len_interpolation_factor=None,
|
130
|
+
pretrained_max_position_embeddings=model.config.max_position_embeddings,
|
131
|
+
)
|
132
|
+
|
133
|
+
def forward(
|
134
|
+
self,
|
135
|
+
input_ids: torch.Tensor,
|
136
|
+
attention_mask: torch.Tensor,
|
137
|
+
cache_position: torch.LongTensor,
|
138
|
+
*past_key_values,
|
139
|
+
):
|
140
|
+
past_kv_list = []
|
141
|
+
for i in range(self.model.config.n_layer):
|
142
|
+
cur_kv_layer = []
|
143
|
+
for j in range(2):
|
144
|
+
cur_kv_layer.append(past_key_values[2 * i + j])
|
145
|
+
past_kv_list.append(cur_kv_layer)
|
146
|
+
|
147
|
+
transformer_outputs = _MidmModel.forward(
|
148
|
+
self.model.transformer,
|
149
|
+
input_ids=input_ids,
|
150
|
+
past_key_values=past_kv_list,
|
151
|
+
attention_mask=attention_mask,
|
152
|
+
position_ids=cache_position,
|
153
|
+
rotary_pos_emb=self._rotary_pos_emb,
|
154
|
+
)
|
155
|
+
|
156
|
+
hidden_states = transformer_outputs[0]
|
157
|
+
|
158
|
+
# For the input_ids, we assume right-alignment.
|
159
|
+
# This assumption allows us to bypass dynamic indexing.
|
160
|
+
hidden_states = hidden_states[:, -1:]
|
161
|
+
lm_logits = self.model.lm_head(hidden_states)
|
162
|
+
kv_cache = transformer_outputs[1]
|
163
|
+
|
164
|
+
return lm_logits, kv_cache
|
165
|
+
|
166
|
+
|
167
|
+
def layernorm1p(module, input):
|
168
|
+
return torch.nn.functional.layer_norm(input, module.normalized_shape, module.weight + 1, module.bias, module.eps)
|
169
|
+
|
170
|
+
|
171
|
+
class _MidmAttention(MidmAttention):
|
172
|
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
173
|
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
174
|
+
|
175
|
+
if self.scale_attn_weights:
|
176
|
+
attn_weights = attn_weights / torch.full(
|
177
|
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
178
|
+
)
|
179
|
+
|
180
|
+
if self.scale_attn_by_inverse_layer_idx or self.scale_qk_by_inverse_layer_idx:
|
181
|
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
182
|
+
|
183
|
+
if attention_mask is not None:
|
184
|
+
attn_weights = attn_weights + attention_mask
|
185
|
+
|
186
|
+
if self.scale_qk_by_inverse_layer_idx:
|
187
|
+
attn_weights = attn_weights * float(self.layer_idx + 1)
|
188
|
+
|
189
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
190
|
+
|
191
|
+
attn_weights = attn_weights.type(value.dtype)
|
192
|
+
|
193
|
+
if head_mask is not None:
|
194
|
+
attn_weights = attn_weights * head_mask
|
195
|
+
|
196
|
+
attn_output = torch.matmul(attn_weights, value)
|
197
|
+
|
198
|
+
return attn_output, attn_weights
|
199
|
+
|
200
|
+
def forward(
|
201
|
+
self,
|
202
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
203
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
204
|
+
position_ids: Optional[torch.LongTensor] = None,
|
205
|
+
past_key_value: Optional[Cache] = None,
|
206
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
207
|
+
use_cache: Optional[bool] = False,
|
208
|
+
rotary_pos_emb=None,
|
209
|
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
210
|
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
211
|
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
212
|
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
213
|
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
214
|
+
|
215
|
+
kv_seq_len = key.shape[-2]
|
216
|
+
if past_key_value is not None:
|
217
|
+
if self.layer_idx is None:
|
218
|
+
raise ValueError(
|
219
|
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
220
|
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
221
|
+
"with a layer index."
|
222
|
+
)
|
223
|
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
224
|
+
|
225
|
+
if use_cache is True:
|
226
|
+
present = (key, value)
|
227
|
+
else:
|
228
|
+
present = None
|
229
|
+
|
230
|
+
if rotary_pos_emb is not None:
|
231
|
+
query = apply_rotary_pos_emb(query, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
|
232
|
+
key = apply_rotary_pos_emb(key, {"rotary_pos_emb": rotary_pos_emb, "position_ids": position_ids})
|
233
|
+
|
234
|
+
if past_key_value is not None:
|
235
|
+
key, value = past_key_value.update(key, value, self.layer_idx)
|
236
|
+
|
237
|
+
attn_output, _ = _MidmAttention._attn(self, query, key, value, attention_mask, head_mask)
|
238
|
+
|
239
|
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
240
|
+
|
241
|
+
attn_output = self.c_proj(attn_output)
|
242
|
+
attn_output = self.resid_dropout(attn_output)
|
243
|
+
|
244
|
+
outputs = (attn_output, present)
|
245
|
+
|
246
|
+
return outputs, past_key_value
|
247
|
+
|
248
|
+
|
249
|
+
class _MidmBlock(MidmBlock):
|
250
|
+
def forward(
|
251
|
+
self,
|
252
|
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
253
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
254
|
+
position_ids: Optional[torch.LongTensor] = None,
|
255
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
256
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
257
|
+
use_cache: Optional[bool] = False,
|
258
|
+
rotary_pos_emb=None,
|
259
|
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
260
|
+
residual = hidden_states
|
261
|
+
if self.use_layernorm1p:
|
262
|
+
hidden_states = layernorm1p(self.ln_1, hidden_states)
|
263
|
+
else:
|
264
|
+
hidden_states = self.ln_1(hidden_states)
|
265
|
+
|
266
|
+
attn_outputs, present_key_value = _MidmAttention.forward(
|
267
|
+
self.attn,
|
268
|
+
hidden_states,
|
269
|
+
attention_mask=attention_mask,
|
270
|
+
position_ids=position_ids,
|
271
|
+
past_key_value=past_key_value,
|
272
|
+
head_mask=head_mask,
|
273
|
+
rotary_pos_emb=rotary_pos_emb,
|
274
|
+
use_cache=use_cache,
|
275
|
+
)
|
276
|
+
|
277
|
+
attn_output = attn_outputs[0]
|
278
|
+
outputs = attn_outputs[1:]
|
279
|
+
|
280
|
+
hidden_states = attn_output + residual
|
281
|
+
|
282
|
+
residual = hidden_states
|
283
|
+
if self.use_layernorm1p:
|
284
|
+
hidden_states = layernorm1p(self.ln_2, hidden_states)
|
285
|
+
else:
|
286
|
+
hidden_states = self.ln_2(hidden_states)
|
287
|
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
288
|
+
|
289
|
+
hidden_states = residual + feed_forward_hidden_states
|
290
|
+
|
291
|
+
if use_cache:
|
292
|
+
outputs = (hidden_states,) + outputs
|
293
|
+
else:
|
294
|
+
outputs = (hidden_states,) + outputs[1:]
|
295
|
+
|
296
|
+
if use_cache:
|
297
|
+
outputs += (present_key_value,)
|
298
|
+
|
299
|
+
return outputs
|
300
|
+
|
301
|
+
|
302
|
+
class _MidmModel(MidmModel):
|
303
|
+
def forward(
|
304
|
+
self,
|
305
|
+
input_ids: Optional[torch.LongTensor] = None,
|
306
|
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
307
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
308
|
+
token_type_ids=None,
|
309
|
+
position_ids: Optional[torch.LongTensor] = None,
|
310
|
+
rotary_pos_emb=None,
|
311
|
+
head_mask: Optional[torch.FloatTensor] = None,
|
312
|
+
inputs_embeds=None,
|
313
|
+
use_cache: Optional[bool] = None,
|
314
|
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
315
|
+
input_shape = input_ids.size()
|
316
|
+
|
317
|
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
318
|
+
|
319
|
+
current_step = position_ids
|
320
|
+
|
321
|
+
if input_ids is not None and inputs_embeds is not None:
|
322
|
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
323
|
+
elif input_ids is not None:
|
324
|
+
batch_size, seq_length = input_ids.shape[:2]
|
325
|
+
elif inputs_embeds is not None:
|
326
|
+
batch_size, seq_length = inputs_embeds.shape[:2]
|
327
|
+
else:
|
328
|
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
329
|
+
|
330
|
+
if token_type_ids is not None:
|
331
|
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
332
|
+
|
333
|
+
past_key_values_length = 0
|
334
|
+
|
335
|
+
if use_cache:
|
336
|
+
use_legacy_cache = not isinstance(past_key_values, Cache)
|
337
|
+
if use_legacy_cache:
|
338
|
+
past_key_values = RebelDynamicCache.from_legacy_cache(
|
339
|
+
current_step=current_step,
|
340
|
+
max_length=self.config.max_position_embeddings,
|
341
|
+
past_key_values=past_key_values,
|
342
|
+
)
|
343
|
+
|
344
|
+
position_ids = torch.arange(0, input_ids.shape[-1], dtype=torch.int32).unsqueeze(0) + current_step
|
345
|
+
|
346
|
+
if position_ids is None:
|
347
|
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
348
|
+
position_ids = torch.arange(
|
349
|
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
350
|
+
)
|
351
|
+
position_ids = position_ids.unsqueeze(0)
|
352
|
+
|
353
|
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
354
|
+
|
355
|
+
if self.use_rotary_position_embedding:
|
356
|
+
rotary_pos_emb = rotary_pos_emb(self.config.max_position_embeddings)
|
357
|
+
|
358
|
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
359
|
+
|
360
|
+
if inputs_embeds is None:
|
361
|
+
inputs_embeds = self.wte(input_ids)
|
362
|
+
hidden_states = inputs_embeds
|
363
|
+
|
364
|
+
if token_type_ids is not None:
|
365
|
+
token_type_embeds = self.wte(token_type_ids)
|
366
|
+
hidden_states = hidden_states + token_type_embeds
|
367
|
+
|
368
|
+
hidden_states = self.drop(hidden_states)
|
369
|
+
|
370
|
+
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
371
|
+
|
372
|
+
next_decoder_cache = () if use_cache else None
|
373
|
+
|
374
|
+
for i, (block, _) in enumerate(zip(self.h, past_key_values)):
|
375
|
+
outputs = _MidmBlock.forward(
|
376
|
+
block,
|
377
|
+
hidden_states,
|
378
|
+
attention_mask=attention_mask,
|
379
|
+
position_ids=position_ids,
|
380
|
+
past_key_value=past_key_values,
|
381
|
+
head_mask=head_mask[i],
|
382
|
+
rotary_pos_emb=rotary_pos_emb,
|
383
|
+
use_cache=use_cache,
|
384
|
+
)
|
385
|
+
hidden_states = outputs[0]
|
386
|
+
|
387
|
+
if use_cache:
|
388
|
+
next_decoder_cache = outputs[2]
|
389
|
+
|
390
|
+
if self.use_layernorm1p:
|
391
|
+
hidden_states = layernorm1p(self.ln_f, hidden_states)
|
392
|
+
else:
|
393
|
+
hidden_states = self.ln_f(hidden_states)
|
394
|
+
hidden_states = hidden_states.view(output_shape)
|
395
|
+
|
396
|
+
next_cache = None
|
397
|
+
if use_cache:
|
398
|
+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
|
399
|
+
|
400
|
+
# return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
401
|
+
return hidden_states, next_cache
|
402
|
+
|
403
|
+
|
404
|
+
class RebelDynamicCache(DynamicCache):
|
405
|
+
"""
|
406
|
+
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
|
407
|
+
|
408
|
+
It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
|
409
|
+
`[batch_size, num_heads, seq_len, head_dim]`.
|
410
|
+
"""
|
411
|
+
|
412
|
+
def __init__(self, current_step, max_length) -> None:
|
413
|
+
super().__init__()
|
414
|
+
self.current_step = current_step
|
415
|
+
self.max_length = max_length
|
416
|
+
|
417
|
+
def copy(
|
418
|
+
self,
|
419
|
+
key_states: torch.Tensor,
|
420
|
+
value_states: torch.Tensor,
|
421
|
+
layer_idx: int,
|
422
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
423
|
+
"""
|
424
|
+
Copy the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
425
|
+
just for from_legacy_cache function
|
426
|
+
|
427
|
+
Parameters:
|
428
|
+
key_states (`torch.Tensor`):
|
429
|
+
The new key states to cache.
|
430
|
+
value_states (`torch.Tensor`):
|
431
|
+
The new value states to cache.
|
432
|
+
layer_idx (`int`):
|
433
|
+
The index of the layer to cache the states for.
|
434
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
435
|
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
436
|
+
|
437
|
+
Return:
|
438
|
+
A tuple containing the updated key and value states.
|
439
|
+
"""
|
440
|
+
|
441
|
+
if len(self.key_cache) <= layer_idx:
|
442
|
+
self.key_cache.append(key_states)
|
443
|
+
self.value_cache.append(value_states)
|
444
|
+
else:
|
445
|
+
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
|
446
|
+
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
|
447
|
+
|
448
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
449
|
+
|
450
|
+
def update(
|
451
|
+
self,
|
452
|
+
key_states: torch.Tensor,
|
453
|
+
value_states: torch.Tensor,
|
454
|
+
layer_idx: int,
|
455
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
456
|
+
"""
|
457
|
+
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`
|
458
|
+
based on self.current_step,
|
459
|
+
|
460
|
+
Parameters:
|
461
|
+
key_states (`torch.Tensor`):
|
462
|
+
The new key states to cache.
|
463
|
+
value_states (`torch.Tensor`):
|
464
|
+
The new value states to cache.
|
465
|
+
layer_idx (`int`):
|
466
|
+
The index of the layer to cache the states for.
|
467
|
+
cache_kwargs (`Dict[str, Any]`, `optional`):
|
468
|
+
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
|
469
|
+
|
470
|
+
Return:
|
471
|
+
A tuple containing the updated key and value states.
|
472
|
+
"""
|
473
|
+
|
474
|
+
if len(self.key_cache) <= layer_idx:
|
475
|
+
self.key_cache.append(key_states)
|
476
|
+
self.value_cache.append(value_states)
|
477
|
+
else:
|
478
|
+
self.key_cache[layer_idx] = self.key_cache[layer_idx].slice_scatter(
|
479
|
+
key_states, dim=2, start=self.current_step, end=self.current_step + key_states.shape[2]
|
480
|
+
)
|
481
|
+
self.value_cache[layer_idx] = self.value_cache[layer_idx].slice_scatter(
|
482
|
+
value_states, dim=2, start=self.current_step, end=self.current_step + value_states.shape[2]
|
483
|
+
)
|
484
|
+
|
485
|
+
return self.key_cache[layer_idx], self.value_cache[layer_idx]
|
486
|
+
|
487
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
488
|
+
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
489
|
+
if len(self.key_cache) <= layer_idx:
|
490
|
+
return 0
|
491
|
+
return self.key_cache[layer_idx].shape[-2]
|
492
|
+
|
493
|
+
def get_max_length(self) -> Optional[int]:
|
494
|
+
return self.max_length
|
495
|
+
|
496
|
+
@classmethod
|
497
|
+
def from_legacy_cache(
|
498
|
+
cls, current_step, max_length, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
499
|
+
) -> "DynamicCache":
|
500
|
+
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
|
501
|
+
cache = cls(current_step, max_length)
|
502
|
+
if past_key_values is not None:
|
503
|
+
for layer_idx in range(len(past_key_values)):
|
504
|
+
key_states, value_states = past_key_values[layer_idx]
|
505
|
+
cache.copy(key_states, value_states, layer_idx)
|
506
|
+
return cache
|