optimum-rbln 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. optimum/rbln/__init__.py +115 -0
  2. optimum/rbln/__version__.py +1 -0
  3. optimum/rbln/diffusers/__init__.py +64 -0
  4. optimum/rbln/diffusers/models/__init__.py +26 -0
  5. optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
  6. optimum/rbln/diffusers/models/controlnet.py +180 -0
  7. optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
  8. optimum/rbln/diffusers/pipelines/__init__.py +30 -0
  9. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
  10. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
  16. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
  17. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
  18. optimum/rbln/modeling.py +0 -0
  19. optimum/rbln/modeling_alias.py +49 -0
  20. optimum/rbln/modeling_base.py +645 -0
  21. optimum/rbln/modeling_config.py +169 -0
  22. optimum/rbln/modeling_seq2seq.py +469 -0
  23. optimum/rbln/transformers/__init__.py +59 -0
  24. optimum/rbln/transformers/generation/__init__.py +24 -0
  25. optimum/rbln/transformers/generation/streamers.py +122 -0
  26. optimum/rbln/transformers/models/__init__.py +28 -0
  27. optimum/rbln/transformers/models/bart/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +24 -0
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
  31. optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
  32. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
  33. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
  34. optimum/rbln/transformers/models/llama/__init__.py +24 -0
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
  37. optimum/rbln/transformers/models/t5/__init__.py +24 -0
  38. optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
  39. optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
  40. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
  41. optimum/rbln/transformers/models/whisper/__init__.py +24 -0
  42. optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
  43. optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
  44. optimum/rbln/utils/__init__.py +25 -0
  45. optimum/rbln/utils/import_utils.py +28 -0
  46. optimum/rbln/utils/runtime_utils.py +71 -0
  47. optimum/rbln/utils/save_utils.py +92 -0
  48. optimum_rbln-0.1.0.dist-info/METADATA +144 -0
  49. optimum_rbln-0.1.0.dist-info/RECORD +51 -0
  50. optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
  51. optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,645 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ import logging
25
+ from abc import ABC, abstractmethod
26
+ from pathlib import Path
27
+ from tempfile import TemporaryDirectory
28
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
29
+
30
+ import rebel
31
+ import torch
32
+ from huggingface_hub import HfApi, HfFolder, hf_hub_download
33
+ from optimum.exporters import TasksManager
34
+ from optimum.modeling_base import OptimizedModel
35
+ from transformers import (
36
+ AutoConfig,
37
+ AutoModel,
38
+ AutoModelForAudioClassification,
39
+ AutoModelForImageClassification,
40
+ AutoModelForQuestionAnswering,
41
+ GenerationConfig,
42
+ PretrainedConfig,
43
+ )
44
+
45
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
46
+ from .utils.runtime_utils import UnavailableRuntime
47
+ from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
48
+
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ if TYPE_CHECKING:
53
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
54
+
55
+
56
+ def listify(var: Any):
57
+ if isinstance(var, list):
58
+ return var
59
+ elif var is not None:
60
+ return [var]
61
+ else:
62
+ return None
63
+
64
+
65
+ class RBLNBaseModel(OptimizedModel, ABC):
66
+ """
67
+ An abstract base class for compiling, loading, and saving neural network models from the huggingface
68
+ transformers and diffusers libraries to run on RBLN NPU devices.
69
+
70
+ This class supports loading and saving models using the `from_pretrained` and `save_pretrained` methods,
71
+ similar to the huggingface libraries.
72
+
73
+ The `from_pretrained` method loads a model corresponding to the given `model_id` from a local repository
74
+ or the huggingface hub onto the NPU. If the model is a PyTorch model and `export=True` is passed as a
75
+ kwarg, it compiles the PyTorch model corresponding to the given `model_id` before loading. If `model_id`
76
+ is an already rbln-compiled model, it can be directly loaded onto the NPU with `export=False`.
77
+
78
+ `rbln_npu` is a kwarg required for compilation, specifying the name of the NPU to be used. If this
79
+ keyword is not specified, the NPU installed on the host machine is used. If no NPU is installed on the
80
+ host machine, an error occurs.
81
+
82
+ `rbln_device` specifies the device to be used at runtime. If not specified, device 0 is used.
83
+
84
+ `rbln_create_runtimes` indicates whether to create runtime objects. If False, the runtime does not load
85
+ the model onto the NPU. This option is particularly useful when you want to perform compilation only on a
86
+ host machine without an NPU.
87
+
88
+ `RBLNModel`, `RBLNModelFor*`, etc. are all child classes of RBLNBaseModel.
89
+
90
+ Models compiled in this way can be saved to a local repository using `save_pretrained` or uploaded to
91
+ the huggingface hub.
92
+
93
+ It also supports generation through `generate` (for transformers models that support generation).
94
+
95
+ RBLNBaseModel is a class for models consisting of an arbitrary number of `torch.nn.Module`s, and
96
+ therefore is an abstract class without explicit implementations of `forward` or `export` functions.
97
+ To inherit from this class, `forward`, `export`, etc. must be implemented.
98
+ """
99
+
100
+ model_type = "rbln_model"
101
+ auto_model_class = AutoModel # feature extraction
102
+ config_name = "model_index.json"
103
+
104
+ def __init__(
105
+ self,
106
+ models: List[rebel.RBLNCompiledModel],
107
+ config: "PretrainedConfig",
108
+ preprocessors: Optional[List],
109
+ rbln_config: Optional[RBLNConfig],
110
+ rbln_device: Optional[List[int]] = None,
111
+ rbln_device_map: Optional[Dict[str, int]] = None,
112
+ rbln_create_runtimes: Optional[bool] = True,
113
+ **kwargs,
114
+ ):
115
+ super().__init__(models, config)
116
+ if not isinstance(self.config, PretrainedConfig): # if diffusers config
117
+ self.config = PretrainedConfig(**self.config)
118
+
119
+ self.models = listify(self.model)
120
+
121
+ self.preprocessors = [] if preprocessors is None else preprocessors
122
+
123
+ # Registers the RBLNBaseModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
124
+ # a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
125
+ AutoConfig.register(self.model_type, AutoConfig)
126
+ if hasattr(self.auto_model_class, "register"):
127
+ self.auto_model_class.register(AutoConfig, self.__class__)
128
+
129
+ self.rbln_config = rbln_config
130
+ self.compiled_models: List[rebel.RBLNCompiledModel] = models
131
+
132
+ if rbln_device_map is None:
133
+ self.rbln_device_map = {}
134
+ device_val = 0 if rbln_device is None else rbln_device
135
+ for key in self.rbln_config:
136
+ self.rbln_device_map[key] = device_val
137
+
138
+ else:
139
+ self.rbln_device_map = rbln_device_map
140
+
141
+ # copied from tranformers PreTrainedModel __init__
142
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
143
+ if self.generation_config is not None:
144
+ self.generation_config.use_cache = True
145
+
146
+ self.device = torch.device("cpu")
147
+
148
+ # create runtimes only if `rbln_create_runtimes` is enabled
149
+ self.runtimes = self._create_runtimes(self.rbln_device_map) if rbln_create_runtimes else UnavailableRuntime()
150
+
151
+ self.__post_init__(**kwargs)
152
+
153
+ def __post_init__(self, **kwargs):
154
+ pass
155
+
156
+ def _save_pretrained(self, save_directory: Union[str, Path]):
157
+ """
158
+ Saves a model and its configuration file to a directory, so that it can be re-loaded using the
159
+ [`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
160
+
161
+ Args:
162
+ save_directory (`Union[str, Path]`):
163
+ Directory where to save the model file.
164
+ """
165
+
166
+ for compiled_model, compiled_model_name in zip(self.compiled_models, self.rbln_config):
167
+ dst_path = Path(save_directory) / f"{compiled_model_name}.rbln"
168
+ compiled_model.save(dst_path)
169
+ self.rbln_config.save(save_directory)
170
+
171
+ @classmethod
172
+ def _from_pretrained(
173
+ cls,
174
+ model_id: Union[str, Path],
175
+ config: "PretrainedConfig",
176
+ use_auth_token: Optional[Union[bool, str]] = None,
177
+ revision: Optional[str] = None,
178
+ force_download: bool = False,
179
+ cache_dir: Optional[str] = None,
180
+ subfolder: str = "",
181
+ local_files_only: bool = False,
182
+ **kwargs,
183
+ ) -> "RBLNBaseModel":
184
+ model_path = Path(model_id)
185
+ if model_path.is_dir():
186
+ model_path = model_path / subfolder
187
+ rbln_files = list(model_path.glob("*.rbln"))
188
+ rbln_config_filenames = list(model_path.glob("rbln_config.json"))
189
+ else:
190
+ if isinstance(use_auth_token, bool):
191
+ token = HfFolder().get_token()
192
+ else:
193
+ token = use_auth_token
194
+ repo_files = list(map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)))
195
+
196
+ pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
197
+ rbln_files = [p for p in repo_files if p.match(pattern)]
198
+
199
+ pattern = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
200
+ rbln_config_filenames = [p for p in repo_files if p.match(pattern)]
201
+
202
+ if len(rbln_files) == 0:
203
+ raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
204
+
205
+ if len(rbln_config_filenames) == 0:
206
+ raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
207
+
208
+ if len(rbln_config_filenames) > 1:
209
+ raise FileExistsError(
210
+ f"Multiple rbln_config.json are not expected. but {len(rbln_config_filenames)} are found."
211
+ )
212
+
213
+ if model_path.is_dir():
214
+ rbln_config = RBLNConfig.load(str(model_path))
215
+ models = [
216
+ rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
217
+ for compiled_model_name in rbln_config
218
+ ]
219
+
220
+ else:
221
+ rbln_config_filename = rbln_config_filenames[0]
222
+ rbln_config_cache_path = hf_hub_download(
223
+ repo_id=model_id,
224
+ filename=str(rbln_config_filename),
225
+ subfolder=subfolder,
226
+ use_auth_token=use_auth_token,
227
+ revision=revision,
228
+ cache_dir=cache_dir,
229
+ force_download=force_download,
230
+ local_files_only=local_files_only,
231
+ )
232
+ rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
233
+ models = []
234
+ for compiled_model_name in rbln_config:
235
+ model_cache_path = hf_hub_download(
236
+ repo_id=model_id,
237
+ filename=f"{compiled_model_name}.rbln",
238
+ subfolder=subfolder,
239
+ use_auth_token=use_auth_token,
240
+ revision=revision,
241
+ cache_dir=cache_dir,
242
+ force_download=force_download,
243
+ local_files_only=local_files_only,
244
+ )
245
+ models.append(rebel.RBLNCompiledModel(model_cache_path))
246
+
247
+ preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
248
+
249
+ return cls(
250
+ models,
251
+ config,
252
+ preprocessors,
253
+ rbln_config=rbln_config,
254
+ **kwargs,
255
+ )
256
+
257
+ def __repr__(self):
258
+ return repr(self.runtimes)
259
+
260
+ @classmethod
261
+ def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
262
+ compiled_model = rebel.compile_from_torch(
263
+ model,
264
+ input_info=rbln_runtime_config.input_info,
265
+ batch_size=rbln_runtime_config.batch_size,
266
+ fusion=rbln_runtime_config.fusion,
267
+ npu=rbln_runtime_config.npu,
268
+ tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
269
+ )
270
+ return compiled_model
271
+
272
+ @classmethod
273
+ def get_rbln_config(
274
+ cls,
275
+ **rbln_config_kwargs,
276
+ ) -> RBLNConfig:
277
+ """
278
+ Make default rbln-config for the model.
279
+
280
+ if `input_info` specified,
281
+ other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
282
+
283
+ kwargs for overriding model's config can be accepted.
284
+
285
+ Note that batch_size should be specified with proper input_info.
286
+ """
287
+
288
+ input_info = rbln_config_kwargs.pop("rbln_input_info", None)
289
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
290
+ fusion = rbln_config_kwargs.pop("rbln_fusion", None)
291
+ npu = rbln_config_kwargs.pop("rbln_npu", None)
292
+ tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
293
+
294
+ if input_info is not None:
295
+ rbln_runtime_config = RBLNRuntimeConfig(
296
+ input_info=input_info,
297
+ batch_size=batch_size,
298
+ fusion=fusion,
299
+ npu=npu,
300
+ tensor_parallel_size=tensor_parallel_size,
301
+ )
302
+ rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
303
+ else:
304
+ rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
305
+ for k, rcfgs in rbln_config.items():
306
+ for rcfg in rcfgs:
307
+ rcfg: RBLNRuntimeConfig
308
+ rcfg.fusion = fusion
309
+ rcfg.npu = npu
310
+ rcfg.tensor_parallel_size = tensor_parallel_size
311
+
312
+ return rbln_config
313
+
314
+ @staticmethod
315
+ def pop_rbln_kwargs_from_kwargs(kwargs: dict):
316
+ keys = list(kwargs.keys())
317
+ rbln_constructor_kwargs = {
318
+ key: kwargs.pop(key) for key in keys if key in ["rbln_device", "rbln_device_map", "rbln_create_runtimes"]
319
+ }
320
+
321
+ keys = list(kwargs.keys())
322
+ rbln_config_kwargs = {key: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
323
+ return rbln_config_kwargs, rbln_constructor_kwargs
324
+
325
+ def can_generate(self):
326
+ return False
327
+
328
+ def to(self, *args, **kwargs):
329
+ pass
330
+
331
+ def __call__(self, *args, **kwargs):
332
+ return self.forward(*args, **kwargs)
333
+
334
+ @classmethod
335
+ def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
336
+ # Wrap the model if needed.
337
+ return model
338
+
339
+ @classmethod
340
+ def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
341
+ """
342
+ Exports a vanilla Transformers model into a rbln-compiled Module.
343
+ This will be deprecated after optimum 2.0
344
+ """
345
+ return cls._export(*args, **kwargs)
346
+
347
+ @classmethod
348
+ def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
349
+ raise NotImplementedError
350
+
351
+ @abstractmethod
352
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
353
+ pass
354
+
355
+ @abstractmethod
356
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
357
+ # self.compiled_models -> self.runtimes
358
+ pass
359
+
360
+ @classmethod
361
+ @abstractmethod
362
+ def _export(
363
+ cls,
364
+ model_id: Union[str, Path],
365
+ config: "PretrainedConfig",
366
+ use_auth_token: Optional[Union[bool, str]] = None,
367
+ revision: Optional[str] = None,
368
+ force_download: bool = False,
369
+ cache_dir: Optional[str] = None,
370
+ subfolder: str = "",
371
+ local_files_only: bool = False,
372
+ trust_remote_code: bool = False,
373
+ **kwargs,
374
+ ):
375
+ """
376
+ Exports a vanilla Transformers model into a rbln-compiled Module.
377
+ """
378
+ pass
379
+
380
+
381
+ class RBLNModel(RBLNBaseModel):
382
+ """
383
+ A class that inherits from RBLNBaseModel for models consisting of a single `torch.nn.Module`.
384
+
385
+ This class supports all the functionality of RBLNBaseModel, including loading and saving models using
386
+ the `from_pretrained` and `save_pretrained` methods, compiling PyTorch models for execution on RBLN NPU
387
+ devices.
388
+
389
+ Example:
390
+ ```python
391
+ model = RBLNModel.from_pretrained("model_id", export=True, rbln_npu="npu_name")
392
+ outputs = model(**inputs)
393
+ ```
394
+ """
395
+
396
+ model_type = "rbln_model"
397
+ auto_model_class = AutoModel # feature extraction
398
+
399
+ @classmethod
400
+ def _export(
401
+ cls,
402
+ model_id: Union[str, Path],
403
+ config: "PretrainedConfig",
404
+ use_auth_token: Optional[Union[bool, str]] = None,
405
+ revision: Optional[str] = None,
406
+ force_download: bool = False,
407
+ cache_dir: Optional[str] = None,
408
+ subfolder: str = "",
409
+ local_files_only: bool = False,
410
+ trust_remote_code: bool = False,
411
+ **kwargs,
412
+ ) -> "RBLNModel":
413
+ """
414
+ Exports a vanilla Transformers model into a rbln-compiled Module.
415
+ """
416
+ task = kwargs.pop("task", None)
417
+ if task is None:
418
+ task = TasksManager.infer_task_from_model(cls.auto_model_class)
419
+
420
+ save_dir = TemporaryDirectory()
421
+ save_dir_path = Path(save_dir.name)
422
+
423
+ kwargs.update(
424
+ {
425
+ "torchscript": True,
426
+ "return_dict": False,
427
+ }
428
+ )
429
+
430
+ rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
431
+
432
+ model = TasksManager.get_model_from_task(
433
+ task=task,
434
+ model_name_or_path=model_id,
435
+ subfolder=subfolder,
436
+ revision=revision,
437
+ framework="pt",
438
+ cache_dir=cache_dir,
439
+ use_auth_token=use_auth_token,
440
+ local_files_only=local_files_only,
441
+ force_download=force_download,
442
+ trust_remote_code=trust_remote_code,
443
+ **kwargs,
444
+ )
445
+
446
+ # TODO : do we need this?
447
+ if isinstance(model, torch.nn.Module):
448
+ model.eval()
449
+
450
+ if config is None:
451
+ config = model.config
452
+
453
+ if not isinstance(config, PretrainedConfig): # diffusers config
454
+ config = PretrainedConfig(**config)
455
+
456
+ config.save_pretrained(save_dir_path / subfolder)
457
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
458
+
459
+ # Get compilation arguments
460
+ if rbln_config_kwargs.get("rbln_config", None) is None:
461
+ rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
462
+
463
+ rbln_runtime_configs = list(rbln_config.values())
464
+ if len(rbln_runtime_configs) != 1:
465
+ raise ValueError
466
+ rbln_runtime_config = rbln_runtime_configs[0]
467
+ if len(rbln_runtime_config) != 1:
468
+ raise ValueError
469
+ rbln_runtime_config = rbln_runtime_config[0]
470
+
471
+ model = cls.wrap_model_if_needed(model)
472
+ compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
473
+ compiled_model.save(save_dir_path / subfolder / f"{rbln_runtime_config.compiled_model_name}.rbln")
474
+ rbln_config.save(save_dir_path / subfolder)
475
+
476
+ return cls._from_pretrained(
477
+ model_id=save_dir_path,
478
+ config=config,
479
+ model_save_dir=save_dir,
480
+ use_auth_token=use_auth_token,
481
+ revision=revision,
482
+ force_download=force_download,
483
+ cache_dir=cache_dir,
484
+ subfolder=subfolder,
485
+ local_files_only=local_files_only,
486
+ **rbln_constructor_kwargs,
487
+ **kwargs,
488
+ )
489
+
490
+ def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
491
+ device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
492
+ return [
493
+ compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in self.compiled_models
494
+ ]
495
+
496
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
497
+ output = self.runtimes[0](*args, **kwargs)
498
+ return output
499
+
500
+ def __repr__(self):
501
+ return repr(self.runtimes[0])
502
+
503
+
504
+ class RBLNModelForQuestionAnswering(RBLNModel):
505
+ model_type = "rbln_model"
506
+ auto_model_class = AutoModelForQuestionAnswering
507
+
508
+ @classmethod
509
+ def _get_rbln_config(
510
+ cls,
511
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
512
+ model_config: Optional["PretrainedConfig"] = None,
513
+ rbln_max_seq_len: Optional[int] = None,
514
+ rbln_model_input_names: Optional[List[str]] = None,
515
+ rbln_batch_size: Optional[int] = None,
516
+ ) -> RBLNConfig:
517
+ if rbln_max_seq_len is None:
518
+ for tokenizer in preprocessors:
519
+ if hasattr(tokenizer, "model_max_length"):
520
+ rbln_max_seq_len = tokenizer.model_max_length
521
+ break
522
+ if rbln_max_seq_len is None:
523
+ raise ValueError("`rbln_max_seq_len` should be specified!")
524
+
525
+ if rbln_model_input_names is None:
526
+ # These are BERT's inputs
527
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
528
+
529
+ if rbln_batch_size is None:
530
+ rbln_batch_size = 1
531
+ input_info = [
532
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
533
+ for model_input_name in rbln_model_input_names
534
+ ]
535
+
536
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
537
+ rbln_runtime_config.batch_size = rbln_batch_size
538
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
539
+
540
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
541
+
542
+
543
+ class RBLNModelForImageClassification(RBLNModel):
544
+ """
545
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method
546
+ """
547
+
548
+ model_type = "rbln_model"
549
+ auto_model_class = AutoModelForImageClassification
550
+
551
+ @classmethod
552
+ def _get_rbln_config(
553
+ cls,
554
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
555
+ model_config: Optional["PretrainedConfig"] = None,
556
+ rbln_image_size: Optional[int] = None,
557
+ rbln_batch_size: Optional[int] = None,
558
+ ) -> RBLNConfig:
559
+ if rbln_image_size is None:
560
+ for processor in preprocessors:
561
+ if hasattr(processor, "size"):
562
+ rbln_image_size = processor.size["shortest_edge"]
563
+ break
564
+ if rbln_image_size is None:
565
+ raise ValueError("`rbln_rbln_image_size` should be specified!")
566
+
567
+ if rbln_batch_size is None:
568
+ rbln_batch_size = 1
569
+
570
+ input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size, rbln_image_size], "float32")]
571
+
572
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
573
+ rbln_runtime_config.batch_size = rbln_batch_size
574
+ meta = {"rbln_image_size": rbln_image_size}
575
+
576
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
577
+
578
+
579
+ class RBLNModelForAudioClassification(RBLNModel):
580
+ """
581
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method
582
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
583
+
584
+ A class to convert and run pre-trained transformers based AudioClassification models on RBLN devices.
585
+ It implements the methods to convert a pre-trained transformers AudioClassification model into a RBLN transformer model by:
586
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
587
+ - compiling the resulting graph using the RBLN compiler.
588
+
589
+ Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
590
+ """
591
+
592
+ model_type = "rbln_model"
593
+ auto_model_class = AutoModelForAudioClassification
594
+
595
+ @classmethod
596
+ def _get_rbln_config(
597
+ cls,
598
+ preprocessors: "AutoFeatureExtractor",
599
+ model_config: "PretrainedConfig",
600
+ rbln_batch_size: Optional[int] = None,
601
+ rbln_max_length: Optional[int] = None,
602
+ rbln_num_mel_bins: Optional[int] = None,
603
+ ) -> RBLNConfig:
604
+ meta = {}
605
+
606
+ if rbln_batch_size is None:
607
+ rbln_batch_size = 1
608
+
609
+ if rbln_num_mel_bins is None:
610
+ rbln_num_mel_bins = getattr(model_config, "num_mel_bins", None)
611
+ if rbln_num_mel_bins is None:
612
+ for feature_extractor in preprocessors:
613
+ if hasattr(feature_extractor, "num_mel_bins"):
614
+ rbln_num_mel_bins = feature_extractor.num_mel_bins
615
+ break
616
+
617
+ if rbln_num_mel_bins is None:
618
+ raise ValueError("`rbln_num_mel_bins` should be specified!")
619
+
620
+ if rbln_max_length is None:
621
+ rbln_max_length = getattr(model_config, "max_length", None)
622
+ for feature_extractor in preprocessors:
623
+ if hasattr(feature_extractor, "max_length"):
624
+ rbln_max_length = feature_extractor.max_length
625
+ break
626
+
627
+ if rbln_max_length is None:
628
+ raise ValueError("`rbln_max_length` should be specified!")
629
+
630
+ meta["rbln_batch_size"] = rbln_batch_size
631
+ meta["rbln_max_length"] = rbln_max_length
632
+ meta["rbln_num_mel_bins"] = rbln_num_mel_bins
633
+
634
+ model_input_info = [
635
+ ("input_values", [rbln_batch_size, rbln_max_length, rbln_num_mel_bins], "float32"),
636
+ ]
637
+
638
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
639
+
640
+ rbln_config = RBLNConfig.from_rbln_runtime_configs(
641
+ [rbln_runtime_config],
642
+ _rbln_meta=meta,
643
+ )
644
+
645
+ return rbln_config