optimum-rbln 0.1.1__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. optimum/rbln/__init__.py +9 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  4. optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
  5. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
  8. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  9. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  10. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  12. optimum/rbln/modeling_base.py +175 -103
  13. optimum/rbln/modeling_seq2seq.py +58 -132
  14. optimum/rbln/transformers/__init__.py +4 -0
  15. optimum/rbln/transformers/models/__init__.py +2 -0
  16. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  17. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  18. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  19. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
  20. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
  21. optimum/rbln/transformers/models/llama/llama_architecture.py +62 -33
  22. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +764 -0
  23. optimum/rbln/transformers/models/llama/modeling_llama.py +208 -140
  24. optimum/rbln/transformers/models/midm/__init__.py +32 -0
  25. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +22 -0
  26. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +303 -0
  27. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +1473 -0
  28. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +98 -0
  29. optimum/rbln/transformers/models/midm/midm_architecture.py +506 -0
  30. optimum/rbln/transformers/models/midm/modeling_midm.py +390 -0
  31. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  32. optimum/rbln/utils/__init__.py +1 -1
  33. optimum/rbln/utils/import_utils.py +46 -0
  34. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -50
  35. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +37 -27
  36. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
  37. {optimum_rbln-0.1.1.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,390 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import inspect
25
+ import logging
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
27
+
28
+ import rebel
29
+ import torch
30
+ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
31
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
32
+
33
+ from ....modeling_base import RBLNModel
34
+ from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
35
+ from ....utils.runtime_utils import RBLNPytorchRuntime
36
+ from ...generation.utils import RBLNGenerationMixin
37
+ from .hf_hub_cached.modeling_midm import MidmLMHeadModel
38
+ from .midm_architecture import (
39
+ MidmLMHeadModelWrapper,
40
+ )
41
+
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+ if TYPE_CHECKING:
46
+ from transformers import (
47
+ AutoFeatureExtractor,
48
+ AutoProcessor,
49
+ AutoTokenizer,
50
+ PretrainedConfig,
51
+ )
52
+
53
+
54
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
55
+ mandatory_members = ["main_input_name"]
56
+
57
+ # RBLN_Runtimemodule
58
+ def forward(
59
+ self,
60
+ input_ids: torch.LongTensor = None,
61
+ attention_mask: torch.LongTensor = None,
62
+ cache_position: torch.Tensor = None,
63
+ **kwargs: Dict[str, Any],
64
+ ):
65
+ logits = super().forward(
66
+ input_ids=input_ids,
67
+ attention_mask=attention_mask,
68
+ cache_position=cache_position,
69
+ )
70
+ return logits
71
+
72
+
73
+ class RBLNMidmLMHeadModel(RBLNModel, RBLNGenerationMixin):
74
+ """
75
+ The Midm Model transformer with a language modeling head on top (linear layer with weights tied to the input
76
+ embeddings).
77
+
78
+ This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the
79
+ library implements for all its model.
80
+
81
+ It implements the methods to convert a pre-trained transformers Midm model into a RBLN transformer model by:
82
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
83
+ - compiling the resulting graph using the RBLN compiler.
84
+
85
+ """
86
+
87
+ model_type = "rbln_model"
88
+ auto_model_class = AutoModelForCausalLM
89
+ main_input_name = "input_ids"
90
+
91
+ def __init__(
92
+ self,
93
+ models: List[Union[PreTrainedModel, rebel.RBLNCompiledModel]],
94
+ config: PretrainedConfig = None,
95
+ preprocessors: Optional[List] = None,
96
+ rbln_config: Optional[RBLNConfig] = None,
97
+ rbln_device: Optional[List[int]] = None,
98
+ rbln_device_map: Optional[Dict[str, int]] = None,
99
+ **kwargs,
100
+ ):
101
+ super().__init__(
102
+ models,
103
+ config,
104
+ preprocessors,
105
+ rbln_config,
106
+ rbln_device=rbln_device,
107
+ rbln_device_map=rbln_device_map,
108
+ **kwargs,
109
+ )
110
+ self.batch_size = self.rbln_config.meta["rbln_batch_size"]
111
+ self.prefill_chunk_size = self.rbln_config.meta["rbln_prefill_chunk_size"]
112
+ self.max_seq_len = self.rbln_config.meta["rbln_max_seq_len"]
113
+
114
+ self.prefill_attention_mask = torch.zeros(
115
+ self.batch_size, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.int64
116
+ )
117
+ self.causal_mask = 1 - torch.triu(
118
+ torch.ones(self.batch_size, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
119
+ )
120
+
121
+ self.prefill_decoder = RBLNRuntimeDecoder(runtime=self.model[0], main_input_name="input_ids")
122
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
123
+ self.past_cached_length = 0
124
+
125
+ def can_generate(self):
126
+ return True
127
+
128
+ def __getattr__(self, __name: str) -> Any:
129
+ """This is the key method to implement RBLN-Midm.
130
+
131
+ Returns:
132
+ Any: Midm's corresponding method
133
+ """
134
+
135
+ def redirect(func):
136
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
137
+
138
+ val = getattr(MidmLMHeadModel, __name)
139
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
140
+ return redirect(val)
141
+ return val
142
+
143
+ def _reorder_cache(self, past_key_values, beam_idx):
144
+ # TODO(jongho): implement
145
+ raise NotImplementedError
146
+
147
+ @classmethod
148
+ @torch.inference_mode()
149
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
150
+ wrapped_decoder = MidmLMHeadModelWrapper(model).eval()
151
+ prefill_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
152
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
153
+
154
+ prefill_example_inputs = prefill_rbln_runtime_config.get_dummy_inputs(fill=0)
155
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
156
+
157
+ prefill_scripted_model = torch.jit.trace(wrapped_decoder, prefill_example_inputs, check_trace=False)
158
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
159
+
160
+ prefill_ir = rebel.torchscript_to_ir(
161
+ prefill_scripted_model,
162
+ input_names=[v[0] for v in prefill_rbln_runtime_config.input_info],
163
+ )
164
+ dec_ir = rebel.torchscript_to_ir(
165
+ dec_scripted_model,
166
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
167
+ )
168
+
169
+ connections = [(prefill_ir.outputs[1 + i], prefill_ir.inputs[3 + i]) for i in range(model.config.n_layer * 2)]
170
+
171
+ compiled_model = rebel.compile(
172
+ prefill_ir,
173
+ dec_ir,
174
+ connections=connections,
175
+ fusion=prefill_rbln_runtime_config.fusion,
176
+ npu=prefill_rbln_runtime_config.npu,
177
+ tensor_parallel_size=prefill_rbln_runtime_config.tensor_parallel_size,
178
+ use_weight_sharing=True,
179
+ )
180
+ return compiled_model
181
+
182
+ @classmethod
183
+ def update_kwargs(cls, kwargs):
184
+ """
185
+ Update user-given kwargs to get proper pytorch model.
186
+
187
+ For example, `torchscript`=True should be set because torch.jit
188
+ does not support `transformers` output instances as module output;
189
+ """
190
+ kwargs.update(
191
+ {
192
+ "torchscript": True,
193
+ "return_dict": False,
194
+ "use_cache": True,
195
+ "torch_dtype": torch.float32,
196
+ "_attn_implementation": "eager",
197
+ }
198
+ )
199
+ return kwargs
200
+
201
+ @classmethod
202
+ def get_pytorch_model(
203
+ cls,
204
+ model_id: str,
205
+ use_auth_token: Optional[Union[bool, str]] = None,
206
+ revision: Optional[str] = None,
207
+ force_download: bool = False,
208
+ cache_dir: Optional[str] = None,
209
+ subfolder: str = "",
210
+ local_files_only: bool = False,
211
+ trust_remote_code: bool = False,
212
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
213
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
214
+ **kwargs,
215
+ ) -> PreTrainedModel:
216
+ if rbln_max_seq_len := rbln_config_kwargs.get("rbln_max_seq_len", None):
217
+ config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
218
+ if hf_position_embedding := getattr(config, "max_position_embeddings", None):
219
+ if hf_position_embedding < rbln_max_seq_len:
220
+ logger.warning(
221
+ f"`rbln_max_seq_len` is larger than original config({hf_position_embedding})."
222
+ "This may lead to incorrect inferences of the model."
223
+ )
224
+ kwargs.update({"max_position_embeddings": rbln_max_seq_len})
225
+
226
+ return super().get_pytorch_model(
227
+ model_id=model_id,
228
+ use_auth_token=use_auth_token,
229
+ revision=revision,
230
+ force_download=force_download,
231
+ cache_dir=cache_dir,
232
+ subfolder=subfolder,
233
+ local_files_only=local_files_only,
234
+ trust_remote_code=trust_remote_code,
235
+ rbln_config_kwargs=rbln_config_kwargs,
236
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
237
+ ignore_mismatched_sizes=True,
238
+ **kwargs,
239
+ )
240
+
241
+ @classmethod
242
+ def _get_rbln_config(
243
+ cls,
244
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
245
+ model_config: "PretrainedConfig",
246
+ rbln_prefill_chunk_size: Optional[int] = 128,
247
+ rbln_max_seq_len: Optional[int] = None,
248
+ rbln_batch_size: Optional[int] = None,
249
+ ) -> RBLNConfig:
250
+ meta = {}
251
+ if rbln_max_seq_len is None:
252
+ rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
253
+
254
+ if rbln_max_seq_len is None:
255
+ for tokenizer in preprocessors:
256
+ if hasattr(tokenizer, "model_max_length"):
257
+ rbln_max_seq_len = tokenizer.model_max_length
258
+ break
259
+ if rbln_max_seq_len is None:
260
+ raise ValueError("`rbln_max_seq_len` should be specified!")
261
+
262
+ if rbln_batch_size is None:
263
+ rbln_batch_size = 1
264
+
265
+ meta["rbln_prefill_chunk_size"] = rbln_prefill_chunk_size
266
+ meta["rbln_max_seq_len"] = rbln_max_seq_len
267
+ meta["rbln_batch_size"] = rbln_batch_size if rbln_batch_size is not None else 1
268
+
269
+ def get_input_info(query_length):
270
+ input_info = [
271
+ ("input_ids", [rbln_batch_size, query_length], "int64"),
272
+ ("attention_mask", [rbln_batch_size, 1, query_length, rbln_max_seq_len], "int64"),
273
+ (
274
+ "cache_position",
275
+ [],
276
+ "int32",
277
+ ),
278
+ ]
279
+ input_info.extend(
280
+ [
281
+ (
282
+ f"past_key_values_{i}",
283
+ [
284
+ rbln_batch_size,
285
+ model_config.n_head,
286
+ rbln_max_seq_len,
287
+ model_config.hidden_size // model_config.n_head,
288
+ ],
289
+ "float32",
290
+ )
291
+ for i in range(model_config.n_layer * 2)
292
+ ]
293
+ )
294
+ return input_info
295
+
296
+ # model input info
297
+ prefill_input_info = get_input_info(query_length=rbln_prefill_chunk_size)
298
+ dec_input_info = get_input_info(query_length=1)
299
+
300
+ prefill_rbln_runtime_config = RBLNRuntimeConfig(input_info=prefill_input_info)
301
+ dec_rbln_runtime_config = RBLNRuntimeConfig(input_info=dec_input_info)
302
+
303
+ dec_rbln_runtime_config.batch_size = rbln_batch_size
304
+
305
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
306
+ [prefill_rbln_runtime_config, dec_rbln_runtime_config],
307
+ _rbln_meta=meta,
308
+ )
309
+
310
+ return rbln_config
311
+
312
+ @classmethod
313
+ def _create_runtimes(
314
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
315
+ ) -> List[rebel.Runtime]:
316
+ device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
317
+ return [
318
+ compiled_models[0].create_runtime(input_info_index=0, tensor_type="pt", device=device_val),
319
+ compiled_models[0].create_runtime(input_info_index=1, tensor_type="pt", device=device_val),
320
+ ]
321
+
322
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=0, attention_mask=None, **kwargs):
323
+ batch_size, cur_len = input_ids.shape
324
+ past_cached_length = past_key_values
325
+
326
+ if past_cached_length == 0:
327
+ mod_len = cur_len % self.prefill_chunk_size
328
+ self.pad_len = self.prefill_chunk_size - mod_len if mod_len > 0 else 0
329
+
330
+ prompt_attn_mask = torch.nn.functional.pad(attention_mask, (self.pad_len, 0), value=0)
331
+ self.prompt_attn_mask = prompt_attn_mask.reshape(batch_size, 1, 1, -1).contiguous()
332
+
333
+ input_ids = torch.nn.functional.pad(input_ids, (self.pad_len, 0), value=0)
334
+ attention_mask = self.prefill_attention_mask.clone()
335
+ cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
336
+
337
+ query_length = cur_len + self.pad_len
338
+ else:
339
+ attention_mask = torch.nn.functional.pad(
340
+ attention_mask, (self.pad_len, self.max_seq_len - cur_len - self.pad_len)
341
+ )
342
+ attention_mask = attention_mask.reshape(batch_size, 1, 1, -1).contiguous()
343
+ cache_position = torch.tensor(past_cached_length, dtype=torch.int32)
344
+ input_ids = input_ids[:, -1:].contiguous()
345
+ query_length = 1
346
+
347
+ model_inputs = {
348
+ "input_ids": input_ids,
349
+ "past_key_values": past_cached_length,
350
+ "attention_mask": attention_mask,
351
+ "cache_position": cache_position,
352
+ "query_length": query_length,
353
+ }
354
+
355
+ return model_inputs
356
+
357
+ def forward(
358
+ self,
359
+ input_ids: Optional[torch.LongTensor] = None,
360
+ past_key_values: int = None,
361
+ attention_mask: Optional[torch.FloatTensor] = None,
362
+ cache_position: Optional[torch.Tensor] = None,
363
+ query_length: Optional[torch.Tensor] = None,
364
+ **kwargs,
365
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
366
+ past_cached_length = past_key_values
367
+
368
+ if past_cached_length is not None:
369
+ past_cached_length += query_length
370
+
371
+ if cache_position == 0:
372
+ for step in range(0, query_length, self.prefill_chunk_size):
373
+ sliced_input_ids = input_ids[:, step : step + self.prefill_chunk_size]
374
+ attention_mask[:, :, :, :step] = 1
375
+ attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
376
+ attention_mask[:, :, :, :query_length] *= self.prompt_attn_mask
377
+
378
+ output = self.prefill_decoder(
379
+ input_ids=sliced_input_ids.contiguous(),
380
+ attention_mask=attention_mask,
381
+ cache_position=cache_position + step,
382
+ )
383
+ cache_position += self.prefill_chunk_size
384
+ else:
385
+ output = self.decoder(
386
+ input_ids=input_ids.contiguous(),
387
+ attention_mask=attention_mask,
388
+ cache_position=cache_position,
389
+ )
390
+ return CausalLMOutputWithCrossAttentions(logits=output, past_key_values=past_cached_length)
@@ -23,13 +23,10 @@
23
23
 
24
24
  import inspect
25
25
  import logging
26
- from pathlib import Path
27
- from tempfile import TemporaryDirectory
28
26
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
29
27
 
30
28
  import rebel
31
29
  import torch
32
- from optimum.exporters import TasksManager
33
30
  from transformers import (
34
31
  AutoModelForSpeechSeq2Seq,
35
32
  AutoProcessor,
@@ -40,10 +37,9 @@ from transformers import (
40
37
  )
41
38
  from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
42
39
 
43
- from ....modeling_base import RBLNBaseModel
40
+ from ....modeling_base import RBLNModel
44
41
  from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
45
42
  from ....utils.runtime_utils import RBLNPytorchRuntime
46
- from ....utils.save_utils import maybe_save_preprocessors
47
43
  from .whisper_architecture import (
48
44
  _WhisperDecoderWrapper,
49
45
  _WhisperEncoderWrapper,
@@ -76,10 +72,10 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
76
72
  return Seq2SeqLMOutput(logits=outputs)
77
73
 
78
74
 
79
- class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
75
+ class RBLNWhisperForConditionalGeneration(RBLNModel, GenerationMixin):
80
76
  """
81
77
  The Whisper Model with a language modeling head. Can be used for automatic speech recognition.
82
- This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
78
+ This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
83
79
 
84
80
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
85
81
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -96,8 +92,8 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
96
92
  self.enc_max_seq_len = self.rbln_config.meta["input_max_length"]
97
93
  self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
98
94
 
99
- self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_features")
100
- self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
95
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
96
+ self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
101
97
  self.forced_decoder_ids = self.config.forced_decoder_ids
102
98
 
103
99
  # used in GenerationMixin.generate()
@@ -152,123 +148,57 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
152
148
  }
153
149
 
154
150
  @classmethod
155
- def _export(
156
- cls,
157
- model_id: str,
158
- config: "PretrainedConfig",
159
- use_auth_token: Optional[Union[bool, str]] = None,
160
- revision: Optional[str] = None,
161
- force_download: bool = False,
162
- cache_dir: Optional[str] = None,
163
- subfolder: str = "",
164
- local_files_only: bool = False,
165
- trust_remote_code: bool = False,
166
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
167
- **kwargs,
168
- ) -> "RBLNWhisperForConditionalGeneration":
169
- """
170
- Exports a vanilla Transformers model into a rbln-compiled Module.
171
- """
172
- task = kwargs.pop("task", None)
173
- if task is None:
174
- task = TasksManager.infer_task_from_model(cls.auto_model_class)
175
-
176
- if model_save_dir is None:
177
- save_dir = TemporaryDirectory()
178
- save_dir_path = Path(save_dir.name)
179
- else:
180
- save_dir = model_save_dir
181
- if isinstance(save_dir, TemporaryDirectory):
182
- save_dir_path = Path(model_save_dir.name)
183
- else:
184
- save_dir_path = Path(model_save_dir)
185
- save_dir_path.mkdir(exist_ok=True)
186
-
151
+ def update_kwargs(cls, kwargs):
187
152
  kwargs.update(
188
153
  {
189
154
  "torchscript": True,
190
155
  "return_dict": False,
191
- "use_cache": False,
156
+ "use_cache": True,
192
157
  }
193
158
  )
194
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
195
-
196
- model: WhisperForConditionalGeneration = TasksManager.get_model_from_task(
197
- task=task,
198
- model_name_or_path=model_id,
199
- subfolder=subfolder,
200
- revision=revision,
201
- framework="pt",
202
- cache_dir=cache_dir,
203
- use_auth_token=use_auth_token,
204
- local_files_only=local_files_only,
205
- force_download=force_download,
206
- trust_remote_code=trust_remote_code,
207
- **kwargs,
159
+ return kwargs
160
+
161
+ @classmethod
162
+ @torch.inference_mode()
163
+ def get_compiled_model(cls, model, rbln_config: RBLNConfig):
164
+ wrapped_encoder = _WhisperEncoderWrapper(model).eval()
165
+ wrapped_decoder = _WhisperDecoderWrapper(model).eval()
166
+
167
+ enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
168
+ dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
169
+
170
+ enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
171
+ dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
172
+
173
+ enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0], check_trace=False)
174
+ dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
175
+
176
+ enc_ir = rebel.torchscript_to_ir(
177
+ enc_scripted_model,
178
+ input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
179
+ name=enc_rbln_runtime_config.rbln_mod_name,
180
+ )
181
+ dec_ir = rebel.torchscript_to_ir(
182
+ dec_scripted_model,
183
+ input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
184
+ name=dec_rbln_runtime_config.rbln_mod_name,
208
185
  )
186
+ dec_ir.batch_size = dec_rbln_runtime_config.batch_size
209
187
 
210
- if config is None:
211
- config = model.config
212
-
213
- config.save_pretrained(save_dir_path)
214
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
215
-
216
- # Get compilation arguments
217
- if rbln_config_kwargs.get("rbln_config", None) is None:
218
- rbln_config = cls.get_rbln_config(
219
- preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
220
- )
221
-
222
- def compile_whisper():
223
- wrapped_encoder = _WhisperEncoderWrapper(model).eval()
224
- wrapped_decoder = _WhisperDecoderWrapper(model).eval()
225
-
226
- enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
227
- dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
228
-
229
- enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
230
- dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
231
-
232
- enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs[0]).eval()
233
- dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs).eval()
234
-
235
- enc_ir = rebel.torchscript_to_ir(
236
- enc_scripted_model,
237
- input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
238
- name=enc_rbln_runtime_config.rbln_mod_name,
239
- )
240
- dec_ir = rebel.torchscript_to_ir(
241
- dec_scripted_model,
242
- input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
243
- name=dec_rbln_runtime_config.rbln_mod_name,
244
- )
245
- dec_ir.batch_size = dec_rbln_runtime_config.batch_size
246
-
247
- # Caching encoder/decoder I/O
248
- connections = [
249
- (enc_ir.outputs[0], dec_ir.inputs[4]),
250
- (dec_ir.outputs[1], dec_ir.inputs[3]),
251
- ]
252
- compiled_model = rebel.compile(
253
- enc_ir,
254
- dec_ir,
255
- connections=connections,
256
- fusion=enc_rbln_runtime_config.fusion,
257
- npu=enc_rbln_runtime_config.npu,
258
- tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
259
- )
260
- compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
261
-
262
- compile_whisper()
263
- rbln_config.save(save_dir_path)
264
-
265
- return cls._from_pretrained(
266
- model_id=save_dir_path,
267
- config=config,
268
- model_save_dir=save_dir,
269
- **rbln_constructor_kwargs,
270
- **kwargs,
188
+ # Caching encoder/decoder I/O
189
+ connections = [
190
+ (enc_ir.outputs[0], dec_ir.inputs[4]),
191
+ (dec_ir.outputs[1], dec_ir.inputs[3]),
192
+ ]
193
+ compiled_model = rebel.compile(
194
+ enc_ir,
195
+ dec_ir,
196
+ connections=connections,
197
+ fusion=enc_rbln_runtime_config.fusion,
198
+ npu=enc_rbln_runtime_config.npu,
199
+ tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
271
200
  )
201
+ return compiled_model
272
202
 
273
203
  @classmethod
274
204
  def _get_rbln_config(
@@ -357,11 +287,14 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
357
287
 
358
288
  return rbln_config
359
289
 
360
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
290
+ @classmethod
291
+ def _create_runtimes(
292
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
293
+ ) -> List[rebel.Runtime]:
361
294
  device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
362
295
  return [
363
- self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
364
- self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
296
+ compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
297
+ compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
365
298
  ]
366
299
 
367
300
  def forward(
@@ -379,6 +312,3 @@ class RBLNWhisperForConditionalGeneration(RBLNBaseModel, GenerationMixin):
379
312
  lm_logits = decoder_output.logits
380
313
 
381
314
  return Seq2SeqLMOutput(logits=lm_logits)
382
-
383
- def __repr__(self):
384
- return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
@@ -21,5 +21,5 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .import_utils import is_rbln_available
24
+ from .import_utils import check_version_compats, is_rbln_available
25
25
  from .runtime_utils import RBLNPytorchRuntime
@@ -21,8 +21,54 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import importlib.metadata
24
25
  import importlib.util
26
+ import warnings
27
+ from dataclasses import dataclass
28
+
29
+ from packaging.version import Version
30
+
31
+
32
+ @dataclass
33
+ class VersionCompat:
34
+ package_name: str
35
+ min_version: str
36
+ max_version: str
37
+
38
+
39
+ RBLN_VERSION_COMPATS = {
40
+ "0.1.5": [
41
+ VersionCompat(
42
+ package_name="rebel-compiler",
43
+ min_version="0.5.7",
44
+ max_version="0.5.8",
45
+ ),
46
+ ],
47
+ "0.0.0": [],
48
+ }
25
49
 
26
50
 
27
51
  def is_rbln_available() -> bool:
28
52
  return importlib.util.find_spec("rebel-compiler") is not None
53
+
54
+
55
+ def check_version_compats() -> None:
56
+ warnings.filterwarnings(action="always", category=ImportWarning)
57
+
58
+ my_version = importlib.metadata.version("optimum-rbln")
59
+ target_version = list(filter(lambda v: Version(my_version) > Version(v), RBLN_VERSION_COMPATS.keys()))[0]
60
+ for compat in RBLN_VERSION_COMPATS[target_version]:
61
+ try:
62
+ dep_version = importlib.metadata.version(compat.package_name)
63
+ except importlib.metadata.PackageNotFoundError:
64
+ warnings.warn(f"optimum-rbln requires {compat.package_name} to be installed.", ImportWarning)
65
+ continue
66
+
67
+ if not Version(compat.min_version) <= Version(dep_version) < Version(compat.max_version):
68
+ warnings.warn(
69
+ f"optimum-rbln v{my_version} is compatible to {compat.package_name} v{compat.min_version} to v{compat.max_version}. (you are currently using v{dep_version})\n"
70
+ "Please refer to our SDK release notes at https://docs.rbln.ai/about_atom/release_note.html",
71
+ ImportWarning,
72
+ )
73
+
74
+ warnings.resetwarnings()