optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. optimum/rbln/__init__.py +47 -9
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
  4. optimum/rbln/diffusers/models/controlnet.py +53 -43
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
  6. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
  7. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
  11. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
  13. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
  15. optimum/rbln/modeling_alias.py +6 -11
  16. optimum/rbln/modeling_base.py +467 -261
  17. optimum/rbln/modeling_config.py +199 -73
  18. optimum/rbln/transformers/__init__.py +43 -1
  19. optimum/rbln/transformers/models/__init__.py +23 -1
  20. optimum/rbln/transformers/models/auto/__init__.py +14 -0
  21. optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
  22. optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
  23. optimum/rbln/transformers/models/bart/__init__.py +1 -0
  24. optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
  25. optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
  26. optimum/rbln/transformers/models/bert/__init__.py +24 -0
  27. optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
  28. optimum/rbln/transformers/models/clip/__init__.py +1 -1
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
  30. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
  31. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
  32. optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
  33. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  34. optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
  35. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  37. optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
  38. optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
  41. optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
  42. optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
  47. optimum/rbln/transformers/models/phi/__init__.py +24 -0
  48. optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
  49. optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  51. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
  52. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  53. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  54. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
  55. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  56. optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
  57. optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
  58. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
  59. optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
  60. optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
  61. optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
  62. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
  63. optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
  64. optimum/rbln/utils/import_utils.py +50 -1
  65. optimum/rbln/utils/logging.py +82 -0
  66. optimum/rbln/utils/runtime_utils.py +33 -0
  67. optimum/rbln/utils/timer_utils.py +43 -0
  68. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
  69. optimum_rbln-0.1.12.dist-info/RECORD +103 -0
  70. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
  71. optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
  72. optimum_rbln-0.1.9.dist-info/RECORD +0 -78
  73. {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -21,16 +21,19 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
+ import importlib
25
+ import inspect
24
26
  import logging
25
27
  import os
26
28
  import shutil
27
29
  from abc import ABC, abstractmethod
28
30
  from pathlib import Path
29
31
  from tempfile import TemporaryDirectory
30
- from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
32
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
31
33
 
32
34
  import rebel
33
35
  import torch
36
+ import transformers
34
37
  from huggingface_hub import HfApi, HfFolder, hf_hub_download
35
38
  from optimum.exporters import TasksManager
36
39
  from optimum.modeling_base import OptimizedModel
@@ -46,9 +49,9 @@ from transformers import (
46
49
  PretrainedConfig,
47
50
  )
48
51
 
49
- from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
52
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig, use_rbln_config
50
53
  from .utils.runtime_utils import UnavailableRuntime
51
- from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
54
+ from .utils.save_utils import maybe_load_preprocessors
52
55
 
53
56
 
54
57
  if TYPE_CHECKING:
@@ -62,7 +65,111 @@ if TYPE_CHECKING:
62
65
  logger = logging.getLogger(__name__)
63
66
 
64
67
 
65
- class RBLNBaseModel(OptimizedModel, ABC):
68
+ class SubModulesMixin:
69
+ """
70
+ _rbln_submodules = [
71
+ {"name": "vision_tower"},
72
+ {"name": "language_model"},
73
+ ]
74
+ """
75
+
76
+ _rbln_submodules: List[Dict[str, Any]] = []
77
+
78
+ def __init__(
79
+ self,
80
+ *,
81
+ rbln_submodules: List["RBLNBaseModel"] = [],
82
+ **kwargs,
83
+ ) -> None:
84
+ for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
85
+ setattr(self, submodule_meta["name"], submodule)
86
+
87
+ @classmethod
88
+ def _export_submodules_from_model(
89
+ cls,
90
+ model: "PreTrainedModel",
91
+ model_save_dir: str,
92
+ rbln_kwargs: Dict[str, Any],
93
+ **kwargs,
94
+ ) -> List["RBLNBaseModel"]:
95
+ rbln_submodules = []
96
+ for submodule in cls._rbln_submodules:
97
+ submodule_name = submodule["name"]
98
+ torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
99
+ cls_name = torch_submodule.__class__.__name__
100
+ submodule_cls: "RBLNModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
101
+
102
+ if submodule_name in rbln_kwargs:
103
+ kwargs["rbln_config"] = rbln_kwargs[submodule_name]
104
+
105
+ rbln_submodule = submodule_cls.from_model(
106
+ model=torch_submodule,
107
+ subfolder=submodule_name,
108
+ model_save_dir=model_save_dir,
109
+ **kwargs,
110
+ )
111
+
112
+ rbln_submodules.append(rbln_submodule)
113
+
114
+ return rbln_submodules
115
+
116
+ @classmethod
117
+ def _load_submodules_from_compiled_models(
118
+ cls,
119
+ model_save_dir: str,
120
+ rbln_kwargs: Dict[str, Any],
121
+ **kwargs,
122
+ ):
123
+ rbln_submodules = []
124
+ for submodule in cls._rbln_submodules:
125
+ submodule_name = submodule["name"]
126
+
127
+ if submodule_name in rbln_kwargs:
128
+ kwargs["rbln_config"] = rbln_kwargs[submodule_name]
129
+
130
+ # Get cls name for call the constructor of the rbln class
131
+ submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
132
+ submodule_cls_name = submodule_rbln_config.meta["cls"]
133
+ submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
134
+
135
+ config = OptimizedModel._load_config(Path(model_save_dir) / submodule_name)
136
+ rbln_submodule = submodule_cls._from_pretrained(
137
+ model_id=model_save_dir,
138
+ config=config,
139
+ subfolder=submodule_name,
140
+ **kwargs,
141
+ )
142
+ rbln_submodules.append(rbln_submodule)
143
+ return rbln_submodules
144
+
145
+ @classmethod
146
+ def _load_submodules(
147
+ cls,
148
+ model_save_dir,
149
+ rbln_kwargs,
150
+ model=None,
151
+ **kwargs,
152
+ ):
153
+ # Two ways :
154
+ # 1. Compile from pytorch object
155
+ # 2. Load from compiled file
156
+ if model is not None:
157
+ return cls._export_submodules_from_model(
158
+ model=model,
159
+ model_save_dir=model_save_dir,
160
+ rbln_kwargs=rbln_kwargs,
161
+ **kwargs,
162
+ )
163
+
164
+ else:
165
+ return cls._load_submodules_from_compiled_models(
166
+ model_save_dir=model_save_dir,
167
+ rbln_kwargs=rbln_kwargs,
168
+ **kwargs,
169
+ )
170
+
171
+
172
+ class RBLNBaseModel(OptimizedModel, ABC, SubModulesMixin):
66
173
  """
67
174
  An abstract base class for compiling, loading, and saving neural network models from the huggingface
68
175
  transformers and diffusers libraries to run on RBLN NPU devices.
@@ -110,6 +217,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
110
217
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
111
218
  subfolder: str = "",
112
219
  rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
220
+ rbln_submodules: List["RBLNBaseModel"] = [],
113
221
  **kwargs,
114
222
  ):
115
223
  super().__init__(models, config)
@@ -127,11 +235,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
127
235
  self.auto_model_class.register(AutoConfig, self.__class__)
128
236
 
129
237
  # copied from tranformers PreTrainedModel __init__
130
- self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
238
+ if self.can_generate():
239
+ gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
240
+ self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
241
+ else:
242
+ self.generation_config = None
243
+
244
+ # self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
131
245
  if self.generation_config is not None:
132
246
  self.generation_config.use_cache = True
133
247
 
134
248
  self.device = torch.device("cpu")
249
+ self.training = False
135
250
 
136
251
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
137
252
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
@@ -146,11 +261,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
146
261
  self.model_save_dir = model_save_dir
147
262
  self.subfolder = subfolder
148
263
 
264
+ self.rbln_submodules = rbln_submodules
149
265
  self.__post_init__(**kwargs)
150
266
 
151
- def __post_init__(self, **kwargs):
152
- pass
153
-
154
267
  def _save_pretrained(self, save_directory: Union[str, Path]):
155
268
  """
156
269
  Saves a model and its configuration file to a directory, so that it can be re-loaded using the
@@ -180,27 +293,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
180
293
  )
181
294
 
182
295
  @classmethod
183
- def _from_pretrained(
296
+ def _load_compiled_model_dir(
184
297
  cls,
185
298
  model_id: Union[str, Path],
186
- config: "PretrainedConfig",
187
299
  use_auth_token: Optional[Union[bool, str]] = None,
188
300
  revision: Optional[str] = None,
189
301
  force_download: bool = False,
190
302
  cache_dir: Optional[str] = None,
191
303
  subfolder: str = "",
192
304
  local_files_only: bool = False,
193
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
194
- # Runtime - related kwargs
195
- rbln_device: Optional[List[int]] = None,
196
- rbln_device_map: Optional[Dict[str, int]] = None,
197
- rbln_create_runtimes: Optional[bool] = None,
198
- # passed from compile function
199
- rbln_config: Optional[RBLNConfig] = None,
200
- rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
201
- rbln_optimize_host_memory: Optional[bool] = None,
202
- **kwargs,
203
- ) -> "RBLNBaseModel":
305
+ ):
306
+ # Find compiled model
307
+ # And prepare or download cache folder from HF Hub if needed.
204
308
  model_path = Path(model_id)
205
309
  if model_path.is_dir():
206
310
  model_path = model_path / subfolder
@@ -236,16 +340,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
236
340
  )
237
341
 
238
342
  if model_path.is_dir():
239
- if rbln_compiled_models is None:
240
- rbln_config = RBLNConfig.load(str(model_path))
241
- rbln_compiled_models = [
242
- rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
243
- for compiled_model_name in rbln_config
244
- ]
245
- new_model_save_dir = model_path
246
- else:
247
- pass
248
-
343
+ model_path = str(model_path)
249
344
  else:
250
345
  rbln_config_filename = rbln_config_filenames[0]
251
346
  rbln_config_cache_path = hf_hub_download(
@@ -258,48 +353,106 @@ class RBLNBaseModel(OptimizedModel, ABC):
258
353
  force_download=force_download,
259
354
  local_files_only=local_files_only,
260
355
  )
261
- rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
262
- rbln_compiled_models = []
263
- for compiled_model_name in rbln_config:
264
- model_cache_path = hf_hub_download(
265
- repo_id=model_id,
266
- filename=f"{compiled_model_name}.rbln",
267
- subfolder=subfolder,
268
- use_auth_token=use_auth_token,
269
- revision=revision,
270
- cache_dir=cache_dir,
271
- force_download=force_download,
272
- local_files_only=local_files_only,
356
+ model_path = Path(rbln_config_cache_path).parent
357
+
358
+ return model_path
359
+
360
+ @classmethod
361
+ def _load_compiled_models(cls, model_path: str):
362
+ compiled_models = Path(model_path).glob("*.rbln")
363
+ rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
364
+ return rbln_compiled_models
365
+
366
+ @classmethod
367
+ @use_rbln_config
368
+ def _from_pretrained(
369
+ cls,
370
+ model_id: Union[str, Path],
371
+ config: "PretrainedConfig",
372
+ use_auth_token: Optional[Union[bool, str]] = None,
373
+ revision: Optional[str] = None,
374
+ force_download: bool = False,
375
+ cache_dir: Optional[str] = None,
376
+ subfolder: str = "",
377
+ local_files_only: bool = False,
378
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
379
+ # passed from compile function
380
+ rbln_config: Optional[RBLNConfig] = None,
381
+ rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
382
+ rbln_submodules: List["RBLNBaseModel"] = [],
383
+ **kwargs,
384
+ ) -> "RBLNBaseModel":
385
+ from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
386
+
387
+ if not from_export_method:
388
+ # from compiled dir
389
+ rbln_kwargs = rbln_config or {}
390
+
391
+ model_path_subfolder = cls._load_compiled_model_dir(
392
+ model_id=model_id,
393
+ use_auth_token=use_auth_token,
394
+ revision=revision,
395
+ force_download=force_download,
396
+ cache_dir=cache_dir,
397
+ subfolder=subfolder,
398
+ local_files_only=local_files_only,
399
+ )
400
+
401
+ rbln_config = RBLNConfig.load(model_path_subfolder)
402
+ rbln_config.update_runtime_cfg(rbln_kwargs)
403
+
404
+ rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
405
+
406
+ if len(cls._rbln_submodules) > 0:
407
+ rbln_submodules = cls._load_submodules(
408
+ model_save_dir=model_id,
409
+ rbln_kwargs=rbln_kwargs,
410
+ **kwargs,
273
411
  )
274
- rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
275
- new_model_save_dir = Path(rbln_config_cache_path).parent
412
+ else:
413
+ rbln_submodules = []
276
414
 
277
- preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
415
+ if subfolder != "":
416
+ model_save_dir = Path(model_path_subfolder).absolute().parent
417
+ else:
418
+ model_save_dir = Path(model_path_subfolder).absolute()
278
419
 
279
- if model_save_dir is None:
280
- model_save_dir = new_model_save_dir
281
-
282
- # Create runtimes
283
- if rbln_create_runtimes is None:
284
- rbln_create_runtimes = rebel.npu_is_available()
285
- if rbln_device_map is None:
286
- rbln_device_map = {}
287
- device_val = 0 if rbln_device is None else rbln_device
288
- for key in rbln_config:
289
- rbln_device_map[key] = device_val
290
- else:
291
- rbln_device_map = rbln_device_map
420
+ return cls._from_compiled_models(
421
+ rbln_compiled_models=rbln_compiled_models,
422
+ rbln_config=rbln_config,
423
+ config=config,
424
+ model_save_dir=model_save_dir,
425
+ subfolder=subfolder,
426
+ rbln_submodules=rbln_submodules,
427
+ **kwargs,
428
+ )
429
+
430
+ @classmethod
431
+ def _from_compiled_models(
432
+ cls,
433
+ rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
434
+ rbln_config: RBLNConfig,
435
+ config: "PretrainedConfig",
436
+ model_save_dir: Union[Path, str],
437
+ subfolder: Union[Path, str],
438
+ rbln_submodules: List["RBLNBaseModel"] = [],
439
+ **kwargs,
440
+ ):
441
+ if isinstance(model_save_dir, str):
442
+ model_save_dir = Path(model_save_dir)
443
+ preprocessors = maybe_load_preprocessors(model_save_dir.name, subfolder=subfolder)
444
+
445
+ # FIXME:: Should we convert it?
446
+ compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
447
+ rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
292
448
 
293
449
  # create runtimes only if `rbln_create_runtimes` is enabled
294
450
  models = (
295
- cls._create_runtimes(rbln_compiled_models, rbln_device_map)
296
- if rbln_create_runtimes
451
+ cls._create_runtimes(rbln_compiled_models, rbln_config.device_map)
452
+ if rbln_config.create_runtimes
297
453
  else UnavailableRuntime()
298
454
  )
299
455
 
300
- if rbln_optimize_host_memory is None:
301
- rbln_optimize_host_memory = True
302
-
303
456
  return cls(
304
457
  models,
305
458
  config,
@@ -307,99 +460,50 @@ class RBLNBaseModel(OptimizedModel, ABC):
307
460
  preprocessors,
308
461
  model_save_dir=model_save_dir,
309
462
  subfolder=subfolder,
310
- rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
463
+ rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
464
+ rbln_submodules=rbln_submodules,
311
465
  **kwargs,
312
466
  )
313
467
 
314
468
  def __repr__(self):
315
- return repr(self.model)
469
+ return repr(self.model) + repr(self.rbln_submodules)
316
470
 
317
471
  @classmethod
318
- def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
472
+ def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None):
319
473
  compiled_model = rebel.compile_from_torch(
320
474
  model,
321
- input_info=rbln_runtime_config.input_info,
322
- batch_size=rbln_runtime_config.batch_size,
323
- fusion=rbln_runtime_config.fusion,
324
- npu=rbln_runtime_config.npu,
325
- tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
475
+ input_info=rbln_compile_config.input_info,
476
+ fusion=rbln_compile_config.fusion,
477
+ npu=rbln_compile_config.npu,
478
+ tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
326
479
  )
327
480
  return compiled_model
328
481
 
329
482
  @classmethod
330
483
  def get_rbln_config(
331
484
  cls,
332
- **rbln_config_kwargs,
485
+ rbln_kwargs: Dict[str, Any],
486
+ **others,
333
487
  ) -> RBLNConfig:
334
488
  """
335
489
  Make default rbln-config for the model.
336
-
337
- if `input_info` specified,
338
- other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
339
-
340
490
  kwargs for overriding model's config can be accepted.
341
-
342
491
  Note that batch_size should be specified with proper input_info.
343
492
  """
344
-
345
- input_info = rbln_config_kwargs.pop("rbln_input_info", None)
346
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
347
- fusion = rbln_config_kwargs.pop("rbln_fusion", None)
348
- npu = rbln_config_kwargs.pop("rbln_npu", None)
349
- tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
350
-
351
- if input_info is not None:
352
- rbln_runtime_config = RBLNRuntimeConfig(
353
- input_info=input_info,
354
- batch_size=batch_size,
355
- fusion=fusion,
356
- npu=npu,
357
- tensor_parallel_size=tensor_parallel_size,
358
- )
359
- rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
360
- else:
361
- rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
362
- for k, rcfgs in rbln_config.items():
363
- for rcfg in rcfgs:
364
- rcfg: RBLNRuntimeConfig
365
- rcfg.fusion = fusion
366
- rcfg.npu = npu
367
- rcfg.tensor_parallel_size = tensor_parallel_size
368
-
493
+ rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
369
494
  return rbln_config
370
495
 
371
- @staticmethod
372
- def pop_rbln_kwargs_from_kwargs(kwargs: dict):
373
- keys = list(kwargs.keys())
374
- rbln_constructor_kwargs = {
375
- key: kwargs.pop(key)
376
- for key in keys
377
- if key
378
- in [
379
- "rbln_device",
380
- "rbln_device_map",
381
- "rbln_create_runtimes",
382
- "rbln_optimize_host_memory",
383
- ]
384
- }
385
-
386
- keys = list(kwargs.keys())
387
- rbln_config_kwargs = {key: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
388
- return rbln_config_kwargs, rbln_constructor_kwargs
389
-
390
496
  def can_generate(self):
391
497
  return False
392
498
 
393
499
  def to(self, *args, **kwargs):
394
- pass
500
+ return self
395
501
 
396
502
  def __call__(self, *args, **kwargs):
397
503
  return self.forward(*args, **kwargs)
398
504
 
399
- @classmethod
400
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
401
- # Wrap the model if needed.
402
- return model
505
+ def __post_init__(self, **kwargs):
506
+ self.dtype = torch.float32
403
507
 
404
508
  @classmethod
405
509
  def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
@@ -410,8 +514,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
410
514
  return cls._export(*args, **kwargs)
411
515
 
412
516
  @classmethod
517
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
518
+ # Wrap the model if needed.
519
+ return model
520
+
521
+ @classmethod
522
+ @abstractmethod
413
523
  def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
414
- raise NotImplementedError
524
+ pass
415
525
 
416
526
  @abstractmethod
417
527
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
@@ -429,25 +539,49 @@ class RBLNBaseModel(OptimizedModel, ABC):
429
539
 
430
540
  @classmethod
431
541
  @abstractmethod
432
- def _export(
542
+ def get_pytorch_model(cls, *args, **kwargs):
543
+ pass
544
+
545
+ @classmethod
546
+ @abstractmethod
547
+ @use_rbln_config
548
+ def from_model(
433
549
  cls,
434
- model_id: Union[str, Path],
435
- config: "PretrainedConfig",
436
- use_auth_token: Optional[Union[bool, str]] = None,
437
- revision: Optional[str] = None,
438
- force_download: bool = False,
439
- cache_dir: Optional[str] = None,
440
- subfolder: str = "",
441
- local_files_only: bool = False,
442
- trust_remote_code: bool = False,
550
+ model: "PreTrainedModel",
551
+ rbln_config: Dict[str, Any] = {},
443
552
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
553
+ subfolder: str = "",
444
554
  **kwargs,
445
555
  ):
446
- """
447
- Exports a vanilla Transformers model into a rbln-compiled Module.
448
- """
449
556
  pass
450
557
 
558
+ @classmethod
559
+ @use_rbln_config
560
+ def _export(
561
+ cls,
562
+ model_id: Union[str, Path],
563
+ config: "PretrainedConfig", # FIXME : optimum passes config, but we ignore it.
564
+ rbln_config: Optional[Dict[str, Any]] = None,
565
+ **kwargs,
566
+ ) -> "RBLNModel":
567
+ subfolder = kwargs.get("subfolder", "")
568
+ model_save_dir = kwargs.pop("model_save_dir", None)
569
+
570
+ rbln_kwargs = rbln_config
571
+ model: "PreTrainedModel" = cls.get_pytorch_model(
572
+ model_id=model_id,
573
+ rbln_kwargs=rbln_kwargs,
574
+ **kwargs,
575
+ )
576
+ preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
577
+ return cls.from_model(
578
+ model,
579
+ rbln_config=rbln_config,
580
+ preprocessors=preprocessors,
581
+ model_save_dir=model_save_dir,
582
+ **kwargs,
583
+ )
584
+
451
585
 
452
586
  class RBLNModel(RBLNBaseModel):
453
587
  """
@@ -491,8 +625,8 @@ class RBLNModel(RBLNBaseModel):
491
625
  subfolder: str = "",
492
626
  local_files_only: bool = False,
493
627
  trust_remote_code: bool = False,
494
- rbln_config_kwargs: Optional[Dict[str, Any]] = None,
495
- rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
628
+ # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
629
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
496
630
  **kwargs,
497
631
  ) -> "PreTrainedModel":
498
632
  task = kwargs.pop("task", None)
@@ -517,36 +651,40 @@ class RBLNModel(RBLNBaseModel):
517
651
 
518
652
  return model
519
653
 
654
+ @classmethod
655
+ def save_torch_artifacts(
656
+ cls,
657
+ model: "PreTrainedModel",
658
+ save_dir_path: Path,
659
+ subfolder: str,
660
+ rbln_config: RBLNConfig,
661
+ ):
662
+ """
663
+ If you are unavoidably running on a CPU rather than an RBLN device,
664
+ store the torch tensor, weight, etc. in this function.
665
+ """
666
+
520
667
  @classmethod
521
668
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
522
669
  model = cls.wrap_model_if_needed(model, rbln_config)
523
- rbln_runtime_configs = list(rbln_config.values())
524
- if len(rbln_runtime_configs) != 1:
525
- raise ValueError
526
- rbln_runtime_config = rbln_runtime_configs[0]
527
- if len(rbln_runtime_config) != 1:
528
- raise ValueError
529
- rbln_runtime_config = rbln_runtime_config[0]
530
-
531
- compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
670
+ rbln_compile_config = rbln_config.compile_cfgs[0]
671
+ compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
532
672
  return compiled_model
533
673
 
534
674
  @classmethod
535
- @torch.no_grad()
536
- def _export(
675
+ @use_rbln_config
676
+ def from_model(
537
677
  cls,
538
- model_id: str,
539
- config: "PretrainedConfig",
540
- use_auth_token: Optional[Union[bool, str]] = None,
541
- revision: Optional[str] = None,
542
- force_download: bool = False,
543
- cache_dir: Optional[str] = None,
544
- subfolder: str = "",
545
- local_files_only: bool = False,
546
- trust_remote_code: bool = False,
678
+ model: "PreTrainedModel",
679
+ rbln_config: Dict[str, Any] = {},
547
680
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
681
+ subfolder: str = "",
548
682
  **kwargs,
549
- ) -> "RBLNModel":
683
+ ):
684
+ preprocessors = kwargs.pop("preprocessors", [])
685
+ rbln_kwargs = rbln_config
686
+
687
+ # Directory to save compile artifacts(.rbln) and original configs
550
688
  if model_save_dir is None:
551
689
  save_dir = TemporaryDirectory()
552
690
  save_dir_path = Path(save_dir.name)
@@ -558,63 +696,63 @@ class RBLNModel(RBLNBaseModel):
558
696
  save_dir_path = Path(model_save_dir)
559
697
  save_dir_path.mkdir(exist_ok=True)
560
698
 
561
- rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
562
-
563
- model: "PreTrainedModel" = cls.get_pytorch_model(
564
- model_id=model_id,
565
- subfolder=subfolder,
566
- revision=revision,
567
- cache_dir=cache_dir,
568
- use_auth_token=use_auth_token,
569
- local_files_only=local_files_only,
570
- force_download=force_download,
571
- trust_remote_code=trust_remote_code,
572
- rbln_config_kwargs=rbln_config_kwargs,
573
- rbln_constructor_kwargs=rbln_constructor_kwargs,
574
- **kwargs,
575
- )
699
+ # (Optional) Save preprocessors (tokenizer, image preprocessors, etc)
700
+ for preprocessor in preprocessors:
701
+ preprocessor.save_pretrained(save_dir_path)
576
702
 
577
- # FIXME :: optimum passes AutoConfig.
703
+ # Save configs
704
+ # FIXME :: optimum passes AutoConfig. But here we ignore it.
578
705
  config = model.config
579
-
706
+ if hasattr(model, "can_generate") and model.can_generate():
707
+ generation_config = model.generation_config
708
+ generation_config.save_pretrained(save_dir_path / subfolder)
580
709
  if not isinstance(config, PretrainedConfig): # diffusers config
581
710
  config = PretrainedConfig(**config)
582
-
583
711
  config.save_pretrained(save_dir_path / subfolder)
584
- preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
585
712
 
586
- # Get compilation arguments
587
- if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
588
- rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
589
- compiled_model = cls.get_compiled_model(model, rbln_config=rbln_config)
713
+ # Get compilation arguments (e.g. input_info)
714
+ rbln_config: RBLNConfig = cls.get_rbln_config(
715
+ preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
716
+ )
717
+ # rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
718
+
719
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
720
+ model, rbln_config=rbln_config
721
+ )
590
722
 
591
- # Save compiled models
723
+ # Save compiled models (.rbln)
592
724
  (save_dir_path / subfolder).mkdir(exist_ok=True)
593
- if isinstance(compiled_model, Iterable):
594
- # compiled_model is an Iterable instance
595
- for single_compiled_model, compiled_model_name in zip(compiled_model, rbln_config):
596
- single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
725
+ if not isinstance(compiled_model, dict):
726
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
727
+ else:
597
728
  compiled_models = compiled_model
729
+ for compiled_model_name, cm in compiled_models.items():
730
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
731
+ rbln_config.save(save_dir_path / subfolder)
732
+
733
+ # Save torch artifacts (e.g. embedding matrix if needed.)
734
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
598
735
 
736
+ # Load submodules
737
+ if len(cls._rbln_submodules) > 0:
738
+ rbln_submodules = cls._load_submodules(
739
+ model=model,
740
+ model_save_dir=save_dir,
741
+ rbln_kwargs=rbln_kwargs,
742
+ **kwargs,
743
+ )
599
744
  else:
600
- compiled_model.save(save_dir_path / subfolder / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
601
- compiled_models = [compiled_model]
602
- rbln_config.save(save_dir_path / subfolder)
745
+ rbln_submodules = []
603
746
 
604
747
  # Instantiate
605
748
  return cls._from_pretrained(
606
749
  model_id=save_dir_path,
607
750
  config=config,
608
751
  model_save_dir=save_dir,
609
- use_auth_token=use_auth_token,
610
- revision=revision,
611
- force_download=force_download,
612
- cache_dir=cache_dir,
613
752
  subfolder=subfolder,
614
- local_files_only=local_files_only,
615
753
  rbln_config=rbln_config,
616
754
  rbln_compiled_models=compiled_models,
617
- **rbln_constructor_kwargs,
755
+ rbln_submodules=rbln_submodules,
618
756
  **kwargs,
619
757
  )
620
758
 
@@ -633,18 +771,20 @@ class RBLNModel(RBLNBaseModel):
633
771
 
634
772
 
635
773
  class RBLNModelForQuestionAnswering(RBLNModel):
636
- model_type = "rbln_model"
637
774
  auto_model_class = AutoModelForQuestionAnswering
775
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
638
776
 
639
777
  @classmethod
640
778
  def _get_rbln_config(
641
779
  cls,
642
780
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
643
781
  model_config: Optional["PretrainedConfig"] = None,
644
- rbln_max_seq_len: Optional[int] = None,
645
- rbln_batch_size: Optional[int] = None,
646
- rbln_model_input_names: Optional[List[str]] = None,
782
+ rbln_kwargs: Dict[str, Any] = {},
647
783
  ) -> RBLNConfig:
784
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
785
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
786
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
787
+
648
788
  if rbln_max_seq_len is None:
649
789
  for tokenizer in preprocessors:
650
790
  if hasattr(tokenizer, "model_max_length"):
@@ -656,19 +796,34 @@ class RBLNModelForQuestionAnswering(RBLNModel):
656
796
  if rbln_batch_size is None:
657
797
  rbln_batch_size = 1
658
798
 
659
- if rbln_model_input_names is not None:
660
- cls.rbln_model_input_names = rbln_model_input_names
799
+ if rbln_model_input_names is None:
800
+ for tokenizer in preprocessors:
801
+ if hasattr(tokenizer, "model_input_names"):
802
+ rbln_model_input_names = tokenizer.model_input_names
803
+ break
804
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
805
+ rbln_model_input_names = cls.rbln_model_input_names
806
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
807
+ original_model_class = getattr(transformers, model_config.architectures[0])
808
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
809
+ raise ValueError(
810
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
811
+ f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(input_names_order)})"
812
+ )
661
813
 
662
814
  input_info = [
663
815
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
664
- for model_input_name in cls.rbln_model_input_names
816
+ for model_input_name in rbln_model_input_names
665
817
  ]
666
818
 
667
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
668
- rbln_runtime_config.batch_size = rbln_batch_size
669
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
670
-
671
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
819
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
820
+ rbln_config = RBLNConfig(
821
+ rbln_cls=cls.__name__,
822
+ compile_cfgs=[rbln_compile_config],
823
+ rbln_kwargs=rbln_kwargs,
824
+ )
825
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
826
+ return rbln_config
672
827
 
673
828
 
674
829
  class RBLNModelForImageClassification(RBLNModel):
@@ -676,7 +831,6 @@ class RBLNModelForImageClassification(RBLNModel):
676
831
  This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method
677
832
  """
678
833
 
679
- model_type = "rbln_model"
680
834
  auto_model_class = AutoModelForImageClassification
681
835
 
682
836
  @classmethod
@@ -684,33 +838,52 @@ class RBLNModelForImageClassification(RBLNModel):
684
838
  cls,
685
839
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
686
840
  model_config: Optional["PretrainedConfig"] = None,
687
- rbln_image_size: Optional[int] = None,
688
- rbln_batch_size: Optional[int] = None,
841
+ rbln_kwargs: Dict[str, Any] = {},
689
842
  ) -> RBLNConfig:
843
+ rbln_image_size = rbln_kwargs.get("image_size", None)
844
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
845
+
690
846
  if rbln_image_size is None:
691
847
  for processor in preprocessors:
692
848
  if hasattr(processor, "size"):
693
- rbln_image_size = processor.size["shortest_edge"]
849
+ if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
850
+ rbln_image_size = (processor.size["height"], processor.size["width"])
851
+ elif "shortest_edge" in processor.size.keys():
852
+ rbln_image_size = (processor.size["shortest_edge"], processor.size["shortest_edge"])
853
+ elif "longest_edge" in processor.size.keys():
854
+ rbln_image_size = (processor.size["longest_edge"], processor.size["longest_edge"])
694
855
  break
856
+
857
+ if rbln_image_size is None:
858
+ rbln_image_size = model_config.image_size
859
+
695
860
  if rbln_image_size is None:
696
861
  raise ValueError("`rbln_image_size` should be specified!")
697
862
 
698
863
  if rbln_batch_size is None:
699
864
  rbln_batch_size = 1
700
865
 
866
+ if isinstance(rbln_image_size, int):
867
+ rbln_image_height, rbln_image_width = rbln_image_size, rbln_image_size
868
+ elif isinstance(rbln_image_size, (list, tuple)):
869
+ rbln_image_height, rbln_image_width = rbln_image_size[0], rbln_image_size[1]
870
+ elif isinstance(rbln_image_size, dict):
871
+ rbln_image_height, rbln_image_width = rbln_image_size["height"], rbln_image_size["width"]
872
+ else:
873
+ raise ValueError(
874
+ "`rbln_image_size` should be `int` (ex. 224), `tuple` (ex. 224, 224), `dict` (ex. {'height': 224, 'width': 224}) format"
875
+ )
876
+
701
877
  input_info = [
702
878
  (
703
879
  "pixel_values",
704
- [rbln_batch_size, 3, rbln_image_size, rbln_image_size],
880
+ [rbln_batch_size, 3, rbln_image_height, rbln_image_width],
705
881
  "float32",
706
882
  )
707
883
  ]
708
884
 
709
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
710
- rbln_runtime_config.batch_size = rbln_batch_size
711
- meta = {"rbln_image_size": rbln_image_size}
712
-
713
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
885
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
886
+ return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
714
887
 
715
888
 
716
889
  class RBLNModelForAudioClassification(RBLNModel):
@@ -726,7 +899,6 @@ class RBLNModelForAudioClassification(RBLNModel):
726
899
  Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
727
900
  """
728
901
 
729
- model_type = "rbln_model"
730
902
  auto_model_class = AutoModelForAudioClassification
731
903
 
732
904
  @classmethod
@@ -734,11 +906,11 @@ class RBLNModelForAudioClassification(RBLNModel):
734
906
  cls,
735
907
  preprocessors: "AutoFeatureExtractor",
736
908
  model_config: "PretrainedConfig",
737
- rbln_batch_size: Optional[int] = None,
738
- rbln_max_length: Optional[int] = None,
739
- rbln_num_mel_bins: Optional[int] = None,
909
+ rbln_kwargs: Dict[str, Any] = {},
740
910
  ) -> RBLNConfig:
741
- meta = {}
911
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
912
+ rbln_max_length = rbln_kwargs.get("max_length", None)
913
+ rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
742
914
 
743
915
  if rbln_batch_size is None:
744
916
  rbln_batch_size = 1
@@ -764,11 +936,7 @@ class RBLNModelForAudioClassification(RBLNModel):
764
936
  if rbln_max_length is None:
765
937
  raise ValueError("`rbln_max_length` should be specified!")
766
938
 
767
- meta["rbln_batch_size"] = rbln_batch_size
768
- meta["rbln_max_length"] = rbln_max_length
769
- meta["rbln_num_mel_bins"] = rbln_num_mel_bins
770
-
771
- model_input_info = [
939
+ input_info = [
772
940
  (
773
941
  "input_values",
774
942
  [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
@@ -776,13 +944,19 @@ class RBLNModelForAudioClassification(RBLNModel):
776
944
  ),
777
945
  ]
778
946
 
779
- rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
780
-
781
- rbln_config = RBLNConfig.from_rbln_runtime_configs(
782
- [rbln_runtime_config],
783
- _rbln_meta=meta,
947
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
948
+ rbln_config = RBLNConfig(
949
+ rbln_cls=cls.__name__,
950
+ compile_cfgs=[rbln_compile_config],
951
+ rbln_kwargs=rbln_kwargs,
952
+ )
953
+ rbln_config.model_cfg.update(
954
+ {
955
+ "batch_size": rbln_batch_size,
956
+ "max_length": rbln_max_length,
957
+ "num_mel_bins": rbln_num_mel_bins,
958
+ }
784
959
  )
785
-
786
960
  return rbln_config
787
961
 
788
962
 
@@ -799,7 +973,6 @@ class RBLNModelForSequenceClassification(RBLNModel):
799
973
  Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
800
974
  """
801
975
 
802
- model_type = "rbln_model"
803
976
  auto_model_class = AutoModelForSequenceClassification
804
977
 
805
978
  @classmethod
@@ -807,10 +980,12 @@ class RBLNModelForSequenceClassification(RBLNModel):
807
980
  cls,
808
981
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
809
982
  model_config: Optional["PretrainedConfig"] = None,
810
- rbln_max_seq_len: Optional[int] = None,
811
- rbln_model_input_names: Optional[List[str]] = None,
812
- rbln_batch_size: Optional[int] = None,
983
+ rbln_kwargs: Dict[str, Any] = {},
813
984
  ) -> RBLNConfig:
985
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
986
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
987
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
988
+
814
989
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
815
990
  model_config, "max_position_embeddings", None
816
991
  )
@@ -829,25 +1004,39 @@ class RBLNModelForSequenceClassification(RBLNModel):
829
1004
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
830
1005
 
831
1006
  if rbln_model_input_names is None:
832
- # These are BERT's inputs
833
- rbln_model_input_names = ["input_ids", "attention_mask"]
1007
+ for tokenizer in preprocessors:
1008
+ if hasattr(tokenizer, "model_input_names"):
1009
+ rbln_model_input_names = tokenizer.model_input_names
1010
+ break
1011
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1012
+ rbln_model_input_names = cls.rbln_model_input_names
1013
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1014
+ original_model_class = getattr(transformers, model_config.architectures[0])
1015
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1016
+ raise ValueError(
1017
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1018
+ f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(input_names_order)})"
1019
+ )
834
1020
 
835
1021
  if rbln_batch_size is None:
836
1022
  rbln_batch_size = 1
1023
+
837
1024
  input_info = [
838
1025
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
839
1026
  for model_input_name in rbln_model_input_names
840
1027
  ]
841
1028
 
842
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
843
- rbln_runtime_config.batch_size = rbln_batch_size
844
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
845
-
846
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
1029
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1030
+ rbln_config = RBLNConfig(
1031
+ rbln_cls=cls.__name__,
1032
+ compile_cfgs=[rbln_compile_config],
1033
+ rbln_kwargs=rbln_kwargs,
1034
+ )
1035
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1036
+ return rbln_config
847
1037
 
848
1038
 
849
1039
  class RBLNModelForMaskedLM(RBLNModel):
850
- model_type = "rbln_model"
851
1040
  auto_model_class = AutoModelForMaskedLM
852
1041
 
853
1042
  @classmethod
@@ -855,10 +1044,12 @@ class RBLNModelForMaskedLM(RBLNModel):
855
1044
  cls,
856
1045
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
857
1046
  model_config: Optional["PretrainedConfig"] = None,
858
- rbln_max_seq_len: Optional[int] = None,
859
- rbln_model_input_names: Optional[List[str]] = None,
860
- rbln_batch_size: Optional[int] = None,
1047
+ rbln_kwargs: Dict[str, Any] = {},
861
1048
  ) -> RBLNConfig:
1049
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
1050
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
1051
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
1052
+
862
1053
  max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
863
1054
  model_config, "max_position_embeddings", None
864
1055
  )
@@ -877,18 +1068,33 @@ class RBLNModelForMaskedLM(RBLNModel):
877
1068
  raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
878
1069
 
879
1070
  if rbln_model_input_names is None:
880
- # These are BERT's inputs
881
- rbln_model_input_names = ["input_ids", "attention_mask"]
1071
+ for tokenizer in preprocessors:
1072
+ if hasattr(tokenizer, "model_input_names"):
1073
+ rbln_model_input_names = tokenizer.model_input_names
1074
+ break
1075
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
1076
+ rbln_model_input_names = cls.rbln_model_input_names
1077
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
1078
+ original_model_class = getattr(transformers, model_config.architectures[0])
1079
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
1080
+ raise ValueError(
1081
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
1082
+ f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(input_names_order)})"
1083
+ )
882
1084
 
883
1085
  if rbln_batch_size is None:
884
1086
  rbln_batch_size = 1
1087
+
885
1088
  input_info = [
886
1089
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
887
1090
  for model_input_name in rbln_model_input_names
888
1091
  ]
889
1092
 
890
- rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
891
- rbln_runtime_config.batch_size = rbln_batch_size
892
- meta = {"rbln_max_seq_len": rbln_max_seq_len}
893
-
894
- return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
1093
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
1094
+ rbln_config = RBLNConfig(
1095
+ rbln_cls=cls.__name__,
1096
+ compile_cfgs=[rbln_compile_config],
1097
+ rbln_kwargs=rbln_kwargs,
1098
+ )
1099
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
1100
+ return rbln_config