optimum-rbln 0.1.9__py3-none-any.whl → 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +37 -2
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +36 -29
- optimum/rbln/diffusers/models/controlnet.py +56 -40
- optimum/rbln/diffusers/models/unet_2d_condition.py +40 -28
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +22 -15
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +23 -17
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +24 -18
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +22 -11
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +24 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +24 -14
- optimum/rbln/modeling_alias.py +3 -3
- optimum/rbln/modeling_base.py +471 -231
- optimum/rbln/modeling_config.py +152 -77
- optimum/rbln/modeling_seq2seq.py +166 -77
- optimum/rbln/transformers/__init__.py +35 -1
- optimum/rbln/transformers/models/__init__.py +20 -1
- optimum/rbln/transformers/models/auto/__init__.py +14 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +84 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +94 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +189 -50
- optimum/rbln/transformers/models/bart/modeling_bart.py +106 -0
- optimum/rbln/transformers/models/bert/__init__.py +24 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +102 -0
- optimum/rbln/transformers/models/clip/__init__.py +1 -1
- optimum/rbln/transformers/models/clip/modeling_clip.py +127 -25
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +28 -4
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +302 -115
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +21 -7
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +1 -1
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +1 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +1 -1
- optimum/rbln/transformers/models/llava_next/__init__.py +24 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +666 -0
- optimum/rbln/transformers/models/midm/midm_architecture.py +5 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +1 -1
- optimum/rbln/transformers/models/phi/__init__.py +24 -0
- optimum/rbln/transformers/models/phi/modeling_phi.py +69 -0
- optimum/rbln/transformers/models/phi/phi_architecture.py +406 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +92 -31
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +17 -11
- optimum/rbln/transformers/models/whisper/generation_whisper.py +68 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +141 -105
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +44 -17
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +17 -14
- optimum/rbln/transformers/utils/rbln_quantization.py +48 -60
- optimum/rbln/utils/import_utils.py +36 -1
- optimum/rbln/utils/logging.py +82 -0
- optimum/rbln/utils/runtime_utils.py +33 -0
- optimum/rbln/utils/timer_utils.py +19 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/METADATA +8 -7
- optimum_rbln-0.1.11.dist-info/RECORD +93 -0
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/WHEEL +1 -1
- optimum_rbln-0.1.11.dist-info/entry_points.txt +4 -0
- optimum_rbln-0.1.9.dist-info/RECORD +0 -78
- {optimum_rbln-0.1.9.dist-info → optimum_rbln-0.1.11.dist-info}/licenses/LICENSE +0 -0
@@ -23,19 +23,21 @@
|
|
23
23
|
import glob
|
24
24
|
import logging
|
25
25
|
from abc import ABC
|
26
|
+
from dataclasses import dataclass
|
27
|
+
from pathlib import Path
|
26
28
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
27
29
|
|
28
30
|
import rebel # noqa: F401
|
29
31
|
import torch # noqa: F401
|
30
32
|
from safetensors.torch import load_file
|
31
33
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
32
|
-
from transformers.modeling_outputs import CausalLMOutputWithPast
|
33
34
|
from transformers.modeling_utils import no_init_weights
|
35
|
+
from transformers.utils import ModelOutput
|
34
36
|
|
35
37
|
from ....modeling_base import RBLNModel
|
36
|
-
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME,
|
38
|
+
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
37
39
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
38
|
-
from
|
40
|
+
from ....utils.timer_utils import rbln_timer
|
39
41
|
|
40
42
|
|
41
43
|
logger = logging.getLogger(__name__)
|
@@ -56,7 +58,46 @@ SUPPORTED_QUANTIZATIONS = {
|
|
56
58
|
|
57
59
|
|
58
60
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
59
|
-
mandatory_members = ["main_input_name"]
|
61
|
+
mandatory_members = ["main_input_name", "embed_tokens"]
|
62
|
+
|
63
|
+
def forward(
|
64
|
+
self,
|
65
|
+
input_ids: torch.LongTensor,
|
66
|
+
inputs_embeds: torch.Tensor,
|
67
|
+
attention_mask: torch.Tensor,
|
68
|
+
cache_position: torch.Tensor,
|
69
|
+
batch_position: torch.Tensor,
|
70
|
+
query_idx: torch.Tensor,
|
71
|
+
**kwargs,
|
72
|
+
):
|
73
|
+
if inputs_embeds is None:
|
74
|
+
inp = input_ids
|
75
|
+
if self.embed_tokens is not None:
|
76
|
+
inp = self.embed_tokens(inp)
|
77
|
+
|
78
|
+
return super().forward(
|
79
|
+
inp,
|
80
|
+
attention_mask,
|
81
|
+
cache_position,
|
82
|
+
batch_position,
|
83
|
+
query_idx,
|
84
|
+
**kwargs,
|
85
|
+
)
|
86
|
+
else:
|
87
|
+
return super().forward(
|
88
|
+
inputs_embeds,
|
89
|
+
attention_mask,
|
90
|
+
cache_position,
|
91
|
+
batch_position,
|
92
|
+
query_idx,
|
93
|
+
**kwargs,
|
94
|
+
)
|
95
|
+
|
96
|
+
|
97
|
+
@dataclass
|
98
|
+
class RBLNDecoderOnlyOutput(ModelOutput):
|
99
|
+
logits: torch.FloatTensor = None
|
100
|
+
past_cached_length: Union[int, torch.Tensor] = None
|
60
101
|
|
61
102
|
|
62
103
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
@@ -74,18 +115,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
74
115
|
auto_model_class = AutoModelForCausalLM
|
75
116
|
|
76
117
|
def __post_init__(self, **kwargs):
|
77
|
-
self.batch_size = self.rbln_config.
|
78
|
-
self.max_seq_len = self.rbln_config.
|
79
|
-
self.prefill_chunk_size = self.rbln_config.
|
118
|
+
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
119
|
+
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
120
|
+
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
80
121
|
|
81
|
-
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.
|
122
|
+
self.prefill_attention_mask = torch.zeros(1, 1, self.prefill_chunk_size, self.max_seq_len, dtype=torch.float32)
|
82
123
|
self.causal_mask = 1 - torch.triu(
|
83
124
|
torch.ones(1, 1, self.prefill_chunk_size, self.prefill_chunk_size), diagonal=1
|
84
125
|
)
|
85
|
-
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.
|
86
|
-
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.
|
87
|
-
|
88
|
-
|
126
|
+
self.dec_attn_mask_init = torch.zeros(1, 1, 1, self.max_seq_len, dtype=torch.float32)
|
127
|
+
self.dec_attn_mask = torch.zeros(self.batch_size, 1, 1, self.max_seq_len, dtype=torch.float32)
|
128
|
+
|
129
|
+
main_input_name = self.main_input_name
|
130
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
131
|
+
main_input_name = "inputs_embeds"
|
132
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
133
|
+
with no_init_weights():
|
134
|
+
self.embed_tokens = torch.nn.Embedding(
|
135
|
+
self.config.vocab_size,
|
136
|
+
self.config.hidden_size,
|
137
|
+
self.config.pad_token_id,
|
138
|
+
)
|
139
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
140
|
+
else:
|
141
|
+
self.embed_tokens = None
|
142
|
+
|
143
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
144
|
+
runtime=self.model[0], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
145
|
+
)
|
146
|
+
self.decoder = RBLNRuntimeModel(
|
147
|
+
runtime=self.model[1], main_input_name=main_input_name, embed_tokens=self.embed_tokens
|
148
|
+
)
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def save_torch_artifacts(
|
152
|
+
cls,
|
153
|
+
model: "PreTrainedModel",
|
154
|
+
save_dir_path: Path,
|
155
|
+
subfolder: str,
|
156
|
+
rbln_config: RBLNConfig,
|
157
|
+
):
|
158
|
+
"""
|
159
|
+
If you are unavoidably running on a CPU rather than an RBLN device,
|
160
|
+
store the torch tensor, weight, etc. in this function.
|
161
|
+
"""
|
162
|
+
if rbln_config.model_cfg["use_inputs_embeds"]:
|
163
|
+
save_dict = {}
|
164
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
165
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
166
|
+
|
167
|
+
def get_input_embeddings(self):
|
168
|
+
return self.embed_tokens
|
89
169
|
|
90
170
|
@classmethod
|
91
171
|
def get_quantized_model(
|
@@ -98,10 +178,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
98
178
|
subfolder: str = "",
|
99
179
|
local_files_only: bool = False,
|
100
180
|
trust_remote_code: bool = False,
|
101
|
-
rbln_config_kwargs: Optional[Dict[str, Any]] = None,
|
102
|
-
rbln_constructor_kwargs: Optional[Dict[str, Any]] = None,
|
103
181
|
**kwargs,
|
104
182
|
):
|
183
|
+
from ...utils.rbln_quantization import update_layers_to_quantized
|
184
|
+
|
105
185
|
kwargs = cls.update_kwargs(kwargs)
|
106
186
|
|
107
187
|
config = AutoConfig.from_pretrained(
|
@@ -116,37 +196,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
116
196
|
|
117
197
|
with no_init_weights():
|
118
198
|
model = AutoModelForCausalLM.from_config(config)
|
119
|
-
replace_quantized_linear_layers(model)
|
120
199
|
|
121
|
-
|
122
|
-
for safetensor_file in glob.glob(f"{model_id}/*.safetensors"):
|
123
|
-
partial_state_dict = load_file(safetensor_file)
|
124
|
-
state_dict.update(partial_state_dict)
|
200
|
+
update_layers_to_quantized(model)
|
125
201
|
|
126
202
|
n_layer = kwargs.get("num_hidden_layers", None)
|
127
|
-
|
128
|
-
|
129
|
-
for key in state_dict.keys():
|
130
|
-
parts = key.split(".")
|
131
|
-
if len(parts) > 2 and parts[2].isdigit():
|
132
|
-
layer_num = int(parts[2])
|
133
|
-
if layer_num >= n_layer:
|
134
|
-
keys_to_delete.append(key)
|
135
|
-
|
136
|
-
for key in keys_to_delete:
|
137
|
-
del state_dict[key]
|
138
|
-
|
139
|
-
model.load_state_dict(state_dict)
|
203
|
+
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
204
|
+
|
140
205
|
return model
|
141
206
|
|
207
|
+
def _load_weights_directly_to_model(model, model_id, n_layer=None):
|
208
|
+
"""
|
209
|
+
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
210
|
+
"""
|
211
|
+
|
212
|
+
model_params = dict(model.named_parameters(recurse=True))
|
213
|
+
model_buffers = dict(model.named_buffers(recurse=True))
|
214
|
+
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
215
|
+
|
216
|
+
target_layers = list(range(n_layer)) if n_layer is not None else None
|
217
|
+
|
218
|
+
for safetensor_file in safetensor_files:
|
219
|
+
file_data = load_file(safetensor_file)
|
220
|
+
for key, value in file_data.items():
|
221
|
+
if target_layers is not None:
|
222
|
+
parts = key.split(".")
|
223
|
+
|
224
|
+
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
225
|
+
continue
|
226
|
+
|
227
|
+
if key in model_params:
|
228
|
+
model_params[key].data.copy_(value)
|
229
|
+
elif key in model_buffers:
|
230
|
+
model_buffers[key].data.copy_(value)
|
231
|
+
|
232
|
+
return 0
|
233
|
+
|
142
234
|
@classmethod
|
143
|
-
def get_pytorch_model(
|
144
|
-
|
145
|
-
|
146
|
-
**kwargs,
|
147
|
-
) -> "PreTrainedModel":
|
148
|
-
rbln_config_kwargs = kwargs.get("rbln_config_kwargs", {})
|
149
|
-
rbln_quantization = rbln_config_kwargs.get("rbln_quantization", None)
|
235
|
+
def get_pytorch_model(cls, *args, **kwargs) -> "PreTrainedModel":
|
236
|
+
rbln_kwargs = kwargs.get("rbln_kwargs", {})
|
237
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
150
238
|
|
151
239
|
if rbln_quantization is not None and rbln_quantization["format"] == "rbln":
|
152
240
|
model = cls.get_quantized_model(*args, **kwargs)
|
@@ -160,13 +248,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
160
248
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
161
249
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
162
250
|
|
163
|
-
|
164
|
-
|
251
|
+
rbln_compile_configs = rbln_config.compile_cfgs
|
252
|
+
prefill_rbln_compile_config = rbln_compile_configs[0]
|
253
|
+
dec_rbln_compile_config = rbln_compile_configs[1]
|
165
254
|
|
255
|
+
@rbln_timer("Jit Trace")
|
166
256
|
def get_scripted_model():
|
167
257
|
# This function is nested to dealloc the example inputs before compilation.
|
168
|
-
prefill_example_inputs =
|
169
|
-
dec_example_inputs =
|
258
|
+
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
259
|
+
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=4)
|
170
260
|
|
171
261
|
batch_index = 3
|
172
262
|
dec_example_inputs[batch_index].fill_(-1) # fill batch_position -1 to indicate it is decoder.
|
@@ -181,17 +271,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
181
271
|
|
182
272
|
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
183
273
|
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
274
|
+
@rbln_timer("TorchScript to IR")
|
275
|
+
def scripted_model_to_ir():
|
276
|
+
prefill_ir = rebel.torchscript_to_ir(
|
277
|
+
prefill_scripted_model,
|
278
|
+
input_names=[v[0] for v in prefill_rbln_compile_config.input_info],
|
279
|
+
)
|
280
|
+
dec_ir = rebel.torchscript_to_ir(
|
281
|
+
dec_scripted_model,
|
282
|
+
input_names=[v[0] for v in dec_rbln_compile_config.input_info],
|
283
|
+
)
|
284
|
+
return prefill_ir, dec_ir
|
192
285
|
|
286
|
+
prefill_ir, dec_ir = scripted_model_to_ir()
|
193
287
|
# Caching prefill_decoder/decoder I/O
|
194
|
-
cache_index_offset =
|
288
|
+
cache_index_offset = 5
|
195
289
|
connections = [
|
196
290
|
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
197
291
|
for i in range(model.config.num_hidden_layers * 2)
|
@@ -201,9 +295,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
201
295
|
prefill_ir,
|
202
296
|
dec_ir,
|
203
297
|
connections=connections,
|
204
|
-
fusion=
|
205
|
-
npu=
|
206
|
-
tensor_parallel_size=
|
298
|
+
fusion=prefill_rbln_compile_config.fusion,
|
299
|
+
npu=prefill_rbln_compile_config.npu,
|
300
|
+
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
207
301
|
use_weight_sharing=True,
|
208
302
|
)
|
209
303
|
return compiled_model
|
@@ -213,12 +307,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
213
307
|
cls,
|
214
308
|
preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
|
215
309
|
model_config: "PretrainedConfig",
|
216
|
-
|
217
|
-
rbln_batch_size: Optional[int] = None,
|
218
|
-
rbln_quantization: Optional[Dict[str, str]] = None,
|
219
|
-
**kwargs,
|
310
|
+
rbln_kwargs: Dict[str, Any] = {},
|
220
311
|
) -> RBLNConfig:
|
221
|
-
|
312
|
+
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
313
|
+
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
314
|
+
rbln_quantization = rbln_kwargs.get("quantization", None)
|
315
|
+
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
222
316
|
|
223
317
|
prefill_chunk_size = 128
|
224
318
|
if rbln_max_seq_len is None:
|
@@ -228,15 +322,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
228
322
|
if rbln_max_seq_len is None:
|
229
323
|
raise ValueError("`rbln_max_seq_len` should be specified.")
|
230
324
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
231
|
-
|
232
|
-
meta["rbln_max_seq_len"] = rbln_max_seq_len
|
233
|
-
meta["rbln_batch_size"] = rbln_batch_size
|
234
|
-
meta["rbln_prefill_chunk_size"] = prefill_chunk_size
|
325
|
+
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
235
326
|
|
236
327
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
237
328
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
238
329
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
239
330
|
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
331
|
+
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
240
332
|
|
241
333
|
if rbln_quantization is not None:
|
242
334
|
q_format = rbln_quantization.get("format", None)
|
@@ -247,21 +339,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
247
339
|
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
248
340
|
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
249
341
|
)
|
250
|
-
meta["rbln_quantization"] = rbln_quantization
|
251
342
|
|
252
343
|
def get_input_info(
|
253
344
|
batch_size,
|
254
345
|
query_length,
|
346
|
+
use_inputs_embeds,
|
347
|
+
hidden_size,
|
255
348
|
):
|
349
|
+
if use_inputs_embeds:
|
350
|
+
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
351
|
+
else:
|
352
|
+
main_input = ("input_ids", [batch_size, query_length], "int64")
|
353
|
+
|
256
354
|
input_info = [
|
257
|
-
|
258
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "
|
355
|
+
main_input,
|
356
|
+
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
259
357
|
(
|
260
358
|
"cache_position",
|
261
359
|
[batch_size, query_length],
|
262
360
|
"int32",
|
263
361
|
),
|
264
362
|
("batch_position", [], "int16"),
|
363
|
+
("query_idx", [], "int16"),
|
265
364
|
]
|
266
365
|
|
267
366
|
input_info.extend(
|
@@ -285,22 +384,37 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
285
384
|
prefill_input_info = get_input_info(
|
286
385
|
batch_size=1,
|
287
386
|
query_length=prefill_chunk_size,
|
387
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
388
|
+
hidden_size=hidden_size,
|
288
389
|
)
|
289
390
|
dec_input_info = get_input_info(
|
290
391
|
batch_size=rbln_batch_size,
|
291
392
|
query_length=1,
|
393
|
+
use_inputs_embeds=rbln_use_inputs_embeds,
|
394
|
+
hidden_size=hidden_size,
|
292
395
|
)
|
293
396
|
|
294
|
-
|
295
|
-
|
397
|
+
prefill_rbln_compile_config = RBLNCompileConfig(input_info=prefill_input_info)
|
398
|
+
dec_rbln_compile_config = RBLNCompileConfig(input_info=dec_input_info)
|
296
399
|
|
297
|
-
|
400
|
+
rbln_config = RBLNConfig(
|
401
|
+
rbln_cls=cls.__name__,
|
402
|
+
compile_cfgs=[prefill_rbln_compile_config, dec_rbln_compile_config],
|
403
|
+
rbln_kwargs=rbln_kwargs,
|
404
|
+
)
|
298
405
|
|
299
|
-
rbln_config
|
300
|
-
|
301
|
-
|
406
|
+
rbln_config.model_cfg.update(
|
407
|
+
{
|
408
|
+
"max_seq_len": rbln_max_seq_len,
|
409
|
+
"batch_size": rbln_batch_size,
|
410
|
+
"prefill_chunk_size": prefill_chunk_size,
|
411
|
+
"use_inputs_embeds": rbln_use_inputs_embeds,
|
412
|
+
}
|
302
413
|
)
|
303
414
|
|
415
|
+
if rbln_quantization is not None:
|
416
|
+
rbln_config.model_cfg.update({"quantization": rbln_quantization})
|
417
|
+
|
304
418
|
return rbln_config
|
305
419
|
|
306
420
|
@classmethod
|
@@ -322,71 +436,117 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
322
436
|
def _reorder_cache(self, past_key_values, beam_idx):
|
323
437
|
raise NotImplementedError
|
324
438
|
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
439
|
+
def prepare_inputs_for_generation(
|
440
|
+
self,
|
441
|
+
input_ids: torch.LongTensor,
|
442
|
+
past_cached_length: Optional[torch.Tensor] = None,
|
443
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
444
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
445
|
+
**kwargs,
|
446
|
+
):
|
447
|
+
model_inputs = {}
|
448
|
+
# prefill phase
|
332
449
|
if past_cached_length is None:
|
333
|
-
|
450
|
+
# huggingface make dummy_input_ids if model_input_name is "input_embeds"
|
451
|
+
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
|
452
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
453
|
+
input_tensors = inputs_embeds
|
454
|
+
else:
|
455
|
+
input_tensors = input_ids
|
456
|
+
|
457
|
+
batch_size = input_tensors.shape[0]
|
458
|
+
l_input_tensors = []
|
334
459
|
cache_positions = []
|
335
460
|
past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
|
336
461
|
for i in range(batch_size):
|
337
|
-
|
338
|
-
|
339
|
-
valid_len =
|
462
|
+
input_tensor = input_tensors[i]
|
463
|
+
input_tensor = input_tensor[attention_mask[i] == 1]
|
464
|
+
valid_len = input_tensor.shape[0]
|
340
465
|
cache_position = torch.arange(0, valid_len, dtype=torch.int32)
|
341
466
|
past_cached_length[i] = valid_len
|
342
|
-
|
467
|
+
l_input_tensors.append(input_tensor.unsqueeze(0))
|
343
468
|
cache_positions.append(cache_position.unsqueeze(0))
|
344
469
|
|
345
|
-
|
470
|
+
input_tensors = l_input_tensors
|
471
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
472
|
+
model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
|
473
|
+
else:
|
474
|
+
model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
|
475
|
+
# decoder phase
|
346
476
|
else:
|
347
477
|
input_ids = input_ids[:, -1:]
|
348
478
|
cache_positions = past_cached_length
|
349
479
|
past_cached_length = past_cached_length + 1
|
480
|
+
model_inputs.update({"input_ids": input_ids})
|
350
481
|
|
351
|
-
model_inputs
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
482
|
+
model_inputs.update(
|
483
|
+
{
|
484
|
+
"cache_position": cache_positions,
|
485
|
+
"past_cached_length": past_cached_length,
|
486
|
+
}
|
487
|
+
)
|
356
488
|
|
357
489
|
return model_inputs
|
358
490
|
|
491
|
+
def _update_model_kwargs_for_generation(
|
492
|
+
self,
|
493
|
+
outputs: RBLNDecoderOnlyOutput,
|
494
|
+
model_kwargs: Dict[str, Any],
|
495
|
+
**kwargs,
|
496
|
+
) -> Dict[str, Any]:
|
497
|
+
# update past_cached_length
|
498
|
+
model_kwargs["past_cached_length"] = outputs.past_cached_length
|
499
|
+
|
500
|
+
return model_kwargs
|
501
|
+
|
359
502
|
def forward(
|
360
503
|
self,
|
361
|
-
input_ids: torch.LongTensor = None,
|
504
|
+
input_ids: Optional[Union[List[torch.LongTensor], torch.LongTensor]] = None,
|
505
|
+
inputs_embeds: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
|
362
506
|
cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
|
363
507
|
batch_idx: Optional[int] = None,
|
364
|
-
past_cached_length: Optional[torch.Tensor] = None,
|
508
|
+
past_cached_length: Optional[torch.Tensor] = None,
|
365
509
|
**kwargs,
|
366
510
|
) -> Tuple[torch.FloatTensor]:
|
367
511
|
# prefll & hf generate
|
368
512
|
if isinstance(cache_position, list):
|
369
513
|
logits = []
|
370
|
-
|
371
|
-
|
514
|
+
input_tensors = input_ids if inputs_embeds is None else inputs_embeds
|
515
|
+
for batch_idx, (input_tensor, cache_pos) in enumerate(zip(input_tensors, cache_position)):
|
516
|
+
logit = self._forward_prefill(
|
517
|
+
input_ids=input_tensor if inputs_embeds is None else None,
|
518
|
+
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
519
|
+
cache_position=cache_pos,
|
520
|
+
batch_idx=batch_idx,
|
521
|
+
)
|
372
522
|
logits.append(logit)
|
373
523
|
logits = torch.cat(logits, dim=0)
|
374
524
|
# prefill & vllm step
|
375
525
|
elif cache_position.shape[-1] > 1:
|
376
|
-
logits = self._forward_prefill(
|
526
|
+
logits = self._forward_prefill(
|
527
|
+
input_ids=input_ids,
|
528
|
+
inputs_embeds=inputs_embeds,
|
529
|
+
cache_position=cache_position,
|
530
|
+
batch_idx=batch_idx,
|
531
|
+
)
|
377
532
|
# common decoder
|
378
533
|
else:
|
379
|
-
logits = self._forward_decoder(
|
534
|
+
logits = self._forward_decoder(
|
535
|
+
input_ids=input_ids,
|
536
|
+
inputs_embeds=inputs_embeds,
|
537
|
+
cache_position=cache_position,
|
538
|
+
)
|
380
539
|
|
381
|
-
return
|
540
|
+
return RBLNDecoderOnlyOutput(
|
382
541
|
logits=logits,
|
383
|
-
|
542
|
+
past_cached_length=past_cached_length,
|
384
543
|
)
|
385
544
|
|
386
545
|
def _forward_prefill(
|
387
546
|
self,
|
388
547
|
input_ids: torch.LongTensor = None,
|
389
|
-
|
548
|
+
inputs_embeds: torch.Tensor = None,
|
549
|
+
cache_position: torch.Tensor = None,
|
390
550
|
batch_idx: int = None,
|
391
551
|
) -> torch.FloatTensor:
|
392
552
|
if batch_idx is None or batch_idx >= self.batch_size:
|
@@ -398,7 +558,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
398
558
|
torch.empty(
|
399
559
|
size=[
|
400
560
|
1,
|
401
|
-
|
561
|
+
1,
|
402
562
|
self.config.vocab_size,
|
403
563
|
],
|
404
564
|
dtype=torch.float32,
|
@@ -407,11 +567,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
407
567
|
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
408
568
|
]
|
409
569
|
|
410
|
-
|
570
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
571
|
+
model_input_name = "inputs_embeds"
|
572
|
+
else:
|
573
|
+
model_input_name = "input_ids"
|
574
|
+
|
575
|
+
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
576
|
+
|
577
|
+
query_length = input_tensors.shape[1]
|
411
578
|
attention_mask = self.prefill_attention_mask.clone()
|
412
579
|
for step in range(0, query_length, self.prefill_chunk_size):
|
413
580
|
if step + self.prefill_chunk_size > query_length:
|
414
|
-
|
581
|
+
# input_tensors = torch.nn.functional.pad(input_tensors, (0, step + self.prefill_chunk_size - query_length))
|
582
|
+
padding_needed = step + self.prefill_chunk_size - query_length
|
583
|
+
if model_input_name == "input_ids":
|
584
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, padding_needed))
|
585
|
+
else:
|
586
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, padding_needed))
|
587
|
+
|
415
588
|
cache_position = torch.cat(
|
416
589
|
[
|
417
590
|
cache_position,
|
@@ -424,21 +597,24 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
424
597
|
dim=-1,
|
425
598
|
)
|
426
599
|
|
427
|
-
|
600
|
+
sliced_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
428
601
|
sliced_cache_positions = cache_position[:, step : step + self.prefill_chunk_size]
|
429
602
|
|
430
603
|
if step >= self.prefill_chunk_size:
|
431
604
|
attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
432
605
|
attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
433
606
|
|
607
|
+
query_idx = query_length % self.prefill_chunk_size - 1
|
608
|
+
|
434
609
|
logits, _ = self.prefill_decoder(
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
610
|
+
input_ids=sliced_input_tensors.contiguous() if model_input_name == "input_ids" else None,
|
611
|
+
inputs_embeds=sliced_input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
|
612
|
+
attention_mask=attention_mask.contiguous(),
|
613
|
+
cache_position=sliced_cache_positions.contiguous(),
|
614
|
+
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
615
|
+
query_idx=torch.tensor(query_idx, dtype=torch.int16),
|
439
616
|
out=out_buffers,
|
440
617
|
)
|
441
|
-
logits = logits[:, query_length % self.prefill_chunk_size - 1].unsqueeze(1)
|
442
618
|
|
443
619
|
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
444
620
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
@@ -446,19 +622,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
446
622
|
return logits
|
447
623
|
|
448
624
|
def _forward_decoder(
|
449
|
-
self,
|
625
|
+
self,
|
626
|
+
input_ids: torch.LongTensor = None,
|
627
|
+
inputs_embeds: torch.Tensor = None,
|
628
|
+
cache_position: torch.Tensor = None,
|
450
629
|
) -> torch.FloatTensor:
|
451
|
-
|
630
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
631
|
+
model_input_name = "inputs_embeds"
|
632
|
+
else:
|
633
|
+
model_input_name = "input_ids"
|
634
|
+
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
635
|
+
|
636
|
+
batch_size = input_tensors.shape[0]
|
452
637
|
|
453
638
|
for b_idx in range(batch_size):
|
454
639
|
decoding_step = cache_position[b_idx].item()
|
455
640
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
456
641
|
|
457
642
|
logits, _ = self.decoder(
|
458
|
-
input_ids.contiguous(),
|
459
|
-
|
460
|
-
|
461
|
-
|
643
|
+
input_ids=input_tensors.contiguous() if model_input_name == "input_ids" else None,
|
644
|
+
inputs_embeds=input_tensors.contiguous() if model_input_name == "inputs_embeds" else None,
|
645
|
+
attention_mask=self.dec_attn_mask.contiguous(),
|
646
|
+
cache_position=cache_position.contiguous(),
|
647
|
+
batch_position=torch.tensor(0, dtype=torch.int16),
|
648
|
+
query_idx=torch.tensor(0, dtype=torch.int16),
|
462
649
|
)
|
463
650
|
|
464
651
|
return logits
|