optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,406 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Optional, Tuple, Union
|
25
|
+
|
26
|
+
import torch
|
27
|
+
from torch import nn
|
28
|
+
from transformers.modeling_attn_mask_utils import (
|
29
|
+
_prepare_4d_causal_attention_mask,
|
30
|
+
_prepare_4d_causal_attention_mask_for_sdpa,
|
31
|
+
)
|
32
|
+
from transformers.modeling_outputs import (
|
33
|
+
BaseModelOutput,
|
34
|
+
BaseModelOutputWithPastAndCrossAttentions,
|
35
|
+
Seq2SeqLMOutput,
|
36
|
+
)
|
37
|
+
from transformers.models.whisper.modeling_whisper import (
|
38
|
+
WhisperAttention,
|
39
|
+
WhisperDecoder,
|
40
|
+
WhisperDecoderLayer,
|
41
|
+
WhisperPositionalEmbedding,
|
42
|
+
WhisperSdpaAttention,
|
43
|
+
)
|
44
|
+
from transformers.utils import logging
|
45
|
+
|
46
|
+
|
47
|
+
logger = logging.get_logger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class _WhisperAttention(WhisperAttention):
|
51
|
+
def forward(
|
52
|
+
self,
|
53
|
+
hidden_states: torch.Tensor,
|
54
|
+
key_value_states: Optional[torch.Tensor] = None,
|
55
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
56
|
+
attention_mask: Optional[torch.Tensor] = None,
|
57
|
+
cache_position: Optional[torch.Tensor] = None,
|
58
|
+
**kwargs,
|
59
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
60
|
+
|
61
|
+
bsz, tgt_len, _ = hidden_states.size()
|
62
|
+
is_cross_attention = key_value_states is not None
|
63
|
+
|
64
|
+
query_states = self.q_proj(hidden_states) * self.scaling
|
65
|
+
|
66
|
+
if is_cross_attention:
|
67
|
+
is_dummy_decoder = len(key_value_states.shape) > 1
|
68
|
+
if is_dummy_decoder:
|
69
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
70
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
71
|
+
else:
|
72
|
+
key_states = past_key_value[0]
|
73
|
+
value_states = past_key_value[1]
|
74
|
+
else:
|
75
|
+
if self.is_decoder:
|
76
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
77
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
78
|
+
key_states = past_key_value[0].slice_scatter(
|
79
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
80
|
+
)
|
81
|
+
value_states = past_key_value[1].slice_scatter(
|
82
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
83
|
+
)
|
84
|
+
else:
|
85
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
86
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
87
|
+
|
88
|
+
if self.is_decoder:
|
89
|
+
present_key_value = (key_states, value_states)
|
90
|
+
else:
|
91
|
+
present_key_value = None
|
92
|
+
|
93
|
+
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
94
|
+
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
95
|
+
key_states = key_states.reshape(*proj_shape)
|
96
|
+
value_states = value_states.reshape(*proj_shape)
|
97
|
+
|
98
|
+
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
99
|
+
src_len = key_states.size(1)
|
100
|
+
if attention_mask is not None:
|
101
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
102
|
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
103
|
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
104
|
+
|
105
|
+
attn_output = torch.bmm(attn_weights, value_states)
|
106
|
+
|
107
|
+
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
108
|
+
attn_output = attn_output.transpose(1, 2)
|
109
|
+
|
110
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
111
|
+
attn_output = self.out_proj(attn_output)
|
112
|
+
|
113
|
+
return attn_output, None, present_key_value
|
114
|
+
|
115
|
+
|
116
|
+
class _WhisperSdpaAttention(WhisperSdpaAttention):
|
117
|
+
def forward(
|
118
|
+
self,
|
119
|
+
hidden_states: torch.Tensor,
|
120
|
+
key_value_states: Optional[torch.Tensor] = None,
|
121
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
122
|
+
attention_mask: Optional[torch.Tensor] = None,
|
123
|
+
cache_position: Optional[torch.Tensor] = None,
|
124
|
+
**kwargs,
|
125
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
126
|
+
|
127
|
+
bsz, tgt_len, _ = hidden_states.size()
|
128
|
+
|
129
|
+
is_cross_attention = key_value_states is not None
|
130
|
+
|
131
|
+
query_states = self.q_proj(hidden_states)
|
132
|
+
|
133
|
+
if is_cross_attention:
|
134
|
+
is_dummy_decoder = len(key_value_states.shape) > 1
|
135
|
+
if is_dummy_decoder:
|
136
|
+
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
137
|
+
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
138
|
+
else:
|
139
|
+
key_states = past_key_value[0]
|
140
|
+
value_states = past_key_value[1]
|
141
|
+
else:
|
142
|
+
if self.is_decoder:
|
143
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
144
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
145
|
+
key_states = past_key_value[0].slice_scatter(
|
146
|
+
key_states, dim=2, start=cache_position, end=cache_position + 1
|
147
|
+
)
|
148
|
+
value_states = past_key_value[1].slice_scatter(
|
149
|
+
value_states, dim=2, start=cache_position, end=cache_position + 1
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
153
|
+
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
154
|
+
|
155
|
+
if self.is_decoder:
|
156
|
+
present_key_value = (key_states, value_states)
|
157
|
+
else:
|
158
|
+
present_key_value = None
|
159
|
+
|
160
|
+
query_states = self._shape(query_states, tgt_len, bsz)
|
161
|
+
|
162
|
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
163
|
+
query_states,
|
164
|
+
key_states,
|
165
|
+
value_states,
|
166
|
+
attn_mask=attention_mask,
|
167
|
+
dropout_p=0.0,
|
168
|
+
is_causal=self.is_causal and attention_mask is None and tgt_len > 1,
|
169
|
+
)
|
170
|
+
|
171
|
+
attn_output = attn_output.transpose(1, 2)
|
172
|
+
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
173
|
+
|
174
|
+
attn_output = self.out_proj(attn_output)
|
175
|
+
|
176
|
+
return attn_output, None, present_key_value
|
177
|
+
|
178
|
+
|
179
|
+
ATTN_FORWARD_MAP = {"eager": _WhisperAttention.forward, "sdpa": _WhisperSdpaAttention.forward}
|
180
|
+
|
181
|
+
|
182
|
+
class _WhisperDecoderLayer(WhisperDecoderLayer):
|
183
|
+
def forward(
|
184
|
+
self,
|
185
|
+
hidden_states: torch.Tensor,
|
186
|
+
attention_mask: Optional[torch.Tensor] = None,
|
187
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
188
|
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
189
|
+
cache_position: Optional[torch.Tensor] = None,
|
190
|
+
attn_impl: str = "eager",
|
191
|
+
) -> torch.Tensor:
|
192
|
+
|
193
|
+
# Self Attention Block
|
194
|
+
residual = hidden_states
|
195
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
196
|
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
197
|
+
|
198
|
+
hidden_states, _, present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
199
|
+
self.self_attn,
|
200
|
+
hidden_states=hidden_states,
|
201
|
+
past_key_value=self_attn_past_key_value,
|
202
|
+
attention_mask=attention_mask,
|
203
|
+
cache_position=cache_position,
|
204
|
+
)
|
205
|
+
hidden_states = residual + hidden_states
|
206
|
+
|
207
|
+
# Cross-Attention Block
|
208
|
+
residual = hidden_states
|
209
|
+
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
210
|
+
cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
|
211
|
+
|
212
|
+
hidden_states, _, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
213
|
+
self.encoder_attn,
|
214
|
+
hidden_states=hidden_states,
|
215
|
+
key_value_states=encoder_hidden_states,
|
216
|
+
past_key_value=cross_attn_past_key_value,
|
217
|
+
cache_position=cache_position,
|
218
|
+
)
|
219
|
+
hidden_states = residual + hidden_states
|
220
|
+
present_key_value = present_key_value + cross_attn_present_key_value
|
221
|
+
|
222
|
+
# Fully Connected Block
|
223
|
+
residual = hidden_states
|
224
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
225
|
+
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
226
|
+
hidden_states = self.fc2(hidden_states)
|
227
|
+
hidden_states = residual + hidden_states
|
228
|
+
|
229
|
+
return hidden_states, present_key_value
|
230
|
+
|
231
|
+
|
232
|
+
class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
|
233
|
+
def forward(self, input_ids, past_key_values_length=0, position_ids=None):
|
234
|
+
if position_ids is None:
|
235
|
+
return self.weight[past_key_values_length : past_key_values_length + input_ids.shape[1]]
|
236
|
+
else:
|
237
|
+
return self.weight[position_ids]
|
238
|
+
|
239
|
+
|
240
|
+
class _WhisperDecoder(WhisperDecoder):
|
241
|
+
def forward(
|
242
|
+
self,
|
243
|
+
input_ids: Optional[torch.Tensor] = None,
|
244
|
+
attention_mask: Optional[torch.Tensor] = None,
|
245
|
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
246
|
+
past_key_values: Optional[torch.Tensor] = None,
|
247
|
+
cache_position: Optional[torch.Tensor] = None,
|
248
|
+
attn_impl: str = "eager",
|
249
|
+
**kwargs,
|
250
|
+
):
|
251
|
+
|
252
|
+
input_shape = input_ids.size()
|
253
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
254
|
+
|
255
|
+
# positional embeding
|
256
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
257
|
+
positions = _WhisperPositionalEmbedding.forward(
|
258
|
+
self.embed_positions, input_ids, cache_position, cache_position
|
259
|
+
)
|
260
|
+
hidden_states = inputs_embeds + positions
|
261
|
+
|
262
|
+
# prepare casual_attn_mask
|
263
|
+
if self._use_sdpa:
|
264
|
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
265
|
+
attention_mask, input_shape, inputs_embeds, cache_position
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
attention_mask = _prepare_4d_causal_attention_mask(
|
269
|
+
attention_mask, input_shape, inputs_embeds, cache_position
|
270
|
+
)
|
271
|
+
|
272
|
+
next_decoder_cache = ()
|
273
|
+
# iterate decoder_layer
|
274
|
+
for idx, decoder_layer in enumerate(self.layers):
|
275
|
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
276
|
+
layer_outputs = _WhisperDecoderLayer.forward(
|
277
|
+
decoder_layer,
|
278
|
+
hidden_states,
|
279
|
+
attention_mask=attention_mask,
|
280
|
+
encoder_hidden_states=encoder_hidden_states,
|
281
|
+
past_key_value=past_key_value,
|
282
|
+
cache_position=cache_position,
|
283
|
+
attn_impl=attn_impl,
|
284
|
+
)
|
285
|
+
hidden_states = layer_outputs[0]
|
286
|
+
|
287
|
+
next_decoder_cache += (layer_outputs[1],)
|
288
|
+
|
289
|
+
# layer_norm
|
290
|
+
hidden_states = self.layer_norm(hidden_states)
|
291
|
+
|
292
|
+
return BaseModelOutputWithPastAndCrossAttentions(
|
293
|
+
last_hidden_state=hidden_states,
|
294
|
+
past_key_values=next_decoder_cache,
|
295
|
+
)
|
296
|
+
|
297
|
+
|
298
|
+
class _WhisperDecoderWrapper(torch.nn.Module):
|
299
|
+
def __init__(self, model):
|
300
|
+
super().__init__()
|
301
|
+
self.proj_out = model.proj_out
|
302
|
+
self.config = model.config
|
303
|
+
self.decoder = model.get_decoder()
|
304
|
+
self.num_layers = self.config.decoder_layers
|
305
|
+
self.attn_impl = self.config._attn_implementation
|
306
|
+
|
307
|
+
def forward(
|
308
|
+
self,
|
309
|
+
decoder_input_ids: torch.Tensor,
|
310
|
+
decoder_attention_mask: torch.Tensor,
|
311
|
+
cache_position: torch.Tensor,
|
312
|
+
self_kv_cache: torch.Tensor,
|
313
|
+
cross_kv_cache: torch.Tensor,
|
314
|
+
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
315
|
+
|
316
|
+
# prepare past_key_values
|
317
|
+
kv_cache = ()
|
318
|
+
for i in range(0, self.num_layers * 2, 2):
|
319
|
+
kv_cache = kv_cache + (
|
320
|
+
(
|
321
|
+
self_kv_cache[i],
|
322
|
+
self_kv_cache[i + 1],
|
323
|
+
cross_kv_cache[i],
|
324
|
+
cross_kv_cache[i + 1],
|
325
|
+
),
|
326
|
+
)
|
327
|
+
|
328
|
+
# Decode
|
329
|
+
decoder_outputs = _WhisperDecoder.forward(
|
330
|
+
self.decoder,
|
331
|
+
input_ids=decoder_input_ids,
|
332
|
+
attention_mask=decoder_attention_mask,
|
333
|
+
cache_position=cache_position,
|
334
|
+
past_key_values=kv_cache,
|
335
|
+
encoder_hidden_states=torch.tensor([1]),
|
336
|
+
attn_impl=self.attn_impl,
|
337
|
+
)
|
338
|
+
sequence_output = decoder_outputs[0]
|
339
|
+
lm_logits = self.proj_out(sequence_output)
|
340
|
+
|
341
|
+
# get self_kv_cache from ouputs
|
342
|
+
past_key_values = decoder_outputs[1]
|
343
|
+
self_kv_cache = []
|
344
|
+
for i in range(self.config.decoder_layers):
|
345
|
+
self_kv_cache.append(past_key_values[i][0])
|
346
|
+
self_kv_cache.append(past_key_values[i][1])
|
347
|
+
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
348
|
+
|
349
|
+
return lm_logits, self_kv_cache
|
350
|
+
|
351
|
+
|
352
|
+
class _WhisperEncoderWrapper(torch.nn.Module):
|
353
|
+
def __init__(self, model):
|
354
|
+
super().__init__()
|
355
|
+
self.model = model
|
356
|
+
self.config = model.config
|
357
|
+
self.decoder = model.get_decoder()
|
358
|
+
self.encoder = model.get_encoder()
|
359
|
+
self.num_layers = self.config.decoder_layers
|
360
|
+
self.decoder_max_length = self.config.max_target_positions
|
361
|
+
self.encoder_max_length = self.config.max_source_positions
|
362
|
+
self.num_heads = self.config.decoder_attention_heads
|
363
|
+
self.d_kv = self.config.d_model // self.num_heads
|
364
|
+
self.attn_impl = self.config._attn_implementation
|
365
|
+
|
366
|
+
def forward(
|
367
|
+
self,
|
368
|
+
input_features: Optional[torch.LongTensor] = None,
|
369
|
+
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
370
|
+
|
371
|
+
encoder_outputs = self.encoder(input_features=input_features)
|
372
|
+
last_hidden_states = encoder_outputs[0]
|
373
|
+
|
374
|
+
encoder_batch_size = input_features.shape[0]
|
375
|
+
decoder_batch_size = encoder_batch_size # TODO fix in future
|
376
|
+
|
377
|
+
dummy_past_key_value = []
|
378
|
+
for _ in range(self.num_layers):
|
379
|
+
pkv_self_attn_key = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
|
380
|
+
pkv_self_attn_value = torch.zeros(decoder_batch_size, self.num_heads, self.decoder_max_length, self.d_kv)
|
381
|
+
pkv_cross_attn_key = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
|
382
|
+
pkv_cross_attn_value = torch.zeros(encoder_batch_size, self.num_heads, self.encoder_max_length, self.d_kv)
|
383
|
+
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
384
|
+
dummy_past_key_value.append(layer_pkv)
|
385
|
+
|
386
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, self.decoder_max_length, dtype=torch.int64)
|
387
|
+
decoder_attention_mask[:, :1] = 1
|
388
|
+
|
389
|
+
decoder_outputs = _WhisperDecoder.forward(
|
390
|
+
self.decoder,
|
391
|
+
input_ids=torch.zeros((decoder_batch_size, 1), dtype=torch.int64),
|
392
|
+
attention_mask=decoder_attention_mask,
|
393
|
+
cache_position=torch.tensor(0, dtype=torch.int32),
|
394
|
+
encoder_hidden_states=last_hidden_states,
|
395
|
+
past_key_values=dummy_past_key_value,
|
396
|
+
attn_impl=self.attn_impl,
|
397
|
+
)
|
398
|
+
|
399
|
+
first_past_kv = decoder_outputs[1]
|
400
|
+
|
401
|
+
encoder_kv = []
|
402
|
+
for layer_out in first_past_kv: # for layer
|
403
|
+
encoder_kv.append(torch.stack(layer_out[2:], dim=0))
|
404
|
+
encoder_kv = torch.stack(encoder_kv, dim=0)
|
405
|
+
|
406
|
+
return encoder_kv
|
@@ -0,0 +1,25 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from .import_utils import is_rbln_available
|
25
|
+
from .runtime_utils import RBLNPytorchRuntime
|
@@ -0,0 +1,28 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import importlib.util
|
25
|
+
|
26
|
+
|
27
|
+
def is_rbln_available() -> bool:
|
28
|
+
return importlib.util.find_spec("rebel-compiler") is not None
|
@@ -0,0 +1,71 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import Any, Dict, List
|
25
|
+
|
26
|
+
import rebel
|
27
|
+
import torch
|
28
|
+
|
29
|
+
|
30
|
+
class RBLNPytorchRuntime:
|
31
|
+
mandatory_members = []
|
32
|
+
|
33
|
+
def __init__(self, runtime: rebel.Runtime, **kwargs) -> None:
|
34
|
+
self.runtime = runtime
|
35
|
+
for key, value in kwargs.items():
|
36
|
+
setattr(self, key, value)
|
37
|
+
for mandatory_member in __class__.mandatory_members:
|
38
|
+
if mandatory_member not in kwargs:
|
39
|
+
raise AttributeError(f"`{mandatory_member}` should be assigned to {__class__.__name__} objects.")
|
40
|
+
|
41
|
+
def __call__(self, *args: Any, **kwds: Any) -> Any:
|
42
|
+
return self.forward(*args, **kwds)
|
43
|
+
|
44
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
45
|
+
args = list(filter(lambda arg: isinstance(arg, torch.Tensor), args))
|
46
|
+
kwargs = dict(filter(lambda kwarg: isinstance(kwarg[1], torch.Tensor), kwargs.items()))
|
47
|
+
output = self.runtime(*args, **kwargs)
|
48
|
+
return output
|
49
|
+
|
50
|
+
def __repr__(self) -> str:
|
51
|
+
return repr(self.runtime)
|
52
|
+
|
53
|
+
|
54
|
+
class UnavailableRuntime:
|
55
|
+
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
56
|
+
raise self.forward(*args, **kwargs)
|
57
|
+
|
58
|
+
def __len__(self) -> int:
|
59
|
+
return 0
|
60
|
+
|
61
|
+
def __getitem__(self, idx: int) -> Any:
|
62
|
+
return self
|
63
|
+
|
64
|
+
def __iter__(self):
|
65
|
+
return iter([self])
|
66
|
+
|
67
|
+
def forward(self, *args: List["torch.Tensor"], **kwargs: Dict[str, "torch.Tensor"]):
|
68
|
+
raise RuntimeError("RBLN-Runtime is not created, So it is not available.")
|
69
|
+
|
70
|
+
def __repr__(self) -> str:
|
71
|
+
return "UnavailableRuntime"
|
@@ -0,0 +1,92 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from pathlib import Path
|
26
|
+
from typing import List, Union
|
27
|
+
|
28
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
29
|
+
|
30
|
+
|
31
|
+
logger = logging.getLogger(__name__)
|
32
|
+
|
33
|
+
|
34
|
+
def maybe_load_preprocessors(
|
35
|
+
src_name_or_path: Union[str, Path], subfolder: str = "", trust_remote_code: bool = False
|
36
|
+
) -> List:
|
37
|
+
preprocessors = []
|
38
|
+
try:
|
39
|
+
preprocessors.append(
|
40
|
+
AutoTokenizer.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code)
|
41
|
+
)
|
42
|
+
except Exception:
|
43
|
+
pass
|
44
|
+
|
45
|
+
try:
|
46
|
+
preprocessors.append(
|
47
|
+
AutoProcessor.from_pretrained(src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code)
|
48
|
+
)
|
49
|
+
except Exception:
|
50
|
+
pass
|
51
|
+
|
52
|
+
try:
|
53
|
+
preprocessors.append(
|
54
|
+
AutoFeatureExtractor.from_pretrained(
|
55
|
+
src_name_or_path, subfolder=subfolder, trust_remote_code=trust_remote_code
|
56
|
+
)
|
57
|
+
)
|
58
|
+
except Exception:
|
59
|
+
pass
|
60
|
+
return preprocessors
|
61
|
+
|
62
|
+
|
63
|
+
def maybe_save_preprocessors(
|
64
|
+
src_name_or_path: Union[str, Path],
|
65
|
+
dest_dir: Union[str, Path],
|
66
|
+
src_subfolder: str = "",
|
67
|
+
trust_remote_code: bool = False,
|
68
|
+
):
|
69
|
+
"""
|
70
|
+
Saves the tokenizer, the processor and the feature extractor when found in `src_dir` in `dest_dir`.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
src_dir (`Union[str, Path]`):
|
74
|
+
The source directory from which to copy the files.
|
75
|
+
dest_dir (`Union[str, Path]`):
|
76
|
+
The destination directory to copy the files to.
|
77
|
+
src_subfolder (`str`, defaults to `""`):
|
78
|
+
In case the preprocessor files are located inside a subfolder of the model directory / repo on the Hugging
|
79
|
+
Face Hub, you can specify the subfolder name here.
|
80
|
+
trust_remote_code (`bool`, defaults to `False`):
|
81
|
+
Whether to allow to save preprocessors that is allowed to run arbitrary code. Use this option at your own risk.
|
82
|
+
"""
|
83
|
+
if not isinstance(dest_dir, Path):
|
84
|
+
dest_dir = Path(dest_dir)
|
85
|
+
|
86
|
+
dest_dir.mkdir(exist_ok=True)
|
87
|
+
preprocessors = maybe_load_preprocessors(
|
88
|
+
src_name_or_path, subfolder=src_subfolder, trust_remote_code=trust_remote_code
|
89
|
+
)
|
90
|
+
for preprocessor in preprocessors:
|
91
|
+
preprocessor.save_pretrained(dest_dir)
|
92
|
+
return preprocessors
|