optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +96 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/cli.py +660 -0
- optimum/rbln/configuration_utils.py +153 -42
- optimum/rbln/diffusers/__init__.py +7 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
- optimum/rbln/diffusers/modeling_diffusers.py +30 -14
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
- optimum/rbln/diffusers/pipelines/__init__.py +11 -5
- optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +71 -19
- optimum/rbln/modeling_base.py +99 -21
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +92 -0
- optimum/rbln/transformers/configuration_generic.py +9 -7
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +51 -9
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +91 -30
- optimum/rbln/transformers/models/auto/__init__.py +2 -0
- optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
- optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
- optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
- optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
- optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
- optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
- optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
- optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
- optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
- optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
- optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
- optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
- optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
- optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
- optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
- optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
- optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
- optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
- optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
- optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
- optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
- optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
- optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
- optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
- optimum/rbln/transformers/models/swin/__init__.py +16 -0
- optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
- optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
- optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
- optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/runtime_utils.py +28 -18
- optimum/rbln/utils/submodule.py +31 -9
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
- optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,31 +13,31 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
|
-
from collections import deque
|
|
18
|
-
from dataclasses import dataclass
|
|
19
16
|
from pathlib import Path
|
|
20
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
21
18
|
|
|
22
19
|
import rebel
|
|
23
20
|
import torch
|
|
24
21
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import
|
|
22
|
+
from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
23
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
24
|
from transformers.modeling_utils import no_init_weights
|
|
27
|
-
from transformers.utils import ModelOutput
|
|
28
25
|
|
|
29
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
30
27
|
from ....modeling import RBLNModel
|
|
31
28
|
from ....utils.logging import get_logger
|
|
32
|
-
from
|
|
33
|
-
|
|
34
|
-
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
29
|
+
from ...modeling_attention_utils import (
|
|
30
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
31
|
set_default_values,
|
|
38
32
|
validate_attention_method,
|
|
39
|
-
|
|
33
|
+
validate_sliding_window,
|
|
40
34
|
)
|
|
35
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
36
|
+
from ...utils.rbln_quantization import get_quantized_model
|
|
37
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
38
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
|
+
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
|
+
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
41
41
|
|
|
42
42
|
|
|
43
43
|
logger = get_logger()
|
|
@@ -46,529 +46,85 @@ if TYPE_CHECKING:
|
|
|
46
46
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
47
47
|
|
|
48
48
|
|
|
49
|
-
class
|
|
50
|
-
mandatory_members = ["main_input_name", "embed_tokens"]
|
|
51
|
-
|
|
52
|
-
def __init__(
|
|
53
|
-
self,
|
|
54
|
-
runtime: rebel.Runtime,
|
|
55
|
-
phase: str,
|
|
56
|
-
batch_size: int,
|
|
57
|
-
dec_attn_mask: torch.Tensor,
|
|
58
|
-
block_tables: torch.Tensor,
|
|
59
|
-
free_block_pool: Deque,
|
|
60
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
61
|
-
**kwargs: Any,
|
|
62
|
-
) -> None:
|
|
63
|
-
super().__init__(runtime, **kwargs)
|
|
64
|
-
self.phase = phase
|
|
65
|
-
self.batch_size = batch_size
|
|
66
|
-
self.rbln_config = rbln_config
|
|
67
|
-
|
|
68
|
-
# shared tensor between prefill and decode phase
|
|
69
|
-
self.dec_attn_mask = dec_attn_mask
|
|
70
|
-
self.block_tables = block_tables
|
|
71
|
-
self.free_block_pool = free_block_pool
|
|
72
|
-
|
|
73
|
-
self.empty_block = -1
|
|
74
|
-
if self.phase == "prefill":
|
|
75
|
-
vocab_size = kwargs.pop("vocab_size")
|
|
76
|
-
self.output_size = [1, 1, vocab_size]
|
|
77
|
-
self.causal_mask = 1 - torch.triu(
|
|
78
|
-
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
|
|
82
|
-
"""
|
|
83
|
-
Manages and returns the KV cache block tables.
|
|
84
|
-
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
|
85
|
-
|
|
86
|
-
Args:
|
|
87
|
-
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
|
88
|
-
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
|
89
|
-
|
|
90
|
-
Returns:
|
|
91
|
-
Updated block tables.
|
|
92
|
-
"""
|
|
93
|
-
|
|
94
|
-
NO_BLOCKS_ERROR = (
|
|
95
|
-
"No memory blocks are available for allocation. "
|
|
96
|
-
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
|
97
|
-
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
|
98
|
-
"Using vllm-rbln should fix this issue and enhance inference performance."
|
|
99
|
-
)
|
|
100
|
-
|
|
101
|
-
def update_block(batch_idx: int, block_idx: int):
|
|
102
|
-
"""
|
|
103
|
-
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
104
|
-
"""
|
|
105
|
-
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
|
106
|
-
if self.free_block_pool:
|
|
107
|
-
block = self.free_block_pool.popleft()
|
|
108
|
-
self.block_tables[batch_idx][block_idx] = block
|
|
109
|
-
else:
|
|
110
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
111
|
-
|
|
112
|
-
def replace_empty_block(block_tables: torch.Tensor):
|
|
113
|
-
"""
|
|
114
|
-
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
|
115
|
-
"""
|
|
116
|
-
if not torch.any(block_tables == self.empty_block):
|
|
117
|
-
return block_tables.clone()
|
|
118
|
-
elif self.free_block_pool:
|
|
119
|
-
_free_block = self.free_block_pool[0]
|
|
120
|
-
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
|
121
|
-
else:
|
|
122
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
123
|
-
|
|
124
|
-
def get_global_block_tables(batch_idx: int):
|
|
125
|
-
if self.rbln_config.cache_impl == "sliding_window":
|
|
126
|
-
return None
|
|
127
|
-
|
|
128
|
-
if self.phase == "prefill":
|
|
129
|
-
# Track previously used blocks and return them to the free_block_pool and
|
|
130
|
-
# reset the current batch's block table to empty blocks
|
|
131
|
-
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
|
132
|
-
self.free_block_pool.extend(prev_blocks)
|
|
133
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
|
134
|
-
|
|
135
|
-
# Get the start (s) and end (e) positions from cache_position and
|
|
136
|
-
# iterate over the cache positions to allocate necessary blocks
|
|
137
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
138
|
-
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
139
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
140
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
141
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
142
|
-
update_block(batch_idx, block_idx)
|
|
143
|
-
|
|
144
|
-
return replace_empty_block(self.block_tables[batch_idx])
|
|
145
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
|
146
|
-
else:
|
|
147
|
-
for b_idx in range(self.batch_size):
|
|
148
|
-
position = cache_position[b_idx][0].item()
|
|
149
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
150
|
-
update_block(b_idx, block_idx)
|
|
151
|
-
|
|
152
|
-
return replace_empty_block(self.block_tables)
|
|
153
|
-
|
|
154
|
-
def get_local_block_tables(batch_idx: int):
|
|
155
|
-
if self.rbln_config.cache_impl == "static":
|
|
156
|
-
return None
|
|
157
|
-
else:
|
|
158
|
-
return (
|
|
159
|
-
torch.tensor([batch_idx], dtype=torch.int16)
|
|
160
|
-
if self.phase == "prefill"
|
|
161
|
-
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
|
165
|
-
|
|
166
|
-
def is_external_block_tables(
|
|
167
|
-
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
|
168
|
-
):
|
|
169
|
-
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
|
170
|
-
return False
|
|
171
|
-
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
|
172
|
-
return False
|
|
173
|
-
elif self.rbln_config.cache_impl == "hybrid":
|
|
174
|
-
if (block_tables is not None) != (local_block_tables is not None):
|
|
175
|
-
raise ValueError(
|
|
176
|
-
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
|
177
|
-
)
|
|
178
|
-
elif block_tables is None and local_block_tables is None:
|
|
179
|
-
return False
|
|
180
|
-
|
|
181
|
-
return True
|
|
182
|
-
|
|
183
|
-
def forward(
|
|
184
|
-
self,
|
|
185
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
186
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
187
|
-
cache_position: torch.Tensor = None,
|
|
188
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
189
|
-
batch_idx: Optional[int] = None,
|
|
190
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
191
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
192
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
193
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
194
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
195
|
-
):
|
|
196
|
-
if input_ids is None and inputs_embeds is None:
|
|
197
|
-
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
|
198
|
-
|
|
199
|
-
if inputs_embeds is None:
|
|
200
|
-
inputs = input_ids
|
|
201
|
-
if self.embed_tokens is not None:
|
|
202
|
-
inputs = self.embed_tokens(inputs)
|
|
203
|
-
else:
|
|
204
|
-
inputs = inputs_embeds
|
|
205
|
-
|
|
206
|
-
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
|
207
|
-
if not is_external_block_tables:
|
|
208
|
-
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
|
209
|
-
|
|
210
|
-
if self.phase == "decode":
|
|
211
|
-
return self.decode_forward(
|
|
212
|
-
inputs,
|
|
213
|
-
cache_position,
|
|
214
|
-
block_tables,
|
|
215
|
-
is_external_block_tables,
|
|
216
|
-
attention_mask=attention_mask,
|
|
217
|
-
position_embed=position_embed,
|
|
218
|
-
position_ids=position_ids,
|
|
219
|
-
local_block_tables=local_block_tables,
|
|
220
|
-
)
|
|
221
|
-
else:
|
|
222
|
-
return self.prefill_forward(
|
|
223
|
-
inputs,
|
|
224
|
-
cache_position,
|
|
225
|
-
attention_mask,
|
|
226
|
-
batch_idx,
|
|
227
|
-
block_tables,
|
|
228
|
-
is_external_block_tables=is_external_block_tables,
|
|
229
|
-
position_embed=position_embed,
|
|
230
|
-
token_type_ids=token_type_ids,
|
|
231
|
-
local_block_tables=local_block_tables,
|
|
232
|
-
)
|
|
233
|
-
|
|
234
|
-
def decode_forward(
|
|
235
|
-
self,
|
|
236
|
-
inputs: torch.Tensor,
|
|
237
|
-
cache_position: torch.Tensor = None,
|
|
238
|
-
block_tables: torch.Tensor = None,
|
|
239
|
-
is_external_block_tables: bool = None,
|
|
240
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
241
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
242
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
243
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
244
|
-
) -> torch.FloatTensor:
|
|
245
|
-
batch_size = inputs.shape[0]
|
|
246
|
-
if batch_size != self.batch_size:
|
|
247
|
-
raise RuntimeError(
|
|
248
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
if batch_size != cache_position.shape[0]:
|
|
252
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
253
|
-
|
|
254
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
255
|
-
for b_idx in range(batch_size):
|
|
256
|
-
decoding_step = cache_position[b_idx].item()
|
|
257
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
258
|
-
raise ValueError(
|
|
259
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
260
|
-
)
|
|
261
|
-
|
|
262
|
-
if is_external_block_tables:
|
|
263
|
-
self.dec_attn_mask[b_idx].fill_(0)
|
|
264
|
-
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
265
|
-
else:
|
|
266
|
-
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
267
|
-
|
|
268
|
-
attention_mask = self.dec_attn_mask
|
|
269
|
-
|
|
270
|
-
if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
|
|
271
|
-
block_tables = block_tables[: self.batch_size]
|
|
272
|
-
|
|
273
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
274
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
275
|
-
|
|
276
|
-
logits = super().forward(
|
|
277
|
-
inputs,
|
|
278
|
-
cache_position,
|
|
279
|
-
block_tables,
|
|
280
|
-
local_block_tables,
|
|
281
|
-
position_embed,
|
|
282
|
-
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
283
|
-
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
|
-
)
|
|
285
|
-
|
|
286
|
-
return RBLNDecoderOnlyOutput(logits=logits)
|
|
287
|
-
|
|
288
|
-
def _prepare_prefill_inputs(
|
|
289
|
-
self,
|
|
290
|
-
inputs: torch.Tensor,
|
|
291
|
-
cache_position: torch.Tensor,
|
|
292
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
293
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
294
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
295
|
-
):
|
|
296
|
-
"""
|
|
297
|
-
Prepare inputs for prefill phase.
|
|
298
|
-
"""
|
|
299
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
300
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
301
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
302
|
-
if position_embed is not None:
|
|
303
|
-
position_embed = (
|
|
304
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
305
|
-
)
|
|
306
|
-
if token_type_ids is not None:
|
|
307
|
-
token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
|
|
308
|
-
|
|
309
|
-
query_length = inputs.shape[1]
|
|
310
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
311
|
-
raise ValueError(
|
|
312
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
313
|
-
)
|
|
314
|
-
|
|
315
|
-
# Initialize attention mask for chunked processing
|
|
316
|
-
chunked_attention_mask = (
|
|
317
|
-
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
318
|
-
if self.rbln_config.use_attention_mask
|
|
319
|
-
else None
|
|
320
|
-
)
|
|
321
|
-
|
|
322
|
-
# Buffer for storing output logits
|
|
323
|
-
out_buffers = [
|
|
324
|
-
torch.empty(
|
|
325
|
-
size=self.output_size,
|
|
326
|
-
dtype=torch.float32,
|
|
327
|
-
device="cpu",
|
|
328
|
-
)
|
|
329
|
-
]
|
|
330
|
-
|
|
331
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
332
|
-
padding_size = 0
|
|
333
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
334
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
335
|
-
# inputs_embeds
|
|
336
|
-
if inputs.dim() == 3:
|
|
337
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
338
|
-
# inputs_ids
|
|
339
|
-
else:
|
|
340
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
341
|
-
|
|
342
|
-
cache_position = torch.cat(
|
|
343
|
-
[
|
|
344
|
-
cache_position,
|
|
345
|
-
torch.arange(
|
|
346
|
-
query_length,
|
|
347
|
-
query_length + padding_size,
|
|
348
|
-
dtype=torch.int32,
|
|
349
|
-
).unsqueeze(0),
|
|
350
|
-
],
|
|
351
|
-
dim=-1,
|
|
352
|
-
)
|
|
353
|
-
|
|
354
|
-
if position_embed is not None:
|
|
355
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
356
|
-
|
|
357
|
-
if token_type_ids is not None:
|
|
358
|
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
359
|
-
|
|
360
|
-
# Overwrite position_ids and padded_cache_lengths
|
|
361
|
-
position_ids = cache_position.clone()
|
|
362
|
-
padded_cache_lengths = 0
|
|
363
|
-
|
|
364
|
-
return (
|
|
365
|
-
inputs,
|
|
366
|
-
cache_position,
|
|
367
|
-
chunked_attention_mask,
|
|
368
|
-
out_buffers,
|
|
369
|
-
position_ids,
|
|
370
|
-
position_embed,
|
|
371
|
-
padded_cache_lengths,
|
|
372
|
-
query_length,
|
|
373
|
-
token_type_ids,
|
|
374
|
-
)
|
|
375
|
-
|
|
376
|
-
def prefill_forward(
|
|
377
|
-
self,
|
|
378
|
-
inputs: torch.Tensor,
|
|
379
|
-
cache_position: torch.Tensor = None,
|
|
380
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
381
|
-
batch_idx: int = None,
|
|
382
|
-
block_tables: torch.Tensor = None,
|
|
383
|
-
is_external_block_tables: bool = False,
|
|
384
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
385
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
386
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
387
|
-
) -> torch.FloatTensor:
|
|
388
|
-
"""
|
|
389
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
390
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
391
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
392
|
-
"""
|
|
393
|
-
(
|
|
394
|
-
inputs,
|
|
395
|
-
cache_position,
|
|
396
|
-
chunked_attention_mask,
|
|
397
|
-
out_buffers,
|
|
398
|
-
position_ids,
|
|
399
|
-
position_embed,
|
|
400
|
-
padded_cache_lengths,
|
|
401
|
-
query_length,
|
|
402
|
-
token_type_ids,
|
|
403
|
-
) = self._prepare_prefill_inputs(
|
|
404
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
405
|
-
)
|
|
406
|
-
|
|
407
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
408
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
409
|
-
# Extract the current chunk of inputs and cache positions
|
|
410
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
411
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
412
|
-
position_ids_chunk = (
|
|
413
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
414
|
-
if position_ids is not None
|
|
415
|
-
else None
|
|
416
|
-
)
|
|
417
|
-
if position_embed is not None:
|
|
418
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
|
419
|
-
|
|
420
|
-
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
421
|
-
# Update attention mask to ensure proper causal behavior
|
|
422
|
-
if step >= self.rbln_config.prefill_chunk_size:
|
|
423
|
-
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
|
424
|
-
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
425
|
-
|
|
426
|
-
# Define query position
|
|
427
|
-
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
|
428
|
-
query_position = torch.tensor(
|
|
429
|
-
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
|
430
|
-
)
|
|
431
|
-
else:
|
|
432
|
-
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
|
433
|
-
|
|
434
|
-
# Forward pass for the current chunk
|
|
435
|
-
logits = super().forward(
|
|
436
|
-
input_chunk,
|
|
437
|
-
cache_pos_chunk,
|
|
438
|
-
block_tables,
|
|
439
|
-
local_block_tables,
|
|
440
|
-
position_embed_chunk if position_embed is not None else None,
|
|
441
|
-
query_position,
|
|
442
|
-
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
443
|
-
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
|
444
|
-
out=out_buffers,
|
|
445
|
-
)
|
|
446
|
-
|
|
447
|
-
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
448
|
-
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
449
|
-
self.dec_attn_mask[batch_idx].fill_(0)
|
|
450
|
-
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
451
|
-
|
|
452
|
-
return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
@dataclass
|
|
456
|
-
class RBLNDecoderOnlyOutput(ModelOutput):
|
|
457
|
-
logits: torch.FloatTensor = None
|
|
458
|
-
generate_idx: torch.Tensor = None
|
|
459
|
-
padded_cache_lengths: int = None
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
49
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
463
50
|
"""
|
|
464
|
-
A base class for decoder-only transformer models
|
|
51
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
52
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
465
53
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
466
54
|
|
|
467
55
|
The class provides core functionality for:
|
|
468
56
|
|
|
469
57
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
470
58
|
2. Handling the compilation process for RBLN devices
|
|
471
|
-
3. Managing inference operations for
|
|
472
|
-
|
|
59
|
+
3. Managing inference operations for decoder-only architectures
|
|
473
60
|
This class inherits from RBLNModel and implements specific methods required for
|
|
474
|
-
decoder-only architectures
|
|
61
|
+
decoder-only architectures.
|
|
475
62
|
|
|
476
63
|
Note:
|
|
477
64
|
- This class is designed to be subclassed by specific model implementations
|
|
478
|
-
(e.g.,
|
|
65
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
479
66
|
- Subclasses should implement model-specific conversion logic.
|
|
480
67
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
481
68
|
"""
|
|
482
69
|
|
|
70
|
+
_tp_support = True
|
|
71
|
+
|
|
483
72
|
main_input_name = "input_ids"
|
|
484
|
-
auto_model_class =
|
|
73
|
+
auto_model_class = AutoModel
|
|
485
74
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
486
75
|
_use_rotary_emb = True
|
|
76
|
+
_supports_non_fp32 = True
|
|
487
77
|
|
|
488
78
|
def __post_init__(self, **kwargs):
|
|
489
|
-
main_input_name = self.main_input_name
|
|
490
|
-
|
|
491
79
|
if self.rbln_config.use_inputs_embeds:
|
|
492
|
-
main_input_name = "inputs_embeds"
|
|
493
80
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
494
81
|
self.embed_tokens = self._create_embedding_layer()
|
|
495
82
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
496
83
|
else:
|
|
497
84
|
self.embed_tokens = None
|
|
498
85
|
|
|
499
|
-
|
|
500
|
-
dec_attn_mask = torch.zeros(
|
|
501
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
502
|
-
)
|
|
503
|
-
block_tables = torch.zeros(
|
|
504
|
-
self.rbln_config.batch_size,
|
|
505
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
506
|
-
dtype=torch.int16,
|
|
507
|
-
).fill_(-1)
|
|
508
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
86
|
+
self.setup_runtime()
|
|
509
87
|
|
|
88
|
+
def setup_runtime(self):
|
|
89
|
+
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
90
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
91
|
+
dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
|
|
92
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
|
|
93
|
+
|
|
94
|
+
common_kwargs = {
|
|
95
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
96
|
+
"embed_tokens": self.embed_tokens,
|
|
97
|
+
"dec_attn_mask": dec_attn_mask,
|
|
98
|
+
"page_table_manager": page_table_manager,
|
|
99
|
+
"rbln_config": self.rbln_config,
|
|
100
|
+
}
|
|
510
101
|
self.prefill_decoder = RBLNRuntimeModel(
|
|
511
102
|
runtime=self.model[0],
|
|
512
|
-
main_input_name=main_input_name,
|
|
513
|
-
embed_tokens=self.embed_tokens,
|
|
514
103
|
phase="prefill",
|
|
515
104
|
batch_size=self.rbln_config.batch_size,
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
free_block_pool=free_block_pool,
|
|
519
|
-
rbln_config=self.rbln_config,
|
|
520
|
-
vocab_size=self.config.vocab_size,
|
|
105
|
+
out_buffers=out_buffers,
|
|
106
|
+
**common_kwargs,
|
|
521
107
|
)
|
|
108
|
+
if self.can_generate():
|
|
109
|
+
self.decoders = {}
|
|
110
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
111
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
112
|
+
runtime=self.model[i + 1],
|
|
113
|
+
phase="decode",
|
|
114
|
+
batch_size=batch_size,
|
|
115
|
+
**common_kwargs,
|
|
116
|
+
)
|
|
522
117
|
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
526
|
-
runtime=self.model[i + 1],
|
|
527
|
-
main_input_name=main_input_name,
|
|
528
|
-
embed_tokens=self.embed_tokens,
|
|
529
|
-
phase="decode",
|
|
530
|
-
batch_size=batch_size,
|
|
531
|
-
dec_attn_mask=dec_attn_mask,
|
|
532
|
-
block_tables=block_tables,
|
|
533
|
-
free_block_pool=free_block_pool,
|
|
534
|
-
rbln_config=self.rbln_config,
|
|
535
|
-
)
|
|
536
|
-
|
|
537
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
538
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
539
|
-
|
|
540
|
-
@classmethod
|
|
541
|
-
def save_torch_artifacts(
|
|
542
|
-
cls,
|
|
543
|
-
model: PreTrainedModel,
|
|
544
|
-
save_dir_path: Path,
|
|
545
|
-
subfolder: str,
|
|
546
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
547
|
-
):
|
|
548
|
-
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
549
|
-
# store the torch tensor, weight, etc. in this function.
|
|
550
|
-
if rbln_config.use_inputs_embeds:
|
|
551
|
-
save_dict = {}
|
|
552
|
-
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
553
|
-
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
554
|
-
|
|
555
|
-
def _create_embedding_layer(self):
|
|
556
|
-
with no_init_weights():
|
|
557
|
-
embed_tokens = torch.nn.Embedding(
|
|
558
|
-
self.config.vocab_size,
|
|
559
|
-
self.config.hidden_size,
|
|
560
|
-
self.config.pad_token_id,
|
|
561
|
-
)
|
|
562
|
-
return embed_tokens
|
|
563
|
-
|
|
564
|
-
def get_input_embeddings(self):
|
|
565
|
-
return self.embed_tokens
|
|
566
|
-
|
|
567
|
-
def get_attn_impl(self) -> str:
|
|
568
|
-
return self.rbln_config.attn_impl
|
|
118
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
119
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
569
120
|
|
|
570
|
-
|
|
571
|
-
|
|
121
|
+
@property
|
|
122
|
+
def prefill_output_size(self):
|
|
123
|
+
return (
|
|
124
|
+
1,
|
|
125
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
+
self.config.hidden_size,
|
|
127
|
+
)
|
|
572
128
|
|
|
573
129
|
@classmethod
|
|
574
130
|
def get_quantized_model(
|
|
@@ -582,35 +138,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
582
138
|
subfolder: str = "",
|
|
583
139
|
local_files_only: bool = False,
|
|
584
140
|
trust_remote_code: bool = False,
|
|
141
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
585
142
|
**kwargs,
|
|
586
143
|
):
|
|
587
144
|
kwargs = cls.update_kwargs(kwargs)
|
|
588
145
|
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
model_id,
|
|
592
|
-
use_auth_token=use_auth_token,
|
|
593
|
-
revision=revision,
|
|
594
|
-
force_download=force_download,
|
|
595
|
-
cache_dir=cache_dir,
|
|
596
|
-
trust_remote_code=trust_remote_code,
|
|
597
|
-
**kwargs,
|
|
598
|
-
)
|
|
599
|
-
|
|
600
|
-
with no_init_weights():
|
|
601
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
602
|
-
|
|
603
|
-
model = prepare_model_for_quantization(
|
|
604
|
-
model,
|
|
146
|
+
return get_quantized_model(
|
|
147
|
+
cls.auto_model_class,
|
|
605
148
|
model_id,
|
|
606
|
-
kwargs.get("num_hidden_layers"),
|
|
607
149
|
use_auth_token=use_auth_token,
|
|
608
150
|
revision=revision,
|
|
609
151
|
cache_dir=cache_dir,
|
|
610
152
|
force_download=force_download,
|
|
611
153
|
local_files_only=local_files_only,
|
|
154
|
+
rbln_quantization=rbln_config.quantization,
|
|
155
|
+
**kwargs,
|
|
612
156
|
)
|
|
613
|
-
return model
|
|
614
157
|
|
|
615
158
|
def __getattr__(self, __name: str) -> Any:
|
|
616
159
|
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
@@ -632,233 +175,162 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
632
175
|
return val
|
|
633
176
|
|
|
634
177
|
@classmethod
|
|
635
|
-
def
|
|
636
|
-
cls,
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
178
|
+
def save_torch_artifacts(
|
|
179
|
+
cls,
|
|
180
|
+
model: PreTrainedModel,
|
|
181
|
+
save_dir_path: Path,
|
|
182
|
+
subfolder: str,
|
|
183
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
184
|
+
):
|
|
185
|
+
# If you are unavoidably running on a CPU rather than an RBLN device,
|
|
186
|
+
# store the torch tensor, weight, etc. in this function.
|
|
187
|
+
if rbln_config.use_inputs_embeds:
|
|
188
|
+
save_dict = {}
|
|
189
|
+
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
|
190
|
+
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
|
642
191
|
|
|
643
|
-
|
|
192
|
+
def _create_embedding_layer(self):
|
|
193
|
+
with no_init_weights():
|
|
194
|
+
embed_tokens = torch.nn.Embedding(
|
|
195
|
+
self.config.vocab_size,
|
|
196
|
+
self.config.hidden_size,
|
|
197
|
+
self.config.pad_token_id,
|
|
198
|
+
)
|
|
199
|
+
return embed_tokens
|
|
200
|
+
|
|
201
|
+
def get_decoder(self):
|
|
202
|
+
if not self.can_generate():
|
|
203
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
204
|
+
return self.decoder
|
|
205
|
+
|
|
206
|
+
def can_generate(self):
|
|
207
|
+
return self.rbln_config.can_generate
|
|
208
|
+
|
|
209
|
+
def get_input_embeddings(self):
|
|
210
|
+
return self.embed_tokens
|
|
211
|
+
|
|
212
|
+
def get_attn_impl(self) -> str:
|
|
213
|
+
return self.rbln_config.attn_impl
|
|
214
|
+
|
|
215
|
+
def get_kvcache_num_blocks(self) -> int:
|
|
216
|
+
return self.rbln_config.kvcache_num_blocks
|
|
644
217
|
|
|
645
218
|
@classmethod
|
|
646
|
-
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
219
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
220
|
+
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
221
|
+
|
|
222
|
+
@classmethod
|
|
223
|
+
def _compile_model(
|
|
224
|
+
cls,
|
|
225
|
+
wrapped_model,
|
|
226
|
+
compile_config,
|
|
227
|
+
example_inputs,
|
|
228
|
+
compile_context,
|
|
229
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
230
|
+
quantization=None,
|
|
231
|
+
phase: str = "prefill",
|
|
232
|
+
):
|
|
233
|
+
try:
|
|
234
|
+
wrapped_model.phase = phase
|
|
235
|
+
if quantization:
|
|
236
|
+
quantization.maybe_set_quantization_env()
|
|
237
|
+
original_linear = torch.nn.functional.linear
|
|
238
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
239
|
+
compiled_model = cls.compile(
|
|
240
|
+
wrapped_model,
|
|
241
|
+
compile_config,
|
|
242
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
243
|
+
device=rbln_config.device,
|
|
244
|
+
example_inputs=example_inputs,
|
|
245
|
+
compile_context=compile_context,
|
|
246
|
+
)
|
|
247
|
+
return compiled_model
|
|
248
|
+
finally:
|
|
249
|
+
torch.nn.functional.linear = original_linear
|
|
250
|
+
if quantization:
|
|
251
|
+
quantization.maybe_reset_quantization_env()
|
|
252
|
+
|
|
253
|
+
@classmethod
|
|
254
|
+
def _get_compile_context(
|
|
255
|
+
cls,
|
|
256
|
+
compile_config: RBLNCompileConfig,
|
|
257
|
+
example_inputs: List[torch.Tensor],
|
|
258
|
+
):
|
|
259
|
+
context = CompileContext(use_weight_sharing=True)
|
|
260
|
+
|
|
261
|
+
# Mark static tensors (self kv states)
|
|
262
|
+
static_tensors = {}
|
|
263
|
+
idx = 0
|
|
264
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
265
|
+
if "past_key_values" in name:
|
|
266
|
+
static_tensors[name] = tensor
|
|
267
|
+
context.mark_static_address(tensor, f"kv_cache_{idx}")
|
|
268
|
+
idx += 1
|
|
269
|
+
|
|
270
|
+
return context, static_tensors
|
|
661
271
|
|
|
662
272
|
@classmethod
|
|
663
273
|
@torch.inference_mode()
|
|
664
274
|
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
665
275
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
666
|
-
|
|
667
|
-
rbln_compile_configs = rbln_config.compile_cfgs
|
|
668
|
-
prefill_compile_config = rbln_compile_configs[0]
|
|
669
|
-
|
|
670
|
-
context = CompileContext(use_weight_sharing=True)
|
|
276
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
671
277
|
|
|
672
278
|
# Here we use meta tensor, for the memory efficiency.
|
|
673
279
|
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
674
280
|
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
281
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
282
|
+
|
|
283
|
+
compiled_models = {}
|
|
284
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
285
|
+
wrapped_model,
|
|
286
|
+
prefill_compile_config,
|
|
287
|
+
prefill_example_inputs,
|
|
288
|
+
context,
|
|
289
|
+
rbln_config,
|
|
290
|
+
rbln_config.quantization,
|
|
291
|
+
phase="prefill",
|
|
292
|
+
)
|
|
675
293
|
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
context.mark_static_address(tensor)
|
|
682
|
-
|
|
683
|
-
def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
|
|
684
|
-
try:
|
|
685
|
-
if quantization:
|
|
686
|
-
quantization.maybe_set_quantization_env()
|
|
687
|
-
original_linear = torch.nn.functional.linear
|
|
688
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
689
|
-
compiled_model = cls.compile(
|
|
294
|
+
if rbln_config.can_generate:
|
|
295
|
+
wrapped_model.phase = "decode"
|
|
296
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
297
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
298
|
+
compiled_decoder = cls._compile_model(
|
|
690
299
|
wrapped_model,
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
300
|
+
dec_compile_config,
|
|
301
|
+
dec_example_inputs,
|
|
302
|
+
context,
|
|
303
|
+
rbln_config,
|
|
304
|
+
rbln_config.quantization,
|
|
305
|
+
phase="decode",
|
|
306
|
+
)
|
|
307
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
308
|
+
|
|
309
|
+
# check if the memory is enough to have additional blocks
|
|
310
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
311
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
312
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
313
|
+
compiled_models=compiled_models,
|
|
314
|
+
model_config=model.config,
|
|
315
|
+
rbln_config=rbln_config,
|
|
696
316
|
)
|
|
697
|
-
return compiled_model
|
|
698
|
-
finally:
|
|
699
|
-
torch.nn.functional.linear = original_linear
|
|
700
|
-
if quantization:
|
|
701
|
-
quantization.maybe_reset_quantization_env()
|
|
702
|
-
|
|
703
|
-
wrapped_model.phase = "prefill"
|
|
704
|
-
compiled_prefill = compile_model(
|
|
705
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
706
|
-
)
|
|
707
|
-
|
|
708
|
-
wrapped_model.phase = "decode"
|
|
709
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
710
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
711
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
712
|
-
compiled_decoder = compile_model(
|
|
713
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
714
|
-
)
|
|
715
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
716
|
-
|
|
717
|
-
# check if the memory is enough to have additional blocks
|
|
718
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
719
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
720
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
721
|
-
compiled_models=compiled_models,
|
|
722
|
-
model_config=model.config,
|
|
723
|
-
rbln_config=rbln_config,
|
|
724
|
-
)
|
|
725
317
|
|
|
726
318
|
return compiled_models
|
|
727
319
|
|
|
728
320
|
@classmethod
|
|
729
|
-
def
|
|
730
|
-
cls,
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
|
737
|
-
alloc_memory_by_key: Dict[str, int] = {
|
|
738
|
-
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
|
739
|
-
}
|
|
740
|
-
for batch_size in rbln_config.decoder_batch_sizes:
|
|
741
|
-
for key, memory_per_node in (
|
|
742
|
-
compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
|
|
743
|
-
):
|
|
744
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
745
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
746
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
747
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
748
|
-
|
|
749
|
-
# Get the maximum number of blocks that can be allocated
|
|
750
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
751
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
752
|
-
config=model_config,
|
|
753
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
754
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
755
|
-
kernel_size=kernel_size,
|
|
756
|
-
buffer=buffer,
|
|
757
|
-
)
|
|
321
|
+
def get_pytorch_model(
|
|
322
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
323
|
+
) -> PreTrainedModel:
|
|
324
|
+
if rbln_config and rbln_config.quantization:
|
|
325
|
+
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
326
|
+
else:
|
|
327
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
758
328
|
|
|
759
|
-
|
|
760
|
-
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
761
|
-
# If the memory is not enough, the model will fail to compile.
|
|
762
|
-
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
763
|
-
logger.warning(
|
|
764
|
-
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
765
|
-
"Our analysis indicates that additional memory is available for more blocks. "
|
|
766
|
-
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
767
|
-
"Please be advised that our memory estimation algorithm has limitations, "
|
|
768
|
-
"and increasing this value may not guarantee successful model compilation."
|
|
769
|
-
)
|
|
329
|
+
return model
|
|
770
330
|
|
|
771
331
|
@classmethod
|
|
772
|
-
def
|
|
773
|
-
|
|
774
|
-
config: PretrainedConfig,
|
|
775
|
-
tensor_parallel_size: int,
|
|
776
|
-
kvcache_block_size: int,
|
|
777
|
-
nbits_per_param: Optional[int] = None,
|
|
778
|
-
n_model_params: Optional[int] = None,
|
|
779
|
-
kernel_size: Optional[int] = None,
|
|
780
|
-
buffer: Optional[int] = None,
|
|
781
|
-
num_runtimes: int = 2,
|
|
782
|
-
) -> int:
|
|
783
|
-
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
784
|
-
|
|
785
|
-
# available_dram - kernel_size - buffer
|
|
786
|
-
# - num_layers * 2 * tensor_parallel_size
|
|
787
|
-
# * align_2MB(
|
|
788
|
-
# x
|
|
789
|
-
# * block_size
|
|
790
|
-
# * align_64(head_dim)
|
|
791
|
-
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
792
|
-
# * 2
|
|
793
|
-
# ) > 0
|
|
794
|
-
|
|
795
|
-
# This inequality can be rewritten as follows:
|
|
796
|
-
|
|
797
|
-
# a - c * align_2MB(b * x) > 0
|
|
798
|
-
# where
|
|
799
|
-
# a = available_dram - kernel_size - buffer
|
|
800
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
801
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
802
|
-
|
|
803
|
-
# We can rewrite the inequality as follows:
|
|
804
|
-
# k > align_2MB(b*x)
|
|
805
|
-
# where
|
|
806
|
-
# k = a / c
|
|
807
|
-
|
|
808
|
-
# After that, we can derive the following equation:
|
|
809
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
810
|
-
|
|
811
|
-
def align(x: int, nbytes: int) -> int:
|
|
812
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
813
|
-
|
|
814
|
-
def align_2MB(x: int) -> int:
|
|
815
|
-
return align(x, 2**21)
|
|
816
|
-
|
|
817
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
818
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
819
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
820
|
-
vocab_size = config.vocab_size
|
|
821
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
822
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
823
|
-
|
|
824
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
825
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
826
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
827
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
828
|
-
|
|
829
|
-
if kernel_size is None:
|
|
830
|
-
if n_model_params is None:
|
|
831
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
832
|
-
# Get estimated kernel size (approximated)
|
|
833
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
834
|
-
lm_heads_nbytes = (
|
|
835
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
836
|
-
)
|
|
837
|
-
params = n_model_params - lm_heads_params
|
|
838
|
-
layer_nbytes = (
|
|
839
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
840
|
-
* num_layers
|
|
841
|
-
* tensor_parallel_size
|
|
842
|
-
)
|
|
843
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
844
|
-
elif n_model_params is not None:
|
|
845
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
846
|
-
|
|
847
|
-
available_dram -= kernel_size
|
|
848
|
-
|
|
849
|
-
if buffer is None:
|
|
850
|
-
# TODO: Accurate buffer estimation
|
|
851
|
-
buffer_per_runtime_per_core = 2**28 # 256MB per runtime
|
|
852
|
-
buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
|
|
853
|
-
buffer = buffer_per_core * tensor_parallel_size
|
|
854
|
-
available_dram -= buffer
|
|
855
|
-
|
|
856
|
-
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
857
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
858
|
-
k = available_dram / c
|
|
859
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
860
|
-
|
|
861
|
-
return max_n_blocks
|
|
332
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
333
|
+
return use_local_attention
|
|
862
334
|
|
|
863
335
|
@classmethod
|
|
864
336
|
def get_input_info(
|
|
@@ -868,63 +340,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
868
340
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
869
341
|
model_config: PretrainedConfig,
|
|
870
342
|
):
|
|
871
|
-
is_prefill: bool = query_length > 1
|
|
872
343
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
873
344
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
874
345
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
875
346
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
876
347
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
877
|
-
|
|
348
|
+
is_prefill = query_length > 1
|
|
878
349
|
|
|
879
|
-
|
|
350
|
+
input_info = []
|
|
880
351
|
if rbln_config.use_inputs_embeds:
|
|
881
|
-
|
|
352
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
|
|
882
353
|
else:
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
# 2. cache_position
|
|
886
|
-
input_info = [
|
|
887
|
-
main_input,
|
|
888
|
-
(
|
|
889
|
-
"cache_position",
|
|
890
|
-
[batch_size, query_length],
|
|
891
|
-
"int32",
|
|
892
|
-
),
|
|
893
|
-
]
|
|
354
|
+
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
894
355
|
|
|
895
|
-
|
|
896
|
-
|
|
356
|
+
input_info.append(("cache_position", [batch_size, query_length], "int32"))
|
|
357
|
+
|
|
358
|
+
if rbln_config.use_global_attention:
|
|
897
359
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
898
|
-
input_info.
|
|
899
|
-
|
|
360
|
+
input_info.append(
|
|
361
|
+
("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
|
|
900
362
|
)
|
|
901
|
-
if rbln_config.
|
|
902
|
-
input_info.
|
|
363
|
+
if rbln_config.use_local_attention:
|
|
364
|
+
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
903
365
|
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
input_info.extend([("query_position", [], "int16")])
|
|
366
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
367
|
+
input_info.append(("query_position", [], "int16"))
|
|
907
368
|
|
|
908
|
-
# 5. attention_mask & position_ids
|
|
909
369
|
if rbln_config.use_attention_mask:
|
|
910
|
-
|
|
911
|
-
[
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
370
|
+
if rbln_config.use_position_ids:
|
|
371
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
|
|
372
|
+
else:
|
|
373
|
+
input_info.append(
|
|
374
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
|
|
375
|
+
)
|
|
376
|
+
|
|
917
377
|
if rbln_config.use_position_ids:
|
|
918
378
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
919
379
|
|
|
920
|
-
|
|
380
|
+
if rbln_config.use_lora:
|
|
381
|
+
input_info.append(("lora_int_ids", [batch_size], "int32"))
|
|
382
|
+
|
|
383
|
+
kvcache_dtype = rbln_config.torch_dtype
|
|
384
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
385
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
386
|
+
|
|
921
387
|
global_kvcache_shape = [
|
|
922
388
|
rbln_config.kvcache_num_blocks,
|
|
923
389
|
num_key_value_heads,
|
|
924
390
|
rbln_config.kvcache_block_size,
|
|
925
391
|
head_dim,
|
|
926
392
|
]
|
|
927
|
-
local_kvcache_shape = [
|
|
393
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
928
394
|
input_info.extend(
|
|
929
395
|
[
|
|
930
396
|
(
|
|
@@ -932,7 +398,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
932
398
|
local_kvcache_shape
|
|
933
399
|
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
|
934
400
|
else global_kvcache_shape,
|
|
935
|
-
|
|
401
|
+
kvcache_dtype,
|
|
936
402
|
)
|
|
937
403
|
for i in range(num_hidden_layers * 2)
|
|
938
404
|
]
|
|
@@ -971,7 +437,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
971
437
|
# ```
|
|
972
438
|
|
|
973
439
|
# Returns:
|
|
974
|
-
#
|
|
440
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
975
441
|
|
|
976
442
|
raise NotImplementedError(
|
|
977
443
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
@@ -979,27 +445,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
979
445
|
)
|
|
980
446
|
|
|
981
447
|
@classmethod
|
|
982
|
-
def
|
|
983
|
-
cls,
|
|
984
|
-
|
|
985
|
-
model: Optional[PreTrainedModel] = None,
|
|
986
|
-
model_config: Optional[PretrainedConfig] = None,
|
|
987
|
-
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
988
|
-
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
989
|
-
if rbln_config.max_seq_len is None:
|
|
990
|
-
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
991
|
-
model_config, "n_positions", None
|
|
992
|
-
)
|
|
993
|
-
if rbln_config.max_seq_len is None:
|
|
994
|
-
raise ValueError("`max_seq_len` should be specified.")
|
|
995
|
-
|
|
996
|
-
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
997
|
-
model_config, "use_sliding_window", True
|
|
998
|
-
):
|
|
999
|
-
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
1000
|
-
if rbln_config.sliding_window is not None:
|
|
1001
|
-
validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
|
|
1002
|
-
|
|
448
|
+
def _update_attention_config(
|
|
449
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
450
|
+
):
|
|
1003
451
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
1004
452
|
attn_impl=rbln_config.attn_impl,
|
|
1005
453
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1014,9 +462,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1014
462
|
max_seq_len=rbln_config.max_seq_len,
|
|
1015
463
|
)
|
|
1016
464
|
|
|
1017
|
-
|
|
1018
|
-
max_num_blocks = required_num_blocks
|
|
465
|
+
num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1019
466
|
|
|
467
|
+
# Update kvcache_num_blocks based on the attention implementation.
|
|
1020
468
|
if rbln_config.attn_impl == "flash_attn":
|
|
1021
469
|
estimated_max_num_blocks = cls.get_maximum_num_blocks(
|
|
1022
470
|
config=model_config,
|
|
@@ -1024,30 +472,73 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1024
472
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1025
473
|
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1026
474
|
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1027
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
475
|
+
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
1028
476
|
)
|
|
1029
477
|
|
|
1030
|
-
|
|
478
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
479
|
+
if estimated_max_num_blocks < num_full_blocks:
|
|
480
|
+
# lower bound of the number of blocks for flash attention.
|
|
481
|
+
min_blocks_for_flash = min(
|
|
482
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
|
|
483
|
+
)
|
|
484
|
+
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
485
|
+
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
486
|
+
# Even if it's larger than the estimated maximum number of blocks.
|
|
487
|
+
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
488
|
+
else:
|
|
489
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
490
|
+
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
491
|
+
|
|
492
|
+
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
493
|
+
raise RuntimeError(
|
|
494
|
+
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
495
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
|
496
|
+
)
|
|
497
|
+
else:
|
|
498
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
499
|
+
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
500
|
+
logger.warning(
|
|
501
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
502
|
+
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
503
|
+
"This can cause a failure during model compilation."
|
|
504
|
+
)
|
|
505
|
+
else:
|
|
506
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
507
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
508
|
+
elif rbln_config.kvcache_num_blocks > num_full_blocks:
|
|
509
|
+
logger.warning(
|
|
510
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
511
|
+
f" than the required number of blocks ({num_full_blocks})."
|
|
512
|
+
"This can cause a failure during model compilation."
|
|
513
|
+
)
|
|
1031
514
|
|
|
1032
|
-
|
|
1033
|
-
if max_num_blocks < flash_min_blocks:
|
|
1034
|
-
max_num_blocks = flash_min_blocks
|
|
515
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1035
516
|
|
|
1036
|
-
|
|
1037
|
-
raise RuntimeError(
|
|
1038
|
-
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
|
1039
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
|
1040
|
-
)
|
|
517
|
+
return rbln_config
|
|
1041
518
|
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1047
|
-
|
|
1048
|
-
|
|
519
|
+
@classmethod
|
|
520
|
+
def _update_rbln_config(
|
|
521
|
+
cls,
|
|
522
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
523
|
+
model: Optional[PreTrainedModel] = None,
|
|
524
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
525
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
526
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
527
|
+
if rbln_config.max_seq_len is None:
|
|
528
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
|
529
|
+
model_config, "n_positions", None
|
|
1049
530
|
)
|
|
1050
|
-
|
|
531
|
+
if rbln_config.max_seq_len is None:
|
|
532
|
+
raise ValueError("`max_seq_len` should be specified.")
|
|
533
|
+
|
|
534
|
+
if getattr(model_config, "sliding_window", None) is not None and getattr(
|
|
535
|
+
model_config, "use_sliding_window", True
|
|
536
|
+
):
|
|
537
|
+
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
538
|
+
if rbln_config.sliding_window is not None:
|
|
539
|
+
validate_sliding_window(rbln_config)
|
|
540
|
+
|
|
541
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
1051
542
|
|
|
1052
543
|
prefill_input_info = cls.get_input_info(
|
|
1053
544
|
batch_size=1,
|
|
@@ -1057,19 +548,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1057
548
|
)
|
|
1058
549
|
|
|
1059
550
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
551
|
+
compile_cfgs = [prefill_compile_config]
|
|
552
|
+
|
|
553
|
+
if rbln_config.can_generate:
|
|
554
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
555
|
+
dec_input_info = cls.get_input_info(
|
|
556
|
+
batch_size=batch_size,
|
|
557
|
+
query_length=1,
|
|
558
|
+
rbln_config=rbln_config,
|
|
559
|
+
model_config=model_config,
|
|
560
|
+
)
|
|
561
|
+
compile_cfgs.append(
|
|
562
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
563
|
+
)
|
|
564
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
1073
565
|
|
|
1074
566
|
return rbln_config
|
|
1075
567
|
|
|
@@ -1079,103 +571,153 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1079
571
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1080
572
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1081
573
|
) -> List[rebel.Runtime]:
|
|
1082
|
-
expected_model_names = [
|
|
1083
|
-
|
|
1084
|
-
|
|
1085
|
-
|
|
574
|
+
expected_model_names = ["prefill"]
|
|
575
|
+
if rbln_config.can_generate:
|
|
576
|
+
expected_model_names.extend(
|
|
577
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
578
|
+
)
|
|
1086
579
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1087
580
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1088
581
|
|
|
1089
|
-
|
|
582
|
+
ret_val = [
|
|
1090
583
|
rebel.Runtime(
|
|
1091
584
|
compiled_models[0],
|
|
1092
585
|
tensor_type="pt",
|
|
1093
586
|
device=rbln_config.device_map["prefill"],
|
|
1094
587
|
activate_profiler=rbln_config.activate_profiler,
|
|
1095
588
|
timeout=rbln_config.timeout,
|
|
1096
|
-
)
|
|
1097
|
-
*[
|
|
1098
|
-
rebel.Runtime(
|
|
1099
|
-
compiled_models[i + 1],
|
|
1100
|
-
tensor_type="pt",
|
|
1101
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1102
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1103
|
-
timeout=rbln_config.timeout,
|
|
1104
|
-
)
|
|
1105
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1106
|
-
],
|
|
589
|
+
)
|
|
1107
590
|
]
|
|
591
|
+
if rbln_config.can_generate:
|
|
592
|
+
ret_val.extend(
|
|
593
|
+
[
|
|
594
|
+
rebel.Runtime(
|
|
595
|
+
compiled_models[i + 1],
|
|
596
|
+
tensor_type="pt",
|
|
597
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
598
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
599
|
+
timeout=rbln_config.timeout,
|
|
600
|
+
)
|
|
601
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
602
|
+
]
|
|
603
|
+
)
|
|
604
|
+
return ret_val
|
|
1108
605
|
|
|
1109
|
-
def
|
|
1110
|
-
return self.decoder
|
|
1111
|
-
|
|
1112
|
-
def can_generate(self):
|
|
1113
|
-
return True
|
|
1114
|
-
|
|
1115
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1116
|
-
raise NotImplementedError
|
|
1117
|
-
|
|
1118
|
-
def prepare_inputs_for_generation(
|
|
606
|
+
def forward(
|
|
1119
607
|
self,
|
|
1120
|
-
input_ids: torch.LongTensor,
|
|
1121
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
1122
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
608
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
1123
609
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
1124
|
-
|
|
610
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
611
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
1125
612
|
**kwargs,
|
|
1126
|
-
):
|
|
1127
|
-
|
|
1128
|
-
|
|
613
|
+
) -> Tuple[torch.FloatTensor]:
|
|
614
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
615
|
+
batch_size = inputs.shape[0]
|
|
1129
616
|
|
|
1130
|
-
if
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1143
|
-
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
617
|
+
if batch_size != self.rbln_config.batch_size:
|
|
618
|
+
raise ValueError(
|
|
619
|
+
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
all_last_hidden_states = []
|
|
623
|
+
for b_idx in range(self.rbln_config.batch_size):
|
|
624
|
+
query_length = (
|
|
625
|
+
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
626
|
+
)
|
|
627
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
628
|
+
last_hidden_states = self.prefill_decoder(
|
|
629
|
+
inputs[b_idx : b_idx + 1],
|
|
630
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
631
|
+
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
632
|
+
cache_position=cache_position,
|
|
633
|
+
batch_idx=b_idx,
|
|
634
|
+
).logits
|
|
635
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
636
|
+
|
|
637
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
638
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
642
|
+
"""
|
|
643
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
644
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
645
|
+
|
|
646
|
+
The class provides core functionality for:
|
|
647
|
+
|
|
648
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
649
|
+
2. Handling the compilation process for RBLN devices
|
|
650
|
+
3. Managing inference operations for causal language modeling
|
|
651
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
652
|
+
decoder-only architectures and causal language modeling tasks.
|
|
653
|
+
|
|
654
|
+
Note:
|
|
655
|
+
- This class is designed to be subclassed by specific model implementations
|
|
656
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
657
|
+
- Subclasses should implement model-specific conversion logic.
|
|
658
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
659
|
+
"""
|
|
660
|
+
|
|
661
|
+
auto_model_class = AutoModelForCausalLM
|
|
662
|
+
|
|
663
|
+
@property
|
|
664
|
+
def prefill_output_size(self):
|
|
665
|
+
return (
|
|
666
|
+
1,
|
|
667
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
668
|
+
self.config.vocab_size,
|
|
1164
669
|
)
|
|
1165
670
|
|
|
1166
|
-
|
|
671
|
+
@classmethod
|
|
672
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
673
|
+
return is_prefill
|
|
1167
674
|
|
|
1168
|
-
def
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
|
|
675
|
+
def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
|
|
676
|
+
if isinstance(lora_int_ids, int):
|
|
677
|
+
lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
|
|
678
|
+
elif isinstance(lora_int_ids, list):
|
|
679
|
+
lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
|
|
680
|
+
|
|
681
|
+
self.lora_int_ids = lora_int_ids
|
|
682
|
+
|
|
683
|
+
self.prefill_decoder.lora_int_ids = lora_int_ids
|
|
684
|
+
if self.rbln_config.can_generate:
|
|
685
|
+
for batch_size in self.rbln_config.decoder_batch_sizes:
|
|
686
|
+
self.decoders[batch_size].lora_int_ids = lora_int_ids
|
|
687
|
+
|
|
688
|
+
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
|
|
689
|
+
"""
|
|
690
|
+
Sets the active adapter(s) for the model using adapter name(s).
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
|
|
694
|
+
Can be a single adapter name or a list of adapter names.
|
|
1177
695
|
|
|
1178
|
-
|
|
696
|
+
Raises:
|
|
697
|
+
ValueError: If the model is not configured with LoRA or if the adapter name is not found.
|
|
698
|
+
"""
|
|
699
|
+
if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
|
|
700
|
+
raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
|
|
701
|
+
|
|
702
|
+
# Convert single adapter name to list for uniform processing
|
|
703
|
+
if isinstance(adapter_name, str):
|
|
704
|
+
adapter_names = [adapter_name]
|
|
705
|
+
else:
|
|
706
|
+
adapter_names = adapter_name
|
|
707
|
+
|
|
708
|
+
# Validate that all adapter names exist
|
|
709
|
+
available_adapters = {
|
|
710
|
+
adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
|
|
711
|
+
}
|
|
712
|
+
missing_adapters = [name for name in adapter_names if name not in available_adapters]
|
|
713
|
+
if missing_adapters:
|
|
714
|
+
raise ValueError(
|
|
715
|
+
f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
# Get the adapter IDs and set them
|
|
719
|
+
lora_int_ids = [available_adapters[name] for name in adapter_names]
|
|
720
|
+
self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
|
|
1179
721
|
|
|
1180
722
|
def forward(
|
|
1181
723
|
self,
|
|
@@ -1187,6 +729,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1187
729
|
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
1188
730
|
position_ids: Optional[torch.Tensor] = None,
|
|
1189
731
|
token_type_ids: Optional[torch.Tensor] = None,
|
|
732
|
+
lora_int_ids: Optional[torch.Tensor] = None,
|
|
1190
733
|
return_dict: Optional[torch.Tensor] = None,
|
|
1191
734
|
**kwargs,
|
|
1192
735
|
) -> Tuple[torch.FloatTensor]:
|
|
@@ -1194,16 +737,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1194
737
|
# For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
|
|
1195
738
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1196
739
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
740
|
+
if self.rbln_config.use_lora and lora_int_ids is None:
|
|
741
|
+
if self.lora_int_ids is None:
|
|
742
|
+
raise ValueError(
|
|
743
|
+
"lora_int_id is required when using LoRA. "
|
|
744
|
+
"You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
|
|
745
|
+
)
|
|
746
|
+
lora_int_ids = self.lora_int_ids
|
|
747
|
+
|
|
748
|
+
# for only use forward
|
|
749
|
+
if generate_idx is None:
|
|
750
|
+
generate_idx = (
|
|
751
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
752
|
+
if attention_mask is not None
|
|
753
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
754
|
+
)
|
|
755
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1197
756
|
|
|
1198
|
-
#
|
|
757
|
+
# Prefill
|
|
1199
758
|
if cache_position is None:
|
|
1200
759
|
logits = []
|
|
1201
760
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
1202
|
-
# for only use forward
|
|
1203
|
-
if generate_idx is None:
|
|
1204
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1205
|
-
if padded_cache_lengths is None:
|
|
1206
|
-
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1207
761
|
batch_size = inputs.shape[0]
|
|
1208
762
|
for b_idx in range(batch_size):
|
|
1209
763
|
cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
|
|
@@ -1214,6 +768,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1214
768
|
cache_position=cache_position,
|
|
1215
769
|
batch_idx=b_idx,
|
|
1216
770
|
token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
|
|
771
|
+
lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
|
|
1217
772
|
)
|
|
1218
773
|
padded_cache_lengths[b_idx] += output.padded_cache_lengths
|
|
1219
774
|
logits.append(output.logits)
|
|
@@ -1233,6 +788,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1233
788
|
inputs_embeds=inputs_embeds,
|
|
1234
789
|
cache_position=cache_position,
|
|
1235
790
|
position_ids=position_ids if self.rbln_config.use_position_ids else None,
|
|
791
|
+
lora_int_ids=lora_int_ids,
|
|
1236
792
|
).logits
|
|
1237
793
|
|
|
1238
794
|
if not return_dict:
|