optimum-rbln 0.1.8__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +40 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +39 -32
- optimum/rbln/diffusers/models/controlnet.py +60 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +43 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +2 -3
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +8 -4
- optimum/rbln/modeling_base.py +512 -238
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +37 -1
- optimum/rbln/transformers/models/__init__.py +21 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +128 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +32 -7
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +406 -104
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +10 -3
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +6 -89
- optimum/rbln/transformers/models/llama/modeling_llama.py +9 -3
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -88
- optimum/rbln/transformers/models/mistral/__init__.py +24 -0
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +29 -0
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +68 -0
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +18 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +25 -16
- optimum/rbln/transformers/utils/__init__.py +0 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +97 -0
- optimum/rbln/utils/import_utils.py +37 -5
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +35 -1
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +15 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.8.dist-info/RECORD +0 -73
- {optimum_rbln-0.1.8.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -20,18 +20,24 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import glob
|
23
24
|
import logging
|
24
|
-
from abc import ABC
|
25
|
-
from
|
25
|
+
from abc import ABC
|
26
|
+
from dataclasses import dataclass
|
27
|
+
from pathlib import Path
|
28
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
26
29
|
|
27
30
|
import rebel # noqa: F401
|
28
31
|
import torch # noqa: F401
|
29
|
-
from
|
30
|
-
from transformers
|
32
|
+
from safetensors.torch import load_file
|
33
|
+
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
34
|
+
from transformers.modeling_utils import no_init_weights
|
35
|
+
from transformers.utils import ModelOutput
|
31
36
|
|
32
37
|
from ....modeling_base import RBLNModel
|
33
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
38
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
34
39
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
40
|
+
from ....utils.timer_utils import rbln_timer
|
35
41
|
|
36
42
|
|
37
43
|
logger = logging.getLogger(__name__)
|
@@ -44,9 +50,54 @@ if TYPE_CHECKING:
|
|
44
50
|
PretrainedConfig,
|
45
51
|
)
|
46
52
|
|
53
|
+
SUPPORTED_QUANTIZATIONS = {
|
54
|
+
"rbln": [
|
55
|
+
"w4a16",
|
56
|
+
],
|
57
|
+
}
|
58
|
+
|
47
59
|
|
48
60
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
49
|
-
mandatory_members = ["main_input_name"]
|
61
|
+
mandatory_members = ["main_input_name", "embed_tokens"]
|
62
|
+
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
input_ids: torch.LongTensor,
|
66
|
+
inputs_embeds: torch.Tensor,
|
67
|
+
attention_mask: torch.Tensor,
|
68
|
+
cache_position: torch.Tensor,
|
69
|
+
batch_position: torch.Tensor,
|
70
|
+
query_idx: torch.Tensor,
|
71
|
+
**kwargs,
|
72
|
+
):
|
73
|
+
if inputs_embeds is None:
|
74
|
+
inp = input_ids
|
75
|
+
if self.embed_tokens is not None:
|
76
|
+
inp = self.embed_tokens(inp)
|
77
|
+
|
78
|
+
return super().forward(
|
79
|
+
inp,
|
80
|
+
attention_mask,
|
81
|
+
cache_position,
|
82
|
+
batch_position,
|
83
|
+
query_idx,
|
84
|
+
**kwargs,
|
85
|
+
)
|
86
|
+
else:
|
87
|
+
return super().forward(
|
88
|
+
inputs_embeds,
|
89
|
+
attention_mask,
|
90
|
+
cache_position,
|
91
|
+
batch_position,
|
92
|
+
query_idx,
|
93
|
+
**kwargs,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
@dataclass
|
98
|
+
class RBLNDecoderOnlyOutput(ModelOutput):
|
99
|
+
logits: torch.FloatTensor = None
|
100
|
+
past_cached_length: Union[int, torch.Tensor] = None
|
50
101
|
|
51
102
|
|
52
103
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
@@ -64,52 +115,177 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
64
115
|
auto_model_class = AutoModelForCausalLM
|
65
116
|
|
66
117
|
def __post_init__(self, **kwargs):
|
67
|
-
self.batch_size = self.rbln_config.
|
68
|
-
self.max_seq_len = self.rbln_config.
|
69
|
-
self.prefill_chunk_size = self.rbln_config.
|
118
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
119
|
+
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
120
|
+
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
70
121
|
|
71
|
-
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.
|
122
|
+
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
72
123
|
self.causal_mask = 1 - torch.triu(
|
73
124
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
74
125
|
)
|
75
|
-
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.
|
76
|
-
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.
|
77
|
-
|
78
|
-
|
126
|
+
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
|
127
|
+
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
128
|
+
|
129
|
+
main_input_name = self.main_input_name
|
130
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
131
|
+
main_input_name = "inputs_embeds"
|
132
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
133
|
+
with no_init_weights():
|
134
|
+
self.embed_tokens = torch.nn.Embedding(
|
135
|
+
self.config.vocab_size,
|
136
|
+
self.config.hidden_size,
|
137
|
+
self.config.pad_token_id,
|
138
|
+
)
|
139
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
140
|
+
else:
|
141
|
+
self.embed_tokens = None
|
142
|
+
|
143
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
144
|
+
runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
145
|
+
)
|
146
|
+
self.decoder = RBLNRuntimeModel(
|
147
|
+
runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
148
|
+
)
|
79
149
|
|
80
150
|
@classmethod
|
81
|
-
|
82
|
-
|
83
|
-
|
151
|
+
def save_torch_artifacts(
|
152
|
+
cls,
|
153
|
+
model: "PreTrainedModel",
|
154
|
+
save_dir_path: Path,
|
155
|
+
subfolder: str,
|
156
|
+
rbln_config: RBLNConfig,
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
160
|
+
store the torch tensor, weight, etc. in this function.
|
161
|
+
"""
|
162
|
+
if rbln_config.model_cfg["use_inputs_embeds"]:
|
163
|
+
save_dict = {}
|
164
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
165
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
166
|
+
|
167
|
+
def get_input_embeddings(self):
|
168
|
+
return self.embed_tokens
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
def get_quantized_model(
|
172
|
+
cls,
|
173
|
+
model_id: str,
|
174
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
175
|
+
revision: Optional[str] = None,
|
176
|
+
force_download: bool = False,
|
177
|
+
cache_dir: Optional[str] = None,
|
178
|
+
subfolder: str = "",
|
179
|
+
local_files_only: bool = False,
|
180
|
+
trust_remote_code: bool = False,
|
181
|
+
**kwargs,
|
182
|
+
):
|
183
|
+
from ...utils.rbln_quantization import update_layers_to_quantized
|
184
|
+
|
185
|
+
kwargs = cls.update_kwargs(kwargs)
|
186
|
+
|
187
|
+
config = AutoConfig.from_pretrained(
|
188
|
+
model_id,
|
189
|
+
use_auth_token=use_auth_token,
|
190
|
+
revision=revision,
|
191
|
+
force_download=force_download,
|
192
|
+
cache_dir=cache_dir,
|
193
|
+
trust_remote_code=trust_remote_code,
|
194
|
+
**kwargs,
|
195
|
+
)
|
196
|
+
|
197
|
+
with no_init_weights():
|
198
|
+
model = AutoModelForCausalLM.from_config(config)
|
199
|
+
|
200
|
+
update_layers_to_quantized(model)
|
201
|
+
|
202
|
+
n_layer = kwargs.get("num_hidden_layers", None)
|
203
|
+
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
204
|
+
|
205
|
+
return model
|
206
|
+
|
207
|
+
def _load_weights_directly_to_model(model, model_id, n_layer=None):
|
208
|
+
"""
|
209
|
+
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
210
|
+
"""
|
211
|
+
|
212
|
+
model_params = dict(model.named_parameters(recurse=True))
|
213
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
214
|
+
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
215
|
+
|
216
|
+
target_layers = list(range(n_layer)) if n_layer is not None else None
|
217
|
+
|
218
|
+
for safetensor_file in safetensor_files:
|
219
|
+
file_data = load_file(safetensor_file)
|
220
|
+
for key, value in file_data.items():
|
221
|
+
if target_layers is not None:
|
222
|
+
parts = key.split(".")
|
223
|
+
|
224
|
+
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
225
|
+
continue
|
226
|
+
|
227
|
+
if key in model_params:
|
228
|
+
model_params[key].data.copy_(value)
|
229
|
+
elif key in model_buffers:
|
230
|
+
model_buffers[key].data.copy_(value)
|
231
|
+
|
232
|
+
return 0
|
233
|
+
|
234
|
+
@classmethod
|
235
|
+
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
236
|
+
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
237
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
238
|
+
|
239
|
+
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
240
|
+
model = cls.get_quantized_model(*args, **kwargs)
|
241
|
+
else:
|
242
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
243
|
+
|
244
|
+
return model
|
84
245
|
|
85
246
|
@classmethod
|
86
247
|
@torch.inference_mode()
|
87
248
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
88
|
-
wrapped_model = cls.
|
249
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
89
250
|
|
90
|
-
|
91
|
-
|
251
|
+
rbln_compile_configs = rbln_config.compile_cfgs
|
252
|
+
prefill_rbln_compile_config = rbln_compile_configs[0]
|
253
|
+
dec_rbln_compile_config = rbln_compile_configs[1]
|
92
254
|
|
93
|
-
|
94
|
-
|
255
|
+
@rbln_timer("Jit Trace")
|
256
|
+
def get_scripted_model():
|
257
|
+
# This function is nested to dealloc the example inputs before compilation.
|
258
|
+
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
259
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
|
95
260
|
|
96
|
-
|
97
|
-
|
261
|
+
batch_index = 3
|
262
|
+
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
263
|
+
|
264
|
+
prefill_scripted_model = torch.jit.trace(
|
265
|
+
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
266
|
+
)
|
267
|
+
dec_scripted_model = torch.jit.trace(
|
268
|
+
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
269
|
+
)
|
270
|
+
return prefill_scripted_model, dec_scripted_model
|
98
271
|
|
99
|
-
prefill_scripted_model =
|
100
|
-
dec_scripted_model = torch.jit.trace(wrapped_model, dec_example_inputs, check_trace=False)
|
272
|
+
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
101
273
|
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
274
|
+
@rbln_timer("TorchScript to IR")
|
275
|
+
def scripted_model_to_ir():
|
276
|
+
prefill_ir = rebel.torchscript_to_ir(
|
277
|
+
prefill_scripted_model,
|
278
|
+
input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
|
279
|
+
)
|
280
|
+
dec_ir = rebel.torchscript_to_ir(
|
281
|
+
dec_scripted_model,
|
282
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
283
|
+
)
|
284
|
+
return prefill_ir, dec_ir
|
110
285
|
|
286
|
+
prefill_ir, dec_ir = scripted_model_to_ir()
|
111
287
|
# Caching prefill_decoder/decoder I/O
|
112
|
-
cache_index_offset =
|
288
|
+
cache_index_offset = 5
|
113
289
|
connections = [
|
114
290
|
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
115
291
|
for i in range(model.config.num_hidden_layers * 2)
|
@@ -119,9 +295,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
119
295
|
prefill_ir,
|
120
296
|
dec_ir,
|
121
297
|
connections=connections,
|
122
|
-
fusion=
|
123
|
-
npu=
|
124
|
-
tensor_parallel_size=
|
298
|
+
fusion=prefill_rbln_compile_config.fusion,
|
299
|
+
npu=prefill_rbln_compile_config.npu,
|
300
|
+
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
125
301
|
use_weight_sharing=True,
|
126
302
|
)
|
127
303
|
return compiled_model
|
@@ -131,39 +307,60 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
131
307
|
cls,
|
132
308
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
133
309
|
model_config: "PretrainedConfig",
|
134
|
-
|
135
|
-
rbln_batch_size: Optional[int] = None,
|
136
|
-
**kwargs,
|
310
|
+
rbln_kwargs: Dict[str, Any] = {},
|
137
311
|
) -> RBLNConfig:
|
138
|
-
|
312
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
313
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
314
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
315
|
+
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
139
316
|
|
140
317
|
prefill_chunk_size = 128
|
141
318
|
if rbln_max_seq_len is None:
|
142
|
-
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None)
|
319
|
+
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
320
|
+
model_config, "n_positions", None
|
321
|
+
)
|
322
|
+
if rbln_max_seq_len is None:
|
323
|
+
raise ValueError("`rbln_max_seq_len` should be specified.")
|
143
324
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
325
|
+
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
326
|
+
|
327
|
+
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
328
|
+
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
329
|
+
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
330
|
+
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
331
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
332
|
+
|
333
|
+
if rbln_quantization is not None:
|
334
|
+
q_format = rbln_quantization.get("format", None)
|
335
|
+
q_precision = rbln_quantization.get("precision", None)
|
336
|
+
|
337
|
+
if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
338
|
+
raise ValueError(
|
339
|
+
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
340
|
+
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
341
|
+
)
|
148
342
|
|
149
343
|
def get_input_info(
|
150
344
|
batch_size,
|
151
345
|
query_length,
|
346
|
+
use_inputs_embeds,
|
347
|
+
hidden_size,
|
152
348
|
):
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
349
|
+
if use_inputs_embeds:
|
350
|
+
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
351
|
+
else:
|
352
|
+
main_input = ("input_ids", [batch_size, query_length], "int64")
|
353
|
+
|
158
354
|
input_info = [
|
159
|
-
|
160
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "
|
355
|
+
main_input,
|
356
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
161
357
|
(
|
162
358
|
"cache_position",
|
163
359
|
[batch_size, query_length],
|
164
360
|
"int32",
|
165
361
|
),
|
166
362
|
("batch_position", [], "int16"),
|
363
|
+
("query_idx", [], "int16"),
|
167
364
|
]
|
168
365
|
|
169
366
|
input_info.extend(
|
@@ -172,13 +369,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
172
369
|
f"past_key_values_{i}",
|
173
370
|
[
|
174
371
|
rbln_batch_size,
|
175
|
-
|
372
|
+
num_key_value_heads,
|
176
373
|
rbln_max_seq_len,
|
177
374
|
head_dim,
|
178
375
|
],
|
179
376
|
"float32",
|
180
377
|
)
|
181
|
-
for i in range(
|
378
|
+
for i in range(num_hidden_layers * 2)
|
182
379
|
]
|
183
380
|
)
|
184
381
|
|
@@ -187,22 +384,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
187
384
|
prefill_input_info = get_input_info(
|
188
385
|
batch_size=1,
|
189
386
|
query_length=prefill_chunk_size,
|
387
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
388
|
+
hidden_size=hidden_size,
|
190
389
|
)
|
191
390
|
dec_input_info = get_input_info(
|
192
391
|
batch_size=rbln_batch_size,
|
193
392
|
query_length=1,
|
393
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
394
|
+
hidden_size=hidden_size,
|
194
395
|
)
|
195
396
|
|
196
|
-
|
197
|
-
|
397
|
+
prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
|
398
|
+
dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
|
198
399
|
|
199
|
-
|
400
|
+
rbln_config = RBLNConfig(
|
401
|
+
rbln_cls=cls.__name__,
|
402
|
+
compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
|
403
|
+
rbln_kwargs=rbln_kwargs,
|
404
|
+
)
|
200
405
|
|
201
|
-
rbln_config
|
202
|
-
|
203
|
-
|
406
|
+
rbln_config.model_cfg.update(
|
407
|
+
{
|
408
|
+
"max_seq_len": rbln_max_seq_len,
|
409
|
+
"batch_size": rbln_batch_size,
|
410
|
+
"prefill_chunk_size": prefill_chunk_size,
|
411
|
+
"use_inputs_embeds": rbln_use_inputs_embeds,
|
412
|
+
}
|
204
413
|
)
|
205
414
|
|
415
|
+
if rbln_quantization is not None:
|
416
|
+
rbln_config.model_cfg.update({"quantization": rbln_quantization})
|
417
|
+
|
206
418
|
return rbln_config
|
207
419
|
|
208
420
|
@classmethod
|
@@ -224,82 +436,155 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
224
436
|
def _reorder_cache(self, past_key_values, beam_idx):
|
225
437
|
raise NotImplementedError
|
226
438
|
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
439
|
+
def prepare_inputs_for_generation(
|
440
|
+
self,
|
441
|
+
input_ids: torch.LongTensor,
|
442
|
+
past_cached_length: Optional[torch.Tensor] = None,
|
443
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
444
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
445
|
+
**kwargs,
|
446
|
+
):
|
447
|
+
model_inputs = {}
|
448
|
+
# prefill phase
|
234
449
|
if past_cached_length is None:
|
235
|
-
|
450
|
+
# huggingface make dummy_input_ids if model_input_name is "input_embeds"
|
451
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
|
452
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
453
|
+
input_tensors = inputs_embeds
|
454
|
+
else:
|
455
|
+
input_tensors = input_ids
|
456
|
+
|
457
|
+
batch_size = input_tensors.shape[0]
|
458
|
+
l_input_tensors = []
|
236
459
|
cache_positions = []
|
237
460
|
past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
|
238
461
|
for i in range(batch_size):
|
239
|
-
|
240
|
-
|
241
|
-
valid_len =
|
462
|
+
input_tensor = input_tensors[i]
|
463
|
+
input_tensor = input_tensor[attention_mask[i] == 1]
|
464
|
+
valid_len = input_tensor.shape[0]
|
242
465
|
cache_position = torch.arange(0, valid_len, dtype=torch.int32)
|
243
466
|
past_cached_length[i] = valid_len
|
244
|
-
|
467
|
+
l_input_tensors.append(input_tensor.unsqueeze(0))
|
245
468
|
cache_positions.append(cache_position.unsqueeze(0))
|
246
469
|
|
247
|
-
|
470
|
+
input_tensors = l_input_tensors
|
471
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
472
|
+
model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
|
473
|
+
else:
|
474
|
+
model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
|
475
|
+
# decoder phase
|
248
476
|
else:
|
249
477
|
input_ids = input_ids[:, -1:]
|
250
478
|
cache_positions = past_cached_length
|
251
479
|
past_cached_length = past_cached_length + 1
|
480
|
+
model_inputs.update({"input_ids": input_ids})
|
252
481
|
|
253
|
-
model_inputs
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
482
|
+
model_inputs.update(
|
483
|
+
{
|
484
|
+
"cache_position": cache_positions,
|
485
|
+
"past_cached_length": past_cached_length,
|
486
|
+
}
|
487
|
+
)
|
258
488
|
|
259
489
|
return model_inputs
|
260
490
|
|
491
|
+
def _update_model_kwargs_for_generation(
|
492
|
+
self,
|
493
|
+
outputs: RBLNDecoderOnlyOutput,
|
494
|
+
model_kwargs: Dict[str, Any],
|
495
|
+
**kwargs,
|
496
|
+
) -> Dict[str, Any]:
|
497
|
+
# update past_cached_length
|
498
|
+
model_kwargs["past_cached_length"] = outputs.past_cached_length
|
499
|
+
|
500
|
+
return model_kwargs
|
501
|
+
|
261
502
|
def forward(
|
262
503
|
self,
|
263
|
-
input_ids: torch.LongTensor = None,
|
504
|
+
input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
|
505
|
+
inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
264
506
|
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
265
507
|
batch_idx: Optional[int] = None,
|
266
|
-
past_cached_length: Optional[torch.Tensor] = None,
|
508
|
+
past_cached_length: Optional[torch.Tensor] = None,
|
267
509
|
**kwargs,
|
268
510
|
) -> Tuple[torch.FloatTensor]:
|
269
511
|
# prefll & hf generate
|
270
512
|
if isinstance(cache_position, list):
|
271
513
|
logits = []
|
272
|
-
|
273
|
-
|
514
|
+
input_tensors = input_ids if inputs_embeds is None else inputs_embeds
|
515
|
+
for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
|
516
|
+
logit = self._forward_prefill(
|
517
|
+
input_ids=input_tensor if inputs_embeds is None else None,
|
518
|
+
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
519
|
+
cache_position=cache_pos,
|
520
|
+
batch_idx=batch_idx,
|
521
|
+
)
|
274
522
|
logits.append(logit)
|
275
523
|
logits = torch.cat(logits, dim=0)
|
276
524
|
# prefill & vllm step
|
277
525
|
elif cache_position.shape[-1] > 1:
|
278
|
-
logits = self._forward_prefill(
|
526
|
+
logits = self._forward_prefill(
|
527
|
+
input_ids=input_ids,
|
528
|
+
inputs_embeds=inputs_embeds,
|
529
|
+
cache_position=cache_position,
|
530
|
+
batch_idx=batch_idx,
|
531
|
+
)
|
279
532
|
# common decoder
|
280
533
|
else:
|
281
|
-
logits = self._forward_decoder(
|
534
|
+
logits = self._forward_decoder(
|
535
|
+
input_ids=input_ids,
|
536
|
+
inputs_embeds=inputs_embeds,
|
537
|
+
cache_position=cache_position,
|
538
|
+
)
|
282
539
|
|
283
|
-
return
|
540
|
+
return RBLNDecoderOnlyOutput(
|
284
541
|
logits=logits,
|
285
|
-
|
542
|
+
past_cached_length=past_cached_length,
|
286
543
|
)
|
287
544
|
|
288
545
|
def _forward_prefill(
|
289
546
|
self,
|
290
547
|
input_ids: torch.LongTensor = None,
|
291
|
-
|
548
|
+
inputs_embeds: torch.Tensor = None,
|
549
|
+
cache_position: torch.Tensor = None,
|
292
550
|
batch_idx: int = None,
|
293
551
|
) -> torch.FloatTensor:
|
294
552
|
if batch_idx is None or batch_idx >= self.batch_size:
|
295
553
|
raise RuntimeError(
|
296
554
|
f"Invalid batch_idx ({batch_idx}). It must be a non-null value less than the batch size ({self.batch_size})."
|
297
555
|
)
|
298
|
-
|
556
|
+
|
557
|
+
out_buffers = [
|
558
|
+
torch.empty(
|
559
|
+
size=[
|
560
|
+
1,
|
561
|
+
1,
|
562
|
+
self.config.vocab_size,
|
563
|
+
],
|
564
|
+
dtype=torch.float32,
|
565
|
+
device="cpu",
|
566
|
+
),
|
567
|
+
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
568
|
+
]
|
569
|
+
|
570
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
571
|
+
model_input_name = "inputs_embeds"
|
572
|
+
else:
|
573
|
+
model_input_name = "input_ids"
|
574
|
+
|
575
|
+
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
576
|
+
|
577
|
+
query_length = input_tensors.shape[1]
|
299
578
|
attention_mask = self.prefill_attention_mask.clone()
|
300
579
|
for step in range(0, query_length, self.prefill_chunk_size):
|
301
580
|
if step + self.prefill_chunk_size > query_length:
|
302
|
-
|
581
|
+
# input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
|
582
|
+
padding_needed = step + self.prefill_chunk_size - query_length
|
583
|
+
if model_input_name == "input_ids":
|
584
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
|
585
|
+
else:
|
586
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
|
587
|
+
|
303
588
|
cache_position = torch.cat(
|
304
589
|
[
|
305
590
|
cache_position,
|
@@ -312,18 +597,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
312
597
|
dim=-1,
|
313
598
|
)
|
314
599
|
|
315
|
-
|
600
|
+
sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
316
601
|
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
317
|
-
|
602
|
+
|
603
|
+
if step >= self.prefill_chunk_size:
|
604
|
+
attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
318
605
|
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
319
606
|
|
607
|
+
query_idx = query_length % self.prefill_chunk_size - 1
|
608
|
+
|
320
609
|
logits, _ = self.prefill_decoder(
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
610
|
+
input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
|
611
|
+
inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
|
612
|
+
attention_mask=attention_mask.contiguous(),
|
613
|
+
cache_position=sliced_cache_positions.contiguous(),
|
614
|
+
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
615
|
+
query_idx=torch.tensor(query_idx, dtype=torch.int16),
|
616
|
+
out=out_buffers,
|
325
617
|
)
|
326
|
-
logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
327
618
|
|
328
619
|
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
329
620
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
@@ -331,19 +622,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
331
622
|
return logits
|
332
623
|
|
333
624
|
def _forward_decoder(
|
334
|
-
self,
|
625
|
+
self,
|
626
|
+
input_ids: torch.LongTensor = None,
|
627
|
+
inputs_embeds: torch.Tensor = None,
|
628
|
+
cache_position: torch.Tensor = None,
|
335
629
|
) -> torch.FloatTensor:
|
336
|
-
|
630
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
631
|
+
model_input_name = "inputs_embeds"
|
632
|
+
else:
|
633
|
+
model_input_name = "input_ids"
|
634
|
+
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
635
|
+
|
636
|
+
batch_size = input_tensors.shape[0]
|
337
637
|
|
338
638
|
for b_idx in range(batch_size):
|
339
639
|
decoding_step = cache_position[b_idx].item()
|
340
640
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
341
641
|
|
342
642
|
logits, _ = self.decoder(
|
343
|
-
input_ids.contiguous(),
|
344
|
-
|
345
|
-
|
346
|
-
|
643
|
+
input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
|
644
|
+
inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
|
645
|
+
attention_mask=self.dec_attn_mask.contiguous(),
|
646
|
+
cache_position=cache_position.contiguous(),
|
647
|
+
batch_position=torch.tensor(0, dtype=torch.int16),
|
648
|
+
query_idx=torch.tensor(0, dtype=torch.int16),
|
347
649
|
)
|
348
650
|
|
349
651
|
return logits
|