optimum-rbln 0.8.2a7__py3-none-any.whl → 0.8.3a0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +8 -9
- optimum/rbln/__version__.py +16 -3
- optimum/rbln/configuration_utils.py +4 -4
- optimum/rbln/diffusers/__init__.py +1 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +1 -1
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +11 -6
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +2 -2
- optimum/rbln/modeling_base.py +12 -4
- optimum/rbln/ops/attn.py +158 -0
- optimum/rbln/ops/flash_attn.py +166 -0
- optimum/rbln/transformers/__init__.py +6 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/modeling_outputs.py +37 -0
- optimum/rbln/transformers/models/__init__.py +10 -16
- optimum/rbln/transformers/models/auto/__init__.py +1 -0
- optimum/rbln/transformers/models/auto/modeling_auto.py +7 -0
- optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
- optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +3 -3
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +43 -174
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -93
- optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +450 -0
- optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +88 -0
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +297 -987
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +9 -0
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +14 -3
- optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +217 -0
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -257
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +2 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/modeling_llama.py +12 -3
- optimum/rbln/transformers/models/llava/configuration_llava.py +2 -2
- optimum/rbln/transformers/models/llava/modeling_llava.py +53 -14
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +2 -2
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/opt/modeling_opt.py +2 -30
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +4 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +2 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +1 -3
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +2 -2
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +1 -4
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +6 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +4 -7
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +77 -3
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -4
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +19 -2
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +20 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -2
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +20 -1
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/transformers/utils/rbln_quantization.py +249 -46
- optimum/rbln/utils/runtime_utils.py +3 -3
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/RECORD +90 -86
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.2a7.dist-info → optimum_rbln-0.8.3a0.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,10 +13,8 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
from collections import deque
|
|
17
|
-
from dataclasses import dataclass
|
|
18
16
|
from pathlib import Path
|
|
19
|
-
from typing import TYPE_CHECKING, Any, Callable,
|
|
17
|
+
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
|
|
20
18
|
|
|
21
19
|
import rebel
|
|
22
20
|
import torch
|
|
@@ -24,21 +22,22 @@ from rebel.compile_context import CompileContext
|
|
|
24
22
|
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
25
23
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
24
|
from transformers.modeling_utils import no_init_weights
|
|
27
|
-
from transformers.utils import ModelOutput
|
|
28
25
|
|
|
29
26
|
from ....configuration_utils import RBLNCompileConfig
|
|
30
27
|
from ....modeling import RBLNModel
|
|
31
28
|
from ....utils.logging import get_logger
|
|
32
|
-
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
33
29
|
from ...modeling_attention_utils import (
|
|
34
30
|
RBLNDecoderOnlyFlashAttentionMixin,
|
|
35
31
|
set_default_values,
|
|
36
32
|
validate_attention_method,
|
|
37
33
|
validate_sliding_window,
|
|
38
34
|
)
|
|
35
|
+
from ...modeling_outputs import RBLNDecoderOnlyOutput
|
|
39
36
|
from ...utils.rbln_quantization import prepare_model_for_quantization
|
|
40
37
|
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
41
38
|
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
39
|
+
from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
|
|
40
|
+
from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
|
|
42
41
|
|
|
43
42
|
|
|
44
43
|
logger = get_logger()
|
|
@@ -47,419 +46,6 @@ if TYPE_CHECKING:
|
|
|
47
46
|
from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
|
|
48
47
|
|
|
49
48
|
|
|
50
|
-
class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
51
|
-
mandatory_members = ["main_input_name", "embed_tokens"]
|
|
52
|
-
|
|
53
|
-
def __init__(
|
|
54
|
-
self,
|
|
55
|
-
runtime: rebel.Runtime,
|
|
56
|
-
phase: str,
|
|
57
|
-
batch_size: int,
|
|
58
|
-
dec_attn_mask: torch.Tensor,
|
|
59
|
-
block_tables: torch.Tensor,
|
|
60
|
-
free_block_pool: Deque,
|
|
61
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
62
|
-
**kwargs: Any,
|
|
63
|
-
) -> None:
|
|
64
|
-
super().__init__(runtime, **kwargs)
|
|
65
|
-
self.phase = phase
|
|
66
|
-
self.batch_size = batch_size
|
|
67
|
-
self.rbln_config = rbln_config
|
|
68
|
-
|
|
69
|
-
# shared tensor between prefill and decode phase
|
|
70
|
-
self.dec_attn_mask = dec_attn_mask
|
|
71
|
-
self.block_tables = block_tables
|
|
72
|
-
self.free_block_pool = free_block_pool
|
|
73
|
-
|
|
74
|
-
self.empty_block = -1
|
|
75
|
-
if self.phase == "prefill":
|
|
76
|
-
vocab_size = kwargs.pop("vocab_size")
|
|
77
|
-
self.output_size = [1, 1, vocab_size]
|
|
78
|
-
self.causal_mask = 1 - torch.triu(
|
|
79
|
-
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
80
|
-
)
|
|
81
|
-
|
|
82
|
-
def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
|
|
83
|
-
"""
|
|
84
|
-
Manages and returns the KV cache block tables.
|
|
85
|
-
Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
|
|
86
|
-
|
|
87
|
-
Args:
|
|
88
|
-
cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
|
|
89
|
-
batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
|
|
90
|
-
|
|
91
|
-
Returns:
|
|
92
|
-
Updated block tables.
|
|
93
|
-
"""
|
|
94
|
-
|
|
95
|
-
NO_BLOCKS_ERROR = (
|
|
96
|
-
"No memory blocks are available for allocation. "
|
|
97
|
-
"The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
|
|
98
|
-
"This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
|
|
99
|
-
"Using vllm-rbln should fix this issue and enhance inference performance."
|
|
100
|
-
)
|
|
101
|
-
|
|
102
|
-
def update_block(batch_idx: int, block_idx: int):
|
|
103
|
-
"""
|
|
104
|
-
If the block is empty (empty_block), allocates a block from the free_block_pool.
|
|
105
|
-
"""
|
|
106
|
-
if self.block_tables[batch_idx][block_idx] == self.empty_block:
|
|
107
|
-
if self.free_block_pool:
|
|
108
|
-
block = self.free_block_pool.popleft()
|
|
109
|
-
self.block_tables[batch_idx][block_idx] = block
|
|
110
|
-
else:
|
|
111
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
112
|
-
|
|
113
|
-
def replace_empty_block(block_tables: torch.Tensor):
|
|
114
|
-
"""
|
|
115
|
-
Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
|
|
116
|
-
"""
|
|
117
|
-
if not torch.any(block_tables == self.empty_block):
|
|
118
|
-
return block_tables.clone()
|
|
119
|
-
elif self.free_block_pool:
|
|
120
|
-
_free_block = self.free_block_pool[0]
|
|
121
|
-
return torch.where(block_tables == self.empty_block, _free_block, block_tables)
|
|
122
|
-
else:
|
|
123
|
-
raise RuntimeError(NO_BLOCKS_ERROR)
|
|
124
|
-
|
|
125
|
-
def get_global_block_tables(batch_idx: int):
|
|
126
|
-
if self.rbln_config.cache_impl == "sliding_window":
|
|
127
|
-
return None
|
|
128
|
-
|
|
129
|
-
if self.phase == "prefill":
|
|
130
|
-
# Track previously used blocks and return them to the free_block_pool and
|
|
131
|
-
# reset the current batch's block table to empty blocks
|
|
132
|
-
prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
|
|
133
|
-
self.free_block_pool.extend(prev_blocks)
|
|
134
|
-
self.block_tables[batch_idx].fill_(self.empty_block)
|
|
135
|
-
|
|
136
|
-
# Get the start (s) and end (e) positions from cache_position and
|
|
137
|
-
# iterate over the cache positions to allocate necessary blocks
|
|
138
|
-
s, e = cache_position[0][0].item(), cache_position[0][-1].item()
|
|
139
|
-
for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
|
|
140
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
141
|
-
if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
|
|
142
|
-
raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
|
|
143
|
-
update_block(batch_idx, block_idx)
|
|
144
|
-
|
|
145
|
-
return replace_empty_block(self.block_tables[batch_idx])
|
|
146
|
-
# Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
|
|
147
|
-
else:
|
|
148
|
-
for b_idx in range(self.batch_size):
|
|
149
|
-
position = cache_position[b_idx][0].item()
|
|
150
|
-
block_idx = position // self.rbln_config.kvcache_block_size
|
|
151
|
-
update_block(b_idx, block_idx)
|
|
152
|
-
|
|
153
|
-
return replace_empty_block(self.block_tables)
|
|
154
|
-
|
|
155
|
-
def get_local_block_tables(batch_idx: int):
|
|
156
|
-
if self.rbln_config.cache_impl == "static":
|
|
157
|
-
return None
|
|
158
|
-
else:
|
|
159
|
-
return (
|
|
160
|
-
torch.tensor([batch_idx], dtype=torch.int16)
|
|
161
|
-
if self.phase == "prefill"
|
|
162
|
-
else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
|
|
163
|
-
)
|
|
164
|
-
|
|
165
|
-
return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
|
|
166
|
-
|
|
167
|
-
def is_external_block_tables(
|
|
168
|
-
self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
|
|
169
|
-
):
|
|
170
|
-
if self.rbln_config.cache_impl == "static" and block_tables is None:
|
|
171
|
-
return False
|
|
172
|
-
elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
|
|
173
|
-
return False
|
|
174
|
-
elif self.rbln_config.cache_impl == "hybrid":
|
|
175
|
-
if (block_tables is not None) != (local_block_tables is not None):
|
|
176
|
-
raise ValueError(
|
|
177
|
-
"Both block_tables and local_block_tables must be provided or neither of them must be provided."
|
|
178
|
-
)
|
|
179
|
-
elif block_tables is None and local_block_tables is None:
|
|
180
|
-
return False
|
|
181
|
-
|
|
182
|
-
return True
|
|
183
|
-
|
|
184
|
-
def forward(
|
|
185
|
-
self,
|
|
186
|
-
input_ids: Optional[torch.LongTensor] = None,
|
|
187
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
188
|
-
cache_position: torch.Tensor = None,
|
|
189
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
190
|
-
batch_idx: Optional[int] = None,
|
|
191
|
-
block_tables: Optional[torch.Tensor] = None,
|
|
192
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
193
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
194
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
195
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
196
|
-
):
|
|
197
|
-
if input_ids is None and inputs_embeds is None:
|
|
198
|
-
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
|
199
|
-
|
|
200
|
-
if inputs_embeds is None:
|
|
201
|
-
inputs = input_ids
|
|
202
|
-
if self.embed_tokens is not None:
|
|
203
|
-
inputs = self.embed_tokens(inputs)
|
|
204
|
-
else:
|
|
205
|
-
inputs = inputs_embeds
|
|
206
|
-
|
|
207
|
-
is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
|
|
208
|
-
if not is_external_block_tables:
|
|
209
|
-
block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
|
|
210
|
-
|
|
211
|
-
if self.phase == "decode":
|
|
212
|
-
return self.decode_forward(
|
|
213
|
-
inputs,
|
|
214
|
-
cache_position,
|
|
215
|
-
block_tables,
|
|
216
|
-
is_external_block_tables,
|
|
217
|
-
attention_mask=attention_mask,
|
|
218
|
-
position_embed=position_embed,
|
|
219
|
-
position_ids=position_ids,
|
|
220
|
-
local_block_tables=local_block_tables,
|
|
221
|
-
)
|
|
222
|
-
else:
|
|
223
|
-
return self.prefill_forward(
|
|
224
|
-
inputs,
|
|
225
|
-
cache_position,
|
|
226
|
-
attention_mask,
|
|
227
|
-
batch_idx,
|
|
228
|
-
block_tables,
|
|
229
|
-
is_external_block_tables=is_external_block_tables,
|
|
230
|
-
position_embed=position_embed,
|
|
231
|
-
token_type_ids=token_type_ids,
|
|
232
|
-
local_block_tables=local_block_tables,
|
|
233
|
-
)
|
|
234
|
-
|
|
235
|
-
def decode_forward(
|
|
236
|
-
self,
|
|
237
|
-
inputs: torch.Tensor,
|
|
238
|
-
cache_position: torch.Tensor = None,
|
|
239
|
-
block_tables: torch.Tensor = None,
|
|
240
|
-
is_external_block_tables: bool = None,
|
|
241
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
242
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
243
|
-
position_ids: Optional[torch.Tensor] = None,
|
|
244
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
245
|
-
) -> torch.FloatTensor:
|
|
246
|
-
batch_size = inputs.shape[0]
|
|
247
|
-
if batch_size != self.batch_size:
|
|
248
|
-
raise RuntimeError(
|
|
249
|
-
f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
|
|
250
|
-
)
|
|
251
|
-
|
|
252
|
-
if batch_size != cache_position.shape[0]:
|
|
253
|
-
raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
|
|
254
|
-
|
|
255
|
-
if self.rbln_config.use_attention_mask and attention_mask is None:
|
|
256
|
-
for b_idx in range(batch_size):
|
|
257
|
-
decoding_step = cache_position[b_idx].item()
|
|
258
|
-
if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
|
|
259
|
-
raise ValueError(
|
|
260
|
-
f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
|
|
261
|
-
)
|
|
262
|
-
|
|
263
|
-
if is_external_block_tables:
|
|
264
|
-
self.dec_attn_mask[b_idx].fill_(0)
|
|
265
|
-
self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
|
|
266
|
-
else:
|
|
267
|
-
self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
|
|
268
|
-
|
|
269
|
-
attention_mask = self.dec_attn_mask
|
|
270
|
-
|
|
271
|
-
if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
|
|
272
|
-
block_tables = block_tables[: self.batch_size]
|
|
273
|
-
|
|
274
|
-
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
275
|
-
attention_mask = attention_mask[: self.batch_size]
|
|
276
|
-
|
|
277
|
-
logits = super().forward(
|
|
278
|
-
inputs,
|
|
279
|
-
cache_position,
|
|
280
|
-
block_tables,
|
|
281
|
-
local_block_tables,
|
|
282
|
-
position_embed,
|
|
283
|
-
attention_mask if self.rbln_config.use_attention_mask else None,
|
|
284
|
-
position_ids if self.rbln_config.use_position_ids else None,
|
|
285
|
-
)
|
|
286
|
-
|
|
287
|
-
return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
|
|
288
|
-
|
|
289
|
-
def _prepare_prefill_inputs(
|
|
290
|
-
self,
|
|
291
|
-
inputs: torch.Tensor,
|
|
292
|
-
cache_position: torch.Tensor,
|
|
293
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
294
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
295
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
296
|
-
):
|
|
297
|
-
"""
|
|
298
|
-
Prepare inputs for prefill phase.
|
|
299
|
-
"""
|
|
300
|
-
# Handle continuous batching in a compiled graph by extracting valid inputs
|
|
301
|
-
# If an attention mask is provided, select only the valid (non-masked) inputs
|
|
302
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
303
|
-
if position_embed is not None:
|
|
304
|
-
position_embed = (
|
|
305
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
306
|
-
)
|
|
307
|
-
if token_type_ids is not None:
|
|
308
|
-
token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
|
|
309
|
-
|
|
310
|
-
query_length = inputs.shape[1]
|
|
311
|
-
if query_length > self.rbln_config.max_seq_len:
|
|
312
|
-
raise ValueError(
|
|
313
|
-
f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
|
|
314
|
-
)
|
|
315
|
-
|
|
316
|
-
# Initialize attention mask for chunked processing
|
|
317
|
-
chunked_attention_mask = (
|
|
318
|
-
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
319
|
-
if self.rbln_config.use_attention_mask
|
|
320
|
-
else None
|
|
321
|
-
)
|
|
322
|
-
|
|
323
|
-
# Buffer for storing output logits
|
|
324
|
-
out_buffers = [
|
|
325
|
-
torch.empty(
|
|
326
|
-
size=self.output_size,
|
|
327
|
-
dtype=torch.float32,
|
|
328
|
-
device="cpu",
|
|
329
|
-
)
|
|
330
|
-
]
|
|
331
|
-
|
|
332
|
-
# Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
|
|
333
|
-
padding_size = 0
|
|
334
|
-
if query_length % self.rbln_config.prefill_chunk_size != 0:
|
|
335
|
-
padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
|
|
336
|
-
# inputs_embeds
|
|
337
|
-
if inputs.dim() == 3:
|
|
338
|
-
inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
339
|
-
# inputs_ids
|
|
340
|
-
else:
|
|
341
|
-
inputs = torch.nn.functional.pad(inputs, (0, padding_size))
|
|
342
|
-
|
|
343
|
-
cache_position = torch.cat(
|
|
344
|
-
[
|
|
345
|
-
cache_position,
|
|
346
|
-
torch.arange(
|
|
347
|
-
query_length,
|
|
348
|
-
query_length + padding_size,
|
|
349
|
-
dtype=torch.int32,
|
|
350
|
-
).unsqueeze(0),
|
|
351
|
-
],
|
|
352
|
-
dim=-1,
|
|
353
|
-
)
|
|
354
|
-
|
|
355
|
-
if position_embed is not None:
|
|
356
|
-
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
357
|
-
|
|
358
|
-
if token_type_ids is not None:
|
|
359
|
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
360
|
-
|
|
361
|
-
# Overwrite position_ids and padded_cache_lengths
|
|
362
|
-
position_ids = cache_position.clone()
|
|
363
|
-
padded_cache_lengths = 0
|
|
364
|
-
|
|
365
|
-
return (
|
|
366
|
-
inputs,
|
|
367
|
-
cache_position,
|
|
368
|
-
chunked_attention_mask,
|
|
369
|
-
out_buffers,
|
|
370
|
-
position_ids,
|
|
371
|
-
position_embed,
|
|
372
|
-
padded_cache_lengths,
|
|
373
|
-
query_length,
|
|
374
|
-
token_type_ids,
|
|
375
|
-
)
|
|
376
|
-
|
|
377
|
-
def prefill_forward(
|
|
378
|
-
self,
|
|
379
|
-
inputs: torch.Tensor,
|
|
380
|
-
cache_position: torch.Tensor = None,
|
|
381
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
382
|
-
batch_idx: int = None,
|
|
383
|
-
block_tables: torch.Tensor = None,
|
|
384
|
-
is_external_block_tables: bool = False,
|
|
385
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
386
|
-
token_type_ids: Optional[torch.Tensor] = None,
|
|
387
|
-
local_block_tables: Optional[torch.Tensor] = None,
|
|
388
|
-
) -> torch.FloatTensor:
|
|
389
|
-
"""
|
|
390
|
-
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
|
391
|
-
Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
|
|
392
|
-
and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
|
|
393
|
-
"""
|
|
394
|
-
(
|
|
395
|
-
inputs,
|
|
396
|
-
cache_position,
|
|
397
|
-
chunked_attention_mask,
|
|
398
|
-
out_buffers,
|
|
399
|
-
position_ids,
|
|
400
|
-
position_embed,
|
|
401
|
-
padded_cache_lengths,
|
|
402
|
-
query_length,
|
|
403
|
-
token_type_ids,
|
|
404
|
-
) = self._prepare_prefill_inputs(
|
|
405
|
-
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
406
|
-
)
|
|
407
|
-
|
|
408
|
-
# Process input in chunks of size `prefill_chunk_size`
|
|
409
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
410
|
-
# Extract the current chunk of inputs and cache positions
|
|
411
|
-
input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
412
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
413
|
-
position_ids_chunk = (
|
|
414
|
-
position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
415
|
-
if position_ids is not None
|
|
416
|
-
else None
|
|
417
|
-
)
|
|
418
|
-
if position_embed is not None:
|
|
419
|
-
position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
|
|
420
|
-
|
|
421
|
-
if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
|
|
422
|
-
# Update attention mask to ensure proper causal behavior
|
|
423
|
-
if step >= self.rbln_config.prefill_chunk_size:
|
|
424
|
-
chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
|
|
425
|
-
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
426
|
-
|
|
427
|
-
# Define query position
|
|
428
|
-
if step + self.rbln_config.prefill_chunk_size >= query_length:
|
|
429
|
-
query_position = torch.tensor(
|
|
430
|
-
(query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
|
|
431
|
-
)
|
|
432
|
-
else:
|
|
433
|
-
query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
|
|
434
|
-
|
|
435
|
-
# Forward pass for the current chunk
|
|
436
|
-
logits = super().forward(
|
|
437
|
-
input_chunk,
|
|
438
|
-
cache_pos_chunk,
|
|
439
|
-
block_tables,
|
|
440
|
-
local_block_tables,
|
|
441
|
-
position_embed_chunk if position_embed is not None else None,
|
|
442
|
-
query_position,
|
|
443
|
-
chunked_attention_mask if self.rbln_config.use_attention_mask else None,
|
|
444
|
-
position_ids_chunk if self.rbln_config.use_position_ids else None,
|
|
445
|
-
out=out_buffers,
|
|
446
|
-
)
|
|
447
|
-
|
|
448
|
-
# Update decoder attention mask with processed KV-cache length from prefill phase
|
|
449
|
-
if not is_external_block_tables and self.rbln_config.use_attention_mask:
|
|
450
|
-
self.dec_attn_mask[batch_idx].fill_(0)
|
|
451
|
-
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
452
|
-
|
|
453
|
-
return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
@dataclass
|
|
457
|
-
class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
|
|
458
|
-
logits: torch.FloatTensor = None
|
|
459
|
-
generate_idx: torch.Tensor = None
|
|
460
|
-
padded_cache_lengths: int = None
|
|
461
|
-
|
|
462
|
-
|
|
463
49
|
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
464
50
|
"""
|
|
465
51
|
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
@@ -495,18 +81,116 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
495
81
|
else:
|
|
496
82
|
self.embed_tokens = None
|
|
497
83
|
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
#
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
84
|
+
self.setup_runtime()
|
|
85
|
+
|
|
86
|
+
def setup_runtime(self):
|
|
87
|
+
# Initialize resources to be used across Runtime instances (prefill and decode phases)
|
|
88
|
+
page_table_manager = RBLNPageTableManager(self.rbln_config)
|
|
89
|
+
dec_attn_mask = torch.zeros(
|
|
90
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
91
|
+
)
|
|
92
|
+
out_buffers = [torch.empty(self.prefill_output_size, dtype=torch.float32, device="cpu")]
|
|
93
|
+
|
|
94
|
+
common_kwargs = {
|
|
95
|
+
"main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
|
|
96
|
+
"embed_tokens": self.embed_tokens,
|
|
97
|
+
"dec_attn_mask": dec_attn_mask,
|
|
98
|
+
"page_table_manager": page_table_manager,
|
|
99
|
+
"rbln_config": self.rbln_config,
|
|
100
|
+
}
|
|
101
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
|
102
|
+
runtime=self.model[0],
|
|
103
|
+
phase="prefill",
|
|
104
|
+
batch_size=self.rbln_config.batch_size,
|
|
105
|
+
out_buffers=out_buffers,
|
|
106
|
+
**common_kwargs,
|
|
107
|
+
)
|
|
108
|
+
if self.can_generate():
|
|
109
|
+
self.decoders = {}
|
|
110
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
111
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
112
|
+
runtime=self.model[i + 1],
|
|
113
|
+
phase="decode",
|
|
114
|
+
batch_size=batch_size,
|
|
115
|
+
**common_kwargs,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
119
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def prefill_output_size(self):
|
|
123
|
+
return (
|
|
124
|
+
1,
|
|
125
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
126
|
+
self.config.hidden_size,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def get_quantized_model(
|
|
131
|
+
cls,
|
|
132
|
+
model_id: str,
|
|
133
|
+
config: Optional[PretrainedConfig] = None,
|
|
134
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
135
|
+
revision: Optional[str] = None,
|
|
136
|
+
force_download: bool = False,
|
|
137
|
+
cache_dir: Optional[str] = None,
|
|
138
|
+
subfolder: str = "",
|
|
139
|
+
local_files_only: bool = False,
|
|
140
|
+
trust_remote_code: bool = False,
|
|
141
|
+
rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
|
|
142
|
+
**kwargs,
|
|
143
|
+
):
|
|
144
|
+
kwargs = cls.update_kwargs(kwargs)
|
|
145
|
+
|
|
146
|
+
if config is None:
|
|
147
|
+
config = AutoConfig.from_pretrained(
|
|
148
|
+
model_id,
|
|
149
|
+
use_auth_token=use_auth_token,
|
|
150
|
+
revision=revision,
|
|
151
|
+
force_download=force_download,
|
|
152
|
+
cache_dir=cache_dir,
|
|
153
|
+
trust_remote_code=trust_remote_code,
|
|
154
|
+
**kwargs,
|
|
509
155
|
)
|
|
156
|
+
if config.torch_dtype == torch.bfloat16:
|
|
157
|
+
# FIXME: bfloat16 is not supported by rebel-compiler
|
|
158
|
+
config.torch_dtype = torch.float32
|
|
159
|
+
|
|
160
|
+
with no_init_weights():
|
|
161
|
+
model = cls.auto_model_class.from_config(config)
|
|
162
|
+
|
|
163
|
+
model = prepare_model_for_quantization(
|
|
164
|
+
model,
|
|
165
|
+
model_id,
|
|
166
|
+
kwargs.get("num_hidden_layers"),
|
|
167
|
+
use_auth_token=use_auth_token,
|
|
168
|
+
revision=revision,
|
|
169
|
+
cache_dir=cache_dir,
|
|
170
|
+
force_download=force_download,
|
|
171
|
+
local_files_only=local_files_only,
|
|
172
|
+
rbln_quantization=rbln_config.quantization,
|
|
173
|
+
)
|
|
174
|
+
return model
|
|
175
|
+
|
|
176
|
+
def __getattr__(self, __name: str) -> Any:
|
|
177
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
178
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
179
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
180
|
+
# proper method binding.
|
|
181
|
+
|
|
182
|
+
# The method implements a delegation pattern that:
|
|
183
|
+
|
|
184
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
185
|
+
# 2. For other attributes: Returns them directly from the original class
|
|
186
|
+
|
|
187
|
+
def redirect(func):
|
|
188
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
189
|
+
|
|
190
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
191
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
192
|
+
return redirect(val)
|
|
193
|
+
return val
|
|
510
194
|
|
|
511
195
|
@classmethod
|
|
512
196
|
def save_torch_artifacts(
|
|
@@ -532,6 +216,14 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
532
216
|
)
|
|
533
217
|
return embed_tokens
|
|
534
218
|
|
|
219
|
+
def get_decoder(self):
|
|
220
|
+
if not self.can_generate():
|
|
221
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
222
|
+
return self.decoder
|
|
223
|
+
|
|
224
|
+
def can_generate(self):
|
|
225
|
+
return self.rbln_config.can_generate
|
|
226
|
+
|
|
535
227
|
def get_input_embeddings(self):
|
|
536
228
|
return self.embed_tokens
|
|
537
229
|
|
|
@@ -543,20 +235,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
543
235
|
|
|
544
236
|
@classmethod
|
|
545
237
|
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
546
|
-
|
|
547
|
-
"max_seq_len": rbln_config.max_seq_len,
|
|
548
|
-
"attn_impl": rbln_config.attn_impl,
|
|
549
|
-
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
550
|
-
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
551
|
-
"use_rotary_emb": cls._use_rotary_emb,
|
|
552
|
-
"use_attention_mask": rbln_config.use_attention_mask,
|
|
553
|
-
"use_position_ids": rbln_config.use_position_ids,
|
|
554
|
-
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
555
|
-
"cache_impl": rbln_config.cache_impl,
|
|
556
|
-
"sliding_window": rbln_config.sliding_window,
|
|
557
|
-
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
558
|
-
}
|
|
559
|
-
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
238
|
+
return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
|
|
560
239
|
|
|
561
240
|
@classmethod
|
|
562
241
|
def _compile_model(
|
|
@@ -608,38 +287,58 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
608
287
|
|
|
609
288
|
@classmethod
|
|
610
289
|
@torch.inference_mode()
|
|
611
|
-
def get_compiled_model(
|
|
612
|
-
cls,
|
|
613
|
-
model: PreTrainedModel,
|
|
614
|
-
rbln_config: RBLNDecoderOnlyModelConfig,
|
|
615
|
-
):
|
|
290
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
616
291
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
617
|
-
|
|
292
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
618
293
|
|
|
619
294
|
# Here we use meta tensor, for the memory efficiency.
|
|
620
|
-
meta_tensor_names = [name for name, _, _ in
|
|
621
|
-
|
|
622
|
-
context,
|
|
295
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
296
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
297
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
623
298
|
|
|
624
|
-
|
|
625
|
-
|
|
299
|
+
compiled_models = {}
|
|
300
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
301
|
+
wrapped_model,
|
|
302
|
+
prefill_compile_config,
|
|
303
|
+
prefill_example_inputs,
|
|
304
|
+
context,
|
|
305
|
+
rbln_config,
|
|
306
|
+
rbln_config.quantization,
|
|
307
|
+
phase="prefill",
|
|
626
308
|
)
|
|
627
|
-
compiled_models = {"prefill": compiled_model}
|
|
628
309
|
|
|
629
|
-
|
|
310
|
+
if rbln_config.can_generate:
|
|
311
|
+
wrapped_model.phase = "decode"
|
|
312
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
313
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
314
|
+
compiled_decoder = cls._compile_model(
|
|
315
|
+
wrapped_model,
|
|
316
|
+
dec_compile_config,
|
|
317
|
+
dec_example_inputs,
|
|
318
|
+
context,
|
|
319
|
+
rbln_config,
|
|
320
|
+
rbln_config.quantization,
|
|
321
|
+
phase="decode",
|
|
322
|
+
)
|
|
323
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
630
324
|
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
325
|
+
# check if the memory is enough to have additional blocks
|
|
326
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
327
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
328
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
329
|
+
compiled_models=compiled_models,
|
|
330
|
+
model_config=model.config,
|
|
331
|
+
rbln_config=rbln_config,
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return compiled_models
|
|
636
335
|
|
|
637
336
|
@classmethod
|
|
638
337
|
def get_pytorch_model(
|
|
639
338
|
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
640
339
|
) -> PreTrainedModel:
|
|
641
340
|
if rbln_config and rbln_config.quantization:
|
|
642
|
-
model = cls.get_quantized_model(*args, **kwargs)
|
|
341
|
+
model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
|
|
643
342
|
else:
|
|
644
343
|
model = super().get_pytorch_model(*args, **kwargs)
|
|
645
344
|
|
|
@@ -664,48 +363,40 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
664
363
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
665
364
|
is_prefill = query_length > 1
|
|
666
365
|
|
|
667
|
-
|
|
366
|
+
input_info = []
|
|
668
367
|
if rbln_config.use_inputs_embeds:
|
|
669
|
-
|
|
368
|
+
input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], "float32"))
|
|
670
369
|
else:
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
input_info = [
|
|
675
|
-
main_input,
|
|
676
|
-
(
|
|
677
|
-
"cache_position",
|
|
678
|
-
[batch_size, query_length],
|
|
679
|
-
"int32",
|
|
680
|
-
),
|
|
681
|
-
]
|
|
370
|
+
input_info.append(("input_ids", [batch_size, query_length], "int64"))
|
|
371
|
+
|
|
372
|
+
input_info.append(("cache_position", [batch_size, query_length], "int32"))
|
|
682
373
|
|
|
683
|
-
# 3. block_tables
|
|
684
374
|
if rbln_config.use_global_attention:
|
|
685
375
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
686
|
-
input_info.
|
|
687
|
-
|
|
376
|
+
input_info.append(
|
|
377
|
+
("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
|
|
688
378
|
)
|
|
689
379
|
if rbln_config.use_local_attention:
|
|
690
|
-
input_info.
|
|
380
|
+
input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
|
|
691
381
|
|
|
692
|
-
# 4. query_position for sliding window attention
|
|
693
382
|
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
694
|
-
input_info.
|
|
383
|
+
input_info.append(("query_position", [], "int16"))
|
|
695
384
|
|
|
696
|
-
# 5. attention_mask & position_ids
|
|
697
385
|
if rbln_config.use_attention_mask:
|
|
698
|
-
|
|
699
|
-
[
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
386
|
+
if rbln_config.use_position_ids:
|
|
387
|
+
input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], "float32"))
|
|
388
|
+
else:
|
|
389
|
+
input_info.append(
|
|
390
|
+
("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
|
|
391
|
+
)
|
|
392
|
+
|
|
705
393
|
if rbln_config.use_position_ids:
|
|
706
394
|
input_info.append(("position_ids", [batch_size, query_length], "int32"))
|
|
707
395
|
|
|
708
|
-
|
|
396
|
+
kvcache_dtype = "float32"
|
|
397
|
+
if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
|
|
398
|
+
kvcache_dtype = "float8_e4m3fn"
|
|
399
|
+
|
|
709
400
|
global_kvcache_shape = [
|
|
710
401
|
rbln_config.kvcache_num_blocks,
|
|
711
402
|
num_key_value_heads,
|
|
@@ -720,7 +411,7 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
720
411
|
local_kvcache_shape
|
|
721
412
|
if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
|
|
722
413
|
else global_kvcache_shape,
|
|
723
|
-
|
|
414
|
+
kvcache_dtype,
|
|
724
415
|
)
|
|
725
416
|
for i in range(num_hidden_layers * 2)
|
|
726
417
|
]
|
|
@@ -784,15 +475,62 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
784
475
|
max_seq_len=rbln_config.max_seq_len,
|
|
785
476
|
)
|
|
786
477
|
|
|
787
|
-
|
|
788
|
-
rbln_config.kvcache_num_blocks = (
|
|
789
|
-
rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
790
|
-
) * rbln_config.batch_size
|
|
478
|
+
num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
791
479
|
|
|
792
|
-
|
|
480
|
+
# Update kvcache_num_blocks based on the attention implementation.
|
|
481
|
+
if rbln_config.attn_impl == "flash_attn":
|
|
482
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks(
|
|
483
|
+
config=model_config,
|
|
484
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
|
485
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
486
|
+
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
487
|
+
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
488
|
+
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
489
|
+
)
|
|
793
490
|
|
|
794
|
-
|
|
795
|
-
|
|
491
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
492
|
+
if estimated_max_num_blocks < num_full_blocks:
|
|
493
|
+
# lower bound of the number of blocks for flash attention.
|
|
494
|
+
min_blocks_for_flash = min(
|
|
495
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
|
|
496
|
+
)
|
|
497
|
+
if min_blocks_for_flash > estimated_max_num_blocks:
|
|
498
|
+
# NOTE: Just try to compile with lower bound of blocks for flash attention.
|
|
499
|
+
# Even if it's larger than the estimated maximum number of blocks.
|
|
500
|
+
rbln_config.kvcache_num_blocks = min_blocks_for_flash
|
|
501
|
+
else:
|
|
502
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
503
|
+
rbln_config.kvcache_num_blocks = estimated_max_num_blocks
|
|
504
|
+
|
|
505
|
+
if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
|
|
506
|
+
raise RuntimeError(
|
|
507
|
+
f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
|
|
508
|
+
"Ensure the number of blocks is at least equal to the batch size."
|
|
509
|
+
)
|
|
510
|
+
else:
|
|
511
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
512
|
+
elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
|
|
513
|
+
logger.warning(
|
|
514
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
515
|
+
f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
|
|
516
|
+
"This can cause a failure during model compilation."
|
|
517
|
+
)
|
|
518
|
+
else:
|
|
519
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
520
|
+
rbln_config.kvcache_num_blocks = num_full_blocks
|
|
521
|
+
elif rbln_config.kvcache_num_blocks > num_full_blocks:
|
|
522
|
+
logger.warning(
|
|
523
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
524
|
+
f" than the required number of blocks ({num_full_blocks})."
|
|
525
|
+
"This can cause a failure during model compilation."
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
529
|
+
|
|
530
|
+
return rbln_config
|
|
531
|
+
|
|
532
|
+
@classmethod
|
|
533
|
+
def _update_rbln_config(
|
|
796
534
|
cls,
|
|
797
535
|
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
798
536
|
model: Optional[PreTrainedModel] = None,
|
|
@@ -823,7 +561,20 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
823
561
|
)
|
|
824
562
|
|
|
825
563
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
826
|
-
|
|
564
|
+
compile_cfgs = [prefill_compile_config]
|
|
565
|
+
|
|
566
|
+
if rbln_config.can_generate:
|
|
567
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
568
|
+
dec_input_info = cls.get_input_info(
|
|
569
|
+
batch_size=batch_size,
|
|
570
|
+
query_length=1,
|
|
571
|
+
rbln_config=rbln_config,
|
|
572
|
+
model_config=model_config,
|
|
573
|
+
)
|
|
574
|
+
compile_cfgs.append(
|
|
575
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
576
|
+
)
|
|
577
|
+
rbln_config.set_compile_cfgs(compile_cfgs)
|
|
827
578
|
|
|
828
579
|
return rbln_config
|
|
829
580
|
|
|
@@ -833,128 +584,37 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
833
584
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
834
585
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
835
586
|
) -> List[rebel.Runtime]:
|
|
836
|
-
expected_model_names = [
|
|
837
|
-
|
|
838
|
-
|
|
587
|
+
expected_model_names = ["prefill"]
|
|
588
|
+
if rbln_config.can_generate:
|
|
589
|
+
expected_model_names.extend(
|
|
590
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
591
|
+
)
|
|
839
592
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
840
593
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
841
594
|
|
|
842
|
-
|
|
595
|
+
ret_val = [
|
|
843
596
|
rebel.Runtime(
|
|
844
597
|
compiled_models[0],
|
|
845
598
|
tensor_type="pt",
|
|
846
599
|
device=rbln_config.device_map["prefill"],
|
|
847
600
|
activate_profiler=rbln_config.activate_profiler,
|
|
848
|
-
|
|
849
|
-
]
|
|
850
|
-
|
|
851
|
-
def _preprocess_chunked_prefill(
|
|
852
|
-
self,
|
|
853
|
-
inputs: torch.Tensor,
|
|
854
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
855
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
856
|
-
):
|
|
857
|
-
# valid sequence length of inputs_embeds
|
|
858
|
-
query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
|
|
859
|
-
|
|
860
|
-
# extract valid inputs
|
|
861
|
-
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
862
|
-
|
|
863
|
-
if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
|
|
864
|
-
inputs = self.get_input_embeddings()(inputs)
|
|
865
|
-
|
|
866
|
-
if position_embed is not None:
|
|
867
|
-
position_embed = (
|
|
868
|
-
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
869
|
-
)
|
|
870
|
-
|
|
871
|
-
# padding for chunked prefill
|
|
872
|
-
padding_size = (
|
|
873
|
-
self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
|
|
874
|
-
) % self.rbln_config.prefill_chunk_size
|
|
875
|
-
padded_len = query_length + padding_size
|
|
876
|
-
|
|
877
|
-
inputs = (
|
|
878
|
-
torch.nn.functional.pad(inputs, (0, padding_size))
|
|
879
|
-
if not self.rbln_config.use_inputs_embeds
|
|
880
|
-
else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
881
|
-
)
|
|
882
|
-
position_embed = (
|
|
883
|
-
None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
884
|
-
)
|
|
885
|
-
cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
|
|
886
|
-
|
|
887
|
-
chunked_attention_mask = (
|
|
888
|
-
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
889
|
-
if self.rbln_config.use_attention_mask
|
|
890
|
-
else None
|
|
891
|
-
)
|
|
892
|
-
|
|
893
|
-
return inputs, position_embed, cache_position, query_length, chunked_attention_mask
|
|
894
|
-
|
|
895
|
-
def _chunked_prefill_forward(
|
|
896
|
-
self,
|
|
897
|
-
inputs: torch.Tensor,
|
|
898
|
-
attention_mask: Optional[torch.Tensor] = None,
|
|
899
|
-
position_embed: Optional[torch.Tensor] = None,
|
|
900
|
-
):
|
|
901
|
-
padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
|
|
902
|
-
self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
|
|
903
|
-
)
|
|
904
|
-
|
|
905
|
-
# chunked prefill
|
|
906
|
-
last_hidden_states = []
|
|
907
|
-
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
908
|
-
# Extract the current chunk of inputs and cache positions
|
|
909
|
-
input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
910
|
-
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
911
|
-
|
|
912
|
-
valid_length = (
|
|
913
|
-
self.rbln_config.prefill_chunk_size
|
|
914
|
-
if (step + self.rbln_config.prefill_chunk_size) <= query_length
|
|
915
|
-
else query_length - step
|
|
916
|
-
)
|
|
917
|
-
if self.rbln_config.use_local_attention:
|
|
918
|
-
query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
|
|
919
|
-
else:
|
|
920
|
-
query_position = None
|
|
921
|
-
|
|
922
|
-
if self.rbln_config.use_attention_mask:
|
|
923
|
-
if step > 0:
|
|
924
|
-
chunked_attention_mask[:, :, :, :step] = 1
|
|
925
|
-
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
926
|
-
|
|
927
|
-
# Forward pass for the current chunk
|
|
928
|
-
last_hidden_states_chunk = self.prefill_decoder(
|
|
929
|
-
input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
|
|
930
|
-
inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
|
|
931
|
-
cache_position=cache_pos_chunk,
|
|
932
|
-
block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
|
|
933
|
-
local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
|
|
934
|
-
query_position=query_position,
|
|
935
|
-
attention_mask=chunked_attention_mask,
|
|
936
|
-
position_emb=padded_position_embed,
|
|
601
|
+
timeout=rbln_config.timeout,
|
|
937
602
|
)
|
|
938
|
-
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
dtype=last_hidden_states.dtype,
|
|
603
|
+
]
|
|
604
|
+
if rbln_config.can_generate:
|
|
605
|
+
ret_val.extend(
|
|
606
|
+
[
|
|
607
|
+
rebel.Runtime(
|
|
608
|
+
compiled_models[i + 1],
|
|
609
|
+
tensor_type="pt",
|
|
610
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
611
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
612
|
+
timeout=rbln_config.timeout,
|
|
613
|
+
)
|
|
614
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
615
|
+
]
|
|
952
616
|
)
|
|
953
|
-
|
|
954
|
-
new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
|
|
955
|
-
else:
|
|
956
|
-
new_last_hidden_states = last_hidden_states
|
|
957
|
-
return new_last_hidden_states
|
|
617
|
+
return ret_val
|
|
958
618
|
|
|
959
619
|
def forward(
|
|
960
620
|
self,
|
|
@@ -966,20 +626,32 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
|
966
626
|
) -> Tuple[torch.FloatTensor]:
|
|
967
627
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
968
628
|
batch_size = inputs.shape[0]
|
|
629
|
+
|
|
630
|
+
if batch_size != self.rbln_config.batch_size:
|
|
631
|
+
raise ValueError(
|
|
632
|
+
f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
|
|
633
|
+
)
|
|
634
|
+
|
|
969
635
|
all_last_hidden_states = []
|
|
970
|
-
for b_idx in range(batch_size):
|
|
971
|
-
|
|
972
|
-
|
|
973
|
-
attention_mask[b_idx] if attention_mask is not None else None,
|
|
974
|
-
position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
636
|
+
for b_idx in range(self.rbln_config.batch_size):
|
|
637
|
+
query_length = (
|
|
638
|
+
attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
|
|
975
639
|
)
|
|
640
|
+
cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
|
|
641
|
+
last_hidden_states = self.prefill_decoder(
|
|
642
|
+
inputs[b_idx : b_idx + 1],
|
|
643
|
+
attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
|
|
644
|
+
position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
645
|
+
cache_position=cache_position,
|
|
646
|
+
batch_idx=b_idx,
|
|
647
|
+
).logits
|
|
976
648
|
all_last_hidden_states.append(last_hidden_states)
|
|
977
649
|
|
|
978
650
|
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
979
651
|
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
980
652
|
|
|
981
653
|
|
|
982
|
-
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
654
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
|
|
983
655
|
"""
|
|
984
656
|
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
985
657
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
@@ -1002,380 +674,18 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
|
1002
674
|
|
|
1003
675
|
auto_model_class = AutoModelForCausalLM
|
|
1004
676
|
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
self.embed_tokens = self._create_embedding_layer()
|
|
1012
|
-
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
1013
|
-
else:
|
|
1014
|
-
self.embed_tokens = None
|
|
1015
|
-
|
|
1016
|
-
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
1017
|
-
dec_attn_mask = torch.zeros(
|
|
1018
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
1019
|
-
)
|
|
1020
|
-
block_tables = torch.zeros(
|
|
1021
|
-
self.rbln_config.batch_size,
|
|
1022
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
1023
|
-
dtype=torch.int16,
|
|
1024
|
-
).fill_(-1)
|
|
1025
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
1026
|
-
|
|
1027
|
-
self.prefill_decoder = RBLNRuntimeModel(
|
|
1028
|
-
runtime=self.model[0],
|
|
1029
|
-
main_input_name=main_input_name,
|
|
1030
|
-
embed_tokens=self.embed_tokens,
|
|
1031
|
-
phase="prefill",
|
|
1032
|
-
batch_size=self.rbln_config.batch_size,
|
|
1033
|
-
dec_attn_mask=dec_attn_mask,
|
|
1034
|
-
block_tables=block_tables,
|
|
1035
|
-
free_block_pool=free_block_pool,
|
|
1036
|
-
rbln_config=self.rbln_config,
|
|
1037
|
-
vocab_size=self.config.vocab_size,
|
|
1038
|
-
)
|
|
1039
|
-
|
|
1040
|
-
if self.can_generate():
|
|
1041
|
-
self.decoders = {}
|
|
1042
|
-
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
1043
|
-
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
1044
|
-
runtime=self.model[i + 1],
|
|
1045
|
-
main_input_name=main_input_name,
|
|
1046
|
-
embed_tokens=self.embed_tokens,
|
|
1047
|
-
phase="decode",
|
|
1048
|
-
batch_size=batch_size,
|
|
1049
|
-
dec_attn_mask=dec_attn_mask,
|
|
1050
|
-
block_tables=block_tables,
|
|
1051
|
-
free_block_pool=free_block_pool,
|
|
1052
|
-
rbln_config=self.rbln_config,
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
1056
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
1057
|
-
|
|
1058
|
-
@classmethod
|
|
1059
|
-
def get_quantized_model(
|
|
1060
|
-
cls,
|
|
1061
|
-
model_id: str,
|
|
1062
|
-
config: Optional[PretrainedConfig] = None,
|
|
1063
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
|
1064
|
-
revision: Optional[str] = None,
|
|
1065
|
-
force_download: bool = False,
|
|
1066
|
-
cache_dir: Optional[str] = None,
|
|
1067
|
-
subfolder: str = "",
|
|
1068
|
-
local_files_only: bool = False,
|
|
1069
|
-
trust_remote_code: bool = False,
|
|
1070
|
-
**kwargs,
|
|
1071
|
-
):
|
|
1072
|
-
kwargs = cls.update_kwargs(kwargs)
|
|
1073
|
-
|
|
1074
|
-
if config is None:
|
|
1075
|
-
config = AutoConfig.from_pretrained(
|
|
1076
|
-
model_id,
|
|
1077
|
-
use_auth_token=use_auth_token,
|
|
1078
|
-
revision=revision,
|
|
1079
|
-
force_download=force_download,
|
|
1080
|
-
cache_dir=cache_dir,
|
|
1081
|
-
trust_remote_code=trust_remote_code,
|
|
1082
|
-
**kwargs,
|
|
1083
|
-
)
|
|
1084
|
-
|
|
1085
|
-
with no_init_weights():
|
|
1086
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
1087
|
-
|
|
1088
|
-
model = prepare_model_for_quantization(
|
|
1089
|
-
model,
|
|
1090
|
-
model_id,
|
|
1091
|
-
kwargs.get("num_hidden_layers"),
|
|
1092
|
-
use_auth_token=use_auth_token,
|
|
1093
|
-
revision=revision,
|
|
1094
|
-
cache_dir=cache_dir,
|
|
1095
|
-
force_download=force_download,
|
|
1096
|
-
local_files_only=local_files_only,
|
|
1097
|
-
)
|
|
1098
|
-
return model
|
|
1099
|
-
|
|
1100
|
-
def __getattr__(self, __name: str) -> Any:
|
|
1101
|
-
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
1102
|
-
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
1103
|
-
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
1104
|
-
# proper method binding.
|
|
1105
|
-
|
|
1106
|
-
# The method implements a delegation pattern that:
|
|
1107
|
-
|
|
1108
|
-
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
1109
|
-
# 2. For other attributes: Returns them directly from the original class
|
|
1110
|
-
|
|
1111
|
-
def redirect(func):
|
|
1112
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
1113
|
-
|
|
1114
|
-
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
1115
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
1116
|
-
return redirect(val)
|
|
1117
|
-
return val
|
|
1118
|
-
|
|
1119
|
-
@classmethod
|
|
1120
|
-
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
1121
|
-
wrapper_cfg = {
|
|
1122
|
-
"max_seq_len": rbln_config.max_seq_len,
|
|
1123
|
-
"attn_impl": rbln_config.attn_impl,
|
|
1124
|
-
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
1125
|
-
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
1126
|
-
"use_rotary_emb": cls._use_rotary_emb,
|
|
1127
|
-
"use_attention_mask": rbln_config.use_attention_mask,
|
|
1128
|
-
"use_position_ids": rbln_config.use_position_ids,
|
|
1129
|
-
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
1130
|
-
"cache_impl": rbln_config.cache_impl,
|
|
1131
|
-
"sliding_window": rbln_config.sliding_window,
|
|
1132
|
-
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
1133
|
-
}
|
|
1134
|
-
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
1135
|
-
|
|
1136
|
-
@classmethod
|
|
1137
|
-
@torch.inference_mode()
|
|
1138
|
-
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
1139
|
-
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
1140
|
-
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
1141
|
-
|
|
1142
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
1143
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
1144
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
1145
|
-
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
1146
|
-
|
|
1147
|
-
compiled_models = {}
|
|
1148
|
-
compiled_models["prefill"] = cls._compile_model(
|
|
1149
|
-
wrapped_model,
|
|
1150
|
-
prefill_compile_config,
|
|
1151
|
-
prefill_example_inputs,
|
|
1152
|
-
context,
|
|
1153
|
-
rbln_config,
|
|
1154
|
-
rbln_config.quantization,
|
|
1155
|
-
phase="prefill",
|
|
677
|
+
@property
|
|
678
|
+
def prefill_output_size(self):
|
|
679
|
+
return (
|
|
680
|
+
1,
|
|
681
|
+
self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
|
|
682
|
+
self.config.vocab_size,
|
|
1156
683
|
)
|
|
1157
684
|
|
|
1158
|
-
if rbln_config.can_generate:
|
|
1159
|
-
wrapped_model.phase = "decode"
|
|
1160
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
1161
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
1162
|
-
compiled_decoder = cls._compile_model(
|
|
1163
|
-
wrapped_model,
|
|
1164
|
-
dec_compile_config,
|
|
1165
|
-
dec_example_inputs,
|
|
1166
|
-
context,
|
|
1167
|
-
rbln_config,
|
|
1168
|
-
rbln_config.quantization,
|
|
1169
|
-
phase="decode",
|
|
1170
|
-
)
|
|
1171
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
1172
|
-
|
|
1173
|
-
# check if the memory is enough to have additional blocks
|
|
1174
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1175
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
1176
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
1177
|
-
compiled_models=compiled_models,
|
|
1178
|
-
model_config=model.config,
|
|
1179
|
-
rbln_config=rbln_config,
|
|
1180
|
-
)
|
|
1181
|
-
|
|
1182
|
-
return compiled_models
|
|
1183
|
-
|
|
1184
685
|
@classmethod
|
|
1185
686
|
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
1186
687
|
return is_prefill
|
|
1187
688
|
|
|
1188
|
-
@classmethod
|
|
1189
|
-
def _update_attention_config(
|
|
1190
|
-
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
1191
|
-
):
|
|
1192
|
-
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
1193
|
-
attn_impl=rbln_config.attn_impl,
|
|
1194
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
1195
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1196
|
-
max_seq_len=rbln_config.max_seq_len,
|
|
1197
|
-
)
|
|
1198
|
-
|
|
1199
|
-
validate_attention_method(
|
|
1200
|
-
attn_impl=rbln_config.attn_impl,
|
|
1201
|
-
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
1202
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1203
|
-
max_seq_len=rbln_config.max_seq_len,
|
|
1204
|
-
)
|
|
1205
|
-
|
|
1206
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1207
|
-
max_num_blocks = required_num_blocks
|
|
1208
|
-
|
|
1209
|
-
if rbln_config.attn_impl == "flash_attn":
|
|
1210
|
-
estimated_max_num_blocks = cls.get_maximum_num_blocks(
|
|
1211
|
-
config=model_config,
|
|
1212
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
|
1213
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1214
|
-
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1215
|
-
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1216
|
-
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
1217
|
-
)
|
|
1218
|
-
|
|
1219
|
-
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
|
1220
|
-
|
|
1221
|
-
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
|
1222
|
-
if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
|
|
1223
|
-
max_num_blocks = flash_min_blocks
|
|
1224
|
-
|
|
1225
|
-
if max_num_blocks < rbln_config.batch_size:
|
|
1226
|
-
raise RuntimeError(
|
|
1227
|
-
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
|
1228
|
-
"Ensure the number of blocks is at least equal to the batch size."
|
|
1229
|
-
)
|
|
1230
|
-
|
|
1231
|
-
if rbln_config.kvcache_num_blocks is None:
|
|
1232
|
-
rbln_config.kvcache_num_blocks = max_num_blocks
|
|
1233
|
-
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
|
1234
|
-
logger.warning(
|
|
1235
|
-
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
|
1236
|
-
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
|
1237
|
-
"This can cause a failure during model compilation."
|
|
1238
|
-
)
|
|
1239
|
-
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1240
|
-
|
|
1241
|
-
return rbln_config
|
|
1242
|
-
|
|
1243
|
-
@classmethod
|
|
1244
|
-
def _update_rbln_config(
|
|
1245
|
-
cls,
|
|
1246
|
-
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
1247
|
-
model: Optional[PreTrainedModel] = None,
|
|
1248
|
-
model_config: Optional[PretrainedConfig] = None,
|
|
1249
|
-
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
1250
|
-
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
1251
|
-
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
1252
|
-
if rbln_config.can_generate:
|
|
1253
|
-
compile_configs = rbln_config.compile_cfgs
|
|
1254
|
-
for batch_size in rbln_config.decoder_batch_sizes:
|
|
1255
|
-
dec_input_info = cls.get_input_info(
|
|
1256
|
-
batch_size=batch_size,
|
|
1257
|
-
query_length=1,
|
|
1258
|
-
rbln_config=rbln_config,
|
|
1259
|
-
model_config=model_config,
|
|
1260
|
-
)
|
|
1261
|
-
compile_configs.append(
|
|
1262
|
-
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
1263
|
-
)
|
|
1264
|
-
rbln_config.set_compile_cfgs(compile_configs)
|
|
1265
|
-
|
|
1266
|
-
return rbln_config
|
|
1267
|
-
|
|
1268
|
-
@classmethod
|
|
1269
|
-
def _create_runtimes(
|
|
1270
|
-
cls,
|
|
1271
|
-
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1272
|
-
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1273
|
-
) -> List[rebel.Runtime]:
|
|
1274
|
-
expected_model_names = ["prefill"]
|
|
1275
|
-
if rbln_config.can_generate:
|
|
1276
|
-
expected_model_names.extend(
|
|
1277
|
-
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
1278
|
-
)
|
|
1279
|
-
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1280
|
-
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1281
|
-
|
|
1282
|
-
ret_val = [
|
|
1283
|
-
rebel.Runtime(
|
|
1284
|
-
compiled_models[0],
|
|
1285
|
-
tensor_type="pt",
|
|
1286
|
-
device=rbln_config.device_map["prefill"],
|
|
1287
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1288
|
-
timeout=rbln_config.timeout,
|
|
1289
|
-
)
|
|
1290
|
-
]
|
|
1291
|
-
if rbln_config.can_generate:
|
|
1292
|
-
ret_val.extend(
|
|
1293
|
-
[
|
|
1294
|
-
rebel.Runtime(
|
|
1295
|
-
compiled_models[i + 1],
|
|
1296
|
-
tensor_type="pt",
|
|
1297
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1298
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1299
|
-
timeout=rbln_config.timeout,
|
|
1300
|
-
)
|
|
1301
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1302
|
-
]
|
|
1303
|
-
)
|
|
1304
|
-
return ret_val
|
|
1305
|
-
|
|
1306
|
-
def get_decoder(self):
|
|
1307
|
-
if not self.can_generate():
|
|
1308
|
-
raise ValueError("Decode stage is not supported in this model.")
|
|
1309
|
-
return self.decoder
|
|
1310
|
-
|
|
1311
|
-
def can_generate(self):
|
|
1312
|
-
return self.rbln_config.can_generate
|
|
1313
|
-
|
|
1314
|
-
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1315
|
-
raise NotImplementedError
|
|
1316
|
-
|
|
1317
|
-
def prepare_inputs_for_generation(
|
|
1318
|
-
self,
|
|
1319
|
-
input_ids: torch.LongTensor,
|
|
1320
|
-
generate_idx: Optional[torch.Tensor] = None,
|
|
1321
|
-
attention_mask: Optional[torch.LongTensor] = None,
|
|
1322
|
-
inputs_embeds: Optional[torch.Tensor] = None,
|
|
1323
|
-
padded_cache_lengths: Optional[torch.Tensor] = None,
|
|
1324
|
-
**kwargs,
|
|
1325
|
-
):
|
|
1326
|
-
model_inputs = {}
|
|
1327
|
-
is_prefill_phase = generate_idx is None
|
|
1328
|
-
|
|
1329
|
-
if is_prefill_phase:
|
|
1330
|
-
generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1331
|
-
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1332
|
-
cache_position = None
|
|
1333
|
-
position_ids = None
|
|
1334
|
-
else:
|
|
1335
|
-
if inputs_embeds is not None:
|
|
1336
|
-
# if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
|
|
1337
|
-
inputs_embeds = None
|
|
1338
|
-
|
|
1339
|
-
input_ids = input_ids[:, -1:]
|
|
1340
|
-
position_ids = generate_idx
|
|
1341
|
-
cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
|
|
1342
|
-
generate_idx = generate_idx + 1
|
|
1343
|
-
model_inputs.update({"input_ids": input_ids})
|
|
1344
|
-
|
|
1345
|
-
if inputs_embeds is not None:
|
|
1346
|
-
if self.rbln_config.use_inputs_embeds:
|
|
1347
|
-
model_inputs.update({"inputs_embeds": inputs_embeds})
|
|
1348
|
-
else:
|
|
1349
|
-
raise ValueError(
|
|
1350
|
-
"The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
|
1351
|
-
)
|
|
1352
|
-
else:
|
|
1353
|
-
model_inputs.update({"input_ids": input_ids})
|
|
1354
|
-
|
|
1355
|
-
model_inputs.update(
|
|
1356
|
-
{
|
|
1357
|
-
"attention_mask": attention_mask,
|
|
1358
|
-
"cache_position": cache_position,
|
|
1359
|
-
"generate_idx": generate_idx,
|
|
1360
|
-
"position_ids": position_ids,
|
|
1361
|
-
"padded_cache_lengths": padded_cache_lengths,
|
|
1362
|
-
}
|
|
1363
|
-
)
|
|
1364
|
-
|
|
1365
|
-
return model_inputs
|
|
1366
|
-
|
|
1367
|
-
def _update_model_kwargs_for_generation(
|
|
1368
|
-
self,
|
|
1369
|
-
outputs: RBLNDecoderOnlyForCausalLMOutput,
|
|
1370
|
-
model_kwargs: Dict[str, Any],
|
|
1371
|
-
**kwargs,
|
|
1372
|
-
) -> Dict[str, Any]:
|
|
1373
|
-
# update generate_idx
|
|
1374
|
-
model_kwargs["generate_idx"] = outputs.generate_idx
|
|
1375
|
-
model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
|
|
1376
|
-
|
|
1377
|
-
return model_kwargs
|
|
1378
|
-
|
|
1379
689
|
def forward(
|
|
1380
690
|
self,
|
|
1381
691
|
input_ids: Optional[torch.LongTensor] = None,
|
|
@@ -1403,7 +713,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
|
1403
713
|
)
|
|
1404
714
|
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1405
715
|
|
|
1406
|
-
#
|
|
716
|
+
# Prefill
|
|
1407
717
|
if cache_position is None:
|
|
1408
718
|
logits = []
|
|
1409
719
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
@@ -1441,6 +751,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
|
1441
751
|
if not return_dict:
|
|
1442
752
|
return logits, generate_idx, padded_cache_lengths
|
|
1443
753
|
else:
|
|
1444
|
-
return
|
|
754
|
+
return RBLNDecoderOnlyOutput(
|
|
1445
755
|
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
1446
756
|
)
|