optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -23,21 +23,12 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
29
27
 
30
- import rebel
31
- import torch
32
- from optimum.exporters import TasksManager
33
- from transformers import AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
34
- from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
28
+ from transformers import PretrainedConfig, PreTrainedModel
35
29
 
36
- from ....modeling_base import RBLNBaseModel
37
- from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
38
- from ....utils.runtime_utils import RBLNPytorchRuntime
39
- from ....utils.save_utils import maybe_save_preprocessors
40
- from ...generation.utils import RBLNGenerationMixin
30
+ from ....modeling_config import RBLNConfig, RBLNRuntimeConfig
31
+ from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
41
32
  from .hf_hub_cached.modeling_midm import MidmLMHeadModel
42
33
  from .midm_architecture import (
43
34
  MidmLMHeadModelWrapper,
@@ -45,7 +36,6 @@ from .midm_architecture import (
45
36
 
46
37
 
47
38
  logger = logging.getLogger(__name__)
48
-
49
39
  if TYPE_CHECKING:
50
40
  from transformers import (
51
41
  AutoFeatureExtractor,
@@ -55,31 +45,12 @@ if TYPE_CHECKING:
55
45
  )
56
46
 
57
47
 
58
- class RBLNRuntimeDecoder(RBLNPytorchRuntime):
59
- mandatory_members = ["main_input_name"]
60
-
61
- # RBLN_Runtimemodule
62
- def forward(
63
- self,
64
- input_ids: torch.LongTensor = None,
65
- attention_mask: torch.LongTensor = None,
66
- cache_position: torch.Tensor = None,
67
- **kwargs: Dict[str, Any],
68
- ):
69
- logits = super().forward(
70
- input_ids=input_ids,
71
- attention_mask=attention_mask,
72
- cache_position=cache_position,
73
- )
74
- return logits
75
-
76
-
77
- class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
48
+ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
78
49
  """
79
50
  The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
80
51
  embeddings).
81
52
 
82
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
53
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
83
54
  library implements for all its model.
84
55
 
85
56
  It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
@@ -88,46 +59,9 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
88
59
 
89
60
  """
90
61
 
91
- model_type = "rbln_model"
92
- auto_model_class = AutoModelForCausalLM
93
- main_input_name = "input_ids"
94
-
95
- def __init__(
96
- self,
97
- models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
98
- config: PretrainedConfig = None,
99
- preprocessors: Optional[List] = None,
100
- rbln_config: Optional[RBLNConfig] = None,
101
- rbln_device: Optional[List[int]] = None,
102
- rbln_device_map: Optional[Dict[str, int]] = None,
103
- **kwargs,
104
- ):
105
- super().__init__(
106
- models,
107
- config,
108
- preprocessors,
109
- rbln_config,
110
- rbln_device=rbln_device,
111
- rbln_device_map=rbln_device_map,
112
- **kwargs,
113
- )
114
- self.batch_size = self.rbln_config.meta["rbln_batch_size"]
115
- self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
116
- self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
117
-
118
- self.prefill_attention_mask = torch.zeros(
119
- self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
120
- )
121
- self.causal_mask = 1 - torch.triu(
122
- torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
123
- )
124
-
125
- self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.runtimes[0], main_input_name="input_ids")
126
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
127
- self.past_cached_length = 0
128
-
129
- def can_generate(self):
130
- return True
62
+ @classmethod
63
+ def wrapping_torch_model(self, model: "PreTrainedModel", rbln_max_seq_len: int):
64
+ return MidmLMHeadModelWrapper(model, rbln_max_seq_len).eval()
131
65
 
132
66
  def __getattr__(self, __name: str) -> Any:
133
67
  """This is the key method to implement RBLN-Midm.
@@ -144,174 +78,46 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
144
78
  return redirect(val)
145
79
  return val
146
80
 
147
- def _reorder_cache(self, past_key_values, beam_idx):
148
- # TODO(jongho): implement
149
- raise NotImplementedError
150
-
151
- @classmethod
152
- def _export(
153
- cls,
154
- model_id: str,
155
- config: "PretrainedConfig",
156
- use_auth_token: Optional[Union[bool, str]] = None,
157
- revision: Optional[str] = None,
158
- force_download: bool = False,
159
- cache_dir: Optional[str] = None,
160
- subfolder: str = "",
161
- local_files_only: bool = False,
162
- trust_remote_code: bool = False,
163
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
164
- **kwargs,
165
- ) -> "RBLNMidmLMHeadModel":
166
-
167
- task = kwargs.pop("task", None)
168
- if task is None:
169
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
170
-
171
- if model_save_dir is None:
172
- save_dir = TemporaryDirectory()
173
- save_dir_path = Path(save_dir.name)
174
- else:
175
- save_dir = model_save_dir
176
- if isinstance(save_dir, TemporaryDirectory):
177
- save_dir_path = Path(model_save_dir.name)
178
- else:
179
- save_dir_path = Path(model_save_dir)
180
- save_dir_path.mkdir(exist_ok=True)
181
-
182
- def update_configs(kwargs):
183
- max_seq_len = kwargs.get("rbln_max_seq_len", None)
184
- if max_seq_len is not None:
185
- kwargs.update({"max_position_embeddings": max_seq_len})
186
-
187
- kwargs.update(
188
- {
189
- "torchscript": True,
190
- "return_dict": False,
191
- "use_cache": True,
192
- "torch_dtype": torch.float32,
193
- "_attn_implementation": "eager",
194
- }
195
- )
196
-
197
- return kwargs
198
-
199
- kwargs = update_configs(kwargs)
200
-
201
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
202
-
203
- model: MidmLMHeadModel = TasksManager.get_model_from_task(
204
- task=task,
205
- model_name_or_path=model_id,
206
- subfolder=subfolder,
207
- revision=revision,
208
- framework="pt",
209
- cache_dir=cache_dir,
210
- use_auth_token=use_auth_token,
211
- local_files_only=local_files_only,
212
- force_download=force_download,
213
- trust_remote_code=trust_remote_code,
214
- ignore_mismatched_sizes=True,
215
- **kwargs,
216
- )
217
-
218
- if config is None:
219
- config = model.config
220
-
221
- config.save_pretrained(save_dir_path)
222
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
223
-
224
- # Get compilation arguments
225
- if rbln_config_kwargs.get("rbln_config", None) is None:
226
- rbln_config = cls.get_rbln_config(
227
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
228
- )
229
-
230
- def compile_midm():
231
- wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
232
- prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
233
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
234
-
235
- prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
236
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
237
-
238
- prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs)
239
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
240
-
241
- prefill_ir = rebel.torchscript_to_ir(
242
- prefill_scripted_model,
243
- input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
244
- )
245
- dec_ir = rebel.torchscript_to_ir(
246
- dec_scripted_model,
247
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
248
- )
249
-
250
- connections = [
251
- (prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)
252
- ]
253
-
254
- compiled_model = rebel.compile(
255
- prefill_ir,
256
- dec_ir,
257
- connections=connections,
258
- fusion=prefill_rbln_runtime_config.fusion,
259
- npu=prefill_rbln_runtime_config.npu,
260
- tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
261
- use_weight_sharing=True,
262
- )
263
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
264
-
265
- compile_midm()
266
-
267
- rbln_config.save(save_dir_path)
268
-
269
- return cls._from_pretrained(
270
- model_id=save_dir_path,
271
- config=config,
272
- model_save_dir=save_dir,
273
- **rbln_constructor_kwargs,
274
- **kwargs,
275
- )
276
-
277
81
  @classmethod
278
82
  def _get_rbln_config(
279
83
  cls,
280
84
  preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
281
85
  model_config: "PretrainedConfig",
282
- rbln_prefill_chunk_size: Optional[int] = 128,
283
86
  rbln_max_seq_len: Optional[int] = None,
284
87
  rbln_batch_size: Optional[int] = None,
88
+ **kwargs,
285
89
  ) -> RBLNConfig:
286
90
  meta = {}
287
- if rbln_max_seq_len is None:
288
- rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
289
91
 
92
+ prefill_chunk_size = 128
290
93
  if rbln_max_seq_len is None:
291
- for tokenizer in preprocessors:
292
- if hasattr(tokenizer, "model_max_length"):
293
- rbln_max_seq_len = tokenizer.model_max_length
294
- break
295
- if rbln_max_seq_len is None:
296
- raise ValueError("`rbln_max_seq_len` should be specified!")
297
-
298
- if rbln_batch_size is None:
299
- rbln_batch_size = 1
94
+ rbln_max_seq_len = getattr(model_config, "n_positions", None)
95
+ rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
300
96
 
301
- meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
302
97
  meta["rbln_max_seq_len"] = rbln_max_seq_len
303
- meta["rbln_batch_size"] = rbln_batch_size if rbln_batch_size is not None else 1
304
-
305
- def get_input_info(query_length):
98
+ meta["rbln_batch_size"] = rbln_batch_size
99
+ meta["rbln_prefill_chunk_size"] = prefill_chunk_size
100
+
101
+ def get_input_info(
102
+ batch_size,
103
+ query_length,
104
+ ):
105
+ head_dim = (
106
+ model_config.head_dim
107
+ if hasattr(model_config, "head_dim")
108
+ else model_config.hidden_size // model_config.n_head
109
+ )
306
110
  input_info = [
307
- ("input_ids", [rbln_batch_size, query_length], "int64"),
308
- ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
111
+ ("input_ids", [batch_size, query_length], "int64"),
112
+ ("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "int64"),
309
113
  (
310
114
  "cache_position",
311
- [],
115
+ [batch_size, query_length],
312
116
  "int32",
313
117
  ),
118
+ ("batch_position", [], "int16"),
314
119
  ]
120
+
315
121
  input_info.extend(
316
122
  [
317
123
  (
@@ -320,18 +126,24 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
320
126
  rbln_batch_size,
321
127
  model_config.n_head,
322
128
  rbln_max_seq_len,
323
- model_config.hidden_size // model_config.n_head,
129
+ head_dim,
324
130
  ],
325
131
  "float32",
326
132
  )
327
133
  for i in range(model_config.n_layer * 2)
328
134
  ]
329
135
  )
136
+
330
137
  return input_info
331
138
 
332
- # model input info
333
- prefill_input_info = get_input_info(query_length=rbln_prefill_chunk_size)
334
- dec_input_info = get_input_info(query_length=1)
139
+ prefill_input_info = get_input_info(
140
+ batch_size=1,
141
+ query_length=prefill_chunk_size,
142
+ )
143
+ dec_input_info = get_input_info(
144
+ batch_size=rbln_batch_size,
145
+ query_length=1,
146
+ )
335
147
 
336
148
  prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
337
149
  dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
@@ -344,83 +156,3 @@ class RBLNMidmLMHeadModel(RBLNBaseModel, RBLNGenerationMixin):
344
156
  )
345
157
 
346
158
  return rbln_config
347
-
348
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
349
- device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
350
- return [
351
- self.compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
352
- self.compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
353
- ]
354
-
355
- def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
356
- batch_size, cur_len = input_ids.shape
357
- past_cached_length = past_key_values
358
-
359
- if past_cached_length == 0:
360
- mod_len = cur_len % self.prefill_chunk_size
361
- self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
362
-
363
- prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
364
- self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
365
-
366
- input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
367
- attention_mask = self.prefill_attention_mask.clone()
368
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
369
-
370
- query_length = cur_len + self.pad_len
371
- else:
372
- attention_mask = torch.nn.functional.pad(
373
- attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
374
- )
375
- attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
376
- cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
377
- input_ids = input_ids[:, -1:].contiguous()
378
- query_length = 1
379
-
380
- model_inputs = {
381
- "input_ids": input_ids,
382
- "past_key_values": past_cached_length,
383
- "attention_mask": attention_mask,
384
- "cache_position": cache_position,
385
- "query_length": query_length,
386
- }
387
-
388
- return model_inputs
389
-
390
- def forward(
391
- self,
392
- input_ids: Optional[torch.LongTensor] = None,
393
- past_key_values: int = None,
394
- attention_mask: Optional[torch.FloatTensor] = None,
395
- cache_position: Optional[torch.Tensor] = None,
396
- query_length: Optional[torch.Tensor] = None,
397
- **kwargs,
398
- ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
399
- past_cached_length = past_key_values
400
-
401
- if past_cached_length is not None:
402
- past_cached_length += query_length
403
-
404
- if cache_position == 0:
405
- for step in range(0, query_length, self.prefill_chunk_size):
406
- sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
407
- attention_mask[:, :, :, :step] = 1
408
- attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
409
- attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
410
-
411
- output = self.prefill_decoder(
412
- input_ids=sliced_input_ids.contiguous(),
413
- attention_mask=attention_mask,
414
- cache_position=cache_position + step,
415
- )
416
- cache_position += self.prefill_chunk_size
417
- else:
418
- output = self.decoder(
419
- input_ids=input_ids.contiguous(),
420
- attention_mask=attention_mask,
421
- cache_position=cache_position,
422
- )
423
- return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
424
-
425
- def __repr__(self):
426
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
@@ -23,13 +23,10 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import (
34
31
  AutoModelForSpeechSeq2Seq,
35
32
  AutoProcessor,
@@ -40,10 +37,9 @@ from transformers import (
40
37
  )
41
38
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
42
39
 
43
- from ....modeling_base import RBLNBaseModel
40
+ from ....modeling_base import RBLNModel
44
41
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
45
42
  from ....utils.runtime_utils import RBLNPytorchRuntime
46
- from ....utils.save_utils import maybe_save_preprocessors
47
43
  from .whisper_architecture import (
48
44
  _WhisperDecoderWrapper,
49
45
  _WhisperEncoderWrapper,
@@ -76,10 +72,10 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
76
72
  return Seq2SeqLMOutput(logits=outputs)
77
73
 
78
74
 
79
- class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
75
+ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
80
76
  """
81
77
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
82
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
78
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
83
79
 
84
80
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
85
81
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -96,8 +92,8 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
96
92
  self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
97
93
  self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
98
94
 
99
- self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_features")
100
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
95
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
96
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
101
97
  self.forced_decoder_ids = self.config.forced_decoder_ids
102
98
 
103
99
  # used in GenerationMixin.generate()
@@ -152,123 +148,57 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
152
148
  }
153
149
 
154
150
  @classmethod
155
- def _export(
156
- cls,
157
- model_id: str,
158
- config: "PretrainedConfig",
159
- use_auth_token: Optional[Union[bool, str]] = None,
160
- revision: Optional[str] = None,
161
- force_download: bool = False,
162
- cache_dir: Optional[str] = None,
163
- subfolder: str = "",
164
- local_files_only: bool = False,
165
- trust_remote_code: bool = False,
166
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
167
- **kwargs,
168
- ) -> "RBLNWhisperForConditionalGeneration":
169
- """
170
- Exports a vanilla Transformers model into a rbln-compiled Module.
171
- """
172
- task = kwargs.pop("task", None)
173
- if task is None:
174
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
175
-
176
- if model_save_dir is None:
177
- save_dir = TemporaryDirectory()
178
- save_dir_path = Path(save_dir.name)
179
- else:
180
- save_dir = model_save_dir
181
- if isinstance(save_dir, TemporaryDirectory):
182
- save_dir_path = Path(model_save_dir.name)
183
- else:
184
- save_dir_path = Path(model_save_dir)
185
- save_dir_path.mkdir(exist_ok=True)
186
-
151
+ def update_kwargs(cls, kwargs):
187
152
  kwargs.update(
188
153
  {
189
154
  "torchscript": True,
190
155
  "return_dict": False,
191
- "use_cache": False,
156
+ "use_cache": True,
192
157
  }
193
158
  )
194
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
195
-
196
- model: WhisperForConditionalGeneration = TasksManager.get_model_from_task(
197
- task=task,
198
- model_name_or_path=model_id,
199
- subfolder=subfolder,
200
- revision=revision,
201
- framework="pt",
202
- cache_dir=cache_dir,
203
- use_auth_token=use_auth_token,
204
- local_files_only=local_files_only,
205
- force_download=force_download,
206
- trust_remote_code=trust_remote_code,
207
- **kwargs,
159
+ return kwargs
160
+
161
+ @classmethod
162
+ @torch.inference_mode()
163
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
164
+ wrapped_encoder = _WhisperEncoderWrapper(model).eval()
165
+ wrapped_decoder = _WhisperDecoderWrapper(model).eval()
166
+
167
+ enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
168
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
169
+
170
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
171
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
172
+
173
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0], check_trace=False)
174
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
175
+
176
+ enc_ir = rebel.torchscript_to_ir(
177
+ enc_scripted_model,
178
+ input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
179
+ name=enc_rbln_runtime_config.rbln_mod_name,
180
+ )
181
+ dec_ir = rebel.torchscript_to_ir(
182
+ dec_scripted_model,
183
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
184
+ name=dec_rbln_runtime_config.rbln_mod_name,
208
185
  )
186
+ dec_ir.batch_size = dec_rbln_runtime_config.batch_size
209
187
 
210
- if config is None:
211
- config = model.config
212
-
213
- config.save_pretrained(save_dir_path)
214
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
215
-
216
- # Get compilation arguments
217
- if rbln_config_kwargs.get("rbln_config", None) is None:
218
- rbln_config = cls.get_rbln_config(
219
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
220
- )
221
-
222
- def compile_whisper():
223
- wrapped_encoder = _WhisperEncoderWrapper(model).eval()
224
- wrapped_decoder = _WhisperDecoderWrapper(model).eval()
225
-
226
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
227
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
228
-
229
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
230
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
231
-
232
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0]).eval()
233
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs).eval()
234
-
235
- enc_ir = rebel.torchscript_to_ir(
236
- enc_scripted_model,
237
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
238
- name=enc_rbln_runtime_config.rbln_mod_name,
239
- )
240
- dec_ir = rebel.torchscript_to_ir(
241
- dec_scripted_model,
242
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
243
- name=dec_rbln_runtime_config.rbln_mod_name,
244
- )
245
- dec_ir.batch_size = dec_rbln_runtime_config.batch_size
246
-
247
- # Caching encoder/decoder I/O
248
- connections = [
249
- (enc_ir.outputs[0], dec_ir.inputs[4]),
250
- (dec_ir.outputs[1], dec_ir.inputs[3]),
251
- ]
252
- compiled_model = rebel.compile(
253
- enc_ir,
254
- dec_ir,
255
- connections=connections,
256
- fusion=enc_rbln_runtime_config.fusion,
257
- npu=enc_rbln_runtime_config.npu,
258
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
259
- )
260
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
261
-
262
- compile_whisper()
263
- rbln_config.save(save_dir_path)
264
-
265
- return cls._from_pretrained(
266
- model_id=save_dir_path,
267
- config=config,
268
- model_save_dir=save_dir,
269
- **rbln_constructor_kwargs,
270
- **kwargs,
188
+ # Caching encoder/decoder I/O
189
+ connections = [
190
+ (enc_ir.outputs[0], dec_ir.inputs[4]),
191
+ (dec_ir.outputs[1], dec_ir.inputs[3]),
192
+ ]
193
+ compiled_model = rebel.compile(
194
+ enc_ir,
195
+ dec_ir,
196
+ connections=connections,
197
+ fusion=enc_rbln_runtime_config.fusion,
198
+ npu=enc_rbln_runtime_config.npu,
199
+ tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
271
200
  )
201
+ return compiled_model
272
202
 
273
203
  @classmethod
274
204
  def _get_rbln_config(
@@ -357,11 +287,14 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
357
287
 
358
288
  return rbln_config
359
289
 
360
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
290
+ @classmethod
291
+ def _create_runtimes(
292
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
293
+ ) -> List[rebel.Runtime]:
361
294
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
362
295
  return [
363
- self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
364
- self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
296
+ compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
297
+ compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
365
298
  ]
366
299
 
367
300
  def forward(
@@ -379,6 +312,3 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
379
312
  lm_logits = decoder_output.logits
380
313
 
381
314
  return Seq2SeqLMOutput(logits=lm_logits)
382
-
383
- def __repr__(self):
384
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])