optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -30,7 +30,6 @@ import torch
|
|
30
30
|
from transformers import (
|
31
31
|
AutoModelForSpeechSeq2Seq,
|
32
32
|
AutoProcessor,
|
33
|
-
GenerationMixin,
|
34
33
|
PretrainedConfig,
|
35
34
|
WhisperForConditionalGeneration,
|
36
35
|
WhisperModel,
|
@@ -38,8 +37,9 @@ from transformers import (
|
|
38
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
39
38
|
|
40
39
|
from ....modeling_base import RBLNModel
|
41
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
40
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
42
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
42
|
+
from .generation_whisper import RBLNWhisperGenerationMixin
|
43
43
|
from .whisper_architecture import (
|
44
44
|
_WhisperDecoderWrapper,
|
45
45
|
_WhisperEncoderWrapper,
|
@@ -59,20 +59,47 @@ if TYPE_CHECKING:
|
|
59
59
|
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
60
60
|
mandatory_members = ["main_input_name"]
|
61
61
|
|
62
|
-
def forward(self,
|
63
|
-
|
64
|
-
|
62
|
+
def forward(self, input_features: torch.Tensor = None):
|
63
|
+
# backward compatibility transformers==4.40.2
|
64
|
+
# https://github.com/huggingface/transformers/blob/4fdf58afb72b0754da30037fc800b6044e7d9c99/src/transformers/pipelines/automatic_speech_recognition.py#L494
|
65
|
+
|
66
|
+
n_pad_to_batch = self.batch_size - input_features.shape[0]
|
67
|
+
if n_pad_to_batch > 0:
|
68
|
+
input_features = torch.nn.functional.pad(input_features, (0, 0, 0, 0, 0, n_pad_to_batch))
|
69
|
+
|
70
|
+
_ = super().forward(input_features=input_features)
|
71
|
+
|
72
|
+
# dummy output for generation
|
73
|
+
return BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
65
74
|
|
66
75
|
|
67
76
|
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
68
77
|
mandatory_members = ["main_input_name"]
|
69
78
|
|
70
|
-
def forward(
|
71
|
-
|
72
|
-
|
79
|
+
def forward(
|
80
|
+
self,
|
81
|
+
decoder_input_ids: torch.Tensor = None,
|
82
|
+
decoder_attention_mask: torch.Tensor = None,
|
83
|
+
cache_position: torch.Tensor = None,
|
84
|
+
):
|
85
|
+
inputs_bsz = decoder_input_ids.shape[0]
|
86
|
+
padded_bsz = self.batch_size - inputs_bsz
|
87
|
+
if padded_bsz > 0:
|
88
|
+
decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
|
89
|
+
|
90
|
+
outputs = super().forward(
|
91
|
+
decoder_input_ids=decoder_input_ids,
|
92
|
+
decoder_attention_mask=decoder_attention_mask,
|
93
|
+
cache_position=cache_position,
|
94
|
+
)
|
95
|
+
|
96
|
+
if isinstance(outputs, torch.Tensor):
|
97
|
+
return Seq2SeqLMOutput(logits=outputs[:inputs_bsz], cross_attentions=None)
|
98
|
+
else:
|
99
|
+
return Seq2SeqLMOutput(logits=outputs[0][:inputs_bsz], cross_attentions=outputs[1][:, :inputs_bsz])
|
73
100
|
|
74
101
|
|
75
|
-
class RBLNWhisperForConditionalGeneration(RBLNModel,
|
102
|
+
class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
|
76
103
|
"""
|
77
104
|
The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
|
78
105
|
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
@@ -83,20 +110,30 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
83
110
|
- compiling the resulting graph using the RBLN compiler.
|
84
111
|
"""
|
85
112
|
|
86
|
-
model_type = "rbln_model"
|
87
113
|
auto_model_class = AutoModelForSpeechSeq2Seq
|
88
114
|
main_input_name = "input_ids"
|
89
115
|
|
90
116
|
def __post_init__(self, **kwargs):
|
91
|
-
|
92
|
-
self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
|
93
|
-
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
117
|
+
super().__post_init__(**kwargs)
|
94
118
|
|
95
|
-
self.
|
96
|
-
self.
|
97
|
-
self.
|
119
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
120
|
+
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
121
|
+
self.rbln_token_timestamps = self.rbln_config.model_cfg["token_timestamps"]
|
122
|
+
|
123
|
+
self.encoder = RBLNRuntimeEncoder(
|
124
|
+
runtime=self.model[0], main_input_name="input_features", batch_size=self.batch_size
|
125
|
+
)
|
126
|
+
self.decoder = RBLNRuntimeDecoder(
|
127
|
+
runtime=self.model[1], main_input_name="input_ids", batch_size=self.batch_size
|
128
|
+
)
|
98
129
|
|
99
|
-
#
|
130
|
+
# skip encoder & first decoder when language detected
|
131
|
+
self.is_language_detected = False
|
132
|
+
self.language_cross = None
|
133
|
+
|
134
|
+
# Used in GenerationMixin.generate()
|
135
|
+
# transformers/models/whisper/generation_whisper.py, line 505, in generate
|
136
|
+
# input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
|
100
137
|
self.model = WhisperModel(self.config)
|
101
138
|
self.pad_token_id = self.config.pad_token_id
|
102
139
|
|
@@ -127,63 +164,32 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
127
164
|
# TODO(jongho): implement
|
128
165
|
raise NotImplementedError
|
129
166
|
|
130
|
-
def prepare_inputs_for_generation(
|
131
|
-
self,
|
132
|
-
input_ids,
|
133
|
-
decoder_attention_mask=None,
|
134
|
-
input_features=None, # Must be explicit
|
135
|
-
**kwargs,
|
136
|
-
):
|
137
|
-
max_seq_len = self.dec_max_seq_len
|
138
|
-
cur_seq_len = input_ids.shape[-1]
|
139
|
-
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
140
|
-
decoder_attention_mask = torch.zeros(self.batch_size, max_seq_len, dtype=torch.int64)
|
141
|
-
decoder_attention_mask[:, :cur_seq_len] = 1
|
142
|
-
cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
|
143
|
-
|
144
|
-
return {
|
145
|
-
"decoder_input_ids": input_ids,
|
146
|
-
"decoder_attention_mask": decoder_attention_mask,
|
147
|
-
"cache_position": cache_position,
|
148
|
-
}
|
149
|
-
|
150
|
-
@classmethod
|
151
|
-
def update_kwargs(cls, kwargs):
|
152
|
-
kwargs.update(
|
153
|
-
{
|
154
|
-
"torchscript": True,
|
155
|
-
"return_dict": False,
|
156
|
-
"use_cache": True,
|
157
|
-
}
|
158
|
-
)
|
159
|
-
return kwargs
|
160
|
-
|
161
167
|
@classmethod
|
162
168
|
@torch.inference_mode()
|
163
169
|
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
170
|
+
rbln_token_timestamps = rbln_config.model_cfg["token_timestamps"]
|
164
171
|
wrapped_encoder = _WhisperEncoderWrapper(model).eval()
|
165
|
-
wrapped_decoder = _WhisperDecoderWrapper(model).eval()
|
172
|
+
wrapped_decoder = _WhisperDecoderWrapper(model, output_attentions=rbln_token_timestamps).eval()
|
166
173
|
|
167
|
-
|
168
|
-
|
174
|
+
enc_rbln_compile_config = rbln_config.compile_cfgs[0]
|
175
|
+
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
169
176
|
|
170
|
-
enc_example_inputs =
|
171
|
-
dec_example_inputs =
|
177
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
|
178
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
|
172
179
|
|
173
|
-
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs
|
180
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
174
181
|
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
175
182
|
|
176
183
|
enc_ir = rebel.torchscript_to_ir(
|
177
184
|
enc_scripted_model,
|
178
|
-
input_names=[v[0] for v in
|
179
|
-
name=
|
185
|
+
input_names=[v[0] for v in enc_rbln_compile_config.input_info],
|
186
|
+
name=enc_rbln_compile_config.mod_name,
|
180
187
|
)
|
181
188
|
dec_ir = rebel.torchscript_to_ir(
|
182
189
|
dec_scripted_model,
|
183
|
-
input_names=[v[0] for v in
|
184
|
-
name=
|
190
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
191
|
+
name=dec_rbln_compile_config.mod_name,
|
185
192
|
)
|
186
|
-
dec_ir.batch_size = dec_rbln_runtime_config.batch_size
|
187
193
|
|
188
194
|
# Caching encoder/decoder I/O
|
189
195
|
connections = [
|
@@ -194,9 +200,9 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
194
200
|
enc_ir,
|
195
201
|
dec_ir,
|
196
202
|
connections=connections,
|
197
|
-
fusion=
|
198
|
-
npu=
|
199
|
-
tensor_parallel_size=
|
203
|
+
fusion=enc_rbln_compile_config.fusion,
|
204
|
+
npu=enc_rbln_compile_config.npu,
|
205
|
+
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
200
206
|
)
|
201
207
|
return compiled_model
|
202
208
|
|
@@ -205,42 +211,26 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
205
211
|
cls,
|
206
212
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
|
207
213
|
model_config: "PretrainedConfig",
|
208
|
-
|
214
|
+
rbln_kwargs: Dict[str, Any] = {},
|
209
215
|
) -> RBLNConfig:
|
210
|
-
|
211
|
-
|
212
|
-
input_max_length = 3000
|
213
|
-
rbln_enc_num_mel_bins = getattr(model_config, "num_mel_bins", None)
|
214
|
-
if rbln_enc_num_mel_bins is None:
|
215
|
-
for feature_extractor in preprocessors:
|
216
|
-
if hasattr(feature_extractor, "feature_size"):
|
217
|
-
rbln_enc_num_mel_bins = feature_extractor.feature_size
|
218
|
-
break
|
219
|
-
raise ValueError("`rbln_enc_num_mel_bins` should be specified!")
|
220
|
-
|
221
|
-
rbln_enc_max_seq_len = getattr(model_config, "max_source_positions", None)
|
222
|
-
if rbln_enc_max_seq_len is None:
|
223
|
-
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
224
|
-
|
225
|
-
rbln_dec_max_seq_len = getattr(model_config, "max_length", None)
|
226
|
-
if rbln_dec_max_seq_len is None:
|
227
|
-
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
228
|
-
|
216
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
217
|
+
rbln_token_timestamps = rbln_kwargs.get("token_timestamps", False)
|
229
218
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
230
|
-
decoder_batch_size = rbln_batch_size
|
231
219
|
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
220
|
+
expected_seq_len = model_config.max_source_positions * 2
|
221
|
+
num_mel_bins = model_config.num_mel_bins
|
222
|
+
enc_max_seq_len = model_config.max_source_positions
|
223
|
+
|
224
|
+
# 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
|
225
|
+
rbln_dec_max_seq_len = getattr(model_config, "max_target_positions", None)
|
226
|
+
if rbln_dec_max_seq_len is None:
|
227
|
+
rbln_dec_max_seq_len = model_config.max_length
|
238
228
|
|
239
229
|
# model input info
|
240
|
-
enc_input_info = [("input_features", [rbln_batch_size,
|
230
|
+
enc_input_info = [("input_features", [rbln_batch_size, num_mel_bins, expected_seq_len], "float32")]
|
241
231
|
dec_input_info = [
|
242
|
-
("decoder_input_ids", [
|
243
|
-
("decoder_attention_mask", [
|
232
|
+
("decoder_input_ids", [rbln_batch_size, 1], "int64"),
|
233
|
+
("decoder_attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
|
244
234
|
("cache_position", [], "int32"),
|
245
235
|
]
|
246
236
|
dec_input_info.extend(
|
@@ -249,7 +239,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
249
239
|
"self_key_value_states",
|
250
240
|
[
|
251
241
|
model_config.decoder_layers * 2,
|
252
|
-
|
242
|
+
rbln_batch_size,
|
253
243
|
model_config.decoder_attention_heads,
|
254
244
|
rbln_dec_max_seq_len,
|
255
245
|
model_config.d_model // model_config.encoder_attention_heads,
|
@@ -266,7 +256,7 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
266
256
|
model_config.decoder_layers * 2,
|
267
257
|
rbln_batch_size,
|
268
258
|
model_config.decoder_attention_heads,
|
269
|
-
|
259
|
+
enc_max_seq_len,
|
270
260
|
model_config.d_model // model_config.encoder_attention_heads,
|
271
261
|
],
|
272
262
|
"float32",
|
@@ -274,15 +264,21 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
274
264
|
]
|
275
265
|
)
|
276
266
|
|
277
|
-
|
278
|
-
|
267
|
+
enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
|
268
|
+
dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
|
279
269
|
|
280
|
-
|
281
|
-
|
270
|
+
rbln_config = RBLNConfig(
|
271
|
+
rbln_cls=cls.__name__,
|
272
|
+
compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
|
273
|
+
rbln_kwargs=rbln_kwargs,
|
274
|
+
)
|
282
275
|
|
283
|
-
rbln_config
|
284
|
-
|
285
|
-
|
276
|
+
rbln_config.model_cfg.update(
|
277
|
+
{
|
278
|
+
"batch_size": rbln_batch_size,
|
279
|
+
"dec_max_seq_len": rbln_dec_max_seq_len,
|
280
|
+
"token_timestamps": rbln_token_timestamps,
|
281
|
+
}
|
286
282
|
)
|
287
283
|
|
288
284
|
return rbln_config
|
@@ -297,18 +293,83 @@ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
|
297
293
|
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
298
294
|
]
|
299
295
|
|
296
|
+
def prepare_inputs_for_generation(
|
297
|
+
self,
|
298
|
+
input_ids,
|
299
|
+
cache_position: Optional[torch.Tensor] = None,
|
300
|
+
attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
|
301
|
+
**kwargs,
|
302
|
+
):
|
303
|
+
"""
|
304
|
+
whisper don't use attention_mask,
|
305
|
+
attention_mask (`torch.Tensor`)`, *optional*):
|
306
|
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
307
|
+
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
308
|
+
"""
|
309
|
+
return {
|
310
|
+
"input_ids": input_ids,
|
311
|
+
"cache_position": cache_position,
|
312
|
+
}
|
313
|
+
|
314
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
|
315
|
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
316
|
+
self, inputs_tensor: torch.Tensor, model_kwargs, *args, **kwargs
|
317
|
+
) -> Dict[str, Any]:
|
318
|
+
if not self.is_language_detected:
|
319
|
+
model_kwargs["encoder_outputs"] = self.encoder(input_features=inputs_tensor)
|
320
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
|
321
|
+
else:
|
322
|
+
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
|
323
|
+
|
324
|
+
return model_kwargs
|
325
|
+
|
300
326
|
def forward(
|
301
327
|
self,
|
302
|
-
|
303
|
-
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
328
|
+
input_ids: Optional[torch.LongTensor] = None,
|
304
329
|
cache_position: Optional[torch.Tensor] = None,
|
330
|
+
input_features: Optional[torch.Tensor] = None,
|
331
|
+
decoder_input_ids: Optional[torch.Tensor] = None,
|
332
|
+
encoder_outputs: Optional[Seq2SeqLMOutput] = None,
|
305
333
|
**kwargs,
|
306
334
|
) -> Seq2SeqLMOutput:
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
cache_position
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
335
|
+
# default decoder pass
|
336
|
+
if input_features is None and encoder_outputs is None:
|
337
|
+
cross_attentions = []
|
338
|
+
for step in cache_position:
|
339
|
+
# skip step 0 if language_detection has been processed
|
340
|
+
if step == 0 and self.is_language_detected:
|
341
|
+
cross_attentions.append(self.language_cross)
|
342
|
+
self.is_language_detected = False
|
343
|
+
else:
|
344
|
+
self.decoder_attention_mask[:, step] = 1
|
345
|
+
decoder_output = self.decoder(
|
346
|
+
decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
|
347
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
348
|
+
cache_position=step.to(torch.int32),
|
349
|
+
)
|
350
|
+
cross_attentions.append(decoder_output.cross_attentions)
|
351
|
+
lm_logits = decoder_output.logits
|
352
|
+
|
353
|
+
if self.rbln_token_timestamps:
|
354
|
+
cross_attentions = torch.cat(cross_attentions, dim=-2)
|
355
|
+
else:
|
356
|
+
cross_attentions = None
|
357
|
+
|
358
|
+
return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
|
359
|
+
|
360
|
+
# detect language pass
|
361
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
|
362
|
+
else:
|
363
|
+
if encoder_outputs is None:
|
364
|
+
self.encoder(input_features=input_features.contiguous())
|
365
|
+
self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.int64)
|
366
|
+
self.is_language_detected = True
|
367
|
+
self.decoder_attention_mask[:, 0] = 1
|
368
|
+
decoder_output = self.decoder(
|
369
|
+
decoder_input_ids=decoder_input_ids.contiguous(),
|
370
|
+
decoder_attention_mask=self.decoder_attention_mask,
|
371
|
+
cache_position=torch.zeros([], dtype=torch.int32),
|
372
|
+
)
|
373
|
+
lm_logits = decoder_output.logits
|
374
|
+
self.language_cross = decoder_output.cross_attentions
|
375
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
@@ -55,7 +55,6 @@ class _WhisperAttention(WhisperAttention):
|
|
55
55
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
56
56
|
attention_mask: Optional[torch.Tensor] = None,
|
57
57
|
cache_position: Optional[torch.Tensor] = None,
|
58
|
-
**kwargs,
|
59
58
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
60
59
|
bsz, tgt_len, _ = hidden_states.size()
|
61
60
|
is_cross_attention = key_value_states is not None
|
@@ -99,6 +98,7 @@ class _WhisperAttention(WhisperAttention):
|
|
99
98
|
if attention_mask is not None:
|
100
99
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
101
100
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
101
|
+
|
102
102
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
103
103
|
|
104
104
|
attn_output = torch.bmm(attn_weights, value_states)
|
@@ -109,7 +109,9 @@ class _WhisperAttention(WhisperAttention):
|
|
109
109
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
110
110
|
attn_output = self.out_proj(attn_output)
|
111
111
|
|
112
|
-
|
112
|
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
113
|
+
|
114
|
+
return attn_output, attn_weights, present_key_value
|
113
115
|
|
114
116
|
|
115
117
|
class _WhisperSdpaAttention(WhisperSdpaAttention):
|
@@ -186,6 +188,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
186
188
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
187
189
|
cache_position: Optional[torch.Tensor] = None,
|
188
190
|
attn_impl: str = "eager",
|
191
|
+
output_attentions: bool = False,
|
189
192
|
) -> torch.Tensor:
|
190
193
|
# Self Attention Block
|
191
194
|
residual = hidden_states
|
@@ -205,14 +208,22 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
205
208
|
residual = hidden_states
|
206
209
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
207
210
|
cross_attn_past_key_value = past_key_value[2:] if past_key_value is not None else None
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
211
|
+
if output_attentions:
|
212
|
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = _WhisperAttention.forward(
|
213
|
+
self.encoder_attn,
|
214
|
+
hidden_states=hidden_states,
|
215
|
+
key_value_states=encoder_hidden_states,
|
216
|
+
past_key_value=cross_attn_past_key_value,
|
217
|
+
cache_position=cache_position,
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
hidden_states, cross_attn_weights, cross_attn_present_key_value = ATTN_FORWARD_MAP[attn_impl](
|
221
|
+
self.encoder_attn,
|
222
|
+
hidden_states=hidden_states,
|
223
|
+
key_value_states=encoder_hidden_states,
|
224
|
+
past_key_value=cross_attn_past_key_value,
|
225
|
+
cache_position=cache_position,
|
226
|
+
)
|
216
227
|
hidden_states = residual + hidden_states
|
217
228
|
present_key_value = present_key_value + cross_attn_present_key_value
|
218
229
|
|
@@ -223,7 +234,7 @@ class _WhisperDecoderLayer(WhisperDecoderLayer):
|
|
223
234
|
hidden_states = self.fc2(hidden_states)
|
224
235
|
hidden_states = residual + hidden_states
|
225
236
|
|
226
|
-
return hidden_states, present_key_value
|
237
|
+
return hidden_states, present_key_value, cross_attn_weights
|
227
238
|
|
228
239
|
|
229
240
|
class _WhisperPositionalEmbedding(WhisperPositionalEmbedding):
|
@@ -243,6 +254,7 @@ class _WhisperDecoder(WhisperDecoder):
|
|
243
254
|
past_key_values: Optional[torch.Tensor] = None,
|
244
255
|
cache_position: Optional[torch.Tensor] = None,
|
245
256
|
attn_impl: str = "eager",
|
257
|
+
output_attentions: bool = False,
|
246
258
|
**kwargs,
|
247
259
|
):
|
248
260
|
input_shape = input_ids.size()
|
@@ -266,6 +278,7 @@ class _WhisperDecoder(WhisperDecoder):
|
|
266
278
|
)
|
267
279
|
|
268
280
|
next_decoder_cache = ()
|
281
|
+
all_cross_attentions = () if output_attentions else None
|
269
282
|
# iterate decoder_layer
|
270
283
|
for idx, decoder_layer in enumerate(self.layers):
|
271
284
|
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
@@ -277,10 +290,13 @@ class _WhisperDecoder(WhisperDecoder):
|
|
277
290
|
past_key_value=past_key_value,
|
278
291
|
cache_position=cache_position,
|
279
292
|
attn_impl=attn_impl,
|
293
|
+
output_attentions=output_attentions,
|
280
294
|
)
|
281
295
|
hidden_states = layer_outputs[0]
|
282
296
|
|
283
297
|
next_decoder_cache += (layer_outputs[1],)
|
298
|
+
if output_attentions:
|
299
|
+
all_cross_attentions += (layer_outputs[2],)
|
284
300
|
|
285
301
|
# layer_norm
|
286
302
|
hidden_states = self.layer_norm(hidden_states)
|
@@ -288,17 +304,19 @@ class _WhisperDecoder(WhisperDecoder):
|
|
288
304
|
return BaseModelOutputWithPastAndCrossAttentions(
|
289
305
|
last_hidden_state=hidden_states,
|
290
306
|
past_key_values=next_decoder_cache,
|
307
|
+
cross_attentions=all_cross_attentions,
|
291
308
|
)
|
292
309
|
|
293
310
|
|
294
311
|
class _WhisperDecoderWrapper(torch.nn.Module):
|
295
|
-
def __init__(self, model):
|
312
|
+
def __init__(self, model, output_attentions: bool = False):
|
296
313
|
super().__init__()
|
297
314
|
self.proj_out = model.proj_out
|
298
315
|
self.config = model.config
|
299
316
|
self.decoder = model.get_decoder()
|
300
317
|
self.num_layers = self.config.decoder_layers
|
301
318
|
self.attn_impl = self.config._attn_implementation
|
319
|
+
self.output_attentions = output_attentions
|
302
320
|
|
303
321
|
def forward(
|
304
322
|
self,
|
@@ -329,6 +347,7 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
329
347
|
past_key_values=kv_cache,
|
330
348
|
encoder_hidden_states=torch.tensor([1]),
|
331
349
|
attn_impl=self.attn_impl,
|
350
|
+
output_attentions=self.output_attentions,
|
332
351
|
)
|
333
352
|
sequence_output = decoder_outputs[0]
|
334
353
|
lm_logits = self.proj_out(sequence_output)
|
@@ -341,7 +360,12 @@ class _WhisperDecoderWrapper(torch.nn.Module):
|
|
341
360
|
self_kv_cache.append(past_key_values[i][1])
|
342
361
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
343
362
|
|
344
|
-
|
363
|
+
if self.output_attentions:
|
364
|
+
# deocder's cross attention is used for token_timestamps
|
365
|
+
cross_attention = torch.stack(decoder_outputs[2], dim=0)
|
366
|
+
return lm_logits, self_kv_cache, cross_attention
|
367
|
+
else:
|
368
|
+
return lm_logits, self_kv_cache
|
345
369
|
|
346
370
|
|
347
371
|
class _WhisperEncoderWrapper(torch.nn.Module):
|
@@ -363,6 +387,7 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
363
387
|
input_features: Optional[torch.LongTensor] = None,
|
364
388
|
) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
|
365
389
|
encoder_outputs = self.encoder(input_features=input_features)
|
390
|
+
|
366
391
|
last_hidden_states = encoder_outputs[0]
|
367
392
|
|
368
393
|
encoder_batch_size = input_features.shape[0]
|
@@ -388,13 +413,15 @@ class _WhisperEncoderWrapper(torch.nn.Module):
|
|
388
413
|
encoder_hidden_states=last_hidden_states,
|
389
414
|
past_key_values=dummy_past_key_value,
|
390
415
|
attn_impl=self.attn_impl,
|
416
|
+
output_attentions=False,
|
391
417
|
)
|
392
418
|
|
393
419
|
first_past_kv = decoder_outputs[1]
|
394
420
|
|
395
|
-
|
421
|
+
cross_kv = []
|
396
422
|
for layer_out in first_past_kv: # for layer
|
397
|
-
|
398
|
-
|
423
|
+
cross_kv.append(layer_out[2])
|
424
|
+
cross_kv.append(layer_out[3])
|
425
|
+
cross_kv = torch.stack(cross_kv, dim=0)
|
399
426
|
|
400
|
-
return
|
427
|
+
return cross_kv
|
@@ -22,13 +22,13 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
-
from typing import TYPE_CHECKING, Any, Dict,
|
25
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
26
26
|
|
27
27
|
import torch
|
28
|
-
from transformers import
|
28
|
+
from transformers import PretrainedConfig, PreTrainedModel, XLMRobertaConfig, XLMRobertaModel
|
29
29
|
|
30
30
|
from ....modeling_base import RBLNModel
|
31
|
-
from ....modeling_config import
|
31
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|
34
34
|
logger = logging.getLogger(__name__)
|
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNXLMRobertaModel(RBLNModel):
|
41
|
-
auto_model_class = AutoModel # feature extraction
|
42
41
|
original_model_class = XLMRobertaModel
|
43
42
|
original_config_class = XLMRobertaConfig
|
44
43
|
|
@@ -53,8 +52,7 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
53
52
|
subfolder: str = "",
|
54
53
|
local_files_only: bool = False,
|
55
54
|
trust_remote_code: bool = False,
|
56
|
-
|
57
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
55
|
+
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
58
56
|
**kwargs,
|
59
57
|
) -> "PreTrainedModel":
|
60
58
|
model: "PreTrainedModel" = super().get_pytorch_model(
|
@@ -66,8 +64,7 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
66
64
|
subfolder=subfolder,
|
67
65
|
local_files_only=local_files_only,
|
68
66
|
trust_remote_code=trust_remote_code,
|
69
|
-
|
70
|
-
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
67
|
+
rbln_kwargs=rbln_kwargs,
|
71
68
|
library_name="transformers",
|
72
69
|
)
|
73
70
|
|
@@ -78,10 +75,12 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
78
75
|
cls,
|
79
76
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
80
77
|
model_config: Optional["PretrainedConfig"] = None,
|
81
|
-
|
82
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
83
|
-
rbln_batch_size: Optional[int] = None,
|
78
|
+
rbln_kwargs={},
|
84
79
|
) -> RBLNConfig:
|
80
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
81
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
82
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
83
|
+
|
85
84
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
86
85
|
model_config, "max_position_embeddings", None
|
87
86
|
)
|
@@ -111,12 +110,15 @@ class RBLNXLMRobertaModel(RBLNModel):
|
|
111
110
|
for model_input_name in rbln_model_input_names
|
112
111
|
]
|
113
112
|
|
114
|
-
|
115
|
-
rbln_runtime_config.batch_size = rbln_batch_size
|
116
|
-
|
117
|
-
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
113
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
118
114
|
|
119
|
-
|
115
|
+
rbln_config = RBLNConfig(
|
116
|
+
rbln_cls=cls.__name__,
|
117
|
+
compile_cfgs=[rbln_compile_config],
|
118
|
+
rbln_kwargs=rbln_kwargs,
|
119
|
+
)
|
120
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
121
|
+
return rbln_config
|
120
122
|
|
121
123
|
def forward(
|
122
124
|
self,
|