optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,645 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import logging
|
25
|
+
from abc import ABC, abstractmethod
|
26
|
+
from pathlib import Path
|
27
|
+
from tempfile import TemporaryDirectory
|
28
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
29
|
+
|
30
|
+
import rebel
|
31
|
+
import torch
|
32
|
+
from huggingface_hub import HfApi, HfFolder, hf_hub_download
|
33
|
+
from optimum.exporters import TasksManager
|
34
|
+
from optimum.modeling_base import OptimizedModel
|
35
|
+
from transformers import (
|
36
|
+
AutoConfig,
|
37
|
+
AutoModel,
|
38
|
+
AutoModelForAudioClassification,
|
39
|
+
AutoModelForImageClassification,
|
40
|
+
AutoModelForQuestionAnswering,
|
41
|
+
GenerationConfig,
|
42
|
+
PretrainedConfig,
|
43
|
+
)
|
44
|
+
|
45
|
+
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
46
|
+
from .utils.runtime_utils import UnavailableRuntime
|
47
|
+
from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
48
|
+
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
51
|
+
|
52
|
+
if TYPE_CHECKING:
|
53
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
54
|
+
|
55
|
+
|
56
|
+
def listify(var: Any):
|
57
|
+
if isinstance(var, list):
|
58
|
+
return var
|
59
|
+
elif var is not None:
|
60
|
+
return [var]
|
61
|
+
else:
|
62
|
+
return None
|
63
|
+
|
64
|
+
|
65
|
+
class RBLNBaseModel(OptimizedModel, ABC):
|
66
|
+
"""
|
67
|
+
An abstract base class for compiling, loading, and saving neural network models from the huggingface
|
68
|
+
transformers and diffusers libraries to run on RBLN NPU devices.
|
69
|
+
|
70
|
+
This class supports loading and saving models using the `from_pretrained` and `save_pretrained` methods,
|
71
|
+
similar to the huggingface libraries.
|
72
|
+
|
73
|
+
The `from_pretrained` method loads a model corresponding to the given `model_id` from a local repository
|
74
|
+
or the huggingface hub onto the NPU. If the model is a PyTorch model and `export=True` is passed as a
|
75
|
+
kwarg, it compiles the PyTorch model corresponding to the given `model_id` before loading. If `model_id`
|
76
|
+
is an already rbln-compiled model, it can be directly loaded onto the NPU with `export=False`.
|
77
|
+
|
78
|
+
`rbln_npu` is a kwarg required for compilation, specifying the name of the NPU to be used. If this
|
79
|
+
keyword is not specified, the NPU installed on the host machine is used. If no NPU is installed on the
|
80
|
+
host machine, an error occurs.
|
81
|
+
|
82
|
+
`rbln_device` specifies the device to be used at runtime. If not specified, device 0 is used.
|
83
|
+
|
84
|
+
`rbln_create_runtimes` indicates whether to create runtime objects. If False, the runtime does not load
|
85
|
+
the model onto the NPU. This option is particularly useful when you want to perform compilation only on a
|
86
|
+
host machine without an NPU.
|
87
|
+
|
88
|
+
`RBLNModel`, `RBLNModelFor*`, etc. are all child classes of RBLNBaseModel.
|
89
|
+
|
90
|
+
Models compiled in this way can be saved to a local repository using `save_pretrained` or uploaded to
|
91
|
+
the huggingface hub.
|
92
|
+
|
93
|
+
It also supports generation through `generate` (for transformers models that support generation).
|
94
|
+
|
95
|
+
RBLNBaseModel is a class for models consisting of an arbitrary number of `torch.nn.Module`s, and
|
96
|
+
therefore is an abstract class without explicit implementations of `forward` or `export` functions.
|
97
|
+
To inherit from this class, `forward`, `export`, etc. must be implemented.
|
98
|
+
"""
|
99
|
+
|
100
|
+
model_type = "rbln_model"
|
101
|
+
auto_model_class = AutoModel # feature extraction
|
102
|
+
config_name = "model_index.json"
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
models: List[rebel.RBLNCompiledModel],
|
107
|
+
config: "PretrainedConfig",
|
108
|
+
preprocessors: Optional[List],
|
109
|
+
rbln_config: Optional[RBLNConfig],
|
110
|
+
rbln_device: Optional[List[int]] = None,
|
111
|
+
rbln_device_map: Optional[Dict[str, int]] = None,
|
112
|
+
rbln_create_runtimes: Optional[bool] = True,
|
113
|
+
**kwargs,
|
114
|
+
):
|
115
|
+
super().__init__(models, config)
|
116
|
+
if not isinstance(self.config, PretrainedConfig): # if diffusers config
|
117
|
+
self.config = PretrainedConfig(**self.config)
|
118
|
+
|
119
|
+
self.models = listify(self.model)
|
120
|
+
|
121
|
+
self.preprocessors = [] if preprocessors is None else preprocessors
|
122
|
+
|
123
|
+
# Registers the RBLNBaseModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
|
124
|
+
# a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
|
125
|
+
AutoConfig.register(self.model_type, AutoConfig)
|
126
|
+
if hasattr(self.auto_model_class, "register"):
|
127
|
+
self.auto_model_class.register(AutoConfig, self.__class__)
|
128
|
+
|
129
|
+
self.rbln_config = rbln_config
|
130
|
+
self.compiled_models: List[rebel.RBLNCompiledModel] = models
|
131
|
+
|
132
|
+
if rbln_device_map is None:
|
133
|
+
self.rbln_device_map = {}
|
134
|
+
device_val = 0 if rbln_device is None else rbln_device
|
135
|
+
for key in self.rbln_config:
|
136
|
+
self.rbln_device_map[key] = device_val
|
137
|
+
|
138
|
+
else:
|
139
|
+
self.rbln_device_map = rbln_device_map
|
140
|
+
|
141
|
+
# copied from tranformers PreTrainedModel __init__
|
142
|
+
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
143
|
+
if self.generation_config is not None:
|
144
|
+
self.generation_config.use_cache = True
|
145
|
+
|
146
|
+
self.device = torch.device("cpu")
|
147
|
+
|
148
|
+
# create runtimes only if `rbln_create_runtimes` is enabled
|
149
|
+
self.runtimes = self._create_runtimes(self.rbln_device_map) if rbln_create_runtimes else UnavailableRuntime()
|
150
|
+
|
151
|
+
self.__post_init__(**kwargs)
|
152
|
+
|
153
|
+
def __post_init__(self, **kwargs):
|
154
|
+
pass
|
155
|
+
|
156
|
+
def _save_pretrained(self, save_directory: Union[str, Path]):
|
157
|
+
"""
|
158
|
+
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
|
159
|
+
[`~optimum.rbln.modeling_base.RBLNBaseModel.from_pretrained`] class method.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
save_directory (`Union[str, Path]`):
|
163
|
+
Directory where to save the model file.
|
164
|
+
"""
|
165
|
+
|
166
|
+
for compiled_model, compiled_model_name in zip(self.compiled_models, self.rbln_config):
|
167
|
+
dst_path = Path(save_directory) / f"{compiled_model_name}.rbln"
|
168
|
+
compiled_model.save(dst_path)
|
169
|
+
self.rbln_config.save(save_directory)
|
170
|
+
|
171
|
+
@classmethod
|
172
|
+
def _from_pretrained(
|
173
|
+
cls,
|
174
|
+
model_id: Union[str, Path],
|
175
|
+
config: "PretrainedConfig",
|
176
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
177
|
+
revision: Optional[str] = None,
|
178
|
+
force_download: bool = False,
|
179
|
+
cache_dir: Optional[str] = None,
|
180
|
+
subfolder: str = "",
|
181
|
+
local_files_only: bool = False,
|
182
|
+
**kwargs,
|
183
|
+
) -> "RBLNBaseModel":
|
184
|
+
model_path = Path(model_id)
|
185
|
+
if model_path.is_dir():
|
186
|
+
model_path = model_path / subfolder
|
187
|
+
rbln_files = list(model_path.glob("*.rbln"))
|
188
|
+
rbln_config_filenames = list(model_path.glob("rbln_config.json"))
|
189
|
+
else:
|
190
|
+
if isinstance(use_auth_token, bool):
|
191
|
+
token = HfFolder().get_token()
|
192
|
+
else:
|
193
|
+
token = use_auth_token
|
194
|
+
repo_files = list(map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)))
|
195
|
+
|
196
|
+
pattern = "*.rbln" if subfolder == "" else f"{subfolder}/*.rbln"
|
197
|
+
rbln_files = [p for p in repo_files if p.match(pattern)]
|
198
|
+
|
199
|
+
pattern = "rbln_config.json" if subfolder == "" else f"{subfolder}/rbln_config.json"
|
200
|
+
rbln_config_filenames = [p for p in repo_files if p.match(pattern)]
|
201
|
+
|
202
|
+
if len(rbln_files) == 0:
|
203
|
+
raise FileNotFoundError(f"Could not find any rbln model file in {model_path}")
|
204
|
+
|
205
|
+
if len(rbln_config_filenames) == 0:
|
206
|
+
raise FileNotFoundError(f"Could not find `rbln_config.json` file in {model_path}")
|
207
|
+
|
208
|
+
if len(rbln_config_filenames) > 1:
|
209
|
+
raise FileExistsError(
|
210
|
+
f"Multiple rbln_config.json are not expected. but {len(rbln_config_filenames)} are found."
|
211
|
+
)
|
212
|
+
|
213
|
+
if model_path.is_dir():
|
214
|
+
rbln_config = RBLNConfig.load(str(model_path))
|
215
|
+
models = [
|
216
|
+
rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
|
217
|
+
for compiled_model_name in rbln_config
|
218
|
+
]
|
219
|
+
|
220
|
+
else:
|
221
|
+
rbln_config_filename = rbln_config_filenames[0]
|
222
|
+
rbln_config_cache_path = hf_hub_download(
|
223
|
+
repo_id=model_id,
|
224
|
+
filename=str(rbln_config_filename),
|
225
|
+
subfolder=subfolder,
|
226
|
+
use_auth_token=use_auth_token,
|
227
|
+
revision=revision,
|
228
|
+
cache_dir=cache_dir,
|
229
|
+
force_download=force_download,
|
230
|
+
local_files_only=local_files_only,
|
231
|
+
)
|
232
|
+
rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
|
233
|
+
models = []
|
234
|
+
for compiled_model_name in rbln_config:
|
235
|
+
model_cache_path = hf_hub_download(
|
236
|
+
repo_id=model_id,
|
237
|
+
filename=f"{compiled_model_name}.rbln",
|
238
|
+
subfolder=subfolder,
|
239
|
+
use_auth_token=use_auth_token,
|
240
|
+
revision=revision,
|
241
|
+
cache_dir=cache_dir,
|
242
|
+
force_download=force_download,
|
243
|
+
local_files_only=local_files_only,
|
244
|
+
)
|
245
|
+
models.append(rebel.RBLNCompiledModel(model_cache_path))
|
246
|
+
|
247
|
+
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
|
248
|
+
|
249
|
+
return cls(
|
250
|
+
models,
|
251
|
+
config,
|
252
|
+
preprocessors,
|
253
|
+
rbln_config=rbln_config,
|
254
|
+
**kwargs,
|
255
|
+
)
|
256
|
+
|
257
|
+
def __repr__(self):
|
258
|
+
return repr(self.runtimes)
|
259
|
+
|
260
|
+
@classmethod
|
261
|
+
def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
|
262
|
+
compiled_model = rebel.compile_from_torch(
|
263
|
+
model,
|
264
|
+
input_info=rbln_runtime_config.input_info,
|
265
|
+
batch_size=rbln_runtime_config.batch_size,
|
266
|
+
fusion=rbln_runtime_config.fusion,
|
267
|
+
npu=rbln_runtime_config.npu,
|
268
|
+
tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
|
269
|
+
)
|
270
|
+
return compiled_model
|
271
|
+
|
272
|
+
@classmethod
|
273
|
+
def get_rbln_config(
|
274
|
+
cls,
|
275
|
+
**rbln_config_kwargs,
|
276
|
+
) -> RBLNConfig:
|
277
|
+
"""
|
278
|
+
Make default rbln-config for the model.
|
279
|
+
|
280
|
+
if `input_info` specified,
|
281
|
+
other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
|
282
|
+
|
283
|
+
kwargs for overriding model's config can be accepted.
|
284
|
+
|
285
|
+
Note that batch_size should be specified with proper input_info.
|
286
|
+
"""
|
287
|
+
|
288
|
+
input_info = rbln_config_kwargs.pop("rbln_input_info", None)
|
289
|
+
batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
|
290
|
+
fusion = rbln_config_kwargs.pop("rbln_fusion", None)
|
291
|
+
npu = rbln_config_kwargs.pop("rbln_npu", None)
|
292
|
+
tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
|
293
|
+
|
294
|
+
if input_info is not None:
|
295
|
+
rbln_runtime_config = RBLNRuntimeConfig(
|
296
|
+
input_info=input_info,
|
297
|
+
batch_size=batch_size,
|
298
|
+
fusion=fusion,
|
299
|
+
npu=npu,
|
300
|
+
tensor_parallel_size=tensor_parallel_size,
|
301
|
+
)
|
302
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
|
303
|
+
else:
|
304
|
+
rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
|
305
|
+
for k, rcfgs in rbln_config.items():
|
306
|
+
for rcfg in rcfgs:
|
307
|
+
rcfg: RBLNRuntimeConfig
|
308
|
+
rcfg.fusion = fusion
|
309
|
+
rcfg.npu = npu
|
310
|
+
rcfg.tensor_parallel_size = tensor_parallel_size
|
311
|
+
|
312
|
+
return rbln_config
|
313
|
+
|
314
|
+
@staticmethod
|
315
|
+
def pop_rbln_kwargs_from_kwargs(kwargs: dict):
|
316
|
+
keys = list(kwargs.keys())
|
317
|
+
rbln_constructor_kwargs = {
|
318
|
+
key: kwargs.pop(key) for key in keys if key in ["rbln_device", "rbln_device_map", "rbln_create_runtimes"]
|
319
|
+
}
|
320
|
+
|
321
|
+
keys = list(kwargs.keys())
|
322
|
+
rbln_config_kwargs = {key: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
|
323
|
+
return rbln_config_kwargs, rbln_constructor_kwargs
|
324
|
+
|
325
|
+
def can_generate(self):
|
326
|
+
return False
|
327
|
+
|
328
|
+
def to(self, *args, **kwargs):
|
329
|
+
pass
|
330
|
+
|
331
|
+
def __call__(self, *args, **kwargs):
|
332
|
+
return self.forward(*args, **kwargs)
|
333
|
+
|
334
|
+
@classmethod
|
335
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module) -> torch.nn.Module:
|
336
|
+
# Wrap the model if needed.
|
337
|
+
return model
|
338
|
+
|
339
|
+
@classmethod
|
340
|
+
def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
|
341
|
+
"""
|
342
|
+
Exports a vanilla Transformers model into a rbln-compiled Module.
|
343
|
+
This will be deprecated after optimum 2.0
|
344
|
+
"""
|
345
|
+
return cls._export(*args, **kwargs)
|
346
|
+
|
347
|
+
@classmethod
|
348
|
+
def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
|
349
|
+
raise NotImplementedError
|
350
|
+
|
351
|
+
@abstractmethod
|
352
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
353
|
+
pass
|
354
|
+
|
355
|
+
@abstractmethod
|
356
|
+
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
357
|
+
# self.compiled_models -> self.runtimes
|
358
|
+
pass
|
359
|
+
|
360
|
+
@classmethod
|
361
|
+
@abstractmethod
|
362
|
+
def _export(
|
363
|
+
cls,
|
364
|
+
model_id: Union[str, Path],
|
365
|
+
config: "PretrainedConfig",
|
366
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
367
|
+
revision: Optional[str] = None,
|
368
|
+
force_download: bool = False,
|
369
|
+
cache_dir: Optional[str] = None,
|
370
|
+
subfolder: str = "",
|
371
|
+
local_files_only: bool = False,
|
372
|
+
trust_remote_code: bool = False,
|
373
|
+
**kwargs,
|
374
|
+
):
|
375
|
+
"""
|
376
|
+
Exports a vanilla Transformers model into a rbln-compiled Module.
|
377
|
+
"""
|
378
|
+
pass
|
379
|
+
|
380
|
+
|
381
|
+
class RBLNModel(RBLNBaseModel):
|
382
|
+
"""
|
383
|
+
A class that inherits from RBLNBaseModel for models consisting of a single `torch.nn.Module`.
|
384
|
+
|
385
|
+
This class supports all the functionality of RBLNBaseModel, including loading and saving models using
|
386
|
+
the `from_pretrained` and `save_pretrained` methods, compiling PyTorch models for execution on RBLN NPU
|
387
|
+
devices.
|
388
|
+
|
389
|
+
Example:
|
390
|
+
```python
|
391
|
+
model = RBLNModel.from_pretrained("model_id", export=True, rbln_npu="npu_name")
|
392
|
+
outputs = model(**inputs)
|
393
|
+
```
|
394
|
+
"""
|
395
|
+
|
396
|
+
model_type = "rbln_model"
|
397
|
+
auto_model_class = AutoModel # feature extraction
|
398
|
+
|
399
|
+
@classmethod
|
400
|
+
def _export(
|
401
|
+
cls,
|
402
|
+
model_id: Union[str, Path],
|
403
|
+
config: "PretrainedConfig",
|
404
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
405
|
+
revision: Optional[str] = None,
|
406
|
+
force_download: bool = False,
|
407
|
+
cache_dir: Optional[str] = None,
|
408
|
+
subfolder: str = "",
|
409
|
+
local_files_only: bool = False,
|
410
|
+
trust_remote_code: bool = False,
|
411
|
+
**kwargs,
|
412
|
+
) -> "RBLNModel":
|
413
|
+
"""
|
414
|
+
Exports a vanilla Transformers model into a rbln-compiled Module.
|
415
|
+
"""
|
416
|
+
task = kwargs.pop("task", None)
|
417
|
+
if task is None:
|
418
|
+
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
419
|
+
|
420
|
+
save_dir = TemporaryDirectory()
|
421
|
+
save_dir_path = Path(save_dir.name)
|
422
|
+
|
423
|
+
kwargs.update(
|
424
|
+
{
|
425
|
+
"torchscript": True,
|
426
|
+
"return_dict": False,
|
427
|
+
}
|
428
|
+
)
|
429
|
+
|
430
|
+
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
431
|
+
|
432
|
+
model = TasksManager.get_model_from_task(
|
433
|
+
task=task,
|
434
|
+
model_name_or_path=model_id,
|
435
|
+
subfolder=subfolder,
|
436
|
+
revision=revision,
|
437
|
+
framework="pt",
|
438
|
+
cache_dir=cache_dir,
|
439
|
+
use_auth_token=use_auth_token,
|
440
|
+
local_files_only=local_files_only,
|
441
|
+
force_download=force_download,
|
442
|
+
trust_remote_code=trust_remote_code,
|
443
|
+
**kwargs,
|
444
|
+
)
|
445
|
+
|
446
|
+
# TODO : do we need this?
|
447
|
+
if isinstance(model, torch.nn.Module):
|
448
|
+
model.eval()
|
449
|
+
|
450
|
+
if config is None:
|
451
|
+
config = model.config
|
452
|
+
|
453
|
+
if not isinstance(config, PretrainedConfig): # diffusers config
|
454
|
+
config = PretrainedConfig(**config)
|
455
|
+
|
456
|
+
config.save_pretrained(save_dir_path / subfolder)
|
457
|
+
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
458
|
+
|
459
|
+
# Get compilation arguments
|
460
|
+
if rbln_config_kwargs.get("rbln_config", None) is None:
|
461
|
+
rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
|
462
|
+
|
463
|
+
rbln_runtime_configs = list(rbln_config.values())
|
464
|
+
if len(rbln_runtime_configs) != 1:
|
465
|
+
raise ValueError
|
466
|
+
rbln_runtime_config = rbln_runtime_configs[0]
|
467
|
+
if len(rbln_runtime_config) != 1:
|
468
|
+
raise ValueError
|
469
|
+
rbln_runtime_config = rbln_runtime_config[0]
|
470
|
+
|
471
|
+
model = cls.wrap_model_if_needed(model)
|
472
|
+
compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
|
473
|
+
compiled_model.save(save_dir_path / subfolder / f"{rbln_runtime_config.compiled_model_name}.rbln")
|
474
|
+
rbln_config.save(save_dir_path / subfolder)
|
475
|
+
|
476
|
+
return cls._from_pretrained(
|
477
|
+
model_id=save_dir_path,
|
478
|
+
config=config,
|
479
|
+
model_save_dir=save_dir,
|
480
|
+
use_auth_token=use_auth_token,
|
481
|
+
revision=revision,
|
482
|
+
force_download=force_download,
|
483
|
+
cache_dir=cache_dir,
|
484
|
+
subfolder=subfolder,
|
485
|
+
local_files_only=local_files_only,
|
486
|
+
**rbln_constructor_kwargs,
|
487
|
+
**kwargs,
|
488
|
+
)
|
489
|
+
|
490
|
+
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
491
|
+
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
492
|
+
return [
|
493
|
+
compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in self.compiled_models
|
494
|
+
]
|
495
|
+
|
496
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
497
|
+
output = self.runtimes[0](*args, **kwargs)
|
498
|
+
return output
|
499
|
+
|
500
|
+
def __repr__(self):
|
501
|
+
return repr(self.runtimes[0])
|
502
|
+
|
503
|
+
|
504
|
+
class RBLNModelForQuestionAnswering(RBLNModel):
|
505
|
+
model_type = "rbln_model"
|
506
|
+
auto_model_class = AutoModelForQuestionAnswering
|
507
|
+
|
508
|
+
@classmethod
|
509
|
+
def _get_rbln_config(
|
510
|
+
cls,
|
511
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
512
|
+
model_config: Optional["PretrainedConfig"] = None,
|
513
|
+
rbln_max_seq_len: Optional[int] = None,
|
514
|
+
rbln_model_input_names: Optional[List[str]] = None,
|
515
|
+
rbln_batch_size: Optional[int] = None,
|
516
|
+
) -> RBLNConfig:
|
517
|
+
if rbln_max_seq_len is None:
|
518
|
+
for tokenizer in preprocessors:
|
519
|
+
if hasattr(tokenizer, "model_max_length"):
|
520
|
+
rbln_max_seq_len = tokenizer.model_max_length
|
521
|
+
break
|
522
|
+
if rbln_max_seq_len is None:
|
523
|
+
raise ValueError("`rbln_max_seq_len` should be specified!")
|
524
|
+
|
525
|
+
if rbln_model_input_names is None:
|
526
|
+
# These are BERT's inputs
|
527
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
528
|
+
|
529
|
+
if rbln_batch_size is None:
|
530
|
+
rbln_batch_size = 1
|
531
|
+
input_info = [
|
532
|
+
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
533
|
+
for model_input_name in rbln_model_input_names
|
534
|
+
]
|
535
|
+
|
536
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
537
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
538
|
+
meta = {"rbln_max_seq_len": rbln_max_seq_len}
|
539
|
+
|
540
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
541
|
+
|
542
|
+
|
543
|
+
class RBLNModelForImageClassification(RBLNModel):
|
544
|
+
"""
|
545
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a image classification head) when created with the from_pretrained() class method
|
546
|
+
"""
|
547
|
+
|
548
|
+
model_type = "rbln_model"
|
549
|
+
auto_model_class = AutoModelForImageClassification
|
550
|
+
|
551
|
+
@classmethod
|
552
|
+
def _get_rbln_config(
|
553
|
+
cls,
|
554
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
555
|
+
model_config: Optional["PretrainedConfig"] = None,
|
556
|
+
rbln_image_size: Optional[int] = None,
|
557
|
+
rbln_batch_size: Optional[int] = None,
|
558
|
+
) -> RBLNConfig:
|
559
|
+
if rbln_image_size is None:
|
560
|
+
for processor in preprocessors:
|
561
|
+
if hasattr(processor, "size"):
|
562
|
+
rbln_image_size = processor.size["shortest_edge"]
|
563
|
+
break
|
564
|
+
if rbln_image_size is None:
|
565
|
+
raise ValueError("`rbln_rbln_image_size` should be specified!")
|
566
|
+
|
567
|
+
if rbln_batch_size is None:
|
568
|
+
rbln_batch_size = 1
|
569
|
+
|
570
|
+
input_info = [("pixel_values", [rbln_batch_size, 3, rbln_image_size, rbln_image_size], "float32")]
|
571
|
+
|
572
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=input_info)
|
573
|
+
rbln_runtime_config.batch_size = rbln_batch_size
|
574
|
+
meta = {"rbln_image_size": rbln_image_size}
|
575
|
+
|
576
|
+
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
577
|
+
|
578
|
+
|
579
|
+
class RBLNModelForAudioClassification(RBLNModel):
|
580
|
+
"""
|
581
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a audio classification head) when created with the from_pretrained() class method
|
582
|
+
This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
583
|
+
|
584
|
+
A class to convert and run pre-trained transformers based AudioClassification models on RBLN devices.
|
585
|
+
It implements the methods to convert a pre-trained transformers AudioClassification model into a RBLN transformer model by:
|
586
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
587
|
+
- compiling the resulting graph using the RBLN compiler.
|
588
|
+
|
589
|
+
Currently, this model class only supports the 'AST' model from the transformers library. Future updates may include support for additional model types.
|
590
|
+
"""
|
591
|
+
|
592
|
+
model_type = "rbln_model"
|
593
|
+
auto_model_class = AutoModelForAudioClassification
|
594
|
+
|
595
|
+
@classmethod
|
596
|
+
def _get_rbln_config(
|
597
|
+
cls,
|
598
|
+
preprocessors: "AutoFeatureExtractor",
|
599
|
+
model_config: "PretrainedConfig",
|
600
|
+
rbln_batch_size: Optional[int] = None,
|
601
|
+
rbln_max_length: Optional[int] = None,
|
602
|
+
rbln_num_mel_bins: Optional[int] = None,
|
603
|
+
) -> RBLNConfig:
|
604
|
+
meta = {}
|
605
|
+
|
606
|
+
if rbln_batch_size is None:
|
607
|
+
rbln_batch_size = 1
|
608
|
+
|
609
|
+
if rbln_num_mel_bins is None:
|
610
|
+
rbln_num_mel_bins = getattr(model_config, "num_mel_bins", None)
|
611
|
+
if rbln_num_mel_bins is None:
|
612
|
+
for feature_extractor in preprocessors:
|
613
|
+
if hasattr(feature_extractor, "num_mel_bins"):
|
614
|
+
rbln_num_mel_bins = feature_extractor.num_mel_bins
|
615
|
+
break
|
616
|
+
|
617
|
+
if rbln_num_mel_bins is None:
|
618
|
+
raise ValueError("`rbln_num_mel_bins` should be specified!")
|
619
|
+
|
620
|
+
if rbln_max_length is None:
|
621
|
+
rbln_max_length = getattr(model_config, "max_length", None)
|
622
|
+
for feature_extractor in preprocessors:
|
623
|
+
if hasattr(feature_extractor, "max_length"):
|
624
|
+
rbln_max_length = feature_extractor.max_length
|
625
|
+
break
|
626
|
+
|
627
|
+
if rbln_max_length is None:
|
628
|
+
raise ValueError("`rbln_max_length` should be specified!")
|
629
|
+
|
630
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
631
|
+
meta["rbln_max_length"] = rbln_max_length
|
632
|
+
meta["rbln_num_mel_bins"] = rbln_num_mel_bins
|
633
|
+
|
634
|
+
model_input_info = [
|
635
|
+
("input_values", [rbln_batch_size, rbln_max_length, rbln_num_mel_bins], "float32"),
|
636
|
+
]
|
637
|
+
|
638
|
+
rbln_runtime_config = RBLNRuntimeConfig(input_info=model_input_info, batch_size=rbln_batch_size)
|
639
|
+
|
640
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
641
|
+
[rbln_runtime_config],
|
642
|
+
_rbln_meta=meta,
|
643
|
+
)
|
644
|
+
|
645
|
+
return rbln_config
|