optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +108 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +156 -43
- optimum/rbln/diffusers/__init__.py +19 -0
- optimum/rbln/diffusers/configurations/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +4 -0
- optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +31 -6
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
- optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/controlnet.py +16 -1
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +25 -2
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
- optimum/rbln/diffusers/models/unets/__init__.py +1 -0
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
- optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
- optimum/rbln/diffusers/pipelines/__init__.py +15 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
- optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
- optimum/rbln/modeling.py +48 -21
- optimum/rbln/modeling_base.py +99 -22
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +7 -32
- optimum/rbln/transformers/modeling_attention_utils.py +385 -0
- optimum/rbln/transformers/modeling_generic.py +48 -65
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
- optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +67 -6
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +485 -905
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
- optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -351
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
- optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
- optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +58 -13
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +20 -16
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +30 -5
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
- optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/deprecation.py +213 -0
- optimum/rbln/utils/hub.py +14 -3
- optimum/rbln/utils/runtime_utils.py +60 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
- optimum_rbln-0.9.3.dist-info/RECORD +264 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
- optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
- optimum_rbln-0.8.2a4.dist-info/RECORD +0 -215
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,31 +13,31 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
|
-
from collections import deque
|
|
18
|
-
from dataclasses import dataclass
|
|
19
16
|
from pathlib import Path
|
|
20
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
21
18
|
|
|
22
19
|
import rebel
|
|
23
20
|
import torch
|
|
24
21
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import
|
|
22
|
+
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
24
|
from transformers.modeling_utils import no_init_weights
|
|
27
|
-
from transformers.utils import ModelOutput
|
|
28
25
|
|
|
29
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
30
27
|
from ....modeling import RBLNModel
|
|
31
28
|
from ....utils.logging import get_logger
|
|
32
|
-
from
|
|
33
|
-
|
|
34
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
29
|
+
from ...modeling_attention_utils import (
|
|
30
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
31
|
set_default_values,
|
|
38
32
|
validate_attention_method,
|
|
39
|
-
|
|
33
|
+
validate_sliding_window,
|
|
40
34
|
)
|
|
35
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
|
+
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
|
+
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
logger = get_logger()
|
|
@@ -46,529 +46,85 @@ if TYPE_CHECKING:
|
|
|
46
46
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class
|
|
50
|
-
mandatory_members = ["main_input_name", "embed_tokens"]
|
|
51
|
-
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
runtime: rebel.Runtime,
|
|
55
|
-
phase: str,
|
|
56
|
-
batch_size: int,
|
|
57
|
-
dec_attn_mask: torch.Tensor,
|
|
58
|
-
block_tables: torch.Tensor,
|
|
59
|
-
free_block_pool: Deque,
|
|
60
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
61
|
-
**kwargs: Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
super().__init__(runtime, **kwargs)
|
|
64
|
-
self.phase = phase
|
|
65
|
-
self.batch_size = batch_size
|
|
66
|
-
self.rbln_config = rbln_config
|
|
67
|
-
|
|
68
|
-
# shared tensor between prefill and decode phase
|
|
69
|
-
self.dec_attn_mask = dec_attn_mask
|
|
70
|
-
self.block_tables = block_tables
|
|
71
|
-
self.free_block_pool = free_block_pool
|
|
72
|
-
|
|
73
|
-
self.empty_block = -1
|
|
74
|
-
if self.phase == "prefill":
|
|
75
|
-
vocab_size = kwargs.pop("vocab_size")
|
|
76
|
-
self.output_size = [1, 1, vocab_size]
|
|
77
|
-
self.causal_mask = 1 - torch.triu(
|
|
78
|
-
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
|
|
82
|
-
"""
|
|
83
|
-
Manages and returns the KV cache block tables.
|
|
84
|
-
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
|
88
|
-
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Updated block tables.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
NO_BLOCKS_ERROR = (
|
|
95
|
-
"No memory blocks are available for allocation. "
|
|
96
|
-
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
|
97
|
-
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
|
98
|
-
"Using vllm-rbln should fix this issue and enhance inference performance."
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
def update_block(batch_idx: int, block_idx: int):
|
|
102
|
-
"""
|
|
103
|
-
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
104
|
-
"""
|
|
105
|
-
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
|
106
|
-
if self.free_block_pool:
|
|
107
|
-
block = self.free_block_pool.popleft()
|
|
108
|
-
self.block_tables[batch_idx][block_idx] = block
|
|
109
|
-
else:
|
|
110
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
111
|
-
|
|
112
|
-
def replace_empty_block(block_tables: torch.Tensor):
|
|
113
|
-
"""
|
|
114
|
-
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
|
115
|
-
"""
|
|
116
|
-
if not torch.any(block_tables == self.empty_block):
|
|
117
|
-
return block_tables.clone()
|
|
118
|
-
elif self.free_block_pool:
|
|
119
|
-
_free_block = self.free_block_pool[0]
|
|
120
|
-
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
|
121
|
-
else:
|
|
122
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
123
|
-
|
|
124
|
-
def get_global_block_tables(batch_idx: int):
|
|
125
|
-
if self.rbln_config.cache_impl == "sliding_window":
|
|
126
|
-
return None
|
|
127
|
-
|
|
128
|
-
if self.phase == "prefill":
|
|
129
|
-
# Track previously used blocks and return them to the free_block_pool and
|
|
130
|
-
# reset the current batch's block table to empty blocks
|
|
131
|
-
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
|
132
|
-
self.free_block_pool.extend(prev_blocks)
|
|
133
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
|
134
|
-
|
|
135
|
-
# Get the start (s) and end (e) positions from cache_position and
|
|
136
|
-
# iterate over the cache positions to allocate necessary blocks
|
|
137
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
138
|
-
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
139
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
140
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
141
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
142
|
-
update_block(batch_idx, block_idx)
|
|
143
|
-
|
|
144
|
-
return replace_empty_block(self.block_tables[batch_idx])
|
|
145
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
|
146
|
-
else:
|
|
147
|
-
for b_idx in range(self.batch_size):
|
|
148
|
-
position = cache_position[b_idx][0].item()
|
|
149
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
150
|
-
update_block(b_idx, block_idx)
|
|
151
|
-
|
|
152
|
-
return replace_empty_block(self.block_tables)
|
|
153
|
-
|
|
154
|
-
def get_local_block_tables(batch_idx: int):
|
|
155
|
-
if self.rbln_config.cache_impl == "static":
|
|
156
|
-
return None
|
|
157
|
-
else:
|
|
158
|
-
return (
|
|
159
|
-
torch.tensor([batch_idx], dtype=torch.int16)
|
|
160
|
-
if self.phase == "prefill"
|
|
161
|
-
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
|
165
|
-
|
|
166
|
-
def is_external_block_tables(
|
|
167
|
-
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
|
168
|
-
):
|
|
169
|
-
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
|
170
|
-
return False
|
|
171
|
-
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
|
172
|
-
return False
|
|
173
|
-
elif self.rbln_config.cache_impl == "hybrid":
|
|
174
|
-
if (block_tables is not None) != (local_block_tables is not None):
|
|
175
|
-
raise ValueError(
|
|
176
|
-
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
|
177
|
-
)
|
|
178
|
-
elif block_tables is None and local_block_tables is None:
|
|
179
|
-
return False
|
|
180
|
-
|
|
181
|
-
return True
|
|
182
|
-
|
|
183
|
-
def forward(
|
|
184
|
-
self,
|
|
185
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
186
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
187
|
-
cache_position: torch.Tensor = None,
|
|
188
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
189
|
-
batch_idx: Optional[int] = None,
|
|
190
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
192
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
193
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
194
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
195
|
-
):
|
|
196
|
-
if input_ids is None and inputs_embeds is None:
|
|
197
|
-
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
|
198
|
-
|
|
199
|
-
if inputs_embeds is None:
|
|
200
|
-
inputs = input_ids
|
|
201
|
-
if self.embed_tokens is not None:
|
|
202
|
-
inputs = self.embed_tokens(inputs)
|
|
203
|
-
else:
|
|
204
|
-
inputs = inputs_embeds
|
|
205
|
-
|
|
206
|
-
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
|
207
|
-
if not is_external_block_tables:
|
|
208
|
-
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
|
209
|
-
|
|
210
|
-
if self.phase == "decode":
|
|
211
|
-
return self.decode_forward(
|
|
212
|
-
inputs,
|
|
213
|
-
cache_position,
|
|
214
|
-
block_tables,
|
|
215
|
-
is_external_block_tables,
|
|
216
|
-
attention_mask=attention_mask,
|
|
217
|
-
position_embed=position_embed,
|
|
218
|
-
position_ids=position_ids,
|
|
219
|
-
local_block_tables=local_block_tables,
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
return self.prefill_forward(
|
|
223
|
-
inputs,
|
|
224
|
-
cache_position,
|
|
225
|
-
attention_mask,
|
|
226
|
-
batch_idx,
|
|
227
|
-
block_tables,
|
|
228
|
-
is_external_block_tables=is_external_block_tables,
|
|
229
|
-
position_embed=position_embed,
|
|
230
|
-
token_type_ids=token_type_ids,
|
|
231
|
-
local_block_tables=local_block_tables,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
def decode_forward(
|
|
235
|
-
self,
|
|
236
|
-
inputs: torch.Tensor,
|
|
237
|
-
cache_position: torch.Tensor = None,
|
|
238
|
-
block_tables: torch.Tensor = None,
|
|
239
|
-
is_external_block_tables: bool = None,
|
|
240
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
241
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
242
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
243
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
244
|
-
) -> torch.FloatTensor:
|
|
245
|
-
batch_size = inputs.shape[0]
|
|
246
|
-
if batch_size != self.batch_size:
|
|
247
|
-
raise RuntimeError(
|
|
248
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
if batch_size != cache_position.shape[0]:
|
|
252
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
253
|
-
|
|
254
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
255
|
-
for b_idx in range(batch_size):
|
|
256
|
-
decoding_step = cache_position[b_idx].item()
|
|
257
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
258
|
-
raise ValueError(
|
|
259
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if is_external_block_tables:
|
|
263
|
-
self.dec_attn_mask[b_idx].fill_(0)
|
|
264
|
-
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
265
|
-
else:
|
|
266
|
-
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
267
|
-
|
|
268
|
-
attention_mask = self.dec_attn_mask
|
|
269
|
-
|
|
270
|
-
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
|
271
|
-
block_tables = block_tables[: self.batch_size]
|
|
272
|
-
|
|
273
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
274
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
275
|
-
|
|
276
|
-
logits = super().forward(
|
|
277
|
-
inputs,
|
|
278
|
-
cache_position,
|
|
279
|
-
block_tables,
|
|
280
|
-
local_block_tables,
|
|
281
|
-
position_embed,
|
|
282
|
-
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
283
|
-
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
287
|
-
|
|
288
|
-
def _prepare_prefill_inputs(
|
|
289
|
-
self,
|
|
290
|
-
inputs: torch.Tensor,
|
|
291
|
-
cache_position: torch.Tensor,
|
|
292
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
293
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
294
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
295
|
-
):
|
|
296
|
-
"""
|
|
297
|
-
Prepare inputs for prefill phase.
|
|
298
|
-
"""
|
|
299
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
300
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
301
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
302
|
-
if position_embed is not None:
|
|
303
|
-
position_embed = (
|
|
304
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
305
|
-
)
|
|
306
|
-
if token_type_ids is not None:
|
|
307
|
-
token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
|
|
308
|
-
|
|
309
|
-
query_length = inputs.shape[1]
|
|
310
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
311
|
-
raise ValueError(
|
|
312
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# Initialize attention mask for chunked processing
|
|
316
|
-
chunked_attention_mask = (
|
|
317
|
-
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
318
|
-
if self.rbln_config.use_attention_mask
|
|
319
|
-
else None
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
# Buffer for storing output logits
|
|
323
|
-
out_buffers = [
|
|
324
|
-
torch.empty(
|
|
325
|
-
size=self.output_size,
|
|
326
|
-
dtype=torch.float32,
|
|
327
|
-
device="cpu",
|
|
328
|
-
)
|
|
329
|
-
]
|
|
330
|
-
|
|
331
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
332
|
-
padding_size = 0
|
|
333
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
334
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
335
|
-
# inputs_embeds
|
|
336
|
-
if inputs.dim() == 3:
|
|
337
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
338
|
-
# inputs_ids
|
|
339
|
-
else:
|
|
340
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
341
|
-
|
|
342
|
-
cache_position = torch.cat(
|
|
343
|
-
[
|
|
344
|
-
cache_position,
|
|
345
|
-
torch.arange(
|
|
346
|
-
query_length,
|
|
347
|
-
query_length + padding_size,
|
|
348
|
-
dtype=torch.int32,
|
|
349
|
-
).unsqueeze(0),
|
|
350
|
-
],
|
|
351
|
-
dim=-1,
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
if position_embed is not None:
|
|
355
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
356
|
-
|
|
357
|
-
if token_type_ids is not None:
|
|
358
|
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
359
|
-
|
|
360
|
-
# Overwrite position_ids and padded_cache_lengths
|
|
361
|
-
position_ids = cache_position.clone()
|
|
362
|
-
padded_cache_lengths = 0
|
|
363
|
-
|
|
364
|
-
return (
|
|
365
|
-
inputs,
|
|
366
|
-
cache_position,
|
|
367
|
-
chunked_attention_mask,
|
|
368
|
-
out_buffers,
|
|
369
|
-
position_ids,
|
|
370
|
-
position_embed,
|
|
371
|
-
padded_cache_lengths,
|
|
372
|
-
query_length,
|
|
373
|
-
token_type_ids,
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
def prefill_forward(
|
|
377
|
-
self,
|
|
378
|
-
inputs: torch.Tensor,
|
|
379
|
-
cache_position: torch.Tensor = None,
|
|
380
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
381
|
-
batch_idx: int = None,
|
|
382
|
-
block_tables: torch.Tensor = None,
|
|
383
|
-
is_external_block_tables: bool = False,
|
|
384
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
385
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
386
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
387
|
-
) -> torch.FloatTensor:
|
|
388
|
-
"""
|
|
389
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
390
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
391
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
392
|
-
"""
|
|
393
|
-
(
|
|
394
|
-
inputs,
|
|
395
|
-
cache_position,
|
|
396
|
-
chunked_attention_mask,
|
|
397
|
-
out_buffers,
|
|
398
|
-
position_ids,
|
|
399
|
-
position_embed,
|
|
400
|
-
padded_cache_lengths,
|
|
401
|
-
query_length,
|
|
402
|
-
token_type_ids,
|
|
403
|
-
) = self._prepare_prefill_inputs(
|
|
404
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
408
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
409
|
-
# Extract the current chunk of inputs and cache positions
|
|
410
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
411
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
412
|
-
position_ids_chunk = (
|
|
413
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
414
|
-
if position_ids is not None
|
|
415
|
-
else None
|
|
416
|
-
)
|
|
417
|
-
if position_embed is not None:
|
|
418
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
|
419
|
-
|
|
420
|
-
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
421
|
-
# Update attention mask to ensure proper causal behavior
|
|
422
|
-
if step >= self.rbln_config.prefill_chunk_size:
|
|
423
|
-
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
|
424
|
-
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
425
|
-
|
|
426
|
-
# Define query position
|
|
427
|
-
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
|
428
|
-
query_position = torch.tensor(
|
|
429
|
-
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
|
430
|
-
)
|
|
431
|
-
else:
|
|
432
|
-
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
|
433
|
-
|
|
434
|
-
# Forward pass for the current chunk
|
|
435
|
-
logits = super().forward(
|
|
436
|
-
input_chunk,
|
|
437
|
-
cache_pos_chunk,
|
|
438
|
-
block_tables,
|
|
439
|
-
local_block_tables,
|
|
440
|
-
position_embed_chunk if position_embed is not None else None,
|
|
441
|
-
query_position,
|
|
442
|
-
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
443
|
-
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
|
444
|
-
out=out_buffers,
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
448
|
-
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
449
|
-
self.dec_attn_mask[batch_idx].fill_(0)
|
|
450
|
-
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
451
|
-
|
|
452
|
-
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
@dataclass
|
|
456
|
-
class RBLNDecoderOnlyOutput(ModelOutput):
|
|
457
|
-
logits: torch.FloatTensor = None
|
|
458
|
-
generate_idx: torch.Tensor = None
|
|
459
|
-
padded_cache_lengths: int = None
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
49
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
463
50
|
"""
|
|
464
|
-
A base class for decoder-only transformer models
|
|
51
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
52
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
465
53
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
466
54
|
|
|
467
55
|
The class provides core functionality for:
|
|
468
56
|
|
|
469
57
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
470
58
|
2. Handling the compilation process for RBLN devices
|
|
471
|
-
3. Managing inference operations for
|
|
472
|
-
|
|
59
|
+
3. Managing inference operations for decoder-only architectures
|
|
473
60
|
This class inherits from RBLNModel and implements specific methods required for
|
|
474
|
-
decoder-only architectures
|
|
61
|
+
decoder-only architectures.
|
|
475
62
|
|
|
476
63
|
Note:
|
|
477
64
|
- This class is designed to be subclassed by specific model implementations
|
|
478
|
-
(e.g.,
|
|
65
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
479
66
|
- Subclasses should implement model-specific conversion logic.
|
|
480
67
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
481
68
|
"""
|
|
482
69
|
|
|
70
|
+
_tp_support = True
|
|
71
|
+
|
|
483
72
|
main_input_name = "input_ids"
|
|
484
|
-
auto_model_class =
|
|
73
|
+
auto_model_class = AutoModel
|
|
485
74
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
486
75
|
_use_rotary_emb = True
|
|
76
|
+
_supports_non_fp32 = True
|
|
487
77
|
|
|
488
78
|
def __post_init__(self, **kwargs):
|
|
489
|
-
main_input_name = self.main_input_name
|
|
490
|
-
|
|
491
79
|
if self.rbln_config.use_inputs_embeds:
|
|
492
|
-
main_input_name = "inputs_embeds"
|
|
493
80
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
494
81
|
self.embed_tokens = self._create_embedding_layer()
|
|
495
82
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
496
83
|
else:
|
|
497
84
|
self.embed_tokens = None
|
|
498
85
|
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
)
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
dtype=torch.int16,
|
|
507
|
-
).fill_(-1)
|
|
508
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
86
|
+
self.setup_runtime()
|
|
87
|
+
|
|
88
|
+
def setup_runtime(self):
|
|
89
|
+
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
92
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
|
|
509
93
|
|
|
94
|
+
common_kwargs = {
|
|
95
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
96
|
+
"embed_tokens": self.embed_tokens,
|
|
97
|
+
"dec_attn_mask": dec_attn_mask,
|
|
98
|
+
"page_table_manager": page_table_manager,
|
|
99
|
+
"rbln_config": self.rbln_config,
|
|
100
|
+
}
|
|
510
101
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
511
102
|
runtime=self.model[0],
|
|
512
|
-
main_input_name=main_input_name,
|
|
513
|
-
embed_tokens=self.embed_tokens,
|
|
514
103
|
phase="prefill",
|
|
515
104
|
batch_size=self.rbln_config.batch_size,
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
free_block_pool=free_block_pool,
|
|
519
|
-
rbln_config=self.rbln_config,
|
|
520
|
-
vocab_size=self.config.vocab_size,
|
|
105
|
+
out_buffers=out_buffers,
|
|
106
|
+
**common_kwargs,
|
|
521
107
|
)
|
|
108
|
+
if self.can_generate():
|
|
109
|
+
self.decoders = {}
|
|
110
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
111
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
112
|
+
runtime=self.model[i + 1],
|
|
113
|
+
phase="decode",
|
|
114
|
+
batch_size=batch_size,
|
|
115
|
+
**common_kwargs,
|
|
116
|
+
)
|
|
522
117
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
526
|
-
runtime=self.model[i + 1],
|
|
527
|
-
main_input_name=main_input_name,
|
|
528
|
-
embed_tokens=self.embed_tokens,
|
|
529
|
-
phase="decode",
|
|
530
|
-
batch_size=batch_size,
|
|
531
|
-
dec_attn_mask=dec_attn_mask,
|
|
532
|
-
block_tables=block_tables,
|
|
533
|
-
free_block_pool=free_block_pool,
|
|
534
|
-
rbln_config=self.rbln_config,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
538
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
539
|
-
|
|
540
|
-
@classmethod
|
|
541
|
-
def save_torch_artifacts(
|
|
542
|
-
cls,
|
|
543
|
-
model: PreTrainedModel,
|
|
544
|
-
save_dir_path: Path,
|
|
545
|
-
subfolder: str,
|
|
546
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
547
|
-
):
|
|
548
|
-
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
549
|
-
# store the torch tensor, weight, etc. in this function.
|
|
550
|
-
if rbln_config.use_inputs_embeds:
|
|
551
|
-
save_dict = {}
|
|
552
|
-
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
553
|
-
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
554
|
-
|
|
555
|
-
def _create_embedding_layer(self):
|
|
556
|
-
with no_init_weights():
|
|
557
|
-
embed_tokens = torch.nn.Embedding(
|
|
558
|
-
self.config.vocab_size,
|
|
559
|
-
self.config.hidden_size,
|
|
560
|
-
self.config.pad_token_id,
|
|
561
|
-
)
|
|
562
|
-
return embed_tokens
|
|
563
|
-
|
|
564
|
-
def get_input_embeddings(self):
|
|
565
|
-
return self.embed_tokens
|
|
566
|
-
|
|
567
|
-
def get_attn_impl(self) -> str:
|
|
568
|
-
return self.rbln_config.attn_impl
|
|
118
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
119
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
569
120
|
|
|
570
|
-
|
|
571
|
-
|
|
121
|
+
@property
|
|
122
|
+
def prefill_output_size(self):
|
|
123
|
+
return (
|
|
124
|
+
1,
|
|
125
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
+
self.config.hidden_size,
|
|
127
|
+
)
|
|
572
128
|
|
|
573
129
|
@classmethod
|
|
574
130
|
def get_quantized_model(
|
|
@@ -582,35 +138,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
582
138
|
subfolder: str = "",
|
|
583
139
|
local_files_only: bool = False,
|
|
584
140
|
trust_remote_code: bool = False,
|
|
141
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
585
142
|
**kwargs,
|
|
586
143
|
):
|
|
587
144
|
kwargs = cls.update_kwargs(kwargs)
|
|
588
145
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
model_id,
|
|
592
|
-
use_auth_token=use_auth_token,
|
|
593
|
-
revision=revision,
|
|
594
|
-
force_download=force_download,
|
|
595
|
-
cache_dir=cache_dir,
|
|
596
|
-
trust_remote_code=trust_remote_code,
|
|
597
|
-
**kwargs,
|
|
598
|
-
)
|
|
599
|
-
|
|
600
|
-
with no_init_weights():
|
|
601
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
602
|
-
|
|
603
|
-
model = prepare_model_for_quantization(
|
|
604
|
-
model,
|
|
146
|
+
return get_quantized_model(
|
|
147
|
+
cls.auto_model_class,
|
|
605
148
|
model_id,
|
|
606
|
-
kwargs.get("num_hidden_layers"),
|
|
607
149
|
use_auth_token=use_auth_token,
|
|
608
150
|
revision=revision,
|
|
609
151
|
cache_dir=cache_dir,
|
|
610
152
|
force_download=force_download,
|
|
611
153
|
local_files_only=local_files_only,
|
|
154
|
+
rbln_quantization=rbln_config.quantization,
|
|
155
|
+
**kwargs,
|
|
612
156
|
)
|
|
613
|
-
return model
|
|
614
157
|
|
|
615
158
|
def __getattr__(self, __name: str) -> Any:
|
|
616
159
|
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
@@ -632,233 +175,162 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
632
175
|
return val
|
|
633
176
|
|
|
634
177
|
@classmethod
|
|
635
|
-
def
|
|
636
|
-
cls,
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
178
|
+
def save_torch_artifacts(
|
|
179
|
+
cls,
|
|
180
|
+
model: PreTrainedModel,
|
|
181
|
+
save_dir_path: Path,
|
|
182
|
+
subfolder: str,
|
|
183
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
184
|
+
):
|
|
185
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
186
|
+
# store the torch tensor, weight, etc. in this function.
|
|
187
|
+
if rbln_config.use_inputs_embeds:
|
|
188
|
+
save_dict = {}
|
|
189
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
190
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
642
191
|
|
|
643
|
-
|
|
192
|
+
def _create_embedding_layer(self):
|
|
193
|
+
with no_init_weights():
|
|
194
|
+
embed_tokens = torch.nn.Embedding(
|
|
195
|
+
self.config.vocab_size,
|
|
196
|
+
self.config.hidden_size,
|
|
197
|
+
self.config.pad_token_id,
|
|
198
|
+
)
|
|
199
|
+
return embed_tokens
|
|
644
200
|
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
return
|
|
201
|
+
def get_decoder(self):
|
|
202
|
+
if not self.can_generate():
|
|
203
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
204
|
+
return self.decoder
|
|
205
|
+
|
|
206
|
+
def can_generate(self):
|
|
207
|
+
return self.rbln_config.can_generate
|
|
208
|
+
|
|
209
|
+
def get_input_embeddings(self):
|
|
210
|
+
return self.embed_tokens
|
|
211
|
+
|
|
212
|
+
def get_attn_impl(self) -> str:
|
|
213
|
+
return self.rbln_config.attn_impl
|
|
214
|
+
|
|
215
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
216
|
+
return self.rbln_config.kvcache_num_blocks
|
|
661
217
|
|
|
662
218
|
@classmethod
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
219
|
+
def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
|
+
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
666
221
|
|
|
667
|
-
|
|
668
|
-
|
|
222
|
+
@classmethod
|
|
223
|
+
def _compile_model(
|
|
224
|
+
cls,
|
|
225
|
+
wrapped_model,
|
|
226
|
+
compile_config,
|
|
227
|
+
example_inputs,
|
|
228
|
+
compile_context,
|
|
229
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
230
|
+
quantization=None,
|
|
231
|
+
phase: str = "prefill",
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
wrapped_model.phase = phase
|
|
235
|
+
if quantization:
|
|
236
|
+
quantization.maybe_set_quantization_env()
|
|
237
|
+
original_linear = torch.nn.functional.linear
|
|
238
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
239
|
+
compiled_model = cls.compile(
|
|
240
|
+
wrapped_model,
|
|
241
|
+
compile_config,
|
|
242
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
243
|
+
device=rbln_config.device,
|
|
244
|
+
example_inputs=example_inputs,
|
|
245
|
+
compile_context=compile_context,
|
|
246
|
+
)
|
|
247
|
+
return compiled_model
|
|
248
|
+
finally:
|
|
249
|
+
torch.nn.functional.linear = original_linear
|
|
250
|
+
if quantization:
|
|
251
|
+
quantization.maybe_reset_quantization_env()
|
|
669
252
|
|
|
253
|
+
@classmethod
|
|
254
|
+
def _get_compile_context(
|
|
255
|
+
cls,
|
|
256
|
+
compile_config: RBLNCompileConfig,
|
|
257
|
+
example_inputs: List[torch.Tensor],
|
|
258
|
+
):
|
|
670
259
|
context = CompileContext(use_weight_sharing=True)
|
|
671
260
|
|
|
672
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
673
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
674
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
675
|
-
|
|
676
261
|
# Mark static tensors (self kv states)
|
|
677
262
|
static_tensors = {}
|
|
678
|
-
|
|
263
|
+
idx = 0
|
|
264
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
679
265
|
if "past_key_values" in name:
|
|
680
266
|
static_tensors[name] = tensor
|
|
681
|
-
context.mark_static_address(tensor)
|
|
682
|
-
|
|
683
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
|
684
|
-
try:
|
|
685
|
-
if quantization:
|
|
686
|
-
quantization.maybe_set_quantization_env()
|
|
687
|
-
original_linear = torch.nn.functional.linear
|
|
688
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
689
|
-
compiled_model = cls.compile(
|
|
690
|
-
wrapped_model,
|
|
691
|
-
compile_config,
|
|
692
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
693
|
-
device=rbln_config.device,
|
|
694
|
-
example_inputs=example_inputs,
|
|
695
|
-
compile_context=compile_context,
|
|
696
|
-
)
|
|
697
|
-
return compiled_model
|
|
698
|
-
finally:
|
|
699
|
-
torch.nn.functional.linear = original_linear
|
|
700
|
-
if quantization:
|
|
701
|
-
quantization.maybe_reset_quantization_env()
|
|
702
|
-
|
|
703
|
-
wrapped_model.phase = "prefill"
|
|
704
|
-
compiled_prefill = compile_model(
|
|
705
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
706
|
-
)
|
|
707
|
-
|
|
708
|
-
wrapped_model.phase = "decode"
|
|
709
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
710
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
711
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
712
|
-
compiled_decoder = compile_model(
|
|
713
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
714
|
-
)
|
|
715
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
716
|
-
|
|
717
|
-
# check if the memory is enough to have additional blocks
|
|
718
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
719
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
720
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
721
|
-
compiled_models=compiled_models,
|
|
722
|
-
model_config=model.config,
|
|
723
|
-
rbln_config=rbln_config,
|
|
724
|
-
)
|
|
267
|
+
context.mark_static_address(tensor, f"kv_cache_{idx}")
|
|
268
|
+
idx += 1
|
|
725
269
|
|
|
726
|
-
return
|
|
270
|
+
return context, static_tensors
|
|
727
271
|
|
|
728
272
|
@classmethod
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
734
|
-
) -> None:
|
|
735
|
-
# Get the actual memory allocation of each node by key
|
|
736
|
-
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
|
737
|
-
alloc_memory_by_key: Dict[str, int] = {
|
|
738
|
-
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
|
739
|
-
}
|
|
740
|
-
for batch_size in rbln_config.decoder_batch_sizes:
|
|
741
|
-
for key, memory_per_node in (
|
|
742
|
-
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
|
743
|
-
):
|
|
744
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
745
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
746
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
747
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
748
|
-
|
|
749
|
-
# Get the maximum number of blocks that can be allocated
|
|
750
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
751
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
752
|
-
config=model_config,
|
|
753
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
754
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
755
|
-
kernel_size=kernel_size,
|
|
756
|
-
buffer=buffer,
|
|
757
|
-
)
|
|
273
|
+
@torch.inference_mode()
|
|
274
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
275
|
+
wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
|
|
276
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
758
277
|
|
|
759
|
-
#
|
|
760
|
-
|
|
761
|
-
|
|
762
|
-
|
|
763
|
-
|
|
764
|
-
|
|
765
|
-
|
|
766
|
-
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
|
|
278
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
279
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
280
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
281
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
282
|
+
|
|
283
|
+
compiled_models = {}
|
|
284
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
285
|
+
wrapped_model,
|
|
286
|
+
prefill_compile_config,
|
|
287
|
+
prefill_example_inputs,
|
|
288
|
+
context,
|
|
289
|
+
rbln_config,
|
|
290
|
+
rbln_config.quantization,
|
|
291
|
+
phase="prefill",
|
|
292
|
+
)
|
|
770
293
|
|
|
771
|
-
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
782
|
-
|
|
783
|
-
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
# This inequality can be rewritten as follows:
|
|
796
|
-
|
|
797
|
-
# a - c * align_2MB(b * x) > 0
|
|
798
|
-
# where
|
|
799
|
-
# a = available_dram - kernel_size - buffer
|
|
800
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
801
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
802
|
-
|
|
803
|
-
# We can rewrite the inequality as follows:
|
|
804
|
-
# k > align_2MB(b*x)
|
|
805
|
-
# where
|
|
806
|
-
# k = a / c
|
|
807
|
-
|
|
808
|
-
# After that, we can derive the following equation:
|
|
809
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
810
|
-
|
|
811
|
-
def align(x: int, nbytes: int) -> int:
|
|
812
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
813
|
-
|
|
814
|
-
def align_2MB(x: int) -> int:
|
|
815
|
-
return align(x, 2**21)
|
|
816
|
-
|
|
817
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
818
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
819
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
820
|
-
vocab_size = config.vocab_size
|
|
821
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
822
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
823
|
-
|
|
824
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
825
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
826
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
827
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
828
|
-
|
|
829
|
-
if kernel_size is None:
|
|
830
|
-
if n_model_params is None:
|
|
831
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
832
|
-
# Get estimated kernel size (approximated)
|
|
833
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
834
|
-
lm_heads_nbytes = (
|
|
835
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
836
|
-
)
|
|
837
|
-
params = n_model_params - lm_heads_params
|
|
838
|
-
layer_nbytes = (
|
|
839
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
840
|
-
* num_layers
|
|
841
|
-
* tensor_parallel_size
|
|
842
|
-
)
|
|
843
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
844
|
-
elif n_model_params is not None:
|
|
845
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
294
|
+
if rbln_config.can_generate:
|
|
295
|
+
wrapped_model.phase = "decode"
|
|
296
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
297
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
298
|
+
compiled_decoder = cls._compile_model(
|
|
299
|
+
wrapped_model,
|
|
300
|
+
dec_compile_config,
|
|
301
|
+
dec_example_inputs,
|
|
302
|
+
context,
|
|
303
|
+
rbln_config,
|
|
304
|
+
rbln_config.quantization,
|
|
305
|
+
phase="decode",
|
|
306
|
+
)
|
|
307
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
308
|
+
|
|
309
|
+
# check if the memory is enough to have additional blocks
|
|
310
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
311
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
312
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
313
|
+
compiled_models=compiled_models,
|
|
314
|
+
model_config=model.config,
|
|
315
|
+
rbln_config=rbln_config,
|
|
316
|
+
)
|
|
846
317
|
|
|
847
|
-
|
|
318
|
+
return compiled_models
|
|
848
319
|
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
320
|
+
@classmethod
|
|
321
|
+
def get_pytorch_model(
|
|
322
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
323
|
+
) -> PreTrainedModel:
|
|
324
|
+
if rbln_config and rbln_config.quantization:
|
|
325
|
+
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
326
|
+
else:
|
|
327
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
855
328
|
|
|
856
|
-
|
|
857
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
858
|
-
k = available_dram / c
|
|
859
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
329
|
+
return model
|
|
860
330
|
|
|
861
|
-
|
|
331
|
+
@classmethod
|
|
332
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
333
|
+
return use_local_attention
|
|
862
334
|
|
|
863
335
|
@classmethod
|
|
864
336
|
def get_input_info(
|
|
@@ -868,63 +340,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
868
340
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
869
341
|
model_config: PretrainedConfig,
|
|
870
342
|
):
|
|
871
|
-
is_prefill: bool = query_length > 1
|
|
872
343
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
873
344
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
874
345
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
875
346
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
876
347
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
877
|
-
|
|
348
|
+
is_prefill = query_length > 1
|
|
878
349
|
|
|
879
|
-
|
|
350
|
+
input_info = []
|
|
880
351
|
if rbln_config.use_inputs_embeds:
|
|
881
|
-
|
|
352
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
|
|
882
353
|
else:
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
input_info = [
|
|
887
|
-
main_input,
|
|
888
|
-
(
|
|
889
|
-
"cache_position",
|
|
890
|
-
[batch_size, query_length],
|
|
891
|
-
"int32",
|
|
892
|
-
),
|
|
893
|
-
]
|
|
354
|
+
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
355
|
+
|
|
356
|
+
input_info.append(("cache_position", [batch_size, query_length], "int32"))
|
|
894
357
|
|
|
895
|
-
|
|
896
|
-
if rbln_config.cache_impl in ["static", "hybrid"]:
|
|
358
|
+
if rbln_config.use_global_attention:
|
|
897
359
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
898
|
-
input_info.
|
|
899
|
-
|
|
360
|
+
input_info.append(
|
|
361
|
+
("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
|
|
900
362
|
)
|
|
901
|
-
if rbln_config.
|
|
902
|
-
input_info.
|
|
363
|
+
if rbln_config.use_local_attention:
|
|
364
|
+
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
903
365
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
input_info.extend([("query_position", [], "int16")])
|
|
366
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
367
|
+
input_info.append(("query_position", [], "int16"))
|
|
907
368
|
|
|
908
|
-
# 5. attention_mask & position_ids
|
|
909
369
|
if rbln_config.use_attention_mask:
|
|
910
|
-
|
|
911
|
-
[
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
370
|
+
if rbln_config.use_position_ids:
|
|
371
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
|
|
372
|
+
else:
|
|
373
|
+
input_info.append(
|
|
374
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
|
|
375
|
+
)
|
|
376
|
+
|
|
917
377
|
if rbln_config.use_position_ids:
|
|
918
378
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
919
379
|
|
|
920
|
-
|
|
380
|
+
if rbln_config.use_lora:
|
|
381
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
382
|
+
|
|
383
|
+
kvcache_dtype = rbln_config.torch_dtype
|
|
384
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
385
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
386
|
+
|
|
921
387
|
global_kvcache_shape = [
|
|
922
388
|
rbln_config.kvcache_num_blocks,
|
|
923
389
|
num_key_value_heads,
|
|
924
390
|
rbln_config.kvcache_block_size,
|
|
925
391
|
head_dim,
|
|
926
392
|
]
|
|
927
|
-
local_kvcache_shape = [
|
|
393
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
928
394
|
input_info.extend(
|
|
929
395
|
[
|
|
930
396
|
(
|
|
@@ -932,7 +398,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
932
398
|
local_kvcache_shape
|
|
933
399
|
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
|
934
400
|
else global_kvcache_shape,
|
|
935
|
-
|
|
401
|
+
kvcache_dtype,
|
|
936
402
|
)
|
|
937
403
|
for i in range(num_hidden_layers * 2)
|
|
938
404
|
]
|
|
@@ -971,7 +437,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
971
437
|
# ```
|
|
972
438
|
|
|
973
439
|
# Returns:
|
|
974
|
-
#
|
|
440
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
975
441
|
|
|
976
442
|
raise NotImplementedError(
|
|
977
443
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
@@ -979,27 +445,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
979
445
|
)
|
|
980
446
|
|
|
981
447
|
@classmethod
|
|
982
|
-
def
|
|
983
|
-
cls,
|
|
984
|
-
|
|
985
|
-
model: Optional[PreTrainedModel] = None,
|
|
986
|
-
model_config: Optional[PretrainedConfig] = None,
|
|
987
|
-
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
988
|
-
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
989
|
-
if rbln_config.max_seq_len is None:
|
|
990
|
-
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
991
|
-
model_config, "n_positions", None
|
|
992
|
-
)
|
|
993
|
-
if rbln_config.max_seq_len is None:
|
|
994
|
-
raise ValueError("`max_seq_len` should be specified.")
|
|
995
|
-
|
|
996
|
-
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
997
|
-
model_config, "use_sliding_window", True
|
|
998
|
-
):
|
|
999
|
-
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
1000
|
-
if rbln_config.sliding_window is not None:
|
|
1001
|
-
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
|
1002
|
-
|
|
448
|
+
def _update_attention_config(
|
|
449
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
450
|
+
):
|
|
1003
451
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
1004
452
|
attn_impl=rbln_config.attn_impl,
|
|
1005
453
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1014,40 +462,77 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1014
462
|
max_seq_len=rbln_config.max_seq_len,
|
|
1015
463
|
)
|
|
1016
464
|
|
|
1017
|
-
|
|
1018
|
-
max_num_blocks = required_num_blocks
|
|
465
|
+
num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1019
466
|
|
|
467
|
+
# Update kvcache_num_blocks based on the attention implementation.
|
|
1020
468
|
if rbln_config.attn_impl == "flash_attn":
|
|
1021
|
-
estimated_max_num_blocks = cls.
|
|
1022
|
-
|
|
1023
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
|
1024
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1025
|
-
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1026
|
-
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1027
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
469
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
|
|
470
|
+
model=model, model_config=model_config, rbln_config=rbln_config
|
|
1028
471
|
)
|
|
1029
472
|
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1039
|
-
|
|
473
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
474
|
+
if estimated_max_num_blocks < num_full_blocks:
|
|
475
|
+
# lower bound of the number of blocks for flash attention.
|
|
476
|
+
min_blocks_for_flash = min(
|
|
477
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
|
|
478
|
+
)
|
|
479
|
+
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
480
|
+
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
481
|
+
# Even if it's larger than the estimated maximum number of blocks.
|
|
482
|
+
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
483
|
+
else:
|
|
484
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
485
|
+
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
486
|
+
|
|
487
|
+
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
488
|
+
raise RuntimeError(
|
|
489
|
+
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
490
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
|
491
|
+
)
|
|
492
|
+
else:
|
|
493
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
494
|
+
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
495
|
+
logger.warning(
|
|
496
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
497
|
+
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
498
|
+
"This can cause a failure during model compilation."
|
|
499
|
+
)
|
|
500
|
+
else:
|
|
501
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
502
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
503
|
+
elif rbln_config.kvcache_num_blocks > num_full_blocks:
|
|
504
|
+
logger.warning(
|
|
505
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
506
|
+
f" than the required number of blocks ({num_full_blocks})."
|
|
507
|
+
"This can cause a failure during model compilation."
|
|
1040
508
|
)
|
|
509
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1041
510
|
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
511
|
+
return rbln_config
|
|
512
|
+
|
|
513
|
+
@classmethod
|
|
514
|
+
def _update_rbln_config(
|
|
515
|
+
cls,
|
|
516
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
517
|
+
model: Optional[PreTrainedModel] = None,
|
|
518
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
519
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
520
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
521
|
+
if rbln_config.max_seq_len is None:
|
|
522
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
523
|
+
model_config, "n_positions", None
|
|
1049
524
|
)
|
|
1050
|
-
|
|
525
|
+
if rbln_config.max_seq_len is None:
|
|
526
|
+
raise ValueError("`max_seq_len` should be specified.")
|
|
527
|
+
|
|
528
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
529
|
+
model_config, "use_sliding_window", True
|
|
530
|
+
):
|
|
531
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
532
|
+
if rbln_config.sliding_window is not None:
|
|
533
|
+
validate_sliding_window(rbln_config)
|
|
534
|
+
|
|
535
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
1051
536
|
|
|
1052
537
|
prefill_input_info = cls.get_input_info(
|
|
1053
538
|
batch_size=1,
|
|
@@ -1057,19 +542,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1057
542
|
)
|
|
1058
543
|
|
|
1059
544
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
545
|
+
compile_cfgs = [prefill_compile_config]
|
|
546
|
+
|
|
547
|
+
if rbln_config.can_generate:
|
|
548
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
549
|
+
dec_input_info = cls.get_input_info(
|
|
550
|
+
batch_size=batch_size,
|
|
551
|
+
query_length=1,
|
|
552
|
+
rbln_config=rbln_config,
|
|
553
|
+
model_config=model_config,
|
|
554
|
+
)
|
|
555
|
+
compile_cfgs.append(
|
|
556
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
557
|
+
)
|
|
558
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
1073
559
|
|
|
1074
560
|
return rbln_config
|
|
1075
561
|
|
|
@@ -1079,103 +565,164 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1079
565
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1080
566
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1081
567
|
) -> List[rebel.Runtime]:
|
|
1082
|
-
expected_model_names = [
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
568
|
+
expected_model_names = ["prefill"]
|
|
569
|
+
if rbln_config.can_generate:
|
|
570
|
+
expected_model_names.extend(
|
|
571
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
572
|
+
)
|
|
1086
573
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1087
574
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1088
575
|
|
|
1089
|
-
|
|
576
|
+
ret_val = [
|
|
1090
577
|
rebel.Runtime(
|
|
1091
578
|
compiled_models[0],
|
|
1092
579
|
tensor_type="pt",
|
|
1093
580
|
device=rbln_config.device_map["prefill"],
|
|
1094
581
|
activate_profiler=rbln_config.activate_profiler,
|
|
1095
582
|
timeout=rbln_config.timeout,
|
|
1096
|
-
)
|
|
1097
|
-
*[
|
|
1098
|
-
rebel.Runtime(
|
|
1099
|
-
compiled_models[i + 1],
|
|
1100
|
-
tensor_type="pt",
|
|
1101
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1102
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1103
|
-
timeout=rbln_config.timeout,
|
|
1104
|
-
)
|
|
1105
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1106
|
-
],
|
|
583
|
+
)
|
|
1107
584
|
]
|
|
585
|
+
if rbln_config.can_generate:
|
|
586
|
+
ret_val.extend(
|
|
587
|
+
[
|
|
588
|
+
rebel.Runtime(
|
|
589
|
+
compiled_models[i + 1],
|
|
590
|
+
tensor_type="pt",
|
|
591
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
592
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
593
|
+
timeout=rbln_config.timeout,
|
|
594
|
+
)
|
|
595
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
596
|
+
]
|
|
597
|
+
)
|
|
598
|
+
return ret_val
|
|
1108
599
|
|
|
1109
|
-
def
|
|
1110
|
-
return self.decoder
|
|
1111
|
-
|
|
1112
|
-
def can_generate(self):
|
|
1113
|
-
return True
|
|
1114
|
-
|
|
1115
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1116
|
-
raise NotImplementedError
|
|
1117
|
-
|
|
1118
|
-
def prepare_inputs_for_generation(
|
|
600
|
+
def forward(
|
|
1119
601
|
self,
|
|
1120
|
-
input_ids: torch.LongTensor,
|
|
1121
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
1122
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
602
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
1123
603
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
1124
|
-
|
|
604
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
1125
605
|
**kwargs,
|
|
1126
|
-
):
|
|
1127
|
-
|
|
1128
|
-
|
|
606
|
+
) -> BaseModelOutputWithPast:
|
|
607
|
+
"""
|
|
608
|
+
Args:
|
|
609
|
+
input_ids (torch.LongTensor, optional): The input IDs to the model.
|
|
610
|
+
inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
|
|
611
|
+
attention_mask (torch.LongTensor, optional): The attention mask to the model.
|
|
612
|
+
kwargs (dict[str, Any], optional): Additional keyword arguments.
|
|
1129
613
|
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
614
|
+
Returns:
|
|
615
|
+
Dataclass containing the last hidden states of the model.
|
|
616
|
+
"""
|
|
617
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
618
|
+
batch_size = inputs.shape[0]
|
|
619
|
+
position_embed = kwargs.get("position_embed", None)
|
|
620
|
+
|
|
621
|
+
if batch_size != self.rbln_config.batch_size:
|
|
622
|
+
raise ValueError(
|
|
623
|
+
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
all_last_hidden_states = []
|
|
627
|
+
for b_idx in range(self.rbln_config.batch_size):
|
|
628
|
+
query_length = (
|
|
629
|
+
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
630
|
+
)
|
|
631
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
632
|
+
last_hidden_states = self.prefill_decoder(
|
|
633
|
+
inputs[b_idx : b_idx + 1],
|
|
634
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
635
|
+
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
636
|
+
cache_position=cache_position,
|
|
637
|
+
batch_idx=b_idx,
|
|
638
|
+
).logits
|
|
639
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
640
|
+
|
|
641
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
642
|
+
|
|
643
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
644
|
+
|
|
645
|
+
|
|
646
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
647
|
+
"""
|
|
648
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
649
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
650
|
+
|
|
651
|
+
The class provides core functionality for:
|
|
652
|
+
|
|
653
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
654
|
+
2. Handling the compilation process for RBLN devices
|
|
655
|
+
3. Managing inference operations for causal language modeling
|
|
656
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
657
|
+
decoder-only architectures and causal language modeling tasks.
|
|
658
|
+
|
|
659
|
+
Note:
|
|
660
|
+
- This class is designed to be subclassed by specific model implementations
|
|
661
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
662
|
+
- Subclasses should implement model-specific conversion logic.
|
|
663
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
664
|
+
"""
|
|
665
|
+
|
|
666
|
+
auto_model_class = AutoModelForCausalLM
|
|
667
|
+
|
|
668
|
+
@property
|
|
669
|
+
def prefill_output_size(self):
|
|
670
|
+
return (
|
|
671
|
+
1,
|
|
672
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
673
|
+
self.config.vocab_size,
|
|
1164
674
|
)
|
|
1165
675
|
|
|
1166
|
-
|
|
676
|
+
@classmethod
|
|
677
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
678
|
+
return is_prefill
|
|
679
|
+
|
|
680
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
681
|
+
if isinstance(lora_int_ids, int):
|
|
682
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
683
|
+
elif isinstance(lora_int_ids, list):
|
|
684
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
1167
685
|
|
|
1168
|
-
|
|
1169
|
-
self,
|
|
1170
|
-
outputs: RBLNDecoderOnlyOutput,
|
|
1171
|
-
model_kwargs: Dict[str, Any],
|
|
1172
|
-
**kwargs,
|
|
1173
|
-
) -> Dict[str, Any]:
|
|
1174
|
-
# update generate_idx
|
|
1175
|
-
model_kwargs["generate_idx"] = outputs.generate_idx
|
|
1176
|
-
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
|
686
|
+
self.lora_int_ids = lora_int_ids
|
|
1177
687
|
|
|
1178
|
-
|
|
688
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
689
|
+
if self.rbln_config.can_generate:
|
|
690
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
691
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
692
|
+
|
|
693
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
694
|
+
"""
|
|
695
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
696
|
+
|
|
697
|
+
Args:
|
|
698
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
699
|
+
Can be a single adapter name or a list of adapter names.
|
|
700
|
+
|
|
701
|
+
Raises:
|
|
702
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
703
|
+
"""
|
|
704
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
705
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
706
|
+
|
|
707
|
+
# Convert single adapter name to list for uniform processing
|
|
708
|
+
if isinstance(adapter_name, str):
|
|
709
|
+
adapter_names = [adapter_name]
|
|
710
|
+
else:
|
|
711
|
+
adapter_names = adapter_name
|
|
712
|
+
|
|
713
|
+
# Validate that all adapter names exist
|
|
714
|
+
available_adapters = {
|
|
715
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
716
|
+
}
|
|
717
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
718
|
+
if missing_adapters:
|
|
719
|
+
raise ValueError(
|
|
720
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Get the adapter IDs and set them
|
|
724
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
725
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
1179
726
|
|
|
1180
727
|
def forward(
|
|
1181
728
|
self,
|
|
@@ -1187,6 +734,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1187
734
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
1188
735
|
position_ids: Optional[torch.Tensor] = None,
|
|
1189
736
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
737
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
1190
738
|
return_dict: Optional[torch.Tensor] = None,
|
|
1191
739
|
**kwargs,
|
|
1192
740
|
) -> Tuple[torch.FloatTensor]:
|
|
@@ -1194,17 +742,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1194
742
|
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
1195
743
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1196
744
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
745
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
746
|
+
if self.lora_int_ids is None:
|
|
747
|
+
raise ValueError(
|
|
748
|
+
"lora_int_id is required when using LoRA. "
|
|
749
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
750
|
+
)
|
|
751
|
+
lora_int_ids = self.lora_int_ids
|
|
752
|
+
|
|
753
|
+
# for only use forward
|
|
754
|
+
if generate_idx is None:
|
|
755
|
+
generate_idx = (
|
|
756
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
757
|
+
if attention_mask is not None
|
|
758
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
759
|
+
)
|
|
760
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1197
761
|
|
|
1198
|
-
#
|
|
762
|
+
# Prefill
|
|
1199
763
|
if cache_position is None:
|
|
1200
764
|
logits = []
|
|
1201
765
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
1202
|
-
# for only use forward
|
|
1203
|
-
if generate_idx is None:
|
|
1204
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1205
|
-
if padded_cache_lengths is None:
|
|
1206
|
-
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1207
766
|
batch_size = inputs.shape[0]
|
|
767
|
+
input_len = inputs.shape[1]
|
|
768
|
+
if batch_size > self.rbln_config.batch_size:
|
|
769
|
+
raise ValueError(
|
|
770
|
+
f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
|
|
771
|
+
)
|
|
772
|
+
if input_len > self.rbln_config.max_seq_len:
|
|
773
|
+
raise ValueError(
|
|
774
|
+
f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
|
|
775
|
+
)
|
|
776
|
+
|
|
1208
777
|
for b_idx in range(batch_size):
|
|
1209
778
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
1210
779
|
output = self.prefill_decoder(
|
|
@@ -1214,6 +783,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1214
783
|
cache_position=cache_position,
|
|
1215
784
|
batch_idx=b_idx,
|
|
1216
785
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
786
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
1217
787
|
)
|
|
1218
788
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
1219
789
|
logits.append(output.logits)
|
|
@@ -1228,11 +798,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1228
798
|
f"Available batch sizes are: {list(self.decoders.keys())}. "
|
|
1229
799
|
f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
|
|
1230
800
|
)
|
|
801
|
+
if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
|
|
802
|
+
raise ValueError(
|
|
803
|
+
f"Cache position exceeds the maximum sequence length.\n"
|
|
804
|
+
f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
|
|
805
|
+
f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
|
|
806
|
+
f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
|
|
807
|
+
f"or `max_length` in the generation config."
|
|
808
|
+
)
|
|
809
|
+
|
|
1231
810
|
logits = self.decoders[batch_size](
|
|
1232
811
|
input_ids=input_ids,
|
|
1233
812
|
inputs_embeds=inputs_embeds,
|
|
1234
813
|
cache_position=cache_position,
|
|
1235
814
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
815
|
+
lora_int_ids=lora_int_ids,
|
|
1236
816
|
).logits
|
|
1237
817
|
|
|
1238
818
|
if not return_dict:
|