optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +21 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/__init__.py +0 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/controlnet.py +3 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_alias.py +14 -0
- optimum/rbln/modeling_base.py +282 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/cache_utils.py +111 -0
- optimum/rbln/transformers/generation/utils.py +0 -2
- optimum/rbln/transformers/models/__init__.py +3 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gemma/__init__.py +24 -0
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
- optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
- optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
- optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
- optimum_rbln-0.1.8.dist-info/RECORD +73 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
- optimum_rbln-0.1.4.dist-info/RECORD +0 -63
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,21 +23,12 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from
|
27
|
-
from tempfile import TemporaryDirectory
|
28
|
-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
26
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
29
27
|
|
30
|
-
import
|
31
|
-
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
|
-
from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
34
|
-
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
|
28
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
35
29
|
|
36
|
-
from ....
|
37
|
-
from
|
38
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
39
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
40
|
-
from ...generation.utils import RBLNGenerationMixin
|
30
|
+
from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
|
31
|
+
from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
|
41
32
|
from .hf_hub_cached.modeling_midm import MidmLMHeadModel
|
42
33
|
from .midm_architecture import (
|
43
34
|
MidmLMHeadModelWrapper,
|
@@ -45,7 +36,6 @@ from .midm_architecture import (
|
|
45
36
|
|
46
37
|
|
47
38
|
logger = logging.getLogger(__name__)
|
48
|
-
|
49
39
|
if TYPE_CHECKING:
|
50
40
|
from transformers import (
|
51
41
|
AutoFeatureExtractor,
|
@@ -55,31 +45,12 @@ if TYPE_CHECKING:
|
|
55
45
|
)
|
56
46
|
|
57
47
|
|
58
|
-
class
|
59
|
-
mandatory_members = ["main_input_name"]
|
60
|
-
|
61
|
-
# RBLN_Runtimemodule
|
62
|
-
def forward(
|
63
|
-
self,
|
64
|
-
input_ids: torch.LongTensor = None,
|
65
|
-
attention_mask: torch.LongTensor = None,
|
66
|
-
cache_position: torch.Tensor = None,
|
67
|
-
**kwargs: Dict[str, Any],
|
68
|
-
):
|
69
|
-
logits = super().forward(
|
70
|
-
input_ids=input_ids,
|
71
|
-
attention_mask=attention_mask,
|
72
|
-
cache_position=cache_position,
|
73
|
-
)
|
74
|
-
return logits
|
75
|
-
|
76
|
-
|
77
|
-
class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
48
|
+
class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
|
78
49
|
"""
|
79
50
|
The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
80
51
|
embeddings).
|
81
52
|
|
82
|
-
This model inherits from [`
|
53
|
+
This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
|
83
54
|
library implements for all its model.
|
84
55
|
|
85
56
|
It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
|
@@ -88,46 +59,9 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
88
59
|
|
89
60
|
"""
|
90
61
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
def __init__(
|
96
|
-
self,
|
97
|
-
models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
|
98
|
-
config: PretrainedConfig = None,
|
99
|
-
preprocessors: Optional[List] = None,
|
100
|
-
rbln_config: Optional[RBLNConfig] = None,
|
101
|
-
rbln_device: Optional[List[int]] = None,
|
102
|
-
rbln_device_map: Optional[Dict[str, int]] = None,
|
103
|
-
**kwargs,
|
104
|
-
):
|
105
|
-
super().__init__(
|
106
|
-
models,
|
107
|
-
config,
|
108
|
-
preprocessors,
|
109
|
-
rbln_config,
|
110
|
-
rbln_device=rbln_device,
|
111
|
-
rbln_device_map=rbln_device_map,
|
112
|
-
**kwargs,
|
113
|
-
)
|
114
|
-
self.batch_size = self.rbln_config.meta["rbln_batch_size"]
|
115
|
-
self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
|
116
|
-
self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
|
117
|
-
|
118
|
-
self.prefill_attention_mask = torch.zeros(
|
119
|
-
self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
|
120
|
-
)
|
121
|
-
self.causal_mask = 1 - torch.triu(
|
122
|
-
torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
123
|
-
)
|
124
|
-
|
125
|
-
self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0], main_input_name="input_ids")
|
126
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
|
127
|
-
self.past_cached_length = 0
|
128
|
-
|
129
|
-
def can_generate(self):
|
130
|
-
return True
|
62
|
+
@classmethod
|
63
|
+
def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
|
64
|
+
return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
|
131
65
|
|
132
66
|
def __getattr__(self, __name: str) -> Any:
|
133
67
|
"""This is the key method to implement RBLN-Midm.
|
@@ -144,174 +78,46 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
144
78
|
return redirect(val)
|
145
79
|
return val
|
146
80
|
|
147
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
148
|
-
# TODO(jongho): implement
|
149
|
-
raise NotImplementedError
|
150
|
-
|
151
|
-
@classmethod
|
152
|
-
def _export(
|
153
|
-
cls,
|
154
|
-
model_id: str,
|
155
|
-
config: "PretrainedConfig",
|
156
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
157
|
-
revision: Optional[str] = None,
|
158
|
-
force_download: bool = False,
|
159
|
-
cache_dir: Optional[str] = None,
|
160
|
-
subfolder: str = "",
|
161
|
-
local_files_only: bool = False,
|
162
|
-
trust_remote_code: bool = False,
|
163
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
164
|
-
**kwargs,
|
165
|
-
) -> "RBLNMidmLMHeadModel":
|
166
|
-
|
167
|
-
task = kwargs.pop("task", None)
|
168
|
-
if task is None:
|
169
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
170
|
-
|
171
|
-
if model_save_dir is None:
|
172
|
-
save_dir = TemporaryDirectory()
|
173
|
-
save_dir_path = Path(save_dir.name)
|
174
|
-
else:
|
175
|
-
save_dir = model_save_dir
|
176
|
-
if isinstance(save_dir, TemporaryDirectory):
|
177
|
-
save_dir_path = Path(model_save_dir.name)
|
178
|
-
else:
|
179
|
-
save_dir_path = Path(model_save_dir)
|
180
|
-
save_dir_path.mkdir(exist_ok=True)
|
181
|
-
|
182
|
-
def update_configs(kwargs):
|
183
|
-
max_seq_len = kwargs.get("rbln_max_seq_len", None)
|
184
|
-
if max_seq_len is not None:
|
185
|
-
kwargs.update({"max_position_embeddings": max_seq_len})
|
186
|
-
|
187
|
-
kwargs.update(
|
188
|
-
{
|
189
|
-
"torchscript": True,
|
190
|
-
"return_dict": False,
|
191
|
-
"use_cache": True,
|
192
|
-
"torch_dtype": torch.float32,
|
193
|
-
"_attn_implementation": "eager",
|
194
|
-
}
|
195
|
-
)
|
196
|
-
|
197
|
-
return kwargs
|
198
|
-
|
199
|
-
kwargs = update_configs(kwargs)
|
200
|
-
|
201
|
-
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
202
|
-
|
203
|
-
model: MidmLMHeadModel = TasksManager.get_model_from_task(
|
204
|
-
task=task,
|
205
|
-
model_name_or_path=model_id,
|
206
|
-
subfolder=subfolder,
|
207
|
-
revision=revision,
|
208
|
-
framework="pt",
|
209
|
-
cache_dir=cache_dir,
|
210
|
-
use_auth_token=use_auth_token,
|
211
|
-
local_files_only=local_files_only,
|
212
|
-
force_download=force_download,
|
213
|
-
trust_remote_code=trust_remote_code,
|
214
|
-
ignore_mismatched_sizes=True,
|
215
|
-
**kwargs,
|
216
|
-
)
|
217
|
-
|
218
|
-
if config is None:
|
219
|
-
config = model.config
|
220
|
-
|
221
|
-
config.save_pretrained(save_dir_path)
|
222
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
223
|
-
|
224
|
-
# Get compilation arguments
|
225
|
-
if rbln_config_kwargs.get("rbln_config", None) is None:
|
226
|
-
rbln_config = cls.get_rbln_config(
|
227
|
-
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
228
|
-
)
|
229
|
-
|
230
|
-
def compile_midm():
|
231
|
-
wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
|
232
|
-
prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
233
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
234
|
-
|
235
|
-
prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
|
236
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
237
|
-
|
238
|
-
prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
|
239
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
|
240
|
-
|
241
|
-
prefill_ir = rebel.torchscript_to_ir(
|
242
|
-
prefill_scripted_model,
|
243
|
-
input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
|
244
|
-
)
|
245
|
-
dec_ir = rebel.torchscript_to_ir(
|
246
|
-
dec_scripted_model,
|
247
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
248
|
-
)
|
249
|
-
|
250
|
-
connections = [
|
251
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)
|
252
|
-
]
|
253
|
-
|
254
|
-
compiled_model = rebel.compile(
|
255
|
-
prefill_ir,
|
256
|
-
dec_ir,
|
257
|
-
connections=connections,
|
258
|
-
fusion=prefill_rbln_runtime_config.fusion,
|
259
|
-
npu=prefill_rbln_runtime_config.npu,
|
260
|
-
tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
|
261
|
-
use_weight_sharing=True,
|
262
|
-
)
|
263
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
264
|
-
|
265
|
-
compile_midm()
|
266
|
-
|
267
|
-
rbln_config.save(save_dir_path)
|
268
|
-
|
269
|
-
return cls._from_pretrained(
|
270
|
-
model_id=save_dir_path,
|
271
|
-
config=config,
|
272
|
-
model_save_dir=save_dir,
|
273
|
-
**rbln_constructor_kwargs,
|
274
|
-
**kwargs,
|
275
|
-
)
|
276
|
-
|
277
81
|
@classmethod
|
278
82
|
def _get_rbln_config(
|
279
83
|
cls,
|
280
84
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
281
85
|
model_config: "PretrainedConfig",
|
282
|
-
rbln_prefill_chunk_size: Optional[int] = 128,
|
283
86
|
rbln_max_seq_len: Optional[int] = None,
|
284
87
|
rbln_batch_size: Optional[int] = None,
|
88
|
+
**kwargs,
|
285
89
|
) -> RBLNConfig:
|
286
90
|
meta = {}
|
287
|
-
if rbln_max_seq_len is None:
|
288
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
289
91
|
|
92
|
+
prefill_chunk_size = 128
|
290
93
|
if rbln_max_seq_len is None:
|
291
|
-
|
292
|
-
|
293
|
-
rbln_max_seq_len = tokenizer.model_max_length
|
294
|
-
break
|
295
|
-
if rbln_max_seq_len is None:
|
296
|
-
raise ValueError("`rbln_max_seq_len` should be specified!")
|
297
|
-
|
298
|
-
if rbln_batch_size is None:
|
299
|
-
rbln_batch_size = 1
|
94
|
+
rbln_max_seq_len = getattr(model_config, "n_positions", None)
|
95
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
300
96
|
|
301
|
-
meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
|
302
97
|
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
303
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
304
|
-
|
305
|
-
|
98
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
99
|
+
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
100
|
+
|
101
|
+
def get_input_info(
|
102
|
+
batch_size,
|
103
|
+
query_length,
|
104
|
+
):
|
105
|
+
head_dim = (
|
106
|
+
model_config.head_dim
|
107
|
+
if hasattr(model_config, "head_dim")
|
108
|
+
else model_config.hidden_size // model_config.n_head
|
109
|
+
)
|
306
110
|
input_info = [
|
307
|
-
("input_ids", [
|
308
|
-
("attention_mask", [
|
111
|
+
("input_ids", [batch_size, query_length], "int64"),
|
112
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
|
309
113
|
(
|
310
114
|
"cache_position",
|
311
|
-
[],
|
115
|
+
[batch_size, query_length],
|
312
116
|
"int32",
|
313
117
|
),
|
118
|
+
("batch_position", [], "int16"),
|
314
119
|
]
|
120
|
+
|
315
121
|
input_info.extend(
|
316
122
|
[
|
317
123
|
(
|
@@ -320,18 +126,24 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
320
126
|
rbln_batch_size,
|
321
127
|
model_config.n_head,
|
322
128
|
rbln_max_seq_len,
|
323
|
-
|
129
|
+
head_dim,
|
324
130
|
],
|
325
131
|
"float32",
|
326
132
|
)
|
327
133
|
for i in range(model_config.n_layer * 2)
|
328
134
|
]
|
329
135
|
)
|
136
|
+
|
330
137
|
return input_info
|
331
138
|
|
332
|
-
|
333
|
-
|
334
|
-
|
139
|
+
prefill_input_info = get_input_info(
|
140
|
+
batch_size=1,
|
141
|
+
query_length=prefill_chunk_size,
|
142
|
+
)
|
143
|
+
dec_input_info = get_input_info(
|
144
|
+
batch_size=rbln_batch_size,
|
145
|
+
query_length=1,
|
146
|
+
)
|
335
147
|
|
336
148
|
prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
|
337
149
|
dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
|
@@ -344,83 +156,3 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
|
|
344
156
|
)
|
345
157
|
|
346
158
|
return rbln_config
|
347
|
-
|
348
|
-
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
349
|
-
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
350
|
-
return [
|
351
|
-
self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
|
352
|
-
self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
|
353
|
-
]
|
354
|
-
|
355
|
-
def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
|
356
|
-
batch_size, cur_len = input_ids.shape
|
357
|
-
past_cached_length = past_key_values
|
358
|
-
|
359
|
-
if past_cached_length == 0:
|
360
|
-
mod_len = cur_len % self.prefill_chunk_size
|
361
|
-
self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
|
362
|
-
|
363
|
-
prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
|
364
|
-
self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
365
|
-
|
366
|
-
input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
|
367
|
-
attention_mask = self.prefill_attention_mask.clone()
|
368
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
369
|
-
|
370
|
-
query_length = cur_len + self.pad_len
|
371
|
-
else:
|
372
|
-
attention_mask = torch.nn.functional.pad(
|
373
|
-
attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
|
374
|
-
)
|
375
|
-
attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
|
376
|
-
cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
|
377
|
-
input_ids = input_ids[:, -1:].contiguous()
|
378
|
-
query_length = 1
|
379
|
-
|
380
|
-
model_inputs = {
|
381
|
-
"input_ids": input_ids,
|
382
|
-
"past_key_values": past_cached_length,
|
383
|
-
"attention_mask": attention_mask,
|
384
|
-
"cache_position": cache_position,
|
385
|
-
"query_length": query_length,
|
386
|
-
}
|
387
|
-
|
388
|
-
return model_inputs
|
389
|
-
|
390
|
-
def forward(
|
391
|
-
self,
|
392
|
-
input_ids: Optional[torch.LongTensor] = None,
|
393
|
-
past_key_values: int = None,
|
394
|
-
attention_mask: Optional[torch.FloatTensor] = None,
|
395
|
-
cache_position: Optional[torch.Tensor] = None,
|
396
|
-
query_length: Optional[torch.Tensor] = None,
|
397
|
-
**kwargs,
|
398
|
-
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
399
|
-
past_cached_length = past_key_values
|
400
|
-
|
401
|
-
if past_cached_length is not None:
|
402
|
-
past_cached_length += query_length
|
403
|
-
|
404
|
-
if cache_position == 0:
|
405
|
-
for step in range(0, query_length, self.prefill_chunk_size):
|
406
|
-
sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
|
407
|
-
attention_mask[:, :, :, :step] = 1
|
408
|
-
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
409
|
-
attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
|
410
|
-
|
411
|
-
output = self.prefill_decoder(
|
412
|
-
input_ids=sliced_input_ids.contiguous(),
|
413
|
-
attention_mask=attention_mask,
|
414
|
-
cache_position=cache_position + step,
|
415
|
-
)
|
416
|
-
cache_position += self.prefill_chunk_size
|
417
|
-
else:
|
418
|
-
output = self.decoder(
|
419
|
-
input_ids=input_ids.contiguous(),
|
420
|
-
attention_mask=attention_mask,
|
421
|
-
cache_position=cache_position,
|
422
|
-
)
|
423
|
-
return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
|
424
|
-
|
425
|
-
def __repr__(self):
|
426
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
@@ -23,13 +23,10 @@
|
|
23
23
|
|
24
24
|
import inspect
|
25
25
|
import logging
|
26
|
-
from pathlib import Path
|
27
|
-
from tempfile import TemporaryDirectory
|
28
26
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
29
27
|
|
30
28
|
import rebel
|
31
29
|
import torch
|
32
|
-
from optimum.exporters import TasksManager
|
33
30
|
from transformers import (
|
34
31
|
AutoModelForSpeechSeq2Seq,
|
35
32
|
AutoProcessor,
|
@@ -40,10 +37,9 @@ from transformers import (
|
|
40
37
|
)
|
41
38
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
42
39
|
|
43
|
-
from ....modeling_base import
|
40
|
+
from ....modeling_base import RBLNModel
|
44
41
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
45
42
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
46
|
-
from ....utils.save_utils import maybe_save_preprocessors
|
47
43
|
from .whisper_architecture import (
|
48
44
|
_WhisperDecoderWrapper,
|
49
45
|
_WhisperEncoderWrapper,
|
@@ -76,10 +72,10 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
|
76
72
|
return Seq2SeqLMOutput(logits=outputs)
|
77
73
|
|
78
74
|
|
79
|
-
class RBLNWhisperForConditionalGeneration(
|
75
|
+
class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
|
80
76
|
"""
|
81
77
|
The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
|
82
|
-
This model inherits from [`
|
78
|
+
This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
83
79
|
|
84
80
|
A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
|
85
81
|
It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
|
@@ -96,8 +92,8 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
96
92
|
self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
|
97
93
|
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
98
94
|
|
99
|
-
self.encoder = RBLNRuntimeEncoder(runtime=self.
|
100
|
-
self.decoder = RBLNRuntimeDecoder(runtime=self.
|
95
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
|
96
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
101
97
|
self.forced_decoder_ids = self.config.forced_decoder_ids
|
102
98
|
|
103
99
|
# used in GenerationMixin.generate()
|
@@ -152,123 +148,57 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
152
148
|
}
|
153
149
|
|
154
150
|
@classmethod
|
155
|
-
def
|
156
|
-
cls,
|
157
|
-
model_id: str,
|
158
|
-
config: "PretrainedConfig",
|
159
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
160
|
-
revision: Optional[str] = None,
|
161
|
-
force_download: bool = False,
|
162
|
-
cache_dir: Optional[str] = None,
|
163
|
-
subfolder: str = "",
|
164
|
-
local_files_only: bool = False,
|
165
|
-
trust_remote_code: bool = False,
|
166
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
167
|
-
**kwargs,
|
168
|
-
) -> "RBLNWhisperForConditionalGeneration":
|
169
|
-
"""
|
170
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
171
|
-
"""
|
172
|
-
task = kwargs.pop("task", None)
|
173
|
-
if task is None:
|
174
|
-
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
175
|
-
|
176
|
-
if model_save_dir is None:
|
177
|
-
save_dir = TemporaryDirectory()
|
178
|
-
save_dir_path = Path(save_dir.name)
|
179
|
-
else:
|
180
|
-
save_dir = model_save_dir
|
181
|
-
if isinstance(save_dir, TemporaryDirectory):
|
182
|
-
save_dir_path = Path(model_save_dir.name)
|
183
|
-
else:
|
184
|
-
save_dir_path = Path(model_save_dir)
|
185
|
-
save_dir_path.mkdir(exist_ok=True)
|
186
|
-
|
151
|
+
def update_kwargs(cls, kwargs):
|
187
152
|
kwargs.update(
|
188
153
|
{
|
189
154
|
"torchscript": True,
|
190
155
|
"return_dict": False,
|
191
|
-
"use_cache":
|
156
|
+
"use_cache": True,
|
192
157
|
}
|
193
158
|
)
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
159
|
+
return kwargs
|
160
|
+
|
161
|
+
@classmethod
|
162
|
+
@torch.inference_mode()
|
163
|
+
def get_compiled_model(cls, model, rbln_config: RBLNConfig):
|
164
|
+
wrapped_encoder = _WhisperEncoderWrapper(model).eval()
|
165
|
+
wrapped_decoder = _WhisperDecoderWrapper(model).eval()
|
166
|
+
|
167
|
+
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
168
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
169
|
+
|
170
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
171
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
172
|
+
|
173
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0], check_trace=False)
|
174
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
175
|
+
|
176
|
+
enc_ir = rebel.torchscript_to_ir(
|
177
|
+
enc_scripted_model,
|
178
|
+
input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
|
179
|
+
name=enc_rbln_runtime_config.rbln_mod_name,
|
180
|
+
)
|
181
|
+
dec_ir = rebel.torchscript_to_ir(
|
182
|
+
dec_scripted_model,
|
183
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
184
|
+
name=dec_rbln_runtime_config.rbln_mod_name,
|
208
185
|
)
|
186
|
+
dec_ir.batch_size = dec_rbln_runtime_config.batch_size
|
209
187
|
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
def compile_whisper():
|
223
|
-
wrapped_encoder = _WhisperEncoderWrapper(model).eval()
|
224
|
-
wrapped_decoder = _WhisperDecoderWrapper(model).eval()
|
225
|
-
|
226
|
-
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
227
|
-
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
228
|
-
|
229
|
-
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
230
|
-
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
231
|
-
|
232
|
-
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0]).eval()
|
233
|
-
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs).eval()
|
234
|
-
|
235
|
-
enc_ir = rebel.torchscript_to_ir(
|
236
|
-
enc_scripted_model,
|
237
|
-
input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
|
238
|
-
name=enc_rbln_runtime_config.rbln_mod_name,
|
239
|
-
)
|
240
|
-
dec_ir = rebel.torchscript_to_ir(
|
241
|
-
dec_scripted_model,
|
242
|
-
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
243
|
-
name=dec_rbln_runtime_config.rbln_mod_name,
|
244
|
-
)
|
245
|
-
dec_ir.batch_size = dec_rbln_runtime_config.batch_size
|
246
|
-
|
247
|
-
# Caching encoder/decoder I/O
|
248
|
-
connections = [
|
249
|
-
(enc_ir.outputs[0], dec_ir.inputs[4]),
|
250
|
-
(dec_ir.outputs[1], dec_ir.inputs[3]),
|
251
|
-
]
|
252
|
-
compiled_model = rebel.compile(
|
253
|
-
enc_ir,
|
254
|
-
dec_ir,
|
255
|
-
connections=connections,
|
256
|
-
fusion=enc_rbln_runtime_config.fusion,
|
257
|
-
npu=enc_rbln_runtime_config.npu,
|
258
|
-
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
259
|
-
)
|
260
|
-
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
261
|
-
|
262
|
-
compile_whisper()
|
263
|
-
rbln_config.save(save_dir_path)
|
264
|
-
|
265
|
-
return cls._from_pretrained(
|
266
|
-
model_id=save_dir_path,
|
267
|
-
config=config,
|
268
|
-
model_save_dir=save_dir,
|
269
|
-
**rbln_constructor_kwargs,
|
270
|
-
**kwargs,
|
188
|
+
# Caching encoder/decoder I/O
|
189
|
+
connections = [
|
190
|
+
(enc_ir.outputs[0], dec_ir.inputs[4]),
|
191
|
+
(dec_ir.outputs[1], dec_ir.inputs[3]),
|
192
|
+
]
|
193
|
+
compiled_model = rebel.compile(
|
194
|
+
enc_ir,
|
195
|
+
dec_ir,
|
196
|
+
connections=connections,
|
197
|
+
fusion=enc_rbln_runtime_config.fusion,
|
198
|
+
npu=enc_rbln_runtime_config.npu,
|
199
|
+
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
271
200
|
)
|
201
|
+
return compiled_model
|
272
202
|
|
273
203
|
@classmethod
|
274
204
|
def _get_rbln_config(
|
@@ -357,11 +287,14 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
357
287
|
|
358
288
|
return rbln_config
|
359
289
|
|
360
|
-
|
290
|
+
@classmethod
|
291
|
+
def _create_runtimes(
|
292
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
293
|
+
) -> List[rebel.Runtime]:
|
361
294
|
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
362
295
|
return [
|
363
|
-
|
364
|
-
|
296
|
+
compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
|
297
|
+
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
365
298
|
]
|
366
299
|
|
367
300
|
def forward(
|
@@ -379,6 +312,3 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
|
|
379
312
|
lm_logits = decoder_output.logits
|
380
313
|
|
381
314
|
return Seq2SeqLMOutput(logits=lm_logits)
|
382
|
-
|
383
|
-
def __repr__(self):
|
384
|
-
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|