optimum-rbln 0.1.13__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +41 -38
- optimum/rbln/__version__.py +16 -1
- optimum/rbln/diffusers/__init__.py +26 -2
- optimum/rbln/{modeling_diffusers.py → diffusers/modeling_diffusers.py} +97 -126
- optimum/rbln/diffusers/models/__init__.py +36 -3
- optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
- optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +73 -61
- optimum/rbln/diffusers/models/autoencoders/vae.py +83 -0
- optimum/rbln/diffusers/models/controlnet.py +54 -14
- optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
- optimum/rbln/diffusers/models/unets/__init__.py +24 -0
- optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +82 -22
- optimum/rbln/diffusers/pipelines/__init__.py +23 -2
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +13 -33
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +17 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -2
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +18 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +24 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +15 -8
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
- optimum/rbln/modeling.py +238 -0
- optimum/rbln/modeling_base.py +186 -760
- optimum/rbln/modeling_config.py +31 -7
- optimum/rbln/ops/__init__.py +26 -0
- optimum/rbln/ops/attn.py +221 -0
- optimum/rbln/ops/flash_attn.py +70 -0
- optimum/rbln/ops/kv_cache_update.py +69 -0
- optimum/rbln/transformers/__init__.py +20 -2
- optimum/rbln/{modeling_alias.py → transformers/modeling_alias.py} +5 -1
- optimum/rbln/transformers/modeling_generic.py +385 -0
- optimum/rbln/transformers/models/auto/__init__.py +23 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
- optimum/rbln/transformers/models/auto/modeling_auto.py +36 -12
- optimum/rbln/transformers/models/bart/__init__.py +0 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +107 -464
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -9
- optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
- optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
- optimum/rbln/transformers/models/decoderonly/__init__.py +0 -10
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +775 -514
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +128 -260
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +60 -45
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
- optimum/rbln/transformers/models/llama/llama_architecture.py +0 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -75
- optimum/rbln/transformers/models/midm/midm_architecture.py +84 -238
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -6
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +0 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +60 -261
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +0 -1
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -103
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +498 -0
- optimum/rbln/transformers/models/t5/__init__.py +0 -1
- optimum/rbln/transformers/models/t5/modeling_t5.py +106 -5
- optimum/rbln/transformers/models/t5/t5_architecture.py +106 -448
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
- optimum/rbln/transformers/models/whisper/generation_whisper.py +42 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +78 -55
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +219 -312
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
- optimum/rbln/transformers/utils/rbln_quantization.py +120 -4
- optimum/rbln/utils/decorator_utils.py +51 -11
- optimum/rbln/utils/hub.py +131 -0
- optimum/rbln/utils/import_utils.py +22 -1
- optimum/rbln/utils/logging.py +37 -0
- optimum/rbln/utils/model_utils.py +52 -0
- optimum/rbln/utils/runtime_utils.py +10 -4
- optimum/rbln/utils/save_utils.py +17 -0
- optimum/rbln/utils/submodule.py +137 -0
- optimum_rbln-0.2.0.dist-info/METADATA +117 -0
- optimum_rbln-0.2.0.dist-info/RECORD +114 -0
- {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.2.0.dist-info}/WHEEL +1 -1
- optimum_rbln-0.2.0.dist-info/licenses/LICENSE +288 -0
- optimum/rbln/transformers/cache_utils.py +0 -107
- optimum/rbln/transformers/generation/streamers.py +0 -139
- optimum/rbln/transformers/generation/utils.py +0 -397
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
- optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
- optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
- optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
- optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
- optimum/rbln/utils/context.py +0 -58
- optimum/rbln/utils/timer_utils.py +0 -43
- optimum_rbln-0.1.13.dist-info/METADATA +0 -120
- optimum_rbln-0.1.13.dist-info/RECORD +0 -107
- optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
- optimum_rbln-0.1.13.dist-info/licenses/LICENSE +0 -201
@@ -20,45 +20,34 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
-
|
24
|
-
import glob
|
23
|
+
|
25
24
|
import inspect
|
26
|
-
import os
|
27
25
|
from dataclasses import dataclass
|
28
26
|
from pathlib import Path
|
29
27
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
30
28
|
|
31
29
|
import rebel
|
32
30
|
import torch
|
33
|
-
import
|
34
|
-
from safetensors.torch import load_file
|
31
|
+
from rebel.compile_context import CompileContext
|
35
32
|
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
36
33
|
from transformers.modeling_utils import no_init_weights
|
37
34
|
from transformers.utils import ModelOutput
|
38
35
|
|
39
|
-
from ....
|
40
|
-
from ....modeling_config import
|
36
|
+
from ....modeling import RBLNModel
|
37
|
+
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
41
38
|
from ....utils.logging import get_logger
|
42
39
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
43
|
-
from
|
44
|
-
from .decoderonly_architecture import
|
40
|
+
from ...utils.rbln_quantization import QuantizationManager
|
41
|
+
from .decoderonly_architecture import (
|
42
|
+
DecoderOnlyWrapper,
|
43
|
+
validate_attention_method,
|
44
|
+
)
|
45
45
|
|
46
46
|
|
47
47
|
logger = get_logger()
|
48
48
|
|
49
49
|
if TYPE_CHECKING:
|
50
|
-
from transformers import
|
51
|
-
AutoFeatureExtractor,
|
52
|
-
AutoProcessor,
|
53
|
-
AutoTokenizer,
|
54
|
-
PretrainedConfig,
|
55
|
-
)
|
56
|
-
|
57
|
-
SUPPORTED_QUANTIZATIONS = {
|
58
|
-
"rbln": [
|
59
|
-
"w4a16",
|
60
|
-
],
|
61
|
-
}
|
50
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
|
62
51
|
|
63
52
|
|
64
53
|
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
@@ -70,32 +59,21 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
70
59
|
inputs_embeds: torch.Tensor,
|
71
60
|
attention_mask: torch.Tensor,
|
72
61
|
cache_position: torch.Tensor,
|
73
|
-
batch_position: torch.Tensor,
|
74
|
-
query_idx: torch.Tensor,
|
75
62
|
**kwargs,
|
76
63
|
):
|
77
64
|
if inputs_embeds is None:
|
78
65
|
inp = input_ids
|
79
66
|
if self.embed_tokens is not None:
|
80
67
|
inp = self.embed_tokens(inp)
|
81
|
-
|
82
|
-
return super().forward(
|
83
|
-
inp,
|
84
|
-
attention_mask,
|
85
|
-
cache_position,
|
86
|
-
batch_position,
|
87
|
-
query_idx,
|
88
|
-
**kwargs,
|
89
|
-
)
|
90
68
|
else:
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
69
|
+
inp = inputs_embeds
|
70
|
+
|
71
|
+
return super().forward(
|
72
|
+
inp,
|
73
|
+
attention_mask,
|
74
|
+
cache_position,
|
75
|
+
**kwargs,
|
76
|
+
)
|
99
77
|
|
100
78
|
|
101
79
|
@dataclass
|
@@ -127,24 +105,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
127
105
|
main_input_name = "input_ids"
|
128
106
|
auto_model_class = AutoModelForCausalLM
|
129
107
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
130
|
-
|
131
|
-
|
132
|
-
@classmethod
|
133
|
-
@property
|
134
|
-
def original_cls(cls):
|
135
|
-
"""
|
136
|
-
Lazily loads and caches the corresponding Hugging Face model class.
|
137
|
-
Removes 'RBLN' prefix from the class name to get the original class name
|
138
|
-
(e.g., RBLNLlamaForCausalLM -> LlamaForCausalLM) and imports it from
|
139
|
-
the transformers module.
|
140
|
-
|
141
|
-
Returns:
|
142
|
-
type: The original Hugging Face model class
|
143
|
-
"""
|
144
|
-
if cls._original_cls is None:
|
145
|
-
hf_original_cls_name = cls.__name__[4:]
|
146
|
-
cls._original_cls = getattr(transformers, hf_original_cls_name)
|
147
|
-
return cls._original_cls
|
108
|
+
_use_rotary_emb = True
|
148
109
|
|
149
110
|
def __post_init__(self, **kwargs):
|
150
111
|
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
@@ -203,6 +164,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
203
164
|
def get_quantized_model(
|
204
165
|
cls,
|
205
166
|
model_id: str,
|
167
|
+
config: Optional[PretrainedConfig] = None,
|
206
168
|
use_auth_token: Optional[Union[bool, str]] = None,
|
207
169
|
revision: Optional[str] = None,
|
208
170
|
force_download: bool = False,
|
@@ -212,57 +174,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
212
174
|
trust_remote_code: bool = False,
|
213
175
|
**kwargs,
|
214
176
|
):
|
215
|
-
from ...utils.rbln_quantization import
|
177
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
216
178
|
|
217
179
|
kwargs = cls.update_kwargs(kwargs)
|
218
180
|
|
219
|
-
config
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
181
|
+
if config is None:
|
182
|
+
config = AutoConfig.from_pretrained(
|
183
|
+
model_id,
|
184
|
+
use_auth_token=use_auth_token,
|
185
|
+
revision=revision,
|
186
|
+
force_download=force_download,
|
187
|
+
cache_dir=cache_dir,
|
188
|
+
trust_remote_code=trust_remote_code,
|
189
|
+
**kwargs,
|
190
|
+
)
|
228
191
|
|
229
192
|
with no_init_weights():
|
230
193
|
model = AutoModelForCausalLM.from_config(config)
|
231
194
|
|
232
|
-
|
233
|
-
|
234
|
-
n_layer = kwargs.get("num_hidden_layers", None)
|
235
|
-
cls._load_weights_directly_to_model(model, model_id, n_layer)
|
195
|
+
prepare_model_for_quantization(model, model_id, kwargs.get("num_hidden_layers"))
|
236
196
|
|
237
197
|
return model
|
238
198
|
|
239
|
-
def _load_weights_directly_to_model(model, model_id, n_layer=None):
|
240
|
-
"""
|
241
|
-
Load safetensor file data directly into the model, filtering by layer if n_layer is provided.
|
242
|
-
"""
|
243
|
-
|
244
|
-
model_params = dict(model.named_parameters(recurse=True))
|
245
|
-
model_buffers = dict(model.named_buffers(recurse=True))
|
246
|
-
safetensor_files = glob.glob(f"{model_id}/*.safetensors")
|
247
|
-
|
248
|
-
target_layers = list(range(n_layer)) if n_layer is not None else None
|
249
|
-
|
250
|
-
for safetensor_file in safetensor_files:
|
251
|
-
file_data = load_file(safetensor_file)
|
252
|
-
for key, value in file_data.items():
|
253
|
-
if target_layers is not None:
|
254
|
-
parts = key.split(".")
|
255
|
-
|
256
|
-
if len(parts) > 2 and parts[2].isdigit() and (int(parts[2]) not in target_layers):
|
257
|
-
continue
|
258
|
-
|
259
|
-
if key in model_params:
|
260
|
-
model_params[key].data.copy_(value)
|
261
|
-
elif key in model_buffers:
|
262
|
-
model_buffers[key].data.copy_(value)
|
263
|
-
|
264
|
-
return 0
|
265
|
-
|
266
199
|
def __getattr__(self, __name: str) -> Any:
|
267
200
|
"""
|
268
201
|
Special method to delegate attribute access to the original Huggingface LM class.
|
@@ -278,7 +211,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
278
211
|
def redirect(func):
|
279
212
|
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
280
213
|
|
281
|
-
val = getattr(self.
|
214
|
+
val = getattr(self.hf_class, __name, None) or getattr(PreTrainedModel, __name)
|
282
215
|
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
283
216
|
return redirect(val)
|
284
217
|
return val
|
@@ -295,61 +228,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
295
228
|
|
296
229
|
return model
|
297
230
|
|
298
|
-
def validate_quantization_config(quantize_config):
|
299
|
-
if quantize_config is not None:
|
300
|
-
q_format = quantize_config.get("format")
|
301
|
-
q_precision = quantize_config.get("precision")
|
302
|
-
|
303
|
-
if q_format not in SUPPORTED_QUANTIZATIONS:
|
304
|
-
raise ValueError(
|
305
|
-
f"Invalid quantization format: {q_format}. "
|
306
|
-
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
307
|
-
)
|
308
|
-
|
309
|
-
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
310
|
-
raise ValueError(
|
311
|
-
f"Invalid precision: {q_precision} for format: {q_format}. "
|
312
|
-
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
313
|
-
)
|
314
|
-
|
315
|
-
return quantize_config
|
316
|
-
|
317
|
-
@classmethod
|
318
|
-
def set_quantize_env(cls, quantize_config):
|
319
|
-
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
320
|
-
quantize_config = cls.validate_quantization_config(quantize_config)
|
321
|
-
if quantize_config is not None:
|
322
|
-
q_precision = quantize_config.get("precision")
|
323
|
-
quant_bits = q_precision.split("w")[1].split("a")[0]
|
324
|
-
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
325
|
-
return RBLN_QUANT_BITS_ENV
|
326
|
-
return None
|
327
|
-
|
328
|
-
@classmethod
|
329
|
-
def reset_quantize_env(cls, env_var_name):
|
330
|
-
if env_var_name is not None and env_var_name in os.environ:
|
331
|
-
del os.environ[env_var_name]
|
332
|
-
|
333
|
-
@classmethod
|
334
|
-
def manage_quantize_env(cls, func):
|
335
|
-
@functools.wraps(func)
|
336
|
-
def wrapper(*args, **kwargs):
|
337
|
-
quantize_config = kwargs.get("quantize_config")
|
338
|
-
quantize_env_var = cls.set_quantize_env(quantize_config)
|
339
|
-
try:
|
340
|
-
return func(*args, **kwargs)
|
341
|
-
finally:
|
342
|
-
cls.reset_quantize_env(quantize_env_var)
|
343
|
-
|
344
|
-
return wrapper
|
345
|
-
|
346
231
|
@classmethod
|
347
232
|
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
348
233
|
wrapper_cfg = {"max_seq_len": rbln_config.model_cfg["max_seq_len"]}
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
234
|
+
wrapper_cfg["attn_impl"] = rbln_config.model_cfg.get("attn_impl")
|
235
|
+
wrapper_cfg["kvcache_partition_len"] = rbln_config.model_cfg.get("kvcache_partition_len")
|
236
|
+
wrapper_cfg["use_rotary_emb"] = cls._use_rotary_emb
|
353
237
|
|
354
238
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
355
239
|
|
@@ -359,69 +243,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
359
243
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
360
244
|
|
361
245
|
rbln_compile_configs = rbln_config.compile_cfgs
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
@rbln_timer("JIT trace")
|
366
|
-
def get_scripted_model():
|
367
|
-
# This function is nested to dealloc the example inputs before compilation.
|
368
|
-
# FIXME: 3rd dummy_input(batch_idx) should be fill zero to compile flash_attn.
|
369
|
-
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
370
|
-
dec_example_inputs = dec_rbln_compile_config.get_dummy_inputs(fill=0)
|
371
|
-
|
372
|
-
prefill_scripted_model = torch.jit.trace(
|
373
|
-
wrapped_model, prefill_example_inputs, check_trace=False, _store_inputs=False
|
374
|
-
)
|
375
|
-
dec_scripted_model = torch.jit.trace(
|
376
|
-
wrapped_model, dec_example_inputs, check_trace=False, _store_inputs=False
|
377
|
-
)
|
378
|
-
return prefill_scripted_model, dec_scripted_model
|
246
|
+
prefill_compile_config = rbln_compile_configs[0]
|
247
|
+
dec_compile_config = rbln_compile_configs[1]
|
379
248
|
|
380
|
-
|
249
|
+
context = CompileContext(use_weight_sharing=True)
|
381
250
|
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
prefill_ir, dec_ir = scripted_model_to_ir()
|
395
|
-
# Caching prefill_decoder/decoder I/O
|
396
|
-
cache_index_offset = 5
|
397
|
-
connections = [
|
398
|
-
(prefill_ir.outputs[1 + i], prefill_ir.inputs[cache_index_offset + i])
|
399
|
-
for i in range(model.config.num_hidden_layers * 2)
|
400
|
-
]
|
251
|
+
# Here we use meta tensor, for the memory efficiency.
|
252
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
253
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
254
|
+
|
255
|
+
# Mark static tensors (self kv states)
|
256
|
+
static_tensors = {}
|
257
|
+
for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
|
258
|
+
if "past_key_values" in name:
|
259
|
+
static_tensors[name] = tensor
|
260
|
+
context.mark_static_address(tensor)
|
261
|
+
|
262
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
401
263
|
|
402
|
-
# Extract quantize_config from rbln_config
|
403
264
|
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
404
265
|
|
405
|
-
@
|
266
|
+
@QuantizationManager.with_quantization_env
|
406
267
|
def compile_model(*args, **kwargs):
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
prefill_ir,
|
415
|
-
dec_ir,
|
416
|
-
connections=connections,
|
417
|
-
fusion=prefill_rbln_compile_config.fusion,
|
418
|
-
npu=prefill_rbln_compile_config.npu,
|
419
|
-
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
420
|
-
use_weight_sharing=True,
|
421
|
-
quantize_config=quantize_config,
|
422
|
-
)
|
268
|
+
wrapped_model.phase = "prefill"
|
269
|
+
compiled_prefill = RBLNModel.compile(
|
270
|
+
wrapped_model,
|
271
|
+
prefill_compile_config,
|
272
|
+
example_inputs=prefill_example_inputs,
|
273
|
+
compile_context=context,
|
274
|
+
)
|
423
275
|
|
424
|
-
|
276
|
+
wrapped_model.phase = "decode"
|
277
|
+
compiled_decoder = RBLNModel.compile(
|
278
|
+
wrapped_model,
|
279
|
+
dec_compile_config,
|
280
|
+
example_inputs=dec_example_inputs,
|
281
|
+
compile_context=context,
|
282
|
+
)
|
283
|
+
return {"prefill": compiled_prefill, "decoder": compiled_decoder}
|
284
|
+
|
285
|
+
return compile_model(quantize_config=quantize_config)
|
425
286
|
|
426
287
|
@classmethod
|
427
288
|
def _get_rbln_config(
|
@@ -432,10 +293,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
432
293
|
) -> RBLNConfig:
|
433
294
|
rbln_max_seq_len = rbln_kwargs.get("max_seq_len", None)
|
434
295
|
rbln_batch_size = rbln_kwargs.get("batch_size", None)
|
435
|
-
rbln_quantization = rbln_kwargs.get("quantization", None)
|
436
296
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
437
|
-
|
438
|
-
|
297
|
+
rbln_attn_impl = rbln_kwargs.get("attn_impl", None)
|
298
|
+
rbln_kvcache_partition_len = rbln_kwargs.get("kvcache_partition_len", None)
|
299
|
+
rbln_quantization = QuantizationManager.validate_quantization_config(rbln_kwargs.get("quantization", None))
|
439
300
|
|
440
301
|
prefill_chunk_size = 128
|
441
302
|
if rbln_max_seq_len is None:
|
@@ -444,9 +305,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
444
305
|
)
|
445
306
|
if rbln_max_seq_len is None:
|
446
307
|
raise ValueError("`rbln_max_seq_len` should be specified.")
|
308
|
+
|
447
309
|
rbln_batch_size = 1 if rbln_batch_size is None else rbln_batch_size
|
448
310
|
rbln_use_inputs_embeds = False if rbln_use_inputs_embeds is None else rbln_use_inputs_embeds
|
449
311
|
|
312
|
+
rbln_attn_impl, rbln_kvcache_partition_len = validate_attention_method(
|
313
|
+
rbln_attn_impl=rbln_attn_impl,
|
314
|
+
rbln_kvcache_partition_len=rbln_kvcache_partition_len,
|
315
|
+
rbln_max_seq_len=rbln_max_seq_len,
|
316
|
+
)
|
317
|
+
|
450
318
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
451
319
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
452
320
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
@@ -472,9 +340,14 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
472
340
|
[batch_size, query_length],
|
473
341
|
"int32",
|
474
342
|
),
|
475
|
-
("batch_position", [], "int16"),
|
476
|
-
("query_idx", [], "int16"),
|
477
343
|
]
|
344
|
+
if query_length > 1:
|
345
|
+
input_info.extend(
|
346
|
+
[
|
347
|
+
("batch_position", [], "int16"),
|
348
|
+
("query_position", [], "int16"),
|
349
|
+
]
|
350
|
+
)
|
478
351
|
|
479
352
|
input_info.extend(
|
480
353
|
[
|
@@ -507,12 +380,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
507
380
|
hidden_size=hidden_size,
|
508
381
|
)
|
509
382
|
|
510
|
-
|
511
|
-
|
383
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
384
|
+
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
512
385
|
|
513
386
|
rbln_config = RBLNConfig(
|
514
387
|
rbln_cls=cls.__name__,
|
515
|
-
compile_cfgs=[
|
388
|
+
compile_cfgs=[prefill_compile_config, dec_compile_config],
|
516
389
|
rbln_kwargs=rbln_kwargs,
|
517
390
|
)
|
518
391
|
|
@@ -522,6 +395,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
522
395
|
"batch_size": rbln_batch_size,
|
523
396
|
"prefill_chunk_size": prefill_chunk_size,
|
524
397
|
"use_inputs_embeds": rbln_use_inputs_embeds,
|
398
|
+
"kvcache_partition_len": rbln_kvcache_partition_len,
|
399
|
+
"attn_impl": rbln_attn_impl,
|
525
400
|
}
|
526
401
|
)
|
527
402
|
|
@@ -532,12 +407,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
532
407
|
|
533
408
|
@classmethod
|
534
409
|
def _create_runtimes(
|
535
|
-
cls,
|
410
|
+
cls,
|
411
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
412
|
+
rbln_device_map: Dict[str, int],
|
413
|
+
activate_profiler: Optional[bool] = None,
|
536
414
|
) -> List[rebel.Runtime]:
|
537
|
-
|
415
|
+
if any(model_name not in rbln_device_map for model_name in ["prefill", "decoder"]):
|
416
|
+
cls._raise_missing_compiled_file_error(["prefill", "decoder"])
|
417
|
+
|
538
418
|
return [
|
539
|
-
compiled_models[0].create_runtime(
|
540
|
-
|
419
|
+
compiled_models[0].create_runtime(
|
420
|
+
tensor_type="pt", device=rbln_device_map["prefill"], activate_profiler=activate_profiler
|
421
|
+
),
|
422
|
+
compiled_models[1].create_runtime(
|
423
|
+
tensor_type="pt", device=rbln_device_map["decoder"], activate_profiler=activate_profiler
|
424
|
+
),
|
541
425
|
]
|
542
426
|
|
543
427
|
def get_decoder(self):
|
@@ -610,8 +494,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
610
494
|
cache_position: Optional[torch.Tensor] = None,
|
611
495
|
attention_mask: Optional[torch.LongTensor] = None,
|
612
496
|
generate_idx: Optional[torch.Tensor] = None,
|
613
|
-
# from llava_next forward args
|
614
|
-
batch_idx: Optional[int] = None,
|
615
497
|
**kwargs,
|
616
498
|
) -> Tuple[torch.FloatTensor]:
|
617
499
|
# prefll
|
@@ -633,7 +515,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
633
515
|
input_ids=input_tensor if inputs_embeds is None else None,
|
634
516
|
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
635
517
|
cache_position=cache_position,
|
636
|
-
batch_idx=b_idx
|
518
|
+
batch_idx=b_idx,
|
637
519
|
)
|
638
520
|
logits.append(logit)
|
639
521
|
logits = torch.cat(logits, dim=0)
|
@@ -671,12 +553,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
671
553
|
],
|
672
554
|
dtype=torch.float32,
|
673
555
|
device="cpu",
|
674
|
-
)
|
675
|
-
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
556
|
+
)
|
676
557
|
]
|
677
558
|
|
678
559
|
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
679
560
|
query_length = input_tensors.shape[1]
|
561
|
+
if query_length > self.max_seq_len:
|
562
|
+
raise ValueError(
|
563
|
+
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.max_seq_len})."
|
564
|
+
)
|
565
|
+
|
680
566
|
_attention_mask = self.prefill_attention_mask.clone()
|
681
567
|
|
682
568
|
for step in range(0, query_length, self.prefill_chunk_size):
|
@@ -709,15 +595,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
709
595
|
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
710
596
|
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
711
597
|
|
712
|
-
|
598
|
+
query_position = (query_length - 1) % self.prefill_chunk_size
|
713
599
|
|
714
|
-
logits
|
600
|
+
logits = self.prefill_decoder(
|
715
601
|
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
716
602
|
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
717
603
|
attention_mask=_attention_mask.contiguous(),
|
718
604
|
cache_position=_cache_position.contiguous(),
|
719
605
|
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
720
|
-
|
606
|
+
query_position=torch.tensor(query_position, dtype=torch.int16),
|
721
607
|
out=out_buffers,
|
722
608
|
)
|
723
609
|
|
@@ -734,48 +620,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
734
620
|
cache_position: torch.Tensor = None,
|
735
621
|
) -> torch.FloatTensor:
|
736
622
|
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
623
|
+
if input_tensors is None:
|
624
|
+
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
737
625
|
|
738
626
|
batch_size = input_tensors.shape[0]
|
627
|
+
if batch_size != self.batch_size:
|
628
|
+
raise RuntimeError(
|
629
|
+
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
630
|
+
)
|
631
|
+
|
632
|
+
if batch_size != cache_position.shape[0]:
|
633
|
+
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
739
634
|
|
740
635
|
for b_idx in range(batch_size):
|
741
636
|
decoding_step = cache_position[b_idx].item()
|
637
|
+
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
638
|
+
raise ValueError(
|
639
|
+
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
640
|
+
)
|
742
641
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
743
|
-
|
744
|
-
logits, _ = self.decoder(
|
642
|
+
logits = self.decoder(
|
745
643
|
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
746
644
|
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
747
645
|
attention_mask=self.dec_attn_mask.contiguous(),
|
748
646
|
cache_position=cache_position.contiguous(),
|
749
|
-
batch_position=torch.tensor(0, dtype=torch.int16),
|
750
|
-
query_idx=torch.tensor(0, dtype=torch.int16),
|
751
647
|
)
|
752
648
|
|
753
649
|
return logits
|
754
|
-
|
755
|
-
def vllm_forward(
|
756
|
-
self,
|
757
|
-
input_ids: torch.LongTensor = None,
|
758
|
-
inputs_embeds: torch.Tensor = None,
|
759
|
-
cache_position: torch.Tensor = None,
|
760
|
-
batch_idx: Optional[int] = None,
|
761
|
-
**kwargs,
|
762
|
-
) -> Tuple[torch.FloatTensor]:
|
763
|
-
# prefll
|
764
|
-
if cache_position.shape[-1] > 1:
|
765
|
-
logits = self._forward_prefill(
|
766
|
-
input_ids=input_ids,
|
767
|
-
inputs_embeds=inputs_embeds,
|
768
|
-
cache_position=cache_position,
|
769
|
-
batch_idx=batch_idx,
|
770
|
-
)
|
771
|
-
# decoder
|
772
|
-
else:
|
773
|
-
logits = self._forward_decoder(
|
774
|
-
input_ids=input_ids,
|
775
|
-
inputs_embeds=inputs_embeds,
|
776
|
-
cache_position=cache_position,
|
777
|
-
)
|
778
|
-
|
779
|
-
return RBLNDecoderOnlyOutput(
|
780
|
-
logits=logits,
|
781
|
-
)
|
@@ -27,7 +27,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
|
27
27
|
from transformers import AutoModelForDepthEstimation
|
28
28
|
from transformers.modeling_outputs import DepthEstimatorOutput
|
29
29
|
|
30
|
-
from ....
|
30
|
+
from ....modeling import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
32
32
|
|
33
33
|
|