optimum-rbln 0.1.9__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +47 -9
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -31
- optimum/rbln/diffusers/models/controlnet.py +53 -43
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -31
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +28 -23
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +28 -37
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +30 -39
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +24 -15
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +26 -17
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -17
- optimum/rbln/modeling_alias.py +6 -11
- optimum/rbln/modeling_base.py +467 -261
- optimum/rbln/modeling_config.py +199 -73
- optimum/rbln/transformers/__init__.py +43 -1
- optimum/rbln/transformers/models/__init__.py +23 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +95 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +203 -58
- optimum/rbln/transformers/models/bart/modeling_bart.py +125 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +101 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -26
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +409 -150
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -8
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +662 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +6 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +198 -168
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +122 -47
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -12
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +172 -111
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +18 -16
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +50 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +43 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +9 -7
- optimum_rbln-0.1.12.dist-info/RECORD +103 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.12.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -20,25 +20,29 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import functools
|
23
24
|
import glob
|
24
|
-
import
|
25
|
+
import os
|
25
26
|
from abc import ABC
|
27
|
+
from dataclasses import dataclass
|
28
|
+
from pathlib import Path
|
26
29
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
27
30
|
|
28
31
|
import rebel # noqa: F401
|
29
32
|
import torch # noqa: F401
|
30
33
|
from safetensors.torch import load_file
|
31
34
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
32
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
33
35
|
from transformers.modeling_utils import no_init_weights
|
36
|
+
from transformers.utils import ModelOutput
|
34
37
|
|
35
38
|
from ....modeling_base import RBLNModel
|
36
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
39
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
40
|
+
from ....utils.logging import get_logger
|
37
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
38
|
-
from
|
42
|
+
from ....utils.timer_utils import rbln_timer
|
39
43
|
|
40
44
|
|
41
|
-
logger =
|
45
|
+
logger = get_logger()
|
42
46
|
|
43
47
|
if TYPE_CHECKING:
|
44
48
|
from transformers import (
|
@@ -56,7 +60,46 @@ SUPPORTED_QUANTIZATIONS = {
|
|
56
60
|
|
57
61
|
|
58
62
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
59
|
-
mandatory_members = ["main_input_name"]
|
63
|
+
mandatory_members = ["main_input_name", "embed_tokens"]
|
64
|
+
|
65
|
+
def forward(
|
66
|
+
self,
|
67
|
+
input_ids: torch.LongTensor,
|
68
|
+
inputs_embeds: torch.Tensor,
|
69
|
+
attention_mask: torch.Tensor,
|
70
|
+
cache_position: torch.Tensor,
|
71
|
+
batch_position: torch.Tensor,
|
72
|
+
query_idx: torch.Tensor,
|
73
|
+
**kwargs,
|
74
|
+
):
|
75
|
+
if inputs_embeds is None:
|
76
|
+
inp = input_ids
|
77
|
+
if self.embed_tokens is not None:
|
78
|
+
inp = self.embed_tokens(inp)
|
79
|
+
|
80
|
+
return super().forward(
|
81
|
+
inp,
|
82
|
+
attention_mask,
|
83
|
+
cache_position,
|
84
|
+
batch_position,
|
85
|
+
query_idx,
|
86
|
+
**kwargs,
|
87
|
+
)
|
88
|
+
else:
|
89
|
+
return super().forward(
|
90
|
+
inputs_embeds,
|
91
|
+
attention_mask,
|
92
|
+
cache_position,
|
93
|
+
batch_position,
|
94
|
+
query_idx,
|
95
|
+
**kwargs,
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
@dataclass
|
100
|
+
class RBLNDecoderOnlyOutput(ModelOutput):
|
101
|
+
logits: torch.FloatTensor = None
|
102
|
+
generate_idx: torch.Tensor = None
|
60
103
|
|
61
104
|
|
62
105
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
@@ -74,18 +117,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
74
117
|
auto_model_class = AutoModelForCausalLM
|
75
118
|
|
76
119
|
def __post_init__(self, **kwargs):
|
77
|
-
self.batch_size = self.rbln_config.
|
78
|
-
self.max_seq_len = self.rbln_config.
|
79
|
-
self.prefill_chunk_size = self.rbln_config.
|
120
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
121
|
+
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
122
|
+
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
80
123
|
|
81
|
-
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.
|
124
|
+
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
82
125
|
self.causal_mask = 1 - torch.triu(
|
83
126
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
84
127
|
)
|
85
|
-
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.
|
86
|
-
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.
|
87
|
-
|
88
|
-
|
128
|
+
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
|
129
|
+
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
130
|
+
|
131
|
+
main_input_name = self.main_input_name
|
132
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
133
|
+
main_input_name = "inputs_embeds"
|
134
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
135
|
+
with no_init_weights():
|
136
|
+
self.embed_tokens = torch.nn.Embedding(
|
137
|
+
self.config.vocab_size,
|
138
|
+
self.config.hidden_size,
|
139
|
+
self.config.pad_token_id,
|
140
|
+
)
|
141
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
142
|
+
else:
|
143
|
+
self.embed_tokens = None
|
144
|
+
|
145
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
146
|
+
runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
147
|
+
)
|
148
|
+
self.decoder = RBLNRuntimeModel(
|
149
|
+
runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
150
|
+
)
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def save_torch_artifacts(
|
154
|
+
cls,
|
155
|
+
model: "PreTrainedModel",
|
156
|
+
save_dir_path: Path,
|
157
|
+
subfolder: str,
|
158
|
+
rbln_config: RBLNConfig,
|
159
|
+
):
|
160
|
+
"""
|
161
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
162
|
+
store the torch tensor, weight, etc. in this function.
|
163
|
+
"""
|
164
|
+
if rbln_config.model_cfg["use_inputs_embeds"]:
|
165
|
+
save_dict = {}
|
166
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
167
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
168
|
+
|
169
|
+
def get_input_embeddings(self):
|
170
|
+
return self.embed_tokens
|
89
171
|
|
90
172
|
@classmethod
|
91
173
|
def get_quantized_model(
|
@@ -98,10 +180,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
98
180
|
subfolder: str = "",
|
99
181
|
local_files_only: bool = False,
|
100
182
|
trust_remote_code: bool = False,
|
101
|
-
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
102
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
103
183
|
**kwargs,
|
104
184
|
):
|
185
|
+
from ...utils.rbln_quantization import update_layers_to_quantized
|
186
|
+
|
105
187
|
kwargs = cls.update_kwargs(kwargs)
|
106
188
|
|
107
189
|
config = AutoConfig.from_pretrained(
|
@@ -116,37 +198,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
116
198
|
|
117
199
|
with no_init_weights():
|
118
200
|
model = AutoModelForCausalLM.from_config(config)
|
119
|
-
replace_quantized_linear_layers(model)
|
120
201
|
|
121
|
-
|
122
|
-
for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
|
123
|
-
partial_state_dict = load_file(safetensor_file)
|
124
|
-
state_dict.update(partial_state_dict)
|
202
|
+
update_layers_to_quantized(model)
|
125
203
|
|
126
204
|
n_layer = kwargs.get("num_hidden_layers", None)
|
127
|
-
|
128
|
-
|
129
|
-
for key in state_dict.keys():
|
130
|
-
parts = key.split(".")
|
131
|
-
if len(parts) > 2 and parts[2].isdigit():
|
132
|
-
layer_num = int(parts[2])
|
133
|
-
if layer_num >= n_layer:
|
134
|
-
keys_to_delete.append(key)
|
135
|
-
|
136
|
-
for key in keys_to_delete:
|
137
|
-
del state_dict[key]
|
138
|
-
|
139
|
-
model.load_state_dict(state_dict)
|
205
|
+
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
206
|
+
|
140
207
|
return model
|
141
208
|
|
209
|
+
def _load_weights_directly_to_model(model, model_id, n_layer=None):
|
210
|
+
"""
|
211
|
+
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
212
|
+
"""
|
213
|
+
|
214
|
+
model_params = dict(model.named_parameters(recurse=True))
|
215
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
216
|
+
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
217
|
+
|
218
|
+
target_layers = list(range(n_layer)) if n_layer is not None else None
|
219
|
+
|
220
|
+
for safetensor_file in safetensor_files:
|
221
|
+
file_data = load_file(safetensor_file)
|
222
|
+
for key, value in file_data.items():
|
223
|
+
if target_layers is not None:
|
224
|
+
parts = key.split(".")
|
225
|
+
|
226
|
+
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
227
|
+
continue
|
228
|
+
|
229
|
+
if key in model_params:
|
230
|
+
model_params[key].data.copy_(value)
|
231
|
+
elif key in model_buffers:
|
232
|
+
model_buffers[key].data.copy_(value)
|
233
|
+
|
234
|
+
return 0
|
235
|
+
|
142
236
|
@classmethod
|
143
|
-
def get_pytorch_model(
|
144
|
-
|
145
|
-
|
146
|
-
**kwargs,
|
147
|
-
) -> "PreTrainedModel":
|
148
|
-
rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
|
149
|
-
rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
|
237
|
+
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
238
|
+
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
239
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
150
240
|
|
151
241
|
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
152
242
|
model = cls.get_quantized_model(*args, **kwargs)
|
@@ -155,18 +245,68 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
155
245
|
|
156
246
|
return model
|
157
247
|
|
248
|
+
def validate_quantization_config(quantize_config):
|
249
|
+
if quantize_config is not None:
|
250
|
+
q_format = quantize_config.get("format")
|
251
|
+
q_precision = quantize_config.get("precision")
|
252
|
+
|
253
|
+
if q_format not in SUPPORTED_QUANTIZATIONS:
|
254
|
+
raise ValueError(
|
255
|
+
f"Invalid quantization format: {q_format}. "
|
256
|
+
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
257
|
+
)
|
258
|
+
|
259
|
+
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
260
|
+
raise ValueError(
|
261
|
+
f"Invalid precision: {q_precision} for format: {q_format}. "
|
262
|
+
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
263
|
+
)
|
264
|
+
|
265
|
+
return quantize_config
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def set_quantize_env(cls, quantize_config):
|
269
|
+
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
270
|
+
quantize_config = cls.validate_quantization_config(quantize_config)
|
271
|
+
if quantize_config is not None:
|
272
|
+
q_precision = quantize_config.get("precision")
|
273
|
+
quant_bits = q_precision.split("w")[1].split("a")[0]
|
274
|
+
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
275
|
+
return RBLN_QUANT_BITS_ENV
|
276
|
+
return None
|
277
|
+
|
278
|
+
@classmethod
|
279
|
+
def reset_quantize_env(cls, env_var_name):
|
280
|
+
if env_var_name is not None and env_var_name in os.environ:
|
281
|
+
del os.environ[env_var_name]
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def manage_quantize_env(cls, func):
|
285
|
+
@functools.wraps(func)
|
286
|
+
def wrapper(*args, **kwargs):
|
287
|
+
quantize_config = kwargs.get("quantize_config")
|
288
|
+
quantize_env_var = cls.set_quantize_env(quantize_config)
|
289
|
+
try:
|
290
|
+
return func(*args, **kwargs)
|
291
|
+
finally:
|
292
|
+
cls.reset_quantize_env(quantize_env_var)
|
293
|
+
|
294
|
+
return wrapper
|
295
|
+
|
158
296
|
@classmethod
|
159
297
|
@torch.inference_mode()
|
160
298
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
161
299
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
162
300
|
|
163
|
-
|
164
|
-
|
301
|
+
rbln_compile_configs = rbln_config.compile_cfgs
|
302
|
+
prefill_rbln_compile_config = rbln_compile_configs[0]
|
303
|
+
dec_rbln_compile_config = rbln_compile_configs[1]
|
165
304
|
|
305
|
+
@rbln_timer("JIT trace")
|
166
306
|
def get_scripted_model():
|
167
307
|
# This function is nested to dealloc the example inputs before compilation.
|
168
|
-
prefill_example_inputs =
|
169
|
-
dec_example_inputs =
|
308
|
+
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
309
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
|
170
310
|
|
171
311
|
batch_index = 3
|
172
312
|
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
@@ -181,31 +321,48 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
181
321
|
|
182
322
|
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
183
323
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
324
|
+
@rbln_timer("Model conversion")
|
325
|
+
def scripted_model_to_ir():
|
326
|
+
prefill_ir = rebel.torchscript_to_ir(
|
327
|
+
prefill_scripted_model,
|
328
|
+
input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
|
329
|
+
)
|
330
|
+
dec_ir = rebel.torchscript_to_ir(
|
331
|
+
dec_scripted_model,
|
332
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
333
|
+
)
|
334
|
+
return prefill_ir, dec_ir
|
192
335
|
|
336
|
+
prefill_ir, dec_ir = scripted_model_to_ir()
|
193
337
|
# Caching prefill_decoder/decoder I/O
|
194
|
-
cache_index_offset =
|
338
|
+
cache_index_offset = 5
|
195
339
|
connections = [
|
196
340
|
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
197
341
|
for i in range(model.config.num_hidden_layers * 2)
|
198
342
|
]
|
199
343
|
|
200
|
-
|
344
|
+
# Extract quantize_config from rbln_config
|
345
|
+
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
346
|
+
|
347
|
+
@cls.manage_quantize_env
|
348
|
+
def compile_model(*args, **kwargs):
|
349
|
+
# Remove quantize_config from kwargs
|
350
|
+
kwargs.pop("quantize_config", None)
|
351
|
+
|
352
|
+
# Call rebel.compile with the updated kwargs
|
353
|
+
return rebel.compile(*args, **kwargs)
|
354
|
+
|
355
|
+
compiled_model = compile_model(
|
201
356
|
prefill_ir,
|
202
357
|
dec_ir,
|
203
358
|
connections=connections,
|
204
|
-
fusion=
|
205
|
-
npu=
|
206
|
-
tensor_parallel_size=
|
359
|
+
fusion=prefill_rbln_compile_config.fusion,
|
360
|
+
npu=prefill_rbln_compile_config.npu,
|
361
|
+
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
207
362
|
use_weight_sharing=True,
|
363
|
+
quantize_config=quantize_config,
|
208
364
|
)
|
365
|
+
|
209
366
|
return compiled_model
|
210
367
|
|
211
368
|
@classmethod
|
@@ -213,12 +370,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
213
370
|
cls,
|
214
371
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
215
372
|
model_config: "PretrainedConfig",
|
216
|
-
|
217
|
-
rbln_batch_size: Optional[int] = None,
|
218
|
-
rbln_quantization: Optional[Dict[str, str]] = None,
|
219
|
-
**kwargs,
|
373
|
+
rbln_kwargs: Dict[str, Any] = {},
|
220
374
|
) -> RBLNConfig:
|
221
|
-
|
375
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
376
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
377
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
378
|
+
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
379
|
+
|
380
|
+
rbln_quantization = cls.validate_quantization_config(rbln_quantization)
|
222
381
|
|
223
382
|
prefill_chunk_size = 128
|
224
383
|
if rbln_max_seq_len is None:
|
@@ -228,40 +387,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
228
387
|
if rbln_max_seq_len is None:
|
229
388
|
raise ValueError("`rbln_max_seq_len` should be specified.")
|
230
389
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
231
|
-
|
232
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
233
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
234
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
390
|
+
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
235
391
|
|
236
392
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
237
393
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
238
394
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
239
395
|
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
240
|
-
|
241
|
-
if rbln_quantization is not None:
|
242
|
-
q_format = rbln_quantization.get("format", None)
|
243
|
-
q_precision = rbln_quantization.get("precision", None)
|
244
|
-
|
245
|
-
if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
246
|
-
raise ValueError(
|
247
|
-
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
248
|
-
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
249
|
-
)
|
250
|
-
meta["rbln_quantization"] = rbln_quantization
|
396
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
251
397
|
|
252
398
|
def get_input_info(
|
253
399
|
batch_size,
|
254
400
|
query_length,
|
401
|
+
use_inputs_embeds,
|
402
|
+
hidden_size,
|
255
403
|
):
|
404
|
+
if use_inputs_embeds:
|
405
|
+
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
406
|
+
else:
|
407
|
+
main_input = ("input_ids", [batch_size, query_length], "int64")
|
408
|
+
|
256
409
|
input_info = [
|
257
|
-
|
258
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "
|
410
|
+
main_input,
|
411
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
259
412
|
(
|
260
413
|
"cache_position",
|
261
414
|
[batch_size, query_length],
|
262
415
|
"int32",
|
263
416
|
),
|
264
417
|
("batch_position", [], "int16"),
|
418
|
+
("query_idx", [], "int16"),
|
265
419
|
]
|
266
420
|
|
267
421
|
input_info.extend(
|
@@ -285,22 +439,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
285
439
|
prefill_input_info = get_input_info(
|
286
440
|
batch_size=1,
|
287
441
|
query_length=prefill_chunk_size,
|
442
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
443
|
+
hidden_size=hidden_size,
|
288
444
|
)
|
289
445
|
dec_input_info = get_input_info(
|
290
446
|
batch_size=rbln_batch_size,
|
291
447
|
query_length=1,
|
448
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
449
|
+
hidden_size=hidden_size,
|
292
450
|
)
|
293
451
|
|
294
|
-
|
295
|
-
|
452
|
+
prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
|
453
|
+
dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
|
296
454
|
|
297
|
-
|
455
|
+
rbln_config = RBLNConfig(
|
456
|
+
rbln_cls=cls.__name__,
|
457
|
+
compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
|
458
|
+
rbln_kwargs=rbln_kwargs,
|
459
|
+
)
|
298
460
|
|
299
|
-
rbln_config
|
300
|
-
|
301
|
-
|
461
|
+
rbln_config.model_cfg.update(
|
462
|
+
{
|
463
|
+
"max_seq_len": rbln_max_seq_len,
|
464
|
+
"batch_size": rbln_batch_size,
|
465
|
+
"prefill_chunk_size": prefill_chunk_size,
|
466
|
+
"use_inputs_embeds": rbln_use_inputs_embeds,
|
467
|
+
}
|
302
468
|
)
|
303
469
|
|
470
|
+
if rbln_quantization is not None:
|
471
|
+
rbln_config.model_cfg.update({"quantization": rbln_quantization})
|
472
|
+
|
304
473
|
return rbln_config
|
305
474
|
|
306
475
|
@classmethod
|
@@ -322,71 +491,112 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
322
491
|
def _reorder_cache(self, past_key_values, beam_idx):
|
323
492
|
raise NotImplementedError
|
324
493
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
valid_len = input_id.shape[-1]
|
340
|
-
cache_position = torch.arange(0, valid_len, dtype=torch.int32)
|
341
|
-
past_cached_length[i] = valid_len
|
342
|
-
l_input_ids.append(input_id.unsqueeze(0))
|
343
|
-
cache_positions.append(cache_position.unsqueeze(0))
|
344
|
-
|
345
|
-
input_ids = l_input_ids
|
494
|
+
def prepare_inputs_for_generation(
|
495
|
+
self,
|
496
|
+
input_ids: torch.LongTensor,
|
497
|
+
generate_idx: Optional[torch.Tensor] = None,
|
498
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
499
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
500
|
+
**kwargs,
|
501
|
+
):
|
502
|
+
model_inputs = {}
|
503
|
+
is_prefill_phase = generate_idx is None
|
504
|
+
|
505
|
+
if is_prefill_phase:
|
506
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
507
|
+
cache_position = None
|
346
508
|
else:
|
347
|
-
|
348
|
-
|
349
|
-
past_cached_length = past_cached_length + 1
|
509
|
+
if inputs_embeds is not None:
|
510
|
+
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
350
511
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
"
|
355
|
-
|
512
|
+
input_ids = input_ids[:, -1:]
|
513
|
+
cache_position = generate_idx
|
514
|
+
generate_idx = generate_idx + 1
|
515
|
+
model_inputs.update({"input_ids": input_ids})
|
516
|
+
|
517
|
+
if inputs_embeds is not None:
|
518
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
519
|
+
model_inputs.update({"inputs_embeds": inputs_embeds})
|
520
|
+
else:
|
521
|
+
raise ValueError(
|
522
|
+
"The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
523
|
+
)
|
524
|
+
else:
|
525
|
+
model_inputs.update({"input_ids": input_ids})
|
526
|
+
|
527
|
+
model_inputs.update(
|
528
|
+
{
|
529
|
+
"attention_mask": attention_mask,
|
530
|
+
"cache_position": cache_position,
|
531
|
+
"generate_idx": generate_idx,
|
532
|
+
}
|
533
|
+
)
|
356
534
|
|
357
535
|
return model_inputs
|
358
536
|
|
537
|
+
def _update_model_kwargs_for_generation(
|
538
|
+
self,
|
539
|
+
outputs: RBLNDecoderOnlyOutput,
|
540
|
+
model_kwargs: Dict[str, Any],
|
541
|
+
**kwargs,
|
542
|
+
) -> Dict[str, Any]:
|
543
|
+
# update generate_idx
|
544
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
545
|
+
|
546
|
+
return model_kwargs
|
547
|
+
|
359
548
|
def forward(
|
360
549
|
self,
|
361
|
-
input_ids: torch.LongTensor = None,
|
362
|
-
|
550
|
+
input_ids: Optional[torch.LongTensor] = None,
|
551
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
552
|
+
cache_position: Optional[torch.Tensor] = None,
|
553
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
554
|
+
generate_idx: Optional[torch.Tensor] = None,
|
555
|
+
# from llava_next forward args
|
363
556
|
batch_idx: Optional[int] = None,
|
364
|
-
past_cached_length: Optional[torch.Tensor] = None, # past_cached_length
|
365
557
|
**kwargs,
|
366
558
|
) -> Tuple[torch.FloatTensor]:
|
367
|
-
# prefll
|
368
|
-
if
|
559
|
+
# prefll
|
560
|
+
if cache_position is None:
|
369
561
|
logits = []
|
370
|
-
|
371
|
-
|
562
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
563
|
+
batch_size = input_tensors.shape[0]
|
564
|
+
|
565
|
+
for b_idx in range(batch_size):
|
566
|
+
# Transform inputs as vllm format
|
567
|
+
if attention_mask is not None:
|
568
|
+
input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
569
|
+
else:
|
570
|
+
input_tensor = input_tensors[b_idx : b_idx + 1]
|
571
|
+
|
572
|
+
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
573
|
+
|
574
|
+
logit = self._forward_prefill(
|
575
|
+
input_ids=input_tensor if inputs_embeds is None else None,
|
576
|
+
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
577
|
+
cache_position=cache_position,
|
578
|
+
batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
|
579
|
+
)
|
372
580
|
logits.append(logit)
|
373
581
|
logits = torch.cat(logits, dim=0)
|
374
|
-
#
|
375
|
-
elif cache_position.shape[-1] > 1:
|
376
|
-
logits = self._forward_prefill(input_ids=input_ids, cache_position=cache_position, batch_idx=batch_idx)
|
377
|
-
# common decoder
|
582
|
+
# decoder
|
378
583
|
else:
|
379
|
-
logits = self._forward_decoder(
|
584
|
+
logits = self._forward_decoder(
|
585
|
+
input_ids=input_ids,
|
586
|
+
inputs_embeds=inputs_embeds,
|
587
|
+
cache_position=cache_position,
|
588
|
+
)
|
380
589
|
|
381
|
-
return
|
590
|
+
return RBLNDecoderOnlyOutput(
|
382
591
|
logits=logits,
|
383
|
-
|
592
|
+
generate_idx=generate_idx,
|
384
593
|
)
|
385
594
|
|
386
595
|
def _forward_prefill(
|
387
596
|
self,
|
388
597
|
input_ids: torch.LongTensor = None,
|
389
|
-
|
598
|
+
inputs_embeds: torch.Tensor = None,
|
599
|
+
cache_position: torch.Tensor = None,
|
390
600
|
batch_idx: int = None,
|
391
601
|
) -> torch.FloatTensor:
|
392
602
|
if batch_idx is None or batch_idx >= self.batch_size:
|
@@ -398,7 +608,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
398
608
|
torch.empty(
|
399
609
|
size=[
|
400
610
|
1,
|
401
|
-
|
611
|
+
1,
|
402
612
|
self.config.vocab_size,
|
403
613
|
],
|
404
614
|
dtype=torch.float32,
|
@@ -407,11 +617,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
407
617
|
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
408
618
|
]
|
409
619
|
|
410
|
-
|
411
|
-
|
620
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
621
|
+
query_length = input_tensors.shape[1]
|
622
|
+
_attention_mask = self.prefill_attention_mask.clone()
|
623
|
+
|
412
624
|
for step in range(0, query_length, self.prefill_chunk_size):
|
413
|
-
|
414
|
-
|
625
|
+
# pad input_tensors & cache_position for prefill_chunk
|
626
|
+
if (step + self.prefill_chunk_size) > query_length:
|
627
|
+
pad_to_chunk = step + self.prefill_chunk_size - query_length
|
628
|
+
if inputs_embeds is not None:
|
629
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
|
630
|
+
else:
|
631
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
|
632
|
+
|
415
633
|
cache_position = torch.cat(
|
416
634
|
[
|
417
635
|
cache_position,
|
@@ -424,41 +642,82 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
424
642
|
dim=-1,
|
425
643
|
)
|
426
644
|
|
427
|
-
|
428
|
-
|
645
|
+
# slice input_tensor & cache_position with prefill_chunk_size
|
646
|
+
_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
647
|
+
_cache_position = cache_position[:, step : step + self.prefill_chunk_size]
|
429
648
|
|
649
|
+
# update attention_mask
|
430
650
|
if step >= self.prefill_chunk_size:
|
431
|
-
|
432
|
-
|
651
|
+
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
652
|
+
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
653
|
+
|
654
|
+
query_idx = (query_length - 1) % self.prefill_chunk_size
|
433
655
|
|
434
656
|
logits, _ = self.prefill_decoder(
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
657
|
+
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
658
|
+
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
659
|
+
attention_mask=_attention_mask.contiguous(),
|
660
|
+
cache_position=_cache_position.contiguous(),
|
661
|
+
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
662
|
+
query_idx=torch.tensor(query_idx, dtype=torch.int16),
|
439
663
|
out=out_buffers,
|
440
664
|
)
|
441
|
-
logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
442
665
|
|
666
|
+
# update decoder_attn_mask with preprocessed kv-cache length in prefill phase
|
443
667
|
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
444
668
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
445
669
|
|
446
670
|
return logits
|
447
671
|
|
448
672
|
def _forward_decoder(
|
449
|
-
self,
|
673
|
+
self,
|
674
|
+
input_ids: torch.LongTensor = None,
|
675
|
+
inputs_embeds: torch.Tensor = None,
|
676
|
+
cache_position: torch.Tensor = None,
|
450
677
|
) -> torch.FloatTensor:
|
451
|
-
|
678
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
679
|
+
|
680
|
+
batch_size = input_tensors.shape[0]
|
452
681
|
|
453
682
|
for b_idx in range(batch_size):
|
454
683
|
decoding_step = cache_position[b_idx].item()
|
455
684
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
456
685
|
|
457
686
|
logits, _ = self.decoder(
|
458
|
-
input_ids.contiguous(),
|
459
|
-
|
460
|
-
|
461
|
-
|
687
|
+
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
688
|
+
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
689
|
+
attention_mask=self.dec_attn_mask.contiguous(),
|
690
|
+
cache_position=cache_position.contiguous(),
|
691
|
+
batch_position=torch.tensor(0, dtype=torch.int16),
|
692
|
+
query_idx=torch.tensor(0, dtype=torch.int16),
|
462
693
|
)
|
463
694
|
|
464
695
|
return logits
|
696
|
+
|
697
|
+
def vllm_forward(
|
698
|
+
self,
|
699
|
+
input_ids: torch.LongTensor = None,
|
700
|
+
inputs_embeds: torch.Tensor = None,
|
701
|
+
cache_position: torch.Tensor = None,
|
702
|
+
batch_idx: Optional[int] = None,
|
703
|
+
**kwargs,
|
704
|
+
) -> Tuple[torch.FloatTensor]:
|
705
|
+
# prefll
|
706
|
+
if cache_position.shape[-1] > 1:
|
707
|
+
logits = self._forward_prefill(
|
708
|
+
input_ids=input_ids,
|
709
|
+
inputs_embeds=inputs_embeds,
|
710
|
+
cache_position=cache_position,
|
711
|
+
batch_idx=batch_idx,
|
712
|
+
)
|
713
|
+
# decoder
|
714
|
+
else:
|
715
|
+
logits = self._forward_decoder(
|
716
|
+
input_ids=input_ids,
|
717
|
+
inputs_embeds=inputs_embeds,
|
718
|
+
cache_position=cache_position,
|
719
|
+
)
|
720
|
+
|
721
|
+
return RBLNDecoderOnlyOutput(
|
722
|
+
logits=logits,
|
723
|
+
)
|