optimum-rbln 0.1.11__py3-none-any.whl → 0.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +10 -7
- optimum/rbln/__version__.py +1 -1
- optimum/rbln/diffusers/models/autoencoder_kl.py +0 -2
- optimum/rbln/diffusers/models/controlnet.py +0 -6
- optimum/rbln/diffusers/models/unet_2d_condition.py +0 -3
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +4 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +18 -20
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +18 -20
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +19 -34
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +20 -35
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +12 -13
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +13 -14
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +13 -14
- optimum/rbln/modeling_alias.py +4 -9
- optimum/rbln/modeling_base.py +105 -139
- optimum/rbln/modeling_config.py +51 -0
- optimum/rbln/transformers/__init__.py +8 -0
- optimum/rbln/transformers/models/__init__.py +4 -1
- optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
- optimum/rbln/transformers/models/bart/__init__.py +1 -1
- optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
- optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
- optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
- optimum/rbln/transformers/models/clip/modeling_clip.py +0 -1
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +172 -100
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
- optimum/rbln/transformers/models/exaone/__init__.py +32 -0
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +72 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
- optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +78 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +148 -152
- optimum/rbln/transformers/models/midm/modeling_midm.py +5 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -0
- optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
- optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +55 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +37 -12
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
- optimum/rbln/utils/import_utils.py +14 -0
- optimum/rbln/utils/logging.py +1 -1
- optimum/rbln/utils/runtime_utils.py +1 -1
- optimum/rbln/utils/timer_utils.py +26 -2
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/METADATA +4 -3
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/RECORD +54 -44
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/WHEEL +1 -1
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/entry_points.txt +0 -0
- {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.12.dist-info}/licenses/LICENSE +0 -0
@@ -47,6 +47,12 @@ from transformers.utils import logging
|
|
47
47
|
logger = logging.get_logger(__name__)
|
48
48
|
|
49
49
|
|
50
|
+
class BartWrapper:
|
51
|
+
def __init__(self, model):
|
52
|
+
self.encoder = BartEncoderWrapper(model)
|
53
|
+
self.decoder = BartDecoderWrapper(model)
|
54
|
+
|
55
|
+
|
50
56
|
class _BartAttention(BartAttention):
|
51
57
|
def forward(
|
52
58
|
self,
|
@@ -238,6 +244,7 @@ class _BartSdpaAttention(BartSdpaAttention):
|
|
238
244
|
value_states, dim=2, start=cache_position, end=cache_position + 1
|
239
245
|
)
|
240
246
|
|
247
|
+
# need 4d shape (input tensors) for scaled_dot_product_attention
|
241
248
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
242
249
|
query_states,
|
243
250
|
key_states,
|
@@ -324,7 +331,6 @@ class _BartDecoder(BartDecoder):
|
|
324
331
|
attn_impl: str = "eager",
|
325
332
|
):
|
326
333
|
# embedding
|
327
|
-
# thkim fix : transformers == 4.44.2 compile
|
328
334
|
if hasattr(self, "embed_scale"):
|
329
335
|
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
330
336
|
else:
|
@@ -336,13 +342,15 @@ class _BartDecoder(BartDecoder):
|
|
336
342
|
hidden_states = inputs_embeds + positions
|
337
343
|
else:
|
338
344
|
hidden_all = []
|
345
|
+
# compiler pattern base dependency -> take + add
|
339
346
|
for i in range(input_ids.shape[0]):
|
340
347
|
# cache position [N,1]
|
341
348
|
positions_idx = cache_position[i]
|
349
|
+
# offset is set 2 in bart embedding
|
342
350
|
position_weight = self.embed_positions.weight[2:]
|
343
351
|
position = position_weight[positions_idx]
|
344
|
-
|
345
|
-
hidden_all.append(
|
352
|
+
batch_hidden = position + inputs_embeds[i]
|
353
|
+
hidden_all.append(batch_hidden)
|
346
354
|
hidden_states = torch.stack(hidden_all, dim=0)
|
347
355
|
|
348
356
|
hidden_states = self.layernorm_embedding(hidden_states)
|
@@ -444,6 +452,7 @@ class BartDecoderWrapper(torch.nn.Module):
|
|
444
452
|
self_kv_cache.append(past_key_values[i][1])
|
445
453
|
self_kv_cache = torch.stack(self_kv_cache, dim=0)
|
446
454
|
|
455
|
+
# return batch_position to keep it as a variable within the graph
|
447
456
|
return lm_logits, self_kv_cache, batch_position
|
448
457
|
|
449
458
|
|
@@ -467,9 +476,6 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
467
476
|
cross_key_value: torch.Tensor = None,
|
468
477
|
batch_idx: torch.Tensor = None,
|
469
478
|
) -> Tuple[torch.Tensor]:
|
470
|
-
encoder_batch_size = input_ids.shape[0]
|
471
|
-
decoder_batch_size = encoder_batch_size # TODO(taehoon) fix to enable beam-search
|
472
|
-
|
473
479
|
# 1. run encoder
|
474
480
|
encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
|
475
481
|
last_hidden_states = encoder_outputs[0]
|
@@ -477,19 +483,19 @@ class BartEncoderWrapper(torch.nn.Module):
|
|
477
483
|
# 2. run dummy decoder to get pre-calculated cross-key_values for generation
|
478
484
|
dummy_past_key_value = []
|
479
485
|
for _ in range(self.num_layers):
|
480
|
-
pkv_self_attn_key = torch.zeros(
|
481
|
-
pkv_self_attn_value = torch.zeros(
|
482
|
-
pkv_cross_attn_key = torch.zeros(
|
483
|
-
pkv_cross_attn_value = torch.zeros(
|
486
|
+
pkv_self_attn_key = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
487
|
+
pkv_self_attn_value = torch.zeros(1, self.num_heads, self.decoder_max_length, self.d_kv)
|
488
|
+
pkv_cross_attn_key = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
489
|
+
pkv_cross_attn_value = torch.zeros(1, self.num_heads, self.encoder_max_length, self.d_kv)
|
484
490
|
layer_pkv = (pkv_self_attn_key, pkv_self_attn_value, pkv_cross_attn_key, pkv_cross_attn_value)
|
485
491
|
dummy_past_key_value.append(layer_pkv)
|
486
492
|
|
487
|
-
decoder_attention_mask = torch.zeros(
|
493
|
+
decoder_attention_mask = torch.zeros(1, self.decoder_max_length, dtype=torch.float32)
|
488
494
|
decoder_attention_mask[:, :1] = 1
|
489
495
|
|
490
496
|
decoder_outputs = _BartDecoder.forward(
|
491
497
|
self.decoder,
|
492
|
-
input_ids=torch.zeros((
|
498
|
+
input_ids=torch.zeros((1, 1), dtype=torch.int64),
|
493
499
|
attention_mask=decoder_attention_mask,
|
494
500
|
encoder_attention_mask=attention_mask,
|
495
501
|
cache_position=torch.tensor(0, dtype=torch.int32),
|
@@ -22,23 +22,25 @@
|
|
22
22
|
# from Rebellions Inc.
|
23
23
|
|
24
24
|
import inspect
|
25
|
-
import
|
26
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
25
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Union
|
27
26
|
|
28
|
-
from transformers import
|
27
|
+
from transformers import BartConfig, BartForConditionalGeneration, BartModel, PretrainedConfig
|
29
28
|
|
30
29
|
from ....modeling_base import RBLNModel
|
31
30
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
|
+
from ....utils.logging import get_logger
|
32
|
+
from ...models.seq2seq import RBLNModelForSeq2SeqLM
|
33
|
+
from .bart_architecture import BartWrapper
|
32
34
|
|
33
35
|
|
34
|
-
logger =
|
36
|
+
logger = get_logger()
|
37
|
+
|
35
38
|
|
36
39
|
if TYPE_CHECKING:
|
37
|
-
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
40
|
+
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
|
38
41
|
|
39
42
|
|
40
43
|
class RBLNBartModel(RBLNModel):
|
41
|
-
auto_model_class = AutoModel # feature extraction
|
42
44
|
original_model_class = BartModel
|
43
45
|
original_config_class = BartConfig
|
44
46
|
|
@@ -104,3 +106,20 @@ class RBLNBartModel(RBLNModel):
|
|
104
106
|
|
105
107
|
rbln_config.model_cfg.update({"max_seq_len": rbln_max_seq_len})
|
106
108
|
return rbln_config
|
109
|
+
|
110
|
+
|
111
|
+
class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
|
112
|
+
@classmethod
|
113
|
+
def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
|
114
|
+
return BartWrapper(model)
|
115
|
+
|
116
|
+
def __getattr__(self, __name: str) -> Any:
|
117
|
+
def redirect(func):
|
118
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
119
|
+
|
120
|
+
val = getattr(BartForConditionalGeneration, __name)
|
121
|
+
|
122
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
123
|
+
return redirect(val)
|
124
|
+
|
125
|
+
return val
|
@@ -25,7 +25,7 @@ import inspect
|
|
25
25
|
import logging
|
26
26
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
27
27
|
|
28
|
-
from transformers import
|
28
|
+
from transformers import BertConfig, BertModel, PretrainedConfig
|
29
29
|
|
30
30
|
from ....modeling_base import RBLNModel
|
31
31
|
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
@@ -38,7 +38,6 @@ if TYPE_CHECKING:
|
|
38
38
|
|
39
39
|
|
40
40
|
class RBLNBertModel(RBLNModel):
|
41
|
-
auto_model_class = AutoModel # feature extraction
|
42
41
|
original_model_class = BertModel
|
43
42
|
original_config_class = BertConfig
|
44
43
|
|
@@ -20,8 +20,9 @@
|
|
20
20
|
# are the intellectual property of Rebellions Inc. and may not be
|
21
21
|
# copied, modified, or distributed without prior written permission
|
22
22
|
# from Rebellions Inc.
|
23
|
+
import functools
|
23
24
|
import glob
|
24
|
-
import
|
25
|
+
import os
|
25
26
|
from abc import ABC
|
26
27
|
from dataclasses import dataclass
|
27
28
|
from pathlib import Path
|
@@ -36,11 +37,12 @@ from transformers.utils import ModelOutput
|
|
36
37
|
|
37
38
|
from ....modeling_base import RBLNModel
|
38
39
|
from ....modeling_config import DEFAULT_COMPILED_MODEL_NAME, RBLNCompileConfig, RBLNConfig
|
40
|
+
from ....utils.logging import get_logger
|
39
41
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
40
42
|
from ....utils.timer_utils import rbln_timer
|
41
43
|
|
42
44
|
|
43
|
-
logger =
|
45
|
+
logger = get_logger()
|
44
46
|
|
45
47
|
if TYPE_CHECKING:
|
46
48
|
from transformers import (
|
@@ -97,7 +99,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
97
99
|
@dataclass
|
98
100
|
class RBLNDecoderOnlyOutput(ModelOutput):
|
99
101
|
logits: torch.FloatTensor = None
|
100
|
-
|
102
|
+
generate_idx: torch.Tensor = None
|
101
103
|
|
102
104
|
|
103
105
|
class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
@@ -243,6 +245,54 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
243
245
|
|
244
246
|
return model
|
245
247
|
|
248
|
+
def validate_quantization_config(quantize_config):
|
249
|
+
if quantize_config is not None:
|
250
|
+
q_format = quantize_config.get("format")
|
251
|
+
q_precision = quantize_config.get("precision")
|
252
|
+
|
253
|
+
if q_format not in SUPPORTED_QUANTIZATIONS:
|
254
|
+
raise ValueError(
|
255
|
+
f"Invalid quantization format: {q_format}. "
|
256
|
+
f"Supported formats are: {list(SUPPORTED_QUANTIZATIONS.keys())}"
|
257
|
+
)
|
258
|
+
|
259
|
+
if q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
260
|
+
raise ValueError(
|
261
|
+
f"Invalid precision: {q_precision} for format: {q_format}. "
|
262
|
+
f"Supported precisions are: {SUPPORTED_QUANTIZATIONS[q_format]}"
|
263
|
+
)
|
264
|
+
|
265
|
+
return quantize_config
|
266
|
+
|
267
|
+
@classmethod
|
268
|
+
def set_quantize_env(cls, quantize_config):
|
269
|
+
RBLN_QUANT_BITS_ENV = "RBLN_QUANT_BITS"
|
270
|
+
quantize_config = cls.validate_quantization_config(quantize_config)
|
271
|
+
if quantize_config is not None:
|
272
|
+
q_precision = quantize_config.get("precision")
|
273
|
+
quant_bits = q_precision.split("w")[1].split("a")[0]
|
274
|
+
os.environ[RBLN_QUANT_BITS_ENV] = quant_bits
|
275
|
+
return RBLN_QUANT_BITS_ENV
|
276
|
+
return None
|
277
|
+
|
278
|
+
@classmethod
|
279
|
+
def reset_quantize_env(cls, env_var_name):
|
280
|
+
if env_var_name is not None and env_var_name in os.environ:
|
281
|
+
del os.environ[env_var_name]
|
282
|
+
|
283
|
+
@classmethod
|
284
|
+
def manage_quantize_env(cls, func):
|
285
|
+
@functools.wraps(func)
|
286
|
+
def wrapper(*args, **kwargs):
|
287
|
+
quantize_config = kwargs.get("quantize_config")
|
288
|
+
quantize_env_var = cls.set_quantize_env(quantize_config)
|
289
|
+
try:
|
290
|
+
return func(*args, **kwargs)
|
291
|
+
finally:
|
292
|
+
cls.reset_quantize_env(quantize_env_var)
|
293
|
+
|
294
|
+
return wrapper
|
295
|
+
|
246
296
|
@classmethod
|
247
297
|
@torch.inference_mode()
|
248
298
|
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNConfig):
|
@@ -252,7 +302,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
252
302
|
prefill_rbln_compile_config = rbln_compile_configs[0]
|
253
303
|
dec_rbln_compile_config = rbln_compile_configs[1]
|
254
304
|
|
255
|
-
@rbln_timer("
|
305
|
+
@rbln_timer("JIT trace")
|
256
306
|
def get_scripted_model():
|
257
307
|
# This function is nested to dealloc the example inputs before compilation.
|
258
308
|
prefill_example_inputs = prefill_rbln_compile_config.get_dummy_inputs(fill=0)
|
@@ -271,7 +321,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
271
321
|
|
272
322
|
prefill_scripted_model, dec_scripted_model = get_scripted_model()
|
273
323
|
|
274
|
-
@rbln_timer("
|
324
|
+
@rbln_timer("Model conversion")
|
275
325
|
def scripted_model_to_ir():
|
276
326
|
prefill_ir = rebel.torchscript_to_ir(
|
277
327
|
prefill_scripted_model,
|
@@ -291,7 +341,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
291
341
|
for i in range(model.config.num_hidden_layers * 2)
|
292
342
|
]
|
293
343
|
|
294
|
-
|
344
|
+
# Extract quantize_config from rbln_config
|
345
|
+
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
346
|
+
|
347
|
+
@cls.manage_quantize_env
|
348
|
+
def compile_model(*args, **kwargs):
|
349
|
+
# Remove quantize_config from kwargs
|
350
|
+
kwargs.pop("quantize_config", None)
|
351
|
+
|
352
|
+
# Call rebel.compile with the updated kwargs
|
353
|
+
return rebel.compile(*args, **kwargs)
|
354
|
+
|
355
|
+
compiled_model = compile_model(
|
295
356
|
prefill_ir,
|
296
357
|
dec_ir,
|
297
358
|
connections=connections,
|
@@ -299,7 +360,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
299
360
|
npu=prefill_rbln_compile_config.npu,
|
300
361
|
tensor_parallel_size=prefill_rbln_compile_config.tensor_parallel_size,
|
301
362
|
use_weight_sharing=True,
|
363
|
+
quantize_config=quantize_config,
|
302
364
|
)
|
365
|
+
|
303
366
|
return compiled_model
|
304
367
|
|
305
368
|
@classmethod
|
@@ -314,6 +377,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
314
377
|
rbln_quantization = rbln_kwargs.get("quantization", None)
|
315
378
|
rbln_use_inputs_embeds = rbln_kwargs.get("use_inputs_embeds", None)
|
316
379
|
|
380
|
+
rbln_quantization = cls.validate_quantization_config(rbln_quantization)
|
381
|
+
|
317
382
|
prefill_chunk_size = 128
|
318
383
|
if rbln_max_seq_len is None:
|
319
384
|
rbln_max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
@@ -330,16 +395,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
330
395
|
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
331
396
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
332
397
|
|
333
|
-
if rbln_quantization is not None:
|
334
|
-
q_format = rbln_quantization.get("format", None)
|
335
|
-
q_precision = rbln_quantization.get("precision", None)
|
336
|
-
|
337
|
-
if q_format not in SUPPORTED_QUANTIZATIONS.keys() or q_precision not in SUPPORTED_QUANTIZATIONS[q_format]:
|
338
|
-
raise ValueError(
|
339
|
-
f'rbln_quantization="{rbln_quantization}" is not a supported quantization format or precesion, '
|
340
|
-
f"Possible: {SUPPORTED_QUANTIZATIONS}"
|
341
|
-
)
|
342
|
-
|
343
398
|
def get_input_info(
|
344
399
|
batch_size,
|
345
400
|
query_length,
|
@@ -439,50 +494,41 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
439
494
|
def prepare_inputs_for_generation(
|
440
495
|
self,
|
441
496
|
input_ids: torch.LongTensor,
|
442
|
-
|
497
|
+
generate_idx: Optional[torch.Tensor] = None,
|
443
498
|
attention_mask: Optional[torch.LongTensor] = None,
|
444
499
|
inputs_embeds: Optional[torch.Tensor] = None,
|
445
500
|
**kwargs,
|
446
501
|
):
|
447
502
|
model_inputs = {}
|
448
|
-
|
449
|
-
if past_cached_length is None:
|
450
|
-
# huggingface make dummy_input_ids if model_input_name is "input_embeds"
|
451
|
-
# https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L469
|
452
|
-
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
453
|
-
input_tensors = inputs_embeds
|
454
|
-
else:
|
455
|
-
input_tensors = input_ids
|
503
|
+
is_prefill_phase = generate_idx is None
|
456
504
|
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
past_cached_length = torch.zeros((batch_size, 1), dtype=torch.int32)
|
461
|
-
for i in range(batch_size):
|
462
|
-
input_tensor = input_tensors[i]
|
463
|
-
input_tensor = input_tensor[attention_mask[i] == 1]
|
464
|
-
valid_len = input_tensor.shape[0]
|
465
|
-
cache_position = torch.arange(0, valid_len, dtype=torch.int32)
|
466
|
-
past_cached_length[i] = valid_len
|
467
|
-
l_input_tensors.append(input_tensor.unsqueeze(0))
|
468
|
-
cache_positions.append(cache_position.unsqueeze(0))
|
469
|
-
|
470
|
-
input_tensors = l_input_tensors
|
471
|
-
if self.rbln_config.model_cfg["use_inputs_embeds"] and inputs_embeds is not None:
|
472
|
-
model_inputs.update({"inputs_embeds": input_tensors, "input_ids": input_ids})
|
473
|
-
else:
|
474
|
-
model_inputs.update({"input_ids": input_tensors, "inputs_embeds": inputs_embeds})
|
475
|
-
# decoder phase
|
505
|
+
if is_prefill_phase:
|
506
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
507
|
+
cache_position = None
|
476
508
|
else:
|
509
|
+
if inputs_embeds is not None:
|
510
|
+
raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
|
511
|
+
|
477
512
|
input_ids = input_ids[:, -1:]
|
478
|
-
|
479
|
-
|
513
|
+
cache_position = generate_idx
|
514
|
+
generate_idx = generate_idx + 1
|
515
|
+
model_inputs.update({"input_ids": input_ids})
|
516
|
+
|
517
|
+
if inputs_embeds is not None:
|
518
|
+
if self.rbln_config.model_cfg["use_inputs_embeds"]:
|
519
|
+
model_inputs.update({"inputs_embeds": inputs_embeds})
|
520
|
+
else:
|
521
|
+
raise ValueError(
|
522
|
+
"The specifying inputs_embedst is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
523
|
+
)
|
524
|
+
else:
|
480
525
|
model_inputs.update({"input_ids": input_ids})
|
481
526
|
|
482
527
|
model_inputs.update(
|
483
528
|
{
|
484
|
-
"
|
485
|
-
"
|
529
|
+
"attention_mask": attention_mask,
|
530
|
+
"cache_position": cache_position,
|
531
|
+
"generate_idx": generate_idx,
|
486
532
|
}
|
487
533
|
)
|
488
534
|
|
@@ -494,42 +540,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
494
540
|
model_kwargs: Dict[str, Any],
|
495
541
|
**kwargs,
|
496
542
|
) -> Dict[str, Any]:
|
497
|
-
# update
|
498
|
-
model_kwargs["
|
543
|
+
# update generate_idx
|
544
|
+
model_kwargs["generate_idx"] = outputs.generate_idx
|
499
545
|
|
500
546
|
return model_kwargs
|
501
547
|
|
502
548
|
def forward(
|
503
549
|
self,
|
504
|
-
input_ids: Optional[
|
505
|
-
inputs_embeds: Optional[
|
506
|
-
cache_position:
|
550
|
+
input_ids: Optional[torch.LongTensor] = None,
|
551
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
552
|
+
cache_position: Optional[torch.Tensor] = None,
|
553
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
554
|
+
generate_idx: Optional[torch.Tensor] = None,
|
555
|
+
# from llava_next forward args
|
507
556
|
batch_idx: Optional[int] = None,
|
508
|
-
past_cached_length: Optional[torch.Tensor] = None,
|
509
557
|
**kwargs,
|
510
558
|
) -> Tuple[torch.FloatTensor]:
|
511
|
-
# prefll
|
512
|
-
if
|
559
|
+
# prefll
|
560
|
+
if cache_position is None:
|
513
561
|
logits = []
|
514
|
-
input_tensors =
|
515
|
-
|
562
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
563
|
+
batch_size = input_tensors.shape[0]
|
564
|
+
|
565
|
+
for b_idx in range(batch_size):
|
566
|
+
# Transform inputs as vllm format
|
567
|
+
if attention_mask is not None:
|
568
|
+
input_tensor = input_tensors[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
|
569
|
+
else:
|
570
|
+
input_tensor = input_tensors[b_idx : b_idx + 1]
|
571
|
+
|
572
|
+
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
573
|
+
|
516
574
|
logit = self._forward_prefill(
|
517
575
|
input_ids=input_tensor if inputs_embeds is None else None,
|
518
576
|
inputs_embeds=input_tensor if inputs_embeds is not None else None,
|
519
|
-
cache_position=
|
520
|
-
batch_idx=batch_idx,
|
577
|
+
cache_position=cache_position,
|
578
|
+
batch_idx=b_idx if batch_idx is None else batch_idx, # Llava-next prefill
|
521
579
|
)
|
522
580
|
logits.append(logit)
|
523
581
|
logits = torch.cat(logits, dim=0)
|
524
|
-
#
|
525
|
-
elif cache_position.shape[-1] > 1:
|
526
|
-
logits = self._forward_prefill(
|
527
|
-
input_ids=input_ids,
|
528
|
-
inputs_embeds=inputs_embeds,
|
529
|
-
cache_position=cache_position,
|
530
|
-
batch_idx=batch_idx,
|
531
|
-
)
|
532
|
-
# common decoder
|
582
|
+
# decoder
|
533
583
|
else:
|
534
584
|
logits = self._forward_decoder(
|
535
585
|
input_ids=input_ids,
|
@@ -539,7 +589,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
539
589
|
|
540
590
|
return RBLNDecoderOnlyOutput(
|
541
591
|
logits=logits,
|
542
|
-
|
592
|
+
generate_idx=generate_idx,
|
543
593
|
)
|
544
594
|
|
545
595
|
def _forward_prefill(
|
@@ -567,23 +617,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
567
617
|
torch.empty(size=[], dtype=torch.int16, device="cpu"),
|
568
618
|
]
|
569
619
|
|
570
|
-
|
571
|
-
model_input_name = "inputs_embeds"
|
572
|
-
else:
|
573
|
-
model_input_name = "input_ids"
|
574
|
-
|
575
|
-
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
576
|
-
|
620
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
577
621
|
query_length = input_tensors.shape[1]
|
578
|
-
|
622
|
+
_attention_mask = self.prefill_attention_mask.clone()
|
623
|
+
|
579
624
|
for step in range(0, query_length, self.prefill_chunk_size):
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
if
|
584
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0,
|
625
|
+
# pad input_tensors & cache_position for prefill_chunk
|
626
|
+
if (step + self.prefill_chunk_size) > query_length:
|
627
|
+
pad_to_chunk = step + self.prefill_chunk_size - query_length
|
628
|
+
if inputs_embeds is not None:
|
629
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, 0, 0, pad_to_chunk))
|
585
630
|
else:
|
586
|
-
input_tensors = torch.nn.functional.pad(input_tensors, (0,
|
631
|
+
input_tensors = torch.nn.functional.pad(input_tensors, (0, pad_to_chunk))
|
587
632
|
|
588
633
|
cache_position = torch.cat(
|
589
634
|
[
|
@@ -597,25 +642,28 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
597
642
|
dim=-1,
|
598
643
|
)
|
599
644
|
|
600
|
-
|
601
|
-
|
645
|
+
# slice input_tensor & cache_position with prefill_chunk_size
|
646
|
+
_input_tensors = input_tensors[:, step : step + self.prefill_chunk_size]
|
647
|
+
_cache_position = cache_position[:, step : step + self.prefill_chunk_size]
|
602
648
|
|
649
|
+
# update attention_mask
|
603
650
|
if step >= self.prefill_chunk_size:
|
604
|
-
|
605
|
-
|
651
|
+
_attention_mask[:, :, :, step - self.prefill_chunk_size : step] = 1
|
652
|
+
_attention_mask[:, :, :, step : step + self.prefill_chunk_size] = self.causal_mask
|
606
653
|
|
607
|
-
query_idx = query_length % self.prefill_chunk_size
|
654
|
+
query_idx = (query_length - 1) % self.prefill_chunk_size
|
608
655
|
|
609
656
|
logits, _ = self.prefill_decoder(
|
610
|
-
input_ids=
|
611
|
-
inputs_embeds=
|
612
|
-
attention_mask=
|
613
|
-
cache_position=
|
657
|
+
input_ids=_input_tensors.contiguous() if inputs_embeds is None else None,
|
658
|
+
inputs_embeds=_input_tensors.contiguous() if inputs_embeds is not None else None,
|
659
|
+
attention_mask=_attention_mask.contiguous(),
|
660
|
+
cache_position=_cache_position.contiguous(),
|
614
661
|
batch_position=torch.tensor(batch_idx, dtype=torch.int16),
|
615
662
|
query_idx=torch.tensor(query_idx, dtype=torch.int16),
|
616
663
|
out=out_buffers,
|
617
664
|
)
|
618
665
|
|
666
|
+
# update decoder_attn_mask with preprocessed kv-cache length in prefill phase
|
619
667
|
self.dec_attn_mask[batch_idx] = self.dec_attn_mask_init.clone()
|
620
668
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
621
669
|
|
@@ -627,11 +675,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
627
675
|
inputs_embeds: torch.Tensor = None,
|
628
676
|
cache_position: torch.Tensor = None,
|
629
677
|
) -> torch.FloatTensor:
|
630
|
-
|
631
|
-
model_input_name = "inputs_embeds"
|
632
|
-
else:
|
633
|
-
model_input_name = "input_ids"
|
634
|
-
input_tensors = input_ids if model_input_name == "input_ids" else inputs_embeds
|
678
|
+
input_tensors = inputs_embeds if inputs_embeds is not None else input_ids
|
635
679
|
|
636
680
|
batch_size = input_tensors.shape[0]
|
637
681
|
|
@@ -640,8 +684,8 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
640
684
|
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
641
685
|
|
642
686
|
logits, _ = self.decoder(
|
643
|
-
input_ids=input_tensors.contiguous() if
|
644
|
-
inputs_embeds=input_tensors.contiguous() if
|
687
|
+
input_ids=input_tensors.contiguous() if inputs_embeds is None else None,
|
688
|
+
inputs_embeds=input_tensors.contiguous() if inputs_embeds is not None else None,
|
645
689
|
attention_mask=self.dec_attn_mask.contiguous(),
|
646
690
|
cache_position=cache_position.contiguous(),
|
647
691
|
batch_position=torch.tensor(0, dtype=torch.int16),
|
@@ -649,3 +693,31 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel, ABC):
|
|
649
693
|
)
|
650
694
|
|
651
695
|
return logits
|
696
|
+
|
697
|
+
def vllm_forward(
|
698
|
+
self,
|
699
|
+
input_ids: torch.LongTensor = None,
|
700
|
+
inputs_embeds: torch.Tensor = None,
|
701
|
+
cache_position: torch.Tensor = None,
|
702
|
+
batch_idx: Optional[int] = None,
|
703
|
+
**kwargs,
|
704
|
+
) -> Tuple[torch.FloatTensor]:
|
705
|
+
# prefll
|
706
|
+
if cache_position.shape[-1] > 1:
|
707
|
+
logits = self._forward_prefill(
|
708
|
+
input_ids=input_ids,
|
709
|
+
inputs_embeds=inputs_embeds,
|
710
|
+
cache_position=cache_position,
|
711
|
+
batch_idx=batch_idx,
|
712
|
+
)
|
713
|
+
# decoder
|
714
|
+
else:
|
715
|
+
logits = self._forward_decoder(
|
716
|
+
input_ids=input_ids,
|
717
|
+
inputs_embeds=inputs_embeds,
|
718
|
+
cache_position=cache_position,
|
719
|
+
)
|
720
|
+
|
721
|
+
return RBLNDecoderOnlyOutput(
|
722
|
+
logits=logits,
|
723
|
+
)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
# Copyright 2024 Rebellions Inc.
|
2
|
+
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
# Portions of this software are licensed under the Apache License,
|
16
|
+
# Version 2.0. See the NOTICE file distributed with this work for
|
17
|
+
# additional information regarding copyright ownership.
|
18
|
+
|
19
|
+
# All other portions of this software, including proprietary code,
|
20
|
+
# are the intellectual property of Rebellions Inc. and may not be
|
21
|
+
# copied, modified, or distributed without prior written permission
|
22
|
+
# from Rebellions Inc.
|
23
|
+
|
24
|
+
import os
|
25
|
+
from os import environ
|
26
|
+
|
27
|
+
|
28
|
+
this_path = os.path.abspath(__file__)
|
29
|
+
local_dir = "/" + os.path.join(*this_path.split("/")[:-1]) + "/hf_hub_cached"
|
30
|
+
environ["LOCAL_CACHE_ROOT_CUSTOM_CODE_MIDM"] = local_dir
|
31
|
+
|
32
|
+
from .modeling_exaone import RBLNExaoneForCausalLM
|