optimum-rbln 0.8.2a3__py3-none-any.whl → 0.8.2a5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +36 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +40 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +31 -14
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -4
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +567 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +58 -52
- {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
16
|
from collections import deque
|
|
18
17
|
from dataclasses import dataclass
|
|
19
18
|
from pathlib import Path
|
|
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
|
|
|
22
21
|
import rebel
|
|
23
22
|
import torch
|
|
24
23
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
24
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
26
|
from transformers.modeling_utils import no_init_weights
|
|
27
27
|
from transformers.utils import ModelOutput
|
|
28
28
|
|
|
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
|
|
|
30
30
|
from ....modeling import RBLNModel
|
|
31
31
|
from ....utils.logging import get_logger
|
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
33
|
-
from ...
|
|
34
|
-
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
33
|
+
from ...modeling_attention_utils import (
|
|
34
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
35
|
set_default_values,
|
|
38
36
|
validate_attention_method,
|
|
39
|
-
|
|
37
|
+
validate_sliding_window,
|
|
40
38
|
)
|
|
39
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
|
40
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
41
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
logger = get_logger()
|
|
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
267
268
|
|
|
268
269
|
attention_mask = self.dec_attn_mask
|
|
269
270
|
|
|
270
|
-
if self.rbln_config.
|
|
271
|
+
if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
|
|
271
272
|
block_tables = block_tables[: self.batch_size]
|
|
272
273
|
|
|
273
274
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
283
284
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
285
|
)
|
|
285
286
|
|
|
286
|
-
return
|
|
287
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
|
|
287
288
|
|
|
288
289
|
def _prepare_prefill_inputs(
|
|
289
290
|
self,
|
|
@@ -303,6 +304,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
303
304
|
position_embed = (
|
|
304
305
|
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
305
306
|
)
|
|
307
|
+
if token_type_ids is not None:
|
|
308
|
+
token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
|
|
306
309
|
|
|
307
310
|
query_length = inputs.shape[1]
|
|
308
311
|
if query_length > self.rbln_config.max_seq_len:
|
|
@@ -447,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
447
450
|
self.dec_attn_mask[batch_idx].fill_(0)
|
|
448
451
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
449
452
|
|
|
450
|
-
return
|
|
453
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
451
454
|
|
|
452
455
|
|
|
453
456
|
@dataclass
|
|
454
|
-
class
|
|
457
|
+
class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
|
|
455
458
|
logits: torch.FloatTensor = None
|
|
456
459
|
generate_idx: torch.Tensor = None
|
|
457
460
|
padded_cache_lengths: int = None
|
|
458
461
|
|
|
459
462
|
|
|
460
|
-
class
|
|
463
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
461
464
|
"""
|
|
462
|
-
A base class for decoder-only transformer models
|
|
465
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
466
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
463
467
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
464
468
|
|
|
465
469
|
The class provides core functionality for:
|
|
466
470
|
|
|
467
471
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
468
472
|
2. Handling the compilation process for RBLN devices
|
|
469
|
-
3. Managing inference operations for
|
|
473
|
+
3. Managing inference operations for decoder-only architectures
|
|
470
474
|
|
|
471
475
|
This class inherits from RBLNModel and implements specific methods required for
|
|
472
|
-
decoder-only architectures
|
|
476
|
+
decoder-only architectures.
|
|
473
477
|
|
|
474
478
|
Note:
|
|
475
479
|
- This class is designed to be subclassed by specific model implementations
|
|
476
|
-
(e.g.,
|
|
480
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
477
481
|
- Subclasses should implement model-specific conversion logic.
|
|
478
482
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
479
483
|
"""
|
|
480
484
|
|
|
481
485
|
main_input_name = "input_ids"
|
|
482
|
-
auto_model_class =
|
|
486
|
+
auto_model_class = AutoModel
|
|
483
487
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
484
488
|
_use_rotary_emb = True
|
|
485
489
|
|
|
486
490
|
def __post_init__(self, **kwargs):
|
|
487
|
-
main_input_name = self.main_input_name
|
|
488
|
-
|
|
489
491
|
if self.rbln_config.use_inputs_embeds:
|
|
490
|
-
main_input_name = "inputs_embeds"
|
|
491
492
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
492
493
|
self.embed_tokens = self._create_embedding_layer()
|
|
493
494
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
494
495
|
else:
|
|
495
496
|
self.embed_tokens = None
|
|
496
497
|
|
|
497
|
-
#
|
|
498
|
-
|
|
499
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
500
|
-
)
|
|
501
|
-
block_tables = torch.zeros(
|
|
502
|
-
self.rbln_config.batch_size,
|
|
503
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
504
|
-
dtype=torch.int16,
|
|
505
|
-
).fill_(-1)
|
|
506
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
507
|
-
|
|
508
|
-
self.prefill_decoder = RBLNRuntimeModel(
|
|
509
|
-
runtime=self.model[0],
|
|
510
|
-
main_input_name=main_input_name,
|
|
511
|
-
embed_tokens=self.embed_tokens,
|
|
512
|
-
phase="prefill",
|
|
513
|
-
batch_size=self.rbln_config.batch_size,
|
|
514
|
-
dec_attn_mask=dec_attn_mask,
|
|
515
|
-
block_tables=block_tables,
|
|
516
|
-
free_block_pool=free_block_pool,
|
|
517
|
-
rbln_config=self.rbln_config,
|
|
518
|
-
vocab_size=self.config.vocab_size,
|
|
519
|
-
)
|
|
498
|
+
# TODO: add prefill runtime class.
|
|
499
|
+
self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
|
|
520
500
|
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
self.
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
dec_attn_mask=dec_attn_mask,
|
|
530
|
-
block_tables=block_tables,
|
|
531
|
-
free_block_pool=free_block_pool,
|
|
532
|
-
rbln_config=self.rbln_config,
|
|
501
|
+
# attributes for prefill
|
|
502
|
+
if self.rbln_config.use_global_attention:
|
|
503
|
+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
504
|
+
if self.rbln_config.use_local_attention:
|
|
505
|
+
self.local_block_tables = torch.tensor([0], dtype=torch.int16)
|
|
506
|
+
if self.rbln_config.use_attention_mask:
|
|
507
|
+
self.causal_mask = 1 - torch.triu(
|
|
508
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
533
509
|
)
|
|
534
510
|
|
|
535
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
536
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
537
|
-
|
|
538
511
|
@classmethod
|
|
539
512
|
def save_torch_artifacts(
|
|
540
513
|
cls,
|
|
@@ -569,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
569
542
|
return self.rbln_config.kvcache_num_blocks
|
|
570
543
|
|
|
571
544
|
@classmethod
|
|
572
|
-
def
|
|
573
|
-
cls,
|
|
574
|
-
model_id: str,
|
|
575
|
-
config: Optional[PretrainedConfig] = None,
|
|
576
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
|
577
|
-
revision: Optional[str] = None,
|
|
578
|
-
force_download: bool = False,
|
|
579
|
-
cache_dir: Optional[str] = None,
|
|
580
|
-
subfolder: str = "",
|
|
581
|
-
local_files_only: bool = False,
|
|
582
|
-
trust_remote_code: bool = False,
|
|
583
|
-
**kwargs,
|
|
584
|
-
):
|
|
585
|
-
kwargs = cls.update_kwargs(kwargs)
|
|
586
|
-
|
|
587
|
-
if config is None:
|
|
588
|
-
config = AutoConfig.from_pretrained(
|
|
589
|
-
model_id,
|
|
590
|
-
use_auth_token=use_auth_token,
|
|
591
|
-
revision=revision,
|
|
592
|
-
force_download=force_download,
|
|
593
|
-
cache_dir=cache_dir,
|
|
594
|
-
trust_remote_code=trust_remote_code,
|
|
595
|
-
**kwargs,
|
|
596
|
-
)
|
|
597
|
-
|
|
598
|
-
with no_init_weights():
|
|
599
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
600
|
-
|
|
601
|
-
model = prepare_model_for_quantization(
|
|
602
|
-
model,
|
|
603
|
-
model_id,
|
|
604
|
-
kwargs.get("num_hidden_layers"),
|
|
605
|
-
use_auth_token=use_auth_token,
|
|
606
|
-
revision=revision,
|
|
607
|
-
cache_dir=cache_dir,
|
|
608
|
-
force_download=force_download,
|
|
609
|
-
local_files_only=local_files_only,
|
|
610
|
-
)
|
|
611
|
-
return model
|
|
612
|
-
|
|
613
|
-
def __getattr__(self, __name: str) -> Any:
|
|
614
|
-
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
615
|
-
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
616
|
-
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
617
|
-
# proper method binding.
|
|
618
|
-
|
|
619
|
-
# The method implements a delegation pattern that:
|
|
620
|
-
|
|
621
|
-
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
622
|
-
# 2. For other attributes: Returns them directly from the original class
|
|
623
|
-
|
|
624
|
-
def redirect(func):
|
|
625
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
626
|
-
|
|
627
|
-
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
628
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
629
|
-
return redirect(val)
|
|
630
|
-
return val
|
|
631
|
-
|
|
632
|
-
@classmethod
|
|
633
|
-
def get_pytorch_model(
|
|
634
|
-
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
|
635
|
-
) -> PreTrainedModel:
|
|
636
|
-
if rbln_config and rbln_config.quantization:
|
|
637
|
-
model = cls.get_quantized_model(*args, **kwargs)
|
|
638
|
-
else:
|
|
639
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
640
|
-
|
|
641
|
-
return model
|
|
642
|
-
|
|
643
|
-
@classmethod
|
|
644
|
-
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
545
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
645
546
|
wrapper_cfg = {
|
|
646
547
|
"max_seq_len": rbln_config.max_seq_len,
|
|
647
548
|
"attn_impl": rbln_config.attn_impl,
|
|
@@ -658,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
658
559
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
659
560
|
|
|
660
561
|
@classmethod
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
wrapped_model
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
562
|
+
def _compile_model(
|
|
563
|
+
cls,
|
|
564
|
+
wrapped_model,
|
|
565
|
+
compile_config,
|
|
566
|
+
example_inputs,
|
|
567
|
+
compile_context,
|
|
568
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
569
|
+
quantization=None,
|
|
570
|
+
phase: str = "prefill",
|
|
571
|
+
):
|
|
572
|
+
try:
|
|
573
|
+
wrapped_model.phase = phase
|
|
574
|
+
if quantization:
|
|
575
|
+
quantization.maybe_set_quantization_env()
|
|
576
|
+
original_linear = torch.nn.functional.linear
|
|
577
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
578
|
+
compiled_model = cls.compile(
|
|
579
|
+
wrapped_model,
|
|
580
|
+
compile_config,
|
|
581
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
582
|
+
device=rbln_config.device,
|
|
583
|
+
example_inputs=example_inputs,
|
|
584
|
+
compile_context=compile_context,
|
|
585
|
+
)
|
|
586
|
+
return compiled_model
|
|
587
|
+
finally:
|
|
588
|
+
torch.nn.functional.linear = original_linear
|
|
589
|
+
if quantization:
|
|
590
|
+
quantization.maybe_reset_quantization_env()
|
|
667
591
|
|
|
592
|
+
@classmethod
|
|
593
|
+
def _get_compile_context(
|
|
594
|
+
cls,
|
|
595
|
+
compile_config: RBLNCompileConfig,
|
|
596
|
+
example_inputs: List[torch.Tensor],
|
|
597
|
+
):
|
|
668
598
|
context = CompileContext(use_weight_sharing=True)
|
|
669
599
|
|
|
670
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
671
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
672
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
673
|
-
|
|
674
600
|
# Mark static tensors (self kv states)
|
|
675
601
|
static_tensors = {}
|
|
676
|
-
for (name, _, _), tensor in zip(
|
|
602
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
677
603
|
if "past_key_values" in name:
|
|
678
604
|
static_tensors[name] = tensor
|
|
679
605
|
context.mark_static_address(tensor)
|
|
680
606
|
|
|
681
|
-
|
|
682
|
-
try:
|
|
683
|
-
if quantization:
|
|
684
|
-
quantization.maybe_set_quantization_env()
|
|
685
|
-
original_linear = torch.nn.functional.linear
|
|
686
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
687
|
-
compiled_model = cls.compile(
|
|
688
|
-
wrapped_model,
|
|
689
|
-
compile_config,
|
|
690
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
691
|
-
device=rbln_config.device,
|
|
692
|
-
example_inputs=example_inputs,
|
|
693
|
-
compile_context=compile_context,
|
|
694
|
-
)
|
|
695
|
-
return compiled_model
|
|
696
|
-
finally:
|
|
697
|
-
torch.nn.functional.linear = original_linear
|
|
698
|
-
if quantization:
|
|
699
|
-
quantization.maybe_reset_quantization_env()
|
|
700
|
-
|
|
701
|
-
wrapped_model.phase = "prefill"
|
|
702
|
-
compiled_prefill = compile_model(
|
|
703
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
704
|
-
)
|
|
705
|
-
|
|
706
|
-
wrapped_model.phase = "decode"
|
|
707
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
708
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
709
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
710
|
-
compiled_decoder = compile_model(
|
|
711
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
712
|
-
)
|
|
713
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
714
|
-
|
|
715
|
-
# check if the memory is enough to have additional blocks
|
|
716
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
717
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
718
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
719
|
-
compiled_models=compiled_models,
|
|
720
|
-
model_config=model.config,
|
|
721
|
-
rbln_config=rbln_config,
|
|
722
|
-
)
|
|
723
|
-
|
|
724
|
-
return compiled_models
|
|
607
|
+
return context, static_tensors
|
|
725
608
|
|
|
726
609
|
@classmethod
|
|
727
|
-
|
|
610
|
+
@torch.inference_mode()
|
|
611
|
+
def get_compiled_model(
|
|
728
612
|
cls,
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
743
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
744
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
745
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
746
|
-
|
|
747
|
-
# Get the maximum number of blocks that can be allocated
|
|
748
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
749
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
750
|
-
config=model_config,
|
|
751
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
752
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
753
|
-
kernel_size=kernel_size,
|
|
754
|
-
buffer=buffer,
|
|
613
|
+
model: PreTrainedModel,
|
|
614
|
+
rbln_config: RBLNDecoderOnlyModelConfig,
|
|
615
|
+
):
|
|
616
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
617
|
+
compile_config = rbln_config.compile_cfgs[0]
|
|
618
|
+
|
|
619
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
620
|
+
meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
|
|
621
|
+
example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
622
|
+
context, _ = cls._get_compile_context(compile_config, example_inputs)
|
|
623
|
+
|
|
624
|
+
compiled_model = cls._compile_model(
|
|
625
|
+
wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
|
|
755
626
|
)
|
|
627
|
+
compiled_models = {"prefill": compiled_model}
|
|
756
628
|
|
|
757
|
-
|
|
758
|
-
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
759
|
-
# If the memory is not enough, the model will fail to compile.
|
|
760
|
-
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
761
|
-
logger.warning(
|
|
762
|
-
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
763
|
-
"Our analysis indicates that additional memory is available for more blocks. "
|
|
764
|
-
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
765
|
-
"Please be advised that our memory estimation algorithm has limitations, "
|
|
766
|
-
"and increasing this value may not guarantee successful model compilation."
|
|
767
|
-
)
|
|
629
|
+
return compiled_models
|
|
768
630
|
|
|
769
631
|
@classmethod
|
|
770
|
-
def
|
|
771
|
-
cls,
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
kvcache_block_size: int,
|
|
775
|
-
nbits_per_param: Optional[int] = None,
|
|
776
|
-
n_model_params: Optional[int] = None,
|
|
777
|
-
kernel_size: Optional[int] = None,
|
|
778
|
-
buffer: Optional[int] = None,
|
|
779
|
-
num_runtimes: int = 2,
|
|
780
|
-
) -> int:
|
|
781
|
-
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
782
|
-
|
|
783
|
-
# available_dram - kernel_size - buffer
|
|
784
|
-
# - num_layers * 2 * tensor_parallel_size
|
|
785
|
-
# * align_2MB(
|
|
786
|
-
# x
|
|
787
|
-
# * block_size
|
|
788
|
-
# * align_64(head_dim)
|
|
789
|
-
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
790
|
-
# * 2
|
|
791
|
-
# ) > 0
|
|
792
|
-
|
|
793
|
-
# This inequality can be rewritten as follows:
|
|
794
|
-
|
|
795
|
-
# a - c * align_2MB(b * x) > 0
|
|
796
|
-
# where
|
|
797
|
-
# a = available_dram - kernel_size - buffer
|
|
798
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
799
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
800
|
-
|
|
801
|
-
# We can rewrite the inequality as follows:
|
|
802
|
-
# k > align_2MB(b*x)
|
|
803
|
-
# where
|
|
804
|
-
# k = a / c
|
|
805
|
-
|
|
806
|
-
# After that, we can derive the following equation:
|
|
807
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
808
|
-
|
|
809
|
-
def align(x: int, nbytes: int) -> int:
|
|
810
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
811
|
-
|
|
812
|
-
def align_2MB(x: int) -> int:
|
|
813
|
-
return align(x, 2**21)
|
|
814
|
-
|
|
815
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
816
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
817
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
818
|
-
vocab_size = config.vocab_size
|
|
819
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
820
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
821
|
-
|
|
822
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
823
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
824
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
825
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
826
|
-
|
|
827
|
-
if kernel_size is None:
|
|
828
|
-
if n_model_params is None:
|
|
829
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
830
|
-
# Get estimated kernel size (approximated)
|
|
831
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
832
|
-
lm_heads_nbytes = (
|
|
833
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
834
|
-
)
|
|
835
|
-
params = n_model_params - lm_heads_params
|
|
836
|
-
layer_nbytes = (
|
|
837
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
838
|
-
* num_layers
|
|
839
|
-
* tensor_parallel_size
|
|
840
|
-
)
|
|
841
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
842
|
-
elif n_model_params is not None:
|
|
843
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
844
|
-
|
|
845
|
-
available_dram -= kernel_size
|
|
632
|
+
def get_quantized_model(
|
|
633
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
634
|
+
) -> PreTrainedModel:
|
|
635
|
+
raise NotImplementedError
|
|
846
636
|
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
637
|
+
@classmethod
|
|
638
|
+
def get_pytorch_model(
|
|
639
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
640
|
+
) -> PreTrainedModel:
|
|
641
|
+
if rbln_config and rbln_config.quantization:
|
|
642
|
+
model = cls.get_quantized_model(*args, **kwargs)
|
|
643
|
+
else:
|
|
644
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
853
645
|
|
|
854
|
-
|
|
855
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
856
|
-
k = available_dram / c
|
|
857
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
646
|
+
return model
|
|
858
647
|
|
|
859
|
-
|
|
648
|
+
@classmethod
|
|
649
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
650
|
+
return use_local_attention
|
|
860
651
|
|
|
861
652
|
@classmethod
|
|
862
653
|
def get_input_info(
|
|
@@ -866,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
866
657
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
867
658
|
model_config: PretrainedConfig,
|
|
868
659
|
):
|
|
869
|
-
is_prefill: bool = query_length > 1
|
|
870
660
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
871
661
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
872
662
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
873
663
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
874
664
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
875
|
-
|
|
665
|
+
is_prefill = query_length > 1
|
|
876
666
|
|
|
877
667
|
# 1. main input
|
|
878
668
|
if rbln_config.use_inputs_embeds:
|
|
@@ -891,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
891
681
|
]
|
|
892
682
|
|
|
893
683
|
# 3. block_tables
|
|
894
|
-
if rbln_config.
|
|
684
|
+
if rbln_config.use_global_attention:
|
|
895
685
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
896
686
|
input_info.extend(
|
|
897
687
|
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
|
898
688
|
)
|
|
899
|
-
if rbln_config.
|
|
689
|
+
if rbln_config.use_local_attention:
|
|
900
690
|
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
|
901
691
|
|
|
902
|
-
# 4. query_position
|
|
903
|
-
if is_prefill:
|
|
692
|
+
# 4. query_position for sliding window attention
|
|
693
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
904
694
|
input_info.extend([("query_position", [], "int16")])
|
|
905
695
|
|
|
906
696
|
# 5. attention_mask & position_ids
|
|
@@ -922,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
922
712
|
rbln_config.kvcache_block_size,
|
|
923
713
|
head_dim,
|
|
924
714
|
]
|
|
925
|
-
local_kvcache_shape = [
|
|
715
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
926
716
|
input_info.extend(
|
|
927
717
|
[
|
|
928
718
|
(
|
|
@@ -969,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
969
759
|
# ```
|
|
970
760
|
|
|
971
761
|
# Returns:
|
|
972
|
-
#
|
|
762
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
973
763
|
|
|
974
764
|
raise NotImplementedError(
|
|
975
765
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
976
766
|
"See method docstring for required configuration details."
|
|
977
767
|
)
|
|
978
768
|
|
|
769
|
+
@classmethod
|
|
770
|
+
def _update_attention_config(
|
|
771
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
772
|
+
):
|
|
773
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
774
|
+
attn_impl=rbln_config.attn_impl,
|
|
775
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
776
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
777
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
validate_attention_method(
|
|
781
|
+
attn_impl=rbln_config.attn_impl,
|
|
782
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
783
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
784
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
788
|
+
rbln_config.kvcache_num_blocks = (
|
|
789
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
790
|
+
) * rbln_config.batch_size
|
|
791
|
+
|
|
792
|
+
return rbln_config
|
|
793
|
+
|
|
979
794
|
@classmethod
|
|
980
795
|
def _update_rbln_config(
|
|
981
796
|
cls,
|
|
@@ -996,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
996
811
|
):
|
|
997
812
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
998
813
|
if rbln_config.sliding_window is not None:
|
|
999
|
-
|
|
814
|
+
validate_sliding_window(rbln_config)
|
|
815
|
+
|
|
816
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
817
|
+
|
|
818
|
+
prefill_input_info = cls.get_input_info(
|
|
819
|
+
batch_size=1,
|
|
820
|
+
query_length=rbln_config.prefill_chunk_size,
|
|
821
|
+
rbln_config=rbln_config,
|
|
822
|
+
model_config=model_config,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
826
|
+
rbln_config.set_compile_cfgs([prefill_compile_config])
|
|
1000
827
|
|
|
828
|
+
return rbln_config
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
831
|
+
def _create_runtimes(
|
|
832
|
+
cls,
|
|
833
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
834
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
835
|
+
) -> List[rebel.Runtime]:
|
|
836
|
+
expected_model_names = [
|
|
837
|
+
"prefill",
|
|
838
|
+
]
|
|
839
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
840
|
+
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
841
|
+
|
|
842
|
+
return [
|
|
843
|
+
rebel.Runtime(
|
|
844
|
+
compiled_models[0],
|
|
845
|
+
tensor_type="pt",
|
|
846
|
+
device=rbln_config.device_map["prefill"],
|
|
847
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
848
|
+
),
|
|
849
|
+
]
|
|
850
|
+
|
|
851
|
+
def _preprocess_chunked_prefill(
|
|
852
|
+
self,
|
|
853
|
+
inputs: torch.Tensor,
|
|
854
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
855
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
856
|
+
):
|
|
857
|
+
# valid sequence length of inputs_embeds
|
|
858
|
+
query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
|
|
859
|
+
|
|
860
|
+
# extract valid inputs
|
|
861
|
+
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
862
|
+
|
|
863
|
+
if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
|
|
864
|
+
inputs = self.get_input_embeddings()(inputs)
|
|
865
|
+
|
|
866
|
+
if position_embed is not None:
|
|
867
|
+
position_embed = (
|
|
868
|
+
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
# padding for chunked prefill
|
|
872
|
+
padding_size = (
|
|
873
|
+
self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
|
|
874
|
+
) % self.rbln_config.prefill_chunk_size
|
|
875
|
+
padded_len = query_length + padding_size
|
|
876
|
+
|
|
877
|
+
inputs = (
|
|
878
|
+
torch.nn.functional.pad(inputs, (0, padding_size))
|
|
879
|
+
if not self.rbln_config.use_inputs_embeds
|
|
880
|
+
else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
881
|
+
)
|
|
882
|
+
position_embed = (
|
|
883
|
+
None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
884
|
+
)
|
|
885
|
+
cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
|
|
886
|
+
|
|
887
|
+
chunked_attention_mask = (
|
|
888
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
889
|
+
if self.rbln_config.use_attention_mask
|
|
890
|
+
else None
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
return inputs, position_embed, cache_position, query_length, chunked_attention_mask
|
|
894
|
+
|
|
895
|
+
def _chunked_prefill_forward(
|
|
896
|
+
self,
|
|
897
|
+
inputs: torch.Tensor,
|
|
898
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
899
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
900
|
+
):
|
|
901
|
+
padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
|
|
902
|
+
self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# chunked prefill
|
|
906
|
+
last_hidden_states = []
|
|
907
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
908
|
+
# Extract the current chunk of inputs and cache positions
|
|
909
|
+
input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
910
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
911
|
+
|
|
912
|
+
valid_length = (
|
|
913
|
+
self.rbln_config.prefill_chunk_size
|
|
914
|
+
if (step + self.rbln_config.prefill_chunk_size) <= query_length
|
|
915
|
+
else query_length - step
|
|
916
|
+
)
|
|
917
|
+
if self.rbln_config.use_local_attention:
|
|
918
|
+
query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
|
|
919
|
+
else:
|
|
920
|
+
query_position = None
|
|
921
|
+
|
|
922
|
+
if self.rbln_config.use_attention_mask:
|
|
923
|
+
if step > 0:
|
|
924
|
+
chunked_attention_mask[:, :, :, :step] = 1
|
|
925
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
926
|
+
|
|
927
|
+
# Forward pass for the current chunk
|
|
928
|
+
last_hidden_states_chunk = self.prefill_decoder(
|
|
929
|
+
input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
|
|
930
|
+
inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
|
|
931
|
+
cache_position=cache_pos_chunk,
|
|
932
|
+
block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
|
|
933
|
+
local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
|
|
934
|
+
query_position=query_position,
|
|
935
|
+
attention_mask=chunked_attention_mask,
|
|
936
|
+
position_emb=padded_position_embed,
|
|
937
|
+
)
|
|
938
|
+
last_hidden_states.append(last_hidden_states_chunk)
|
|
939
|
+
last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
|
|
940
|
+
|
|
941
|
+
return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
|
|
942
|
+
|
|
943
|
+
def _postprocess_chunked_prefill(
|
|
944
|
+
self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
|
945
|
+
):
|
|
946
|
+
# index copy for attention mask
|
|
947
|
+
if attention_mask is not None:
|
|
948
|
+
new_last_hidden_states = torch.full(
|
|
949
|
+
(1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
|
|
950
|
+
fill_value=1e-10,
|
|
951
|
+
dtype=last_hidden_states.dtype,
|
|
952
|
+
)
|
|
953
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
954
|
+
new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
|
|
955
|
+
else:
|
|
956
|
+
new_last_hidden_states = last_hidden_states
|
|
957
|
+
return new_last_hidden_states
|
|
958
|
+
|
|
959
|
+
def forward(
|
|
960
|
+
self,
|
|
961
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
962
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
963
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
964
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
965
|
+
**kwargs,
|
|
966
|
+
) -> Tuple[torch.FloatTensor]:
|
|
967
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
968
|
+
batch_size = inputs.shape[0]
|
|
969
|
+
all_last_hidden_states = []
|
|
970
|
+
for b_idx in range(batch_size):
|
|
971
|
+
last_hidden_states = self._chunked_prefill_forward(
|
|
972
|
+
inputs[b_idx : b_idx + 1],
|
|
973
|
+
attention_mask[b_idx] if attention_mask is not None else None,
|
|
974
|
+
position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
975
|
+
)
|
|
976
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
977
|
+
|
|
978
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
979
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
983
|
+
"""
|
|
984
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
985
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
986
|
+
|
|
987
|
+
The class provides core functionality for:
|
|
988
|
+
|
|
989
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
990
|
+
2. Handling the compilation process for RBLN devices
|
|
991
|
+
3. Managing inference operations for causal language modeling
|
|
992
|
+
|
|
993
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
994
|
+
decoder-only architectures and causal language modeling tasks.
|
|
995
|
+
|
|
996
|
+
Note:
|
|
997
|
+
- This class is designed to be subclassed by specific model implementations
|
|
998
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
999
|
+
- Subclasses should implement model-specific conversion logic.
|
|
1000
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
1001
|
+
"""
|
|
1002
|
+
|
|
1003
|
+
auto_model_class = AutoModelForCausalLM
|
|
1004
|
+
|
|
1005
|
+
def __post_init__(self, **kwargs):
|
|
1006
|
+
main_input_name = self.main_input_name
|
|
1007
|
+
|
|
1008
|
+
if self.rbln_config.use_inputs_embeds:
|
|
1009
|
+
main_input_name = "inputs_embeds"
|
|
1010
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
1011
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
1012
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
1013
|
+
else:
|
|
1014
|
+
self.embed_tokens = None
|
|
1015
|
+
|
|
1016
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
1017
|
+
dec_attn_mask = torch.zeros(
|
|
1018
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
1019
|
+
)
|
|
1020
|
+
block_tables = torch.zeros(
|
|
1021
|
+
self.rbln_config.batch_size,
|
|
1022
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
1023
|
+
dtype=torch.int16,
|
|
1024
|
+
).fill_(-1)
|
|
1025
|
+
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
1026
|
+
|
|
1027
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
|
1028
|
+
runtime=self.model[0],
|
|
1029
|
+
main_input_name=main_input_name,
|
|
1030
|
+
embed_tokens=self.embed_tokens,
|
|
1031
|
+
phase="prefill",
|
|
1032
|
+
batch_size=self.rbln_config.batch_size,
|
|
1033
|
+
dec_attn_mask=dec_attn_mask,
|
|
1034
|
+
block_tables=block_tables,
|
|
1035
|
+
free_block_pool=free_block_pool,
|
|
1036
|
+
rbln_config=self.rbln_config,
|
|
1037
|
+
vocab_size=self.config.vocab_size,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
if self.can_generate():
|
|
1041
|
+
self.decoders = {}
|
|
1042
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
1043
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
1044
|
+
runtime=self.model[i + 1],
|
|
1045
|
+
main_input_name=main_input_name,
|
|
1046
|
+
embed_tokens=self.embed_tokens,
|
|
1047
|
+
phase="decode",
|
|
1048
|
+
batch_size=batch_size,
|
|
1049
|
+
dec_attn_mask=dec_attn_mask,
|
|
1050
|
+
block_tables=block_tables,
|
|
1051
|
+
free_block_pool=free_block_pool,
|
|
1052
|
+
rbln_config=self.rbln_config,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
1056
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
1057
|
+
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def get_quantized_model(
|
|
1060
|
+
cls,
|
|
1061
|
+
model_id: str,
|
|
1062
|
+
config: Optional[PretrainedConfig] = None,
|
|
1063
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
1064
|
+
revision: Optional[str] = None,
|
|
1065
|
+
force_download: bool = False,
|
|
1066
|
+
cache_dir: Optional[str] = None,
|
|
1067
|
+
subfolder: str = "",
|
|
1068
|
+
local_files_only: bool = False,
|
|
1069
|
+
trust_remote_code: bool = False,
|
|
1070
|
+
**kwargs,
|
|
1071
|
+
):
|
|
1072
|
+
kwargs = cls.update_kwargs(kwargs)
|
|
1073
|
+
|
|
1074
|
+
if config is None:
|
|
1075
|
+
config = AutoConfig.from_pretrained(
|
|
1076
|
+
model_id,
|
|
1077
|
+
use_auth_token=use_auth_token,
|
|
1078
|
+
revision=revision,
|
|
1079
|
+
force_download=force_download,
|
|
1080
|
+
cache_dir=cache_dir,
|
|
1081
|
+
trust_remote_code=trust_remote_code,
|
|
1082
|
+
**kwargs,
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
with no_init_weights():
|
|
1086
|
+
model = AutoModelForCausalLM.from_config(config)
|
|
1087
|
+
|
|
1088
|
+
model = prepare_model_for_quantization(
|
|
1089
|
+
model,
|
|
1090
|
+
model_id,
|
|
1091
|
+
kwargs.get("num_hidden_layers"),
|
|
1092
|
+
use_auth_token=use_auth_token,
|
|
1093
|
+
revision=revision,
|
|
1094
|
+
cache_dir=cache_dir,
|
|
1095
|
+
force_download=force_download,
|
|
1096
|
+
local_files_only=local_files_only,
|
|
1097
|
+
)
|
|
1098
|
+
return model
|
|
1099
|
+
|
|
1100
|
+
def __getattr__(self, __name: str) -> Any:
|
|
1101
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
1102
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
1103
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
1104
|
+
# proper method binding.
|
|
1105
|
+
|
|
1106
|
+
# The method implements a delegation pattern that:
|
|
1107
|
+
|
|
1108
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
1109
|
+
# 2. For other attributes: Returns them directly from the original class
|
|
1110
|
+
|
|
1111
|
+
def redirect(func):
|
|
1112
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
1113
|
+
|
|
1114
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
1115
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
1116
|
+
return redirect(val)
|
|
1117
|
+
return val
|
|
1118
|
+
|
|
1119
|
+
@classmethod
|
|
1120
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
1121
|
+
wrapper_cfg = {
|
|
1122
|
+
"max_seq_len": rbln_config.max_seq_len,
|
|
1123
|
+
"attn_impl": rbln_config.attn_impl,
|
|
1124
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
1125
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
1126
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
|
1127
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
|
1128
|
+
"use_position_ids": rbln_config.use_position_ids,
|
|
1129
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
1130
|
+
"cache_impl": rbln_config.cache_impl,
|
|
1131
|
+
"sliding_window": rbln_config.sliding_window,
|
|
1132
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
1133
|
+
}
|
|
1134
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
1135
|
+
|
|
1136
|
+
@classmethod
|
|
1137
|
+
@torch.inference_mode()
|
|
1138
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
1139
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
1140
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
1141
|
+
|
|
1142
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
1143
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
1144
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
1145
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
1146
|
+
|
|
1147
|
+
compiled_models = {}
|
|
1148
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
1149
|
+
wrapped_model,
|
|
1150
|
+
prefill_compile_config,
|
|
1151
|
+
prefill_example_inputs,
|
|
1152
|
+
context,
|
|
1153
|
+
rbln_config,
|
|
1154
|
+
rbln_config.quantization,
|
|
1155
|
+
phase="prefill",
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
if rbln_config.can_generate:
|
|
1159
|
+
wrapped_model.phase = "decode"
|
|
1160
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
1161
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
1162
|
+
compiled_decoder = cls._compile_model(
|
|
1163
|
+
wrapped_model,
|
|
1164
|
+
dec_compile_config,
|
|
1165
|
+
dec_example_inputs,
|
|
1166
|
+
context,
|
|
1167
|
+
rbln_config,
|
|
1168
|
+
rbln_config.quantization,
|
|
1169
|
+
phase="decode",
|
|
1170
|
+
)
|
|
1171
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
1172
|
+
|
|
1173
|
+
# check if the memory is enough to have additional blocks
|
|
1174
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1175
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
1176
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
1177
|
+
compiled_models=compiled_models,
|
|
1178
|
+
model_config=model.config,
|
|
1179
|
+
rbln_config=rbln_config,
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
return compiled_models
|
|
1183
|
+
|
|
1184
|
+
@classmethod
|
|
1185
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
1186
|
+
return is_prefill
|
|
1187
|
+
|
|
1188
|
+
@classmethod
|
|
1189
|
+
def _update_attention_config(
|
|
1190
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
1191
|
+
):
|
|
1001
1192
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
1002
1193
|
attn_impl=rbln_config.attn_impl,
|
|
1003
1194
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1022,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1022
1213
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1023
1214
|
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1024
1215
|
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1025
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
1216
|
+
num_runtimes=1 if rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
1026
1217
|
)
|
|
1027
1218
|
|
|
1028
1219
|
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
|
1029
1220
|
|
|
1030
1221
|
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
|
1031
|
-
if max_num_blocks < flash_min_blocks:
|
|
1222
|
+
if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
|
|
1032
1223
|
max_num_blocks = flash_min_blocks
|
|
1033
1224
|
|
|
1034
1225
|
if max_num_blocks < rbln_config.batch_size:
|
|
@@ -1047,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1047
1238
|
)
|
|
1048
1239
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1049
1240
|
|
|
1050
|
-
|
|
1051
|
-
batch_size=1,
|
|
1052
|
-
query_length=rbln_config.prefill_chunk_size,
|
|
1053
|
-
rbln_config=rbln_config,
|
|
1054
|
-
model_config=model_config,
|
|
1055
|
-
)
|
|
1056
|
-
|
|
1057
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1241
|
+
return rbln_config
|
|
1058
1242
|
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1243
|
+
@classmethod
|
|
1244
|
+
def _update_rbln_config(
|
|
1245
|
+
cls,
|
|
1246
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
1247
|
+
model: Optional[PreTrainedModel] = None,
|
|
1248
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
1249
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
1250
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
1251
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
1252
|
+
if rbln_config.can_generate:
|
|
1253
|
+
compile_configs = rbln_config.compile_cfgs
|
|
1254
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
1255
|
+
dec_input_info = cls.get_input_info(
|
|
1256
|
+
batch_size=batch_size,
|
|
1257
|
+
query_length=1,
|
|
1258
|
+
rbln_config=rbln_config,
|
|
1259
|
+
model_config=model_config,
|
|
1260
|
+
)
|
|
1261
|
+
compile_configs.append(
|
|
1262
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
1263
|
+
)
|
|
1264
|
+
rbln_config.set_compile_cfgs(compile_configs)
|
|
1071
1265
|
|
|
1072
1266
|
return rbln_config
|
|
1073
1267
|
|
|
@@ -1077,38 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1077
1271
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1078
1272
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1079
1273
|
) -> List[rebel.Runtime]:
|
|
1080
|
-
expected_model_names = [
|
|
1081
|
-
|
|
1082
|
-
|
|
1083
|
-
|
|
1274
|
+
expected_model_names = ["prefill"]
|
|
1275
|
+
if rbln_config.can_generate:
|
|
1276
|
+
expected_model_names.extend(
|
|
1277
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
1278
|
+
)
|
|
1084
1279
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1085
1280
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1086
1281
|
|
|
1087
|
-
|
|
1282
|
+
ret_val = [
|
|
1088
1283
|
rebel.Runtime(
|
|
1089
1284
|
compiled_models[0],
|
|
1090
1285
|
tensor_type="pt",
|
|
1091
1286
|
device=rbln_config.device_map["prefill"],
|
|
1092
1287
|
activate_profiler=rbln_config.activate_profiler,
|
|
1093
1288
|
timeout=rbln_config.timeout,
|
|
1094
|
-
)
|
|
1095
|
-
*[
|
|
1096
|
-
rebel.Runtime(
|
|
1097
|
-
compiled_models[i + 1],
|
|
1098
|
-
tensor_type="pt",
|
|
1099
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1100
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1101
|
-
timeout=rbln_config.timeout,
|
|
1102
|
-
)
|
|
1103
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1104
|
-
],
|
|
1289
|
+
)
|
|
1105
1290
|
]
|
|
1291
|
+
if rbln_config.can_generate:
|
|
1292
|
+
ret_val.extend(
|
|
1293
|
+
[
|
|
1294
|
+
rebel.Runtime(
|
|
1295
|
+
compiled_models[i + 1],
|
|
1296
|
+
tensor_type="pt",
|
|
1297
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1298
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
1299
|
+
timeout=rbln_config.timeout,
|
|
1300
|
+
)
|
|
1301
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1302
|
+
]
|
|
1303
|
+
)
|
|
1304
|
+
return ret_val
|
|
1106
1305
|
|
|
1107
1306
|
def get_decoder(self):
|
|
1307
|
+
if not self.can_generate():
|
|
1308
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
1108
1309
|
return self.decoder
|
|
1109
1310
|
|
|
1110
1311
|
def can_generate(self):
|
|
1111
|
-
return
|
|
1312
|
+
return self.rbln_config.can_generate
|
|
1112
1313
|
|
|
1113
1314
|
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1114
1315
|
raise NotImplementedError
|
|
@@ -1165,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1165
1366
|
|
|
1166
1367
|
def _update_model_kwargs_for_generation(
|
|
1167
1368
|
self,
|
|
1168
|
-
outputs:
|
|
1369
|
+
outputs: RBLNDecoderOnlyForCausalLMOutput,
|
|
1169
1370
|
model_kwargs: Dict[str, Any],
|
|
1170
1371
|
**kwargs,
|
|
1171
1372
|
) -> Dict[str, Any]:
|
|
@@ -1193,15 +1394,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1193
1394
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1194
1395
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
1195
1396
|
|
|
1397
|
+
# for only use forward
|
|
1398
|
+
if not self.can_generate():
|
|
1399
|
+
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1400
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1401
|
+
|
|
1196
1402
|
# Prefll
|
|
1197
1403
|
if cache_position is None:
|
|
1198
1404
|
logits = []
|
|
1199
1405
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
1200
|
-
# for only use forward
|
|
1201
|
-
if generate_idx is None:
|
|
1202
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1203
|
-
if padded_cache_lengths is None:
|
|
1204
|
-
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1205
1406
|
batch_size = inputs.shape[0]
|
|
1206
1407
|
for b_idx in range(batch_size):
|
|
1207
1408
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
@@ -1236,6 +1437,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1236
1437
|
if not return_dict:
|
|
1237
1438
|
return logits, generate_idx, padded_cache_lengths
|
|
1238
1439
|
else:
|
|
1239
|
-
return
|
|
1440
|
+
return RBLNDecoderOnlyForCausalLMOutput(
|
|
1240
1441
|
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
1241
1442
|
)
|