optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
optimum/rbln/modeling_config.py
CHANGED
@@ -23,24 +23,38 @@
|
|
23
23
|
|
24
24
|
import copy
|
25
25
|
import json
|
26
|
-
from collections import UserDict
|
27
26
|
from dataclasses import asdict, dataclass
|
28
27
|
from pathlib import Path
|
29
28
|
from typing import Any, Dict, List, Optional, Tuple
|
30
29
|
|
30
|
+
import rebel
|
31
31
|
import torch
|
32
32
|
|
33
|
+
from .__version__ import __version__
|
34
|
+
from .utils.runtime_utils import ContextRblnConfig
|
35
|
+
|
33
36
|
|
34
37
|
DEFAULT_COMPILED_MODEL_NAME = "compiled_model"
|
35
38
|
DEFAULT_MOD_NAME = "default"
|
36
39
|
|
37
40
|
|
38
41
|
@dataclass
|
39
|
-
class
|
42
|
+
class RBLNCompileConfig:
|
43
|
+
"""
|
44
|
+
Configuration for RBLN compilation.
|
45
|
+
|
46
|
+
Attributes:
|
47
|
+
compiled_model_name (str): Name of the compiled model.
|
48
|
+
mod_name (str): Name of the RBLN module.
|
49
|
+
input_info (List[Tuple[str, Tuple[int], Optional[str]]]): Information about input tensors.
|
50
|
+
fusion (Optional[bool]): Whether to use fusion optimization.
|
51
|
+
npu (Optional[str]): NPU configuration.
|
52
|
+
tensor_parallel_size (Optional[int]): Size for tensor parallelism.
|
53
|
+
"""
|
54
|
+
|
40
55
|
compiled_model_name: str = DEFAULT_COMPILED_MODEL_NAME
|
41
|
-
|
56
|
+
mod_name: str = DEFAULT_MOD_NAME
|
42
57
|
input_info: List[Tuple[str, Tuple[int], Optional[str]]] = None
|
43
|
-
batch_size: Optional[int] = None
|
44
58
|
fusion: Optional[bool] = None
|
45
59
|
npu: Optional[str] = None
|
46
60
|
tensor_parallel_size: Optional[int] = None
|
@@ -48,8 +62,14 @@ class RBLNRuntimeConfig:
|
|
48
62
|
@staticmethod
|
49
63
|
def normalize_dtype(dtype):
|
50
64
|
"""
|
51
|
-
framework
|
65
|
+
Convert framework-specific dtype to string representation.
|
52
66
|
i.e. torch.float32 -> "float32"
|
67
|
+
|
68
|
+
Args:
|
69
|
+
dtype: The input dtype (can be string, torch dtype, or numpy dtype).
|
70
|
+
|
71
|
+
Returns:
|
72
|
+
str: The normalized string representation of the dtype.
|
53
73
|
"""
|
54
74
|
if isinstance(dtype, str):
|
55
75
|
return dtype
|
@@ -60,13 +80,12 @@ class RBLNRuntimeConfig:
|
|
60
80
|
return dtype
|
61
81
|
|
62
82
|
def __post_init__(self):
|
63
|
-
self.input_info = [(i[0], i[1],
|
83
|
+
self.input_info = [(i[0], i[1], RBLNCompileConfig.normalize_dtype(i[2]) or "float32") for i in self.input_info]
|
64
84
|
|
65
|
-
def update(self,
|
85
|
+
def update(self, kwargs: Dict[str, Any]):
|
66
86
|
self.compiled_model_name = kwargs.get("compiled_model_name", self.compiled_model_name)
|
67
|
-
self.
|
87
|
+
self.mod_name = kwargs.get("mod_name", self.mod_name)
|
68
88
|
self.input_info = kwargs.get("input_info", self.input_info)
|
69
|
-
self.batch_size = kwargs.get("batch_size", self.batch_size)
|
70
89
|
self.fusion = kwargs.get("fusion", self.fusion)
|
71
90
|
self.npu = kwargs.get("npu", self.npu)
|
72
91
|
self.tensor_parallel_size = kwargs.get("tensor_parallel_size", self.tensor_parallel_size)
|
@@ -86,84 +105,140 @@ class RBLNRuntimeConfig:
|
|
86
105
|
return asdict(self)
|
87
106
|
|
88
107
|
|
89
|
-
|
90
|
-
|
91
|
-
"""Configurations for RBLN model compilation and inference.
|
108
|
+
RUNTIME_KEYWORDS = ["create_runtimes", "optimize_host_memory", "device", "device_map"]
|
109
|
+
COMPILE_KEYWORDS = ["compiled_model_name", "mod_name", "input_info", "fusion", "npu", "tensor_parallel_size"]
|
92
110
|
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
111
|
+
|
112
|
+
class RBLNConfig:
|
113
|
+
"""
|
114
|
+
Configuration for single RBLN OptimizedModel, representing multiple compiled models.
|
115
|
+
|
116
|
+
Attributes:
|
117
|
+
compile_cfgs (List[RBLNCompileConfig]): Compilation configurations.
|
118
|
+
meta (dict): Metadata including version and class information.
|
119
|
+
runtime_cfg (dict): Runtime-specific configuration.
|
120
|
+
"""
|
121
|
+
|
122
|
+
# It represents multiple compiled model, one of each can have multiple runtimes.
|
123
|
+
def __init__(
|
124
|
+
self,
|
125
|
+
rbln_cls,
|
126
|
+
compile_cfgs: List[RBLNCompileConfig],
|
127
|
+
rbln_kwargs=None,
|
128
|
+
meta=None,
|
129
|
+
) -> None:
|
130
|
+
if rbln_kwargs is None:
|
131
|
+
rbln_kwargs = {}
|
102
132
|
else:
|
103
|
-
|
133
|
+
rbln_kwargs = copy.deepcopy(rbln_kwargs)
|
104
134
|
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
for name, cfg in zip(names, rbln_configs):
|
111
|
-
if len(cfg) > 1:
|
112
|
-
msg = (
|
113
|
-
"`from_rbln_configs` requires exact one `RBLNRuntimeConfig` for each `RBLNConfig`."
|
114
|
-
f"But got {len(cfg)} `RBLNRuntimeConfig`."
|
115
|
-
)
|
116
|
-
raise RuntimeError(msg)
|
117
|
-
|
118
|
-
runtime_cfg = cfg[list(cfg.keys())[0]][0]
|
119
|
-
runtime_cfg = copy.deepcopy(runtime_cfg)
|
120
|
-
if name is not None:
|
121
|
-
runtime_cfg.compiled_model_name = name
|
122
|
-
runtime_cfgs.append(runtime_cfg)
|
123
|
-
|
124
|
-
metas = [cfg.meta for cfg in rbln_configs]
|
125
|
-
merged_meta = {k: v for meta in metas for k, v in meta.items()}
|
126
|
-
|
127
|
-
return RBLNConfig.from_rbln_runtime_configs(runtime_cfgs, _rbln_meta=merged_meta)
|
135
|
+
# meta : class, version and other informations.
|
136
|
+
if meta is None:
|
137
|
+
self.meta = {"version": __version__, "cls": rbln_cls}
|
138
|
+
else:
|
139
|
+
self.meta = meta
|
128
140
|
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
+
# compile_cfgs : compile args for each runtime
|
142
|
+
self.compile_cfgs = compile_cfgs
|
143
|
+
for compile_cfg in self.compile_cfgs:
|
144
|
+
compile_cfg.update(rbln_kwargs)
|
145
|
+
for K in COMPILE_KEYWORDS:
|
146
|
+
rbln_kwargs.pop(K, None)
|
147
|
+
|
148
|
+
# runtime_cfg : Values that don't be saved / loaded.
|
149
|
+
self.runtime_cfg = {}
|
150
|
+
for runtime_key in RUNTIME_KEYWORDS:
|
151
|
+
if runtime_key in rbln_kwargs:
|
152
|
+
self.runtime_cfg[runtime_key] = rbln_kwargs.pop(runtime_key)
|
153
|
+
|
154
|
+
# model_cfg : All user-provided values such as "max_seq_len".
|
155
|
+
self.model_cfg: Dict[str, Any] = rbln_kwargs
|
141
156
|
|
142
157
|
def save(self, dir_path: str):
|
143
158
|
dir_path = Path(dir_path)
|
144
|
-
|
145
|
-
|
159
|
+
|
160
|
+
s_json = {}
|
161
|
+
compile_cfgs = [asdict(cfg) for cfg in self.compile_cfgs]
|
162
|
+
s_json["_compile_cfgs"] = compile_cfgs
|
163
|
+
s_json["_meta"] = self.meta
|
164
|
+
s_json.update(self.model_cfg)
|
165
|
+
|
146
166
|
with open(dir_path / "rbln_config.json", "w") as jsonf:
|
147
|
-
json.dump(
|
167
|
+
json.dump(s_json, jsonf, indent=2)
|
148
168
|
|
149
|
-
@
|
150
|
-
def load(dir_path: str) -> "RBLNConfig":
|
169
|
+
@classmethod
|
170
|
+
def load(cls, dir_path: str) -> "RBLNConfig":
|
151
171
|
dir_path = Path(dir_path)
|
152
172
|
with open(dir_path / "rbln_config.json", "r") as jsonf:
|
153
173
|
config_file = json.load(jsonf)
|
154
|
-
return RBLNConfig.fromdict(config_file)
|
155
|
-
|
156
|
-
def asdict(self):
|
157
|
-
dic = {k: [asdict(cfg) for cfg in cfgs] for k, cfgs in self.data.items()}
|
158
|
-
return dic
|
159
174
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
175
|
+
return cls.fromdict(config_file)
|
176
|
+
|
177
|
+
@classmethod
|
178
|
+
def fromdict(cls, dic: dict):
|
179
|
+
compile_cfgs = dic.pop("_compile_cfgs")
|
180
|
+
compile_cfgs = [RBLNCompileConfig(**cfg) for cfg in compile_cfgs]
|
181
|
+
|
182
|
+
meta = dic.pop("_meta")
|
183
|
+
rbln_cls = meta["cls"]
|
184
|
+
|
185
|
+
rbln_kwargs = dic
|
186
|
+
return cls(rbln_cls=rbln_cls, compile_cfgs=compile_cfgs, rbln_kwargs=rbln_kwargs, meta=meta)
|
187
|
+
|
188
|
+
def update_runtime_cfg(self, rbln_kwargs: Dict[str, Any]):
|
189
|
+
keys = list(rbln_kwargs.keys())
|
190
|
+
for key in keys:
|
191
|
+
if key in RUNTIME_KEYWORDS:
|
192
|
+
self.runtime_cfg[key] = rbln_kwargs[key]
|
193
|
+
|
194
|
+
def __repr__(self):
|
195
|
+
compile_cfgs_repr = [f"\n {cfg!r}" for cfg in self.compile_cfgs]
|
196
|
+
return (
|
197
|
+
f"RBLNConfig(\n"
|
198
|
+
f" rbln_cls={self.meta['cls']},\n"
|
199
|
+
f" version='{self.meta['version']}',\n"
|
200
|
+
f" compile_cfgs=[{''.join(compile_cfgs_repr)}\n ],\n"
|
201
|
+
f" model_cfg={self.model_cfg},\n"
|
202
|
+
f" runtime_cfg={self.runtime_cfg}\n"
|
203
|
+
f")"
|
204
|
+
)
|
205
|
+
|
206
|
+
@property
|
207
|
+
def create_runtimes(self):
|
208
|
+
context = ContextRblnConfig.get_current_context()["create_runtimes"]
|
209
|
+
if context is not None:
|
210
|
+
return context
|
211
|
+
elif self.runtime_cfg.get("create_runtimes", None) is None:
|
212
|
+
return rebel.npu_is_available()
|
213
|
+
return self.runtime_cfg["create_runtimes"]
|
214
|
+
|
215
|
+
@property
|
216
|
+
def optimize_host_memory(self):
|
217
|
+
context = ContextRblnConfig.get_current_context()["optimize_host_memory"]
|
218
|
+
if context is not None:
|
219
|
+
return context
|
220
|
+
elif self.runtime_cfg.get("optimize_host_memory", None) is None:
|
221
|
+
return True
|
222
|
+
return self.runtime_cfg["optimize_host_memory"]
|
223
|
+
|
224
|
+
@property
|
225
|
+
def device(self):
|
226
|
+
context = ContextRblnConfig.get_current_context()["device"]
|
227
|
+
if context:
|
228
|
+
return context
|
229
|
+
elif self.runtime_cfg.get("device", None) is None:
|
230
|
+
return 0
|
231
|
+
return self.runtime_cfg["device"]
|
232
|
+
|
233
|
+
@property
|
234
|
+
def device_map(self):
|
235
|
+
context = ContextRblnConfig.get_current_context()["device_map"]
|
236
|
+
if context:
|
237
|
+
return context
|
238
|
+
elif self.runtime_cfg.get("device_map", None) is None:
|
239
|
+
rbln_device_map = {}
|
240
|
+
device_val = self.device
|
241
|
+
for cfg in self.compile_cfgs:
|
242
|
+
rbln_device_map[cfg.compiled_model_name] = device_val
|
243
|
+
return rbln_device_map
|
244
|
+
return self.runtime_cfg["device_map"]
|
optimum/rbln/modeling_seq2seq.py
CHANGED
@@ -37,7 +37,7 @@ from transformers import (
|
|
37
37
|
from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
|
38
38
|
|
39
39
|
from .modeling_base import RBLNModel
|
40
|
-
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
40
|
+
from .modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
41
41
|
from .transformers.models.bart import BartDecoderWrapper, BartEncoderWrapper
|
42
42
|
from .transformers.models.t5 import T5DecoderWrapper, T5EncoderWrapper
|
43
43
|
from .utils.runtime_utils import RBLNPytorchRuntime
|
@@ -88,12 +88,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
88
88
|
|
89
89
|
def __post_init__(self, **kwargs):
|
90
90
|
self.model_dim = self.config.d_model
|
91
|
-
self.batch_size = self.rbln_config[
|
92
|
-
self.enc_max_seq_len = self.rbln_config.
|
93
|
-
self.dec_max_seq_len = self.rbln_config.
|
94
|
-
self.pad_token_id = self.rbln_config.
|
91
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
92
|
+
self.enc_max_seq_len = self.rbln_config.model_cfg["enc_max_seq_len"]
|
93
|
+
self.dec_max_seq_len = self.rbln_config.model_cfg["dec_max_seq_len"]
|
94
|
+
self.pad_token_id = self.rbln_config.model_cfg["pad_token_id"]
|
95
95
|
self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_ids")
|
96
96
|
self.decoder = RBLNRuntimeDecoder(runtime=self.model[1], main_input_name="input_ids")
|
97
|
+
self.enc_attention_mask = torch.zeros(1, self.enc_max_seq_len, dtype=torch.float32)
|
98
|
+
self.dec_enc_attention_mask = torch.zeros(self.batch_size, self.enc_max_seq_len, dtype=torch.float32)
|
97
99
|
|
98
100
|
def can_generate(self):
|
99
101
|
return True
|
@@ -117,32 +119,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
117
119
|
return redirect(val)
|
118
120
|
return val
|
119
121
|
|
120
|
-
def prepare_inputs_for_generation(
|
121
|
-
self,
|
122
|
-
input_ids,
|
123
|
-
past_key_values=None,
|
124
|
-
attention_mask=None,
|
125
|
-
decoder_attention_mask=None,
|
126
|
-
**kwargs,
|
127
|
-
):
|
128
|
-
max_seq_len = self.dec_max_seq_len
|
129
|
-
cur_seq_len = input_ids.shape[-1]
|
130
|
-
decoder_batch_size = input_ids.shape[0]
|
131
|
-
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
132
|
-
|
133
|
-
# In greedy decoding
|
134
|
-
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.int64)
|
135
|
-
decoder_attention_mask[:, :cur_seq_len] = 1
|
136
|
-
cache_position = torch.tensor(cur_seq_len - 1, dtype=torch.int32)
|
137
|
-
|
138
|
-
return {
|
139
|
-
"decoder_input_ids": input_ids,
|
140
|
-
"past_key_values": past_key_values,
|
141
|
-
"attention_mask": attention_mask,
|
142
|
-
"decoder_attention_mask": decoder_attention_mask,
|
143
|
-
"cache_position": cache_position,
|
144
|
-
}
|
145
|
-
|
146
122
|
@classmethod
|
147
123
|
def update_kwargs(cls, kwargs):
|
148
124
|
kwargs.update(
|
@@ -170,50 +146,54 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
170
146
|
|
171
147
|
wrapped_encoder, wrapped_decoder = optimized_models(model)
|
172
148
|
|
173
|
-
wrapped_encoder.encoder_max_length = rbln_config.
|
174
|
-
wrapped_encoder.decoder_max_length = rbln_config.
|
175
|
-
wrapped_encoder.decoder_batch_size = rbln_config.
|
149
|
+
wrapped_encoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
150
|
+
wrapped_encoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
151
|
+
wrapped_encoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
176
152
|
|
177
|
-
wrapped_decoder.encoder_max_length = rbln_config.
|
178
|
-
wrapped_decoder.decoder_max_length = rbln_config.
|
179
|
-
wrapped_decoder.decoder_batch_size = rbln_config.
|
153
|
+
wrapped_decoder.encoder_max_length = rbln_config.model_cfg["enc_max_seq_len"]
|
154
|
+
wrapped_decoder.decoder_max_length = rbln_config.model_cfg["dec_max_seq_len"]
|
155
|
+
wrapped_decoder.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
180
156
|
|
181
|
-
|
182
|
-
|
157
|
+
enc_rbln_compile_config = rbln_config.compile_cfgs[0]
|
158
|
+
dec_rbln_compile_config = rbln_config.compile_cfgs[1]
|
183
159
|
|
184
160
|
if isinstance(model, T5ForConditionalGeneration):
|
185
|
-
enc_example_inputs =
|
186
|
-
dec_example_inputs =
|
161
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=1)
|
162
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=1)
|
187
163
|
else:
|
188
|
-
enc_example_inputs =
|
189
|
-
dec_example_inputs =
|
164
|
+
enc_example_inputs = enc_rbln_compile_config.get_dummy_inputs(fill=0)
|
165
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
166
|
+
|
167
|
+
enc_example_inputs[3].fill_(0)
|
168
|
+
dec_example_inputs[4].fill_(-1)
|
190
169
|
|
191
170
|
enc_scripted_model = torch.jit.trace(wrapped_encoder, enc_example_inputs, check_trace=False)
|
192
171
|
dec_scripted_model = torch.jit.trace(wrapped_decoder, dec_example_inputs, check_trace=False)
|
193
172
|
|
194
173
|
enc_ir = rebel.torchscript_to_ir(
|
195
174
|
enc_scripted_model,
|
196
|
-
input_names=[v[0] for v in
|
197
|
-
name=
|
175
|
+
input_names=[v[0] for v in enc_rbln_compile_config.input_info],
|
176
|
+
name=enc_rbln_compile_config.mod_name,
|
198
177
|
)
|
199
178
|
dec_ir = rebel.torchscript_to_ir(
|
200
179
|
dec_scripted_model,
|
201
|
-
input_names=[v[0] for v in
|
202
|
-
name=
|
180
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
181
|
+
name=dec_rbln_compile_config.mod_name,
|
203
182
|
)
|
204
|
-
dec_ir.decoder_batch_size = rbln_config.
|
183
|
+
dec_ir.decoder_batch_size = rbln_config.model_cfg["batch_size"]
|
205
184
|
|
206
185
|
connections = [
|
207
|
-
(enc_ir.outputs[0], dec_ir.inputs[
|
208
|
-
(
|
186
|
+
(enc_ir.outputs[0], enc_ir.inputs[2], dec_ir.inputs[6]),
|
187
|
+
# (enc_ir.outputs[0], enc_ir.inputs[2]),
|
188
|
+
(dec_ir.outputs[1], dec_ir.inputs[5]),
|
209
189
|
]
|
210
190
|
compiled_model = rebel.compile(
|
211
191
|
enc_ir,
|
212
192
|
dec_ir,
|
213
193
|
connections=connections,
|
214
|
-
fusion=
|
215
|
-
npu=
|
216
|
-
tensor_parallel_size=
|
194
|
+
fusion=enc_rbln_compile_config.fusion,
|
195
|
+
npu=enc_rbln_compile_config.npu,
|
196
|
+
tensor_parallel_size=enc_rbln_compile_config.tensor_parallel_size,
|
217
197
|
)
|
218
198
|
return compiled_model
|
219
199
|
|
@@ -222,11 +202,12 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
222
202
|
cls,
|
223
203
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
224
204
|
model_config: "PretrainedConfig",
|
225
|
-
|
226
|
-
rbln_dec_max_seq_len: Optional[int] = None,
|
227
|
-
rbln_batch_size: Optional[int] = 1,
|
205
|
+
rbln_kwargs: Dict[str, Any] = {},
|
228
206
|
) -> RBLNConfig:
|
229
|
-
|
207
|
+
rbln_enc_max_seq_len = rbln_kwargs.get("enc_max_seq_len", None)
|
208
|
+
rbln_dec_max_seq_len = rbln_kwargs.get("dec_max_seq_len", None)
|
209
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
210
|
+
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
230
211
|
|
231
212
|
if isinstance(model_config, BartConfig):
|
232
213
|
n_layer = model_config.decoder_layers
|
@@ -274,28 +255,36 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
274
255
|
if max_position_embeddings is not None and rbln_dec_max_seq_len > max_position_embeddings:
|
275
256
|
raise ValueError("`rbln_dec_max_seq_len` should be less or equal than max_position_embeddings!")
|
276
257
|
|
277
|
-
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
278
|
-
|
279
|
-
meta["rbln_enc_max_seq_len"] = rbln_enc_max_seq_len
|
280
|
-
meta["rbln_dec_max_seq_len"] = rbln_dec_max_seq_len
|
281
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
282
|
-
meta["rbln_pad_token_id"] = rbln_pad_token_id
|
283
|
-
|
284
258
|
# model input info
|
285
259
|
enc_input_info = [
|
286
|
-
("input_ids", [
|
287
|
-
("attention_mask", [
|
260
|
+
("input_ids", [1, rbln_enc_max_seq_len], "int64"),
|
261
|
+
("attention_mask", [1, rbln_enc_max_seq_len], "float32"),
|
262
|
+
(
|
263
|
+
"cross_key_value_states",
|
264
|
+
[
|
265
|
+
n_layer * 2,
|
266
|
+
rbln_batch_size,
|
267
|
+
n_head,
|
268
|
+
rbln_enc_max_seq_len,
|
269
|
+
d_kv,
|
270
|
+
],
|
271
|
+
"float32",
|
272
|
+
),
|
273
|
+
# int16 available?
|
274
|
+
("batch_idx", [], "int32"),
|
288
275
|
]
|
289
276
|
|
290
277
|
dec_input_info = [
|
291
278
|
("input_ids", [rbln_batch_size, 1], "int64"),
|
292
|
-
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "
|
293
|
-
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "
|
279
|
+
("attention_mask", [rbln_batch_size, rbln_dec_max_seq_len], "float32"),
|
280
|
+
("encoder_attention_mask", [rbln_batch_size, rbln_enc_max_seq_len], "float32"),
|
294
281
|
(
|
295
282
|
"cache_position",
|
296
|
-
[],
|
283
|
+
[rbln_batch_size, 1],
|
284
|
+
# [],
|
297
285
|
"int32",
|
298
286
|
),
|
287
|
+
("batch_position", [], "int32"),
|
299
288
|
]
|
300
289
|
dec_input_info.extend(
|
301
290
|
[
|
@@ -327,12 +316,22 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
327
316
|
)
|
328
317
|
]
|
329
318
|
)
|
330
|
-
|
331
|
-
|
319
|
+
enc_rbln_compile_config = RBLNCompileConfig(mod_name="encoder", input_info=enc_input_info)
|
320
|
+
dec_rbln_compile_config = RBLNCompileConfig(mod_name="decoder", input_info=dec_input_info)
|
332
321
|
|
333
|
-
rbln_config = RBLNConfig
|
334
|
-
|
335
|
-
|
322
|
+
rbln_config = RBLNConfig(
|
323
|
+
rbln_cls=cls.__name__,
|
324
|
+
compile_cfgs=[enc_rbln_compile_config, dec_rbln_compile_config],
|
325
|
+
rbln_kwargs=rbln_kwargs,
|
326
|
+
)
|
327
|
+
|
328
|
+
rbln_config.model_cfg.update(
|
329
|
+
{
|
330
|
+
"enc_max_seq_len": rbln_enc_max_seq_len,
|
331
|
+
"dec_max_seq_len": rbln_dec_max_seq_len,
|
332
|
+
"batch_size": rbln_batch_size,
|
333
|
+
"pad_token_id": rbln_pad_token_id,
|
334
|
+
}
|
336
335
|
)
|
337
336
|
|
338
337
|
return rbln_config
|
@@ -347,7 +346,84 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
347
346
|
compiled_models[0].create_runtime("decoder", tensor_type="pt", device=device_val),
|
348
347
|
]
|
349
348
|
|
349
|
+
def prepare_inputs_for_generation(
|
350
|
+
self,
|
351
|
+
input_ids,
|
352
|
+
past_key_values=None,
|
353
|
+
attention_mask=None,
|
354
|
+
decoder_attention_mask=None,
|
355
|
+
**kwargs,
|
356
|
+
):
|
357
|
+
past_cache_length = past_key_values
|
358
|
+
if past_cache_length == 0:
|
359
|
+
cache_pos = []
|
360
|
+
for i in range(input_ids.shape[0]):
|
361
|
+
cache_pos.append([0])
|
362
|
+
cache_position = torch.tensor(cache_pos, dtype=torch.int32)
|
363
|
+
|
364
|
+
max_seq_len = self.dec_max_seq_len
|
365
|
+
cur_seq_len = input_ids.shape[-1]
|
366
|
+
decoder_batch_size = input_ids.shape[0]
|
367
|
+
input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
|
368
|
+
# In greedy decoding
|
369
|
+
decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
|
370
|
+
decoder_attention_mask[:, :cur_seq_len] = 1
|
371
|
+
cache_pos = []
|
372
|
+
for i in range(input_ids.shape[0]):
|
373
|
+
cache_pos.append([cur_seq_len - 1])
|
374
|
+
cache_position = torch.tensor(cache_pos, dtype=torch.int32)
|
375
|
+
return {
|
376
|
+
"decoder_input_ids": input_ids,
|
377
|
+
"past_key_values": past_key_values,
|
378
|
+
"attention_mask": attention_mask.to(torch.float32),
|
379
|
+
"decoder_attention_mask": decoder_attention_mask,
|
380
|
+
"cache_position": cache_position,
|
381
|
+
}
|
382
|
+
|
350
383
|
def forward(
|
384
|
+
self,
|
385
|
+
input_ids: torch.LongTensor = None,
|
386
|
+
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
387
|
+
batch_idx: Optional[torch.LongTensor] = None,
|
388
|
+
enc_lengths: List[int] = None,
|
389
|
+
**kwargs,
|
390
|
+
) -> Tuple[torch.FloatTensor]:
|
391
|
+
# common decoder
|
392
|
+
if enc_lengths is None:
|
393
|
+
output = self._forward_decoder(input_ids=input_ids, cache_position=cache_position, **kwargs)
|
394
|
+
return output
|
395
|
+
|
396
|
+
# vllm & encoder
|
397
|
+
if batch_idx is not None:
|
398
|
+
enc_attention_mask = self.enc_attention_mask.clone()
|
399
|
+
enc_attention_mask[0][: enc_lengths[batch_idx] + 1] = 1
|
400
|
+
padding_need = self.enc_max_seq_len - input_ids.shape[-1]
|
401
|
+
input_ids = torch.nn.functional.pad(input_ids, (0, padding_need))
|
402
|
+
_ = self.encoder(input_ids, enc_attention_mask, batch_idx=batch_idx.to(torch.int32))
|
403
|
+
logits = torch.zeros(1, 1, self.config.vocab_size + 100)
|
404
|
+
logits[0][0][-1] = 1
|
405
|
+
# vllm & decoder
|
406
|
+
else:
|
407
|
+
input_ids[input_ids == (self.config.vocab_size + 99)] = self.config.decoder_start_token_id
|
408
|
+
cache_position[cache_position != 0] = cache_position[cache_position != 0] - 2
|
409
|
+
|
410
|
+
enc_attention_mask = self.dec_enc_attention_mask.clone()
|
411
|
+
dec_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
|
412
|
+
for batch_idx in range(self.batch_size):
|
413
|
+
enc_attention_mask[batch_idx, : enc_lengths[batch_idx] + 1] = 1
|
414
|
+
|
415
|
+
logits = self._forward_decoder(
|
416
|
+
attention_mask=enc_attention_mask,
|
417
|
+
decoder_input_ids=input_ids,
|
418
|
+
decoder_attention_mask=dec_attention_mask,
|
419
|
+
cache_position=cache_position,
|
420
|
+
).logits
|
421
|
+
|
422
|
+
return Seq2SeqLMOutput(
|
423
|
+
logits=logits,
|
424
|
+
)
|
425
|
+
|
426
|
+
def _forward_decoder(
|
351
427
|
self,
|
352
428
|
attention_mask: Optional[torch.FloatTensor] = None,
|
353
429
|
decoder_input_ids: Optional[torch.LongTensor] = None,
|
@@ -355,13 +431,18 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
355
431
|
cache_position: Optional[torch.Tensor] = None,
|
356
432
|
**kwargs,
|
357
433
|
) -> Tuple[torch.FloatTensor]:
|
434
|
+
dec_attention_mask = decoder_attention_mask.clone()
|
435
|
+
for b_idx in range(self.rbln_config.model_cfg["batch_size"]):
|
436
|
+
dec_attention_mask[b_idx, : cache_position[b_idx] + 1] = 1
|
437
|
+
|
358
438
|
decoder_output = self.decoder(
|
359
439
|
input_ids=decoder_input_ids,
|
360
|
-
attention_mask=
|
440
|
+
attention_mask=dec_attention_mask,
|
361
441
|
encoder_attention_mask=attention_mask,
|
362
442
|
cache_position=cache_position,
|
443
|
+
batch_position=torch.tensor(0, dtype=torch.int32),
|
363
444
|
)
|
364
|
-
lm_logits = decoder_output.logits
|
445
|
+
lm_logits = decoder_output.logits[0]
|
365
446
|
|
366
447
|
return Seq2SeqLMOutput(logits=lm_logits)
|
367
448
|
|
@@ -405,6 +486,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel):
|
|
405
486
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
406
487
|
encoder_kwargs["return_dict"] = True
|
407
488
|
encoder_kwargs[model_input_name] = inputs_tensor
|
408
|
-
|
489
|
+
for b in range(batch_size):
|
490
|
+
batch_idx = torch.tensor(b, dtype=torch.int32)
|
491
|
+
cb_inputs = {}
|
492
|
+
cb_inputs["return_dict"] = True
|
493
|
+
cb_inputs["output_hidden_states"] = False
|
494
|
+
cb_inputs["output_attentions"] = False
|
495
|
+
cb_inputs["input_ids"] = encoder_kwargs["input_ids"][b].unsqueeze(0)
|
496
|
+
cb_inputs["attention_mask"] = encoder_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
|
497
|
+
model_kwargs["encoder_outputs"] = encoder(**cb_inputs, batch_idx=batch_idx)
|
409
498
|
|
410
499
|
return model_kwargs
|