optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. optimum/rbln/__init__.py +37 -2
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
  4. optimum/rbln/diffusers/models/controlnet.py +56 -40
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
  6. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
  10. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
  12. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
  14. optimum/rbln/modeling_alias.py +3 -3
  15. optimum/rbln/modeling_base.py +471 -231
  16. optimum/rbln/modeling_config.py +152 -77
  17. optimum/rbln/modeling_seq2seq.py +166 -77
  18. optimum/rbln/transformers/__init__.py +35 -1
  19. optimum/rbln/transformers/models/__init__.py +20 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
  33. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  34. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  35. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  37. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  38. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
  41. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  42. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  43. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  44. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  45. optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
  46. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
  47. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  48. optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
  49. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  50. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
  51. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  52. optimum/rbln/utils/import_utils.py +36 -1
  53. optimum/rbln/utils/logging.py +82 -0
  54. optimum/rbln/utils/runtime_utils.py +33 -0
  55. optimum/rbln/utils/timer_utils.py +19 -0
  56. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
  57. optimum_rbln-0.1.11.dist-info/RECORD +93 -0
  58. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
  59. optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
  60. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  61. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -21,16 +21,20 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import copy
25
+ import importlib
26
+ import inspect
24
27
  import logging
25
28
  import os
26
29
  import shutil
27
30
  from abc import ABC, abstractmethod
28
31
  from pathlib import Path
29
32
  from tempfile import TemporaryDirectory
30
- from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
33
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
31
34
 
32
35
  import rebel
33
36
  import torch
37
+ import transformers
34
38
  from huggingface_hub import HfApi, HfFolder, hf_hub_download
35
39
  from optimum.exporters import TasksManager
36
40
  from optimum.modeling_base import OptimizedModel
@@ -46,7 +50,7 @@ from transformers import (
46
50
  PretrainedConfig,
47
51
  )
48
52
 
49
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
53
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
50
54
  from .utils.runtime_utils import UnavailableRuntime
51
55
  from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
52
56
 
@@ -62,7 +66,116 @@ if TYPE_CHECKING:
62
66
  logger = logging.getLogger(__name__)
63
67
 
64
68
 
65
- class RBLNBaseModel(OptimizedModel, ABC):
69
+ class SubModulesMixin:
70
+ """
71
+ _rbln_submodules = [
72
+ {"name": "vision_tower"},
73
+ {"name": "language_model"},
74
+ ]
75
+ """
76
+
77
+ _rbln_submodules: List[Dict[str, Any]] = []
78
+
79
+ def __init__(
80
+ self,
81
+ *,
82
+ rbln_submodules: List["RBLNBaseModel"] = [],
83
+ **kwargs,
84
+ ) -> None:
85
+ for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
86
+ setattr(self, submodule_meta["name"], submodule)
87
+
88
+ @classmethod
89
+ def _from_model(
90
+ cls,
91
+ model: "PreTrainedModel",
92
+ model_save_dir: str,
93
+ rbln_sub_configs_dict: Dict[str, Any],
94
+ rbln_kwargs: Dict[str, Any],
95
+ subfolder=None, # warning: will be ignored
96
+ **kwargs,
97
+ ) -> List["RBLNBaseModel"]:
98
+ rbln_submodules = []
99
+ for submodule in cls._rbln_submodules:
100
+ submodule_name = submodule["name"]
101
+ torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
102
+ cls_name = torch_submodule.__class__.__name__
103
+ submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
104
+
105
+ if submodule_name in rbln_sub_configs_dict:
106
+ kwargs["rbln_config"] = rbln_sub_configs_dict[submodule_name]
107
+
108
+ rbln_submodule = submodule_cls._export(
109
+ model_id=None,
110
+ config=torch_submodule.config,
111
+ subfolder=submodule_name,
112
+ model_save_dir=model_save_dir,
113
+ model=torch_submodule,
114
+ **rbln_kwargs,
115
+ **kwargs,
116
+ )
117
+
118
+ rbln_submodules.append(rbln_submodule)
119
+
120
+ return rbln_submodules
121
+
122
+ @classmethod
123
+ def _submodule_from_compiled_model(
124
+ cls, model_save_dir: str, rbln_sub_configs_dict: Dict[str, Any], rbln_kwargs: Dict[str, Any], **kwargs
125
+ ):
126
+ rbln_submodules = []
127
+ for submodule in cls._rbln_submodules:
128
+ submodule_name = submodule["name"]
129
+ rbln_submodule_config_dict = rbln_sub_configs_dict.get(submodule_name, None)
130
+
131
+ # Get cls name for call the constructor of the rbln class
132
+ submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
133
+ submodule_cls_name = submodule_rbln_config.meta["cls"]
134
+ submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
135
+
136
+ config = OptimizedModel._load_config(Path(model_save_dir) / submodule_name, **kwargs)
137
+ rbln_submodule = submodule_cls._from_pretrained(
138
+ model_id=model_save_dir,
139
+ config=config,
140
+ subfolder=submodule_name,
141
+ rbln_config=rbln_submodule_config_dict,
142
+ **rbln_kwargs,
143
+ **kwargs,
144
+ )
145
+ rbln_submodules.append(rbln_submodule)
146
+ return rbln_submodules
147
+
148
+ @classmethod
149
+ def _load_submodules(
150
+ cls,
151
+ model_save_dir,
152
+ rbln_sub_configs_dict,
153
+ rbln_kwargs,
154
+ model=None,
155
+ **kwargs,
156
+ ):
157
+ # Two way :
158
+ # 1. Compile from pytorch object
159
+ # 2. Load from compiled file
160
+ if model is not None:
161
+ return cls._from_model(
162
+ model=model,
163
+ model_save_dir=model_save_dir,
164
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
165
+ rbln_kwargs=rbln_kwargs,
166
+ **kwargs,
167
+ )
168
+
169
+ else:
170
+ return cls._submodule_from_compiled_model(
171
+ model_save_dir=model_save_dir,
172
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
173
+ rbln_kwargs=rbln_kwargs,
174
+ **kwargs,
175
+ )
176
+
177
+
178
+ class RBLNBaseModel(OptimizedModel, ABC, SubModulesMixin):
66
179
  """
67
180
  An abstract base class for compiling, loading, and saving neural network models from the huggingface
68
181
  transformers and diffusers libraries to run on RBLN NPU devices.
@@ -110,6 +223,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
110
223
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
111
224
  subfolder: str = "",
112
225
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
226
+ rbln_submodules: List["RBLNBaseModel"] = [],
113
227
  **kwargs,
114
228
  ):
115
229
  super().__init__(models, config)
@@ -127,11 +241,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
127
241
  self.auto_model_class.register(AutoConfig, self.__class__)
128
242
 
129
243
  # copied from tranformers PreTrainedModel __init__
130
- self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
244
+ if self.can_generate():
245
+ gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
246
+ self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
247
+ else:
248
+ self.generation_config = None
249
+
250
+ # self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
131
251
  if self.generation_config is not None:
132
252
  self.generation_config.use_cache = True
133
253
 
134
254
  self.device = torch.device("cpu")
255
+ self.training = False
135
256
 
136
257
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
137
258
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -146,11 +267,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
146
267
  self.model_save_dir = model_save_dir
147
268
  self.subfolder = subfolder
148
269
 
270
+ self.rbln_submodules = rbln_submodules
149
271
  self.__post_init__(**kwargs)
150
272
 
151
- def __post_init__(self, **kwargs):
152
- pass
153
-
154
273
  def _save_pretrained(self, save_directory: Union[str, Path]):
155
274
  """
156
275
  Saves a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -180,27 +299,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
180
299
  )
181
300
 
182
301
  @classmethod
183
- def _from_pretrained(
302
+ def _load_compiled_model_dir(
184
303
  cls,
185
304
  model_id: Union[str, Path],
186
- config: "PretrainedConfig",
187
305
  use_auth_token: Optional[Union[bool, str]] = None,
188
306
  revision: Optional[str] = None,
189
307
  force_download: bool = False,
190
308
  cache_dir: Optional[str] = None,
191
309
  subfolder: str = "",
192
310
  local_files_only: bool = False,
193
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
194
- # Runtime - related kwargs
195
- rbln_device: Optional[List[int]] = None,
196
- rbln_device_map: Optional[Dict[str, int]] = None,
197
- rbln_create_runtimes: Optional[bool] = None,
198
- # passed from compile function
199
- rbln_config: Optional[RBLNConfig] = None,
200
- rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
201
- rbln_optimize_host_memory: Optional[bool] = None,
202
- **kwargs,
203
- ) -> "RBLNBaseModel":
311
+ ):
312
+ # Find compiled model
313
+ # And prepare or download cache folder from HF Hub if needed.
204
314
  model_path = Path(model_id)
205
315
  if model_path.is_dir():
206
316
  model_path = model_path / subfolder
@@ -236,16 +346,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
236
346
  )
237
347
 
238
348
  if model_path.is_dir():
239
- if rbln_compiled_models is None:
240
- rbln_config = RBLNConfig.load(str(model_path))
241
- rbln_compiled_models = [
242
- rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
243
- for compiled_model_name in rbln_config
244
- ]
245
- new_model_save_dir = model_path
246
- else:
247
- pass
248
-
349
+ model_path = str(model_path)
249
350
  else:
250
351
  rbln_config_filename = rbln_config_filenames[0]
251
352
  rbln_config_cache_path = hf_hub_download(
@@ -258,48 +359,145 @@ class RBLNBaseModel(OptimizedModel, ABC):
258
359
  force_download=force_download,
259
360
  local_files_only=local_files_only,
260
361
  )
261
- rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
262
- rbln_compiled_models = []
263
- for compiled_model_name in rbln_config:
264
- model_cache_path = hf_hub_download(
265
- repo_id=model_id,
266
- filename=f"{compiled_model_name}.rbln",
267
- subfolder=subfolder,
268
- use_auth_token=use_auth_token,
269
- revision=revision,
270
- cache_dir=cache_dir,
271
- force_download=force_download,
272
- local_files_only=local_files_only,
273
- )
274
- rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
275
- new_model_save_dir = Path(rbln_config_cache_path).parent
362
+ model_path = Path(rbln_config_cache_path).parent
276
363
 
277
- preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
364
+ return model_path
365
+
366
+ @classmethod
367
+ def _load_compiled_models(cls, model_path: str):
368
+ compiled_models = Path(model_path).glob("*.rbln")
369
+ rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
370
+ return rbln_compiled_models
371
+
372
+ @classmethod
373
+ def _split_submodule_config(cls, rbln_config_dict: Dict[str, Any] = {}) -> Dict[str, Any]:
374
+ # {"language_model" : {"rbln_tensor_parallel_size":4}}
375
+ rbln_sub_configs_dict: Dict[str, Dict[str, Any]] = {}
376
+
377
+ # Remove submodule-configs from rbln_config
378
+ if len(cls._rbln_submodules) > 0:
379
+ keys = list(rbln_config_dict.keys())
380
+ submodule_names = [m["name"] for m in cls._rbln_submodules]
381
+ for key in keys:
382
+ if key in submodule_names:
383
+ rbln_sub_configs_dict[key] = rbln_config_dict.pop(key)
384
+
385
+ return rbln_sub_configs_dict
386
+
387
+ @classmethod
388
+ def resolve_rbln_config(cls, rbln_config: Union[RBLNConfig, Dict[str, Any]], kwargs):
389
+ if isinstance(rbln_config, RBLNConfig):
390
+ # Already resolved
391
+ return rbln_config, None
278
392
 
279
- if model_save_dir is None:
280
- model_save_dir = new_model_save_dir
281
-
282
- # Create runtimes
283
- if rbln_create_runtimes is None:
284
- rbln_create_runtimes = rebel.npu_is_available()
285
- if rbln_device_map is None:
286
- rbln_device_map = {}
287
- device_val = 0 if rbln_device is None else rbln_device
288
- for key in rbln_config:
289
- rbln_device_map[key] = device_val
290
393
  else:
291
- rbln_device_map = rbln_device_map
394
+ if rbln_config is None:
395
+ rbln_config_dict = {}
396
+ else:
397
+ rbln_config_dict = rbln_config
398
+
399
+ rbln_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
400
+ rbln_sub_configs_dict = cls._split_submodule_config(rbln_config_dict)
401
+
402
+ for key in rbln_config_dict:
403
+ if key in rbln_kwargs:
404
+ raise KeyError(f"duplicate key in both `rbln_config` and {key}")
405
+
406
+ merged_rbln_kwargs = copy.deepcopy(rbln_kwargs)
407
+ merged_rbln_kwargs.update(rbln_config_dict)
408
+
409
+ return (merged_rbln_kwargs, rbln_sub_configs_dict)
410
+
411
+ @classmethod
412
+ def _from_pretrained(
413
+ cls,
414
+ model_id: Union[str, Path],
415
+ config: "PretrainedConfig",
416
+ use_auth_token: Optional[Union[bool, str]] = None,
417
+ revision: Optional[str] = None,
418
+ force_download: bool = False,
419
+ cache_dir: Optional[str] = None,
420
+ subfolder: str = "",
421
+ local_files_only: bool = False,
422
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
423
+ # passed from compile function
424
+ rbln_config: Optional[RBLNConfig] = None,
425
+ rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
426
+ rbln_submodules: List["RBLNBaseModel"] = [],
427
+ **kwargs,
428
+ ) -> "RBLNBaseModel":
429
+ from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
430
+
431
+ if not from_export_method:
432
+ # from compiled dir
433
+ rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
434
+
435
+ model_path_subfolder = cls._load_compiled_model_dir(
436
+ model_id=model_id,
437
+ use_auth_token=use_auth_token,
438
+ revision=revision,
439
+ force_download=force_download,
440
+ cache_dir=cache_dir,
441
+ subfolder=subfolder,
442
+ local_files_only=local_files_only,
443
+ )
444
+
445
+ rbln_config = RBLNConfig.load(model_path_subfolder)
446
+ rbln_config.update_runtime_cfg(rbln_kwargs)
447
+
448
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
449
+
450
+ if len(cls._rbln_submodules) > 0:
451
+ rbln_submodules = cls._load_submodules(
452
+ model_save_dir=model_id,
453
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
454
+ rbln_kwargs=rbln_kwargs,
455
+ **kwargs,
456
+ )
457
+ else:
458
+ rbln_submodules = []
459
+
460
+ if subfolder != "":
461
+ model_save_dir = Path(model_path_subfolder).absolute().parent
462
+ else:
463
+ model_save_dir = Path(model_path_subfolder).absolute()
464
+
465
+ return cls._from_compiled_models(
466
+ rbln_compiled_models=rbln_compiled_models,
467
+ rbln_config=rbln_config,
468
+ config=config,
469
+ model_save_dir=model_save_dir,
470
+ subfolder=subfolder,
471
+ rbln_submodules=rbln_submodules,
472
+ **kwargs,
473
+ )
474
+
475
+ @classmethod
476
+ def _from_compiled_models(
477
+ cls,
478
+ rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
479
+ rbln_config: RBLNConfig,
480
+ config,
481
+ model_save_dir: str,
482
+ subfolder: str,
483
+ rbln_submodules: List["RBLNBaseModel"] = [],
484
+ **kwargs,
485
+ ):
486
+ if isinstance(model_save_dir, str):
487
+ model_save_dir = Path(model_save_dir)
488
+ preprocessors = maybe_load_preprocessors(model_save_dir.name, subfolder=subfolder)
489
+
490
+ # FIXME:: Should we convert it?
491
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
492
+ rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
292
493
 
293
494
  # create runtimes only if `rbln_create_runtimes` is enabled
294
495
  models = (
295
- cls._create_runtimes(rbln_compiled_models, rbln_device_map)
296
- if rbln_create_runtimes
496
+ cls._create_runtimes(rbln_compiled_models, rbln_config.device_map)
497
+ if rbln_config.create_runtimes
297
498
  else UnavailableRuntime()
298
499
  )
299
500
 
300
- if rbln_optimize_host_memory is None:
301
- rbln_optimize_host_memory = True
302
-
303
501
  return cls(
304
502
  models,
305
503
  config,
@@ -307,99 +505,65 @@ class RBLNBaseModel(OptimizedModel, ABC):
307
505
  preprocessors,
308
506
  model_save_dir=model_save_dir,
309
507
  subfolder=subfolder,
310
- rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
508
+ rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
509
+ rbln_submodules=rbln_submodules,
311
510
  **kwargs,
312
511
  )
313
512
 
314
513
  def __repr__(self):
315
- return repr(self.model)
514
+ return repr(self.model) + repr(self.rbln_submodules)
316
515
 
317
516
  @classmethod
318
- def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
517
+ def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None):
319
518
  compiled_model = rebel.compile_from_torch(
320
519
  model,
321
- input_info=rbln_runtime_config.input_info,
322
- batch_size=rbln_runtime_config.batch_size,
323
- fusion=rbln_runtime_config.fusion,
324
- npu=rbln_runtime_config.npu,
325
- tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
520
+ input_info=rbln_compile_config.input_info,
521
+ fusion=rbln_compile_config.fusion,
522
+ npu=rbln_compile_config.npu,
523
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
326
524
  )
327
525
  return compiled_model
328
526
 
329
527
  @classmethod
330
528
  def get_rbln_config(
331
529
  cls,
332
- **rbln_config_kwargs,
530
+ rbln_kwargs: Dict[str, Any],
531
+ **others,
333
532
  ) -> RBLNConfig:
334
533
  """
335
534
  Make default rbln-config for the model.
336
-
337
- if `input_info` specified,
338
- other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
339
-
340
535
  kwargs for overriding model's config can be accepted.
341
-
342
536
  Note that batch_size should be specified with proper input_info.
343
537
  """
344
-
345
- input_info = rbln_config_kwargs.pop("rbln_input_info", None)
346
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
347
- fusion = rbln_config_kwargs.pop("rbln_fusion", None)
348
- npu = rbln_config_kwargs.pop("rbln_npu", None)
349
- tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
350
-
351
- if input_info is not None:
352
- rbln_runtime_config = RBLNRuntimeConfig(
353
- input_info=input_info,
354
- batch_size=batch_size,
355
- fusion=fusion,
356
- npu=npu,
357
- tensor_parallel_size=tensor_parallel_size,
358
- )
359
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
360
- else:
361
- rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
362
- for k, rcfgs in rbln_config.items():
363
- for rcfg in rcfgs:
364
- rcfg: RBLNRuntimeConfig
365
- rcfg.fusion = fusion
366
- rcfg.npu = npu
367
- rcfg.tensor_parallel_size = tensor_parallel_size
368
-
538
+ rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
369
539
  return rbln_config
370
540
 
371
541
  @staticmethod
372
- def pop_rbln_kwargs_from_kwargs(kwargs: dict):
542
+ def pop_rbln_kwargs_from_kwargs(kwargs: Dict[str, Any], runtime_only=False):
373
543
  keys = list(kwargs.keys())
374
- rbln_constructor_kwargs = {
375
- key: kwargs.pop(key)
376
- for key in keys
377
- if key
378
- in [
379
- "rbln_device",
380
- "rbln_device_map",
381
- "rbln_create_runtimes",
382
- "rbln_optimize_host_memory",
383
- ]
384
- }
544
+ rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
385
545
 
386
- keys = list(kwargs.keys())
387
- rbln_config_kwargs = {key: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
388
- return rbln_config_kwargs, rbln_constructor_kwargs
546
+ if runtime_only:
547
+ rbln_kwargs = {
548
+ key: value
549
+ for key, value in rbln_kwargs.items()
550
+ if key in {"create_runtimes", "optimize_host_memory", "device", "device_map"}
551
+ }
552
+
553
+ return rbln_kwargs
389
554
 
390
555
  def can_generate(self):
391
556
  return False
392
557
 
393
558
  def to(self, *args, **kwargs):
394
- pass
559
+ # Do nothing
560
+ return self
395
561
 
396
562
  def __call__(self, *args, **kwargs):
397
563
  return self.forward(*args, **kwargs)
398
564
 
399
- @classmethod
400
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
401
- # Wrap the model if needed.
402
- return model
565
+ def __post_init__(self, **kwargs):
566
+ self.dtype = torch.float32
403
567
 
404
568
  @classmethod
405
569
  def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
@@ -410,8 +574,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
410
574
  return cls._export(*args, **kwargs)
411
575
 
412
576
  @classmethod
577
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
578
+ # Wrap the model if needed.
579
+ return model
580
+
581
+ @classmethod
582
+ @abstractmethod
413
583
  def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
414
- raise NotImplementedError
584
+ pass
415
585
 
416
586
  @abstractmethod
417
587
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
@@ -429,20 +599,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
429
599
 
430
600
  @classmethod
431
601
  @abstractmethod
432
- def _export(
433
- cls,
434
- model_id: Union[str, Path],
435
- config: "PretrainedConfig",
436
- use_auth_token: Optional[Union[bool, str]] = None,
437
- revision: Optional[str] = None,
438
- force_download: bool = False,
439
- cache_dir: Optional[str] = None,
440
- subfolder: str = "",
441
- local_files_only: bool = False,
442
- trust_remote_code: bool = False,
443
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
444
- **kwargs,
445
- ):
602
+ def _export(cls, *args, **kwargs):
446
603
  """
447
604
  Exports a vanilla Transformers model into a rbln-compiled Module.
448
605
  """
@@ -491,8 +648,8 @@ class RBLNModel(RBLNBaseModel):
491
648
  subfolder: str = "",
492
649
  local_files_only: bool = False,
493
650
  trust_remote_code: bool = False,
494
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
495
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
651
+ # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
652
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
496
653
  **kwargs,
497
654
  ) -> "PreTrainedModel":
498
655
  task = kwargs.pop("task", None)
@@ -517,25 +674,31 @@ class RBLNModel(RBLNBaseModel):
517
674
 
518
675
  return model
519
676
 
677
+ @classmethod
678
+ def save_torch_artifacts(
679
+ cls,
680
+ model: "PreTrainedModel",
681
+ save_dir_path: Path,
682
+ subfolder: str,
683
+ rbln_config: RBLNConfig,
684
+ ):
685
+ """
686
+ If you are unavoidably running on a CPU rather than an RBLN device,
687
+ store the torch tensor, weight, etc. in this function.
688
+ """
689
+
520
690
  @classmethod
521
691
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
522
692
  model = cls.wrap_model_if_needed(model, rbln_config)
523
- rbln_runtime_configs = list(rbln_config.values())
524
- if len(rbln_runtime_configs) != 1:
525
- raise ValueError
526
- rbln_runtime_config = rbln_runtime_configs[0]
527
- if len(rbln_runtime_config) != 1:
528
- raise ValueError
529
- rbln_runtime_config = rbln_runtime_config[0]
530
-
531
- compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
693
+ rbln_compile_config = rbln_config.compile_cfgs[0]
694
+ compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
532
695
  return compiled_model
533
696
 
534
697
  @classmethod
535
698
  @torch.no_grad()
536
699
  def _export(
537
700
  cls,
538
- model_id: str,
701
+ model_id: Union[str, Path],
539
702
  config: "PretrainedConfig",
540
703
  use_auth_token: Optional[Union[bool, str]] = None,
541
704
  revision: Optional[str] = None,
@@ -545,8 +708,12 @@ class RBLNModel(RBLNBaseModel):
545
708
  local_files_only: bool = False,
546
709
  trust_remote_code: bool = False,
547
710
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
711
+ model: "PreTrainedModel" = None,
712
+ rbln_config: Optional[Dict[str, Any]] = None,
548
713
  **kwargs,
549
714
  ) -> "RBLNModel":
715
+ rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
716
+
550
717
  if model_save_dir is None:
551
718
  save_dir = TemporaryDirectory()
552
719
  save_dir_path = Path(save_dir.name)
@@ -558,48 +725,65 @@ class RBLNModel(RBLNBaseModel):
558
725
  save_dir_path = Path(model_save_dir)
559
726
  save_dir_path.mkdir(exist_ok=True)
560
727
 
561
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
562
-
563
- model: "PreTrainedModel" = cls.get_pytorch_model(
564
- model_id=model_id,
565
- subfolder=subfolder,
566
- revision=revision,
567
- cache_dir=cache_dir,
568
- use_auth_token=use_auth_token,
569
- local_files_only=local_files_only,
570
- force_download=force_download,
571
- trust_remote_code=trust_remote_code,
572
- rbln_config_kwargs=rbln_config_kwargs,
573
- rbln_constructor_kwargs=rbln_constructor_kwargs,
574
- **kwargs,
575
- )
728
+ # Load pytorch model if needed.
729
+ if model is None:
730
+ model: "PreTrainedModel" = cls.get_pytorch_model(
731
+ model_id=model_id,
732
+ subfolder=subfolder,
733
+ revision=revision,
734
+ cache_dir=cache_dir,
735
+ use_auth_token=use_auth_token,
736
+ local_files_only=local_files_only,
737
+ force_download=force_download,
738
+ trust_remote_code=trust_remote_code,
739
+ rbln_kwargs=rbln_kwargs,
740
+ **kwargs,
741
+ )
742
+ preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
743
+ else:
744
+ preprocessors = []
576
745
 
577
746
  # FIXME :: optimum passes AutoConfig.
578
747
  config = model.config
748
+ if hasattr(model, "can_generate") and model.can_generate():
749
+ generation_config = model.generation_config
750
+ generation_config.save_pretrained(save_dir_path / subfolder)
579
751
 
580
752
  if not isinstance(config, PretrainedConfig): # diffusers config
581
753
  config = PretrainedConfig(**config)
582
-
583
754
  config.save_pretrained(save_dir_path / subfolder)
584
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
585
755
 
586
756
  # Get compilation arguments
587
- if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
588
- rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
589
- compiled_model = cls.get_compiled_model(model, rbln_config=rbln_config)
757
+ rbln_config: RBLNConfig = cls.get_rbln_config(
758
+ preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
759
+ )
760
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
761
+ model, rbln_config=rbln_config
762
+ )
590
763
 
591
764
  # Save compiled models
592
765
  (save_dir_path / subfolder).mkdir(exist_ok=True)
593
- if isinstance(compiled_model, Iterable):
594
- # compiled_model is an Iterable instance
595
- for single_compiled_model, compiled_model_name in zip(compiled_model, rbln_config):
596
- single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
766
+ if not isinstance(compiled_model, dict):
767
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
768
+ else:
597
769
  compiled_models = compiled_model
770
+ for compiled_model_name, cm in compiled_models.items():
771
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
772
+ rbln_config.save(save_dir_path / subfolder)
773
+
774
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
598
775
 
776
+ # Load submodules
777
+ if len(cls._rbln_submodules) > 0:
778
+ rbln_submodules = cls._load_submodules(
779
+ model=model,
780
+ model_save_dir=save_dir,
781
+ rbln_sub_configs_dict=rbln_sub_configs_dict,
782
+ rbln_kwargs=rbln_kwargs,
783
+ **kwargs,
784
+ )
599
785
  else:
600
- compiled_model.save(save_dir_path / subfolder / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
601
- compiled_models = [compiled_model]
602
- rbln_config.save(save_dir_path / subfolder)
786
+ rbln_submodules = []
603
787
 
604
788
  # Instantiate
605
789
  return cls._from_pretrained(
@@ -614,7 +798,7 @@ class RBLNModel(RBLNBaseModel):
614
798
  local_files_only=local_files_only,
615
799
  rbln_config=rbln_config,
616
800
  rbln_compiled_models=compiled_models,
617
- **rbln_constructor_kwargs,
801
+ rbln_submodules=rbln_submodules,
618
802
  **kwargs,
619
803
  )
620
804
 
@@ -635,16 +819,19 @@ class RBLNModel(RBLNBaseModel):
635
819
  class RBLNModelForQuestionAnswering(RBLNModel):
636
820
  model_type = "rbln_model"
637
821
  auto_model_class = AutoModelForQuestionAnswering
822
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
638
823
 
639
824
  @classmethod
640
825
  def _get_rbln_config(
641
826
  cls,
642
827
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
643
828
  model_config: Optional["PretrainedConfig"] = None,
644
- rbln_max_seq_len: Optional[int] = None,
645
- rbln_batch_size: Optional[int] = None,
646
- rbln_model_input_names: Optional[List[str]] = None,
829
+ rbln_kwargs: Dict[str, Any] = {},
647
830
  ) -> RBLNConfig:
831
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
832
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
833
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
834
+
648
835
  if rbln_max_seq_len is None:
649
836
  for tokenizer in preprocessors:
650
837
  if hasattr(tokenizer, "model_max_length"):
@@ -656,19 +843,34 @@ class RBLNModelForQuestionAnswering(RBLNModel):
656
843
  if rbln_batch_size is None:
657
844
  rbln_batch_size = 1
658
845
 
659
- if rbln_model_input_names is not None:
660
- cls.rbln_model_input_names = rbln_model_input_names
846
+ if rbln_model_input_names is None:
847
+ for tokenizer in preprocessors:
848
+ if hasattr(tokenizer, "model_input_names"):
849
+ rbln_model_input_names = tokenizer.model_input_names
850
+ break
851
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
852
+ rbln_model_input_names = cls.rbln_model_input_names
853
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
854
+ original_model_class = getattr(transformers, model_config.architectures[0])
855
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
856
+ raise ValueError(
857
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
858
+ f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(input_names_order)})"
859
+ )
661
860
 
662
861
  input_info = [
663
862
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
664
- for model_input_name in cls.rbln_model_input_names
863
+ for model_input_name in rbln_model_input_names
665
864
  ]
666
865
 
667
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
668
- rbln_runtime_config.batch_size = rbln_batch_size
669
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
670
-
671
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
866
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
867
+ rbln_config = RBLNConfig(
868
+ rbln_cls=cls.__name__,
869
+ compile_cfgs=[rbln_compile_config],
870
+ rbln_kwargs=rbln_kwargs,
871
+ )
872
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
873
+ return rbln_config
672
874
 
673
875
 
674
876
  class RBLNModelForImageClassification(RBLNModel):
@@ -684,9 +886,11 @@ class RBLNModelForImageClassification(RBLNModel):
684
886
  cls,
685
887
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
686
888
  model_config: Optional["PretrainedConfig"] = None,
687
- rbln_image_size: Optional[int] = None,
688
- rbln_batch_size: Optional[int] = None,
889
+ rbln_kwargs: Dict[str, Any] = {},
689
890
  ) -> RBLNConfig:
891
+ rbln_image_size = rbln_kwargs.get("image_size", None)
892
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
893
+
690
894
  if rbln_image_size is None:
691
895
  for processor in preprocessors:
692
896
  if hasattr(processor, "size"):
@@ -698,19 +902,19 @@ class RBLNModelForImageClassification(RBLNModel):
698
902
  if rbln_batch_size is None:
699
903
  rbln_batch_size = 1
700
904
 
905
+ if isinstance(rbln_image_size, int):
906
+ rbln_image_size = rbln_image_size, rbln_image_size
907
+
701
908
  input_info = [
702
909
  (
703
910
  "pixel_values",
704
- [rbln_batch_size, 3, rbln_image_size, rbln_image_size],
911
+ [rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]],
705
912
  "float32",
706
913
  )
707
914
  ]
708
915
 
709
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
710
- rbln_runtime_config.batch_size = rbln_batch_size
711
- meta = {"rbln_image_size": rbln_image_size}
712
-
713
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
916
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
917
+ return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
714
918
 
715
919
 
716
920
  class RBLNModelForAudioClassification(RBLNModel):
@@ -734,11 +938,11 @@ class RBLNModelForAudioClassification(RBLNModel):
734
938
  cls,
735
939
  preprocessors: "AutoFeatureExtractor",
736
940
  model_config: "PretrainedConfig",
737
- rbln_batch_size: Optional[int] = None,
738
- rbln_max_length: Optional[int] = None,
739
- rbln_num_mel_bins: Optional[int] = None,
941
+ rbln_kwargs: Dict[str, Any] = {},
740
942
  ) -> RBLNConfig:
741
- meta = {}
943
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
944
+ rbln_max_length = rbln_kwargs.get("max_length", None)
945
+ rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
742
946
 
743
947
  if rbln_batch_size is None:
744
948
  rbln_batch_size = 1
@@ -764,11 +968,7 @@ class RBLNModelForAudioClassification(RBLNModel):
764
968
  if rbln_max_length is None:
765
969
  raise ValueError("`rbln_max_length` should be specified!")
766
970
 
767
- meta["rbln_batch_size"] = rbln_batch_size
768
- meta["rbln_max_length"] = rbln_max_length
769
- meta["rbln_num_mel_bins"] = rbln_num_mel_bins
770
-
771
- model_input_info = [
971
+ input_info = [
772
972
  (
773
973
  "input_values",
774
974
  [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
@@ -776,13 +976,19 @@ class RBLNModelForAudioClassification(RBLNModel):
776
976
  ),
777
977
  ]
778
978
 
779
- rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
780
-
781
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
782
- [rbln_runtime_config],
783
- _rbln_meta=meta,
979
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
980
+ rbln_config = RBLNConfig(
981
+ rbln_cls=cls.__name__,
982
+ compile_cfgs=[rbln_compile_config],
983
+ rbln_kwargs=rbln_kwargs,
984
+ )
985
+ rbln_config.model_cfg.update(
986
+ {
987
+ "batch_size": rbln_batch_size,
988
+ "max_length": rbln_max_length,
989
+ "num_mel_bins": rbln_num_mel_bins,
990
+ }
784
991
  )
785
-
786
992
  return rbln_config
787
993
 
788
994
 
@@ -807,10 +1013,12 @@ class RBLNModelForSequenceClassification(RBLNModel):
807
1013
  cls,
808
1014
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
809
1015
  model_config: Optional["PretrainedConfig"] = None,
810
- rbln_max_seq_len: Optional[int] = None,
811
- rbln_model_input_names: Optional[List[str]] = None,
812
- rbln_batch_size: Optional[int] = None,
1016
+ rbln_kwargs: Dict[str, Any] = {},
813
1017
  ) -> RBLNConfig:
1018
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
1019
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
1020
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
1021
+
814
1022
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
815
1023
  model_config, "max_position_embeddings", None
816
1024
  )
@@ -829,21 +1037,36 @@ class RBLNModelForSequenceClassification(RBLNModel):
829
1037
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
830
1038
 
831
1039
  if rbln_model_input_names is None:
832
- # These are BERT's inputs
833
- rbln_model_input_names = ["input_ids", "attention_mask"]
1040
+ for tokenizer in preprocessors:
1041
+ if hasattr(tokenizer, "model_input_names"):
1042
+ rbln_model_input_names = tokenizer.model_input_names
1043
+ break
1044
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1045
+ rbln_model_input_names = cls.rbln_model_input_names
1046
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1047
+ original_model_class = getattr(transformers, model_config.architectures[0])
1048
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1049
+ raise ValueError(
1050
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1051
+ f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(input_names_order)})"
1052
+ )
834
1053
 
835
1054
  if rbln_batch_size is None:
836
1055
  rbln_batch_size = 1
1056
+
837
1057
  input_info = [
838
1058
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
839
1059
  for model_input_name in rbln_model_input_names
840
1060
  ]
841
1061
 
842
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
843
- rbln_runtime_config.batch_size = rbln_batch_size
844
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
845
-
846
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
1062
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1063
+ rbln_config = RBLNConfig(
1064
+ rbln_cls=cls.__name__,
1065
+ compile_cfgs=[rbln_compile_config],
1066
+ rbln_kwargs=rbln_kwargs,
1067
+ )
1068
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1069
+ return rbln_config
847
1070
 
848
1071
 
849
1072
  class RBLNModelForMaskedLM(RBLNModel):
@@ -855,10 +1078,12 @@ class RBLNModelForMaskedLM(RBLNModel):
855
1078
  cls,
856
1079
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
857
1080
  model_config: Optional["PretrainedConfig"] = None,
858
- rbln_max_seq_len: Optional[int] = None,
859
- rbln_model_input_names: Optional[List[str]] = None,
860
- rbln_batch_size: Optional[int] = None,
1081
+ rbln_kwargs: Dict[str, Any] = {},
861
1082
  ) -> RBLNConfig:
1083
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
1084
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
1085
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
1086
+
862
1087
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
863
1088
  model_config, "max_position_embeddings", None
864
1089
  )
@@ -877,18 +1102,33 @@ class RBLNModelForMaskedLM(RBLNModel):
877
1102
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
878
1103
 
879
1104
  if rbln_model_input_names is None:
880
- # These are BERT's inputs
881
- rbln_model_input_names = ["input_ids", "attention_mask"]
1105
+ for tokenizer in preprocessors:
1106
+ if hasattr(tokenizer, "model_input_names"):
1107
+ rbln_model_input_names = tokenizer.model_input_names
1108
+ break
1109
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1110
+ rbln_model_input_names = cls.rbln_model_input_names
1111
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1112
+ original_model_class = getattr(transformers, model_config.architectures[0])
1113
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1114
+ raise ValueError(
1115
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1116
+ f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(input_names_order)})"
1117
+ )
882
1118
 
883
1119
  if rbln_batch_size is None:
884
1120
  rbln_batch_size = 1
1121
+
885
1122
  input_info = [
886
1123
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
887
1124
  for model_input_name in rbln_model_input_names
888
1125
  ]
889
1126
 
890
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
891
- rbln_runtime_config.batch_size = rbln_batch_size
892
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
893
-
894
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
1127
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1128
+ rbln_config = RBLNConfig(
1129
+ rbln_cls=cls.__name__,
1130
+ compile_cfgs=[rbln_compile_config],
1131
+ rbln_kwargs=rbln_kwargs,
1132
+ )
1133
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1134
+ return rbln_config