optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +116 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +171 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +33 -18
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +50 -24
- optimum/rbln/modeling_base.py +116 -35
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +100 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +93 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +22 -50
- optimum/rbln/utils/runtime_utils.py +85 -17
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
- {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,31 +13,31 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
|
-
from collections import deque
|
|
18
|
-
from dataclasses import dataclass
|
|
19
16
|
from pathlib import Path
|
|
20
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
21
18
|
|
|
22
19
|
import rebel
|
|
23
20
|
import torch
|
|
24
21
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import
|
|
22
|
+
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
24
|
from transformers.modeling_utils import no_init_weights
|
|
27
|
-
from transformers.utils import ModelOutput
|
|
28
25
|
|
|
29
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
30
27
|
from ....modeling import RBLNModel
|
|
31
28
|
from ....utils.logging import get_logger
|
|
32
|
-
from
|
|
33
|
-
|
|
34
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
29
|
+
from ...modeling_attention_utils import (
|
|
30
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
31
|
set_default_values,
|
|
38
32
|
validate_attention_method,
|
|
39
|
-
|
|
33
|
+
validate_sliding_window,
|
|
40
34
|
)
|
|
35
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
|
+
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
|
+
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
logger = get_logger()
|
|
@@ -46,522 +46,85 @@ if TYPE_CHECKING:
|
|
|
46
46
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class
|
|
50
|
-
mandatory_members = ["main_input_name", "embed_tokens"]
|
|
51
|
-
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
runtime: rebel.Runtime,
|
|
55
|
-
phase: str,
|
|
56
|
-
batch_size: int,
|
|
57
|
-
dec_attn_mask: torch.Tensor,
|
|
58
|
-
block_tables: torch.Tensor,
|
|
59
|
-
free_block_pool: Deque,
|
|
60
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
61
|
-
**kwargs: Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
super().__init__(runtime, **kwargs)
|
|
64
|
-
self.phase = phase
|
|
65
|
-
self.batch_size = batch_size
|
|
66
|
-
self.rbln_config = rbln_config
|
|
67
|
-
|
|
68
|
-
# shared tensor between prefill and decode phase
|
|
69
|
-
self.dec_attn_mask = dec_attn_mask
|
|
70
|
-
self.block_tables = block_tables
|
|
71
|
-
self.free_block_pool = free_block_pool
|
|
72
|
-
|
|
73
|
-
self.empty_block = -1
|
|
74
|
-
if self.phase == "prefill":
|
|
75
|
-
vocab_size = kwargs.pop("vocab_size")
|
|
76
|
-
self.output_size = [1, 1, vocab_size]
|
|
77
|
-
self.causal_mask = 1 - torch.triu(
|
|
78
|
-
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
|
|
82
|
-
"""
|
|
83
|
-
Manages and returns the KV cache block tables.
|
|
84
|
-
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
|
88
|
-
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Updated block tables.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
NO_BLOCKS_ERROR = (
|
|
95
|
-
"No memory blocks are available for allocation. "
|
|
96
|
-
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
|
97
|
-
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
|
98
|
-
"Using vllm-rbln should fix this issue and enhance inference performance."
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
def update_block(batch_idx: int, block_idx: int):
|
|
102
|
-
"""
|
|
103
|
-
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
104
|
-
"""
|
|
105
|
-
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
|
106
|
-
if self.free_block_pool:
|
|
107
|
-
block = self.free_block_pool.popleft()
|
|
108
|
-
self.block_tables[batch_idx][block_idx] = block
|
|
109
|
-
else:
|
|
110
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
111
|
-
|
|
112
|
-
def replace_empty_block(block_tables: torch.Tensor):
|
|
113
|
-
"""
|
|
114
|
-
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
|
115
|
-
"""
|
|
116
|
-
if not torch.any(block_tables == self.empty_block):
|
|
117
|
-
return block_tables.clone()
|
|
118
|
-
elif self.free_block_pool:
|
|
119
|
-
_free_block = self.free_block_pool[0]
|
|
120
|
-
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
|
121
|
-
else:
|
|
122
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
123
|
-
|
|
124
|
-
def get_global_block_tables(batch_idx: int):
|
|
125
|
-
if self.rbln_config.cache_impl == "sliding_window":
|
|
126
|
-
return None
|
|
127
|
-
|
|
128
|
-
if self.phase == "prefill":
|
|
129
|
-
# Track previously used blocks and return them to the free_block_pool and
|
|
130
|
-
# reset the current batch's block table to empty blocks
|
|
131
|
-
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
|
132
|
-
self.free_block_pool.extend(prev_blocks)
|
|
133
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
|
134
|
-
|
|
135
|
-
# Get the start (s) and end (e) positions from cache_position and
|
|
136
|
-
# iterate over the cache positions to allocate necessary blocks
|
|
137
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
138
|
-
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
139
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
140
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
141
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
142
|
-
update_block(batch_idx, block_idx)
|
|
143
|
-
|
|
144
|
-
return replace_empty_block(self.block_tables[batch_idx])
|
|
145
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
|
146
|
-
else:
|
|
147
|
-
for b_idx in range(self.batch_size):
|
|
148
|
-
position = cache_position[b_idx][0].item()
|
|
149
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
150
|
-
update_block(b_idx, block_idx)
|
|
151
|
-
|
|
152
|
-
return replace_empty_block(self.block_tables)
|
|
153
|
-
|
|
154
|
-
def get_local_block_tables(batch_idx: int):
|
|
155
|
-
if self.rbln_config.cache_impl == "static":
|
|
156
|
-
return None
|
|
157
|
-
else:
|
|
158
|
-
return (
|
|
159
|
-
torch.tensor([batch_idx], dtype=torch.int16)
|
|
160
|
-
if self.phase == "prefill"
|
|
161
|
-
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
|
165
|
-
|
|
166
|
-
def is_external_block_tables(
|
|
167
|
-
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
|
168
|
-
):
|
|
169
|
-
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
|
170
|
-
return False
|
|
171
|
-
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
|
172
|
-
return False
|
|
173
|
-
elif self.rbln_config.cache_impl == "hybrid":
|
|
174
|
-
if (block_tables is not None) != (local_block_tables is not None):
|
|
175
|
-
raise ValueError(
|
|
176
|
-
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
|
177
|
-
)
|
|
178
|
-
elif block_tables is None and local_block_tables is None:
|
|
179
|
-
return False
|
|
180
|
-
|
|
181
|
-
return True
|
|
182
|
-
|
|
183
|
-
def forward(
|
|
184
|
-
self,
|
|
185
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
186
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
187
|
-
cache_position: torch.Tensor = None,
|
|
188
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
189
|
-
batch_idx: Optional[int] = None,
|
|
190
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
192
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
193
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
194
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
195
|
-
):
|
|
196
|
-
if input_ids is None and inputs_embeds is None:
|
|
197
|
-
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
|
198
|
-
|
|
199
|
-
if inputs_embeds is None:
|
|
200
|
-
inputs = input_ids
|
|
201
|
-
if self.embed_tokens is not None:
|
|
202
|
-
inputs = self.embed_tokens(inputs)
|
|
203
|
-
else:
|
|
204
|
-
inputs = inputs_embeds
|
|
205
|
-
|
|
206
|
-
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
|
207
|
-
if not is_external_block_tables:
|
|
208
|
-
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
|
209
|
-
|
|
210
|
-
if self.phase == "decode":
|
|
211
|
-
return self.decode_forward(
|
|
212
|
-
inputs,
|
|
213
|
-
cache_position,
|
|
214
|
-
block_tables,
|
|
215
|
-
is_external_block_tables,
|
|
216
|
-
attention_mask=attention_mask,
|
|
217
|
-
position_embed=position_embed,
|
|
218
|
-
position_ids=position_ids,
|
|
219
|
-
local_block_tables=local_block_tables,
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
return self.prefill_forward(
|
|
223
|
-
inputs,
|
|
224
|
-
cache_position,
|
|
225
|
-
attention_mask,
|
|
226
|
-
batch_idx,
|
|
227
|
-
block_tables,
|
|
228
|
-
is_external_block_tables=is_external_block_tables,
|
|
229
|
-
position_embed=position_embed,
|
|
230
|
-
token_type_ids=token_type_ids,
|
|
231
|
-
local_block_tables=local_block_tables,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
def decode_forward(
|
|
235
|
-
self,
|
|
236
|
-
inputs: torch.Tensor,
|
|
237
|
-
cache_position: torch.Tensor = None,
|
|
238
|
-
block_tables: torch.Tensor = None,
|
|
239
|
-
is_external_block_tables: bool = None,
|
|
240
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
241
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
242
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
243
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
244
|
-
) -> torch.FloatTensor:
|
|
245
|
-
batch_size = inputs.shape[0]
|
|
246
|
-
if batch_size != self.batch_size:
|
|
247
|
-
raise RuntimeError(
|
|
248
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
if batch_size != cache_position.shape[0]:
|
|
252
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
253
|
-
|
|
254
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
255
|
-
for b_idx in range(batch_size):
|
|
256
|
-
decoding_step = cache_position[b_idx].item()
|
|
257
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
258
|
-
raise ValueError(
|
|
259
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if is_external_block_tables:
|
|
263
|
-
self.dec_attn_mask[b_idx].fill_(0)
|
|
264
|
-
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
265
|
-
else:
|
|
266
|
-
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
267
|
-
|
|
268
|
-
attention_mask = self.dec_attn_mask
|
|
269
|
-
|
|
270
|
-
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
|
271
|
-
block_tables = block_tables[: self.batch_size]
|
|
272
|
-
|
|
273
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
274
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
275
|
-
|
|
276
|
-
logits = super().forward(
|
|
277
|
-
inputs,
|
|
278
|
-
cache_position,
|
|
279
|
-
block_tables,
|
|
280
|
-
local_block_tables,
|
|
281
|
-
position_embed,
|
|
282
|
-
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
283
|
-
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
287
|
-
|
|
288
|
-
def _prepare_prefill_inputs(
|
|
289
|
-
self,
|
|
290
|
-
inputs: torch.Tensor,
|
|
291
|
-
cache_position: torch.Tensor,
|
|
292
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
293
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
294
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
295
|
-
):
|
|
296
|
-
"""
|
|
297
|
-
Prepare inputs for prefill phase.
|
|
298
|
-
"""
|
|
299
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
300
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
301
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
302
|
-
if position_embed is not None:
|
|
303
|
-
position_embed = (
|
|
304
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
305
|
-
)
|
|
306
|
-
|
|
307
|
-
query_length = inputs.shape[1]
|
|
308
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
309
|
-
raise ValueError(
|
|
310
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
# Initialize attention mask for chunked processing
|
|
314
|
-
chunked_attention_mask = (
|
|
315
|
-
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
316
|
-
if self.rbln_config.use_attention_mask
|
|
317
|
-
else None
|
|
318
|
-
)
|
|
319
|
-
|
|
320
|
-
# Buffer for storing output logits
|
|
321
|
-
out_buffers = [
|
|
322
|
-
torch.empty(
|
|
323
|
-
size=self.output_size,
|
|
324
|
-
dtype=torch.float32,
|
|
325
|
-
device="cpu",
|
|
326
|
-
)
|
|
327
|
-
]
|
|
328
|
-
|
|
329
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
330
|
-
padding_size = 0
|
|
331
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
332
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
333
|
-
# inputs_embeds
|
|
334
|
-
if inputs.dim() == 3:
|
|
335
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
336
|
-
# inputs_ids
|
|
337
|
-
else:
|
|
338
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
339
|
-
|
|
340
|
-
cache_position = torch.cat(
|
|
341
|
-
[
|
|
342
|
-
cache_position,
|
|
343
|
-
torch.arange(
|
|
344
|
-
query_length,
|
|
345
|
-
query_length + padding_size,
|
|
346
|
-
dtype=torch.int32,
|
|
347
|
-
).unsqueeze(0),
|
|
348
|
-
],
|
|
349
|
-
dim=-1,
|
|
350
|
-
)
|
|
351
|
-
|
|
352
|
-
if position_embed is not None:
|
|
353
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
354
|
-
|
|
355
|
-
# Overwrite position_ids and padded_cache_lengths
|
|
356
|
-
position_ids = None
|
|
357
|
-
padded_cache_lengths = 0
|
|
358
|
-
|
|
359
|
-
return (
|
|
360
|
-
inputs,
|
|
361
|
-
cache_position,
|
|
362
|
-
chunked_attention_mask,
|
|
363
|
-
out_buffers,
|
|
364
|
-
position_ids,
|
|
365
|
-
position_embed,
|
|
366
|
-
padded_cache_lengths,
|
|
367
|
-
query_length,
|
|
368
|
-
)
|
|
369
|
-
|
|
370
|
-
def prefill_forward(
|
|
371
|
-
self,
|
|
372
|
-
inputs: torch.Tensor,
|
|
373
|
-
cache_position: torch.Tensor = None,
|
|
374
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
375
|
-
batch_idx: int = None,
|
|
376
|
-
block_tables: torch.Tensor = None,
|
|
377
|
-
is_external_block_tables: bool = False,
|
|
378
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
379
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
380
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
381
|
-
) -> torch.FloatTensor:
|
|
382
|
-
"""
|
|
383
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
384
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
385
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
386
|
-
"""
|
|
387
|
-
(
|
|
388
|
-
inputs,
|
|
389
|
-
cache_position,
|
|
390
|
-
chunked_attention_mask,
|
|
391
|
-
out_buffers,
|
|
392
|
-
position_ids,
|
|
393
|
-
position_embed,
|
|
394
|
-
padded_cache_lengths,
|
|
395
|
-
query_length,
|
|
396
|
-
) = self._prepare_prefill_inputs(
|
|
397
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
398
|
-
)
|
|
399
|
-
|
|
400
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
401
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
402
|
-
# Extract the current chunk of inputs and cache positions
|
|
403
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
404
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
405
|
-
position_ids_chunk = (
|
|
406
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
407
|
-
if position_ids is not None
|
|
408
|
-
else None
|
|
409
|
-
)
|
|
410
|
-
if position_embed is not None:
|
|
411
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
|
412
|
-
|
|
413
|
-
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
414
|
-
# Update attention mask to ensure proper causal behavior
|
|
415
|
-
if step >= self.rbln_config.prefill_chunk_size:
|
|
416
|
-
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
|
417
|
-
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
418
|
-
|
|
419
|
-
# Define query position
|
|
420
|
-
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
|
421
|
-
query_position = torch.tensor(
|
|
422
|
-
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
|
423
|
-
)
|
|
424
|
-
else:
|
|
425
|
-
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
|
426
|
-
|
|
427
|
-
# Forward pass for the current chunk
|
|
428
|
-
logits = super().forward(
|
|
429
|
-
input_chunk,
|
|
430
|
-
cache_pos_chunk,
|
|
431
|
-
block_tables,
|
|
432
|
-
local_block_tables,
|
|
433
|
-
position_embed_chunk if position_embed is not None else None,
|
|
434
|
-
query_position,
|
|
435
|
-
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
436
|
-
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
|
437
|
-
out=out_buffers,
|
|
438
|
-
)
|
|
439
|
-
|
|
440
|
-
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
441
|
-
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
442
|
-
self.dec_attn_mask[batch_idx].fill_(0)
|
|
443
|
-
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
444
|
-
|
|
445
|
-
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
@dataclass
|
|
449
|
-
class RBLNDecoderOnlyOutput(ModelOutput):
|
|
450
|
-
logits: torch.FloatTensor = None
|
|
451
|
-
generate_idx: torch.Tensor = None
|
|
452
|
-
padded_cache_lengths: int = None
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
49
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
456
50
|
"""
|
|
457
|
-
A base class for decoder-only transformer models
|
|
51
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
52
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
458
53
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
459
54
|
|
|
460
55
|
The class provides core functionality for:
|
|
461
56
|
|
|
462
57
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
463
58
|
2. Handling the compilation process for RBLN devices
|
|
464
|
-
3. Managing inference operations for
|
|
465
|
-
|
|
59
|
+
3. Managing inference operations for decoder-only architectures
|
|
466
60
|
This class inherits from RBLNModel and implements specific methods required for
|
|
467
|
-
decoder-only architectures
|
|
61
|
+
decoder-only architectures.
|
|
468
62
|
|
|
469
63
|
Note:
|
|
470
64
|
- This class is designed to be subclassed by specific model implementations
|
|
471
|
-
(e.g.,
|
|
65
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
472
66
|
- Subclasses should implement model-specific conversion logic.
|
|
473
67
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
474
68
|
"""
|
|
475
69
|
|
|
70
|
+
_tp_support = True
|
|
71
|
+
|
|
476
72
|
main_input_name = "input_ids"
|
|
477
|
-
auto_model_class =
|
|
73
|
+
auto_model_class = AutoModel
|
|
478
74
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
479
75
|
_use_rotary_emb = True
|
|
76
|
+
_supports_non_fp32 = True
|
|
480
77
|
|
|
481
78
|
def __post_init__(self, **kwargs):
|
|
482
|
-
main_input_name = self.main_input_name
|
|
483
|
-
|
|
484
79
|
if self.rbln_config.use_inputs_embeds:
|
|
485
|
-
main_input_name = "inputs_embeds"
|
|
486
80
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
487
81
|
self.embed_tokens = self._create_embedding_layer()
|
|
488
82
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
489
83
|
else:
|
|
490
84
|
self.embed_tokens = None
|
|
491
85
|
|
|
492
|
-
|
|
493
|
-
dec_attn_mask = torch.zeros(
|
|
494
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
495
|
-
)
|
|
496
|
-
block_tables = torch.zeros(
|
|
497
|
-
self.rbln_config.batch_size,
|
|
498
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
499
|
-
dtype=torch.int16,
|
|
500
|
-
).fill_(-1)
|
|
501
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
86
|
+
self.setup_runtime()
|
|
502
87
|
|
|
88
|
+
def setup_runtime(self):
|
|
89
|
+
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
92
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
|
|
93
|
+
|
|
94
|
+
common_kwargs = {
|
|
95
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
96
|
+
"embed_tokens": self.embed_tokens,
|
|
97
|
+
"dec_attn_mask": dec_attn_mask,
|
|
98
|
+
"page_table_manager": page_table_manager,
|
|
99
|
+
"rbln_config": self.rbln_config,
|
|
100
|
+
}
|
|
503
101
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
504
102
|
runtime=self.model[0],
|
|
505
|
-
main_input_name=main_input_name,
|
|
506
|
-
embed_tokens=self.embed_tokens,
|
|
507
103
|
phase="prefill",
|
|
508
104
|
batch_size=self.rbln_config.batch_size,
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
free_block_pool=free_block_pool,
|
|
512
|
-
rbln_config=self.rbln_config,
|
|
513
|
-
vocab_size=self.config.vocab_size,
|
|
105
|
+
out_buffers=out_buffers,
|
|
106
|
+
**common_kwargs,
|
|
514
107
|
)
|
|
108
|
+
if self.can_generate():
|
|
109
|
+
self.decoders = {}
|
|
110
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
111
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
112
|
+
runtime=self.model[i + 1],
|
|
113
|
+
phase="decode",
|
|
114
|
+
batch_size=batch_size,
|
|
115
|
+
**common_kwargs,
|
|
116
|
+
)
|
|
515
117
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
519
|
-
runtime=self.model[i + 1],
|
|
520
|
-
main_input_name=main_input_name,
|
|
521
|
-
embed_tokens=self.embed_tokens,
|
|
522
|
-
phase="decode",
|
|
523
|
-
batch_size=batch_size,
|
|
524
|
-
dec_attn_mask=dec_attn_mask,
|
|
525
|
-
block_tables=block_tables,
|
|
526
|
-
free_block_pool=free_block_pool,
|
|
527
|
-
rbln_config=self.rbln_config,
|
|
528
|
-
)
|
|
529
|
-
|
|
530
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
531
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
532
|
-
|
|
533
|
-
@classmethod
|
|
534
|
-
def save_torch_artifacts(
|
|
535
|
-
cls,
|
|
536
|
-
model: PreTrainedModel,
|
|
537
|
-
save_dir_path: Path,
|
|
538
|
-
subfolder: str,
|
|
539
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
540
|
-
):
|
|
541
|
-
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
542
|
-
# store the torch tensor, weight, etc. in this function.
|
|
543
|
-
if rbln_config.use_inputs_embeds:
|
|
544
|
-
save_dict = {}
|
|
545
|
-
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
546
|
-
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
547
|
-
|
|
548
|
-
def _create_embedding_layer(self):
|
|
549
|
-
with no_init_weights():
|
|
550
|
-
embed_tokens = torch.nn.Embedding(
|
|
551
|
-
self.config.vocab_size,
|
|
552
|
-
self.config.hidden_size,
|
|
553
|
-
self.config.pad_token_id,
|
|
554
|
-
)
|
|
555
|
-
return embed_tokens
|
|
556
|
-
|
|
557
|
-
def get_input_embeddings(self):
|
|
558
|
-
return self.embed_tokens
|
|
559
|
-
|
|
560
|
-
def get_attn_impl(self) -> str:
|
|
561
|
-
return self.rbln_config.attn_impl
|
|
118
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
119
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
562
120
|
|
|
563
|
-
|
|
564
|
-
|
|
121
|
+
@property
|
|
122
|
+
def prefill_output_size(self):
|
|
123
|
+
return (
|
|
124
|
+
1,
|
|
125
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
+
self.config.hidden_size,
|
|
127
|
+
)
|
|
565
128
|
|
|
566
129
|
@classmethod
|
|
567
130
|
def get_quantized_model(
|
|
@@ -575,35 +138,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
575
138
|
subfolder: str = "",
|
|
576
139
|
local_files_only: bool = False,
|
|
577
140
|
trust_remote_code: bool = False,
|
|
141
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
578
142
|
**kwargs,
|
|
579
143
|
):
|
|
580
144
|
kwargs = cls.update_kwargs(kwargs)
|
|
581
145
|
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
model_id,
|
|
585
|
-
use_auth_token=use_auth_token,
|
|
586
|
-
revision=revision,
|
|
587
|
-
force_download=force_download,
|
|
588
|
-
cache_dir=cache_dir,
|
|
589
|
-
trust_remote_code=trust_remote_code,
|
|
590
|
-
**kwargs,
|
|
591
|
-
)
|
|
592
|
-
|
|
593
|
-
with no_init_weights():
|
|
594
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
595
|
-
|
|
596
|
-
model = prepare_model_for_quantization(
|
|
597
|
-
model,
|
|
146
|
+
return get_quantized_model(
|
|
147
|
+
cls.auto_model_class,
|
|
598
148
|
model_id,
|
|
599
|
-
kwargs.get("num_hidden_layers"),
|
|
600
149
|
use_auth_token=use_auth_token,
|
|
601
150
|
revision=revision,
|
|
602
151
|
cache_dir=cache_dir,
|
|
603
152
|
force_download=force_download,
|
|
604
153
|
local_files_only=local_files_only,
|
|
154
|
+
rbln_quantization=rbln_config.quantization,
|
|
155
|
+
**kwargs,
|
|
605
156
|
)
|
|
606
|
-
return model
|
|
607
157
|
|
|
608
158
|
def __getattr__(self, __name: str) -> Any:
|
|
609
159
|
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
@@ -625,233 +175,162 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
625
175
|
return val
|
|
626
176
|
|
|
627
177
|
@classmethod
|
|
628
|
-
def
|
|
629
|
-
cls,
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
178
|
+
def save_torch_artifacts(
|
|
179
|
+
cls,
|
|
180
|
+
model: PreTrainedModel,
|
|
181
|
+
save_dir_path: Path,
|
|
182
|
+
subfolder: str,
|
|
183
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
184
|
+
):
|
|
185
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
186
|
+
# store the torch tensor, weight, etc. in this function.
|
|
187
|
+
if rbln_config.use_inputs_embeds:
|
|
188
|
+
save_dict = {}
|
|
189
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
190
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
635
191
|
|
|
636
|
-
|
|
192
|
+
def _create_embedding_layer(self):
|
|
193
|
+
with no_init_weights():
|
|
194
|
+
embed_tokens = torch.nn.Embedding(
|
|
195
|
+
self.config.vocab_size,
|
|
196
|
+
self.config.hidden_size,
|
|
197
|
+
self.config.pad_token_id,
|
|
198
|
+
)
|
|
199
|
+
return embed_tokens
|
|
637
200
|
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
return
|
|
201
|
+
def get_decoder(self):
|
|
202
|
+
if not self.can_generate():
|
|
203
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
204
|
+
return self.decoder
|
|
205
|
+
|
|
206
|
+
def can_generate(self):
|
|
207
|
+
return self.rbln_config.can_generate
|
|
208
|
+
|
|
209
|
+
def get_input_embeddings(self):
|
|
210
|
+
return self.embed_tokens
|
|
211
|
+
|
|
212
|
+
def get_attn_impl(self) -> str:
|
|
213
|
+
return self.rbln_config.attn_impl
|
|
214
|
+
|
|
215
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
216
|
+
return self.rbln_config.kvcache_num_blocks
|
|
654
217
|
|
|
655
218
|
@classmethod
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
219
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
|
+
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
659
221
|
|
|
660
|
-
|
|
661
|
-
|
|
222
|
+
@classmethod
|
|
223
|
+
def _compile_model(
|
|
224
|
+
cls,
|
|
225
|
+
wrapped_model,
|
|
226
|
+
compile_config,
|
|
227
|
+
example_inputs,
|
|
228
|
+
compile_context,
|
|
229
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
230
|
+
quantization=None,
|
|
231
|
+
phase: str = "prefill",
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
wrapped_model.phase = phase
|
|
235
|
+
if quantization:
|
|
236
|
+
quantization.maybe_set_quantization_env()
|
|
237
|
+
original_linear = torch.nn.functional.linear
|
|
238
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
239
|
+
compiled_model = cls.compile(
|
|
240
|
+
wrapped_model,
|
|
241
|
+
compile_config,
|
|
242
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
243
|
+
device=rbln_config.device,
|
|
244
|
+
example_inputs=example_inputs,
|
|
245
|
+
compile_context=compile_context,
|
|
246
|
+
)
|
|
247
|
+
return compiled_model
|
|
248
|
+
finally:
|
|
249
|
+
torch.nn.functional.linear = original_linear
|
|
250
|
+
if quantization:
|
|
251
|
+
quantization.maybe_reset_quantization_env()
|
|
662
252
|
|
|
253
|
+
@classmethod
|
|
254
|
+
def _get_compile_context(
|
|
255
|
+
cls,
|
|
256
|
+
compile_config: RBLNCompileConfig,
|
|
257
|
+
example_inputs: List[torch.Tensor],
|
|
258
|
+
):
|
|
663
259
|
context = CompileContext(use_weight_sharing=True)
|
|
664
260
|
|
|
665
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
666
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
667
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
668
|
-
|
|
669
261
|
# Mark static tensors (self kv states)
|
|
670
262
|
static_tensors = {}
|
|
671
|
-
|
|
263
|
+
idx = 0
|
|
264
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
672
265
|
if "past_key_values" in name:
|
|
673
266
|
static_tensors[name] = tensor
|
|
674
|
-
context.mark_static_address(tensor)
|
|
675
|
-
|
|
676
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
|
677
|
-
try:
|
|
678
|
-
if quantization:
|
|
679
|
-
quantization.maybe_set_quantization_env()
|
|
680
|
-
original_linear = torch.nn.functional.linear
|
|
681
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
682
|
-
compiled_model = cls.compile(
|
|
683
|
-
wrapped_model,
|
|
684
|
-
compile_config,
|
|
685
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
686
|
-
device=rbln_config.device,
|
|
687
|
-
example_inputs=example_inputs,
|
|
688
|
-
compile_context=compile_context,
|
|
689
|
-
)
|
|
690
|
-
return compiled_model
|
|
691
|
-
finally:
|
|
692
|
-
torch.nn.functional.linear = original_linear
|
|
693
|
-
if quantization:
|
|
694
|
-
quantization.maybe_reset_quantization_env()
|
|
695
|
-
|
|
696
|
-
wrapped_model.phase = "prefill"
|
|
697
|
-
compiled_prefill = compile_model(
|
|
698
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
wrapped_model.phase = "decode"
|
|
702
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
703
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
704
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
705
|
-
compiled_decoder = compile_model(
|
|
706
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
707
|
-
)
|
|
708
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
709
|
-
|
|
710
|
-
# check if the memory is enough to have additional blocks
|
|
711
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
712
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
713
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
714
|
-
compiled_models=compiled_models,
|
|
715
|
-
model_config=model.config,
|
|
716
|
-
rbln_config=rbln_config,
|
|
717
|
-
)
|
|
267
|
+
context.mark_static_address(tensor, f"kv_cache_{idx}")
|
|
268
|
+
idx += 1
|
|
718
269
|
|
|
719
|
-
return
|
|
270
|
+
return context, static_tensors
|
|
720
271
|
|
|
721
272
|
@classmethod
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
727
|
-
) -> None:
|
|
728
|
-
# Get the actual memory allocation of each node by key
|
|
729
|
-
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
|
730
|
-
alloc_memory_by_key: Dict[str, int] = {
|
|
731
|
-
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
|
732
|
-
}
|
|
733
|
-
for batch_size in rbln_config.decoder_batch_sizes:
|
|
734
|
-
for key, memory_per_node in (
|
|
735
|
-
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
|
736
|
-
):
|
|
737
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
738
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
739
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
740
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
741
|
-
|
|
742
|
-
# Get the maximum number of blocks that can be allocated
|
|
743
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
744
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
745
|
-
config=model_config,
|
|
746
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
747
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
748
|
-
kernel_size=kernel_size,
|
|
749
|
-
buffer=buffer,
|
|
750
|
-
)
|
|
273
|
+
@torch.inference_mode()
|
|
274
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
275
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
276
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
751
277
|
|
|
752
|
-
#
|
|
753
|
-
|
|
754
|
-
|
|
755
|
-
|
|
756
|
-
|
|
757
|
-
|
|
758
|
-
|
|
759
|
-
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
278
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
279
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
280
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
281
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
282
|
+
|
|
283
|
+
compiled_models = {}
|
|
284
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
285
|
+
wrapped_model,
|
|
286
|
+
prefill_compile_config,
|
|
287
|
+
prefill_example_inputs,
|
|
288
|
+
context,
|
|
289
|
+
rbln_config,
|
|
290
|
+
rbln_config.quantization,
|
|
291
|
+
phase="prefill",
|
|
292
|
+
)
|
|
763
293
|
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
770
|
-
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
# This inequality can be rewritten as follows:
|
|
789
|
-
|
|
790
|
-
# a - c * align_2MB(b * x) > 0
|
|
791
|
-
# where
|
|
792
|
-
# a = available_dram - kernel_size - buffer
|
|
793
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
794
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
795
|
-
|
|
796
|
-
# We can rewrite the inequality as follows:
|
|
797
|
-
# k > align_2MB(b*x)
|
|
798
|
-
# where
|
|
799
|
-
# k = a / c
|
|
800
|
-
|
|
801
|
-
# After that, we can derive the following equation:
|
|
802
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
803
|
-
|
|
804
|
-
def align(x: int, nbytes: int) -> int:
|
|
805
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
806
|
-
|
|
807
|
-
def align_2MB(x: int) -> int:
|
|
808
|
-
return align(x, 2**21)
|
|
809
|
-
|
|
810
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
811
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
812
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
813
|
-
vocab_size = config.vocab_size
|
|
814
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
815
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
816
|
-
|
|
817
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
818
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
819
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
820
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
821
|
-
|
|
822
|
-
if kernel_size is None:
|
|
823
|
-
if n_model_params is None:
|
|
824
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
825
|
-
# Get estimated kernel size (approximated)
|
|
826
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
827
|
-
lm_heads_nbytes = (
|
|
828
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
829
|
-
)
|
|
830
|
-
params = n_model_params - lm_heads_params
|
|
831
|
-
layer_nbytes = (
|
|
832
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
833
|
-
* num_layers
|
|
834
|
-
* tensor_parallel_size
|
|
835
|
-
)
|
|
836
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
837
|
-
elif n_model_params is not None:
|
|
838
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
294
|
+
if rbln_config.can_generate:
|
|
295
|
+
wrapped_model.phase = "decode"
|
|
296
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
297
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
298
|
+
compiled_decoder = cls._compile_model(
|
|
299
|
+
wrapped_model,
|
|
300
|
+
dec_compile_config,
|
|
301
|
+
dec_example_inputs,
|
|
302
|
+
context,
|
|
303
|
+
rbln_config,
|
|
304
|
+
rbln_config.quantization,
|
|
305
|
+
phase="decode",
|
|
306
|
+
)
|
|
307
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
308
|
+
|
|
309
|
+
# check if the memory is enough to have additional blocks
|
|
310
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
311
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
312
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
313
|
+
compiled_models=compiled_models,
|
|
314
|
+
model_config=model.config,
|
|
315
|
+
rbln_config=rbln_config,
|
|
316
|
+
)
|
|
839
317
|
|
|
840
|
-
|
|
318
|
+
return compiled_models
|
|
841
319
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
320
|
+
@classmethod
|
|
321
|
+
def get_pytorch_model(
|
|
322
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
323
|
+
) -> PreTrainedModel:
|
|
324
|
+
if rbln_config and rbln_config.quantization:
|
|
325
|
+
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
326
|
+
else:
|
|
327
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
848
328
|
|
|
849
|
-
|
|
850
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
851
|
-
k = available_dram / c
|
|
852
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
329
|
+
return model
|
|
853
330
|
|
|
854
|
-
|
|
331
|
+
@classmethod
|
|
332
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
333
|
+
return use_local_attention
|
|
855
334
|
|
|
856
335
|
@classmethod
|
|
857
336
|
def get_input_info(
|
|
@@ -861,63 +340,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
861
340
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
862
341
|
model_config: PretrainedConfig,
|
|
863
342
|
):
|
|
864
|
-
is_prefill: bool = query_length > 1
|
|
865
343
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
866
344
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
867
345
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
868
346
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
869
347
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
870
|
-
|
|
348
|
+
is_prefill = query_length > 1
|
|
871
349
|
|
|
872
|
-
|
|
350
|
+
input_info = []
|
|
873
351
|
if rbln_config.use_inputs_embeds:
|
|
874
|
-
|
|
352
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
|
|
875
353
|
else:
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
input_info = [
|
|
880
|
-
main_input,
|
|
881
|
-
(
|
|
882
|
-
"cache_position",
|
|
883
|
-
[batch_size, query_length],
|
|
884
|
-
"int32",
|
|
885
|
-
),
|
|
886
|
-
]
|
|
354
|
+
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
355
|
+
|
|
356
|
+
input_info.append(("cache_position", [batch_size, query_length], "int32"))
|
|
887
357
|
|
|
888
|
-
|
|
889
|
-
if rbln_config.cache_impl in ["static", "hybrid"]:
|
|
358
|
+
if rbln_config.use_global_attention:
|
|
890
359
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
891
|
-
input_info.
|
|
892
|
-
|
|
360
|
+
input_info.append(
|
|
361
|
+
("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
|
|
893
362
|
)
|
|
894
|
-
if rbln_config.
|
|
895
|
-
input_info.
|
|
363
|
+
if rbln_config.use_local_attention:
|
|
364
|
+
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
896
365
|
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
input_info.extend([("query_position", [], "int16")])
|
|
366
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
367
|
+
input_info.append(("query_position", [], "int16"))
|
|
900
368
|
|
|
901
|
-
# 5. attention_mask & position_ids
|
|
902
369
|
if rbln_config.use_attention_mask:
|
|
903
|
-
|
|
904
|
-
[
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
370
|
+
if rbln_config.use_position_ids:
|
|
371
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
|
|
372
|
+
else:
|
|
373
|
+
input_info.append(
|
|
374
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
|
|
375
|
+
)
|
|
376
|
+
|
|
910
377
|
if rbln_config.use_position_ids:
|
|
911
378
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
912
379
|
|
|
913
|
-
|
|
380
|
+
if rbln_config.use_lora:
|
|
381
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
382
|
+
|
|
383
|
+
kvcache_dtype = rbln_config.torch_dtype
|
|
384
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
385
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
386
|
+
|
|
914
387
|
global_kvcache_shape = [
|
|
915
388
|
rbln_config.kvcache_num_blocks,
|
|
916
389
|
num_key_value_heads,
|
|
917
390
|
rbln_config.kvcache_block_size,
|
|
918
391
|
head_dim,
|
|
919
392
|
]
|
|
920
|
-
local_kvcache_shape = [
|
|
393
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
921
394
|
input_info.extend(
|
|
922
395
|
[
|
|
923
396
|
(
|
|
@@ -925,7 +398,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
925
398
|
local_kvcache_shape
|
|
926
399
|
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
|
927
400
|
else global_kvcache_shape,
|
|
928
|
-
|
|
401
|
+
kvcache_dtype,
|
|
929
402
|
)
|
|
930
403
|
for i in range(num_hidden_layers * 2)
|
|
931
404
|
]
|
|
@@ -964,7 +437,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
964
437
|
# ```
|
|
965
438
|
|
|
966
439
|
# Returns:
|
|
967
|
-
#
|
|
440
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
968
441
|
|
|
969
442
|
raise NotImplementedError(
|
|
970
443
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
@@ -972,27 +445,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
972
445
|
)
|
|
973
446
|
|
|
974
447
|
@classmethod
|
|
975
|
-
def
|
|
976
|
-
cls,
|
|
977
|
-
|
|
978
|
-
model: Optional[PreTrainedModel] = None,
|
|
979
|
-
model_config: Optional[PretrainedConfig] = None,
|
|
980
|
-
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
981
|
-
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
982
|
-
if rbln_config.max_seq_len is None:
|
|
983
|
-
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
984
|
-
model_config, "n_positions", None
|
|
985
|
-
)
|
|
986
|
-
if rbln_config.max_seq_len is None:
|
|
987
|
-
raise ValueError("`max_seq_len` should be specified.")
|
|
988
|
-
|
|
989
|
-
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
990
|
-
model_config, "use_sliding_window", True
|
|
991
|
-
):
|
|
992
|
-
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
993
|
-
if rbln_config.sliding_window is not None:
|
|
994
|
-
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
|
995
|
-
|
|
448
|
+
def _update_attention_config(
|
|
449
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
450
|
+
):
|
|
996
451
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
997
452
|
attn_impl=rbln_config.attn_impl,
|
|
998
453
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1007,40 +462,77 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1007
462
|
max_seq_len=rbln_config.max_seq_len,
|
|
1008
463
|
)
|
|
1009
464
|
|
|
1010
|
-
|
|
1011
|
-
max_num_blocks = required_num_blocks
|
|
465
|
+
num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1012
466
|
|
|
467
|
+
# Update kvcache_num_blocks based on the attention implementation.
|
|
1013
468
|
if rbln_config.attn_impl == "flash_attn":
|
|
1014
|
-
estimated_max_num_blocks = cls.
|
|
1015
|
-
|
|
1016
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
|
1017
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1018
|
-
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1019
|
-
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1020
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
469
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
|
|
470
|
+
model=model, model_config=model_config, rbln_config=rbln_config
|
|
1021
471
|
)
|
|
1022
472
|
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
473
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
474
|
+
if estimated_max_num_blocks < num_full_blocks:
|
|
475
|
+
# lower bound of the number of blocks for flash attention.
|
|
476
|
+
min_blocks_for_flash = min(
|
|
477
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
|
|
478
|
+
)
|
|
479
|
+
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
480
|
+
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
481
|
+
# Even if it's larger than the estimated maximum number of blocks.
|
|
482
|
+
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
483
|
+
else:
|
|
484
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
485
|
+
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
486
|
+
|
|
487
|
+
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
488
|
+
raise RuntimeError(
|
|
489
|
+
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
490
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
494
|
+
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
495
|
+
logger.warning(
|
|
496
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
497
|
+
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
498
|
+
"This can cause a failure during model compilation."
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
502
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
503
|
+
elif rbln_config.kvcache_num_blocks > num_full_blocks:
|
|
504
|
+
logger.warning(
|
|
505
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
506
|
+
f" than the required number of blocks ({num_full_blocks})."
|
|
507
|
+
"This can cause a failure during model compilation."
|
|
1033
508
|
)
|
|
509
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1034
510
|
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
1040
|
-
|
|
1041
|
-
|
|
511
|
+
return rbln_config
|
|
512
|
+
|
|
513
|
+
@classmethod
|
|
514
|
+
def _update_rbln_config(
|
|
515
|
+
cls,
|
|
516
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
517
|
+
model: Optional[PreTrainedModel] = None,
|
|
518
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
519
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
520
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
521
|
+
if rbln_config.max_seq_len is None:
|
|
522
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
523
|
+
model_config, "n_positions", None
|
|
1042
524
|
)
|
|
1043
|
-
|
|
525
|
+
if rbln_config.max_seq_len is None:
|
|
526
|
+
raise ValueError("`max_seq_len` should be specified.")
|
|
527
|
+
|
|
528
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
529
|
+
model_config, "use_sliding_window", True
|
|
530
|
+
):
|
|
531
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
532
|
+
if rbln_config.sliding_window is not None:
|
|
533
|
+
validate_sliding_window(rbln_config)
|
|
534
|
+
|
|
535
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
1044
536
|
|
|
1045
537
|
prefill_input_info = cls.get_input_info(
|
|
1046
538
|
batch_size=1,
|
|
@@ -1050,19 +542,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1050
542
|
)
|
|
1051
543
|
|
|
1052
544
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1053
|
-
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
545
|
+
compile_cfgs = [prefill_compile_config]
|
|
546
|
+
|
|
547
|
+
if rbln_config.can_generate:
|
|
548
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
549
|
+
dec_input_info = cls.get_input_info(
|
|
550
|
+
batch_size=batch_size,
|
|
551
|
+
query_length=1,
|
|
552
|
+
rbln_config=rbln_config,
|
|
553
|
+
model_config=model_config,
|
|
554
|
+
)
|
|
555
|
+
compile_cfgs.append(
|
|
556
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
557
|
+
)
|
|
558
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
1066
559
|
|
|
1067
560
|
return rbln_config
|
|
1068
561
|
|
|
@@ -1072,101 +565,164 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1072
565
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1073
566
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1074
567
|
) -> List[rebel.Runtime]:
|
|
1075
|
-
expected_model_names = [
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
568
|
+
expected_model_names = ["prefill"]
|
|
569
|
+
if rbln_config.can_generate:
|
|
570
|
+
expected_model_names.extend(
|
|
571
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
572
|
+
)
|
|
1079
573
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1080
574
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1081
575
|
|
|
1082
|
-
|
|
576
|
+
ret_val = [
|
|
1083
577
|
rebel.Runtime(
|
|
1084
578
|
compiled_models[0],
|
|
1085
579
|
tensor_type="pt",
|
|
1086
580
|
device=rbln_config.device_map["prefill"],
|
|
1087
581
|
activate_profiler=rbln_config.activate_profiler,
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
rebel.Runtime(
|
|
1091
|
-
compiled_models[i + 1],
|
|
1092
|
-
tensor_type="pt",
|
|
1093
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1094
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1095
|
-
)
|
|
1096
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1097
|
-
],
|
|
582
|
+
timeout=rbln_config.timeout,
|
|
583
|
+
)
|
|
1098
584
|
]
|
|
585
|
+
if rbln_config.can_generate:
|
|
586
|
+
ret_val.extend(
|
|
587
|
+
[
|
|
588
|
+
rebel.Runtime(
|
|
589
|
+
compiled_models[i + 1],
|
|
590
|
+
tensor_type="pt",
|
|
591
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
592
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
593
|
+
timeout=rbln_config.timeout,
|
|
594
|
+
)
|
|
595
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
596
|
+
]
|
|
597
|
+
)
|
|
598
|
+
return ret_val
|
|
1099
599
|
|
|
1100
|
-
def
|
|
1101
|
-
return self.decoder
|
|
1102
|
-
|
|
1103
|
-
def can_generate(self):
|
|
1104
|
-
return True
|
|
1105
|
-
|
|
1106
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1107
|
-
raise NotImplementedError
|
|
1108
|
-
|
|
1109
|
-
def prepare_inputs_for_generation(
|
|
600
|
+
def forward(
|
|
1110
601
|
self,
|
|
1111
|
-
input_ids: torch.LongTensor,
|
|
1112
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
1113
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
602
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
1114
603
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
1115
|
-
|
|
604
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
1116
605
|
**kwargs,
|
|
1117
|
-
):
|
|
1118
|
-
|
|
1119
|
-
|
|
606
|
+
) -> BaseModelOutputWithPast:
|
|
607
|
+
"""
|
|
608
|
+
Args:
|
|
609
|
+
input_ids (torch.LongTensor, optional): The input IDs to the model.
|
|
610
|
+
inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
|
|
611
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
612
|
+
kwargs (dict[str, Any], optional): Additional keyword arguments.
|
|
1120
613
|
|
|
1121
|
-
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
614
|
+
Returns:
|
|
615
|
+
Dataclass containing the last hidden states of the model.
|
|
616
|
+
"""
|
|
617
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
618
|
+
batch_size = inputs.shape[0]
|
|
619
|
+
position_embed = kwargs.get("position_embed", None)
|
|
620
|
+
|
|
621
|
+
if batch_size != self.rbln_config.batch_size:
|
|
622
|
+
raise ValueError(
|
|
623
|
+
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
all_last_hidden_states = []
|
|
627
|
+
for b_idx in range(self.rbln_config.batch_size):
|
|
628
|
+
query_length = (
|
|
629
|
+
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
630
|
+
)
|
|
631
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
632
|
+
last_hidden_states = self.prefill_decoder(
|
|
633
|
+
inputs[b_idx : b_idx + 1],
|
|
634
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
635
|
+
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
636
|
+
cache_position=cache_position,
|
|
637
|
+
batch_idx=b_idx,
|
|
638
|
+
).logits
|
|
639
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
640
|
+
|
|
641
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
642
|
+
|
|
643
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
647
|
+
"""
|
|
648
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
649
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
650
|
+
|
|
651
|
+
The class provides core functionality for:
|
|
652
|
+
|
|
653
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
654
|
+
2. Handling the compilation process for RBLN devices
|
|
655
|
+
3. Managing inference operations for causal language modeling
|
|
656
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
657
|
+
decoder-only architectures and causal language modeling tasks.
|
|
658
|
+
|
|
659
|
+
Note:
|
|
660
|
+
- This class is designed to be subclassed by specific model implementations
|
|
661
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
662
|
+
- Subclasses should implement model-specific conversion logic.
|
|
663
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
664
|
+
"""
|
|
665
|
+
|
|
666
|
+
auto_model_class = AutoModelForCausalLM
|
|
667
|
+
|
|
668
|
+
@property
|
|
669
|
+
def prefill_output_size(self):
|
|
670
|
+
return (
|
|
671
|
+
1,
|
|
672
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
673
|
+
self.config.vocab_size,
|
|
1155
674
|
)
|
|
1156
675
|
|
|
1157
|
-
|
|
676
|
+
@classmethod
|
|
677
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
678
|
+
return is_prefill
|
|
679
|
+
|
|
680
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
681
|
+
if isinstance(lora_int_ids, int):
|
|
682
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
683
|
+
elif isinstance(lora_int_ids, list):
|
|
684
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
1158
685
|
|
|
1159
|
-
|
|
1160
|
-
self,
|
|
1161
|
-
outputs: RBLNDecoderOnlyOutput,
|
|
1162
|
-
model_kwargs: Dict[str, Any],
|
|
1163
|
-
**kwargs,
|
|
1164
|
-
) -> Dict[str, Any]:
|
|
1165
|
-
# update generate_idx
|
|
1166
|
-
model_kwargs["generate_idx"] = outputs.generate_idx
|
|
1167
|
-
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
|
686
|
+
self.lora_int_ids = lora_int_ids
|
|
1168
687
|
|
|
1169
|
-
|
|
688
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
689
|
+
if self.rbln_config.can_generate:
|
|
690
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
691
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
692
|
+
|
|
693
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
694
|
+
"""
|
|
695
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
699
|
+
Can be a single adapter name or a list of adapter names.
|
|
700
|
+
|
|
701
|
+
Raises:
|
|
702
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
703
|
+
"""
|
|
704
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
705
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
706
|
+
|
|
707
|
+
# Convert single adapter name to list for uniform processing
|
|
708
|
+
if isinstance(adapter_name, str):
|
|
709
|
+
adapter_names = [adapter_name]
|
|
710
|
+
else:
|
|
711
|
+
adapter_names = adapter_name
|
|
712
|
+
|
|
713
|
+
# Validate that all adapter names exist
|
|
714
|
+
available_adapters = {
|
|
715
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
716
|
+
}
|
|
717
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
718
|
+
if missing_adapters:
|
|
719
|
+
raise ValueError(
|
|
720
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Get the adapter IDs and set them
|
|
724
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
725
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
1170
726
|
|
|
1171
727
|
def forward(
|
|
1172
728
|
self,
|
|
@@ -1178,6 +734,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1178
734
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
1179
735
|
position_ids: Optional[torch.Tensor] = None,
|
|
1180
736
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
737
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
1181
738
|
return_dict: Optional[torch.Tensor] = None,
|
|
1182
739
|
**kwargs,
|
|
1183
740
|
) -> Tuple[torch.FloatTensor]:
|
|
@@ -1185,12 +742,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1185
742
|
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
1186
743
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1187
744
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
745
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
746
|
+
if self.lora_int_ids is None:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
"lora_int_id is required when using LoRA. "
|
|
749
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
750
|
+
)
|
|
751
|
+
lora_int_ids = self.lora_int_ids
|
|
752
|
+
|
|
753
|
+
# for only use forward
|
|
754
|
+
if generate_idx is None:
|
|
755
|
+
generate_idx = (
|
|
756
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
757
|
+
if attention_mask is not None
|
|
758
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
759
|
+
)
|
|
760
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1188
761
|
|
|
1189
|
-
#
|
|
762
|
+
# Prefill
|
|
1190
763
|
if cache_position is None:
|
|
1191
764
|
logits = []
|
|
1192
765
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
1193
766
|
batch_size = inputs.shape[0]
|
|
767
|
+
input_len = inputs.shape[1]
|
|
768
|
+
if batch_size > self.rbln_config.batch_size:
|
|
769
|
+
raise ValueError(
|
|
770
|
+
f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
|
|
771
|
+
)
|
|
772
|
+
if input_len > self.rbln_config.max_seq_len:
|
|
773
|
+
raise ValueError(
|
|
774
|
+
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
775
|
+
)
|
|
776
|
+
|
|
1194
777
|
for b_idx in range(batch_size):
|
|
1195
778
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
1196
779
|
output = self.prefill_decoder(
|
|
@@ -1200,6 +783,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1200
783
|
cache_position=cache_position,
|
|
1201
784
|
batch_idx=b_idx,
|
|
1202
785
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
786
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
1203
787
|
)
|
|
1204
788
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
1205
789
|
logits.append(output.logits)
|
|
@@ -1214,11 +798,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1214
798
|
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
|
1215
799
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
1216
800
|
)
|
|
801
|
+
if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
|
|
802
|
+
raise ValueError(
|
|
803
|
+
f"Cache position exceeds the maximum sequence length.\n"
|
|
804
|
+
f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
|
|
805
|
+
f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
|
|
806
|
+
f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
|
|
807
|
+
f"or `max_length` in the generation config."
|
|
808
|
+
)
|
|
809
|
+
|
|
1217
810
|
logits = self.decoders[batch_size](
|
|
1218
811
|
input_ids=input_ids,
|
|
1219
812
|
inputs_embeds=inputs_embeds,
|
|
1220
813
|
cache_position=cache_position,
|
|
1221
814
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
815
|
+
lora_int_ids=lora_int_ids,
|
|
1222
816
|
).logits
|
|
1223
817
|
|
|
1224
818
|
if not return_dict:
|