optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. optimum/rbln/__init__.py +40 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
  4. optimum/rbln/diffusers/models/controlnet.py +60 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  15. optimum/rbln/modeling_alias.py +8 -4
  16. optimum/rbln/modeling_base.py +512 -238
  17. optimum/rbln/modeling_config.py +152 -77
  18. optimum/rbln/modeling_seq2seq.py +166 -77
  19. optimum/rbln/transformers/__init__.py +37 -1
  20. optimum/rbln/transformers/models/__init__.py +21 -1
  21. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  22. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  23. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  24. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  25. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  26. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  27. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  29. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  30. optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  34. optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
  35. optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
  36. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  37. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
  38. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
  39. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  40. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  41. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  42. optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
  43. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  44. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  45. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  46. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  49. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  50. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
  51. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  52. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  53. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  54. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
  55. optimum/rbln/transformers/utils/__init__.py +0 -0
  56. optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
  57. optimum/rbln/utils/import_utils.py +37 -5
  58. optimum/rbln/utils/logging.py +82 -0
  59. optimum/rbln/utils/runtime_utils.py +35 -1
  60. optimum/rbln/utils/timer_utils.py +19 -0
  61. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
  62. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  63. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  64. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  65. optimum_rbln-0.1.8.dist-info/RECORD +0 -73
  66. {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -21,16 +21,20 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import copy
25
+ import importlib
26
+ import inspect
24
27
  import logging
25
28
  import os
26
29
  import shutil
27
30
  from abc import ABC, abstractmethod
28
31
  from pathlib import Path
29
32
  from tempfile import TemporaryDirectory
30
- from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
33
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
31
34
 
32
35
  import rebel
33
36
  import torch
37
+ import transformers
34
38
  from huggingface_hub import HfApi, HfFolder, hf_hub_download
35
39
  from optimum.exporters import TasksManager
36
40
  from optimum.modeling_base import OptimizedModel
@@ -46,18 +50,132 @@ from transformers import (
46
50
  PretrainedConfig,
47
51
  )
48
52
 
49
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
53
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
50
54
  from .utils.runtime_utils import UnavailableRuntime
51
55
  from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
52
56
 
53
57
 
58
+ if TYPE_CHECKING:
59
+ from transformers import (
60
+ AutoFeatureExtractor,
61
+ AutoProcessor,
62
+ AutoTokenizer,
63
+ PreTrainedModel,
64
+ )
65
+
54
66
  logger = logging.getLogger(__name__)
55
67
 
56
- if TYPE_CHECKING:
57
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
58
68
 
69
+ class SubModulesMixin:
70
+ """
71
+ _rbln_submodules = [
72
+ {"name": "vision_tower"},
73
+ {"name": "language_model"},
74
+ ]
75
+ """
76
+
77
+ _rbln_submodules: List[Dict[str, Any]] = []
78
+
79
+ def __init__(
80
+ self,
81
+ *,
82
+ rbln_submodules: List["RBLNBaseModel"] = [],
83
+ **kwargs,
84
+ ) -> None:
85
+ for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
86
+ setattr(self, submodule_meta["name"], submodule)
87
+
88
+ @classmethod
89
+ def _from_model(
90
+ cls,
91
+ model: "PreTrainedModel",
92
+ model_save_dir: str,
93
+ rbln_sub_configs_dict: Dict[str, Any],
94
+ rbln_kwargs: Dict[str, Any],
95
+ subfolder=None, # warning: will be ignored
96
+ **kwargs,
97
+ ) -> List["RBLNBaseModel"]:
98
+ rbln_submodules = []
99
+ for submodule in cls._rbln_submodules:
100
+ submodule_name = submodule["name"]
101
+ torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
102
+ cls_name = torch_submodule.__class__.__name__
103
+ submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
104
+
105
+ if submodule_name in rbln_sub_configs_dict:
106
+ kwargs["rbln_config"] = rbln_sub_configs_dict[submodule_name]
107
+
108
+ rbln_submodule = submodule_cls._export(
109
+ model_id=None,
110
+ config=torch_submodule.config,
111
+ subfolder=submodule_name,
112
+ model_save_dir=model_save_dir,
113
+ model=torch_submodule,
114
+ **rbln_kwargs,
115
+ **kwargs,
116
+ )
117
+
118
+ rbln_submodules.append(rbln_submodule)
119
+
120
+ return rbln_submodules
121
+
122
+ @classmethod
123
+ def _submodule_from_compiled_model(
124
+ cls, model_save_dir: str, rbln_sub_configs_dict: Dict[str, Any], rbln_kwargs: Dict[str, Any], **kwargs
125
+ ):
126
+ rbln_submodules = []
127
+ for submodule in cls._rbln_submodules:
128
+ submodule_name = submodule["name"]
129
+ rbln_submodule_config_dict = rbln_sub_configs_dict.get(submodule_name, None)
130
+
131
+ # Get cls name for call the constructor of the rbln class
132
+ submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
133
+ submodule_cls_name = submodule_rbln_config.meta["cls"]
134
+ submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
135
+
136
+ config = OptimizedModel._load_config(Path(model_save_dir) / submodule_name, **kwargs)
137
+ rbln_submodule = submodule_cls._from_pretrained(
138
+ model_id=model_save_dir,
139
+ config=config,
140
+ subfolder=submodule_name,
141
+ rbln_config=rbln_submodule_config_dict,
142
+ **rbln_kwargs,
143
+ **kwargs,
144
+ )
145
+ rbln_submodules.append(rbln_submodule)
146
+ return rbln_submodules
147
+
148
+ @classmethod
149
+ def _load_submodules(
150
+ cls,
151
+ model_save_dir,
152
+ rbln_sub_configs_dict,
153
+ rbln_kwargs,
154
+ model=None,
155
+ **kwargs,
156
+ ):
157
+ # Two way :
158
+ # 1. Compile from pytorch object
159
+ # 2. Load from compiled file
160
+ if model is not None:
161
+ return cls._from_model(
162
+ model=model,
163
+ model_save_dir=model_save_dir,
164
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
165
+ rbln_kwargs=rbln_kwargs,
166
+ **kwargs,
167
+ )
168
+
169
+ else:
170
+ return cls._submodule_from_compiled_model(
171
+ model_save_dir=model_save_dir,
172
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
173
+ rbln_kwargs=rbln_kwargs,
174
+ **kwargs,
175
+ )
59
176
 
60
- class RBLNBaseModel(OptimizedModel, ABC):
177
+
178
+ class RBLNBaseModel(OptimizedModel, ABC, SubModulesMixin):
61
179
  """
62
180
  An abstract base class for compiling, loading, and saving neural network models from the huggingface
63
181
  transformers and diffusers libraries to run on RBLN NPU devices.
@@ -105,6 +223,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
105
223
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
106
224
  subfolder: str = "",
107
225
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
226
+ rbln_submodules: List["RBLNBaseModel"] = [],
108
227
  **kwargs,
109
228
  ):
110
229
  super().__init__(models, config)
@@ -122,11 +241,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
122
241
  self.auto_model_class.register(AutoConfig, self.__class__)
123
242
 
124
243
  # copied from tranformers PreTrainedModel __init__
125
- self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
244
+ if self.can_generate():
245
+ gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
246
+ self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
247
+ else:
248
+ self.generation_config = None
249
+
250
+ # self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
126
251
  if self.generation_config is not None:
127
252
  self.generation_config.use_cache = True
128
253
 
129
254
  self.device = torch.device("cpu")
255
+ self.training = False
130
256
 
131
257
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
132
258
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -141,11 +267,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
141
267
  self.model_save_dir = model_save_dir
142
268
  self.subfolder = subfolder
143
269
 
270
+ self.rbln_submodules = rbln_submodules
144
271
  self.__post_init__(**kwargs)
145
272
 
146
- def __post_init__(self, **kwargs):
147
- pass
148
-
149
273
  def _save_pretrained(self, save_directory: Union[str, Path]):
150
274
  """
151
275
  Saves a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -156,36 +280,37 @@ class RBLNBaseModel(OptimizedModel, ABC):
156
280
  Directory where to save the model file.
157
281
  """
158
282
  real_save_dir = self.model_save_dir / self.subfolder
283
+ save_directory_path = Path(save_directory)
159
284
  if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
285
+ if save_directory_path.absolute() == real_save_dir.absolute():
286
+ raise FileExistsError(
287
+ f"Cannot save model to '{save_directory}'. "
288
+ f"This directory already exists and contains the model files."
289
+ )
160
290
  shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
161
291
  self.config.save_pretrained(save_directory)
162
292
  if self.generation_config is not None:
163
293
  self.generation_config.save_pretrained(save_directory)
164
294
  else:
165
- raise FileNotFoundError(f"Saving compiled model failed.({real_save_dir}).")
295
+ raise FileNotFoundError(
296
+ f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
297
+ f"Cannot save to the specified destination '{save_directory}'. "
298
+ f"Please ensure the model directory exists and you have the necessary permissions to access it."
299
+ )
166
300
 
167
301
  @classmethod
168
- def _from_pretrained(
302
+ def _load_compiled_model_dir(
169
303
  cls,
170
304
  model_id: Union[str, Path],
171
- config: "PretrainedConfig",
172
305
  use_auth_token: Optional[Union[bool, str]] = None,
173
306
  revision: Optional[str] = None,
174
307
  force_download: bool = False,
175
308
  cache_dir: Optional[str] = None,
176
309
  subfolder: str = "",
177
310
  local_files_only: bool = False,
178
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
179
- # Runtime - related kwargs
180
- rbln_device: Optional[List[int]] = None,
181
- rbln_device_map: Optional[Dict[str, int]] = None,
182
- rbln_create_runtimes: Optional[bool] = None,
183
- # passed from compile function
184
- rbln_config: Optional[RBLNConfig] = None,
185
- rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
186
- rbln_optimize_host_memory: Optional[bool] = None,
187
- **kwargs,
188
- ) -> "RBLNBaseModel":
311
+ ):
312
+ # Find compiled model
313
+ # And prepare or download cache folder from HF Hub if needed.
189
314
  model_path = Path(model_id)
190
315
  if model_path.is_dir():
191
316
  model_path = model_path / subfolder
@@ -196,7 +321,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
196
321
  token = HfFolder().get_token()
197
322
  else:
198
323
  token = use_auth_token
199
- repo_files = list(map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)))
324
+ repo_files = list(
325
+ map(
326
+ Path,
327
+ HfApi().list_repo_files(model_id, revision=revision, token=token),
328
+ )
329
+ )
200
330
 
201
331
  pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
202
332
  rbln_files = [p for p in repo_files if p.match(pattern)]
@@ -216,16 +346,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
216
346
  )
217
347
 
218
348
  if model_path.is_dir():
219
- if rbln_compiled_models is None:
220
- rbln_config = RBLNConfig.load(str(model_path))
221
- rbln_compiled_models = [
222
- rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
223
- for compiled_model_name in rbln_config
224
- ]
225
- new_model_save_dir = model_path
226
- else:
227
- pass
228
-
349
+ model_path = str(model_path)
229
350
  else:
230
351
  rbln_config_filename = rbln_config_filenames[0]
231
352
  rbln_config_cache_path = hf_hub_download(
@@ -238,48 +359,145 @@ class RBLNBaseModel(OptimizedModel, ABC):
238
359
  force_download=force_download,
239
360
  local_files_only=local_files_only,
240
361
  )
241
- rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
242
- rbln_compiled_models = []
243
- for compiled_model_name in rbln_config:
244
- model_cache_path = hf_hub_download(
245
- repo_id=model_id,
246
- filename=f"{compiled_model_name}.rbln",
247
- subfolder=subfolder,
248
- use_auth_token=use_auth_token,
249
- revision=revision,
250
- cache_dir=cache_dir,
251
- force_download=force_download,
252
- local_files_only=local_files_only,
253
- )
254
- rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
255
- new_model_save_dir = Path(rbln_config_cache_path).parent
362
+ model_path = Path(rbln_config_cache_path).parent
256
363
 
257
- preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
364
+ return model_path
365
+
366
+ @classmethod
367
+ def _load_compiled_models(cls, model_path: str):
368
+ compiled_models = Path(model_path).glob("*.rbln")
369
+ rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
370
+ return rbln_compiled_models
371
+
372
+ @classmethod
373
+ def _split_submodule_config(cls, rbln_config_dict: Dict[str, Any] = {}) -> Dict[str, Any]:
374
+ # {"language_model" : {"rbln_tensor_parallel_size":4}}
375
+ rbln_sub_configs_dict: Dict[str, Dict[str, Any]] = {}
376
+
377
+ # Remove submodule-configs from rbln_config
378
+ if len(cls._rbln_submodules) > 0:
379
+ keys = list(rbln_config_dict.keys())
380
+ submodule_names = [m["name"] for m in cls._rbln_submodules]
381
+ for key in keys:
382
+ if key in submodule_names:
383
+ rbln_sub_configs_dict[key] = rbln_config_dict.pop(key)
384
+
385
+ return rbln_sub_configs_dict
386
+
387
+ @classmethod
388
+ def resolve_rbln_config(cls, rbln_config: Union[RBLNConfig, Dict[str, Any]], kwargs):
389
+ if isinstance(rbln_config, RBLNConfig):
390
+ # Already resolved
391
+ return rbln_config, None
258
392
 
259
- if model_save_dir is None:
260
- model_save_dir = new_model_save_dir
261
-
262
- # Create runtimes
263
- if rbln_create_runtimes is None:
264
- rbln_create_runtimes = rebel.npu_is_available()
265
- if rbln_device_map is None:
266
- rbln_device_map = {}
267
- device_val = 0 if rbln_device is None else rbln_device
268
- for key in rbln_config:
269
- rbln_device_map[key] = device_val
270
393
  else:
271
- rbln_device_map = rbln_device_map
394
+ if rbln_config is None:
395
+ rbln_config_dict = {}
396
+ else:
397
+ rbln_config_dict = rbln_config
398
+
399
+ rbln_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
400
+ rbln_sub_configs_dict = cls._split_submodule_config(rbln_config_dict)
401
+
402
+ for key in rbln_config_dict:
403
+ if key in rbln_kwargs:
404
+ raise KeyError(f"duplicate key in both `rbln_config` and {key}")
405
+
406
+ merged_rbln_kwargs = copy.deepcopy(rbln_kwargs)
407
+ merged_rbln_kwargs.update(rbln_config_dict)
408
+
409
+ return (merged_rbln_kwargs, rbln_sub_configs_dict)
410
+
411
+ @classmethod
412
+ def _from_pretrained(
413
+ cls,
414
+ model_id: Union[str, Path],
415
+ config: "PretrainedConfig",
416
+ use_auth_token: Optional[Union[bool, str]] = None,
417
+ revision: Optional[str] = None,
418
+ force_download: bool = False,
419
+ cache_dir: Optional[str] = None,
420
+ subfolder: str = "",
421
+ local_files_only: bool = False,
422
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
423
+ # passed from compile function
424
+ rbln_config: Optional[RBLNConfig] = None,
425
+ rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
426
+ rbln_submodules: List["RBLNBaseModel"] = [],
427
+ **kwargs,
428
+ ) -> "RBLNBaseModel":
429
+ from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
430
+
431
+ if not from_export_method:
432
+ # from compiled dir
433
+ rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
434
+
435
+ model_path_subfolder = cls._load_compiled_model_dir(
436
+ model_id=model_id,
437
+ use_auth_token=use_auth_token,
438
+ revision=revision,
439
+ force_download=force_download,
440
+ cache_dir=cache_dir,
441
+ subfolder=subfolder,
442
+ local_files_only=local_files_only,
443
+ )
444
+
445
+ rbln_config = RBLNConfig.load(model_path_subfolder)
446
+ rbln_config.update_runtime_cfg(rbln_kwargs)
447
+
448
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
449
+
450
+ if len(cls._rbln_submodules) > 0:
451
+ rbln_submodules = cls._load_submodules(
452
+ model_save_dir=model_id,
453
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
454
+ rbln_kwargs=rbln_kwargs,
455
+ **kwargs,
456
+ )
457
+ else:
458
+ rbln_submodules = []
459
+
460
+ if subfolder != "":
461
+ model_save_dir = Path(model_path_subfolder).absolute().parent
462
+ else:
463
+ model_save_dir = Path(model_path_subfolder).absolute()
464
+
465
+ return cls._from_compiled_models(
466
+ rbln_compiled_models=rbln_compiled_models,
467
+ rbln_config=rbln_config,
468
+ config=config,
469
+ model_save_dir=model_save_dir,
470
+ subfolder=subfolder,
471
+ rbln_submodules=rbln_submodules,
472
+ **kwargs,
473
+ )
474
+
475
+ @classmethod
476
+ def _from_compiled_models(
477
+ cls,
478
+ rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
479
+ rbln_config: RBLNConfig,
480
+ config,
481
+ model_save_dir: str,
482
+ subfolder: str,
483
+ rbln_submodules: List["RBLNBaseModel"] = [],
484
+ **kwargs,
485
+ ):
486
+ if isinstance(model_save_dir, str):
487
+ model_save_dir = Path(model_save_dir)
488
+ preprocessors = maybe_load_preprocessors(model_save_dir.name, subfolder=subfolder)
489
+
490
+ # FIXME:: Should we convert it?
491
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
492
+ rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
272
493
 
273
494
  # create runtimes only if `rbln_create_runtimes` is enabled
274
495
  models = (
275
- cls._create_runtimes(rbln_compiled_models, rbln_device_map)
276
- if rbln_create_runtimes
496
+ cls._create_runtimes(rbln_compiled_models, rbln_config.device_map)
497
+ if rbln_config.create_runtimes
277
498
  else UnavailableRuntime()
278
499
  )
279
500
 
280
- if rbln_optimize_host_memory is None:
281
- rbln_optimize_host_memory = True
282
-
283
501
  return cls(
284
502
  models,
285
503
  config,
@@ -287,99 +505,65 @@ class RBLNBaseModel(OptimizedModel, ABC):
287
505
  preprocessors,
288
506
  model_save_dir=model_save_dir,
289
507
  subfolder=subfolder,
290
- rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
508
+ rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
509
+ rbln_submodules=rbln_submodules,
291
510
  **kwargs,
292
511
  )
293
512
 
294
513
  def __repr__(self):
295
- return repr(self.model)
514
+ return repr(self.model) + repr(self.rbln_submodules)
296
515
 
297
516
  @classmethod
298
- def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
517
+ def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None):
299
518
  compiled_model = rebel.compile_from_torch(
300
519
  model,
301
- input_info=rbln_runtime_config.input_info,
302
- batch_size=rbln_runtime_config.batch_size,
303
- fusion=rbln_runtime_config.fusion,
304
- npu=rbln_runtime_config.npu,
305
- tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
520
+ input_info=rbln_compile_config.input_info,
521
+ fusion=rbln_compile_config.fusion,
522
+ npu=rbln_compile_config.npu,
523
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
306
524
  )
307
525
  return compiled_model
308
526
 
309
527
  @classmethod
310
528
  def get_rbln_config(
311
529
  cls,
312
- **rbln_config_kwargs,
530
+ rbln_kwargs: Dict[str, Any],
531
+ **others,
313
532
  ) -> RBLNConfig:
314
533
  """
315
534
  Make default rbln-config for the model.
316
-
317
- if `input_info` specified,
318
- other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
319
-
320
535
  kwargs for overriding model's config can be accepted.
321
-
322
536
  Note that batch_size should be specified with proper input_info.
323
537
  """
324
-
325
- input_info = rbln_config_kwargs.pop("rbln_input_info", None)
326
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
327
- fusion = rbln_config_kwargs.pop("rbln_fusion", None)
328
- npu = rbln_config_kwargs.pop("rbln_npu", None)
329
- tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
330
-
331
- if input_info is not None:
332
- rbln_runtime_config = RBLNRuntimeConfig(
333
- input_info=input_info,
334
- batch_size=batch_size,
335
- fusion=fusion,
336
- npu=npu,
337
- tensor_parallel_size=tensor_parallel_size,
338
- )
339
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
340
- else:
341
- rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
342
- for k, rcfgs in rbln_config.items():
343
- for rcfg in rcfgs:
344
- rcfg: RBLNRuntimeConfig
345
- rcfg.fusion = fusion
346
- rcfg.npu = npu
347
- rcfg.tensor_parallel_size = tensor_parallel_size
348
-
538
+ rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
349
539
  return rbln_config
350
540
 
351
541
  @staticmethod
352
- def pop_rbln_kwargs_from_kwargs(kwargs: dict):
542
+ def pop_rbln_kwargs_from_kwargs(kwargs: Dict[str, Any], runtime_only=False):
353
543
  keys = list(kwargs.keys())
354
- rbln_constructor_kwargs = {
355
- key: kwargs.pop(key)
356
- for key in keys
357
- if key
358
- in [
359
- "rbln_device",
360
- "rbln_device_map",
361
- "rbln_create_runtimes",
362
- "rbln_optimize_host_memory",
363
- ]
364
- }
544
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
365
545
 
366
- keys = list(kwargs.keys())
367
- rbln_config_kwargs = {key: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
368
- return rbln_config_kwargs, rbln_constructor_kwargs
546
+ if runtime_only:
547
+ rbln_kwargs = {
548
+ key: value
549
+ for key, value in rbln_kwargs.items()
550
+ if key in {"create_runtimes", "optimize_host_memory", "device", "device_map"}
551
+ }
552
+
553
+ return rbln_kwargs
369
554
 
370
555
  def can_generate(self):
371
556
  return False
372
557
 
373
558
  def to(self, *args, **kwargs):
374
- pass
559
+ # Do nothing
560
+ return self
375
561
 
376
562
  def __call__(self, *args, **kwargs):
377
563
  return self.forward(*args, **kwargs)
378
564
 
379
- @classmethod
380
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
381
- # Wrap the model if needed.
382
- return model
565
+ def __post_init__(self, **kwargs):
566
+ self.dtype = torch.float32
383
567
 
384
568
  @classmethod
385
569
  def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
@@ -390,8 +574,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
390
574
  return cls._export(*args, **kwargs)
391
575
 
392
576
  @classmethod
577
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
578
+ # Wrap the model if needed.
579
+ return model
580
+
581
+ @classmethod
582
+ @abstractmethod
393
583
  def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
394
- raise NotImplementedError
584
+ pass
395
585
 
396
586
  @abstractmethod
397
587
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
@@ -400,27 +590,16 @@ class RBLNBaseModel(OptimizedModel, ABC):
400
590
  @classmethod
401
591
  @abstractmethod
402
592
  def _create_runtimes(
403
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
593
+ cls,
594
+ compiled_models: List[rebel.RBLNCompiledModel],
595
+ rbln_device_map: Dict[str, int],
404
596
  ) -> List[rebel.Runtime]:
405
597
  # compiled_models -> runtimes
406
598
  pass
407
599
 
408
600
  @classmethod
409
601
  @abstractmethod
410
- def _export(
411
- cls,
412
- model_id: Union[str, Path],
413
- config: "PretrainedConfig",
414
- use_auth_token: Optional[Union[bool, str]] = None,
415
- revision: Optional[str] = None,
416
- force_download: bool = False,
417
- cache_dir: Optional[str] = None,
418
- subfolder: str = "",
419
- local_files_only: bool = False,
420
- trust_remote_code: bool = False,
421
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
422
- **kwargs,
423
- ):
602
+ def _export(cls, *args, **kwargs):
424
603
  """
425
604
  Exports a vanilla Transformers model into a rbln-compiled Module.
426
605
  """
@@ -469,8 +648,8 @@ class RBLNModel(RBLNBaseModel):
469
648
  subfolder: str = "",
470
649
  local_files_only: bool = False,
471
650
  trust_remote_code: bool = False,
472
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
473
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
651
+ # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
652
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
474
653
  **kwargs,
475
654
  ) -> "PreTrainedModel":
476
655
  task = kwargs.pop("task", None)
@@ -495,25 +674,31 @@ class RBLNModel(RBLNBaseModel):
495
674
 
496
675
  return model
497
676
 
677
+ @classmethod
678
+ def save_torch_artifacts(
679
+ cls,
680
+ model: "PreTrainedModel",
681
+ save_dir_path: Path,
682
+ subfolder: str,
683
+ rbln_config: RBLNConfig,
684
+ ):
685
+ """
686
+ If you are unavoidably running on a CPU rather than an RBLN device,
687
+ store the torch tensor, weight, etc. in this function.
688
+ """
689
+
498
690
  @classmethod
499
691
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
500
- model = cls.wrap_model_if_needed(model)
501
- rbln_runtime_configs = list(rbln_config.values())
502
- if len(rbln_runtime_configs) != 1:
503
- raise ValueError
504
- rbln_runtime_config = rbln_runtime_configs[0]
505
- if len(rbln_runtime_config) != 1:
506
- raise ValueError
507
- rbln_runtime_config = rbln_runtime_config[0]
508
-
509
- compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
692
+ model = cls.wrap_model_if_needed(model, rbln_config)
693
+ rbln_compile_config = rbln_config.compile_cfgs[0]
694
+ compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
510
695
  return compiled_model
511
696
 
512
697
  @classmethod
513
698
  @torch.no_grad()
514
699
  def _export(
515
700
  cls,
516
- model_id: str,
701
+ model_id: Union[str, Path],
517
702
  config: "PretrainedConfig",
518
703
  use_auth_token: Optional[Union[bool, str]] = None,
519
704
  revision: Optional[str] = None,
@@ -523,8 +708,12 @@ class RBLNModel(RBLNBaseModel):
523
708
  local_files_only: bool = False,
524
709
  trust_remote_code: bool = False,
525
710
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
711
+ model: "PreTrainedModel" = None,
712
+ rbln_config: Optional[Dict[str, Any]] = None,
526
713
  **kwargs,
527
714
  ) -> "RBLNModel":
715
+ rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
716
+
528
717
  if model_save_dir is None:
529
718
  save_dir = TemporaryDirectory()
530
719
  save_dir_path = Path(save_dir.name)
@@ -536,48 +725,65 @@ class RBLNModel(RBLNBaseModel):
536
725
  save_dir_path = Path(model_save_dir)
537
726
  save_dir_path.mkdir(exist_ok=True)
538
727
 
539
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
540
-
541
- model: "PreTrainedModel" = cls.get_pytorch_model(
542
- model_id=model_id,
543
- subfolder=subfolder,
544
- revision=revision,
545
- cache_dir=cache_dir,
546
- use_auth_token=use_auth_token,
547
- local_files_only=local_files_only,
548
- force_download=force_download,
549
- trust_remote_code=trust_remote_code,
550
- rbln_config_kwargs=rbln_config_kwargs,
551
- rbln_constructor_kwargs=rbln_constructor_kwargs,
552
- **kwargs,
553
- )
728
+ # Load pytorch model if needed.
729
+ if model is None:
730
+ model: "PreTrainedModel" = cls.get_pytorch_model(
731
+ model_id=model_id,
732
+ subfolder=subfolder,
733
+ revision=revision,
734
+ cache_dir=cache_dir,
735
+ use_auth_token=use_auth_token,
736
+ local_files_only=local_files_only,
737
+ force_download=force_download,
738
+ trust_remote_code=trust_remote_code,
739
+ rbln_kwargs=rbln_kwargs,
740
+ **kwargs,
741
+ )
742
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
743
+ else:
744
+ preprocessors = []
554
745
 
555
746
  # FIXME :: optimum passes AutoConfig.
556
747
  config = model.config
748
+ if hasattr(model, "can_generate") and model.can_generate():
749
+ generation_config = model.generation_config
750
+ generation_config.save_pretrained(save_dir_path / subfolder)
557
751
 
558
752
  if not isinstance(config, PretrainedConfig): # diffusers config
559
753
  config = PretrainedConfig(**config)
560
-
561
754
  config.save_pretrained(save_dir_path / subfolder)
562
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
563
755
 
564
756
  # Get compilation arguments
565
- if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
566
- rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
567
- compiled_model = cls.get_compiled_model(model, rbln_config=rbln_config)
757
+ rbln_config: RBLNConfig = cls.get_rbln_config(
758
+ preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
759
+ )
760
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
761
+ model, rbln_config=rbln_config
762
+ )
568
763
 
569
764
  # Save compiled models
570
765
  (save_dir_path / subfolder).mkdir(exist_ok=True)
571
- if isinstance(compiled_model, Iterable):
572
- # compiled_model is an Iterable instance
573
- for single_compiled_model, compiled_model_name in zip(compiled_model, rbln_config):
574
- single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
766
+ if not isinstance(compiled_model, dict):
767
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
768
+ else:
575
769
  compiled_models = compiled_model
770
+ for compiled_model_name, cm in compiled_models.items():
771
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
772
+ rbln_config.save(save_dir_path / subfolder)
576
773
 
774
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
775
+
776
+ # Load submodules
777
+ if len(cls._rbln_submodules) > 0:
778
+ rbln_submodules = cls._load_submodules(
779
+ model=model,
780
+ model_save_dir=save_dir,
781
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
782
+ rbln_kwargs=rbln_kwargs,
783
+ **kwargs,
784
+ )
577
785
  else:
578
- compiled_model.save(save_dir_path / subfolder / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
579
- compiled_models = [compiled_model]
580
- rbln_config.save(save_dir_path / subfolder)
786
+ rbln_submodules = []
581
787
 
582
788
  # Instantiate
583
789
  return cls._from_pretrained(
@@ -592,13 +798,15 @@ class RBLNModel(RBLNBaseModel):
592
798
  local_files_only=local_files_only,
593
799
  rbln_config=rbln_config,
594
800
  rbln_compiled_models=compiled_models,
595
- **rbln_constructor_kwargs,
801
+ rbln_submodules=rbln_submodules,
596
802
  **kwargs,
597
803
  )
598
804
 
599
805
  @classmethod
600
806
  def _create_runtimes(
601
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
807
+ cls,
808
+ compiled_models: List[rebel.RBLNCompiledModel],
809
+ rbln_device_map: Dict[str, int],
602
810
  ) -> List[rebel.Runtime]:
603
811
  device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
604
812
  return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
@@ -611,16 +819,19 @@ class RBLNModel(RBLNBaseModel):
611
819
  class RBLNModelForQuestionAnswering(RBLNModel):
612
820
  model_type = "rbln_model"
613
821
  auto_model_class = AutoModelForQuestionAnswering
822
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
614
823
 
615
824
  @classmethod
616
825
  def _get_rbln_config(
617
826
  cls,
618
827
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
619
828
  model_config: Optional["PretrainedConfig"] = None,
620
- rbln_max_seq_len: Optional[int] = None,
621
- rbln_model_input_names: Optional[List[str]] = None,
622
- rbln_batch_size: Optional[int] = None,
829
+ rbln_kwargs: Dict[str, Any] = {},
623
830
  ) -> RBLNConfig:
831
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
832
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
833
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
834
+
624
835
  if rbln_max_seq_len is None:
625
836
  for tokenizer in preprocessors:
626
837
  if hasattr(tokenizer, "model_max_length"):
@@ -629,22 +840,37 @@ class RBLNModelForQuestionAnswering(RBLNModel):
629
840
  if rbln_max_seq_len is None:
630
841
  raise ValueError("`rbln_max_seq_len` should be specified!")
631
842
 
632
- if rbln_model_input_names is None:
633
- # These are BERT's inputs
634
- rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
635
-
636
843
  if rbln_batch_size is None:
637
844
  rbln_batch_size = 1
845
+
846
+ if rbln_model_input_names is None:
847
+ for tokenizer in preprocessors:
848
+ if hasattr(tokenizer, "model_input_names"):
849
+ rbln_model_input_names = tokenizer.model_input_names
850
+ break
851
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
852
+ rbln_model_input_names = cls.rbln_model_input_names
853
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
854
+ original_model_class = getattr(transformers, model_config.architectures[0])
855
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
856
+ raise ValueError(
857
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
858
+ f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(input_names_order)})"
859
+ )
860
+
638
861
  input_info = [
639
862
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
640
863
  for model_input_name in rbln_model_input_names
641
864
  ]
642
865
 
643
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
644
- rbln_runtime_config.batch_size = rbln_batch_size
645
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
646
-
647
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
866
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
867
+ rbln_config = RBLNConfig(
868
+ rbln_cls=cls.__name__,
869
+ compile_cfgs=[rbln_compile_config],
870
+ rbln_kwargs=rbln_kwargs,
871
+ )
872
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
873
+ return rbln_config
648
874
 
649
875
 
650
876
  class RBLNModelForImageClassification(RBLNModel):
@@ -660,9 +886,11 @@ class RBLNModelForImageClassification(RBLNModel):
660
886
  cls,
661
887
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
662
888
  model_config: Optional["PretrainedConfig"] = None,
663
- rbln_image_size: Optional[int] = None,
664
- rbln_batch_size: Optional[int] = None,
889
+ rbln_kwargs: Dict[str, Any] = {},
665
890
  ) -> RBLNConfig:
891
+ rbln_image_size = rbln_kwargs.get("image_size", None)
892
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
893
+
666
894
  if rbln_image_size is None:
667
895
  for processor in preprocessors:
668
896
  if hasattr(processor, "size"):
@@ -674,13 +902,19 @@ class RBLNModelForImageClassification(RBLNModel):
674
902
  if rbln_batch_size is None:
675
903
  rbln_batch_size = 1
676
904
 
677
- input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size, rbln_image_size], "float32")]
905
+ if isinstance(rbln_image_size, int):
906
+ rbln_image_size = rbln_image_size, rbln_image_size
678
907
 
679
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
680
- rbln_runtime_config.batch_size = rbln_batch_size
681
- meta = {"rbln_image_size": rbln_image_size}
908
+ input_info = [
909
+ (
910
+ "pixel_values",
911
+ [rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]],
912
+ "float32",
913
+ )
914
+ ]
682
915
 
683
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
916
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
917
+ return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
684
918
 
685
919
 
686
920
  class RBLNModelForAudioClassification(RBLNModel):
@@ -704,11 +938,11 @@ class RBLNModelForAudioClassification(RBLNModel):
704
938
  cls,
705
939
  preprocessors: "AutoFeatureExtractor",
706
940
  model_config: "PretrainedConfig",
707
- rbln_batch_size: Optional[int] = None,
708
- rbln_max_length: Optional[int] = None,
709
- rbln_num_mel_bins: Optional[int] = None,
941
+ rbln_kwargs: Dict[str, Any] = {},
710
942
  ) -> RBLNConfig:
711
- meta = {}
943
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
944
+ rbln_max_length = rbln_kwargs.get("max_length", None)
945
+ rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
712
946
 
713
947
  if rbln_batch_size is None:
714
948
  rbln_batch_size = 1
@@ -734,21 +968,27 @@ class RBLNModelForAudioClassification(RBLNModel):
734
968
  if rbln_max_length is None:
735
969
  raise ValueError("`rbln_max_length` should be specified!")
736
970
 
737
- meta["rbln_batch_size"] = rbln_batch_size
738
- meta["rbln_max_length"] = rbln_max_length
739
- meta["rbln_num_mel_bins"] = rbln_num_mel_bins
740
-
741
- model_input_info = [
742
- ("input_values", [rbln_batch_size, rbln_max_length, rbln_num_mel_bins], "float32"),
971
+ input_info = [
972
+ (
973
+ "input_values",
974
+ [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
975
+ "float32",
976
+ ),
743
977
  ]
744
978
 
745
- rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
746
-
747
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
748
- [rbln_runtime_config],
749
- _rbln_meta=meta,
979
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
980
+ rbln_config = RBLNConfig(
981
+ rbln_cls=cls.__name__,
982
+ compile_cfgs=[rbln_compile_config],
983
+ rbln_kwargs=rbln_kwargs,
984
+ )
985
+ rbln_config.model_cfg.update(
986
+ {
987
+ "batch_size": rbln_batch_size,
988
+ "max_length": rbln_max_length,
989
+ "num_mel_bins": rbln_num_mel_bins,
990
+ }
750
991
  )
751
-
752
992
  return rbln_config
753
993
 
754
994
 
@@ -773,10 +1013,11 @@ class RBLNModelForSequenceClassification(RBLNModel):
773
1013
  cls,
774
1014
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
775
1015
  model_config: Optional["PretrainedConfig"] = None,
776
- rbln_max_seq_len: Optional[int] = None,
777
- rbln_model_input_names: Optional[List[str]] = None,
778
- rbln_batch_size: Optional[int] = None,
1016
+ rbln_kwargs: Dict[str, Any] = {},
779
1017
  ) -> RBLNConfig:
1018
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
1019
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
1020
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
780
1021
 
781
1022
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
782
1023
  model_config, "max_position_embeddings", None
@@ -796,21 +1037,37 @@ class RBLNModelForSequenceClassification(RBLNModel):
796
1037
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
797
1038
 
798
1039
  if rbln_model_input_names is None:
799
- # These are BERT's inputs
800
- rbln_model_input_names = ["input_ids", "attention_mask"]
1040
+ for tokenizer in preprocessors:
1041
+ if hasattr(tokenizer, "model_input_names"):
1042
+ rbln_model_input_names = tokenizer.model_input_names
1043
+ break
1044
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1045
+ rbln_model_input_names = cls.rbln_model_input_names
1046
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1047
+ original_model_class = getattr(transformers, model_config.architectures[0])
1048
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1049
+ raise ValueError(
1050
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1051
+ f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(input_names_order)})"
1052
+ )
801
1053
 
802
1054
  if rbln_batch_size is None:
803
1055
  rbln_batch_size = 1
1056
+
804
1057
  input_info = [
805
1058
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
806
1059
  for model_input_name in rbln_model_input_names
807
1060
  ]
808
1061
 
809
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
810
- rbln_runtime_config.batch_size = rbln_batch_size
811
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
1062
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1063
+ rbln_config = RBLNConfig(
1064
+ rbln_cls=cls.__name__,
1065
+ compile_cfgs=[rbln_compile_config],
1066
+ rbln_kwargs=rbln_kwargs,
1067
+ )
1068
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1069
+ return rbln_config
812
1070
 
813
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
814
1071
 
815
1072
  class RBLNModelForMaskedLM(RBLNModel):
816
1073
  model_type = "rbln_model"
@@ -821,10 +1078,12 @@ class RBLNModelForMaskedLM(RBLNModel):
821
1078
  cls,
822
1079
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
823
1080
  model_config: Optional["PretrainedConfig"] = None,
824
- rbln_max_seq_len: Optional[int] = None,
825
- rbln_model_input_names: Optional[List[str]] = None,
826
- rbln_batch_size: Optional[int] = None,
1081
+ rbln_kwargs: Dict[str, Any] = {},
827
1082
  ) -> RBLNConfig:
1083
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
1084
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
1085
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
1086
+
828
1087
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
829
1088
  model_config, "max_position_embeddings", None
830
1089
  )
@@ -843,18 +1102,33 @@ class RBLNModelForMaskedLM(RBLNModel):
843
1102
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
844
1103
 
845
1104
  if rbln_model_input_names is None:
846
- # These are BERT's inputs
847
- rbln_model_input_names = ["input_ids", "attention_mask"]
1105
+ for tokenizer in preprocessors:
1106
+ if hasattr(tokenizer, "model_input_names"):
1107
+ rbln_model_input_names = tokenizer.model_input_names
1108
+ break
1109
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1110
+ rbln_model_input_names = cls.rbln_model_input_names
1111
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1112
+ original_model_class = getattr(transformers, model_config.architectures[0])
1113
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1114
+ raise ValueError(
1115
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1116
+ f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(input_names_order)})"
1117
+ )
848
1118
 
849
1119
  if rbln_batch_size is None:
850
1120
  rbln_batch_size = 1
1121
+
851
1122
  input_info = [
852
1123
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
853
1124
  for model_input_name in rbln_model_input_names
854
1125
  ]
855
1126
 
856
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
857
- rbln_runtime_config.batch_size = rbln_batch_size
858
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
859
-
860
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
1127
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1128
+ rbln_config = RBLNConfig(
1129
+ rbln_cls=cls.__name__,
1130
+ compile_cfgs=[rbln_compile_config],
1131
+ rbln_kwargs=rbln_kwargs,
1132
+ )
1133
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1134
+ return rbln_config