optimum-rbln 0.1.7__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. optimum/rbln/__init__.py +17 -0
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/__init__.py +0 -1
  4. optimum/rbln/diffusers/models/autoencoder_kl.py +3 -3
  5. optimum/rbln/diffusers/models/controlnet.py +7 -3
  6. optimum/rbln/diffusers/models/unet_2d_condition.py +5 -5
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +23 -146
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +107 -59
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +106 -54
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +130 -71
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +131 -72
  12. optimum/rbln/modeling_alias.py +19 -1
  13. optimum/rbln/modeling_base.py +162 -18
  14. optimum/rbln/transformers/__init__.py +8 -0
  15. optimum/rbln/transformers/cache_utils.py +111 -0
  16. optimum/rbln/transformers/generation/utils.py +0 -2
  17. optimum/rbln/transformers/models/__init__.py +3 -0
  18. optimum/rbln/transformers/models/bart/bart_architecture.py +0 -5
  19. optimum/rbln/transformers/models/clip/modeling_clip.py +1 -1
  20. optimum/rbln/transformers/models/decoderonly/__init__.py +36 -0
  21. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +516 -0
  22. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +464 -0
  23. optimum/rbln/transformers/models/gemma/__init__.py +24 -0
  24. optimum/rbln/transformers/models/gemma/gemma_architecture.py +123 -0
  25. optimum/rbln/transformers/models/gemma/modeling_gemma.py +67 -0
  26. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +201 -166
  27. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +10 -257
  28. optimum/rbln/transformers/models/llama/llama_architecture.py +3 -610
  29. optimum/rbln/transformers/models/llama/modeling_llama.py +12 -440
  30. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +2 -1
  31. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -4
  32. optimum/rbln/transformers/models/midm/midm_architecture.py +160 -357
  33. optimum/rbln/transformers/models/midm/modeling_midm.py +10 -325
  34. optimum/rbln/transformers/models/mistral/__init__.py +24 -0
  35. optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
  36. optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
  37. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  38. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -6
  39. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  40. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +131 -0
  41. optimum/rbln/transformers/utils/__init__.py +0 -0
  42. optimum/rbln/transformers/utils/rbln_quantization.py +109 -0
  43. optimum/rbln/utils/import_utils.py +1 -4
  44. optimum/rbln/utils/runtime_utils.py +2 -1
  45. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/METADATA +11 -5
  46. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/RECORD +48 -35
  47. optimum/rbln/transformers/models/llama/llama_architecture_cb.py +0 -764
  48. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/WHEEL +0 -0
  49. {optimum_rbln-0.1.7.dist-info → optimum_rbln-0.1.9.dist-info}/licenses/LICENSE +0 -0
@@ -22,17 +22,17 @@
22
22
  # from Rebellions Inc.
23
23
  """RBLNStableDiffusionXLPipeline class for inference of diffusion models on rbln devices."""
24
24
 
25
- from pathlib import Path
26
- from tempfile import TemporaryDirectory
27
25
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
28
26
 
29
27
  import torch
30
28
  import torch.nn.functional as F
31
- from diffusers import StableDiffusionXLControlNetImg2ImgPipeline
29
+ from diffusers import AutoencoderKL, ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline
32
30
  from diffusers.image_processor import PipelineImageInput
31
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
33
32
  from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
34
33
  from diffusers.utils import deprecate, logging
35
34
  from diffusers.utils.torch_utils import is_compiled_module
35
+ from transformers import CLIPTextModel
36
36
 
37
37
  from ....modeling_base import RBLNBaseModel
38
38
  from ....transformers import RBLNCLIPTextModel, RBLNCLIPTextModelWithProjection
@@ -63,103 +63,152 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
63
63
  - A path to a *directory* containing a model saved using [`~OptimizedModel.save_pretrained`],
64
64
  """
65
65
  export = kwargs.pop("export", None)
66
- text_encoder = kwargs.pop("text_encoder", None)
67
- controlnets = kwargs.pop("controlnet", None)
68
66
  vae = kwargs.pop("vae", None)
67
+ unet = kwargs.pop("unet", None)
68
+ text_encoder = kwargs.pop("text_encoder", None)
69
+ text_encoder_2 = kwargs.pop("text_encoder_2", None)
70
+ controlnet = kwargs.pop("controlnet", None)
71
+ model_save_dir = kwargs.pop("model_save_dir", None)
69
72
 
70
73
  rbln_config_kwargs, rbln_constructor_kwargs = RBLNBaseModel.pop_rbln_kwargs_from_kwargs(kwargs)
74
+
71
75
  kwargs_dict = {
72
76
  "pretrained_model_name_or_path": model_id,
73
- "vae": vae,
74
- "controlnet": controlnets,
75
- "text_encoder": text_encoder,
76
77
  **kwargs,
77
78
  }
78
79
 
80
+ kwargs_dict.update(
81
+ {
82
+ **({"vae": vae} if vae is not None and isinstance(vae, AutoencoderKL) else {}),
83
+ **({"unet": unet} if unet is not None and isinstance(unet, UNet2DConditionModel) else {}),
84
+ **(
85
+ {"text_encoder": text_encoder}
86
+ if text_encoder is not None and isinstance(text_encoder, CLIPTextModel)
87
+ else {}
88
+ ),
89
+ **(
90
+ {"controlnet": controlnet}
91
+ if controlnet is not None
92
+ and (
93
+ isinstance(controlnet, ControlNetModel)
94
+ or all(isinstance(c, ControlNetModel) for c in controlnet)
95
+ )
96
+ else {}
97
+ ),
98
+ }
99
+ )
100
+
79
101
  model = super().from_pretrained(**{k: v for k, v in kwargs_dict.items() if v is not None})
80
102
 
81
103
  if export is None or export is False:
82
104
  return model
83
105
 
84
- save_dir = TemporaryDirectory()
85
- save_dir_path = Path(save_dir.name)
86
-
87
- model.save_pretrained(save_directory=save_dir_path, **kwargs)
88
-
89
106
  do_classifier_free_guidance = (
90
107
  rbln_config_kwargs.pop("rbln_guidance_scale", 5.0) > 1.0 and model.unet.config.time_cond_proj_dim is None
91
108
  )
92
109
 
93
- vae = RBLNAutoencoderKL.from_pretrained(
94
- model_id=model_id,
95
- subfolder="vae",
96
- export=True,
97
- rbln_unet_sample_size=model.unet.config.sample_size,
98
- rbln_use_encode=True,
99
- rbln_vae_scale_factor=model.vae_scale_factor,
100
- **rbln_config_kwargs,
101
- **rbln_constructor_kwargs,
102
- )
103
- text_encoder = RBLNCLIPTextModel.from_pretrained(
104
- model_id=model_id,
105
- subfolder="text_encoder",
106
- export=True,
107
- **rbln_config_kwargs,
108
- **rbln_constructor_kwargs,
109
- )
110
- text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
111
- model_id=model_id,
112
- subfolder="text_encoder_2",
113
- export=True,
114
- **rbln_config_kwargs,
115
- **rbln_constructor_kwargs,
116
- )
117
-
118
- batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
119
- unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
110
+ if not isinstance(vae, RBLNAutoencoderKL):
111
+ vae = RBLNAutoencoderKL.from_pretrained(
112
+ model_id=model_id,
113
+ subfolder="vae",
114
+ export=True,
115
+ model_save_dir=model_save_dir,
116
+ rbln_unet_sample_size=model.unet.config.sample_size,
117
+ rbln_use_encode=True,
118
+ rbln_vae_scale_factor=model.vae_scale_factor,
119
+ **rbln_config_kwargs,
120
+ **rbln_constructor_kwargs,
121
+ )
120
122
 
121
- unet = RBLNUNet2DConditionModel.from_pretrained(
122
- model_id=model_id,
123
- subfolder="unet",
124
- export=True,
125
- rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
126
- rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
127
- rbln_batch_size=unet_batch_size,
128
- rbln_use_encode=True,
129
- rbln_vae_scale_factor=model.vae_scale_factor,
130
- rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
131
- **rbln_config_kwargs,
132
- **rbln_constructor_kwargs,
133
- )
123
+ if not isinstance(text_encoder, RBLNCLIPTextModel):
124
+ text_encoder = RBLNCLIPTextModel.from_pretrained(
125
+ model_id=model_id,
126
+ subfolder="text_encoder",
127
+ export=True,
128
+ model_save_dir=model_save_dir,
129
+ **rbln_config_kwargs,
130
+ **rbln_constructor_kwargs,
131
+ )
134
132
 
135
- if isinstance(controlnets, (list, tuple)):
136
- controlnet = RBLNMultiControlNetModel.from_pretrained(
137
- model_id=str(save_dir_path / "controlnet"),
133
+ if not isinstance(text_encoder_2, RBLNCLIPTextModel):
134
+ text_encoder_2 = RBLNCLIPTextModelWithProjection.from_pretrained(
135
+ model_id=model_id,
136
+ subfolder="text_encoder_2",
138
137
  export=True,
139
- rbln_batch_size=unet_batch_size,
140
- rbln_vae_scale_factor=model.vae_scale_factor,
138
+ model_save_dir=model_save_dir,
141
139
  **rbln_config_kwargs,
142
140
  **rbln_constructor_kwargs,
143
141
  )
144
- controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
145
- else:
146
- controlnet = RBLNControlNetModel.from_pretrained(
147
- model_id=save_dir_path / "controlnet",
142
+
143
+ batch_size = rbln_config_kwargs.pop("rbln_batch_size", 1)
144
+ unet_batch_size = batch_size * 2 if do_classifier_free_guidance else batch_size
145
+
146
+ if not isinstance(unet, RBLNUNet2DConditionModel):
147
+ unet = RBLNUNet2DConditionModel.from_pretrained(
148
+ model_id=model_id,
149
+ subfolder="unet",
148
150
  export=True,
149
- rbln_batch_size=unet_batch_size,
151
+ model_save_dir=model_save_dir,
152
+ rbln_max_seq_len=model.text_encoder.config.max_position_embeddings,
150
153
  rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
154
+ rbln_batch_size=unet_batch_size,
155
+ rbln_use_encode=True,
151
156
  rbln_vae_scale_factor=model.vae_scale_factor,
157
+ rbln_is_controlnet=True if "controlnet" in model.config.keys() else False,
152
158
  **rbln_config_kwargs,
153
159
  **rbln_constructor_kwargs,
154
160
  )
155
- controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
156
161
 
162
+ if not isinstance(controlnet, (RBLNControlNetModel, RBLNMultiControlNetModel)):
163
+ if isinstance(controlnet, (list, tuple)):
164
+ multicontrolnet = []
165
+ for i, cid in enumerate(controlnet):
166
+ subfolder_name = "controlnet" if i == 0 else f"controlnet_{i}"
167
+ multicontrolnet.append(
168
+ RBLNControlNetModel.from_pretrained(
169
+ model_id=cid.config._name_or_path,
170
+ subfolder=subfolder_name,
171
+ export=True,
172
+ model_save_dir=model_save_dir,
173
+ rbln_batch_size=unet_batch_size,
174
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
175
+ rbln_vae_scale_factor=model.vae_scale_factor,
176
+ **rbln_config_kwargs,
177
+ **rbln_constructor_kwargs,
178
+ )
179
+ )
180
+ controlnet = RBLNMultiControlNetModel(multicontrolnet, config=controlnet[0].config)
181
+ controlnet_dict = ("optimum.rbln", "RBLNMultiControlNetModel")
182
+ else:
183
+ controlnet = RBLNControlNetModel.from_pretrained(
184
+ model_id=controlnet.config._name_or_path,
185
+ subfolder="controlnet",
186
+ export=True,
187
+ model_save_dir=model_save_dir,
188
+ rbln_batch_size=unet_batch_size,
189
+ rbln_text_model_hidden_size=model.text_encoder_2.config.hidden_size,
190
+ rbln_vae_scale_factor=model.vae_scale_factor,
191
+ **rbln_config_kwargs,
192
+ **rbln_constructor_kwargs,
193
+ )
194
+ controlnet_dict = ("optimum.rbln", "RBLNControlNetModel")
195
+
196
+ if model_save_dir is not None:
197
+ # To skip saving original pytorch modules
198
+ del (model.vae, model.text_encoder, model.unet, model.controlnet)
199
+
200
+ # Direct calling of `save_pretrained` causes config.unet = (None, None).
201
+ # So config must be saved again, later.
202
+ model.save_pretrained(model_save_dir)
203
+
204
+ # replace modules
157
205
  model.vae = vae
158
206
  model.text_encoder = text_encoder
159
207
  model.unet = unet
160
208
  model.text_encoder_2 = text_encoder_2
161
209
  model.controlnet = controlnet
162
210
 
211
+ # update config to be able to load from file
163
212
  update_dict = {
164
213
  "vae": ("optimum.rbln", "RBLNAutoencoderKL"),
165
214
  "text_encoder": ("optimum.rbln", "RBLNCLIPTextModel"),
@@ -169,14 +218,24 @@ class RBLNStableDiffusionXLControlNetImg2ImgPipeline(StableDiffusionXLControlNet
169
218
  }
170
219
  model.register_to_config(**update_dict)
171
220
 
172
- model.models = [
173
- vae.model[0],
174
- vae.model[1],
175
- unet.model[0],
176
- text_encoder.model[0],
177
- text_encoder_2.model[0],
178
- controlnet.model[0],
179
- ]
221
+ if model_save_dir is not None:
222
+ # overwrite to replace incorrect config
223
+ model.save_config(model_save_dir)
224
+
225
+ # use for CI to access each compiled model
226
+ if rbln_constructor_kwargs.pop("rbln_optimize_host_memory", None) is False:
227
+ model.compiled_models = [
228
+ vae.compiled_models[0],
229
+ vae.compiled_models[1],
230
+ text_encoder.compiled_models[0],
231
+ text_encoder_2.compiled_models[0],
232
+ unet.compiled_models[0],
233
+ ]
234
+ if isinstance(controlnet, RBLNMultiControlNetModel):
235
+ for c_model in controlnet.nets:
236
+ model.compiled_models.append(c_model.compiled_models[0])
237
+ else:
238
+ model.compiled_models.append(controlnet.compiled_models[0])
180
239
 
181
240
  return model
182
241
 
@@ -24,7 +24,9 @@
24
24
  from .modeling_base import (
25
25
  RBLNModelForAudioClassification,
26
26
  RBLNModelForImageClassification,
27
+ RBLNModelForMaskedLM,
27
28
  RBLNModelForQuestionAnswering,
29
+ RBLNModelForSequenceClassification,
28
30
  )
29
31
  from .modeling_seq2seq import RBLNModelForSeq2SeqLM
30
32
 
@@ -34,7 +36,11 @@ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
34
36
 
35
37
 
36
38
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
37
- pass
39
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
40
+
41
+
42
+ class RBLNDistilBertForQuestionAnswering(RBLNModelForQuestionAnswering):
43
+ rbln_model_input_names = ["input_ids", "attention_mask"]
38
44
 
39
45
 
40
46
  class RBLNResNetForImageClassification(RBLNModelForImageClassification):
@@ -47,3 +53,15 @@ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
47
53
 
48
54
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
49
55
  pass
56
+
57
+
58
+ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
59
+ pass
60
+
61
+
62
+ class RBLNRobertaForSequenceClassification(RBLNModelForSequenceClassification):
63
+ pass
64
+
65
+
66
+ class RBLNRobertaForMaskedLM(RBLNModelForMaskedLM):
67
+ pass
@@ -39,7 +39,9 @@ from transformers import (
39
39
  AutoModel,
40
40
  AutoModelForAudioClassification,
41
41
  AutoModelForImageClassification,
42
+ AutoModelForMaskedLM,
42
43
  AutoModelForQuestionAnswering,
44
+ AutoModelForSequenceClassification,
43
45
  GenerationConfig,
44
46
  PretrainedConfig,
45
47
  )
@@ -49,10 +51,15 @@ from .utils.runtime_utils import UnavailableRuntime
49
51
  from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
50
52
 
51
53
 
52
- logger = logging.getLogger(__name__)
53
-
54
54
  if TYPE_CHECKING:
55
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
55
+ from transformers import (
56
+ AutoFeatureExtractor,
57
+ AutoProcessor,
58
+ AutoTokenizer,
59
+ PreTrainedModel,
60
+ )
61
+
62
+ logger = logging.getLogger(__name__)
56
63
 
57
64
 
58
65
  class RBLNBaseModel(OptimizedModel, ABC):
@@ -154,13 +161,23 @@ class RBLNBaseModel(OptimizedModel, ABC):
154
161
  Directory where to save the model file.
155
162
  """
156
163
  real_save_dir = self.model_save_dir / self.subfolder
164
+ save_directory_path = Path(save_directory)
157
165
  if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
166
+ if save_directory_path.absolute() == real_save_dir.absolute():
167
+ raise FileExistsError(
168
+ f"Cannot save model to '{save_directory}'. "
169
+ f"This directory already exists and contains the model files."
170
+ )
158
171
  shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
159
172
  self.config.save_pretrained(save_directory)
160
173
  if self.generation_config is not None:
161
174
  self.generation_config.save_pretrained(save_directory)
162
175
  else:
163
- raise FileNotFoundError(f"Saving compiled model failed.({real_save_dir}).")
176
+ raise FileNotFoundError(
177
+ f"Unable to save the model. The model directory '{real_save_dir}' does not exist or is not accessible. "
178
+ f"Cannot save to the specified destination '{save_directory}'. "
179
+ f"Please ensure the model directory exists and you have the necessary permissions to access it."
180
+ )
164
181
 
165
182
  @classmethod
166
183
  def _from_pretrained(
@@ -194,7 +211,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
194
211
  token = HfFolder().get_token()
195
212
  else:
196
213
  token = use_auth_token
197
- repo_files = list(map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)))
214
+ repo_files = list(
215
+ map(
216
+ Path,
217
+ HfApi().list_repo_files(model_id, revision=revision, token=token),
218
+ )
219
+ )
198
220
 
199
221
  pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
200
222
  rbln_files = [p for p in repo_files if p.match(pattern)]
@@ -285,7 +307,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
285
307
  preprocessors,
286
308
  model_save_dir=model_save_dir,
287
309
  subfolder=subfolder,
288
- rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
310
+ rbln_compiled_models=(None if rbln_optimize_host_memory else rbln_compiled_models),
289
311
  **kwargs,
290
312
  )
291
313
 
@@ -375,7 +397,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
375
397
  return self.forward(*args, **kwargs)
376
398
 
377
399
  @classmethod
378
- def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
400
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
379
401
  # Wrap the model if needed.
380
402
  return model
381
403
 
@@ -398,7 +420,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
398
420
  @classmethod
399
421
  @abstractmethod
400
422
  def _create_runtimes(
401
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
423
+ cls,
424
+ compiled_models: List[rebel.RBLNCompiledModel],
425
+ rbln_device_map: Dict[str, int],
402
426
  ) -> List[rebel.Runtime]:
403
427
  # compiled_models -> runtimes
404
428
  pass
@@ -495,7 +519,7 @@ class RBLNModel(RBLNBaseModel):
495
519
 
496
520
  @classmethod
497
521
  def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
498
- model = cls.wrap_model_if_needed(model)
522
+ model = cls.wrap_model_if_needed(model, rbln_config)
499
523
  rbln_runtime_configs = list(rbln_config.values())
500
524
  if len(rbln_runtime_configs) != 1:
501
525
  raise ValueError
@@ -596,7 +620,9 @@ class RBLNModel(RBLNBaseModel):
596
620
 
597
621
  @classmethod
598
622
  def _create_runtimes(
599
- cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
623
+ cls,
624
+ compiled_models: List[rebel.RBLNCompiledModel],
625
+ rbln_device_map: Dict[str, int],
600
626
  ) -> List[rebel.Runtime]:
601
627
  device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
602
628
  return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
@@ -616,8 +642,8 @@ class RBLNModelForQuestionAnswering(RBLNModel):
616
642
  preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
617
643
  model_config: Optional["PretrainedConfig"] = None,
618
644
  rbln_max_seq_len: Optional[int] = None,
619
- rbln_model_input_names: Optional[List[str]] = None,
620
645
  rbln_batch_size: Optional[int] = None,
646
+ rbln_model_input_names: Optional[List[str]] = None,
621
647
  ) -> RBLNConfig:
622
648
  if rbln_max_seq_len is None:
623
649
  for tokenizer in preprocessors:
@@ -627,15 +653,15 @@ class RBLNModelForQuestionAnswering(RBLNModel):
627
653
  if rbln_max_seq_len is None:
628
654
  raise ValueError("`rbln_max_seq_len` should be specified!")
629
655
 
630
- if rbln_model_input_names is None:
631
- # These are BERT's inputs
632
- rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
633
-
634
656
  if rbln_batch_size is None:
635
657
  rbln_batch_size = 1
658
+
659
+ if rbln_model_input_names is not None:
660
+ cls.rbln_model_input_names = rbln_model_input_names
661
+
636
662
  input_info = [
637
663
  (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
638
- for model_input_name in rbln_model_input_names
664
+ for model_input_name in cls.rbln_model_input_names
639
665
  ]
640
666
 
641
667
  rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
@@ -672,7 +698,13 @@ class RBLNModelForImageClassification(RBLNModel):
672
698
  if rbln_batch_size is None:
673
699
  rbln_batch_size = 1
674
700
 
675
- input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size, rbln_image_size], "float32")]
701
+ input_info = [
702
+ (
703
+ "pixel_values",
704
+ [rbln_batch_size, 3, rbln_image_size, rbln_image_size],
705
+ "float32",
706
+ )
707
+ ]
676
708
 
677
709
  rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
678
710
  rbln_runtime_config.batch_size = rbln_batch_size
@@ -737,7 +769,11 @@ class RBLNModelForAudioClassification(RBLNModel):
737
769
  meta["rbln_num_mel_bins"] = rbln_num_mel_bins
738
770
 
739
771
  model_input_info = [
740
- ("input_values", [rbln_batch_size, rbln_max_length, rbln_num_mel_bins], "float32"),
772
+ (
773
+ "input_values",
774
+ [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
775
+ "float32",
776
+ ),
741
777
  ]
742
778
 
743
779
  rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
@@ -748,3 +784,111 @@ class RBLNModelForAudioClassification(RBLNModel):
748
784
  )
749
785
 
750
786
  return rbln_config
787
+
788
+
789
+ class RBLNModelForSequenceClassification(RBLNModel):
790
+ """
791
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
792
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
793
+
794
+ A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
795
+ It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
796
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
797
+ - compiling the resulting graph using the RBLN compiler.
798
+
799
+ Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
800
+ """
801
+
802
+ model_type = "rbln_model"
803
+ auto_model_class = AutoModelForSequenceClassification
804
+
805
+ @classmethod
806
+ def _get_rbln_config(
807
+ cls,
808
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
809
+ model_config: Optional["PretrainedConfig"] = None,
810
+ rbln_max_seq_len: Optional[int] = None,
811
+ rbln_model_input_names: Optional[List[str]] = None,
812
+ rbln_batch_size: Optional[int] = None,
813
+ ) -> RBLNConfig:
814
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
815
+ model_config, "max_position_embeddings", None
816
+ )
817
+
818
+ if rbln_max_seq_len is None:
819
+ rbln_max_seq_len = max_position_embeddings
820
+ if rbln_max_seq_len is None:
821
+ for tokenizer in preprocessors:
822
+ if hasattr(tokenizer, "model_max_length"):
823
+ rbln_max_seq_len = tokenizer.model_max_length
824
+ break
825
+ if rbln_max_seq_len is None:
826
+ raise ValueError("`rbln_max_seq_len` should be specified!")
827
+
828
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
829
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
830
+
831
+ if rbln_model_input_names is None:
832
+ # These are BERT's inputs
833
+ rbln_model_input_names = ["input_ids", "attention_mask"]
834
+
835
+ if rbln_batch_size is None:
836
+ rbln_batch_size = 1
837
+ input_info = [
838
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
839
+ for model_input_name in rbln_model_input_names
840
+ ]
841
+
842
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
843
+ rbln_runtime_config.batch_size = rbln_batch_size
844
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
845
+
846
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
847
+
848
+
849
+ class RBLNModelForMaskedLM(RBLNModel):
850
+ model_type = "rbln_model"
851
+ auto_model_class = AutoModelForMaskedLM
852
+
853
+ @classmethod
854
+ def _get_rbln_config(
855
+ cls,
856
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
857
+ model_config: Optional["PretrainedConfig"] = None,
858
+ rbln_max_seq_len: Optional[int] = None,
859
+ rbln_model_input_names: Optional[List[str]] = None,
860
+ rbln_batch_size: Optional[int] = None,
861
+ ) -> RBLNConfig:
862
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
863
+ model_config, "max_position_embeddings", None
864
+ )
865
+
866
+ if rbln_max_seq_len is None:
867
+ rbln_max_seq_len = max_position_embeddings
868
+ if rbln_max_seq_len is None:
869
+ for tokenizer in preprocessors:
870
+ if hasattr(tokenizer, "model_max_length"):
871
+ rbln_max_seq_len = tokenizer.model_max_length
872
+ break
873
+ if rbln_max_seq_len is None:
874
+ raise ValueError("`rbln_max_seq_len` should be specified!")
875
+
876
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
877
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
878
+
879
+ if rbln_model_input_names is None:
880
+ # These are BERT's inputs
881
+ rbln_model_input_names = ["input_ids", "attention_mask"]
882
+
883
+ if rbln_batch_size is None:
884
+ rbln_batch_size = 1
885
+ input_info = [
886
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
887
+ for model_input_name in rbln_model_input_names
888
+ ]
889
+
890
+ rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
891
+ rbln_runtime_config.batch_size = rbln_batch_size
892
+ meta = {"rbln_max_seq_len": rbln_max_seq_len}
893
+
894
+ return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
@@ -27,30 +27,38 @@ from transformers.utils import _LazyModule
27
27
 
28
28
 
29
29
  _import_structure = {
30
+ "cache_utils": ["RebelDynamicCache"],
30
31
  "generation": ["BatchTextIteratorStreamer"],
31
32
  "models": [
32
33
  "RBLNCLIPTextModel",
33
34
  "RBLNCLIPTextModelWithProjection",
34
35
  "RBLNDPTForDepthEstimation",
36
+ "RBLNGemmaForCausalLM",
35
37
  "RBLNGPT2LMHeadModel",
36
38
  "RBLNWav2Vec2ForCTC",
37
39
  "RBLNWhisperForConditionalGeneration",
38
40
  "RBLNLlamaForCausalLM",
39
41
  "RBLNMidmLMHeadModel",
42
+ "RBLNMistralForCausalLM",
43
+ "RBLNXLMRobertaModel",
40
44
  ],
41
45
  }
42
46
 
43
47
  if TYPE_CHECKING:
48
+ from .cache_utils import RebelDynamicCache
44
49
  from .generation import BatchTextIteratorStreamer
45
50
  from .models import (
46
51
  RBLNCLIPTextModel,
47
52
  RBLNCLIPTextModelWithProjection,
48
53
  RBLNDPTForDepthEstimation,
54
+ RBLNGemmaForCausalLM,
49
55
  RBLNGPT2LMHeadModel,
50
56
  RBLNLlamaForCausalLM,
51
57
  RBLNMidmLMHeadModel,
58
+ RBLNMistralForCausalLM,
52
59
  RBLNWav2Vec2ForCTC,
53
60
  RBLNWhisperForConditionalGeneration,
61
+ RBLNXLMRobertaModel,
54
62
  )
55
63
  else:
56
64
  import sys