optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
@@ -21,16 +21,20 @@
|
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
|
+
import copy
|
25
|
+
import importlib
|
26
|
+
import inspect
|
24
27
|
import logging
|
25
28
|
import os
|
26
29
|
import shutil
|
27
30
|
from abc import ABC, abstractmethod
|
28
31
|
from pathlib import Path
|
29
32
|
from tempfile import TemporaryDirectory
|
30
|
-
from typing import TYPE_CHECKING, Any, Dict,
|
33
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
31
34
|
|
32
35
|
import rebel
|
33
36
|
import torch
|
37
|
+
import transformers
|
34
38
|
from huggingface_hub import HfApi, HfFolder, hf_hub_download
|
35
39
|
from optimum.exporters import TasksManager
|
36
40
|
from optimum.modeling_base import OptimizedModel
|
@@ -46,7 +50,7 @@ from transformers import (
|
|
46
50
|
PretrainedConfig,
|
47
51
|
)
|
48
52
|
|
49
|
-
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
53
|
+
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
50
54
|
from .utils.runtime_utils import UnavailableRuntime
|
51
55
|
from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
52
56
|
|
@@ -62,7 +66,116 @@ if TYPE_CHECKING:
|
|
62
66
|
logger = logging.getLogger(__name__)
|
63
67
|
|
64
68
|
|
65
|
-
class
|
69
|
+
class SubModulesMixin:
|
70
|
+
"""
|
71
|
+
_rbln_submodules = [
|
72
|
+
{"name": "vision_tower"},
|
73
|
+
{"name": "language_model"},
|
74
|
+
]
|
75
|
+
"""
|
76
|
+
|
77
|
+
_rbln_submodules: List[Dict[str, Any]] = []
|
78
|
+
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
*,
|
82
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
83
|
+
**kwargs,
|
84
|
+
) -> None:
|
85
|
+
for submodule_meta, submodule in zip(self._rbln_submodules, rbln_submodules):
|
86
|
+
setattr(self, submodule_meta["name"], submodule)
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def _from_model(
|
90
|
+
cls,
|
91
|
+
model: "PreTrainedModel",
|
92
|
+
model_save_dir: str,
|
93
|
+
rbln_sub_configs_dict: Dict[str, Any],
|
94
|
+
rbln_kwargs: Dict[str, Any],
|
95
|
+
subfolder=None, # warning: will be ignored
|
96
|
+
**kwargs,
|
97
|
+
) -> List["RBLNBaseModel"]:
|
98
|
+
rbln_submodules = []
|
99
|
+
for submodule in cls._rbln_submodules:
|
100
|
+
submodule_name = submodule["name"]
|
101
|
+
torch_submodule: "PreTrainedModel" = getattr(model, submodule["name"])
|
102
|
+
cls_name = torch_submodule.__class__.__name__
|
103
|
+
submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), f"RBLN{cls_name}")
|
104
|
+
|
105
|
+
if submodule_name in rbln_sub_configs_dict:
|
106
|
+
kwargs["rbln_config"] = rbln_sub_configs_dict[submodule_name]
|
107
|
+
|
108
|
+
rbln_submodule = submodule_cls._export(
|
109
|
+
model_id=None,
|
110
|
+
config=torch_submodule.config,
|
111
|
+
subfolder=submodule_name,
|
112
|
+
model_save_dir=model_save_dir,
|
113
|
+
model=torch_submodule,
|
114
|
+
**rbln_kwargs,
|
115
|
+
**kwargs,
|
116
|
+
)
|
117
|
+
|
118
|
+
rbln_submodules.append(rbln_submodule)
|
119
|
+
|
120
|
+
return rbln_submodules
|
121
|
+
|
122
|
+
@classmethod
|
123
|
+
def _submodule_from_compiled_model(
|
124
|
+
cls, model_save_dir: str, rbln_sub_configs_dict: Dict[str, Any], rbln_kwargs: Dict[str, Any], **kwargs
|
125
|
+
):
|
126
|
+
rbln_submodules = []
|
127
|
+
for submodule in cls._rbln_submodules:
|
128
|
+
submodule_name = submodule["name"]
|
129
|
+
rbln_submodule_config_dict = rbln_sub_configs_dict.get(submodule_name, None)
|
130
|
+
|
131
|
+
# Get cls name for call the constructor of the rbln class
|
132
|
+
submodule_rbln_config = RBLNConfig.load(Path(model_save_dir) / submodule_name)
|
133
|
+
submodule_cls_name = submodule_rbln_config.meta["cls"]
|
134
|
+
submodule_cls: "RBLNBaseModel" = getattr(importlib.import_module("optimum.rbln"), submodule_cls_name)
|
135
|
+
|
136
|
+
config = OptimizedModel._load_config(Path(model_save_dir) / submodule_name, **kwargs)
|
137
|
+
rbln_submodule = submodule_cls._from_pretrained(
|
138
|
+
model_id=model_save_dir,
|
139
|
+
config=config,
|
140
|
+
subfolder=submodule_name,
|
141
|
+
rbln_config=rbln_submodule_config_dict,
|
142
|
+
**rbln_kwargs,
|
143
|
+
**kwargs,
|
144
|
+
)
|
145
|
+
rbln_submodules.append(rbln_submodule)
|
146
|
+
return rbln_submodules
|
147
|
+
|
148
|
+
@classmethod
|
149
|
+
def _load_submodules(
|
150
|
+
cls,
|
151
|
+
model_save_dir,
|
152
|
+
rbln_sub_configs_dict,
|
153
|
+
rbln_kwargs,
|
154
|
+
model=None,
|
155
|
+
**kwargs,
|
156
|
+
):
|
157
|
+
# Two way :
|
158
|
+
# 1. Compile from pytorch object
|
159
|
+
# 2. Load from compiled file
|
160
|
+
if model is not None:
|
161
|
+
return cls._from_model(
|
162
|
+
model=model,
|
163
|
+
model_save_dir=model_save_dir,
|
164
|
+
rbln_sub_configs_dict=rbln_sub_configs_dict,
|
165
|
+
rbln_kwargs=rbln_kwargs,
|
166
|
+
**kwargs,
|
167
|
+
)
|
168
|
+
|
169
|
+
else:
|
170
|
+
return cls._submodule_from_compiled_model(
|
171
|
+
model_save_dir=model_save_dir,
|
172
|
+
rbln_sub_configs_dict=rbln_sub_configs_dict,
|
173
|
+
rbln_kwargs=rbln_kwargs,
|
174
|
+
**kwargs,
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
class RBLNBaseModel(OptimizedModel, ABC, SubModulesMixin):
|
66
179
|
"""
|
67
180
|
An abstract base class for compiling, loading, and saving neural network models from the huggingface
|
68
181
|
transformers and diffusers libraries to run on RBLN NPU devices.
|
@@ -110,6 +223,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
110
223
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
111
224
|
subfolder: str = "",
|
112
225
|
rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
|
226
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
113
227
|
**kwargs,
|
114
228
|
):
|
115
229
|
super().__init__(models, config)
|
@@ -127,11 +241,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
127
241
|
self.auto_model_class.register(AutoConfig, self.__class__)
|
128
242
|
|
129
243
|
# copied from tranformers PreTrainedModel __init__
|
130
|
-
|
244
|
+
if self.can_generate():
|
245
|
+
gen_config_dir = model_save_dir.name if isinstance(model_save_dir, TemporaryDirectory) else model_save_dir
|
246
|
+
self.generation_config = GenerationConfig.from_pretrained(gen_config_dir, trust_remote_code=True)
|
247
|
+
else:
|
248
|
+
self.generation_config = None
|
249
|
+
|
250
|
+
# self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
131
251
|
if self.generation_config is not None:
|
132
252
|
self.generation_config.use_cache = True
|
133
253
|
|
134
254
|
self.device = torch.device("cpu")
|
255
|
+
self.training = False
|
135
256
|
|
136
257
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
137
258
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
@@ -146,11 +267,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
146
267
|
self.model_save_dir = model_save_dir
|
147
268
|
self.subfolder = subfolder
|
148
269
|
|
270
|
+
self.rbln_submodules = rbln_submodules
|
149
271
|
self.__post_init__(**kwargs)
|
150
272
|
|
151
|
-
def __post_init__(self, **kwargs):
|
152
|
-
pass
|
153
|
-
|
154
273
|
def _save_pretrained(self, save_directory: Union[str, Path]):
|
155
274
|
"""
|
156
275
|
Saves a model and its configuration file to a directory, so that it can be re-loaded using the
|
@@ -180,27 +299,18 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
180
299
|
)
|
181
300
|
|
182
301
|
@classmethod
|
183
|
-
def
|
302
|
+
def _load_compiled_model_dir(
|
184
303
|
cls,
|
185
304
|
model_id: Union[str, Path],
|
186
|
-
config: "PretrainedConfig",
|
187
305
|
use_auth_token: Optional[Union[bool, str]] = None,
|
188
306
|
revision: Optional[str] = None,
|
189
307
|
force_download: bool = False,
|
190
308
|
cache_dir: Optional[str] = None,
|
191
309
|
subfolder: str = "",
|
192
310
|
local_files_only: bool = False,
|
193
|
-
|
194
|
-
#
|
195
|
-
|
196
|
-
rbln_device_map: Optional[Dict[str, int]] = None,
|
197
|
-
rbln_create_runtimes: Optional[bool] = None,
|
198
|
-
# passed from compile function
|
199
|
-
rbln_config: Optional[RBLNConfig] = None,
|
200
|
-
rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
|
201
|
-
rbln_optimize_host_memory: Optional[bool] = None,
|
202
|
-
**kwargs,
|
203
|
-
) -> "RBLNBaseModel":
|
311
|
+
):
|
312
|
+
# Find compiled model
|
313
|
+
# And prepare or download cache folder from HF Hub if needed.
|
204
314
|
model_path = Path(model_id)
|
205
315
|
if model_path.is_dir():
|
206
316
|
model_path = model_path / subfolder
|
@@ -236,16 +346,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
236
346
|
)
|
237
347
|
|
238
348
|
if model_path.is_dir():
|
239
|
-
|
240
|
-
rbln_config = RBLNConfig.load(str(model_path))
|
241
|
-
rbln_compiled_models = [
|
242
|
-
rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
|
243
|
-
for compiled_model_name in rbln_config
|
244
|
-
]
|
245
|
-
new_model_save_dir = model_path
|
246
|
-
else:
|
247
|
-
pass
|
248
|
-
|
349
|
+
model_path = str(model_path)
|
249
350
|
else:
|
250
351
|
rbln_config_filename = rbln_config_filenames[0]
|
251
352
|
rbln_config_cache_path = hf_hub_download(
|
@@ -258,48 +359,145 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
258
359
|
force_download=force_download,
|
259
360
|
local_files_only=local_files_only,
|
260
361
|
)
|
261
|
-
|
262
|
-
rbln_compiled_models = []
|
263
|
-
for compiled_model_name in rbln_config:
|
264
|
-
model_cache_path = hf_hub_download(
|
265
|
-
repo_id=model_id,
|
266
|
-
filename=f"{compiled_model_name}.rbln",
|
267
|
-
subfolder=subfolder,
|
268
|
-
use_auth_token=use_auth_token,
|
269
|
-
revision=revision,
|
270
|
-
cache_dir=cache_dir,
|
271
|
-
force_download=force_download,
|
272
|
-
local_files_only=local_files_only,
|
273
|
-
)
|
274
|
-
rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
|
275
|
-
new_model_save_dir = Path(rbln_config_cache_path).parent
|
362
|
+
model_path = Path(rbln_config_cache_path).parent
|
276
363
|
|
277
|
-
|
364
|
+
return model_path
|
365
|
+
|
366
|
+
@classmethod
|
367
|
+
def _load_compiled_models(cls, model_path: str):
|
368
|
+
compiled_models = Path(model_path).glob("*.rbln")
|
369
|
+
rbln_compiled_models = {cm.stem: rebel.RBLNCompiledModel(cm) for cm in compiled_models}
|
370
|
+
return rbln_compiled_models
|
371
|
+
|
372
|
+
@classmethod
|
373
|
+
def _split_submodule_config(cls, rbln_config_dict: Dict[str, Any] = {}) -> Dict[str, Any]:
|
374
|
+
# {"language_model" : {"rbln_tensor_parallel_size":4}}
|
375
|
+
rbln_sub_configs_dict: Dict[str, Dict[str, Any]] = {}
|
376
|
+
|
377
|
+
# Remove submodule-configs from rbln_config
|
378
|
+
if len(cls._rbln_submodules) > 0:
|
379
|
+
keys = list(rbln_config_dict.keys())
|
380
|
+
submodule_names = [m["name"] for m in cls._rbln_submodules]
|
381
|
+
for key in keys:
|
382
|
+
if key in submodule_names:
|
383
|
+
rbln_sub_configs_dict[key] = rbln_config_dict.pop(key)
|
384
|
+
|
385
|
+
return rbln_sub_configs_dict
|
386
|
+
|
387
|
+
@classmethod
|
388
|
+
def resolve_rbln_config(cls, rbln_config: Union[RBLNConfig, Dict[str, Any]], kwargs):
|
389
|
+
if isinstance(rbln_config, RBLNConfig):
|
390
|
+
# Already resolved
|
391
|
+
return rbln_config, None
|
278
392
|
|
279
|
-
if model_save_dir is None:
|
280
|
-
model_save_dir = new_model_save_dir
|
281
|
-
|
282
|
-
# Create runtimes
|
283
|
-
if rbln_create_runtimes is None:
|
284
|
-
rbln_create_runtimes = rebel.npu_is_available()
|
285
|
-
if rbln_device_map is None:
|
286
|
-
rbln_device_map = {}
|
287
|
-
device_val = 0 if rbln_device is None else rbln_device
|
288
|
-
for key in rbln_config:
|
289
|
-
rbln_device_map[key] = device_val
|
290
393
|
else:
|
291
|
-
|
394
|
+
if rbln_config is None:
|
395
|
+
rbln_config_dict = {}
|
396
|
+
else:
|
397
|
+
rbln_config_dict = rbln_config
|
398
|
+
|
399
|
+
rbln_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
400
|
+
rbln_sub_configs_dict = cls._split_submodule_config(rbln_config_dict)
|
401
|
+
|
402
|
+
for key in rbln_config_dict:
|
403
|
+
if key in rbln_kwargs:
|
404
|
+
raise KeyError(f"duplicate key in both `rbln_config` and {key}")
|
405
|
+
|
406
|
+
merged_rbln_kwargs = copy.deepcopy(rbln_kwargs)
|
407
|
+
merged_rbln_kwargs.update(rbln_config_dict)
|
408
|
+
|
409
|
+
return (merged_rbln_kwargs, rbln_sub_configs_dict)
|
410
|
+
|
411
|
+
@classmethod
|
412
|
+
def _from_pretrained(
|
413
|
+
cls,
|
414
|
+
model_id: Union[str, Path],
|
415
|
+
config: "PretrainedConfig",
|
416
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
417
|
+
revision: Optional[str] = None,
|
418
|
+
force_download: bool = False,
|
419
|
+
cache_dir: Optional[str] = None,
|
420
|
+
subfolder: str = "",
|
421
|
+
local_files_only: bool = False,
|
422
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
423
|
+
# passed from compile function
|
424
|
+
rbln_config: Optional[RBLNConfig] = None,
|
425
|
+
rbln_compiled_models: Optional[Dict[str, rebel.RBLNCompiledModel]] = None,
|
426
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
427
|
+
**kwargs,
|
428
|
+
) -> "RBLNBaseModel":
|
429
|
+
from_export_method = isinstance(rbln_config, RBLNConfig) and rbln_compiled_models is not None
|
430
|
+
|
431
|
+
if not from_export_method:
|
432
|
+
# from compiled dir
|
433
|
+
rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
|
434
|
+
|
435
|
+
model_path_subfolder = cls._load_compiled_model_dir(
|
436
|
+
model_id=model_id,
|
437
|
+
use_auth_token=use_auth_token,
|
438
|
+
revision=revision,
|
439
|
+
force_download=force_download,
|
440
|
+
cache_dir=cache_dir,
|
441
|
+
subfolder=subfolder,
|
442
|
+
local_files_only=local_files_only,
|
443
|
+
)
|
444
|
+
|
445
|
+
rbln_config = RBLNConfig.load(model_path_subfolder)
|
446
|
+
rbln_config.update_runtime_cfg(rbln_kwargs)
|
447
|
+
|
448
|
+
rbln_compiled_models = cls._load_compiled_models(model_path_subfolder)
|
449
|
+
|
450
|
+
if len(cls._rbln_submodules) > 0:
|
451
|
+
rbln_submodules = cls._load_submodules(
|
452
|
+
model_save_dir=model_id,
|
453
|
+
rbln_sub_configs_dict=rbln_sub_configs_dict,
|
454
|
+
rbln_kwargs=rbln_kwargs,
|
455
|
+
**kwargs,
|
456
|
+
)
|
457
|
+
else:
|
458
|
+
rbln_submodules = []
|
459
|
+
|
460
|
+
if subfolder != "":
|
461
|
+
model_save_dir = Path(model_path_subfolder).absolute().parent
|
462
|
+
else:
|
463
|
+
model_save_dir = Path(model_path_subfolder).absolute()
|
464
|
+
|
465
|
+
return cls._from_compiled_models(
|
466
|
+
rbln_compiled_models=rbln_compiled_models,
|
467
|
+
rbln_config=rbln_config,
|
468
|
+
config=config,
|
469
|
+
model_save_dir=model_save_dir,
|
470
|
+
subfolder=subfolder,
|
471
|
+
rbln_submodules=rbln_submodules,
|
472
|
+
**kwargs,
|
473
|
+
)
|
474
|
+
|
475
|
+
@classmethod
|
476
|
+
def _from_compiled_models(
|
477
|
+
cls,
|
478
|
+
rbln_compiled_models: Dict[str, rebel.RBLNCompiledModel],
|
479
|
+
rbln_config: RBLNConfig,
|
480
|
+
config,
|
481
|
+
model_save_dir: str,
|
482
|
+
subfolder: str,
|
483
|
+
rbln_submodules: List["RBLNBaseModel"] = [],
|
484
|
+
**kwargs,
|
485
|
+
):
|
486
|
+
if isinstance(model_save_dir, str):
|
487
|
+
model_save_dir = Path(model_save_dir)
|
488
|
+
preprocessors = maybe_load_preprocessors(model_save_dir.name, subfolder=subfolder)
|
489
|
+
|
490
|
+
# FIXME:: Should we convert it?
|
491
|
+
compiled_model_names = [cfg.compiled_model_name for cfg in rbln_config.compile_cfgs]
|
492
|
+
rbln_compiled_models = [rbln_compiled_models[cm_name] for cm_name in compiled_model_names]
|
292
493
|
|
293
494
|
# create runtimes only if `rbln_create_runtimes` is enabled
|
294
495
|
models = (
|
295
|
-
cls._create_runtimes(rbln_compiled_models,
|
296
|
-
if
|
496
|
+
cls._create_runtimes(rbln_compiled_models, rbln_config.device_map)
|
497
|
+
if rbln_config.create_runtimes
|
297
498
|
else UnavailableRuntime()
|
298
499
|
)
|
299
500
|
|
300
|
-
if rbln_optimize_host_memory is None:
|
301
|
-
rbln_optimize_host_memory = True
|
302
|
-
|
303
501
|
return cls(
|
304
502
|
models,
|
305
503
|
config,
|
@@ -307,99 +505,65 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
307
505
|
preprocessors,
|
308
506
|
model_save_dir=model_save_dir,
|
309
507
|
subfolder=subfolder,
|
310
|
-
rbln_compiled_models=(None if
|
508
|
+
rbln_compiled_models=(None if rbln_config.optimize_host_memory else rbln_compiled_models),
|
509
|
+
rbln_submodules=rbln_submodules,
|
311
510
|
**kwargs,
|
312
511
|
)
|
313
512
|
|
314
513
|
def __repr__(self):
|
315
|
-
return repr(self.model)
|
514
|
+
return repr(self.model) + repr(self.rbln_submodules)
|
316
515
|
|
317
516
|
@classmethod
|
318
|
-
def compile(cls, model,
|
517
|
+
def compile(cls, model, rbln_compile_config: Optional[RBLNCompileConfig] = None):
|
319
518
|
compiled_model = rebel.compile_from_torch(
|
320
519
|
model,
|
321
|
-
input_info=
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
tensor_parallel_size=rbln_runtime_config.tensor_parallel_size,
|
520
|
+
input_info=rbln_compile_config.input_info,
|
521
|
+
fusion=rbln_compile_config.fusion,
|
522
|
+
npu=rbln_compile_config.npu,
|
523
|
+
tensor_parallel_size=rbln_compile_config.tensor_parallel_size,
|
326
524
|
)
|
327
525
|
return compiled_model
|
328
526
|
|
329
527
|
@classmethod
|
330
528
|
def get_rbln_config(
|
331
529
|
cls,
|
332
|
-
|
530
|
+
rbln_kwargs: Dict[str, Any],
|
531
|
+
**others,
|
333
532
|
) -> RBLNConfig:
|
334
533
|
"""
|
335
534
|
Make default rbln-config for the model.
|
336
|
-
|
337
|
-
if `input_info` specified,
|
338
|
-
other kwargs but `input_info`, `batch_size` and `fusion` are ignored.
|
339
|
-
|
340
535
|
kwargs for overriding model's config can be accepted.
|
341
|
-
|
342
536
|
Note that batch_size should be specified with proper input_info.
|
343
537
|
"""
|
344
|
-
|
345
|
-
input_info = rbln_config_kwargs.pop("rbln_input_info", None)
|
346
|
-
batch_size = rbln_config_kwargs.pop("rbln_batch_size", None)
|
347
|
-
fusion = rbln_config_kwargs.pop("rbln_fusion", None)
|
348
|
-
npu = rbln_config_kwargs.pop("rbln_npu", None)
|
349
|
-
tensor_parallel_size = rbln_config_kwargs.pop("rbln_tensor_parallel_size", None)
|
350
|
-
|
351
|
-
if input_info is not None:
|
352
|
-
rbln_runtime_config = RBLNRuntimeConfig(
|
353
|
-
input_info=input_info,
|
354
|
-
batch_size=batch_size,
|
355
|
-
fusion=fusion,
|
356
|
-
npu=npu,
|
357
|
-
tensor_parallel_size=tensor_parallel_size,
|
358
|
-
)
|
359
|
-
rbln_config = RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config])
|
360
|
-
else:
|
361
|
-
rbln_config = cls._get_rbln_config(rbln_batch_size=batch_size, **rbln_config_kwargs)
|
362
|
-
for k, rcfgs in rbln_config.items():
|
363
|
-
for rcfg in rcfgs:
|
364
|
-
rcfg: RBLNRuntimeConfig
|
365
|
-
rcfg.fusion = fusion
|
366
|
-
rcfg.npu = npu
|
367
|
-
rcfg.tensor_parallel_size = tensor_parallel_size
|
368
|
-
|
538
|
+
rbln_config = cls._get_rbln_config(**others, rbln_kwargs=rbln_kwargs)
|
369
539
|
return rbln_config
|
370
540
|
|
371
541
|
@staticmethod
|
372
|
-
def pop_rbln_kwargs_from_kwargs(kwargs:
|
542
|
+
def pop_rbln_kwargs_from_kwargs(kwargs: Dict[str, Any], runtime_only=False):
|
373
543
|
keys = list(kwargs.keys())
|
374
|
-
|
375
|
-
key: kwargs.pop(key)
|
376
|
-
for key in keys
|
377
|
-
if key
|
378
|
-
in [
|
379
|
-
"rbln_device",
|
380
|
-
"rbln_device_map",
|
381
|
-
"rbln_create_runtimes",
|
382
|
-
"rbln_optimize_host_memory",
|
383
|
-
]
|
384
|
-
}
|
544
|
+
rbln_kwargs = {key[5:]: kwargs.pop(key) for key in keys if key.startswith("rbln_")}
|
385
545
|
|
386
|
-
|
387
|
-
|
388
|
-
|
546
|
+
if runtime_only:
|
547
|
+
rbln_kwargs = {
|
548
|
+
key: value
|
549
|
+
for key, value in rbln_kwargs.items()
|
550
|
+
if key in {"create_runtimes", "optimize_host_memory", "device", "device_map"}
|
551
|
+
}
|
552
|
+
|
553
|
+
return rbln_kwargs
|
389
554
|
|
390
555
|
def can_generate(self):
|
391
556
|
return False
|
392
557
|
|
393
558
|
def to(self, *args, **kwargs):
|
394
|
-
|
559
|
+
# Do nothing
|
560
|
+
return self
|
395
561
|
|
396
562
|
def __call__(self, *args, **kwargs):
|
397
563
|
return self.forward(*args, **kwargs)
|
398
564
|
|
399
|
-
|
400
|
-
|
401
|
-
# Wrap the model if needed.
|
402
|
-
return model
|
565
|
+
def __post_init__(self, **kwargs):
|
566
|
+
self.dtype = torch.float32
|
403
567
|
|
404
568
|
@classmethod
|
405
569
|
def _from_transformers(cls, *args, **kwargs) -> "RBLNBaseModel":
|
@@ -410,8 +574,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
410
574
|
return cls._export(*args, **kwargs)
|
411
575
|
|
412
576
|
@classmethod
|
577
|
+
def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNConfig) -> torch.nn.Module:
|
578
|
+
# Wrap the model if needed.
|
579
|
+
return model
|
580
|
+
|
581
|
+
@classmethod
|
582
|
+
@abstractmethod
|
413
583
|
def _get_rbln_config(cls, **rbln_config_kwargs) -> RBLNConfig:
|
414
|
-
|
584
|
+
pass
|
415
585
|
|
416
586
|
@abstractmethod
|
417
587
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
@@ -429,20 +599,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
429
599
|
|
430
600
|
@classmethod
|
431
601
|
@abstractmethod
|
432
|
-
def _export(
|
433
|
-
cls,
|
434
|
-
model_id: Union[str, Path],
|
435
|
-
config: "PretrainedConfig",
|
436
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
437
|
-
revision: Optional[str] = None,
|
438
|
-
force_download: bool = False,
|
439
|
-
cache_dir: Optional[str] = None,
|
440
|
-
subfolder: str = "",
|
441
|
-
local_files_only: bool = False,
|
442
|
-
trust_remote_code: bool = False,
|
443
|
-
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
444
|
-
**kwargs,
|
445
|
-
):
|
602
|
+
def _export(cls, *args, **kwargs):
|
446
603
|
"""
|
447
604
|
Exports a vanilla Transformers model into a rbln-compiled Module.
|
448
605
|
"""
|
@@ -491,8 +648,8 @@ class RBLNModel(RBLNBaseModel):
|
|
491
648
|
subfolder: str = "",
|
492
649
|
local_files_only: bool = False,
|
493
650
|
trust_remote_code: bool = False,
|
494
|
-
|
495
|
-
|
651
|
+
# Some rbln-kwargs should be applied before loading torch module (i.e. quantized llm)
|
652
|
+
rbln_kwargs: Optional[Dict[str, Any]] = None,
|
496
653
|
**kwargs,
|
497
654
|
) -> "PreTrainedModel":
|
498
655
|
task = kwargs.pop("task", None)
|
@@ -517,25 +674,31 @@ class RBLNModel(RBLNBaseModel):
|
|
517
674
|
|
518
675
|
return model
|
519
676
|
|
677
|
+
@classmethod
|
678
|
+
def save_torch_artifacts(
|
679
|
+
cls,
|
680
|
+
model: "PreTrainedModel",
|
681
|
+
save_dir_path: Path,
|
682
|
+
subfolder: str,
|
683
|
+
rbln_config: RBLNConfig,
|
684
|
+
):
|
685
|
+
"""
|
686
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
687
|
+
store the torch tensor, weight, etc. in this function.
|
688
|
+
"""
|
689
|
+
|
520
690
|
@classmethod
|
521
691
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
522
692
|
model = cls.wrap_model_if_needed(model, rbln_config)
|
523
|
-
|
524
|
-
|
525
|
-
raise ValueError
|
526
|
-
rbln_runtime_config = rbln_runtime_configs[0]
|
527
|
-
if len(rbln_runtime_config) != 1:
|
528
|
-
raise ValueError
|
529
|
-
rbln_runtime_config = rbln_runtime_config[0]
|
530
|
-
|
531
|
-
compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
|
693
|
+
rbln_compile_config = rbln_config.compile_cfgs[0]
|
694
|
+
compiled_model = cls.compile(model, rbln_compile_config=rbln_compile_config)
|
532
695
|
return compiled_model
|
533
696
|
|
534
697
|
@classmethod
|
535
698
|
@torch.no_grad()
|
536
699
|
def _export(
|
537
700
|
cls,
|
538
|
-
model_id: str,
|
701
|
+
model_id: Union[str, Path],
|
539
702
|
config: "PretrainedConfig",
|
540
703
|
use_auth_token: Optional[Union[bool, str]] = None,
|
541
704
|
revision: Optional[str] = None,
|
@@ -545,8 +708,12 @@ class RBLNModel(RBLNBaseModel):
|
|
545
708
|
local_files_only: bool = False,
|
546
709
|
trust_remote_code: bool = False,
|
547
710
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
711
|
+
model: "PreTrainedModel" = None,
|
712
|
+
rbln_config: Optional[Dict[str, Any]] = None,
|
548
713
|
**kwargs,
|
549
714
|
) -> "RBLNModel":
|
715
|
+
rbln_kwargs, rbln_sub_configs_dict = cls.resolve_rbln_config(rbln_config, kwargs)
|
716
|
+
|
550
717
|
if model_save_dir is None:
|
551
718
|
save_dir = TemporaryDirectory()
|
552
719
|
save_dir_path = Path(save_dir.name)
|
@@ -558,48 +725,65 @@ class RBLNModel(RBLNBaseModel):
|
|
558
725
|
save_dir_path = Path(model_save_dir)
|
559
726
|
save_dir_path.mkdir(exist_ok=True)
|
560
727
|
|
561
|
-
|
562
|
-
|
563
|
-
|
564
|
-
|
565
|
-
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
575
|
-
|
728
|
+
# Load pytorch model if needed.
|
729
|
+
if model is None:
|
730
|
+
model: "PreTrainedModel" = cls.get_pytorch_model(
|
731
|
+
model_id=model_id,
|
732
|
+
subfolder=subfolder,
|
733
|
+
revision=revision,
|
734
|
+
cache_dir=cache_dir,
|
735
|
+
use_auth_token=use_auth_token,
|
736
|
+
local_files_only=local_files_only,
|
737
|
+
force_download=force_download,
|
738
|
+
trust_remote_code=trust_remote_code,
|
739
|
+
rbln_kwargs=rbln_kwargs,
|
740
|
+
**kwargs,
|
741
|
+
)
|
742
|
+
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
743
|
+
else:
|
744
|
+
preprocessors = []
|
576
745
|
|
577
746
|
# FIXME :: optimum passes AutoConfig.
|
578
747
|
config = model.config
|
748
|
+
if hasattr(model, "can_generate") and model.can_generate():
|
749
|
+
generation_config = model.generation_config
|
750
|
+
generation_config.save_pretrained(save_dir_path / subfolder)
|
579
751
|
|
580
752
|
if not isinstance(config, PretrainedConfig): # diffusers config
|
581
753
|
config = PretrainedConfig(**config)
|
582
|
-
|
583
754
|
config.save_pretrained(save_dir_path / subfolder)
|
584
|
-
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
585
755
|
|
586
756
|
# Get compilation arguments
|
587
|
-
|
588
|
-
|
589
|
-
|
757
|
+
rbln_config: RBLNConfig = cls.get_rbln_config(
|
758
|
+
preprocessors=preprocessors, model_config=config, rbln_kwargs=rbln_kwargs
|
759
|
+
)
|
760
|
+
compiled_model: Union[rebel.RBLNCompiledModel, Dict[str, rebel.RBLNCompiledModel]] = cls.get_compiled_model(
|
761
|
+
model, rbln_config=rbln_config
|
762
|
+
)
|
590
763
|
|
591
764
|
# Save compiled models
|
592
765
|
(save_dir_path / subfolder).mkdir(exist_ok=True)
|
593
|
-
if isinstance(compiled_model,
|
594
|
-
|
595
|
-
|
596
|
-
single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
|
766
|
+
if not isinstance(compiled_model, dict):
|
767
|
+
compiled_models = {DEFAULT_COMPILED_MODEL_NAME: compiled_model}
|
768
|
+
else:
|
597
769
|
compiled_models = compiled_model
|
770
|
+
for compiled_model_name, cm in compiled_models.items():
|
771
|
+
cm.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
|
772
|
+
rbln_config.save(save_dir_path / subfolder)
|
773
|
+
|
774
|
+
cls.save_torch_artifacts(model, save_dir_path=save_dir_path, subfolder=subfolder, rbln_config=rbln_config)
|
598
775
|
|
776
|
+
# Load submodules
|
777
|
+
if len(cls._rbln_submodules) > 0:
|
778
|
+
rbln_submodules = cls._load_submodules(
|
779
|
+
model=model,
|
780
|
+
model_save_dir=save_dir,
|
781
|
+
rbln_sub_configs_dict=rbln_sub_configs_dict,
|
782
|
+
rbln_kwargs=rbln_kwargs,
|
783
|
+
**kwargs,
|
784
|
+
)
|
599
785
|
else:
|
600
|
-
|
601
|
-
compiled_models = [compiled_model]
|
602
|
-
rbln_config.save(save_dir_path / subfolder)
|
786
|
+
rbln_submodules = []
|
603
787
|
|
604
788
|
# Instantiate
|
605
789
|
return cls._from_pretrained(
|
@@ -614,7 +798,7 @@ class RBLNModel(RBLNBaseModel):
|
|
614
798
|
local_files_only=local_files_only,
|
615
799
|
rbln_config=rbln_config,
|
616
800
|
rbln_compiled_models=compiled_models,
|
617
|
-
|
801
|
+
rbln_submodules=rbln_submodules,
|
618
802
|
**kwargs,
|
619
803
|
)
|
620
804
|
|
@@ -635,16 +819,19 @@ class RBLNModel(RBLNBaseModel):
|
|
635
819
|
class RBLNModelForQuestionAnswering(RBLNModel):
|
636
820
|
model_type = "rbln_model"
|
637
821
|
auto_model_class = AutoModelForQuestionAnswering
|
822
|
+
rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
|
638
823
|
|
639
824
|
@classmethod
|
640
825
|
def _get_rbln_config(
|
641
826
|
cls,
|
642
827
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
643
828
|
model_config: Optional["PretrainedConfig"] = None,
|
644
|
-
|
645
|
-
rbln_batch_size: Optional[int] = None,
|
646
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
829
|
+
rbln_kwargs: Dict[str, Any] = {},
|
647
830
|
) -> RBLNConfig:
|
831
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
832
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
833
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
834
|
+
|
648
835
|
if rbln_max_seq_len is None:
|
649
836
|
for tokenizer in preprocessors:
|
650
837
|
if hasattr(tokenizer, "model_max_length"):
|
@@ -656,19 +843,34 @@ class RBLNModelForQuestionAnswering(RBLNModel):
|
|
656
843
|
if rbln_batch_size is None:
|
657
844
|
rbln_batch_size = 1
|
658
845
|
|
659
|
-
if rbln_model_input_names is
|
660
|
-
|
846
|
+
if rbln_model_input_names is None:
|
847
|
+
for tokenizer in preprocessors:
|
848
|
+
if hasattr(tokenizer, "model_input_names"):
|
849
|
+
rbln_model_input_names = tokenizer.model_input_names
|
850
|
+
break
|
851
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
852
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
853
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
854
|
+
original_model_class = getattr(transformers, model_config.architectures[0])
|
855
|
+
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
856
|
+
raise ValueError(
|
857
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
858
|
+
f"and be sure to make the order of the inputs same as QuestionAnswering forward() arguments like ({list(input_names_order)})"
|
859
|
+
)
|
661
860
|
|
662
861
|
input_info = [
|
663
862
|
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
664
|
-
for model_input_name in
|
863
|
+
for model_input_name in rbln_model_input_names
|
665
864
|
]
|
666
865
|
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
866
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
867
|
+
rbln_config = RBLNConfig(
|
868
|
+
rbln_cls=cls.__name__,
|
869
|
+
compile_cfgs=[rbln_compile_config],
|
870
|
+
rbln_kwargs=rbln_kwargs,
|
871
|
+
)
|
872
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
873
|
+
return rbln_config
|
672
874
|
|
673
875
|
|
674
876
|
class RBLNModelForImageClassification(RBLNModel):
|
@@ -684,9 +886,11 @@ class RBLNModelForImageClassification(RBLNModel):
|
|
684
886
|
cls,
|
685
887
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
686
888
|
model_config: Optional["PretrainedConfig"] = None,
|
687
|
-
|
688
|
-
rbln_batch_size: Optional[int] = None,
|
889
|
+
rbln_kwargs: Dict[str, Any] = {},
|
689
890
|
) -> RBLNConfig:
|
891
|
+
rbln_image_size = rbln_kwargs.get("image_size", None)
|
892
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
893
|
+
|
690
894
|
if rbln_image_size is None:
|
691
895
|
for processor in preprocessors:
|
692
896
|
if hasattr(processor, "size"):
|
@@ -698,19 +902,19 @@ class RBLNModelForImageClassification(RBLNModel):
|
|
698
902
|
if rbln_batch_size is None:
|
699
903
|
rbln_batch_size = 1
|
700
904
|
|
905
|
+
if isinstance(rbln_image_size, int):
|
906
|
+
rbln_image_size = rbln_image_size, rbln_image_size
|
907
|
+
|
701
908
|
input_info = [
|
702
909
|
(
|
703
910
|
"pixel_values",
|
704
|
-
[rbln_batch_size, 3, rbln_image_size, rbln_image_size],
|
911
|
+
[rbln_batch_size, 3, rbln_image_size[0], rbln_image_size[1]],
|
705
912
|
"float32",
|
706
913
|
)
|
707
914
|
]
|
708
915
|
|
709
|
-
|
710
|
-
|
711
|
-
meta = {"rbln_image_size": rbln_image_size}
|
712
|
-
|
713
|
-
return RBLNConfig.from_rbln_runtime_configs([rbln_runtime_config], _rbln_meta=meta)
|
916
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
917
|
+
return RBLNConfig(rbln_cls=cls.__name__, compile_cfgs=[rbln_compile_config], rbln_kwargs=rbln_kwargs)
|
714
918
|
|
715
919
|
|
716
920
|
class RBLNModelForAudioClassification(RBLNModel):
|
@@ -734,11 +938,11 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
734
938
|
cls,
|
735
939
|
preprocessors: "AutoFeatureExtractor",
|
736
940
|
model_config: "PretrainedConfig",
|
737
|
-
|
738
|
-
rbln_max_length: Optional[int] = None,
|
739
|
-
rbln_num_mel_bins: Optional[int] = None,
|
941
|
+
rbln_kwargs: Dict[str, Any] = {},
|
740
942
|
) -> RBLNConfig:
|
741
|
-
|
943
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
944
|
+
rbln_max_length = rbln_kwargs.get("max_length", None)
|
945
|
+
rbln_num_mel_bins = rbln_kwargs.get("num_mel_bins", None)
|
742
946
|
|
743
947
|
if rbln_batch_size is None:
|
744
948
|
rbln_batch_size = 1
|
@@ -764,11 +968,7 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
764
968
|
if rbln_max_length is None:
|
765
969
|
raise ValueError("`rbln_max_length` should be specified!")
|
766
970
|
|
767
|
-
|
768
|
-
meta["rbln_max_length"] = rbln_max_length
|
769
|
-
meta["rbln_num_mel_bins"] = rbln_num_mel_bins
|
770
|
-
|
771
|
-
model_input_info = [
|
971
|
+
input_info = [
|
772
972
|
(
|
773
973
|
"input_values",
|
774
974
|
[rbln_batch_size, rbln_max_length, rbln_num_mel_bins],
|
@@ -776,13 +976,19 @@ class RBLNModelForAudioClassification(RBLNModel):
|
|
776
976
|
),
|
777
977
|
]
|
778
978
|
|
779
|
-
|
780
|
-
|
781
|
-
|
782
|
-
[
|
783
|
-
|
979
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
980
|
+
rbln_config = RBLNConfig(
|
981
|
+
rbln_cls=cls.__name__,
|
982
|
+
compile_cfgs=[rbln_compile_config],
|
983
|
+
rbln_kwargs=rbln_kwargs,
|
984
|
+
)
|
985
|
+
rbln_config.model_cfg.update(
|
986
|
+
{
|
987
|
+
"batch_size": rbln_batch_size,
|
988
|
+
"max_length": rbln_max_length,
|
989
|
+
"num_mel_bins": rbln_num_mel_bins,
|
990
|
+
}
|
784
991
|
)
|
785
|
-
|
786
992
|
return rbln_config
|
787
993
|
|
788
994
|
|
@@ -807,10 +1013,12 @@ class RBLNModelForSequenceClassification(RBLNModel):
|
|
807
1013
|
cls,
|
808
1014
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
809
1015
|
model_config: Optional["PretrainedConfig"] = None,
|
810
|
-
|
811
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
812
|
-
rbln_batch_size: Optional[int] = None,
|
1016
|
+
rbln_kwargs: Dict[str, Any] = {},
|
813
1017
|
) -> RBLNConfig:
|
1018
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
1019
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
1020
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
1021
|
+
|
814
1022
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
815
1023
|
model_config, "max_position_embeddings", None
|
816
1024
|
)
|
@@ -829,21 +1037,36 @@ class RBLNModelForSequenceClassification(RBLNModel):
|
|
829
1037
|
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
830
1038
|
|
831
1039
|
if rbln_model_input_names is None:
|
832
|
-
|
833
|
-
|
1040
|
+
for tokenizer in preprocessors:
|
1041
|
+
if hasattr(tokenizer, "model_input_names"):
|
1042
|
+
rbln_model_input_names = tokenizer.model_input_names
|
1043
|
+
break
|
1044
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
1045
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
1046
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
1047
|
+
original_model_class = getattr(transformers, model_config.architectures[0])
|
1048
|
+
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
1049
|
+
raise ValueError(
|
1050
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
1051
|
+
f"and be sure to make the order of the inputs same as SequenceClassification forward() arguments like ({list(input_names_order)})"
|
1052
|
+
)
|
834
1053
|
|
835
1054
|
if rbln_batch_size is None:
|
836
1055
|
rbln_batch_size = 1
|
1056
|
+
|
837
1057
|
input_info = [
|
838
1058
|
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
839
1059
|
for model_input_name in rbln_model_input_names
|
840
1060
|
]
|
841
1061
|
|
842
|
-
|
843
|
-
|
844
|
-
|
845
|
-
|
846
|
-
|
1062
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
1063
|
+
rbln_config = RBLNConfig(
|
1064
|
+
rbln_cls=cls.__name__,
|
1065
|
+
compile_cfgs=[rbln_compile_config],
|
1066
|
+
rbln_kwargs=rbln_kwargs,
|
1067
|
+
)
|
1068
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
1069
|
+
return rbln_config
|
847
1070
|
|
848
1071
|
|
849
1072
|
class RBLNModelForMaskedLM(RBLNModel):
|
@@ -855,10 +1078,12 @@ class RBLNModelForMaskedLM(RBLNModel):
|
|
855
1078
|
cls,
|
856
1079
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
|
857
1080
|
model_config: Optional["PretrainedConfig"] = None,
|
858
|
-
|
859
|
-
rbln_model_input_names: Optional[List[str]] = None,
|
860
|
-
rbln_batch_size: Optional[int] = None,
|
1081
|
+
rbln_kwargs: Dict[str, Any] = {},
|
861
1082
|
) -> RBLNConfig:
|
1083
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
1084
|
+
rbln_model_input_names = rbln_kwargs.get("model_input_names", None)
|
1085
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
1086
|
+
|
862
1087
|
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
863
1088
|
model_config, "max_position_embeddings", None
|
864
1089
|
)
|
@@ -877,18 +1102,33 @@ class RBLNModelForMaskedLM(RBLNModel):
|
|
877
1102
|
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
878
1103
|
|
879
1104
|
if rbln_model_input_names is None:
|
880
|
-
|
881
|
-
|
1105
|
+
for tokenizer in preprocessors:
|
1106
|
+
if hasattr(tokenizer, "model_input_names"):
|
1107
|
+
rbln_model_input_names = tokenizer.model_input_names
|
1108
|
+
break
|
1109
|
+
if rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names"):
|
1110
|
+
rbln_model_input_names = cls.rbln_model_input_names
|
1111
|
+
elif rbln_model_input_names is None and hasattr(cls, "rbln_model_input_names") is False:
|
1112
|
+
original_model_class = getattr(transformers, model_config.architectures[0])
|
1113
|
+
input_names_order = inspect.signature(original_model_class.forward).parameters.keys()
|
1114
|
+
raise ValueError(
|
1115
|
+
"Specify the model input names obtained by the tokenizer via `rbln_model_input_names`, "
|
1116
|
+
f"and be sure to make the order of the inputs same as MaskedLM forward() arguments like ({list(input_names_order)})"
|
1117
|
+
)
|
882
1118
|
|
883
1119
|
if rbln_batch_size is None:
|
884
1120
|
rbln_batch_size = 1
|
1121
|
+
|
885
1122
|
input_info = [
|
886
1123
|
(model_input_name, [rbln_batch_size, rbln_max_seq_len], "int64")
|
887
1124
|
for model_input_name in rbln_model_input_names
|
888
1125
|
]
|
889
1126
|
|
890
|
-
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
1127
|
+
rbln_compile_config = RBLNCompileConfig(input_info=input_info)
|
1128
|
+
rbln_config = RBLNConfig(
|
1129
|
+
rbln_cls=cls.__name__,
|
1130
|
+
compile_cfgs=[rbln_compile_config],
|
1131
|
+
rbln_kwargs=rbln_kwargs,
|
1132
|
+
)
|
1133
|
+
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
1134
|
+
return rbln_config
|