optimum-rbln 0.1.4__py3-none-any.whl → 0.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
  5. optimum/rbln/diffusers/models/controlnet.py +3 -0
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +3 -3
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +22 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +109 -53
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +114 -53
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
  16. optimum/rbln/modeling_alias.py +14 -0
  17. optimum/rbln/modeling_base.py +282 -100
  18. optimum/rbln/modeling_seq2seq.py +58 -132
  19. optimum/rbln/transformers/__init__.py +8 -0
  20. optimum/rbln/transformers/cache_utils.py +111 -0
  21. optimum/rbln/transformers/generation/utils.py +0 -2
  22. optimum/rbln/transformers/models/__init__.py +3 -0
  23. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  24. optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
  25. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  26. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +515 -0
  27. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +349 -0
  28. optimum/rbln/transformers/models/dpt/__init__.py +24 -0
  29. optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
  30. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  31. optimum/rbln/transformers/models/gemma/gemma_architecture.py +116 -0
  32. optimum/rbln/transformers/models/gemma/modeling_gemma.py +61 -0
  33. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +200 -174
  34. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +57 -293
  35. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -613
  36. optimum/rbln/transformers/models/llama/modeling_llama.py +9 -469
  37. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  38. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  39. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  40. optimum/rbln/transformers/models/midm/modeling_midm.py +40 -308
  41. optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
  42. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  43. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  44. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +125 -0
  45. optimum/rbln/utils/__init__.py +1 -1
  46. optimum/rbln/utils/import_utils.py +46 -0
  47. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/METADATA +18 -53
  48. optimum_rbln-0.1.8.dist-info/RECORD +73 -0
  49. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/WHEEL +1 -1
  50. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -759
  51. optimum_rbln-0.1.4.dist-info/RECORD +0 -63
  52. {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.8.dist-info}/licenses/LICENSE +0 -0
@@ -22,10 +22,12 @@
22
22
  # from Rebellions Inc.
23
23
 
24
24
  import logging
25
+ import os
26
+ import shutil
25
27
  from abc import ABC, abstractmethod
26
28
  from pathlib import Path
27
29
  from tempfile import TemporaryDirectory
28
- from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
30
+ from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
29
31
 
30
32
  import rebel
31
33
  import torch
@@ -37,7 +39,9 @@ from transformers import (
37
39
  AutoModel,
38
40
  AutoModelForAudioClassification,
39
41
  AutoModelForImageClassification,
42
+ AutoModelForMaskedLM,
40
43
  AutoModelForQuestionAnswering,
44
+ AutoModelForSequenceClassification,
41
45
  GenerationConfig,
42
46
  PretrainedConfig,
43
47
  )
@@ -50,16 +54,7 @@ from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
50
54
  logger = logging.getLogger(__name__)
51
55
 
52
56
  if TYPE_CHECKING:
53
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
54
-
55
-
56
- def listify(var: Any):
57
- if isinstance(var, list):
58
- return var
59
- elif var is not None:
60
- return [var]
61
- else:
62
- return None
57
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
63
58
 
64
59
 
65
60
  class RBLNBaseModel(OptimizedModel, ABC):
@@ -103,23 +98,22 @@ class RBLNBaseModel(OptimizedModel, ABC):
103
98
 
104
99
  def __init__(
105
100
  self,
106
- models: List[rebel.RBLNCompiledModel],
101
+ models: List[rebel.Runtime],
107
102
  config: "PretrainedConfig",
103
+ rbln_config: RBLNConfig,
108
104
  preprocessors: Optional[List],
109
- rbln_config: Optional[RBLNConfig],
110
- rbln_device: Optional[List[int]] = None,
111
- rbln_device_map: Optional[Dict[str, int]] = None,
112
- rbln_create_runtimes: Optional[bool] = None,
113
105
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
106
+ subfolder: str = "",
107
+ rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
114
108
  **kwargs,
115
109
  ):
116
110
  super().__init__(models, config)
117
111
  if not isinstance(self.config, PretrainedConfig): # if diffusers config
118
112
  self.config = PretrainedConfig(**self.config)
119
113
 
120
- self.models = listify(self.model)
121
-
114
+ self.rbln_config = rbln_config
122
115
  self.preprocessors = [] if preprocessors is None else preprocessors
116
+ self.compiled_models = rbln_compiled_models
123
117
 
124
118
  # Registers the RBLNBaseModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
125
119
  # a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
@@ -127,18 +121,6 @@ class RBLNBaseModel(OptimizedModel, ABC):
127
121
  if hasattr(self.auto_model_class, "register"):
128
122
  self.auto_model_class.register(AutoConfig, self.__class__)
129
123
 
130
- self.rbln_config = rbln_config
131
- self.compiled_models: List[rebel.RBLNCompiledModel] = models
132
-
133
- if rbln_device_map is None:
134
- self.rbln_device_map = {}
135
- device_val = 0 if rbln_device is None else rbln_device
136
- for key in self.rbln_config:
137
- self.rbln_device_map[key] = device_val
138
-
139
- else:
140
- self.rbln_device_map = rbln_device_map
141
-
142
124
  # copied from tranformers PreTrainedModel __init__
143
125
  self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
144
126
  if self.generation_config is not None:
@@ -146,15 +128,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
146
128
 
147
129
  self.device = torch.device("cpu")
148
130
 
149
- if rbln_create_runtimes is None:
150
- rbln_create_runtimes = rebel.npu_is_available()
151
-
152
- # create runtimes only if `rbln_create_runtimes` is enabled
153
- self.runtimes = self._create_runtimes(self.rbln_device_map) if rbln_create_runtimes else UnavailableRuntime()
154
-
155
131
  # FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
156
132
  # This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
157
- # would end-up removing the directory containing the underlying ONNX model.
133
+ # would end-up removing the directory containing the underlying RBLN model.
158
134
  self._model_save_dir_tempdirectory_instance = None
159
135
  if isinstance(model_save_dir, TemporaryDirectory):
160
136
  self._model_save_dir_tempdirectory_instance = model_save_dir
@@ -163,6 +139,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
163
139
  self.model_save_dir = Path(model_save_dir)
164
140
  else:
165
141
  self.model_save_dir = model_save_dir
142
+ self.subfolder = subfolder
166
143
 
167
144
  self.__post_init__(**kwargs)
168
145
 
@@ -178,11 +155,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
178
155
  save_directory (`Union[str, Path]`):
179
156
  Directory where to save the model file.
180
157
  """
181
-
182
- for compiled_model, compiled_model_name in zip(self.compiled_models, self.rbln_config):
183
- dst_path = Path(save_directory) / f"{compiled_model_name}.rbln"
184
- compiled_model.save(dst_path)
185
- self.rbln_config.save(save_directory)
158
+ real_save_dir = self.model_save_dir / self.subfolder
159
+ if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
160
+ shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
161
+ self.config.save_pretrained(save_directory)
162
+ if self.generation_config is not None:
163
+ self.generation_config.save_pretrained(save_directory)
164
+ else:
165
+ raise FileNotFoundError(f"Saving compiled model failed.({real_save_dir}).")
186
166
 
187
167
  @classmethod
188
168
  def _from_pretrained(
@@ -196,6 +176,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
196
176
  subfolder: str = "",
197
177
  local_files_only: bool = False,
198
178
  model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
179
+ # Runtime - related kwargs
180
+ rbln_device: Optional[List[int]] = None,
181
+ rbln_device_map: Optional[Dict[str, int]] = None,
182
+ rbln_create_runtimes: Optional[bool] = None,
183
+ # passed from compile function
184
+ rbln_config: Optional[RBLNConfig] = None,
185
+ rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
186
+ rbln_optimize_host_memory: Optional[bool] = None,
199
187
  **kwargs,
200
188
  ) -> "RBLNBaseModel":
201
189
  model_path = Path(model_id)
@@ -228,12 +216,15 @@ class RBLNBaseModel(OptimizedModel, ABC):
228
216
  )
229
217
 
230
218
  if model_path.is_dir():
231
- rbln_config = RBLNConfig.load(str(model_path))
232
- models = [
233
- rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
234
- for compiled_model_name in rbln_config
235
- ]
236
- new_model_save_dir = model_path
219
+ if rbln_compiled_models is None:
220
+ rbln_config = RBLNConfig.load(str(model_path))
221
+ rbln_compiled_models = [
222
+ rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
223
+ for compiled_model_name in rbln_config
224
+ ]
225
+ new_model_save_dir = model_path
226
+ else:
227
+ pass
237
228
 
238
229
  else:
239
230
  rbln_config_filename = rbln_config_filenames[0]
@@ -248,7 +239,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
248
239
  local_files_only=local_files_only,
249
240
  )
250
241
  rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
251
- models = []
242
+ rbln_compiled_models = []
252
243
  for compiled_model_name in rbln_config:
253
244
  model_cache_path = hf_hub_download(
254
245
  repo_id=model_id,
@@ -260,7 +251,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
260
251
  force_download=force_download,
261
252
  local_files_only=local_files_only,
262
253
  )
263
- models.append(rebel.RBLNCompiledModel(model_cache_path))
254
+ rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
264
255
  new_model_save_dir = Path(rbln_config_cache_path).parent
265
256
 
266
257
  preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
@@ -268,17 +259,40 @@ class RBLNBaseModel(OptimizedModel, ABC):
268
259
  if model_save_dir is None:
269
260
  model_save_dir = new_model_save_dir
270
261
 
262
+ # Create runtimes
263
+ if rbln_create_runtimes is None:
264
+ rbln_create_runtimes = rebel.npu_is_available()
265
+ if rbln_device_map is None:
266
+ rbln_device_map = {}
267
+ device_val = 0 if rbln_device is None else rbln_device
268
+ for key in rbln_config:
269
+ rbln_device_map[key] = device_val
270
+ else:
271
+ rbln_device_map = rbln_device_map
272
+
273
+ # create runtimes only if `rbln_create_runtimes` is enabled
274
+ models = (
275
+ cls._create_runtimes(rbln_compiled_models, rbln_device_map)
276
+ if rbln_create_runtimes
277
+ else UnavailableRuntime()
278
+ )
279
+
280
+ if rbln_optimize_host_memory is None:
281
+ rbln_optimize_host_memory = True
282
+
271
283
  return cls(
272
284
  models,
273
285
  config,
286
+ rbln_config,
274
287
  preprocessors,
275
- rbln_config=rbln_config,
276
288
  model_save_dir=model_save_dir,
289
+ subfolder=subfolder,
290
+ rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
277
291
  **kwargs,
278
292
  )
279
293
 
280
294
  def __repr__(self):
281
- return repr(self.runtimes)
295
+ return repr(self.model)
282
296
 
283
297
  @classmethod
284
298
  def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
@@ -338,7 +352,15 @@ class RBLNBaseModel(OptimizedModel, ABC):
338
352
  def pop_rbln_kwargs_from_kwargs(kwargs: dict):
339
353
  keys = list(kwargs.keys())
340
354
  rbln_constructor_kwargs = {
341
- key: kwargs.pop(key) for key in keys if key in ["rbln_device", "rbln_device_map", "rbln_create_runtimes"]
355
+ key: kwargs.pop(key)
356
+ for key in keys
357
+ if key
358
+ in [
359
+ "rbln_device",
360
+ "rbln_device_map",
361
+ "rbln_create_runtimes",
362
+ "rbln_optimize_host_memory",
363
+ ]
342
364
  }
343
365
 
344
366
  keys = list(kwargs.keys())
@@ -375,9 +397,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
375
397
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
376
398
  pass
377
399
 
400
+ @classmethod
378
401
  @abstractmethod
379
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
380
- # self.compiled_models -> self.runtimes
402
+ def _create_runtimes(
403
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
404
+ ) -> List[rebel.Runtime]:
405
+ # compiled_models -> runtimes
381
406
  pass
382
407
 
383
408
  @classmethod
@@ -417,14 +442,26 @@ class RBLNModel(RBLNBaseModel):
417
442
  ```
418
443
  """
419
444
 
420
- model_type = "rbln_model"
421
- auto_model_class = AutoModel # feature extraction
445
+ @classmethod
446
+ def update_kwargs(cls, kwargs):
447
+ """
448
+ Update user-given kwargs to get proper pytorch model.
449
+
450
+ For example, `torchscript`=True should be set because torch.jit
451
+ does not support `transformers` output instances as module output;
452
+ """
453
+ kwargs.update(
454
+ {
455
+ "torchscript": True,
456
+ "return_dict": False,
457
+ }
458
+ )
459
+ return kwargs
422
460
 
423
461
  @classmethod
424
- def _export(
462
+ def get_pytorch_model(
425
463
  cls,
426
- model_id: Union[str, Path],
427
- config: "PretrainedConfig",
464
+ model_id: str,
428
465
  use_auth_token: Optional[Union[bool, str]] = None,
429
466
  revision: Optional[str] = None,
430
467
  force_download: bool = False,
@@ -432,16 +469,62 @@ class RBLNModel(RBLNBaseModel):
432
469
  subfolder: str = "",
433
470
  local_files_only: bool = False,
434
471
  trust_remote_code: bool = False,
435
- model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
472
+ rbln_config_kwargs: Optional[Dict[str, Any]] = None,
473
+ rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
436
474
  **kwargs,
437
- ) -> "RBLNModel":
438
- """
439
- Exports a vanilla Transformers model into a rbln-compiled Module.
440
- """
475
+ ) -> "PreTrainedModel":
441
476
  task = kwargs.pop("task", None)
442
477
  if task is None:
443
478
  task = TasksManager.infer_task_from_model(cls.auto_model_class)
444
479
 
480
+ kwargs = cls.update_kwargs(kwargs)
481
+
482
+ model = TasksManager.get_model_from_task(
483
+ task=task,
484
+ model_name_or_path=model_id,
485
+ subfolder=subfolder,
486
+ revision=revision,
487
+ framework="pt",
488
+ cache_dir=cache_dir,
489
+ use_auth_token=use_auth_token,
490
+ local_files_only=local_files_only,
491
+ force_download=force_download,
492
+ trust_remote_code=trust_remote_code,
493
+ **kwargs,
494
+ )
495
+
496
+ return model
497
+
498
+ @classmethod
499
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
500
+ model = cls.wrap_model_if_needed(model)
501
+ rbln_runtime_configs = list(rbln_config.values())
502
+ if len(rbln_runtime_configs) != 1:
503
+ raise ValueError
504
+ rbln_runtime_config = rbln_runtime_configs[0]
505
+ if len(rbln_runtime_config) != 1:
506
+ raise ValueError
507
+ rbln_runtime_config = rbln_runtime_config[0]
508
+
509
+ compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
510
+ return compiled_model
511
+
512
+ @classmethod
513
+ @torch.no_grad()
514
+ def _export(
515
+ cls,
516
+ model_id: str,
517
+ config: "PretrainedConfig",
518
+ use_auth_token: Optional[Union[bool, str]] = None,
519
+ revision: Optional[str] = None,
520
+ force_download: bool = False,
521
+ cache_dir: Optional[str] = None,
522
+ subfolder: str = "",
523
+ local_files_only: bool = False,
524
+ trust_remote_code: bool = False,
525
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
526
+ **kwargs,
527
+ ) -> "RBLNModel":
445
528
  if model_save_dir is None:
446
529
  save_dir = TemporaryDirectory()
447
530
  save_dir_path = Path(save_dir.name)
@@ -453,35 +536,24 @@ class RBLNModel(RBLNBaseModel):
453
536
  save_dir_path = Path(model_save_dir)
454
537
  save_dir_path.mkdir(exist_ok=True)
455
538
 
456
- kwargs.update(
457
- {
458
- "torchscript": True,
459
- "return_dict": False,
460
- }
461
- )
462
-
463
539
  rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
464
540
 
465
- model = TasksManager.get_model_from_task(
466
- task=task,
467
- model_name_or_path=model_id,
541
+ model: "PreTrainedModel" = cls.get_pytorch_model(
542
+ model_id=model_id,
468
543
  subfolder=subfolder,
469
544
  revision=revision,
470
- framework="pt",
471
545
  cache_dir=cache_dir,
472
546
  use_auth_token=use_auth_token,
473
547
  local_files_only=local_files_only,
474
548
  force_download=force_download,
475
549
  trust_remote_code=trust_remote_code,
550
+ rbln_config_kwargs=rbln_config_kwargs,
551
+ rbln_constructor_kwargs=rbln_constructor_kwargs,
476
552
  **kwargs,
477
553
  )
478
554
 
479
- # TODO : do we need this?
480
- if isinstance(model, torch.nn.Module):
481
- model.eval()
482
-
483
- if config is None:
484
- config = model.config
555
+ # FIXME :: optimum passes AutoConfig.
556
+ config = model.config
485
557
 
486
558
  if not isinstance(config, PretrainedConfig): # diffusers config
487
559
  config = PretrainedConfig(**config)
@@ -492,20 +564,22 @@ class RBLNModel(RBLNBaseModel):
492
564
  # Get compilation arguments
493
565
  if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
494
566
  rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
567
+ compiled_model = cls.get_compiled_model(model, rbln_config=rbln_config)
495
568
 
496
- rbln_runtime_configs = list(rbln_config.values())
497
- if len(rbln_runtime_configs) != 1:
498
- raise ValueError
499
- rbln_runtime_config = rbln_runtime_configs[0]
500
- if len(rbln_runtime_config) != 1:
501
- raise ValueError
502
- rbln_runtime_config = rbln_runtime_config[0]
569
+ # Save compiled models
570
+ (save_dir_path / subfolder).mkdir(exist_ok=True)
571
+ if isinstance(compiled_model, Iterable):
572
+ # compiled_model is an Iterable instance
573
+ for single_compiled_model, compiled_model_name in zip(compiled_model, rbln_config):
574
+ single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
575
+ compiled_models = compiled_model
503
576
 
504
- model = cls.wrap_model_if_needed(model)
505
- compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
506
- compiled_model.save(save_dir_path / subfolder / f"{rbln_runtime_config.compiled_model_name}.rbln")
577
+ else:
578
+ compiled_model.save(save_dir_path / subfolder / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
579
+ compiled_models = [compiled_model]
507
580
  rbln_config.save(save_dir_path / subfolder)
508
581
 
582
+ # Instantiate
509
583
  return cls._from_pretrained(
510
584
  model_id=save_dir_path,
511
585
  config=config,
@@ -516,23 +590,23 @@ class RBLNModel(RBLNBaseModel):
516
590
  cache_dir=cache_dir,
517
591
  subfolder=subfolder,
518
592
  local_files_only=local_files_only,
593
+ rbln_config=rbln_config,
594
+ rbln_compiled_models=compiled_models,
519
595
  **rbln_constructor_kwargs,
520
596
  **kwargs,
521
597
  )
522
598
 
523
- def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
599
+ @classmethod
600
+ def _create_runtimes(
601
+ cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
602
+ ) -> List[rebel.Runtime]:
524
603
  device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
525
- return [
526
- compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in self.compiled_models
527
- ]
604
+ return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
528
605
 
529
606
  def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
530
- output = self.runtimes[0](*args, **kwargs)
607
+ output = self.model[0](*args, **kwargs)
531
608
  return output
532
609
 
533
- def __repr__(self):
534
- return repr(self.runtimes[0])
535
-
536
610
 
537
611
  class RBLNModelForQuestionAnswering(RBLNModel):
538
612
  model_type = "rbln_model"
@@ -676,3 +750,111 @@ class RBLNModelForAudioClassification(RBLNModel):
676
750
  )
677
751
 
678
752
  return rbln_config
753
+
754
+
755
+ class RBLNModelForSequenceClassification(RBLNModel):
756
+ """
757
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
758
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
759
+
760
+ A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
761
+ It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
762
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
763
+ - compiling the resulting graph using the RBLN compiler.
764
+
765
+ Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
766
+ """
767
+
768
+ model_type = "rbln_model"
769
+ auto_model_class = AutoModelForSequenceClassification
770
+
771
+ @classmethod
772
+ def _get_rbln_config(
773
+ cls,
774
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
775
+ model_config: Optional["PretrainedConfig"] = None,
776
+ rbln_max_seq_len: Optional[int] = None,
777
+ rbln_model_input_names: Optional[List[str]] = None,
778
+ rbln_batch_size: Optional[int] = None,
779
+ ) -> RBLNConfig:
780
+
781
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
782
+ model_config, "max_position_embeddings", None
783
+ )
784
+
785
+ if rbln_max_seq_len is None:
786
+ rbln_max_seq_len = max_position_embeddings
787
+ if rbln_max_seq_len is None:
788
+ for tokenizer in preprocessors:
789
+ if hasattr(tokenizer, "model_max_length"):
790
+ rbln_max_seq_len = tokenizer.model_max_length
791
+ break
792
+ if rbln_max_seq_len is None:
793
+ raise ValueError("`rbln_max_seq_len` should be specified!")
794
+
795
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
796
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
797
+
798
+ if rbln_model_input_names is None:
799
+ # These are BERT's inputs
800
+ rbln_model_input_names = ["input_ids", "attention_mask"]
801
+
802
+ if rbln_batch_size is None:
803
+ rbln_batch_size = 1
804
+ input_info = [
805
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
806
+ for model_input_name in rbln_model_input_names
807
+ ]
808
+
809
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
810
+ rbln_runtime_config.batch_size = rbln_batch_size
811
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
812
+
813
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
814
+
815
+ class RBLNModelForMaskedLM(RBLNModel):
816
+ model_type = "rbln_model"
817
+ auto_model_class = AutoModelForMaskedLM
818
+
819
+ @classmethod
820
+ def _get_rbln_config(
821
+ cls,
822
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
823
+ model_config: Optional["PretrainedConfig"] = None,
824
+ rbln_max_seq_len: Optional[int] = None,
825
+ rbln_model_input_names: Optional[List[str]] = None,
826
+ rbln_batch_size: Optional[int] = None,
827
+ ) -> RBLNConfig:
828
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
829
+ model_config, "max_position_embeddings", None
830
+ )
831
+
832
+ if rbln_max_seq_len is None:
833
+ rbln_max_seq_len = max_position_embeddings
834
+ if rbln_max_seq_len is None:
835
+ for tokenizer in preprocessors:
836
+ if hasattr(tokenizer, "model_max_length"):
837
+ rbln_max_seq_len = tokenizer.model_max_length
838
+ break
839
+ if rbln_max_seq_len is None:
840
+ raise ValueError("`rbln_max_seq_len` should be specified!")
841
+
842
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
843
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
844
+
845
+ if rbln_model_input_names is None:
846
+ # These are BERT's inputs
847
+ rbln_model_input_names = ["input_ids", "attention_mask"]
848
+
849
+ if rbln_batch_size is None:
850
+ rbln_batch_size = 1
851
+ input_info = [
852
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
853
+ for model_input_name in rbln_model_input_names
854
+ ]
855
+
856
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
857
+ rbln_runtime_config.batch_size = rbln_batch_size
858
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
859
+
860
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)