optimum-rbln 0.1.4__py3-none-any.whl → 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +7 -0
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +16 -98
- optimum/rbln/diffusers/models/unet_2d_condition.py +1 -1
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +9 -11
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +8 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +8 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +9 -0
- optimum/rbln/modeling_base.py +172 -100
- optimum/rbln/modeling_seq2seq.py +58 -132
- optimum/rbln/transformers/__init__.py +2 -0
- optimum/rbln/transformers/models/__init__.py +1 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/dpt/__init__.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +89 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +24 -33
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +52 -124
- optimum/rbln/transformers/models/llama/llama_architecture.py +13 -16
- optimum/rbln/transformers/models/llama/llama_architecture_cb.py +41 -36
- optimum/rbln/transformers/models/llama/modeling_llama.py +94 -120
- optimum/rbln/transformers/models/midm/modeling_midm.py +85 -121
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +53 -123
- optimum/rbln/utils/__init__.py +1 -1
- optimum/rbln/utils/import_utils.py +46 -0
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/METADATA +17 -51
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/RECORD +31 -29
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.4.dist-info → optimum_rbln-0.1.7.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_base.py
CHANGED
@@ -22,10 +22,12 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import logging
|
25
|
+
import os
|
26
|
+
import shutil
|
25
27
|
from abc import ABC, abstractmethod
|
26
28
|
from pathlib import Path
|
27
29
|
from tempfile import TemporaryDirectory
|
28
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
30
|
+
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
29
31
|
|
30
32
|
import rebel
|
31
33
|
import torch
|
@@ -50,16 +52,7 @@ from .utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
|
|
50
52
|
logger = logging.getLogger(__name__)
|
51
53
|
|
52
54
|
if TYPE_CHECKING:
|
53
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer,
|
54
|
-
|
55
|
-
|
56
|
-
def listify(var: Any):
|
57
|
-
if isinstance(var, list):
|
58
|
-
return var
|
59
|
-
elif var is not None:
|
60
|
-
return [var]
|
61
|
-
else:
|
62
|
-
return None
|
55
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
63
56
|
|
64
57
|
|
65
58
|
class RBLNBaseModel(OptimizedModel, ABC):
|
@@ -103,23 +96,22 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
103
96
|
|
104
97
|
def __init__(
|
105
98
|
self,
|
106
|
-
models: List[rebel.
|
99
|
+
models: List[rebel.Runtime],
|
107
100
|
config: "PretrainedConfig",
|
101
|
+
rbln_config: RBLNConfig,
|
108
102
|
preprocessors: Optional[List],
|
109
|
-
rbln_config: Optional[RBLNConfig],
|
110
|
-
rbln_device: Optional[List[int]] = None,
|
111
|
-
rbln_device_map: Optional[Dict[str, int]] = None,
|
112
|
-
rbln_create_runtimes: Optional[bool] = None,
|
113
103
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
104
|
+
subfolder: str = "",
|
105
|
+
rbln_compiled_models: Optional[rebel.RBLNCompiledModel] = None,
|
114
106
|
**kwargs,
|
115
107
|
):
|
116
108
|
super().__init__(models, config)
|
117
109
|
if not isinstance(self.config, PretrainedConfig): # if diffusers config
|
118
110
|
self.config = PretrainedConfig(**self.config)
|
119
111
|
|
120
|
-
self.
|
121
|
-
|
112
|
+
self.rbln_config = rbln_config
|
122
113
|
self.preprocessors = [] if preprocessors is None else preprocessors
|
114
|
+
self.compiled_models = rbln_compiled_models
|
123
115
|
|
124
116
|
# Registers the RBLNBaseModelForXXX classes into the transformers AutoModel classes to avoid warnings when creating
|
125
117
|
# a pipeline https://github.com/huggingface/transformers/blob/3d3204c025b6b5de013e07dd364208e28b4d9589/src/transformers/pipelines/base.py#L940
|
@@ -127,18 +119,6 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
127
119
|
if hasattr(self.auto_model_class, "register"):
|
128
120
|
self.auto_model_class.register(AutoConfig, self.__class__)
|
129
121
|
|
130
|
-
self.rbln_config = rbln_config
|
131
|
-
self.compiled_models: List[rebel.RBLNCompiledModel] = models
|
132
|
-
|
133
|
-
if rbln_device_map is None:
|
134
|
-
self.rbln_device_map = {}
|
135
|
-
device_val = 0 if rbln_device is None else rbln_device
|
136
|
-
for key in self.rbln_config:
|
137
|
-
self.rbln_device_map[key] = device_val
|
138
|
-
|
139
|
-
else:
|
140
|
-
self.rbln_device_map = rbln_device_map
|
141
|
-
|
142
122
|
# copied from tranformers PreTrainedModel __init__
|
143
123
|
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
144
124
|
if self.generation_config is not None:
|
@@ -146,15 +126,9 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
146
126
|
|
147
127
|
self.device = torch.device("cpu")
|
148
128
|
|
149
|
-
if rbln_create_runtimes is None:
|
150
|
-
rbln_create_runtimes = rebel.npu_is_available()
|
151
|
-
|
152
|
-
# create runtimes only if `rbln_create_runtimes` is enabled
|
153
|
-
self.runtimes = self._create_runtimes(self.rbln_device_map) if rbln_create_runtimes else UnavailableRuntime()
|
154
|
-
|
155
129
|
# FIXME :: model_save_dir is not used after initialized. (This can be used when save/load)
|
156
130
|
# This attribute is needed to keep one reference on the temporary directory, since garbage collecting it
|
157
|
-
# would end-up removing the directory containing the underlying
|
131
|
+
# would end-up removing the directory containing the underlying RBLN model.
|
158
132
|
self._model_save_dir_tempdirectory_instance = None
|
159
133
|
if isinstance(model_save_dir, TemporaryDirectory):
|
160
134
|
self._model_save_dir_tempdirectory_instance = model_save_dir
|
@@ -163,6 +137,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
163
137
|
self.model_save_dir = Path(model_save_dir)
|
164
138
|
else:
|
165
139
|
self.model_save_dir = model_save_dir
|
140
|
+
self.subfolder = subfolder
|
166
141
|
|
167
142
|
self.__post_init__(**kwargs)
|
168
143
|
|
@@ -178,11 +153,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
178
153
|
save_directory (`Union[str, Path]`):
|
179
154
|
Directory where to save the model file.
|
180
155
|
"""
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
156
|
+
real_save_dir = self.model_save_dir / self.subfolder
|
157
|
+
if os.path.exists(real_save_dir) and os.path.isdir(real_save_dir):
|
158
|
+
shutil.copytree(real_save_dir, save_directory, dirs_exist_ok=True)
|
159
|
+
self.config.save_pretrained(save_directory)
|
160
|
+
if self.generation_config is not None:
|
161
|
+
self.generation_config.save_pretrained(save_directory)
|
162
|
+
else:
|
163
|
+
raise FileNotFoundError(f"Saving compiled model failed.({real_save_dir}).")
|
186
164
|
|
187
165
|
@classmethod
|
188
166
|
def _from_pretrained(
|
@@ -196,6 +174,14 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
196
174
|
subfolder: str = "",
|
197
175
|
local_files_only: bool = False,
|
198
176
|
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
177
|
+
# Runtime - related kwargs
|
178
|
+
rbln_device: Optional[List[int]] = None,
|
179
|
+
rbln_device_map: Optional[Dict[str, int]] = None,
|
180
|
+
rbln_create_runtimes: Optional[bool] = None,
|
181
|
+
# passed from compile function
|
182
|
+
rbln_config: Optional[RBLNConfig] = None,
|
183
|
+
rbln_compiled_models: Optional[List[rebel.RBLNCompiledModel]] = None,
|
184
|
+
rbln_optimize_host_memory: Optional[bool] = None,
|
199
185
|
**kwargs,
|
200
186
|
) -> "RBLNBaseModel":
|
201
187
|
model_path = Path(model_id)
|
@@ -228,12 +214,15 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
228
214
|
)
|
229
215
|
|
230
216
|
if model_path.is_dir():
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
217
|
+
if rbln_compiled_models is None:
|
218
|
+
rbln_config = RBLNConfig.load(str(model_path))
|
219
|
+
rbln_compiled_models = [
|
220
|
+
rebel.RBLNCompiledModel(model_path / f"{compiled_model_name}.rbln")
|
221
|
+
for compiled_model_name in rbln_config
|
222
|
+
]
|
223
|
+
new_model_save_dir = model_path
|
224
|
+
else:
|
225
|
+
pass
|
237
226
|
|
238
227
|
else:
|
239
228
|
rbln_config_filename = rbln_config_filenames[0]
|
@@ -248,7 +237,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
248
237
|
local_files_only=local_files_only,
|
249
238
|
)
|
250
239
|
rbln_config = RBLNConfig.load(Path(rbln_config_cache_path).parent)
|
251
|
-
|
240
|
+
rbln_compiled_models = []
|
252
241
|
for compiled_model_name in rbln_config:
|
253
242
|
model_cache_path = hf_hub_download(
|
254
243
|
repo_id=model_id,
|
@@ -260,7 +249,7 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
260
249
|
force_download=force_download,
|
261
250
|
local_files_only=local_files_only,
|
262
251
|
)
|
263
|
-
|
252
|
+
rbln_compiled_models.append(rebel.RBLNCompiledModel(model_cache_path))
|
264
253
|
new_model_save_dir = Path(rbln_config_cache_path).parent
|
265
254
|
|
266
255
|
preprocessors = maybe_load_preprocessors(model_id, subfolder=subfolder)
|
@@ -268,17 +257,40 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
268
257
|
if model_save_dir is None:
|
269
258
|
model_save_dir = new_model_save_dir
|
270
259
|
|
260
|
+
# Create runtimes
|
261
|
+
if rbln_create_runtimes is None:
|
262
|
+
rbln_create_runtimes = rebel.npu_is_available()
|
263
|
+
if rbln_device_map is None:
|
264
|
+
rbln_device_map = {}
|
265
|
+
device_val = 0 if rbln_device is None else rbln_device
|
266
|
+
for key in rbln_config:
|
267
|
+
rbln_device_map[key] = device_val
|
268
|
+
else:
|
269
|
+
rbln_device_map = rbln_device_map
|
270
|
+
|
271
|
+
# create runtimes only if `rbln_create_runtimes` is enabled
|
272
|
+
models = (
|
273
|
+
cls._create_runtimes(rbln_compiled_models, rbln_device_map)
|
274
|
+
if rbln_create_runtimes
|
275
|
+
else UnavailableRuntime()
|
276
|
+
)
|
277
|
+
|
278
|
+
if rbln_optimize_host_memory is None:
|
279
|
+
rbln_optimize_host_memory = True
|
280
|
+
|
271
281
|
return cls(
|
272
282
|
models,
|
273
283
|
config,
|
284
|
+
rbln_config,
|
274
285
|
preprocessors,
|
275
|
-
rbln_config=rbln_config,
|
276
286
|
model_save_dir=model_save_dir,
|
287
|
+
subfolder=subfolder,
|
288
|
+
rbln_compiled_models=None if rbln_optimize_host_memory else rbln_compiled_models,
|
277
289
|
**kwargs,
|
278
290
|
)
|
279
291
|
|
280
292
|
def __repr__(self):
|
281
|
-
return repr(self.
|
293
|
+
return repr(self.model)
|
282
294
|
|
283
295
|
@classmethod
|
284
296
|
def compile(cls, model, rbln_runtime_config: Optional[RBLNRuntimeConfig] = None):
|
@@ -338,7 +350,15 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
338
350
|
def pop_rbln_kwargs_from_kwargs(kwargs: dict):
|
339
351
|
keys = list(kwargs.keys())
|
340
352
|
rbln_constructor_kwargs = {
|
341
|
-
key: kwargs.pop(key)
|
353
|
+
key: kwargs.pop(key)
|
354
|
+
for key in keys
|
355
|
+
if key
|
356
|
+
in [
|
357
|
+
"rbln_device",
|
358
|
+
"rbln_device_map",
|
359
|
+
"rbln_create_runtimes",
|
360
|
+
"rbln_optimize_host_memory",
|
361
|
+
]
|
342
362
|
}
|
343
363
|
|
344
364
|
keys = list(kwargs.keys())
|
@@ -375,9 +395,12 @@ class RBLNBaseModel(OptimizedModel, ABC):
|
|
375
395
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
376
396
|
pass
|
377
397
|
|
398
|
+
@classmethod
|
378
399
|
@abstractmethod
|
379
|
-
def _create_runtimes(
|
380
|
-
|
400
|
+
def _create_runtimes(
|
401
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
402
|
+
) -> List[rebel.Runtime]:
|
403
|
+
# compiled_models -> runtimes
|
381
404
|
pass
|
382
405
|
|
383
406
|
@classmethod
|
@@ -417,14 +440,26 @@ class RBLNModel(RBLNBaseModel):
|
|
417
440
|
```
|
418
441
|
"""
|
419
442
|
|
420
|
-
|
421
|
-
|
443
|
+
@classmethod
|
444
|
+
def update_kwargs(cls, kwargs):
|
445
|
+
"""
|
446
|
+
Update user-given kwargs to get proper pytorch model.
|
447
|
+
|
448
|
+
For example, `torchscript`=True should be set because torch.jit
|
449
|
+
does not support `transformers` output instances as module output;
|
450
|
+
"""
|
451
|
+
kwargs.update(
|
452
|
+
{
|
453
|
+
"torchscript": True,
|
454
|
+
"return_dict": False,
|
455
|
+
}
|
456
|
+
)
|
457
|
+
return kwargs
|
422
458
|
|
423
459
|
@classmethod
|
424
|
-
def
|
460
|
+
def get_pytorch_model(
|
425
461
|
cls,
|
426
|
-
model_id:
|
427
|
-
config: "PretrainedConfig",
|
462
|
+
model_id: str,
|
428
463
|
use_auth_token: Optional[Union[bool, str]] = None,
|
429
464
|
revision: Optional[str] = None,
|
430
465
|
force_download: bool = False,
|
@@ -432,16 +467,62 @@ class RBLNModel(RBLNBaseModel):
|
|
432
467
|
subfolder: str = "",
|
433
468
|
local_files_only: bool = False,
|
434
469
|
trust_remote_code: bool = False,
|
435
|
-
|
470
|
+
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
471
|
+
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
436
472
|
**kwargs,
|
437
|
-
) -> "
|
438
|
-
"""
|
439
|
-
Exports a vanilla Transformers model into a rbln-compiled Module.
|
440
|
-
"""
|
473
|
+
) -> "PreTrainedModel":
|
441
474
|
task = kwargs.pop("task", None)
|
442
475
|
if task is None:
|
443
476
|
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
444
477
|
|
478
|
+
kwargs = cls.update_kwargs(kwargs)
|
479
|
+
|
480
|
+
model = TasksManager.get_model_from_task(
|
481
|
+
task=task,
|
482
|
+
model_name_or_path=model_id,
|
483
|
+
subfolder=subfolder,
|
484
|
+
revision=revision,
|
485
|
+
framework="pt",
|
486
|
+
cache_dir=cache_dir,
|
487
|
+
use_auth_token=use_auth_token,
|
488
|
+
local_files_only=local_files_only,
|
489
|
+
force_download=force_download,
|
490
|
+
trust_remote_code=trust_remote_code,
|
491
|
+
**kwargs,
|
492
|
+
)
|
493
|
+
|
494
|
+
return model
|
495
|
+
|
496
|
+
@classmethod
|
497
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
498
|
+
model = cls.wrap_model_if_needed(model)
|
499
|
+
rbln_runtime_configs = list(rbln_config.values())
|
500
|
+
if len(rbln_runtime_configs) != 1:
|
501
|
+
raise ValueError
|
502
|
+
rbln_runtime_config = rbln_runtime_configs[0]
|
503
|
+
if len(rbln_runtime_config) != 1:
|
504
|
+
raise ValueError
|
505
|
+
rbln_runtime_config = rbln_runtime_config[0]
|
506
|
+
|
507
|
+
compiled_model = cls.compile(model, rbln_runtime_config=rbln_runtime_config)
|
508
|
+
return compiled_model
|
509
|
+
|
510
|
+
@classmethod
|
511
|
+
@torch.no_grad()
|
512
|
+
def _export(
|
513
|
+
cls,
|
514
|
+
model_id: str,
|
515
|
+
config: "PretrainedConfig",
|
516
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
517
|
+
revision: Optional[str] = None,
|
518
|
+
force_download: bool = False,
|
519
|
+
cache_dir: Optional[str] = None,
|
520
|
+
subfolder: str = "",
|
521
|
+
local_files_only: bool = False,
|
522
|
+
trust_remote_code: bool = False,
|
523
|
+
model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None,
|
524
|
+
**kwargs,
|
525
|
+
) -> "RBLNModel":
|
445
526
|
if model_save_dir is None:
|
446
527
|
save_dir = TemporaryDirectory()
|
447
528
|
save_dir_path = Path(save_dir.name)
|
@@ -453,35 +534,24 @@ class RBLNModel(RBLNBaseModel):
|
|
453
534
|
save_dir_path = Path(model_save_dir)
|
454
535
|
save_dir_path.mkdir(exist_ok=True)
|
455
536
|
|
456
|
-
kwargs.update(
|
457
|
-
{
|
458
|
-
"torchscript": True,
|
459
|
-
"return_dict": False,
|
460
|
-
}
|
461
|
-
)
|
462
|
-
|
463
537
|
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
464
538
|
|
465
|
-
model =
|
466
|
-
|
467
|
-
model_name_or_path=model_id,
|
539
|
+
model: "PreTrainedModel" = cls.get_pytorch_model(
|
540
|
+
model_id=model_id,
|
468
541
|
subfolder=subfolder,
|
469
542
|
revision=revision,
|
470
|
-
framework="pt",
|
471
543
|
cache_dir=cache_dir,
|
472
544
|
use_auth_token=use_auth_token,
|
473
545
|
local_files_only=local_files_only,
|
474
546
|
force_download=force_download,
|
475
547
|
trust_remote_code=trust_remote_code,
|
548
|
+
rbln_config_kwargs=rbln_config_kwargs,
|
549
|
+
rbln_constructor_kwargs=rbln_constructor_kwargs,
|
476
550
|
**kwargs,
|
477
551
|
)
|
478
552
|
|
479
|
-
#
|
480
|
-
|
481
|
-
model.eval()
|
482
|
-
|
483
|
-
if config is None:
|
484
|
-
config = model.config
|
553
|
+
# FIXME :: optimum passes AutoConfig.
|
554
|
+
config = model.config
|
485
555
|
|
486
556
|
if not isinstance(config, PretrainedConfig): # diffusers config
|
487
557
|
config = PretrainedConfig(**config)
|
@@ -492,20 +562,22 @@ class RBLNModel(RBLNBaseModel):
|
|
492
562
|
# Get compilation arguments
|
493
563
|
if (rbln_config := rbln_config_kwargs.pop("rbln_config", None)) is None:
|
494
564
|
rbln_config = cls.get_rbln_config(preprocessors=preprocessors, model_config=config, **rbln_config_kwargs)
|
565
|
+
compiled_model = cls.get_compiled_model(model, rbln_config=rbln_config)
|
495
566
|
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
567
|
+
# Save compiled models
|
568
|
+
(save_dir_path / subfolder).mkdir(exist_ok=True)
|
569
|
+
if isinstance(compiled_model, Iterable):
|
570
|
+
# compiled_model is an Iterable instance
|
571
|
+
for single_compiled_model, compiled_model_name in zip(compiled_model, rbln_config):
|
572
|
+
single_compiled_model.save(save_dir_path / subfolder / f"{compiled_model_name}.rbln")
|
573
|
+
compiled_models = compiled_model
|
503
574
|
|
504
|
-
|
505
|
-
|
506
|
-
|
575
|
+
else:
|
576
|
+
compiled_model.save(save_dir_path / subfolder / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
577
|
+
compiled_models = [compiled_model]
|
507
578
|
rbln_config.save(save_dir_path / subfolder)
|
508
579
|
|
580
|
+
# Instantiate
|
509
581
|
return cls._from_pretrained(
|
510
582
|
model_id=save_dir_path,
|
511
583
|
config=config,
|
@@ -516,23 +588,23 @@ class RBLNModel(RBLNBaseModel):
|
|
516
588
|
cache_dir=cache_dir,
|
517
589
|
subfolder=subfolder,
|
518
590
|
local_files_only=local_files_only,
|
591
|
+
rbln_config=rbln_config,
|
592
|
+
rbln_compiled_models=compiled_models,
|
519
593
|
**rbln_constructor_kwargs,
|
520
594
|
**kwargs,
|
521
595
|
)
|
522
596
|
|
523
|
-
|
597
|
+
@classmethod
|
598
|
+
def _create_runtimes(
|
599
|
+
cls, compiled_models: List[rebel.RBLNCompiledModel], rbln_device_map: Dict[str, int]
|
600
|
+
) -> List[rebel.Runtime]:
|
524
601
|
device = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
525
|
-
return [
|
526
|
-
compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in self.compiled_models
|
527
|
-
]
|
602
|
+
return [compiled_model.create_runtime(tensor_type="pt", device=device) for compiled_model in compiled_models]
|
528
603
|
|
529
604
|
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
530
|
-
output = self.
|
605
|
+
output = self.model[0](*args, **kwargs)
|
531
606
|
return output
|
532
607
|
|
533
|
-
def __repr__(self):
|
534
|
-
return repr(self.runtimes[0])
|
535
|
-
|
536
608
|
|
537
609
|
class RBLNModelForQuestionAnswering(RBLNModel):
|
538
610
|
model_type = "rbln_model"
|