optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +44 -0
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +4 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +48 -0
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/models/__init__.py +35 -14
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
- optimum/rbln/utils/depreacate_utils.py +16 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
16
|
from collections import deque
|
|
18
17
|
from dataclasses import dataclass
|
|
19
18
|
from pathlib import Path
|
|
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
|
|
|
22
21
|
import rebel
|
|
23
22
|
import torch
|
|
24
23
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
24
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
26
|
from transformers.modeling_utils import no_init_weights
|
|
27
27
|
from transformers.utils import ModelOutput
|
|
28
28
|
|
|
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
|
|
|
30
30
|
from ....modeling import RBLNModel
|
|
31
31
|
from ....utils.logging import get_logger
|
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
33
|
-
from ...
|
|
34
|
-
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
33
|
+
from ...modeling_attention_utils import (
|
|
34
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
35
|
set_default_values,
|
|
38
36
|
validate_attention_method,
|
|
39
|
-
|
|
37
|
+
validate_sliding_window,
|
|
40
38
|
)
|
|
39
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
|
40
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
41
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
logger = get_logger()
|
|
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
267
268
|
|
|
268
269
|
attention_mask = self.dec_attn_mask
|
|
269
270
|
|
|
270
|
-
if self.rbln_config.
|
|
271
|
+
if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
|
|
271
272
|
block_tables = block_tables[: self.batch_size]
|
|
272
273
|
|
|
273
274
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
283
284
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
285
|
)
|
|
285
286
|
|
|
286
|
-
return
|
|
287
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
|
|
287
288
|
|
|
288
289
|
def _prepare_prefill_inputs(
|
|
289
290
|
self,
|
|
@@ -449,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
449
450
|
self.dec_attn_mask[batch_idx].fill_(0)
|
|
450
451
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
451
452
|
|
|
452
|
-
return
|
|
453
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
453
454
|
|
|
454
455
|
|
|
455
456
|
@dataclass
|
|
456
|
-
class
|
|
457
|
+
class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
|
|
457
458
|
logits: torch.FloatTensor = None
|
|
458
459
|
generate_idx: torch.Tensor = None
|
|
459
460
|
padded_cache_lengths: int = None
|
|
460
461
|
|
|
461
462
|
|
|
462
|
-
class
|
|
463
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
463
464
|
"""
|
|
464
|
-
A base class for decoder-only transformer models
|
|
465
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
466
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
465
467
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
466
468
|
|
|
467
469
|
The class provides core functionality for:
|
|
468
470
|
|
|
469
471
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
470
472
|
2. Handling the compilation process for RBLN devices
|
|
471
|
-
3. Managing inference operations for
|
|
473
|
+
3. Managing inference operations for decoder-only architectures
|
|
472
474
|
|
|
473
475
|
This class inherits from RBLNModel and implements specific methods required for
|
|
474
|
-
decoder-only architectures
|
|
476
|
+
decoder-only architectures.
|
|
475
477
|
|
|
476
478
|
Note:
|
|
477
479
|
- This class is designed to be subclassed by specific model implementations
|
|
478
|
-
(e.g.,
|
|
480
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
479
481
|
- Subclasses should implement model-specific conversion logic.
|
|
480
482
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
481
483
|
"""
|
|
482
484
|
|
|
483
485
|
main_input_name = "input_ids"
|
|
484
|
-
auto_model_class =
|
|
486
|
+
auto_model_class = AutoModel
|
|
485
487
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
486
488
|
_use_rotary_emb = True
|
|
487
489
|
|
|
488
490
|
def __post_init__(self, **kwargs):
|
|
489
|
-
main_input_name = self.main_input_name
|
|
490
|
-
|
|
491
491
|
if self.rbln_config.use_inputs_embeds:
|
|
492
|
-
main_input_name = "inputs_embeds"
|
|
493
492
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
494
493
|
self.embed_tokens = self._create_embedding_layer()
|
|
495
494
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
496
495
|
else:
|
|
497
496
|
self.embed_tokens = None
|
|
498
497
|
|
|
499
|
-
#
|
|
500
|
-
|
|
501
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
502
|
-
)
|
|
503
|
-
block_tables = torch.zeros(
|
|
504
|
-
self.rbln_config.batch_size,
|
|
505
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
506
|
-
dtype=torch.int16,
|
|
507
|
-
).fill_(-1)
|
|
508
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
509
|
-
|
|
510
|
-
self.prefill_decoder = RBLNRuntimeModel(
|
|
511
|
-
runtime=self.model[0],
|
|
512
|
-
main_input_name=main_input_name,
|
|
513
|
-
embed_tokens=self.embed_tokens,
|
|
514
|
-
phase="prefill",
|
|
515
|
-
batch_size=self.rbln_config.batch_size,
|
|
516
|
-
dec_attn_mask=dec_attn_mask,
|
|
517
|
-
block_tables=block_tables,
|
|
518
|
-
free_block_pool=free_block_pool,
|
|
519
|
-
rbln_config=self.rbln_config,
|
|
520
|
-
vocab_size=self.config.vocab_size,
|
|
521
|
-
)
|
|
498
|
+
# TODO: add prefill runtime class.
|
|
499
|
+
self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
|
|
522
500
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
self.
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
dec_attn_mask=dec_attn_mask,
|
|
532
|
-
block_tables=block_tables,
|
|
533
|
-
free_block_pool=free_block_pool,
|
|
534
|
-
rbln_config=self.rbln_config,
|
|
501
|
+
# attributes for prefill
|
|
502
|
+
if self.rbln_config.use_global_attention:
|
|
503
|
+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
504
|
+
if self.rbln_config.use_local_attention:
|
|
505
|
+
self.local_block_tables = torch.tensor([0], dtype=torch.int16)
|
|
506
|
+
if self.rbln_config.use_attention_mask:
|
|
507
|
+
self.causal_mask = 1 - torch.triu(
|
|
508
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
535
509
|
)
|
|
536
510
|
|
|
537
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
538
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
539
|
-
|
|
540
511
|
@classmethod
|
|
541
512
|
def save_torch_artifacts(
|
|
542
513
|
cls,
|
|
@@ -571,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
571
542
|
return self.rbln_config.kvcache_num_blocks
|
|
572
543
|
|
|
573
544
|
@classmethod
|
|
574
|
-
def
|
|
575
|
-
cls,
|
|
576
|
-
model_id: str,
|
|
577
|
-
config: Optional[PretrainedConfig] = None,
|
|
578
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
|
579
|
-
revision: Optional[str] = None,
|
|
580
|
-
force_download: bool = False,
|
|
581
|
-
cache_dir: Optional[str] = None,
|
|
582
|
-
subfolder: str = "",
|
|
583
|
-
local_files_only: bool = False,
|
|
584
|
-
trust_remote_code: bool = False,
|
|
585
|
-
**kwargs,
|
|
586
|
-
):
|
|
587
|
-
kwargs = cls.update_kwargs(kwargs)
|
|
588
|
-
|
|
589
|
-
if config is None:
|
|
590
|
-
config = AutoConfig.from_pretrained(
|
|
591
|
-
model_id,
|
|
592
|
-
use_auth_token=use_auth_token,
|
|
593
|
-
revision=revision,
|
|
594
|
-
force_download=force_download,
|
|
595
|
-
cache_dir=cache_dir,
|
|
596
|
-
trust_remote_code=trust_remote_code,
|
|
597
|
-
**kwargs,
|
|
598
|
-
)
|
|
599
|
-
|
|
600
|
-
with no_init_weights():
|
|
601
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
602
|
-
|
|
603
|
-
model = prepare_model_for_quantization(
|
|
604
|
-
model,
|
|
605
|
-
model_id,
|
|
606
|
-
kwargs.get("num_hidden_layers"),
|
|
607
|
-
use_auth_token=use_auth_token,
|
|
608
|
-
revision=revision,
|
|
609
|
-
cache_dir=cache_dir,
|
|
610
|
-
force_download=force_download,
|
|
611
|
-
local_files_only=local_files_only,
|
|
612
|
-
)
|
|
613
|
-
return model
|
|
614
|
-
|
|
615
|
-
def __getattr__(self, __name: str) -> Any:
|
|
616
|
-
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
617
|
-
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
618
|
-
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
619
|
-
# proper method binding.
|
|
620
|
-
|
|
621
|
-
# The method implements a delegation pattern that:
|
|
622
|
-
|
|
623
|
-
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
624
|
-
# 2. For other attributes: Returns them directly from the original class
|
|
625
|
-
|
|
626
|
-
def redirect(func):
|
|
627
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
628
|
-
|
|
629
|
-
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
630
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
631
|
-
return redirect(val)
|
|
632
|
-
return val
|
|
633
|
-
|
|
634
|
-
@classmethod
|
|
635
|
-
def get_pytorch_model(
|
|
636
|
-
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
|
637
|
-
) -> PreTrainedModel:
|
|
638
|
-
if rbln_config and rbln_config.quantization:
|
|
639
|
-
model = cls.get_quantized_model(*args, **kwargs)
|
|
640
|
-
else:
|
|
641
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
642
|
-
|
|
643
|
-
return model
|
|
644
|
-
|
|
645
|
-
@classmethod
|
|
646
|
-
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
545
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
647
546
|
wrapper_cfg = {
|
|
648
547
|
"max_seq_len": rbln_config.max_seq_len,
|
|
649
548
|
"attn_impl": rbln_config.attn_impl,
|
|
@@ -660,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
660
559
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
661
560
|
|
|
662
561
|
@classmethod
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
wrapped_model
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
562
|
+
def _compile_model(
|
|
563
|
+
cls,
|
|
564
|
+
wrapped_model,
|
|
565
|
+
compile_config,
|
|
566
|
+
example_inputs,
|
|
567
|
+
compile_context,
|
|
568
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
569
|
+
quantization=None,
|
|
570
|
+
phase: str = "prefill",
|
|
571
|
+
):
|
|
572
|
+
try:
|
|
573
|
+
wrapped_model.phase = phase
|
|
574
|
+
if quantization:
|
|
575
|
+
quantization.maybe_set_quantization_env()
|
|
576
|
+
original_linear = torch.nn.functional.linear
|
|
577
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
578
|
+
compiled_model = cls.compile(
|
|
579
|
+
wrapped_model,
|
|
580
|
+
compile_config,
|
|
581
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
582
|
+
device=rbln_config.device,
|
|
583
|
+
example_inputs=example_inputs,
|
|
584
|
+
compile_context=compile_context,
|
|
585
|
+
)
|
|
586
|
+
return compiled_model
|
|
587
|
+
finally:
|
|
588
|
+
torch.nn.functional.linear = original_linear
|
|
589
|
+
if quantization:
|
|
590
|
+
quantization.maybe_reset_quantization_env()
|
|
669
591
|
|
|
592
|
+
@classmethod
|
|
593
|
+
def _get_compile_context(
|
|
594
|
+
cls,
|
|
595
|
+
compile_config: RBLNCompileConfig,
|
|
596
|
+
example_inputs: List[torch.Tensor],
|
|
597
|
+
):
|
|
670
598
|
context = CompileContext(use_weight_sharing=True)
|
|
671
599
|
|
|
672
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
673
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
674
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
675
|
-
|
|
676
600
|
# Mark static tensors (self kv states)
|
|
677
601
|
static_tensors = {}
|
|
678
|
-
for (name, _, _), tensor in zip(
|
|
602
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
679
603
|
if "past_key_values" in name:
|
|
680
604
|
static_tensors[name] = tensor
|
|
681
605
|
context.mark_static_address(tensor)
|
|
682
606
|
|
|
683
|
-
|
|
684
|
-
try:
|
|
685
|
-
if quantization:
|
|
686
|
-
quantization.maybe_set_quantization_env()
|
|
687
|
-
original_linear = torch.nn.functional.linear
|
|
688
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
689
|
-
compiled_model = cls.compile(
|
|
690
|
-
wrapped_model,
|
|
691
|
-
compile_config,
|
|
692
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
693
|
-
device=rbln_config.device,
|
|
694
|
-
example_inputs=example_inputs,
|
|
695
|
-
compile_context=compile_context,
|
|
696
|
-
)
|
|
697
|
-
return compiled_model
|
|
698
|
-
finally:
|
|
699
|
-
torch.nn.functional.linear = original_linear
|
|
700
|
-
if quantization:
|
|
701
|
-
quantization.maybe_reset_quantization_env()
|
|
702
|
-
|
|
703
|
-
wrapped_model.phase = "prefill"
|
|
704
|
-
compiled_prefill = compile_model(
|
|
705
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
706
|
-
)
|
|
707
|
-
|
|
708
|
-
wrapped_model.phase = "decode"
|
|
709
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
710
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
711
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
712
|
-
compiled_decoder = compile_model(
|
|
713
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
714
|
-
)
|
|
715
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
716
|
-
|
|
717
|
-
# check if the memory is enough to have additional blocks
|
|
718
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
719
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
720
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
721
|
-
compiled_models=compiled_models,
|
|
722
|
-
model_config=model.config,
|
|
723
|
-
rbln_config=rbln_config,
|
|
724
|
-
)
|
|
725
|
-
|
|
726
|
-
return compiled_models
|
|
607
|
+
return context, static_tensors
|
|
727
608
|
|
|
728
609
|
@classmethod
|
|
729
|
-
|
|
610
|
+
@torch.inference_mode()
|
|
611
|
+
def get_compiled_model(
|
|
730
612
|
cls,
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
|
|
738
|
-
|
|
739
|
-
|
|
740
|
-
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
|
|
744
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
745
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
746
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
747
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
748
|
-
|
|
749
|
-
# Get the maximum number of blocks that can be allocated
|
|
750
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
751
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
752
|
-
config=model_config,
|
|
753
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
754
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
755
|
-
kernel_size=kernel_size,
|
|
756
|
-
buffer=buffer,
|
|
613
|
+
model: PreTrainedModel,
|
|
614
|
+
rbln_config: RBLNDecoderOnlyModelConfig,
|
|
615
|
+
):
|
|
616
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
617
|
+
compile_config = rbln_config.compile_cfgs[0]
|
|
618
|
+
|
|
619
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
620
|
+
meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
|
|
621
|
+
example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
622
|
+
context, _ = cls._get_compile_context(compile_config, example_inputs)
|
|
623
|
+
|
|
624
|
+
compiled_model = cls._compile_model(
|
|
625
|
+
wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
|
|
757
626
|
)
|
|
627
|
+
compiled_models = {"prefill": compiled_model}
|
|
758
628
|
|
|
759
|
-
|
|
760
|
-
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
761
|
-
# If the memory is not enough, the model will fail to compile.
|
|
762
|
-
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
763
|
-
logger.warning(
|
|
764
|
-
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
765
|
-
"Our analysis indicates that additional memory is available for more blocks. "
|
|
766
|
-
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
767
|
-
"Please be advised that our memory estimation algorithm has limitations, "
|
|
768
|
-
"and increasing this value may not guarantee successful model compilation."
|
|
769
|
-
)
|
|
629
|
+
return compiled_models
|
|
770
630
|
|
|
771
631
|
@classmethod
|
|
772
|
-
def
|
|
773
|
-
cls,
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
kvcache_block_size: int,
|
|
777
|
-
nbits_per_param: Optional[int] = None,
|
|
778
|
-
n_model_params: Optional[int] = None,
|
|
779
|
-
kernel_size: Optional[int] = None,
|
|
780
|
-
buffer: Optional[int] = None,
|
|
781
|
-
num_runtimes: int = 2,
|
|
782
|
-
) -> int:
|
|
783
|
-
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
784
|
-
|
|
785
|
-
# available_dram - kernel_size - buffer
|
|
786
|
-
# - num_layers * 2 * tensor_parallel_size
|
|
787
|
-
# * align_2MB(
|
|
788
|
-
# x
|
|
789
|
-
# * block_size
|
|
790
|
-
# * align_64(head_dim)
|
|
791
|
-
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
792
|
-
# * 2
|
|
793
|
-
# ) > 0
|
|
794
|
-
|
|
795
|
-
# This inequality can be rewritten as follows:
|
|
796
|
-
|
|
797
|
-
# a - c * align_2MB(b * x) > 0
|
|
798
|
-
# where
|
|
799
|
-
# a = available_dram - kernel_size - buffer
|
|
800
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
801
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
802
|
-
|
|
803
|
-
# We can rewrite the inequality as follows:
|
|
804
|
-
# k > align_2MB(b*x)
|
|
805
|
-
# where
|
|
806
|
-
# k = a / c
|
|
807
|
-
|
|
808
|
-
# After that, we can derive the following equation:
|
|
809
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
810
|
-
|
|
811
|
-
def align(x: int, nbytes: int) -> int:
|
|
812
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
813
|
-
|
|
814
|
-
def align_2MB(x: int) -> int:
|
|
815
|
-
return align(x, 2**21)
|
|
816
|
-
|
|
817
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
818
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
819
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
820
|
-
vocab_size = config.vocab_size
|
|
821
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
822
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
823
|
-
|
|
824
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
825
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
826
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
827
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
828
|
-
|
|
829
|
-
if kernel_size is None:
|
|
830
|
-
if n_model_params is None:
|
|
831
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
832
|
-
# Get estimated kernel size (approximated)
|
|
833
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
834
|
-
lm_heads_nbytes = (
|
|
835
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
836
|
-
)
|
|
837
|
-
params = n_model_params - lm_heads_params
|
|
838
|
-
layer_nbytes = (
|
|
839
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
840
|
-
* num_layers
|
|
841
|
-
* tensor_parallel_size
|
|
842
|
-
)
|
|
843
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
844
|
-
elif n_model_params is not None:
|
|
845
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
846
|
-
|
|
847
|
-
available_dram -= kernel_size
|
|
632
|
+
def get_quantized_model(
|
|
633
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
634
|
+
) -> PreTrainedModel:
|
|
635
|
+
raise NotImplementedError
|
|
848
636
|
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
637
|
+
@classmethod
|
|
638
|
+
def get_pytorch_model(
|
|
639
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
640
|
+
) -> PreTrainedModel:
|
|
641
|
+
if rbln_config and rbln_config.quantization:
|
|
642
|
+
model = cls.get_quantized_model(*args, **kwargs)
|
|
643
|
+
else:
|
|
644
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
855
645
|
|
|
856
|
-
|
|
857
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
858
|
-
k = available_dram / c
|
|
859
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
646
|
+
return model
|
|
860
647
|
|
|
861
|
-
|
|
648
|
+
@classmethod
|
|
649
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
650
|
+
return use_local_attention
|
|
862
651
|
|
|
863
652
|
@classmethod
|
|
864
653
|
def get_input_info(
|
|
@@ -868,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
868
657
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
869
658
|
model_config: PretrainedConfig,
|
|
870
659
|
):
|
|
871
|
-
is_prefill: bool = query_length > 1
|
|
872
660
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
873
661
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
874
662
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
875
663
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
876
664
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
877
|
-
|
|
665
|
+
is_prefill = query_length > 1
|
|
878
666
|
|
|
879
667
|
# 1. main input
|
|
880
668
|
if rbln_config.use_inputs_embeds:
|
|
@@ -893,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
893
681
|
]
|
|
894
682
|
|
|
895
683
|
# 3. block_tables
|
|
896
|
-
if rbln_config.
|
|
684
|
+
if rbln_config.use_global_attention:
|
|
897
685
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
898
686
|
input_info.extend(
|
|
899
687
|
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
|
900
688
|
)
|
|
901
|
-
if rbln_config.
|
|
689
|
+
if rbln_config.use_local_attention:
|
|
902
690
|
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
|
903
691
|
|
|
904
|
-
# 4. query_position
|
|
905
|
-
if is_prefill:
|
|
692
|
+
# 4. query_position for sliding window attention
|
|
693
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
906
694
|
input_info.extend([("query_position", [], "int16")])
|
|
907
695
|
|
|
908
696
|
# 5. attention_mask & position_ids
|
|
@@ -924,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
924
712
|
rbln_config.kvcache_block_size,
|
|
925
713
|
head_dim,
|
|
926
714
|
]
|
|
927
|
-
local_kvcache_shape = [
|
|
715
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
928
716
|
input_info.extend(
|
|
929
717
|
[
|
|
930
718
|
(
|
|
@@ -971,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
971
759
|
# ```
|
|
972
760
|
|
|
973
761
|
# Returns:
|
|
974
|
-
#
|
|
762
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
975
763
|
|
|
976
764
|
raise NotImplementedError(
|
|
977
765
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
978
766
|
"See method docstring for required configuration details."
|
|
979
767
|
)
|
|
980
768
|
|
|
769
|
+
@classmethod
|
|
770
|
+
def _update_attention_config(
|
|
771
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
772
|
+
):
|
|
773
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
774
|
+
attn_impl=rbln_config.attn_impl,
|
|
775
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
776
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
777
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
validate_attention_method(
|
|
781
|
+
attn_impl=rbln_config.attn_impl,
|
|
782
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
783
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
784
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
788
|
+
rbln_config.kvcache_num_blocks = (
|
|
789
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
790
|
+
) * rbln_config.batch_size
|
|
791
|
+
|
|
792
|
+
return rbln_config
|
|
793
|
+
|
|
981
794
|
@classmethod
|
|
982
795
|
def _update_rbln_config(
|
|
983
796
|
cls,
|
|
@@ -998,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
998
811
|
):
|
|
999
812
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
1000
813
|
if rbln_config.sliding_window is not None:
|
|
1001
|
-
|
|
814
|
+
validate_sliding_window(rbln_config)
|
|
815
|
+
|
|
816
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
817
|
+
|
|
818
|
+
prefill_input_info = cls.get_input_info(
|
|
819
|
+
batch_size=1,
|
|
820
|
+
query_length=rbln_config.prefill_chunk_size,
|
|
821
|
+
rbln_config=rbln_config,
|
|
822
|
+
model_config=model_config,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
826
|
+
rbln_config.set_compile_cfgs([prefill_compile_config])
|
|
1002
827
|
|
|
828
|
+
return rbln_config
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
831
|
+
def _create_runtimes(
|
|
832
|
+
cls,
|
|
833
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
834
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
835
|
+
) -> List[rebel.Runtime]:
|
|
836
|
+
expected_model_names = [
|
|
837
|
+
"prefill",
|
|
838
|
+
]
|
|
839
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
840
|
+
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
841
|
+
|
|
842
|
+
return [
|
|
843
|
+
rebel.Runtime(
|
|
844
|
+
compiled_models[0],
|
|
845
|
+
tensor_type="pt",
|
|
846
|
+
device=rbln_config.device_map["prefill"],
|
|
847
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
848
|
+
),
|
|
849
|
+
]
|
|
850
|
+
|
|
851
|
+
def _preprocess_chunked_prefill(
|
|
852
|
+
self,
|
|
853
|
+
inputs: torch.Tensor,
|
|
854
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
855
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
856
|
+
):
|
|
857
|
+
# valid sequence length of inputs_embeds
|
|
858
|
+
query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
|
|
859
|
+
|
|
860
|
+
# extract valid inputs
|
|
861
|
+
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
862
|
+
|
|
863
|
+
if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
|
|
864
|
+
inputs = self.get_input_embeddings()(inputs)
|
|
865
|
+
|
|
866
|
+
if position_embed is not None:
|
|
867
|
+
position_embed = (
|
|
868
|
+
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
# padding for chunked prefill
|
|
872
|
+
padding_size = (
|
|
873
|
+
self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
|
|
874
|
+
) % self.rbln_config.prefill_chunk_size
|
|
875
|
+
padded_len = query_length + padding_size
|
|
876
|
+
|
|
877
|
+
inputs = (
|
|
878
|
+
torch.nn.functional.pad(inputs, (0, padding_size))
|
|
879
|
+
if not self.rbln_config.use_inputs_embeds
|
|
880
|
+
else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
881
|
+
)
|
|
882
|
+
position_embed = (
|
|
883
|
+
None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
884
|
+
)
|
|
885
|
+
cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
|
|
886
|
+
|
|
887
|
+
chunked_attention_mask = (
|
|
888
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
889
|
+
if self.rbln_config.use_attention_mask
|
|
890
|
+
else None
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
return inputs, position_embed, cache_position, query_length, chunked_attention_mask
|
|
894
|
+
|
|
895
|
+
def _chunked_prefill_forward(
|
|
896
|
+
self,
|
|
897
|
+
inputs: torch.Tensor,
|
|
898
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
899
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
900
|
+
):
|
|
901
|
+
padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
|
|
902
|
+
self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# chunked prefill
|
|
906
|
+
last_hidden_states = []
|
|
907
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
908
|
+
# Extract the current chunk of inputs and cache positions
|
|
909
|
+
input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
910
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
911
|
+
|
|
912
|
+
valid_length = (
|
|
913
|
+
self.rbln_config.prefill_chunk_size
|
|
914
|
+
if (step + self.rbln_config.prefill_chunk_size) <= query_length
|
|
915
|
+
else query_length - step
|
|
916
|
+
)
|
|
917
|
+
if self.rbln_config.use_local_attention:
|
|
918
|
+
query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
|
|
919
|
+
else:
|
|
920
|
+
query_position = None
|
|
921
|
+
|
|
922
|
+
if self.rbln_config.use_attention_mask:
|
|
923
|
+
if step > 0:
|
|
924
|
+
chunked_attention_mask[:, :, :, :step] = 1
|
|
925
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
926
|
+
|
|
927
|
+
# Forward pass for the current chunk
|
|
928
|
+
last_hidden_states_chunk = self.prefill_decoder(
|
|
929
|
+
input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
|
|
930
|
+
inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
|
|
931
|
+
cache_position=cache_pos_chunk,
|
|
932
|
+
block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
|
|
933
|
+
local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
|
|
934
|
+
query_position=query_position,
|
|
935
|
+
attention_mask=chunked_attention_mask,
|
|
936
|
+
position_emb=padded_position_embed,
|
|
937
|
+
)
|
|
938
|
+
last_hidden_states.append(last_hidden_states_chunk)
|
|
939
|
+
last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
|
|
940
|
+
|
|
941
|
+
return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
|
|
942
|
+
|
|
943
|
+
def _postprocess_chunked_prefill(
|
|
944
|
+
self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
|
945
|
+
):
|
|
946
|
+
# index copy for attention mask
|
|
947
|
+
if attention_mask is not None:
|
|
948
|
+
new_last_hidden_states = torch.full(
|
|
949
|
+
(1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
|
|
950
|
+
fill_value=1e-10,
|
|
951
|
+
dtype=last_hidden_states.dtype,
|
|
952
|
+
)
|
|
953
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
954
|
+
new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
|
|
955
|
+
else:
|
|
956
|
+
new_last_hidden_states = last_hidden_states
|
|
957
|
+
return new_last_hidden_states
|
|
958
|
+
|
|
959
|
+
def forward(
|
|
960
|
+
self,
|
|
961
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
962
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
963
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
964
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
965
|
+
**kwargs,
|
|
966
|
+
) -> Tuple[torch.FloatTensor]:
|
|
967
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
968
|
+
batch_size = inputs.shape[0]
|
|
969
|
+
all_last_hidden_states = []
|
|
970
|
+
for b_idx in range(batch_size):
|
|
971
|
+
last_hidden_states = self._chunked_prefill_forward(
|
|
972
|
+
inputs[b_idx : b_idx + 1],
|
|
973
|
+
attention_mask[b_idx] if attention_mask is not None else None,
|
|
974
|
+
position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
975
|
+
)
|
|
976
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
977
|
+
|
|
978
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
979
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
983
|
+
"""
|
|
984
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
985
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
986
|
+
|
|
987
|
+
The class provides core functionality for:
|
|
988
|
+
|
|
989
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
990
|
+
2. Handling the compilation process for RBLN devices
|
|
991
|
+
3. Managing inference operations for causal language modeling
|
|
992
|
+
|
|
993
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
994
|
+
decoder-only architectures and causal language modeling tasks.
|
|
995
|
+
|
|
996
|
+
Note:
|
|
997
|
+
- This class is designed to be subclassed by specific model implementations
|
|
998
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
999
|
+
- Subclasses should implement model-specific conversion logic.
|
|
1000
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
1001
|
+
"""
|
|
1002
|
+
|
|
1003
|
+
auto_model_class = AutoModelForCausalLM
|
|
1004
|
+
|
|
1005
|
+
def __post_init__(self, **kwargs):
|
|
1006
|
+
main_input_name = self.main_input_name
|
|
1007
|
+
|
|
1008
|
+
if self.rbln_config.use_inputs_embeds:
|
|
1009
|
+
main_input_name = "inputs_embeds"
|
|
1010
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
1011
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
1012
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
1013
|
+
else:
|
|
1014
|
+
self.embed_tokens = None
|
|
1015
|
+
|
|
1016
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
1017
|
+
dec_attn_mask = torch.zeros(
|
|
1018
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
1019
|
+
)
|
|
1020
|
+
block_tables = torch.zeros(
|
|
1021
|
+
self.rbln_config.batch_size,
|
|
1022
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
1023
|
+
dtype=torch.int16,
|
|
1024
|
+
).fill_(-1)
|
|
1025
|
+
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
1026
|
+
|
|
1027
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
|
1028
|
+
runtime=self.model[0],
|
|
1029
|
+
main_input_name=main_input_name,
|
|
1030
|
+
embed_tokens=self.embed_tokens,
|
|
1031
|
+
phase="prefill",
|
|
1032
|
+
batch_size=self.rbln_config.batch_size,
|
|
1033
|
+
dec_attn_mask=dec_attn_mask,
|
|
1034
|
+
block_tables=block_tables,
|
|
1035
|
+
free_block_pool=free_block_pool,
|
|
1036
|
+
rbln_config=self.rbln_config,
|
|
1037
|
+
vocab_size=self.config.vocab_size,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
if self.can_generate():
|
|
1041
|
+
self.decoders = {}
|
|
1042
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
1043
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
1044
|
+
runtime=self.model[i + 1],
|
|
1045
|
+
main_input_name=main_input_name,
|
|
1046
|
+
embed_tokens=self.embed_tokens,
|
|
1047
|
+
phase="decode",
|
|
1048
|
+
batch_size=batch_size,
|
|
1049
|
+
dec_attn_mask=dec_attn_mask,
|
|
1050
|
+
block_tables=block_tables,
|
|
1051
|
+
free_block_pool=free_block_pool,
|
|
1052
|
+
rbln_config=self.rbln_config,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
1056
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
1057
|
+
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def get_quantized_model(
|
|
1060
|
+
cls,
|
|
1061
|
+
model_id: str,
|
|
1062
|
+
config: Optional[PretrainedConfig] = None,
|
|
1063
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
1064
|
+
revision: Optional[str] = None,
|
|
1065
|
+
force_download: bool = False,
|
|
1066
|
+
cache_dir: Optional[str] = None,
|
|
1067
|
+
subfolder: str = "",
|
|
1068
|
+
local_files_only: bool = False,
|
|
1069
|
+
trust_remote_code: bool = False,
|
|
1070
|
+
**kwargs,
|
|
1071
|
+
):
|
|
1072
|
+
kwargs = cls.update_kwargs(kwargs)
|
|
1073
|
+
|
|
1074
|
+
if config is None:
|
|
1075
|
+
config = AutoConfig.from_pretrained(
|
|
1076
|
+
model_id,
|
|
1077
|
+
use_auth_token=use_auth_token,
|
|
1078
|
+
revision=revision,
|
|
1079
|
+
force_download=force_download,
|
|
1080
|
+
cache_dir=cache_dir,
|
|
1081
|
+
trust_remote_code=trust_remote_code,
|
|
1082
|
+
**kwargs,
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
with no_init_weights():
|
|
1086
|
+
model = AutoModelForCausalLM.from_config(config)
|
|
1087
|
+
|
|
1088
|
+
model = prepare_model_for_quantization(
|
|
1089
|
+
model,
|
|
1090
|
+
model_id,
|
|
1091
|
+
kwargs.get("num_hidden_layers"),
|
|
1092
|
+
use_auth_token=use_auth_token,
|
|
1093
|
+
revision=revision,
|
|
1094
|
+
cache_dir=cache_dir,
|
|
1095
|
+
force_download=force_download,
|
|
1096
|
+
local_files_only=local_files_only,
|
|
1097
|
+
)
|
|
1098
|
+
return model
|
|
1099
|
+
|
|
1100
|
+
def __getattr__(self, __name: str) -> Any:
|
|
1101
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
1102
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
1103
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
1104
|
+
# proper method binding.
|
|
1105
|
+
|
|
1106
|
+
# The method implements a delegation pattern that:
|
|
1107
|
+
|
|
1108
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
1109
|
+
# 2. For other attributes: Returns them directly from the original class
|
|
1110
|
+
|
|
1111
|
+
def redirect(func):
|
|
1112
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
1113
|
+
|
|
1114
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
1115
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
1116
|
+
return redirect(val)
|
|
1117
|
+
return val
|
|
1118
|
+
|
|
1119
|
+
@classmethod
|
|
1120
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
1121
|
+
wrapper_cfg = {
|
|
1122
|
+
"max_seq_len": rbln_config.max_seq_len,
|
|
1123
|
+
"attn_impl": rbln_config.attn_impl,
|
|
1124
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
1125
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
1126
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
|
1127
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
|
1128
|
+
"use_position_ids": rbln_config.use_position_ids,
|
|
1129
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
1130
|
+
"cache_impl": rbln_config.cache_impl,
|
|
1131
|
+
"sliding_window": rbln_config.sliding_window,
|
|
1132
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
1133
|
+
}
|
|
1134
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
1135
|
+
|
|
1136
|
+
@classmethod
|
|
1137
|
+
@torch.inference_mode()
|
|
1138
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
1139
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
1140
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
1141
|
+
|
|
1142
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
1143
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
1144
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
1145
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
1146
|
+
|
|
1147
|
+
compiled_models = {}
|
|
1148
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
1149
|
+
wrapped_model,
|
|
1150
|
+
prefill_compile_config,
|
|
1151
|
+
prefill_example_inputs,
|
|
1152
|
+
context,
|
|
1153
|
+
rbln_config,
|
|
1154
|
+
rbln_config.quantization,
|
|
1155
|
+
phase="prefill",
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
if rbln_config.can_generate:
|
|
1159
|
+
wrapped_model.phase = "decode"
|
|
1160
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
1161
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
1162
|
+
compiled_decoder = cls._compile_model(
|
|
1163
|
+
wrapped_model,
|
|
1164
|
+
dec_compile_config,
|
|
1165
|
+
dec_example_inputs,
|
|
1166
|
+
context,
|
|
1167
|
+
rbln_config,
|
|
1168
|
+
rbln_config.quantization,
|
|
1169
|
+
phase="decode",
|
|
1170
|
+
)
|
|
1171
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
1172
|
+
|
|
1173
|
+
# check if the memory is enough to have additional blocks
|
|
1174
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1175
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
1176
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
1177
|
+
compiled_models=compiled_models,
|
|
1178
|
+
model_config=model.config,
|
|
1179
|
+
rbln_config=rbln_config,
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
return compiled_models
|
|
1183
|
+
|
|
1184
|
+
@classmethod
|
|
1185
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
1186
|
+
return is_prefill
|
|
1187
|
+
|
|
1188
|
+
@classmethod
|
|
1189
|
+
def _update_attention_config(
|
|
1190
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
1191
|
+
):
|
|
1003
1192
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
1004
1193
|
attn_impl=rbln_config.attn_impl,
|
|
1005
1194
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1024,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1024
1213
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1025
1214
|
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1026
1215
|
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1027
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
1216
|
+
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
1028
1217
|
)
|
|
1029
1218
|
|
|
1030
1219
|
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
|
1031
1220
|
|
|
1032
1221
|
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
|
1033
|
-
if max_num_blocks < flash_min_blocks:
|
|
1222
|
+
if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
|
|
1034
1223
|
max_num_blocks = flash_min_blocks
|
|
1035
1224
|
|
|
1036
1225
|
if max_num_blocks < rbln_config.batch_size:
|
|
@@ -1049,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1049
1238
|
)
|
|
1050
1239
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1051
1240
|
|
|
1052
|
-
|
|
1053
|
-
batch_size=1,
|
|
1054
|
-
query_length=rbln_config.prefill_chunk_size,
|
|
1055
|
-
rbln_config=rbln_config,
|
|
1056
|
-
model_config=model_config,
|
|
1057
|
-
)
|
|
1058
|
-
|
|
1059
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1241
|
+
return rbln_config
|
|
1060
1242
|
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1243
|
+
@classmethod
|
|
1244
|
+
def _update_rbln_config(
|
|
1245
|
+
cls,
|
|
1246
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
1247
|
+
model: Optional[PreTrainedModel] = None,
|
|
1248
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
1249
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
1250
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
1251
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
1252
|
+
if rbln_config.can_generate:
|
|
1253
|
+
compile_configs = rbln_config.compile_cfgs
|
|
1254
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
1255
|
+
dec_input_info = cls.get_input_info(
|
|
1256
|
+
batch_size=batch_size,
|
|
1257
|
+
query_length=1,
|
|
1258
|
+
rbln_config=rbln_config,
|
|
1259
|
+
model_config=model_config,
|
|
1260
|
+
)
|
|
1261
|
+
compile_configs.append(
|
|
1262
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
1263
|
+
)
|
|
1264
|
+
rbln_config.set_compile_cfgs(compile_configs)
|
|
1073
1265
|
|
|
1074
1266
|
return rbln_config
|
|
1075
1267
|
|
|
@@ -1079,38 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1079
1271
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1080
1272
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1081
1273
|
) -> List[rebel.Runtime]:
|
|
1082
|
-
expected_model_names = [
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
1274
|
+
expected_model_names = ["prefill"]
|
|
1275
|
+
if rbln_config.can_generate:
|
|
1276
|
+
expected_model_names.extend(
|
|
1277
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
1278
|
+
)
|
|
1086
1279
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1087
1280
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1088
1281
|
|
|
1089
|
-
|
|
1282
|
+
ret_val = [
|
|
1090
1283
|
rebel.Runtime(
|
|
1091
1284
|
compiled_models[0],
|
|
1092
1285
|
tensor_type="pt",
|
|
1093
1286
|
device=rbln_config.device_map["prefill"],
|
|
1094
1287
|
activate_profiler=rbln_config.activate_profiler,
|
|
1095
1288
|
timeout=rbln_config.timeout,
|
|
1096
|
-
)
|
|
1097
|
-
*[
|
|
1098
|
-
rebel.Runtime(
|
|
1099
|
-
compiled_models[i + 1],
|
|
1100
|
-
tensor_type="pt",
|
|
1101
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1102
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1103
|
-
timeout=rbln_config.timeout,
|
|
1104
|
-
)
|
|
1105
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1106
|
-
],
|
|
1289
|
+
)
|
|
1107
1290
|
]
|
|
1291
|
+
if rbln_config.can_generate:
|
|
1292
|
+
ret_val.extend(
|
|
1293
|
+
[
|
|
1294
|
+
rebel.Runtime(
|
|
1295
|
+
compiled_models[i + 1],
|
|
1296
|
+
tensor_type="pt",
|
|
1297
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1298
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
1299
|
+
timeout=rbln_config.timeout,
|
|
1300
|
+
)
|
|
1301
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1302
|
+
]
|
|
1303
|
+
)
|
|
1304
|
+
return ret_val
|
|
1108
1305
|
|
|
1109
1306
|
def get_decoder(self):
|
|
1307
|
+
if not self.can_generate():
|
|
1308
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
1110
1309
|
return self.decoder
|
|
1111
1310
|
|
|
1112
1311
|
def can_generate(self):
|
|
1113
|
-
return
|
|
1312
|
+
return self.rbln_config.can_generate
|
|
1114
1313
|
|
|
1115
1314
|
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1116
1315
|
raise NotImplementedError
|
|
@@ -1167,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1167
1366
|
|
|
1168
1367
|
def _update_model_kwargs_for_generation(
|
|
1169
1368
|
self,
|
|
1170
|
-
outputs:
|
|
1369
|
+
outputs: RBLNDecoderOnlyForCausalLMOutput,
|
|
1171
1370
|
model_kwargs: Dict[str, Any],
|
|
1172
1371
|
**kwargs,
|
|
1173
1372
|
) -> Dict[str, Any]:
|
|
@@ -1195,15 +1394,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1195
1394
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1196
1395
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
1197
1396
|
|
|
1397
|
+
# for only use forward
|
|
1398
|
+
if generate_idx is None:
|
|
1399
|
+
generate_idx = (
|
|
1400
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1401
|
+
if attention_mask is not None
|
|
1402
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
1403
|
+
)
|
|
1404
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1405
|
+
|
|
1198
1406
|
# Prefll
|
|
1199
1407
|
if cache_position is None:
|
|
1200
1408
|
logits = []
|
|
1201
1409
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
1202
|
-
# for only use forward
|
|
1203
|
-
if generate_idx is None:
|
|
1204
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1205
|
-
if padded_cache_lengths is None:
|
|
1206
|
-
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1207
1410
|
batch_size = inputs.shape[0]
|
|
1208
1411
|
for b_idx in range(batch_size):
|
|
1209
1412
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
@@ -1238,6 +1441,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1238
1441
|
if not return_dict:
|
|
1239
1442
|
return logits, generate_idx, padded_cache_lengths
|
|
1240
1443
|
else:
|
|
1241
|
-
return
|
|
1444
|
+
return RBLNDecoderOnlyForCausalLMOutput(
|
|
1242
1445
|
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
1243
1446
|
)
|