optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,439 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import TYPE_CHECKING, Optional, Tuple
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
|
29
|
+
from transformers.models.t5.configuration_t5 import T5Config
|
30
|
+
from transformers.models.t5.modeling_t5 import (
|
31
|
+
T5Attention,
|
32
|
+
T5Block,
|
33
|
+
T5LayerCrossAttention,
|
34
|
+
T5LayerSelfAttention,
|
35
|
+
T5Stack,
|
36
|
+
)
|
37
|
+
from transformers.utils import logging
|
38
|
+
|
39
|
+
|
40
|
+
logger = logging.get_logger(__name__)
|
41
|
+
|
42
|
+
if TYPE_CHECKING:
|
43
|
+
from transformers import T5ForConditionalGeneration
|
44
|
+
|
45
|
+
|
46
|
+
class T5Encoder(T5Stack):
|
47
|
+
def forward(
|
48
|
+
self,
|
49
|
+
input_ids: torch.Tensor,
|
50
|
+
attention_mask: torch.Tensor,
|
51
|
+
position_bias: torch.Tensor,
|
52
|
+
) -> BaseModelOutput:
|
53
|
+
hidden_states = self.embed_tokens(input_ids)
|
54
|
+
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
55
|
+
position_bias = position_bias + extended_attention_mask
|
56
|
+
for i, layer_module in enumerate(self.block):
|
57
|
+
layer_outputs = _T5Block.forward(
|
58
|
+
layer_module,
|
59
|
+
hidden_states,
|
60
|
+
position_bias=position_bias,
|
61
|
+
)
|
62
|
+
hidden_states = layer_outputs[0]
|
63
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
64
|
+
return BaseModelOutput(last_hidden_state=hidden_states)
|
65
|
+
|
66
|
+
|
67
|
+
class T5Decoder(T5Stack):
|
68
|
+
def forward(
|
69
|
+
self,
|
70
|
+
input_ids: torch.Tensor,
|
71
|
+
attention_mask: torch.Tensor,
|
72
|
+
encoder_hidden_states: torch.Tensor,
|
73
|
+
encoder_attention_mask: torch.Tensor,
|
74
|
+
past_key_values: torch.Tensor,
|
75
|
+
position_bias: torch.Tensor,
|
76
|
+
encoder_decoder_position_bias: torch.Tensor,
|
77
|
+
cache_position: torch.Tensor,
|
78
|
+
) -> BaseModelOutputWithPastAndCrossAttentions:
|
79
|
+
hidden_states = self.embed_tokens(input_ids)
|
80
|
+
extended_attention_mask = self.invert_attention_mask(attention_mask)
|
81
|
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
82
|
+
|
83
|
+
position_bias = position_bias + extended_attention_mask
|
84
|
+
encoder_decoder_position_bias = encoder_decoder_position_bias + encoder_extended_attention_mask
|
85
|
+
|
86
|
+
present_key_value_states = ()
|
87
|
+
for layer_module, past_key_value in zip(self.block, past_key_values):
|
88
|
+
layer_outputs = _T5Block.forward(
|
89
|
+
layer_module,
|
90
|
+
hidden_states,
|
91
|
+
position_bias=position_bias,
|
92
|
+
encoder_hidden_states=encoder_hidden_states,
|
93
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
94
|
+
past_key_value=past_key_value,
|
95
|
+
cache_position=cache_position,
|
96
|
+
)
|
97
|
+
hidden_states, present_key_value_state = layer_outputs[:2]
|
98
|
+
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
99
|
+
|
100
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
101
|
+
|
102
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
103
|
+
last_hidden_state=hidden_states,
|
104
|
+
past_key_values=present_key_value_states,
|
105
|
+
)
|
106
|
+
|
107
|
+
|
108
|
+
class T5EncoderWrapper(torch.nn.Module):
|
109
|
+
def __init__(self, model: "T5ForConditionalGeneration"):
|
110
|
+
super().__init__()
|
111
|
+
self.config = model.config
|
112
|
+
self.model = model
|
113
|
+
self.encoder = model.encoder
|
114
|
+
self.decoder = model.decoder
|
115
|
+
self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
|
116
|
+
self.config, "max_position_embeddings", None
|
117
|
+
)
|
118
|
+
self.encoder_max_length = None
|
119
|
+
self.decoder_max_length = None
|
120
|
+
self.decoder_batch_size = 1
|
121
|
+
|
122
|
+
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
123
|
+
encoder_batch_size = input_ids.shape[0]
|
124
|
+
decoder_batch_size = self.decoder_batch_size
|
125
|
+
decoder_max_length = self.decoder_max_length or self.default_max_length
|
126
|
+
encoder_max_length = self.encoder_max_length or self.default_max_length
|
127
|
+
|
128
|
+
attn_layer = self.encoder.block[0].layer[0].SelfAttention
|
129
|
+
encoder_position_bias = T5Attention.compute_bias(attn_layer, encoder_max_length, encoder_max_length)
|
130
|
+
encoder_outputs = T5Encoder.forward(self.encoder, input_ids, attention_mask, encoder_position_bias)
|
131
|
+
|
132
|
+
attn_layer = self.decoder.block[0].layer[0].SelfAttention
|
133
|
+
decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
134
|
+
decoder_position_bias = decoder_position_bias[:, :, :1]
|
135
|
+
|
136
|
+
attn_layer = self.decoder.block[0].layer[1].EncDecAttention
|
137
|
+
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
138
|
+
|
139
|
+
dummy_past_key_value = []
|
140
|
+
for i in range(self.config.num_layers):
|
141
|
+
pkv_self_attn_key = torch.zeros(
|
142
|
+
decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
|
143
|
+
)
|
144
|
+
pkv_self_attn_value = torch.zeros(
|
145
|
+
decoder_batch_size, self.config.num_heads, decoder_max_length, self.config.d_kv
|
146
|
+
)
|
147
|
+
pkv_cross_attn_key = torch.zeros(
|
148
|
+
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
149
|
+
)
|
150
|
+
pkv_cross_attn_value = torch.zeros(
|
151
|
+
encoder_batch_size, self.config.num_heads, encoder_max_length, self.config.d_kv
|
152
|
+
)
|
153
|
+
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
154
|
+
dummy_past_key_value.append(layer_pkv)
|
155
|
+
|
156
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, decoder_max_length, dtype=torch.int64)
|
157
|
+
decoder_attention_mask[:, :1] = 1
|
158
|
+
|
159
|
+
# Since first step of decoder has different graph to further step of it,
|
160
|
+
# here we merges decoder into its corresponding encoder.
|
161
|
+
# TODO(jongho): Separate first-step-decoder.
|
162
|
+
decoder_outputs = T5Decoder.forward(
|
163
|
+
self.decoder,
|
164
|
+
input_ids=torch.zeros(decoder_batch_size, 1, dtype=torch.int64),
|
165
|
+
attention_mask=decoder_attention_mask,
|
166
|
+
position_bias=decoder_position_bias,
|
167
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
168
|
+
encoder_hidden_states=encoder_outputs.last_hidden_state,
|
169
|
+
encoder_attention_mask=attention_mask,
|
170
|
+
past_key_values=dummy_past_key_value,
|
171
|
+
cache_position=torch.tensor(0, dtype=torch.int32),
|
172
|
+
)
|
173
|
+
|
174
|
+
past_key_values = decoder_outputs.past_key_values
|
175
|
+
|
176
|
+
cross_kv_cache = []
|
177
|
+
for i in range(self.model.config.num_layers):
|
178
|
+
cross_kv_cache.append(past_key_values[i][2])
|
179
|
+
cross_kv_cache.append(past_key_values[i][3])
|
180
|
+
cross_kv_cache = torch.stack(cross_kv_cache, dim=0)
|
181
|
+
|
182
|
+
return cross_kv_cache
|
183
|
+
|
184
|
+
|
185
|
+
class T5DecoderWrapper(torch.nn.Module):
|
186
|
+
def __init__(self, model: "T5ForConditionalGeneration"):
|
187
|
+
super().__init__()
|
188
|
+
self.config = model.config
|
189
|
+
self.model = model
|
190
|
+
self.encoder = model.encoder
|
191
|
+
self.decoder = model.decoder
|
192
|
+
self.default_max_length = getattr(self.config, "n_positions", None) or getattr(
|
193
|
+
self.config, "max_position_embeddings", None
|
194
|
+
)
|
195
|
+
self.encoder_max_length = None
|
196
|
+
self.decoder_max_length = None
|
197
|
+
|
198
|
+
def forward(
|
199
|
+
self,
|
200
|
+
input_ids: torch.Tensor,
|
201
|
+
attention_mask: torch.Tensor,
|
202
|
+
encoder_attention_mask: torch.Tensor,
|
203
|
+
cache_position: torch.Tensor,
|
204
|
+
self_kv_cache: torch.Tensor,
|
205
|
+
cross_kv_cache: torch.Tensor,
|
206
|
+
) -> Tuple[torch.Tensor]:
|
207
|
+
# cache_position : step 0부터
|
208
|
+
# attention_mask : 1개가 색칠된것부터 ([0:cache_position+1])
|
209
|
+
num_layers = self.model.config.num_layers
|
210
|
+
encoder_max_length = self.encoder_max_length or self.default_max_length
|
211
|
+
decoder_max_length = self.decoder_max_length or self.default_max_length
|
212
|
+
|
213
|
+
kv_cache = ()
|
214
|
+
for i in range(0, num_layers * 2, 2):
|
215
|
+
kv_cache = kv_cache + (
|
216
|
+
(
|
217
|
+
self_kv_cache[i],
|
218
|
+
self_kv_cache[i + 1],
|
219
|
+
cross_kv_cache[i],
|
220
|
+
cross_kv_cache[i + 1],
|
221
|
+
),
|
222
|
+
)
|
223
|
+
|
224
|
+
attn_layer = self.model.decoder.block[0].layer[0].SelfAttention
|
225
|
+
_decoder_position_bias = T5Attention.compute_bias(attn_layer, decoder_max_length, decoder_max_length)
|
226
|
+
decoder_position_bias = _decoder_position_bias[:, :, cache_position].unsqueeze(2)
|
227
|
+
|
228
|
+
attn_layer = self.model.decoder.block[0].layer[1].EncDecAttention
|
229
|
+
encoder_decoder_position_bias = torch.zeros(1, attn_layer.n_heads, 1, encoder_max_length)
|
230
|
+
|
231
|
+
decoder_outputs = T5Decoder.forward(
|
232
|
+
self.model.decoder,
|
233
|
+
input_ids=input_ids,
|
234
|
+
attention_mask=attention_mask,
|
235
|
+
encoder_hidden_states=1,
|
236
|
+
encoder_attention_mask=encoder_attention_mask,
|
237
|
+
position_bias=decoder_position_bias,
|
238
|
+
encoder_decoder_position_bias=encoder_decoder_position_bias,
|
239
|
+
past_key_values=kv_cache,
|
240
|
+
cache_position=cache_position,
|
241
|
+
)
|
242
|
+
|
243
|
+
past_key_values = decoder_outputs.past_key_values
|
244
|
+
sequence_output = decoder_outputs[0]
|
245
|
+
if self.model.config.tie_word_embeddings:
|
246
|
+
# Rescale output before projecting on vocab
|
247
|
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
248
|
+
sequence_output = sequence_output * (self.model.model_dim**-0.5)
|
249
|
+
lm_logits = self.model.lm_head(sequence_output)
|
250
|
+
|
251
|
+
self_kv_cache = []
|
252
|
+
for i in range(self.model.config.num_layers):
|
253
|
+
self_kv_cache.append(past_key_values[i][0])
|
254
|
+
self_kv_cache.append(past_key_values[i][1])
|
255
|
+
|
256
|
+
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
257
|
+
|
258
|
+
return lm_logits, self_kv_cache
|
259
|
+
|
260
|
+
|
261
|
+
class _T5Attention(T5Attention):
|
262
|
+
def __init__(self, config: T5Config, has_relative_attention_bias=False):
|
263
|
+
super().__init__(config, has_relative_attention_bias)
|
264
|
+
|
265
|
+
def forward(
|
266
|
+
self,
|
267
|
+
hidden_states: torch.Tensor,
|
268
|
+
key_value_states: Tuple[torch.Tensor] = None,
|
269
|
+
position_bias: torch.Tensor = None,
|
270
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
271
|
+
cache_position: Optional[torch.Tensor] = None, # 현재 cache sequence 길이
|
272
|
+
is_self_attn: Optional[bool] = None,
|
273
|
+
) -> Tuple[torch.Tensor]:
|
274
|
+
batch_size = hidden_states.shape[0]
|
275
|
+
cross_batch_size = key_value_states.shape[0] if not is_self_attn and cache_position == 0 else None
|
276
|
+
|
277
|
+
def shape(states, batch_size):
|
278
|
+
"""projection"""
|
279
|
+
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
280
|
+
|
281
|
+
def unshape(states, batch_size):
|
282
|
+
"""reshape"""
|
283
|
+
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
284
|
+
|
285
|
+
query_states = shape(self.q(hidden_states), batch_size) # (batch_size, n_heads, seq_length, dim_per_head)
|
286
|
+
|
287
|
+
# projection
|
288
|
+
if is_self_attn:
|
289
|
+
key_states = shape(self.k(hidden_states), batch_size)
|
290
|
+
value_states = shape(self.v(hidden_states), batch_size)
|
291
|
+
if past_key_value is not None:
|
292
|
+
# decoder self attn
|
293
|
+
cache_k = past_key_value[0].slice_scatter(
|
294
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
295
|
+
)
|
296
|
+
cache_v = past_key_value[1].slice_scatter(
|
297
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
298
|
+
)
|
299
|
+
past_key_value = (cache_k, cache_v)
|
300
|
+
key_states, value_states = past_key_value
|
301
|
+
|
302
|
+
else:
|
303
|
+
# cross-attn
|
304
|
+
if cache_position == 0:
|
305
|
+
key_states = shape(self.k(key_value_states), cross_batch_size)
|
306
|
+
value_states = shape(self.v(key_value_states), cross_batch_size)
|
307
|
+
past_key_value = key_states, value_states
|
308
|
+
else:
|
309
|
+
key_states = past_key_value[0]
|
310
|
+
value_states = past_key_value[1]
|
311
|
+
|
312
|
+
# compute scores
|
313
|
+
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
314
|
+
scores += position_bias
|
315
|
+
|
316
|
+
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
|
317
|
+
scores
|
318
|
+
) # (batch_size, n_heads, seq_length, key_length)
|
319
|
+
|
320
|
+
attn_output = unshape(torch.matmul(attn_weights, value_states), batch_size) # (batch_size, seq_length, dim)
|
321
|
+
attn_output = self.o(attn_output)
|
322
|
+
|
323
|
+
outputs = (attn_output,) + (past_key_value,)
|
324
|
+
return outputs
|
325
|
+
|
326
|
+
|
327
|
+
class _T5LayerSelfAttention(T5LayerSelfAttention):
|
328
|
+
def forward(
|
329
|
+
self,
|
330
|
+
hidden_states: torch.Tensor,
|
331
|
+
position_bias: torch.Tensor = None,
|
332
|
+
past_key_value: Tuple[torch.Tensor] = None,
|
333
|
+
cache_position: Optional[torch.Tensor] = None,
|
334
|
+
):
|
335
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
336
|
+
attention_output = _T5Attention.forward(
|
337
|
+
self.SelfAttention,
|
338
|
+
hidden_states=normed_hidden_states,
|
339
|
+
position_bias=position_bias,
|
340
|
+
past_key_value=past_key_value,
|
341
|
+
cache_position=cache_position,
|
342
|
+
is_self_attn=True,
|
343
|
+
)
|
344
|
+
|
345
|
+
# Residual Connection
|
346
|
+
hidden_states = hidden_states + self.dropout(attention_output[0])
|
347
|
+
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
|
348
|
+
return outputs
|
349
|
+
|
350
|
+
|
351
|
+
class _T5LayerCrossAttention(T5LayerCrossAttention):
|
352
|
+
def forward(
|
353
|
+
self,
|
354
|
+
hidden_states: torch.Tensor,
|
355
|
+
key_value_states: torch.Tensor,
|
356
|
+
position_bias: torch.Tensor = None,
|
357
|
+
past_key_value: Tuple[torch.Tensor] = None,
|
358
|
+
cache_position: Optional[torch.Tensor] = None,
|
359
|
+
):
|
360
|
+
normed_hidden_states = self.layer_norm(hidden_states)
|
361
|
+
attention_output = _T5Attention.forward(
|
362
|
+
self.EncDecAttention,
|
363
|
+
hidden_states=normed_hidden_states,
|
364
|
+
key_value_states=key_value_states,
|
365
|
+
position_bias=position_bias,
|
366
|
+
past_key_value=past_key_value,
|
367
|
+
cache_position=cache_position,
|
368
|
+
is_self_attn=False,
|
369
|
+
)
|
370
|
+
|
371
|
+
# Residual connection
|
372
|
+
layer_output = hidden_states + self.dropout(attention_output[0])
|
373
|
+
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
|
374
|
+
return outputs
|
375
|
+
|
376
|
+
|
377
|
+
class _T5Block(T5Block):
|
378
|
+
def forward(
|
379
|
+
self,
|
380
|
+
hidden_states,
|
381
|
+
position_bias=None,
|
382
|
+
encoder_hidden_states=None,
|
383
|
+
encoder_decoder_position_bias=None,
|
384
|
+
past_key_value=None,
|
385
|
+
cache_position=None,
|
386
|
+
):
|
387
|
+
if past_key_value is not None:
|
388
|
+
if not self.is_decoder:
|
389
|
+
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
|
390
|
+
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
391
|
+
|
392
|
+
if len(past_key_value) != expected_num_past_key_values:
|
393
|
+
raise ValueError(
|
394
|
+
f"There should be {expected_num_past_key_values} past states. "
|
395
|
+
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
396
|
+
f"Got {len(past_key_value)} past key / value states"
|
397
|
+
)
|
398
|
+
|
399
|
+
self_attn_past_key_value = past_key_value[:2]
|
400
|
+
if self_attn_past_key_value == (None, None):
|
401
|
+
self_attn_past_key_value = None
|
402
|
+
|
403
|
+
cross_attn_past_key_value = past_key_value[2:]
|
404
|
+
else:
|
405
|
+
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
406
|
+
|
407
|
+
self_attention_outputs = _T5LayerSelfAttention.forward(
|
408
|
+
self.layer[0],
|
409
|
+
hidden_states=hidden_states,
|
410
|
+
position_bias=position_bias,
|
411
|
+
past_key_value=self_attn_past_key_value,
|
412
|
+
cache_position=cache_position,
|
413
|
+
)
|
414
|
+
|
415
|
+
hidden_states, present_key_value_state = self_attention_outputs[:2]
|
416
|
+
|
417
|
+
do_cross_attention = self.is_decoder and encoder_hidden_states is not None
|
418
|
+
if do_cross_attention:
|
419
|
+
cross_attention_outputs = _T5LayerCrossAttention.forward(
|
420
|
+
self.layer[1],
|
421
|
+
hidden_states,
|
422
|
+
key_value_states=encoder_hidden_states,
|
423
|
+
position_bias=encoder_decoder_position_bias,
|
424
|
+
past_key_value=cross_attn_past_key_value,
|
425
|
+
cache_position=cache_position,
|
426
|
+
)
|
427
|
+
hidden_states = cross_attention_outputs[0]
|
428
|
+
# Combine self attn and cross attn key value states
|
429
|
+
if present_key_value_state is not None:
|
430
|
+
# print(present_key_value_state.shape)
|
431
|
+
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
|
432
|
+
|
433
|
+
# Apply Feed Forward layer
|
434
|
+
hidden_states = self.layer[-1](hidden_states)
|
435
|
+
|
436
|
+
outputs = (hidden_states,)
|
437
|
+
outputs = outputs + (present_key_value_state,)
|
438
|
+
|
439
|
+
return outputs
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
|
@@ -0,0 +1,121 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from typing import TYPE_CHECKING, Optional, Union
|
26
|
+
|
27
|
+
import torch
|
28
|
+
from transformers import AutoModelForMaskedLM, PretrainedConfig, Wav2Vec2ForCTC
|
29
|
+
from transformers.modeling_outputs import CausalLMOutput
|
30
|
+
|
31
|
+
from ....modeling_base import RBLNModel
|
32
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
33
|
+
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
if TYPE_CHECKING:
|
38
|
+
from transformers import (
|
39
|
+
AutoFeatureExtractor,
|
40
|
+
AutoProcessor,
|
41
|
+
AutoTokenizer,
|
42
|
+
PretrainedConfig,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class _Wav2Vec2(torch.nn.Module):
|
47
|
+
def __init__(self, model: "Wav2Vec2ForCTC"):
|
48
|
+
super().__init__()
|
49
|
+
self.model = model
|
50
|
+
|
51
|
+
def forward(self, input_values):
|
52
|
+
output = self.model.wav2vec2(input_values=input_values)
|
53
|
+
return self.model.lm_head(output[0])
|
54
|
+
|
55
|
+
|
56
|
+
class RBLNWav2Vec2ForCTC(RBLNModel):
|
57
|
+
"""
|
58
|
+
Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
|
59
|
+
|
60
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the
|
61
|
+
library implements for all its model.
|
62
|
+
|
63
|
+
It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
|
64
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
65
|
+
- compiling the resulting graph using the RBLN compiler.
|
66
|
+
"""
|
67
|
+
|
68
|
+
model_type = "rbln_model"
|
69
|
+
main_input_name = "input_values"
|
70
|
+
auto_model_class = AutoModelForMaskedLM
|
71
|
+
|
72
|
+
@classmethod
|
73
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
74
|
+
return _Wav2Vec2(model).eval()
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def _get_rbln_config(
|
78
|
+
cls,
|
79
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
80
|
+
model_config: "PretrainedConfig",
|
81
|
+
rbln_max_seq_len: Optional[int] = None,
|
82
|
+
rbln_batch_size: Optional[int] = None,
|
83
|
+
) -> RBLNConfig:
|
84
|
+
meta = {}
|
85
|
+
|
86
|
+
if rbln_max_seq_len is None:
|
87
|
+
for tokenizer in preprocessors:
|
88
|
+
if hasattr(tokenizer, "model_max_length"):
|
89
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
90
|
+
break
|
91
|
+
if rbln_max_seq_len is None:
|
92
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
93
|
+
|
94
|
+
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
95
|
+
|
96
|
+
if rbln_batch_size is None:
|
97
|
+
rbln_batch_size = 1
|
98
|
+
|
99
|
+
input_info = [
|
100
|
+
(
|
101
|
+
"input_values",
|
102
|
+
[
|
103
|
+
rbln_batch_size,
|
104
|
+
rbln_max_seq_len,
|
105
|
+
],
|
106
|
+
"float32",
|
107
|
+
),
|
108
|
+
]
|
109
|
+
|
110
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info, batch_size=rbln_batch_size)
|
111
|
+
|
112
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
113
|
+
[rbln_runtime_config],
|
114
|
+
_rbln_meta=meta,
|
115
|
+
)
|
116
|
+
|
117
|
+
return rbln_config
|
118
|
+
|
119
|
+
def forward(self, input_values: "torch.Tensor", **kwargs):
|
120
|
+
outputs = super().forward(input_values, **kwargs)
|
121
|
+
return CausalLMOutput(logits=outputs)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .modeling_whisper import RBLNWhisperForConditionalGeneration
|