optimum-rbln 0.1.15__py3-none-any.whl → 0.2.1a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +26 -33
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/diffusers/__init__.py +4 -0
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +66 -24
- optimum/rbln/diffusers/models/__init__.py +2 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +38 -12
- optimum/rbln/diffusers/models/autoencoders/vae.py +0 -1
- optimum/rbln/diffusers/models/controlnet.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +5 -7
- optimum/rbln/diffusers/pipelines/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +8 -7
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +17 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +23 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +1 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +1 -2
- optimum/rbln/modeling.py +13 -347
- optimum/rbln/modeling_base.py +24 -4
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -0
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +0 -1
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +8 -4
- optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -7
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +329 -328
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +92 -107
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +2 -3
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -10
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +1 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +11 -11
- optimum/rbln/transformers/models/midm/modeling_midm.py +0 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +2 -3
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +57 -57
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +5 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +77 -54
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/utils/rbln_quantization.py +1 -2
- optimum/rbln/utils/decorator_utils.py +51 -15
- optimum/rbln/utils/import_utils.py +8 -1
- optimum/rbln/utils/logging.py +38 -1
- optimum/rbln/utils/model_utils.py +0 -1
- optimum/rbln/utils/runtime_utils.py +9 -3
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +23 -0
- optimum_rbln-0.2.1a0.dist-info/METADATA +121 -0
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/RECORD +76 -72
- optimum_rbln-0.2.1a0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.15.dist-info/METADATA +0 -106
- optimum_rbln-0.1.15.dist-info/licenses/LICENSE +0 -201
- {optimum_rbln-0.1.15.dist-info → optimum_rbln-0.2.1a0.dist-info}/WHEEL +0 -0
@@ -21,494 +21,152 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
-
from typing import
|
24
|
+
from typing import Tuple
|
25
25
|
|
26
26
|
import torch
|
27
27
|
from torch import nn
|
28
|
-
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
|
29
|
-
from transformers.models.t5.configuration_t5 import T5Config
|
30
|
-
from transformers.models.t5.modeling_t5 import (
|
31
|
-
T5Attention,
|
32
|
-
T5Block,
|
33
|
-
T5LayerCrossAttention,
|
34
|
-
T5LayerSelfAttention,
|
35
|
-
T5Stack,
|
36
|
-
)
|
37
28
|
from transformers.utils import logging
|
38
29
|
|
30
|
+
from ....ops import register_rbln_custom_attention_add_softmax
|
31
|
+
from ..seq2seq.seq2seq_architecture import (
|
32
|
+
Seq2SeqDecoder,
|
33
|
+
Seq2SeqDecoderLayer,
|
34
|
+
Seq2SeqDecoderWrapper,
|
35
|
+
Seq2SeqEncoderWrapper,
|
36
|
+
Seq2SeqForConditionalGeneration,
|
37
|
+
Seq2SeqSelfAttention,
|
38
|
+
)
|
39
39
|
|
40
|
-
logger = logging.get_logger(__name__)
|
41
40
|
|
42
|
-
|
43
|
-
from transformers import T5ForConditionalGeneration
|
41
|
+
logger = logging.get_logger(__name__)
|
44
42
|
|
45
43
|
|
46
44
|
class T5Wrapper:
|
47
|
-
def __init__(self, model):
|
48
|
-
self.encoder = T5EncoderWrapper(model)
|
49
|
-
self.decoder = T5DecoderWrapper(model)
|
45
|
+
def __init__(self, model: nn.Module, enc_max_seq_len: int, dec_max_seq_len: int = None):
|
46
|
+
self.encoder = T5EncoderWrapper(model, enc_max_seq_len)
|
47
|
+
self.decoder = T5DecoderWrapper(model, dec_max_seq_len=dec_max_seq_len)
|
48
|
+
|
49
|
+
|
50
|
+
class T5EncoderWrapper(Seq2SeqEncoderWrapper):
|
51
|
+
def __post_init__(self, model: nn.Module):
|
52
|
+
self.n_layer = getattr(self.config, "num_layers")
|
53
|
+
self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
|
54
|
+
self.num_heads = self.config.num_heads
|
55
|
+
self.d_kv = self.config.d_kv
|
56
|
+
|
57
|
+
def _extract_cross_kv_projects(self, t5_block: nn.Module):
|
58
|
+
return (
|
59
|
+
# different from bart
|
60
|
+
nn.ModuleList(t5_block[i].layer[1].EncDecAttention.k for i in range(self.n_layer)),
|
61
|
+
nn.ModuleList(t5_block[i].layer[1].EncDecAttention.v for i in range(self.n_layer)),
|
62
|
+
)
|
50
63
|
|
51
64
|
|
52
|
-
class
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
position_bias: torch.Tensor,
|
58
|
-
batch_ids: torch.Tensor = None,
|
59
|
-
) -> BaseModelOutput:
|
60
|
-
hidden_states = self.embed_tokens(input_ids)
|
61
|
-
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
62
|
-
position_bias = position_bias + extended_attention_mask
|
63
|
-
for i, layer_module in enumerate(self.block):
|
64
|
-
layer_outputs = _T5Block.forward(
|
65
|
-
layer_module,
|
66
|
-
hidden_states,
|
67
|
-
position_bias=position_bias,
|
68
|
-
batch_ids=batch_ids,
|
69
|
-
)
|
70
|
-
hidden_states = layer_outputs[0]
|
71
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
72
|
-
return BaseModelOutput(last_hidden_state=hidden_states)
|
73
|
-
|
74
|
-
|
75
|
-
class T5Decoder(T5Stack):
|
76
|
-
def forward(
|
77
|
-
self,
|
78
|
-
input_ids: torch.Tensor,
|
79
|
-
attention_mask: torch.Tensor,
|
80
|
-
encoder_hidden_states: torch.Tensor,
|
81
|
-
encoder_attention_mask: torch.Tensor,
|
82
|
-
past_key_values: torch.Tensor,
|
83
|
-
position_bias: torch.Tensor,
|
84
|
-
encoder_decoder_position_bias: torch.Tensor,
|
85
|
-
cache_position: torch.Tensor,
|
86
|
-
batch_ids: torch.Tensor,
|
87
|
-
) -> BaseModelOutputWithPastAndCrossAttentions:
|
88
|
-
hidden_states = self.embed_tokens(input_ids)
|
89
|
-
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
90
|
-
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
91
|
-
|
92
|
-
position_bias = position_bias + extended_attention_mask
|
93
|
-
encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
|
94
|
-
|
95
|
-
present_key_value_states = ()
|
96
|
-
|
97
|
-
for layer_module, past_key_value in zip(self.block, past_key_values):
|
98
|
-
layer_outputs = _T5Block.forward(
|
99
|
-
layer_module,
|
100
|
-
hidden_states,
|
101
|
-
position_bias=position_bias,
|
102
|
-
encoder_hidden_states=encoder_hidden_states,
|
103
|
-
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
104
|
-
past_key_value=past_key_value,
|
105
|
-
cache_position=cache_position,
|
106
|
-
batch_ids=batch_ids,
|
107
|
-
)
|
108
|
-
hidden_states, present_key_value_state = layer_outputs[:2]
|
109
|
-
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
110
|
-
|
111
|
-
hidden_states = self.final_layer_norm(hidden_states)
|
112
|
-
|
113
|
-
return BaseModelOutputWithPastAndCrossAttentions(
|
114
|
-
last_hidden_state=hidden_states,
|
115
|
-
past_key_values=present_key_value_states,
|
116
|
-
)
|
65
|
+
class T5DecoderWrapper(Seq2SeqDecoderWrapper):
|
66
|
+
def __post_init__(self, model, dec_max_seq_len: int = None):
|
67
|
+
register_rbln_custom_attention_add_softmax()
|
68
|
+
self.num_layers = self.config.num_layers
|
69
|
+
self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
|
117
70
|
|
71
|
+
def convert_to_rbln_conditional_generation(self, model: nn.Module, dec_max_seq_len: int):
|
72
|
+
new_blocks = []
|
73
|
+
for block in model.get_decoder().block:
|
74
|
+
self_attn = T5LayerSelfAttention(block.layer[0].SelfAttention)
|
75
|
+
block = T5Block(block, self_attn)
|
76
|
+
new_blocks.append(block)
|
118
77
|
|
119
|
-
|
120
|
-
|
121
|
-
super().__init__()
|
122
|
-
self.config = model.config
|
123
|
-
self.model = model
|
124
|
-
self.encoder = model.encoder
|
125
|
-
self.decoder = model.decoder
|
126
|
-
self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
|
127
|
-
self.config, "max_position_embeddings", None
|
128
|
-
)
|
129
|
-
self.encoder_max_length = None
|
130
|
-
self.decoder_max_length = None
|
78
|
+
decoder_model = T5Decoder(model.get_decoder(), new_blocks, dec_max_seq_len=dec_max_seq_len)
|
79
|
+
new_model = T5ForConditionalGeneration(model, decoder_model)
|
131
80
|
|
132
|
-
|
133
|
-
self,
|
134
|
-
input_ids: torch.Tensor,
|
135
|
-
attention_mask: torch.Tensor,
|
136
|
-
cross_key_value: torch.Tensor = None,
|
137
|
-
batch_idx: torch.Tensor = None,
|
138
|
-
) -> torch.Tensor:
|
139
|
-
decoder_max_length = self.decoder_max_length or self.default_max_length
|
140
|
-
encoder_max_length = self.encoder_max_length or self.default_max_length
|
141
|
-
|
142
|
-
attn_layer = self.encoder.block[0].layer[0].SelfAttention
|
143
|
-
encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
|
144
|
-
encoder_outputs = T5Encoder.forward(
|
145
|
-
self.encoder,
|
146
|
-
input_ids,
|
147
|
-
attention_mask,
|
148
|
-
encoder_position_bias,
|
149
|
-
batch_ids=torch.tensor(0, dtype=torch.int32),
|
150
|
-
)
|
81
|
+
return new_model
|
151
82
|
|
152
|
-
attn_layer = self.decoder.block[0].layer[0].SelfAttention
|
153
|
-
decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
154
|
-
decoder_position_bias = decoder_position_bias[:, :, :1]
|
155
|
-
|
156
|
-
attn_layer = self.decoder.block[0].layer[1].EncDecAttention
|
157
|
-
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
158
|
-
|
159
|
-
dummy_past_key_value = []
|
160
|
-
for i in range(self.config.num_layers):
|
161
|
-
pkv_self_attn_key = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
162
|
-
pkv_self_attn_value = torch.zeros(1, self.config.num_heads, decoder_max_length, self.config.d_kv)
|
163
|
-
pkv_cross_attn_key = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
164
|
-
pkv_cross_attn_value = torch.zeros(1, self.config.num_heads, encoder_max_length, self.config.d_kv)
|
165
|
-
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
166
|
-
dummy_past_key_value.append(layer_pkv)
|
167
|
-
|
168
|
-
decoder_attention_mask = torch.zeros(1, decoder_max_length, dtype=torch.float32)
|
169
|
-
decoder_attention_mask[:, :1] = 1
|
170
|
-
|
171
|
-
# Since first step of decoder has different graph to further step of it,
|
172
|
-
# here we merges decoder into its corresponding encoder.
|
173
|
-
# TODO(jongho): Separate first-step-decoder.
|
174
|
-
decoder_outputs = T5Decoder.forward(
|
175
|
-
self.decoder,
|
176
|
-
input_ids=torch.zeros(1, 1, dtype=torch.int64),
|
177
|
-
attention_mask=decoder_attention_mask,
|
178
|
-
position_bias=decoder_position_bias,
|
179
|
-
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
180
|
-
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
181
|
-
encoder_attention_mask=attention_mask,
|
182
|
-
past_key_values=dummy_past_key_value,
|
183
|
-
cache_position=torch.tensor(0, dtype=torch.int32),
|
184
|
-
batch_ids=torch.tensor(0, dtype=torch.int32),
|
185
|
-
)
|
186
83
|
|
187
|
-
|
84
|
+
class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
|
85
|
+
has_rescaling = True
|
188
86
|
|
189
|
-
|
190
|
-
|
191
|
-
cross_kv_cache.append(past_key_values[i][2])
|
192
|
-
cross_kv_cache.append(past_key_values[i][3])
|
193
|
-
cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
|
87
|
+
def __post_init__(self):
|
88
|
+
self.scaling = self.config.d_model**-0.5
|
194
89
|
|
195
|
-
cross_key_value = cross_key_value.slice_scatter(cross_kv_cache, dim=1, start=batch_idx, end=batch_idx + 1)
|
196
90
|
|
197
|
-
|
91
|
+
class T5Decoder(Seq2SeqDecoder):
|
92
|
+
has_pos_emb = False
|
198
93
|
|
94
|
+
def __post_init__(self, dec_max_seq_len: int = None):
|
95
|
+
self.invert_attention_mask = self._original_mod.invert_attention_mask
|
96
|
+
self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
|
199
97
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
self.config = model.config
|
204
|
-
self.model = model
|
205
|
-
self.encoder = model.encoder
|
206
|
-
self.decoder = model.decoder
|
207
|
-
self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
|
208
|
-
self.config, "max_position_embeddings", None
|
209
|
-
)
|
210
|
-
self.encoder_max_length = None
|
211
|
-
self.decoder_max_length = None
|
98
|
+
def precompute_dec_position_bias(self, model, dec_max_length):
|
99
|
+
attn_layer = model.block[0].layer[0].SelfAttention
|
100
|
+
return attn_layer.compute_bias(dec_max_length, dec_max_length)
|
212
101
|
|
213
|
-
def
|
214
|
-
self
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
cache_position: torch.Tensor,
|
219
|
-
batch_position: torch.Tensor,
|
220
|
-
self_kv_cache: torch.Tensor,
|
221
|
-
cross_kv_cache: torch.Tensor,
|
222
|
-
) -> Tuple[torch.Tensor]:
|
223
|
-
# cache_position : step 0부터
|
224
|
-
# attention_mask : 1개가 색칠된것부터 ([0:cache_position+1])
|
225
|
-
num_layers = self.model.config.num_layers
|
226
|
-
encoder_max_length = self.encoder_max_length or self.default_max_length
|
227
|
-
decoder_max_length = self.decoder_max_length or self.default_max_length
|
228
|
-
|
229
|
-
if input_ids.shape[1] == 1:
|
230
|
-
rbln_batch_position = None
|
231
|
-
else:
|
232
|
-
rbln_batch_position = batch_position
|
233
|
-
|
234
|
-
kv_cache = ()
|
235
|
-
for i in range(0, num_layers * 2, 2):
|
236
|
-
kv_cache = kv_cache + (
|
237
|
-
(
|
238
|
-
self_kv_cache[i],
|
239
|
-
self_kv_cache[i + 1],
|
240
|
-
cross_kv_cache[i],
|
241
|
-
cross_kv_cache[i + 1],
|
242
|
-
),
|
243
|
-
)
|
244
|
-
|
245
|
-
attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
|
246
|
-
_decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
247
|
-
|
248
|
-
# position_bias need to compute with batch (for cb)
|
102
|
+
def prepare_attn_mask(self, attention_mask, encoder_attention_mask, cache_position):
|
103
|
+
attention_mask = self.invert_attention_mask(attention_mask)
|
104
|
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
105
|
+
|
106
|
+
b_size = attention_mask.shape[0]
|
249
107
|
batch_decoder_position_bias = []
|
250
|
-
for i in range(
|
251
|
-
batch_position_bias =
|
108
|
+
for i in range(b_size):
|
109
|
+
batch_position_bias = self._dec_position_bias[:, :, cache_position[i][0]].unsqueeze(2)
|
252
110
|
batch_decoder_position_bias.append(batch_position_bias)
|
253
|
-
|
254
|
-
|
255
|
-
attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
|
256
|
-
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
257
|
-
|
258
|
-
decoder_outputs = T5Decoder.forward(
|
259
|
-
self.model.decoder,
|
260
|
-
input_ids=input_ids,
|
261
|
-
attention_mask=attention_mask,
|
262
|
-
encoder_hidden_states=1,
|
263
|
-
encoder_attention_mask=encoder_attention_mask,
|
264
|
-
position_bias=decoder_position_bias,
|
265
|
-
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
266
|
-
past_key_values=kv_cache,
|
267
|
-
cache_position=cache_position,
|
268
|
-
batch_ids=rbln_batch_position,
|
269
|
-
)
|
111
|
+
position_bias = torch.cat(batch_decoder_position_bias, dim=0)
|
270
112
|
|
271
|
-
|
272
|
-
sequence_output = decoder_outputs[0]
|
273
|
-
if self.model.config.tie_word_embeddings:
|
274
|
-
# Rescale output before projecting on vocab
|
275
|
-
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
276
|
-
sequence_output = sequence_output * (self.model.model_dim**-0.5)
|
277
|
-
lm_logits = self.model.lm_head(sequence_output)
|
113
|
+
attention_mask = position_bias + attention_mask
|
278
114
|
|
279
|
-
|
280
|
-
for i in range(self.model.config.num_layers):
|
281
|
-
self_kv_cache.append(past_key_values[i][0])
|
282
|
-
self_kv_cache.append(past_key_values[i][1])
|
115
|
+
return attention_mask, encoder_attention_mask
|
283
116
|
|
284
|
-
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
285
117
|
|
286
|
-
|
118
|
+
class T5Block(Seq2SeqDecoderLayer):
|
119
|
+
def __post_init__(self):
|
120
|
+
self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
|
121
|
+
self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
|
122
|
+
self.encoder_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
|
123
|
+
self.ff_layer = self._original_mod.layer[2]
|
287
124
|
|
125
|
+
def pre_self_attn_layer_norm(self, hidden_states):
|
126
|
+
return self.self_attn_layer_norm(hidden_states)
|
288
127
|
|
289
|
-
|
290
|
-
|
291
|
-
super().__init__(config, has_relative_attention_bias)
|
128
|
+
def post_self_attn_layer_norm(self, hidden_states):
|
129
|
+
return hidden_states
|
292
130
|
|
293
|
-
def
|
294
|
-
self
|
295
|
-
hidden_states: torch.Tensor,
|
296
|
-
key_value_states: Tuple[torch.Tensor] = None,
|
297
|
-
position_bias: torch.Tensor = None,
|
298
|
-
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
299
|
-
cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
|
300
|
-
batch_index: torch.Tensor = None,
|
301
|
-
is_self_attn: Optional[bool] = None,
|
302
|
-
) -> Tuple[torch.Tensor]:
|
303
|
-
batch_size = hidden_states.shape[0]
|
304
|
-
|
305
|
-
def shape(states, batch_size):
|
306
|
-
"""projection"""
|
307
|
-
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
308
|
-
|
309
|
-
def unshape(states, batch_size):
|
310
|
-
"""reshape"""
|
311
|
-
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
312
|
-
|
313
|
-
query_states = shape(self.q(hidden_states), batch_size) # (batch_size, n_heads, seq_length, dim_per_head)
|
314
|
-
|
315
|
-
# projection
|
316
|
-
if is_self_attn:
|
317
|
-
key_states = shape(self.k(hidden_states), batch_size)
|
318
|
-
value_states = shape(self.v(hidden_states), batch_size)
|
319
|
-
else:
|
320
|
-
# cross-attn
|
321
|
-
if cache_position.dim() == 0:
|
322
|
-
key_states = shape(self.k(key_value_states), key_value_states.shape[0])
|
323
|
-
value_states = shape(self.v(key_value_states), key_value_states.shape[0])
|
324
|
-
past_key_value = key_states, value_states
|
325
|
-
else:
|
326
|
-
key_states = past_key_value[0]
|
327
|
-
value_states = past_key_value[1]
|
328
|
-
|
329
|
-
if (batch_index is None or batch_index == -1) and batch_size > 1:
|
330
|
-
all_key_states = []
|
331
|
-
all_value_states = []
|
332
|
-
all_attn_output = []
|
333
|
-
|
334
|
-
for b in range(batch_size):
|
335
|
-
batch_query_states = query_states[b].unsqueeze(0)
|
336
|
-
batch_key_states = key_states[b].unsqueeze(0)
|
337
|
-
batch_value_states = value_states[b].unsqueeze(0)
|
338
|
-
|
339
|
-
if is_self_attn and past_key_value is not None:
|
340
|
-
batch_key_states = (
|
341
|
-
past_key_value[0][b]
|
342
|
-
.unsqueeze(0)
|
343
|
-
.slice_scatter(
|
344
|
-
batch_key_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
345
|
-
)
|
346
|
-
)
|
347
|
-
batch_value_states = (
|
348
|
-
past_key_value[1][b]
|
349
|
-
.unsqueeze(0)
|
350
|
-
.slice_scatter(
|
351
|
-
batch_value_states, dim=-2, start=cache_position[b][0], end=cache_position[b][0] + 1
|
352
|
-
)
|
353
|
-
)
|
354
|
-
|
355
|
-
scores = torch.matmul(batch_query_states, batch_key_states.transpose(3, 2))
|
356
|
-
scores += position_bias[b]
|
357
|
-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
358
|
-
attn_output = unshape(torch.matmul(attn_weights, batch_value_states), 1)
|
359
|
-
all_key_states.append(batch_key_states)
|
360
|
-
all_value_states.append(batch_value_states)
|
361
|
-
all_attn_output.append(attn_output)
|
362
|
-
|
363
|
-
key_states = torch.cat(all_key_states, dim=0)
|
364
|
-
value_states = torch.cat(all_value_states, dim=0)
|
365
|
-
attn_output = torch.cat(all_attn_output, dim=0)
|
366
|
-
|
367
|
-
else:
|
368
|
-
if batch_index is None or batch_index == -1:
|
369
|
-
batch_index = 0
|
370
|
-
|
371
|
-
if is_self_attn and past_key_value is not None:
|
372
|
-
key_states = past_key_value[0].slice_scatter(
|
373
|
-
key_states, dim=2, start=cache_position, end=cache_position + 1
|
374
|
-
)
|
375
|
-
value_states = past_key_value[1].slice_scatter(
|
376
|
-
value_states, dim=2, start=cache_position, end=cache_position + 1
|
377
|
-
)
|
378
|
-
# compute scores
|
379
|
-
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
380
|
-
scores += position_bias
|
381
|
-
|
382
|
-
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
383
|
-
scores
|
384
|
-
) # (batch_size, n_heads, seq_length, key_length)
|
385
|
-
|
386
|
-
attn_output = unshape(
|
387
|
-
torch.matmul(attn_weights, value_states), batch_size
|
388
|
-
) # (batch_size, seq_length, dim)
|
389
|
-
|
390
|
-
attn_output = self.o(attn_output)
|
391
|
-
present_key_value = (key_states, value_states)
|
392
|
-
outputs = (attn_output,) + (present_key_value,)
|
393
|
-
return outputs
|
394
|
-
|
395
|
-
|
396
|
-
class _T5LayerSelfAttention(T5LayerSelfAttention):
|
397
|
-
def forward(
|
398
|
-
self,
|
399
|
-
hidden_states: torch.Tensor,
|
400
|
-
position_bias: torch.Tensor = None,
|
401
|
-
past_key_value: Tuple[torch.Tensor] = None,
|
402
|
-
cache_position: Optional[torch.Tensor] = None,
|
403
|
-
batch_index: torch.Tensor = None,
|
404
|
-
):
|
405
|
-
normed_hidden_states = self.layer_norm(hidden_states)
|
406
|
-
attention_output = _T5Attention.forward(
|
407
|
-
self.SelfAttention,
|
408
|
-
hidden_states=normed_hidden_states,
|
409
|
-
position_bias=position_bias,
|
410
|
-
past_key_value=past_key_value,
|
411
|
-
cache_position=cache_position,
|
412
|
-
batch_index=batch_index,
|
413
|
-
is_self_attn=True,
|
414
|
-
)
|
131
|
+
def pre_cross_attn_layer_norm(self, hidden_states):
|
132
|
+
return self.encoder_attn_layer_norm(hidden_states)
|
415
133
|
|
416
|
-
|
417
|
-
|
418
|
-
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
419
|
-
return outputs
|
134
|
+
def post_cross_attn_layer_norm(self, hidden_states):
|
135
|
+
return hidden_states
|
420
136
|
|
421
137
|
|
422
|
-
class
|
423
|
-
def
|
424
|
-
self
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
):
|
432
|
-
normed_hidden_states = self.layer_norm(hidden_states)
|
433
|
-
attention_output = _T5Attention.forward(
|
434
|
-
self.EncDecAttention,
|
435
|
-
hidden_states=normed_hidden_states,
|
436
|
-
key_value_states=key_value_states,
|
437
|
-
position_bias=position_bias,
|
438
|
-
past_key_value=past_key_value,
|
439
|
-
cache_position=cache_position,
|
440
|
-
batch_index=batch_index,
|
441
|
-
is_self_attn=False,
|
442
|
-
)
|
138
|
+
class T5LayerSelfAttention(Seq2SeqSelfAttention):
|
139
|
+
def __post_init__(self):
|
140
|
+
self.q_proj = self._original_mod.q
|
141
|
+
self.k_proj = self._original_mod.k
|
142
|
+
self.v_proj = self._original_mod.v
|
143
|
+
self.out_proj = self._original_mod.o
|
144
|
+
self.num_heads = self._original_mod.n_heads
|
145
|
+
self.head_dim = self._original_mod.key_value_proj_dim
|
146
|
+
self.attn_decode = torch.ops.rbln_custom_ops.attn_decode_add_softmax
|
443
147
|
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
148
|
+
def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
149
|
+
query_states = self.q_proj(hidden_states)
|
150
|
+
key_states = self.k_proj(hidden_states)
|
151
|
+
value_states = self.v_proj(hidden_states)
|
152
|
+
return query_states, key_states, value_states
|
448
153
|
|
449
154
|
|
450
|
-
class
|
155
|
+
class T5CrossAttention(nn.Module):
|
156
|
+
def __init__(self, attn):
|
157
|
+
super().__init__()
|
158
|
+
self.attn = attn
|
159
|
+
|
451
160
|
def forward(
|
452
161
|
self,
|
453
|
-
hidden_states,
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
past_key_value=None,
|
458
|
-
cache_position=None,
|
459
|
-
batch_ids=None,
|
162
|
+
hidden_states: torch.Tensor = None,
|
163
|
+
past_key_value: torch.Tensor = None,
|
164
|
+
attention_mask: torch.Tensor = None,
|
165
|
+
key_value_states: torch.Tensor = None,
|
460
166
|
):
|
461
|
-
|
462
|
-
if not self.is_decoder:
|
463
|
-
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
|
464
|
-
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
465
|
-
|
466
|
-
if len(past_key_value) != expected_num_past_key_values:
|
467
|
-
raise ValueError(
|
468
|
-
f"There should be {expected_num_past_key_values} past states. "
|
469
|
-
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
470
|
-
f"Got {len(past_key_value)} past key / value states"
|
471
|
-
)
|
472
|
-
|
473
|
-
self_attn_past_key_value = past_key_value[:2]
|
474
|
-
if self_attn_past_key_value == (None, None):
|
475
|
-
self_attn_past_key_value = None
|
476
|
-
|
477
|
-
cross_attn_past_key_value = past_key_value[2:]
|
478
|
-
else:
|
479
|
-
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
480
|
-
self_attention_outputs = _T5LayerSelfAttention.forward(
|
481
|
-
self.layer[0],
|
167
|
+
return self.attn(
|
482
168
|
hidden_states=hidden_states,
|
483
|
-
|
484
|
-
|
485
|
-
|
486
|
-
batch_index=batch_ids,
|
169
|
+
past_key_value=past_key_value,
|
170
|
+
position_bias=attention_mask,
|
171
|
+
key_value_states=key_value_states,
|
487
172
|
)
|
488
|
-
|
489
|
-
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
490
|
-
|
491
|
-
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
492
|
-
if do_cross_attention:
|
493
|
-
cross_attention_outputs = _T5LayerCrossAttention.forward(
|
494
|
-
self.layer[1],
|
495
|
-
hidden_states,
|
496
|
-
key_value_states=encoder_hidden_states,
|
497
|
-
position_bias=encoder_decoder_position_bias,
|
498
|
-
past_key_value=cross_attn_past_key_value,
|
499
|
-
cache_position=cache_position,
|
500
|
-
batch_index=batch_ids,
|
501
|
-
)
|
502
|
-
hidden_states = cross_attention_outputs[0]
|
503
|
-
# Combine self attn and cross attn key value states
|
504
|
-
if present_key_value_state is not None:
|
505
|
-
# print(present_key_value_state.shape)
|
506
|
-
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
507
|
-
|
508
|
-
# Apply Feed Forward layer
|
509
|
-
hidden_states = self.layer[-1](hidden_states)
|
510
|
-
|
511
|
-
outputs = (hidden_states,)
|
512
|
-
outputs = outputs + (present_key_value_state,)
|
513
|
-
|
514
|
-
return outputs
|
@@ -1,3 +1,45 @@
|
|
1
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Copyright 2024 Rebellions Inc.
|
16
|
+
|
17
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
18
|
+
# you may not use this file except in compliance with the License.
|
19
|
+
# You may obtain a copy of the License at:
|
20
|
+
|
21
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
22
|
+
|
23
|
+
# Unless required by applicable law or agreed to in writing, software
|
24
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
25
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
26
|
+
# See the License for the specific language governing permissions and
|
27
|
+
# limitations under the License.
|
28
|
+
|
29
|
+
# Portions of this software are licensed under the Apache License,
|
30
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
31
|
+
# additional information regarding copyright ownership.
|
32
|
+
|
33
|
+
# All other portions of this software, including proprietary code,
|
34
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
35
|
+
# copied, modified, or distributed without prior written permission
|
36
|
+
# from Rebellions Inc.
|
37
|
+
|
38
|
+
"""
|
39
|
+
Generation utilities for Whisper.
|
40
|
+
Modified from `transformers.models.whisper.generation_whisper.py`
|
41
|
+
"""
|
42
|
+
|
1
43
|
import torch
|
2
44
|
from transformers import GenerationMixin
|
3
45
|
from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
|