optimum-rbln 0.1.12__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. optimum/rbln/__init__.py +27 -13
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +66 -111
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +85 -65
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +129 -163
  13. optimum/rbln/diffusers/pipelines/__init__.py +60 -12
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -25
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -185
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -190
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -191
  18. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -192
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -110
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -118
  22. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +18 -128
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -131
  30. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  31. optimum/rbln/modeling.py +572 -0
  32. optimum/rbln/modeling_alias.py +1 -1
  33. optimum/rbln/modeling_base.py +176 -763
  34. optimum/rbln/modeling_diffusers.py +329 -0
  35. optimum/rbln/transformers/__init__.py +2 -2
  36. optimum/rbln/transformers/cache_utils.py +5 -9
  37. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  38. optimum/rbln/transformers/models/__init__.py +80 -31
  39. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  40. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  41. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  42. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -34
  44. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -5
  45. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +779 -361
  46. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +83 -142
  47. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  48. optimum/rbln/transformers/models/exaone/exaone_architecture.py +64 -39
  49. optimum/rbln/transformers/models/exaone/modeling_exaone.py +6 -29
  50. optimum/rbln/transformers/models/gemma/gemma_architecture.py +31 -92
  51. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  52. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  53. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -31
  54. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  55. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +29 -83
  56. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -253
  57. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -33
  58. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  59. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  60. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -345
  61. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +5 -29
  62. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  63. optimum/rbln/transformers/models/t5/__init__.py +1 -1
  64. optimum/rbln/transformers/models/t5/modeling_t5.py +157 -6
  65. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  66. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  67. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  68. optimum/rbln/transformers/utils/rbln_quantization.py +128 -5
  69. optimum/rbln/utils/decorator_utils.py +59 -0
  70. optimum/rbln/utils/hub.py +131 -0
  71. optimum/rbln/utils/import_utils.py +21 -0
  72. optimum/rbln/utils/model_utils.py +53 -0
  73. optimum/rbln/utils/runtime_utils.py +5 -5
  74. optimum/rbln/utils/submodule.py +114 -0
  75. optimum/rbln/utils/timer_utils.py +2 -2
  76. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  77. optimum_rbln-0.1.15.dist-info/RECORD +110 -0
  78. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  79. optimum/rbln/transformers/generation/streamers.py +0 -139
  80. optimum/rbln/transformers/generation/utils.py +0 -397
  81. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  82. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  83. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  84. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  85. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  86. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  87. optimum_rbln-0.1.12.dist-info/METADATA +0 -119
  88. optimum_rbln-0.1.12.dist-info/RECORD +0 -103
  89. optimum_rbln-0.1.12.dist-info/entry_points.txt +0 -4
  90. {optimum_rbln-0.1.12.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling.py CHANGED
@@ -0,0 +1,572 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+ import inspect
24
+ import logging
25
+ from pathlib import Path
26
+ from tempfile import TemporaryDirectory
27
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
28
+
29
+ import rebel
30
+ import torch
31
+ import transformers
32
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
33
+ from transformers import (
34
+ AutoConfig,
35
+ AutoModelForAudioClassification,
36
+ AutoModelForImageClassification,
37
+ AutoModelForMaskedLM,
38
+ AutoModelForQuestionAnswering,
39
+ AutoModelForSequenceClassification,
40
+ PretrainedConfig,
41
+ )
42
+
43
+ from .modeling_base import RBLNBaseModel
44
+ from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig, use_rbln_config
45
+
46
+
47
+ if TYPE_CHECKING:
48
+ from transformers import (
49
+ AutoFeatureExtractor,
50
+ AutoProcessor,
51
+ AutoTokenizer,
52
+ PreTrainedModel,
53
+ )
54
+
55
+ logger = logging.getLogger(__name__)
56
+
57
+
58
+ class RBLNModel(RBLNBaseModel):
59
+ """
60
+ A class that inherits from RBLNBaseModel for models consisting of a single `torch.nn.Module`.
61
+
62
+ This class supports all the functionality of RBLNBaseModel, including loading and saving models using
63
+ the `from_pretrained` and `save_pretrained` methods, compiling PyTorch models for execution on RBLN NPU
64
+ devices.
65
+
66
+ Example:
67
+ ```python
68
+ model = RBLNModel.from_pretrained("model_id", export=True, rbln_npu="npu_name")
69
+ outputs = model(**inputs)
70
+ ```
71
+ """
72
+
73
+ @classmethod
74
+ def update_kwargs(cls, kwargs):
75
+ """
76
+ Update user-given kwargs to get proper pytorch model.
77
+
78
+ For example, `torchscript`=True should be set because torch.jit
79
+ does not support `transformers` output instances as module output;
80
+ """
81
+ kwargs.update(
82
+ {
83
+ "torchscript": True,
84
+ "return_dict": False,
85
+ }
86
+ )
87
+ return kwargs
88
+
89
+ @classmethod
90
+ def save_torch_artifacts(
91
+ cls,
92
+ model: "PreTrainedModel",
93
+ save_dir_path: Path,
94
+ subfolder: str,
95
+ rbln_config: RBLNConfig,
96
+ ):
97
+ """
98
+ If you are unavoidably running on a CPU rather than an RBLN device,
99
+ store the torch tensor, weight, etc. in this function.
100
+ """
101
+
102
+ @classmethod
103
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
104
+ # Wrap the model if needed.
105
+ return model
106
+
107
+ @classmethod
108
+ def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
109
+ model = cls.wrap_model_if_needed(model, rbln_config)
110
+ rbln_compile_config = rbln_config.compile_cfgs[0]
111
+ compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
112
+ return compiled_model
113
+
114
+ @classmethod
115
+ @use_rbln_config
116
+ def from_model(
117
+ cls,
118
+ model: "PreTrainedModel",
119
+ config: Optional[PretrainedConfig] = None,
120
+ rbln_config: Dict[str, Any] = {},
121
+ model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
122
+ subfolder: str = "",
123
+ **kwargs,
124
+ ):
125
+ preprocessors = kwargs.pop("preprocessors", [])
126
+ rbln_kwargs = rbln_config
127
+
128
+ # Directory to save compile artifacts(.rbln) and original configs
129
+ if model_save_dir is None:
130
+ save_dir = TemporaryDirectory()
131
+ save_dir_path = Path(save_dir.name)
132
+ else:
133
+ save_dir = model_save_dir
134
+ if isinstance(save_dir, TemporaryDirectory):
135
+ save_dir_path = Path(model_save_dir.name)
136
+ else:
137
+ save_dir_path = Path(model_save_dir)
138
+ save_dir_path.mkdir(exist_ok=True)
139
+
140
+ # Save configs
141
+ if config is None:
142
+ config = model.config
143
+ # remote_config
144
+ if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
145
+ config = AutoConfig.from_pretrained(config._name_or_path, **kwargs)
146
+
147
+ if hasattr(model, "can_generate") and model.can_generate():
148
+ generation_config = model.generation_config
149
+ generation_config.save_pretrained(save_dir_path / subfolder)
150
+
151
+ if not isinstance(config, PretrainedConfig): # diffusers config
152
+ config = PretrainedConfig(**config)
153
+ config.save_pretrained(save_dir_path / subfolder)
154
+
155
+ # Save preprocessor
156
+ for preprocessor in preprocessors:
157
+ preprocessor.save_pretrained(save_dir_path / subfolder)
158
+
159
+ # Get compilation arguments (e.g. input_info)
160
+ rbln_config: RBLNConfig = cls.get_rbln_config(
161
+ preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
162
+ )
163
+ # rbln_config.update_runtime_cfg(rbln_kwargs) # This is done in get_rbln_config
164
+
165
+ compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
166
+ model, rbln_config=rbln_config
167
+ )
168
+
169
+ # Save compiled models (.rbln)
170
+ (save_dir_path / subfolder).mkdir(exist_ok=True)
171
+ if not isinstance(compiled_model, dict):
172
+ compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
173
+ else:
174
+ compiled_models = compiled_model
175
+ for compiled_model_name, cm in compiled_models.items():
176
+ cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
177
+ rbln_config.save(save_dir_path / subfolder)
178
+
179
+ # Save torch artifacts (e.g. embedding matrix if needed.)
180
+ cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
181
+
182
+ # Load submodules
183
+ if len(cls._rbln_submodules) > 0:
184
+ rbln_submodules = cls._load_submodules(
185
+ model=model,
186
+ model_save_dir=save_dir,
187
+ rbln_kwargs=rbln_kwargs,
188
+ **kwargs,
189
+ )
190
+ else:
191
+ rbln_submodules = []
192
+
193
+ # Instantiate
194
+ return cls._from_pretrained(
195
+ model_id=save_dir_path,
196
+ config=config,
197
+ model_save_dir=save_dir,
198
+ subfolder=subfolder,
199
+ rbln_config=rbln_config,
200
+ rbln_compiled_models=compiled_models,
201
+ rbln_submodules=rbln_submodules,
202
+ **kwargs,
203
+ )
204
+
205
+ @classmethod
206
+ def get_pytorch_model(
207
+ cls,
208
+ model_id: str,
209
+ use_auth_token: Optional[Union[bool, str]] = None,
210
+ revision: Optional[str] = None,
211
+ force_download: bool = False,
212
+ cache_dir: Optional[str] = HUGGINGFACE_HUB_CACHE,
213
+ subfolder: str = "",
214
+ local_files_only: bool = False,
215
+ trust_remote_code: bool = False,
216
+ # Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
217
+ rbln_kwargs: Optional[Dict[str, Any]] = None,
218
+ **kwargs,
219
+ ) -> "PreTrainedModel":
220
+ kwargs = cls.update_kwargs(kwargs)
221
+ return cls.hf_class.from_pretrained(
222
+ model_id,
223
+ subfolder=subfolder,
224
+ revision=revision,
225
+ cache_dir=cache_dir,
226
+ use_auth_token=use_auth_token,
227
+ local_files_only=local_files_only,
228
+ force_download=force_download,
229
+ trust_remote_code=trust_remote_code,
230
+ **kwargs,
231
+ )
232
+
233
+ @classmethod
234
+ def _create_runtimes(
235
+ cls,
236
+ compiled_models: List[rebel.RBLNCompiledModel],
237
+ rbln_device_map: Dict[str, int],
238
+ ) -> List[rebel.Runtime]:
239
+ device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
240
+ return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
241
+
242
+ def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
243
+ output = self.model[0](*args, **kwargs)
244
+ return output
245
+
246
+
247
+ class RBLNModelForQuestionAnswering(RBLNModel):
248
+ auto_model_class = AutoModelForQuestionAnswering
249
+ rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
250
+
251
+ @classmethod
252
+ def _get_rbln_config(
253
+ cls,
254
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
255
+ model_config: Optional["PretrainedConfig"] = None,
256
+ rbln_kwargs: Dict[str, Any] = {},
257
+ ) -> RBLNConfig:
258
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
259
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
260
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
261
+
262
+ if rbln_max_seq_len is None:
263
+ for tokenizer in preprocessors:
264
+ if hasattr(tokenizer, "model_max_length"):
265
+ rbln_max_seq_len = tokenizer.model_max_length
266
+ break
267
+ if rbln_max_seq_len is None:
268
+ raise ValueError("`rbln_max_seq_len` should be specified!")
269
+
270
+ if rbln_batch_size is None:
271
+ rbln_batch_size = 1
272
+
273
+ if rbln_model_input_names is None:
274
+ for tokenizer in preprocessors:
275
+ if hasattr(tokenizer, "model_input_names"):
276
+ rbln_model_input_names = tokenizer.model_input_names
277
+ break
278
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
279
+ rbln_model_input_names = cls.rbln_model_input_names
280
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
281
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
282
+ raise ValueError(
283
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
284
+ f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(input_names_order)})"
285
+ )
286
+
287
+ input_info = [
288
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
289
+ for model_input_name in rbln_model_input_names
290
+ ]
291
+
292
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
293
+ rbln_config = RBLNConfig(
294
+ rbln_cls=cls.__name__,
295
+ compile_cfgs=[rbln_compile_config],
296
+ rbln_kwargs=rbln_kwargs,
297
+ )
298
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
299
+ return rbln_config
300
+
301
+
302
+ class RBLNModelForImageClassification(RBLNModel):
303
+ """
304
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method
305
+ """
306
+
307
+ auto_model_class = AutoModelForImageClassification
308
+
309
+ @classmethod
310
+ def _get_rbln_config(
311
+ cls,
312
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
313
+ model_config: Optional["PretrainedConfig"] = None,
314
+ rbln_kwargs: Dict[str, Any] = {},
315
+ ) -> RBLNConfig:
316
+ rbln_image_size = rbln_kwargs.get("image_size", None)
317
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
318
+
319
+ if rbln_image_size is None:
320
+ for processor in preprocessors:
321
+ if hasattr(processor, "size"):
322
+ if all(required_key in processor.size.keys() for required_key in ["height", "width"]):
323
+ rbln_image_size = (processor.size["height"], processor.size["width"])
324
+ elif "shortest_edge" in processor.size.keys():
325
+ rbln_image_size = (processor.size["shortest_edge"], processor.size["shortest_edge"])
326
+ elif "longest_edge" in processor.size.keys():
327
+ rbln_image_size = (processor.size["longest_edge"], processor.size["longest_edge"])
328
+ break
329
+
330
+ if rbln_image_size is None:
331
+ rbln_image_size = model_config.image_size
332
+
333
+ if rbln_image_size is None:
334
+ raise ValueError("`rbln_image_size` should be specified!")
335
+
336
+ if rbln_batch_size is None:
337
+ rbln_batch_size = 1
338
+
339
+ if isinstance(rbln_image_size, int):
340
+ rbln_image_height, rbln_image_width = rbln_image_size, rbln_image_size
341
+ elif isinstance(rbln_image_size, (list, tuple)):
342
+ rbln_image_height, rbln_image_width = rbln_image_size[0], rbln_image_size[1]
343
+ elif isinstance(rbln_image_size, dict):
344
+ rbln_image_height, rbln_image_width = rbln_image_size["height"], rbln_image_size["width"]
345
+ else:
346
+ raise ValueError(
347
+ "`rbln_image_size` should be `int` (ex. 224), `tuple` (ex. 224, 224), `dict` (ex. {'height': 224, 'width': 224}) format"
348
+ )
349
+
350
+ input_info = [
351
+ (
352
+ "pixel_values",
353
+ [rbln_batch_size, 3, rbln_image_height, rbln_image_width],
354
+ "float32",
355
+ )
356
+ ]
357
+
358
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
359
+ return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
360
+
361
+
362
+ class RBLNModelForAudioClassification(RBLNModel):
363
+ """
364
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method
365
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
366
+
367
+ A class to convert and run pre-trained transformers based AudioClassification models on RBLN devices.
368
+ It implements the methods to convert a pre-trained transformers AudioClassification model into a RBLN transformer model by:
369
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
370
+ - compiling the resulting graph using the RBLN compiler.
371
+
372
+ Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
373
+ """
374
+
375
+ auto_model_class = AutoModelForAudioClassification
376
+
377
+ @classmethod
378
+ def _get_rbln_config(
379
+ cls,
380
+ preprocessors: "AutoFeatureExtractor",
381
+ model_config: "PretrainedConfig",
382
+ rbln_kwargs: Dict[str, Any] = {},
383
+ ) -> RBLNConfig:
384
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
385
+ rbln_max_length = rbln_kwargs.get("max_length", None)
386
+ rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
387
+
388
+ if rbln_batch_size is None:
389
+ rbln_batch_size = 1
390
+
391
+ if rbln_num_mel_bins is None:
392
+ rbln_num_mel_bins = getattr(model_config, "num_mel_bins", None)
393
+ if rbln_num_mel_bins is None:
394
+ for feature_extractor in preprocessors:
395
+ if hasattr(feature_extractor, "num_mel_bins"):
396
+ rbln_num_mel_bins = feature_extractor.num_mel_bins
397
+ break
398
+
399
+ if rbln_num_mel_bins is None:
400
+ raise ValueError("`rbln_num_mel_bins` should be specified!")
401
+
402
+ if rbln_max_length is None:
403
+ rbln_max_length = getattr(model_config, "max_length", None)
404
+ for feature_extractor in preprocessors:
405
+ if hasattr(feature_extractor, "max_length"):
406
+ rbln_max_length = feature_extractor.max_length
407
+ break
408
+
409
+ if rbln_max_length is None:
410
+ raise ValueError("`rbln_max_length` should be specified!")
411
+
412
+ input_info = [
413
+ (
414
+ "input_values",
415
+ [rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
416
+ "float32",
417
+ ),
418
+ ]
419
+
420
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
421
+ rbln_config = RBLNConfig(
422
+ rbln_cls=cls.__name__,
423
+ compile_cfgs=[rbln_compile_config],
424
+ rbln_kwargs=rbln_kwargs,
425
+ )
426
+ rbln_config.model_cfg.update(
427
+ {
428
+ "batch_size": rbln_batch_size,
429
+ "max_length": rbln_max_length,
430
+ "num_mel_bins": rbln_num_mel_bins,
431
+ }
432
+ )
433
+ return rbln_config
434
+
435
+
436
+ class RBLNModelForSequenceClassification(RBLNModel):
437
+ """
438
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence classification head) when created with the from_pretrained() class method
439
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
440
+
441
+ A class to convert and run pre-trained transformers based SequenceClassification models on RBLN devices.
442
+ It implements the methods to convert a pre-trained transformers SequenceClassification model into a RBLN transformer model by:
443
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
444
+ - compiling the resulting graph using the RBLN compiler.
445
+
446
+ Currently, this model class supports the 'XLMRoberta' and 'Roberta' model from the transformers library. Future updates may include support for additional model types.
447
+ """
448
+
449
+ auto_model_class = AutoModelForSequenceClassification
450
+
451
+ @classmethod
452
+ def _get_rbln_config(
453
+ cls,
454
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
455
+ model_config: Optional["PretrainedConfig"] = None,
456
+ rbln_kwargs: Dict[str, Any] = {},
457
+ ) -> RBLNConfig:
458
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
459
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
460
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
461
+
462
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
463
+ model_config, "max_position_embeddings", None
464
+ )
465
+
466
+ if rbln_max_seq_len is None:
467
+ rbln_max_seq_len = max_position_embeddings
468
+ if rbln_max_seq_len is None:
469
+ for tokenizer in preprocessors:
470
+ if hasattr(tokenizer, "model_max_length"):
471
+ rbln_max_seq_len = tokenizer.model_max_length
472
+ break
473
+ if rbln_max_seq_len is None:
474
+ raise ValueError("`rbln_max_seq_len` should be specified!")
475
+
476
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
477
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
478
+
479
+ if rbln_model_input_names is None:
480
+ for tokenizer in preprocessors:
481
+ if hasattr(tokenizer, "model_input_names"):
482
+ rbln_model_input_names = tokenizer.model_input_names
483
+ break
484
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
485
+ rbln_model_input_names = cls.rbln_model_input_names
486
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
487
+ original_model_class = getattr(transformers, model_config.architectures[0])
488
+ input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
489
+ raise ValueError(
490
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
491
+ f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(input_names_order)})"
492
+ )
493
+
494
+ if rbln_batch_size is None:
495
+ rbln_batch_size = 1
496
+
497
+ input_info = [
498
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
499
+ for model_input_name in rbln_model_input_names
500
+ ]
501
+
502
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
503
+ rbln_config = RBLNConfig(
504
+ rbln_cls=cls.__name__,
505
+ compile_cfgs=[rbln_compile_config],
506
+ rbln_kwargs=rbln_kwargs,
507
+ )
508
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
509
+ return rbln_config
510
+
511
+
512
+ class RBLNModelForMaskedLM(RBLNModel):
513
+ auto_model_class = AutoModelForMaskedLM
514
+
515
+ @classmethod
516
+ def _get_rbln_config(
517
+ cls,
518
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
519
+ model_config: Optional["PretrainedConfig"] = None,
520
+ rbln_kwargs: Dict[str, Any] = {},
521
+ ) -> RBLNConfig:
522
+ rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
523
+ rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
524
+ rbln_batch_size = rbln_kwargs.get("batch_size", None)
525
+
526
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
527
+ model_config, "max_position_embeddings", None
528
+ )
529
+
530
+ if rbln_max_seq_len is None:
531
+ rbln_max_seq_len = max_position_embeddings
532
+ if rbln_max_seq_len is None:
533
+ for tokenizer in preprocessors:
534
+ if hasattr(tokenizer, "model_max_length"):
535
+ rbln_max_seq_len = tokenizer.model_max_length
536
+ break
537
+ if rbln_max_seq_len is None:
538
+ raise ValueError("`rbln_max_seq_len` should be specified!")
539
+
540
+ if max_position_embeddings is not None and rbln_max_seq_len > max_position_embeddings:
541
+ raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
542
+
543
+ if rbln_model_input_names is None:
544
+ for tokenizer in preprocessors:
545
+ if hasattr(tokenizer, "model_input_names"):
546
+ rbln_model_input_names = tokenizer.model_input_names
547
+ break
548
+ if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
549
+ rbln_model_input_names = cls.rbln_model_input_names
550
+ elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
551
+ input_names_order = inspect.signature(cls.hf_class.forward).parameters.keys()
552
+ raise ValueError(
553
+ "Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
554
+ f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(input_names_order)})"
555
+ )
556
+
557
+ if rbln_batch_size is None:
558
+ rbln_batch_size = 1
559
+
560
+ input_info = [
561
+ (model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
562
+ for model_input_name in rbln_model_input_names
563
+ ]
564
+
565
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
566
+ rbln_config = RBLNConfig(
567
+ rbln_cls=cls.__name__,
568
+ compile_cfgs=[rbln_compile_config],
569
+ rbln_kwargs=rbln_kwargs,
570
+ )
571
+ rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
572
+ return rbln_config
@@ -21,7 +21,7 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from .modeling_base import (
24
+ from .modeling import (
25
25
  RBLNModelForAudioClassification,
26
26
  RBLNModelForImageClassification,
27
27
  RBLNModelForMaskedLM,