optimum-rbln 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +115 -0
- optimum/rbln/__version__.py +1 -0
- optimum/rbln/diffusers/__init__.py +64 -0
- optimum/rbln/diffusers/models/__init__.py +26 -0
- optimum/rbln/diffusers/models/autoencoder_kl.py +313 -0
- optimum/rbln/diffusers/models/controlnet.py +180 -0
- optimum/rbln/diffusers/models/unet_2d_condition.py +352 -0
- optimum/rbln/diffusers/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/pipelines/controlnet/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +266 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_controlnet_img2img.py +731 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +106 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +116 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +109 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +111 -0
- optimum/rbln/modeling.py +0 -0
- optimum/rbln/modeling_alias.py +49 -0
- optimum/rbln/modeling_base.py +645 -0
- optimum/rbln/modeling_config.py +169 -0
- optimum/rbln/modeling_seq2seq.py +469 -0
- optimum/rbln/transformers/__init__.py +59 -0
- optimum/rbln/transformers/generation/__init__.py +24 -0
- optimum/rbln/transformers/generation/streamers.py +122 -0
- optimum/rbln/transformers/models/__init__.py +28 -0
- optimum/rbln/transformers/models/bart/__init__.py +24 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +377 -0
- optimum/rbln/transformers/models/clip/__init__.py +24 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +116 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +24 -0
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +253 -0
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +700 -0
- optimum/rbln/transformers/models/llama/__init__.py +24 -0
- optimum/rbln/transformers/models/llama/llama_architecture.py +607 -0
- optimum/rbln/transformers/models/llama/modeling_llama.py +409 -0
- optimum/rbln/transformers/models/t5/__init__.py +24 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +439 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +24 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +121 -0
- optimum/rbln/transformers/models/whisper/__init__.py +24 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +374 -0
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +406 -0
- optimum/rbln/utils/__init__.py +25 -0
- optimum/rbln/utils/import_utils.py +28 -0
- optimum/rbln/utils/runtime_utils.py +71 -0
- optimum/rbln/utils/save_utils.py +92 -0
- optimum_rbln-0.1.0.dist-info/METADATA +144 -0
- optimum_rbln-0.1.0.dist-info/RECORD +51 -0
- optimum_rbln-0.1.0.dist-info/WHEEL +4 -0
- optimum_rbln-0.1.0.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import copy
|
25
|
+
import json
|
26
|
+
from collections import UserDict
|
27
|
+
from dataclasses import asdict, dataclass
|
28
|
+
from pathlib import Path
|
29
|
+
from typing import Any, Dict, List, Optional, Tuple
|
30
|
+
|
31
|
+
import torch
|
32
|
+
|
33
|
+
|
34
|
+
DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
35
|
+
DEFAULT_MOD_NAME = "default"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class RBLNRuntimeConfig:
|
40
|
+
compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
|
41
|
+
rbln_mod_name: str = DEFAULT_MOD_NAME
|
42
|
+
input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
|
43
|
+
batch_size: Optional[int] = None
|
44
|
+
fusion: Optional[bool] = None
|
45
|
+
npu: Optional[str] = None
|
46
|
+
tensor_parallel_size: Optional[int] = None
|
47
|
+
|
48
|
+
@staticmethod
|
49
|
+
def normalize_dtype(dtype):
|
50
|
+
"""
|
51
|
+
framework's dtype to string.
|
52
|
+
i.e. torch.float32 -> "float32"
|
53
|
+
"""
|
54
|
+
if isinstance(dtype, str):
|
55
|
+
return dtype
|
56
|
+
else:
|
57
|
+
dtype: str = repr(dtype).split(".")[-1]
|
58
|
+
if dtype.endswith("'>"): # numpy
|
59
|
+
dtype = dtype[:-2]
|
60
|
+
return dtype
|
61
|
+
|
62
|
+
def __post_init__(self):
|
63
|
+
self.input_info = [(i[0], i[1], RBLNRuntimeConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
|
64
|
+
|
65
|
+
def update(self, **kwargs):
|
66
|
+
self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
|
67
|
+
self.rbln_mod_name = kwargs.get("rbln_mod_name", self.rbln_mod_name)
|
68
|
+
self.input_info = kwargs.get("input_info", self.input_info)
|
69
|
+
self.batch_size = kwargs.get("batch_size", self.batch_size)
|
70
|
+
self.fusion = kwargs.get("fusion", self.fusion)
|
71
|
+
self.npu = kwargs.get("npu", self.npu)
|
72
|
+
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
73
|
+
return self
|
74
|
+
|
75
|
+
def get_dummy_inputs(self, fill=0):
|
76
|
+
dummy = []
|
77
|
+
for name, shape, dtype in self.input_info:
|
78
|
+
dummy.append(
|
79
|
+
torch.fill(torch.zeros(*shape, dtype=getattr(torch, dtype)), fill)
|
80
|
+
if len(shape) > 0
|
81
|
+
else torch.tensor(fill, dtype=getattr(torch, dtype))
|
82
|
+
)
|
83
|
+
return tuple(dummy)
|
84
|
+
|
85
|
+
def asdict(self):
|
86
|
+
return asdict(self)
|
87
|
+
|
88
|
+
|
89
|
+
class RBLNConfig(UserDict):
|
90
|
+
def __init__(self, runtime_cfgs: Dict[str, List[RBLNRuntimeConfig]], _rbln_meta: Dict[str, Any] = None):
|
91
|
+
"""Configurations for RBLN model compilation and inference.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
_rbln_meta (Dict[str, Any], optional):
|
95
|
+
Any rbln-specific configurations.
|
96
|
+
(i.e. max_seq_len for language models, image_size for image models).
|
97
|
+
Defaults to None.
|
98
|
+
"""
|
99
|
+
super().__init__(runtime_cfgs)
|
100
|
+
if _rbln_meta:
|
101
|
+
self.meta = _rbln_meta
|
102
|
+
else:
|
103
|
+
self.meta: Dict[str, Any] = {}
|
104
|
+
|
105
|
+
@staticmethod
|
106
|
+
def from_rbln_configs(rbln_configs: List["RBLNConfig"], names: Optional[List[str]] = None) -> "RBLNConfig":
|
107
|
+
# assume each rbln_config has exact one rbln_runtime_config
|
108
|
+
names = [None] * len(rbln_configs) if names is None else names
|
109
|
+
runtime_cfgs = []
|
110
|
+
for name, cfg in zip(names, rbln_configs):
|
111
|
+
if len(cfg) > 1:
|
112
|
+
msg = (
|
113
|
+
"`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
|
114
|
+
f"But got {len(cfg)} `RBLNRuntimeConfig`."
|
115
|
+
)
|
116
|
+
raise RuntimeError(msg)
|
117
|
+
|
118
|
+
runtime_cfg = cfg[list(cfg.keys())[0]][0]
|
119
|
+
runtime_cfg = copy.deepcopy(runtime_cfg)
|
120
|
+
if name is not None:
|
121
|
+
runtime_cfg.compiled_model_name = name
|
122
|
+
runtime_cfgs.append(runtime_cfg)
|
123
|
+
|
124
|
+
metas = [cfg.meta for cfg in rbln_configs]
|
125
|
+
merged_meta = {k: v for meta in metas for k, v in meta.items()}
|
126
|
+
|
127
|
+
return RBLNConfig.from_rbln_runtime_configs(runtime_cfgs, _rbln_meta=merged_meta)
|
128
|
+
|
129
|
+
@staticmethod
|
130
|
+
def from_rbln_runtime_configs(
|
131
|
+
rbln_runtime_configs: List[RBLNRuntimeConfig],
|
132
|
+
_rbln_meta: Dict[str, Any] = None,
|
133
|
+
) -> "RBLNConfig":
|
134
|
+
cfgs: Dict[str, List[RBLNRuntimeConfig]] = {}
|
135
|
+
for rbln_runtime_config in rbln_runtime_configs:
|
136
|
+
if rbln_runtime_config.compiled_model_name in cfgs:
|
137
|
+
cfgs[rbln_runtime_config.compiled_model_name].append(rbln_runtime_config)
|
138
|
+
else:
|
139
|
+
cfgs[rbln_runtime_config.compiled_model_name] = [rbln_runtime_config]
|
140
|
+
return RBLNConfig(cfgs, _rbln_meta=_rbln_meta)
|
141
|
+
|
142
|
+
def save(self, dir_path: str):
|
143
|
+
dir_path = Path(dir_path)
|
144
|
+
data = self.asdict()
|
145
|
+
data.update({"rbln_config_meta": self.meta})
|
146
|
+
with open(dir_path / "rbln_config.json", "w") as jsonf:
|
147
|
+
json.dump(data, jsonf, indent=2)
|
148
|
+
|
149
|
+
@staticmethod
|
150
|
+
def load(dir_path: str) -> "RBLNConfig":
|
151
|
+
dir_path = Path(dir_path)
|
152
|
+
with open(dir_path / "rbln_config.json", "r") as jsonf:
|
153
|
+
config_file = json.load(jsonf)
|
154
|
+
return RBLNConfig.fromdict(config_file)
|
155
|
+
|
156
|
+
def asdict(self):
|
157
|
+
dic = {k: [asdict(cfg) for cfg in cfgs] for k, cfgs in self.data.items()}
|
158
|
+
return dic
|
159
|
+
|
160
|
+
@staticmethod
|
161
|
+
def fromdict(dic: dict):
|
162
|
+
runtime_cfgs = {
|
163
|
+
k: [RBLNRuntimeConfig(**cfg) for cfg in cfgs] for k, cfgs in dic.items() if k != "rbln_config_meta"
|
164
|
+
}
|
165
|
+
if "rbln_config_meta" in dic:
|
166
|
+
meta = dic["rbln_config_meta"]
|
167
|
+
else:
|
168
|
+
meta = None
|
169
|
+
return RBLNConfig(runtime_cfgs, _rbln_meta=meta)
|
@@ -0,0 +1,469 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import inspect
|
25
|
+
import logging
|
26
|
+
from pathlib import Path
|
27
|
+
from tempfile import TemporaryDirectory
|
28
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
29
|
+
|
30
|
+
import rebel
|
31
|
+
import torch
|
32
|
+
from optimum.exporters import TasksManager
|
33
|
+
from transformers import (
|
34
|
+
AutoModelForSeq2SeqLM,
|
35
|
+
BartConfig,
|
36
|
+
BartForConditionalGeneration,
|
37
|
+
PretrainedConfig,
|
38
|
+
T5ForConditionalGeneration,
|
39
|
+
)
|
40
|
+
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
41
|
+
|
42
|
+
from .modeling_base import RBLNBaseModel
|
43
|
+
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNConfig, RBLNRuntimeConfig
|
44
|
+
from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
|
45
|
+
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
46
|
+
from .utils.runtime_utils import RBLNPytorchRuntime
|
47
|
+
from .utils.save_utils import maybe_save_preprocessors
|
48
|
+
|
49
|
+
|
50
|
+
logger = logging.getLogger(__name__)
|
51
|
+
|
52
|
+
if TYPE_CHECKING:
|
53
|
+
from transformers import (
|
54
|
+
AutoFeatureExtractor,
|
55
|
+
AutoProcessor,
|
56
|
+
AutoTokenizer,
|
57
|
+
PretrainedConfig,
|
58
|
+
)
|
59
|
+
|
60
|
+
|
61
|
+
class RBLNRuntimeEncoder(RBLNPytorchRuntime):
|
62
|
+
mandatory_members = ["main_input_name"]
|
63
|
+
|
64
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
65
|
+
_ = super().forward(*args, **kwargs)
|
66
|
+
# Just indicates that it is not None
|
67
|
+
return BaseModelOutput(last_hidden_state=torch.tensor([1.0]))
|
68
|
+
|
69
|
+
|
70
|
+
class RBLNRuntimeDecoder(RBLNPytorchRuntime):
|
71
|
+
mandatory_members = ["main_input_name"]
|
72
|
+
|
73
|
+
def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
|
74
|
+
outputs = super().forward(*args, **kwargs)
|
75
|
+
return Seq2SeqLMOutput(logits=outputs)
|
76
|
+
|
77
|
+
|
78
|
+
class RBLNModelForSeq2SeqLM(RBLNBaseModel):
|
79
|
+
"""
|
80
|
+
This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
|
81
|
+
This model inherits from [`RBLNBaseModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
|
82
|
+
|
83
|
+
A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
|
84
|
+
It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
|
85
|
+
- transferring the checkpoint weights of the original into an optimized RBLN graph,
|
86
|
+
- compiling the resulting graph using the RBLN compiler.
|
87
|
+
|
88
|
+
Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
|
89
|
+
"""
|
90
|
+
|
91
|
+
model_type = "rbln_model"
|
92
|
+
auto_model_class = AutoModelForSeq2SeqLM
|
93
|
+
|
94
|
+
def __post_init__(self, **kwargs):
|
95
|
+
self.model_dim = self.config.d_model
|
96
|
+
self.batch_size = self.rbln_config[DEFAULT_COMPILED_MODEL_NAME][0].batch_size
|
97
|
+
self.enc_max_seq_len = self.rbln_config.meta["rbln_enc_max_seq_len"]
|
98
|
+
self.dec_max_seq_len = self.rbln_config.meta["rbln_dec_max_seq_len"]
|
99
|
+
self.pad_token_id = self.rbln_config.meta["rbln_pad_token_id"]
|
100
|
+
self.encoder = RBLNRuntimeEncoder(runtime=self.runtimes[0], main_input_name="input_ids")
|
101
|
+
self.decoder = RBLNRuntimeDecoder(runtime=self.runtimes[1], main_input_name="input_ids")
|
102
|
+
|
103
|
+
def can_generate(self):
|
104
|
+
return True
|
105
|
+
|
106
|
+
def get_encoder(self):
|
107
|
+
return self.encoder
|
108
|
+
|
109
|
+
def get_decoder(self):
|
110
|
+
return self.decoder
|
111
|
+
|
112
|
+
def __getattr__(self, __name: str) -> Any:
|
113
|
+
def redirect(func):
|
114
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
115
|
+
|
116
|
+
if "T5ForConditionalGeneration" == self.config.architectures:
|
117
|
+
val = getattr(T5ForConditionalGeneration, __name)
|
118
|
+
else:
|
119
|
+
val = getattr(BartForConditionalGeneration, __name)
|
120
|
+
|
121
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
122
|
+
return redirect(val)
|
123
|
+
return val
|
124
|
+
|
125
|
+
def prepare_inputs_for_generation(
|
126
|
+
self,
|
127
|
+
input_ids,
|
128
|
+
past_key_values=None,
|
129
|
+
attention_mask=None,
|
130
|
+
decoder_attention_mask=None,
|
131
|
+
**kwargs,
|
132
|
+
):
|
133
|
+
max_seq_len = self.dec_max_seq_len
|
134
|
+
cur_seq_len = input_ids.shape[-1]
|
135
|
+
decoder_batch_size = input_ids.shape[0]
|
136
|
+
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
137
|
+
|
138
|
+
# In greedy decoding
|
139
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.int64)
|
140
|
+
decoder_attention_mask[:, :cur_seq_len] = 1
|
141
|
+
cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
|
142
|
+
|
143
|
+
return {
|
144
|
+
"decoder_input_ids": input_ids,
|
145
|
+
"past_key_values": past_key_values,
|
146
|
+
"attention_mask": attention_mask,
|
147
|
+
"decoder_attention_mask": decoder_attention_mask,
|
148
|
+
"cache_position": cache_position,
|
149
|
+
}
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
def _export(
|
153
|
+
cls,
|
154
|
+
model_id: str,
|
155
|
+
config: "PretrainedConfig",
|
156
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
157
|
+
revision: Optional[str] = None,
|
158
|
+
force_download: bool = False,
|
159
|
+
cache_dir: Optional[str] = None,
|
160
|
+
subfolder: str = "",
|
161
|
+
local_files_only: bool = False,
|
162
|
+
trust_remote_code: bool = False,
|
163
|
+
**kwargs,
|
164
|
+
) -> "AutoModelForSeq2SeqLM":
|
165
|
+
"""
|
166
|
+
Exports a vanilla Transformers model into a rbln-compiled Module.
|
167
|
+
"""
|
168
|
+
task = kwargs.pop("task", None)
|
169
|
+
if task is None:
|
170
|
+
task = TasksManager.infer_task_from_model(cls.auto_model_class)
|
171
|
+
|
172
|
+
save_dir = TemporaryDirectory()
|
173
|
+
save_dir_path = Path(save_dir.name)
|
174
|
+
|
175
|
+
kwargs.update(
|
176
|
+
{
|
177
|
+
"torchscript": True,
|
178
|
+
"return_dict": False,
|
179
|
+
"use_cache": False,
|
180
|
+
}
|
181
|
+
)
|
182
|
+
|
183
|
+
rbln_config_kwargs, rbln_constructor_kwargs = cls.pop_rbln_kwargs_from_kwargs(kwargs)
|
184
|
+
|
185
|
+
model: AutoModelForSeq2SeqLM = TasksManager.get_model_from_task(
|
186
|
+
task=task,
|
187
|
+
model_name_or_path=model_id,
|
188
|
+
subfolder=subfolder,
|
189
|
+
revision=revision,
|
190
|
+
framework="pt",
|
191
|
+
cache_dir=cache_dir,
|
192
|
+
use_auth_token=use_auth_token,
|
193
|
+
local_files_only=local_files_only,
|
194
|
+
force_download=force_download,
|
195
|
+
trust_remote_code=trust_remote_code,
|
196
|
+
**kwargs,
|
197
|
+
)
|
198
|
+
|
199
|
+
if config is None:
|
200
|
+
config = model.config
|
201
|
+
|
202
|
+
config.save_pretrained(save_dir_path)
|
203
|
+
preprocessors = maybe_save_preprocessors(model_id, save_dir_path, src_subfolder=subfolder)
|
204
|
+
|
205
|
+
# Get compilation arguments
|
206
|
+
if rbln_config_kwargs.get("rbln_config", None) is None:
|
207
|
+
rbln_config = cls.get_rbln_config(
|
208
|
+
preprocessors=preprocessors, model_config=model.config, **rbln_config_kwargs
|
209
|
+
)
|
210
|
+
|
211
|
+
def optimized_models(model):
|
212
|
+
if isinstance(model, T5ForConditionalGeneration):
|
213
|
+
encoder_model = T5EncoderWrapper(model).eval()
|
214
|
+
decoder_model = T5DecoderWrapper(model).eval()
|
215
|
+
elif isinstance(model, BartForConditionalGeneration):
|
216
|
+
encoder_model = BartEncoderWrapper(model).eval()
|
217
|
+
decoder_model = BartDecoderWrapper(model).eval()
|
218
|
+
else:
|
219
|
+
raise ValueError(f"{model.__class__.__name__} is not supported yet.")
|
220
|
+
|
221
|
+
return encoder_model, decoder_model
|
222
|
+
|
223
|
+
def compile():
|
224
|
+
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
225
|
+
|
226
|
+
wrapped_encoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
227
|
+
wrapped_encoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
228
|
+
wrapped_encoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
229
|
+
|
230
|
+
wrapped_decoder.encoder_max_length = rbln_config.meta["rbln_enc_max_seq_len"]
|
231
|
+
wrapped_decoder.decoder_max_length = rbln_config.meta["rbln_dec_max_seq_len"]
|
232
|
+
wrapped_decoder.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
233
|
+
|
234
|
+
enc_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][0]
|
235
|
+
dec_rbln_runtime_config = rbln_config[DEFAULT_COMPILED_MODEL_NAME][1]
|
236
|
+
|
237
|
+
if isinstance(model, T5ForConditionalGeneration):
|
238
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=1)
|
239
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=1)
|
240
|
+
else:
|
241
|
+
enc_example_inputs = enc_rbln_runtime_config.get_dummy_inputs(fill=0)
|
242
|
+
dec_example_inputs = dec_rbln_runtime_config.get_dummy_inputs(fill=0)
|
243
|
+
|
244
|
+
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs)
|
245
|
+
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs)
|
246
|
+
|
247
|
+
enc_ir = rebel.torchscript_to_ir(
|
248
|
+
enc_scripted_model,
|
249
|
+
input_names=[v[0] for v in enc_rbln_runtime_config.input_info],
|
250
|
+
name=enc_rbln_runtime_config.rbln_mod_name,
|
251
|
+
)
|
252
|
+
dec_ir = rebel.torchscript_to_ir(
|
253
|
+
dec_scripted_model,
|
254
|
+
input_names=[v[0] for v in dec_rbln_runtime_config.input_info],
|
255
|
+
name=dec_rbln_runtime_config.rbln_mod_name,
|
256
|
+
)
|
257
|
+
dec_ir.decoder_batch_size = rbln_config.meta["rbln_batch_size"]
|
258
|
+
|
259
|
+
connections = [
|
260
|
+
(enc_ir.outputs[0], dec_ir.inputs[5]),
|
261
|
+
(dec_ir.outputs[1], dec_ir.inputs[4]),
|
262
|
+
]
|
263
|
+
compiled_model = rebel.compile(
|
264
|
+
enc_ir,
|
265
|
+
dec_ir,
|
266
|
+
connections=connections,
|
267
|
+
fusion=enc_rbln_runtime_config.fusion,
|
268
|
+
npu=enc_rbln_runtime_config.npu,
|
269
|
+
tensor_parallel_size=enc_rbln_runtime_config.tensor_parallel_size,
|
270
|
+
)
|
271
|
+
compiled_model.save(save_dir_path / f"{DEFAULT_COMPILED_MODEL_NAME}.rbln")
|
272
|
+
|
273
|
+
compile()
|
274
|
+
|
275
|
+
rbln_config.save(save_dir_path)
|
276
|
+
|
277
|
+
return cls._from_pretrained(
|
278
|
+
model_id=save_dir_path,
|
279
|
+
config=config,
|
280
|
+
model_save_dir=save_dir,
|
281
|
+
**rbln_constructor_kwargs,
|
282
|
+
**kwargs,
|
283
|
+
)
|
284
|
+
|
285
|
+
@classmethod
|
286
|
+
def _get_rbln_config(
|
287
|
+
cls,
|
288
|
+
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
289
|
+
model_config: "PretrainedConfig",
|
290
|
+
rbln_enc_max_seq_len: Optional[int] = None,
|
291
|
+
rbln_dec_max_seq_len: Optional[int] = None,
|
292
|
+
rbln_batch_size: Optional[int] = 1,
|
293
|
+
) -> RBLNConfig:
|
294
|
+
meta = {}
|
295
|
+
|
296
|
+
if isinstance(model_config, BartConfig):
|
297
|
+
n_layer = model_config.decoder_layers
|
298
|
+
n_head = model_config.decoder_attention_heads
|
299
|
+
d_kv = model_config.d_model // model_config.encoder_attention_heads
|
300
|
+
else:
|
301
|
+
n_layer = model_config.num_layers
|
302
|
+
n_head = model_config.num_heads
|
303
|
+
d_kv = model_config.d_kv
|
304
|
+
|
305
|
+
max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
|
306
|
+
model_config, "max_position_embeddings", None
|
307
|
+
)
|
308
|
+
|
309
|
+
rbln_pad_token_id = getattr(model_config, "pad_token_id", None)
|
310
|
+
if rbln_pad_token_id is None:
|
311
|
+
rbln_pad_token_id = getattr(model_config, "bos_token_id", None)
|
312
|
+
if rbln_pad_token_id is None:
|
313
|
+
rbln_pad_token_id = getattr(model_config, "eos_token_id", None)
|
314
|
+
if rbln_pad_token_id is None:
|
315
|
+
rbln_pad_token_id = -1
|
316
|
+
|
317
|
+
if rbln_enc_max_seq_len is None:
|
318
|
+
rbln_enc_max_seq_len = max_position_embeddings
|
319
|
+
if rbln_enc_max_seq_len is None:
|
320
|
+
for tokenizer in preprocessors:
|
321
|
+
if hasattr(tokenizer, "model_max_length"):
|
322
|
+
rbln_enc_max_seq_len = tokenizer.model_max_length
|
323
|
+
break
|
324
|
+
if rbln_enc_max_seq_len is None:
|
325
|
+
raise ValueError("`rbln_enc_max_seq_len` should be specified!")
|
326
|
+
if max_position_embeddings is not None and rbln_enc_max_seq_len > max_position_embeddings:
|
327
|
+
raise ValueError("`rbln_enc_max_seq_len` should be less or equal than max_position_embeddings!")
|
328
|
+
|
329
|
+
if rbln_dec_max_seq_len is None:
|
330
|
+
rbln_dec_max_seq_len = max_position_embeddings
|
331
|
+
if rbln_dec_max_seq_len is None:
|
332
|
+
for tokenizer in preprocessors:
|
333
|
+
if hasattr(tokenizer, "model_max_length"):
|
334
|
+
rbln_dec_max_seq_len = tokenizer.model_max_length
|
335
|
+
break
|
336
|
+
if rbln_dec_max_seq_len is None:
|
337
|
+
raise ValueError("`rbln_dec_max_seq_len` should be specified!")
|
338
|
+
|
339
|
+
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
340
|
+
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
341
|
+
|
342
|
+
meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
|
343
|
+
meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
|
344
|
+
meta["rbln_batch_size"] = rbln_batch_size
|
345
|
+
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
346
|
+
|
347
|
+
# model input info
|
348
|
+
enc_input_info = [
|
349
|
+
("input_ids", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
|
350
|
+
("attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
|
351
|
+
]
|
352
|
+
|
353
|
+
dec_input_info = [
|
354
|
+
("input_ids", [rbln_batch_size, 1], "int64"),
|
355
|
+
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "int64"),
|
356
|
+
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "int64"),
|
357
|
+
(
|
358
|
+
"cache_position",
|
359
|
+
[],
|
360
|
+
"int32",
|
361
|
+
),
|
362
|
+
]
|
363
|
+
dec_input_info.extend(
|
364
|
+
[
|
365
|
+
(
|
366
|
+
"self_key_value_states",
|
367
|
+
[
|
368
|
+
n_layer * 2,
|
369
|
+
rbln_batch_size,
|
370
|
+
n_head,
|
371
|
+
rbln_dec_max_seq_len,
|
372
|
+
d_kv,
|
373
|
+
],
|
374
|
+
"float32",
|
375
|
+
)
|
376
|
+
]
|
377
|
+
)
|
378
|
+
dec_input_info.extend(
|
379
|
+
[
|
380
|
+
(
|
381
|
+
"cross_key_value_states",
|
382
|
+
[
|
383
|
+
n_layer * 2,
|
384
|
+
rbln_batch_size,
|
385
|
+
n_head,
|
386
|
+
rbln_enc_max_seq_len,
|
387
|
+
d_kv,
|
388
|
+
],
|
389
|
+
"float32",
|
390
|
+
)
|
391
|
+
]
|
392
|
+
)
|
393
|
+
enc_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="encoder", input_info=enc_input_info)
|
394
|
+
dec_rbln_runtime_config = RBLNRuntimeConfig(rbln_mod_name="decoder", input_info=dec_input_info)
|
395
|
+
|
396
|
+
rbln_config = RBLNConfig.from_rbln_runtime_configs(
|
397
|
+
[enc_rbln_runtime_config, dec_rbln_runtime_config],
|
398
|
+
_rbln_meta=meta,
|
399
|
+
)
|
400
|
+
|
401
|
+
return rbln_config
|
402
|
+
|
403
|
+
def _create_runtimes(self, rbln_device_map: Dict[str, int]) -> List[rebel.Runtime]:
|
404
|
+
device_val = rbln_device_map[DEFAULT_COMPILED_MODEL_NAME]
|
405
|
+
return [
|
406
|
+
self.compiled_models[0].create_runtime("encoder", tensor_type="pt", device=device_val),
|
407
|
+
self.compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
408
|
+
]
|
409
|
+
|
410
|
+
def forward(
|
411
|
+
self,
|
412
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
413
|
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
414
|
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
415
|
+
cache_position: Optional[torch.Tensor] = None,
|
416
|
+
**kwargs,
|
417
|
+
) -> Tuple[torch.FloatTensor]:
|
418
|
+
decoder_output = self.decoder(
|
419
|
+
input_ids=decoder_input_ids,
|
420
|
+
attention_mask=decoder_attention_mask,
|
421
|
+
encoder_attention_mask=attention_mask,
|
422
|
+
cache_position=cache_position,
|
423
|
+
)
|
424
|
+
lm_logits = decoder_output.logits
|
425
|
+
|
426
|
+
return Seq2SeqLMOutput(logits=lm_logits)
|
427
|
+
|
428
|
+
def __repr__(self):
|
429
|
+
return repr(self.runtimes[0]) + "\n" + repr(self.runtimes[1])
|
430
|
+
|
431
|
+
def _prepare_encoder_decoder_kwargs_for_generation(
|
432
|
+
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
433
|
+
) -> Dict[str, Any]:
|
434
|
+
|
435
|
+
########## thkim change start ###################
|
436
|
+
# padding input_ids & attention_mask regardless of user's tokenizer usage
|
437
|
+
batch_size, input_len = inputs_tensor.shape
|
438
|
+
inputs_tensor = torch.nn.functional.pad(
|
439
|
+
inputs_tensor, (0, self.enc_max_seq_len - input_len), value=self.pad_token_id
|
440
|
+
)
|
441
|
+
model_kwargs["attention_mask"] = torch.nn.functional.pad(
|
442
|
+
model_kwargs["attention_mask"], (0, self.enc_max_seq_len - input_len), value=0
|
443
|
+
)
|
444
|
+
########## thkim change end ###################
|
445
|
+
|
446
|
+
# 1. get encoder
|
447
|
+
encoder = self.get_encoder()
|
448
|
+
|
449
|
+
# 2. Prepare encoder args and encoder kwargs from model kwargs.
|
450
|
+
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
451
|
+
encoder_kwargs = {
|
452
|
+
argument: value
|
453
|
+
for argument, value in model_kwargs.items()
|
454
|
+
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
455
|
+
}
|
456
|
+
encoder_signature = set(inspect.signature(encoder.forward).parameters)
|
457
|
+
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
|
458
|
+
if not encoder_accepts_wildcard:
|
459
|
+
encoder_kwargs = {
|
460
|
+
argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
|
461
|
+
}
|
462
|
+
|
463
|
+
# 3. make sure that encoder returns `ModelOutput`
|
464
|
+
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
465
|
+
encoder_kwargs["return_dict"] = True
|
466
|
+
encoder_kwargs[model_input_name] = inputs_tensor
|
467
|
+
model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs)
|
468
|
+
|
469
|
+
return model_kwargs
|
@@ -0,0 +1,59 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
from typing import TYPE_CHECKING
|
25
|
+
|
26
|
+
from transformers.utils import _LazyModule
|
27
|
+
|
28
|
+
|
29
|
+
_import_structure = {
|
30
|
+
"generation": ["BatchTextIteratorStreamer"],
|
31
|
+
"models": [
|
32
|
+
"RBLNCLIPTextModel",
|
33
|
+
"RBLNCLIPTextModelWithProjection",
|
34
|
+
"RBLNGPT2LMHeadModel",
|
35
|
+
"RBLNWav2Vec2ForCTC",
|
36
|
+
"RBLNWhisperForConditionalGeneration",
|
37
|
+
"RBLNLlamaForCausalLM",
|
38
|
+
],
|
39
|
+
}
|
40
|
+
|
41
|
+
if TYPE_CHECKING:
|
42
|
+
from .generation import BatchTextIteratorStreamer
|
43
|
+
from .models import (
|
44
|
+
RBLNCLIPTextModel,
|
45
|
+
RBLNCLIPTextModelWithProjection,
|
46
|
+
RBLNGPT2LMHeadModel,
|
47
|
+
RBLNLlamaForCausalLM,
|
48
|
+
RBLNWav2Vec2ForCTC,
|
49
|
+
RBLNWhisperForConditionalGeneration,
|
50
|
+
)
|
51
|
+
else:
|
52
|
+
import sys
|
53
|
+
|
54
|
+
sys.modules[__name__] = _LazyModule(
|
55
|
+
__name__,
|
56
|
+
globals()["__file__"],
|
57
|
+
_import_structure,
|
58
|
+
module_spec=__spec__,
|
59
|
+
)
|