optimum-rbln 0.7.3.post2__py3-none-any.whl → 0.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +173 -35
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +816 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +62 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +52 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +56 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +74 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +236 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +289 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +111 -137
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +56 -71
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +44 -69
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +111 -114
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -1
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +2 -0
- optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +2 -0
- optimum/rbln/modeling.py +66 -40
- optimum/rbln/modeling_base.py +111 -86
- optimum/rbln/ops/__init__.py +4 -7
- optimum/rbln/ops/attn.py +271 -205
- optimum/rbln/ops/flash_attn.py +161 -67
- optimum/rbln/ops/kv_cache_update.py +4 -40
- optimum/rbln/ops/linear.py +25 -0
- optimum/rbln/transformers/__init__.py +97 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +120 -32
- optimum/rbln/transformers/models/auto/auto_factory.py +6 -6
- optimum/rbln/transformers/models/bart/__init__.py +2 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +12 -85
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +11 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +197 -178
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +343 -249
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +51 -0
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +459 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +18 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +99 -112
- optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +12 -21
- optimum/rbln/transformers/models/t5/__init__.py +2 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +21 -356
- optimum/rbln/transformers/models/t5/t5_architecture.py +10 -5
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +26 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +420 -0
- optimum/rbln/transformers/models/time_series_transformers/time_series_transformers_architecture.py +331 -0
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +2 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +135 -100
- optimum/rbln/transformers/models/whisper/whisper_architecture.py +73 -40
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/hub.py +2 -2
- optimum/rbln/utils/import_utils.py +23 -6
- optimum/rbln/utils/model_utils.py +4 -4
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +36 -44
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/METADATA +6 -6
- optimum_rbln-0.7.4.dist-info/RECORD +169 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.3.post2.dist-info/RECORD +0 -122
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.3.post2.dist-info → optimum_rbln-0.7.4.dist-info}/licenses/LICENSE +0 -0
@@ -26,13 +26,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, Pre
|
|
26
26
|
from transformers.modeling_utils import no_init_weights
|
27
27
|
from transformers.utils import ModelOutput
|
28
28
|
|
29
|
+
from ....configuration_utils import RBLNCompileConfig
|
29
30
|
from ....modeling import RBLNModel
|
30
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
33
|
from ...utils.rbln_quantization import QuantizationManager
|
34
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
34
35
|
from .decoderonly_architecture import (
|
35
36
|
DecoderOnlyWrapper,
|
37
|
+
set_default_values,
|
36
38
|
validate_attention_method,
|
37
39
|
)
|
38
40
|
|
@@ -161,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
161
163
|
attention_mask: Optional[torch.Tensor] = None,
|
162
164
|
batch_idx: Optional[int] = None,
|
163
165
|
block_tables: Optional[torch.Tensor] = None,
|
166
|
+
position_embed: Optional[torch.Tensor] = None,
|
164
167
|
):
|
165
168
|
if input_ids is None and inputs_embeds is None:
|
166
169
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -185,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
185
188
|
block_tables,
|
186
189
|
is_external_block_tables,
|
187
190
|
attention_mask=attention_mask,
|
191
|
+
position_embed=position_embed,
|
188
192
|
)
|
189
193
|
else:
|
190
|
-
return self.prefill_forward(
|
194
|
+
return self.prefill_forward(
|
195
|
+
inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
|
196
|
+
)
|
191
197
|
|
192
198
|
def decode_forward(
|
193
199
|
self,
|
@@ -196,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
196
202
|
block_tables: torch.Tensor = None,
|
197
203
|
is_external_block_tables: bool = None,
|
198
204
|
attention_mask: Optional[torch.Tensor] = None,
|
205
|
+
position_embed: Optional[torch.Tensor] = None,
|
199
206
|
) -> torch.FloatTensor:
|
200
207
|
batch_size = inputs.shape[0]
|
201
208
|
if batch_size != self.batch_size:
|
@@ -222,13 +229,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
222
229
|
|
223
230
|
attention_mask = self.dec_attn_mask
|
224
231
|
|
225
|
-
attention_mask = self.dec_attn_mask
|
226
|
-
|
227
232
|
logits = super().forward(
|
228
233
|
inputs,
|
229
234
|
cache_position,
|
230
235
|
attention_mask if self.use_attention_mask else None,
|
231
236
|
block_tables,
|
237
|
+
position_embed,
|
232
238
|
)
|
233
239
|
|
234
240
|
return logits
|
@@ -241,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
241
247
|
batch_idx: int = None,
|
242
248
|
block_tables: torch.Tensor = None,
|
243
249
|
is_external_block_tables: bool = None,
|
250
|
+
position_embed: Optional[torch.Tensor] = None,
|
244
251
|
) -> torch.FloatTensor:
|
245
252
|
"""
|
246
253
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -251,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
251
258
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
252
259
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
253
260
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
261
|
+
if position_embed is not None:
|
262
|
+
position_embed = (
|
263
|
+
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
264
|
+
)
|
254
265
|
|
255
266
|
query_length = inputs.shape[1]
|
256
267
|
if query_length > self.max_seq_len:
|
@@ -295,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
295
306
|
dim=-1,
|
296
307
|
)
|
297
308
|
|
309
|
+
if position_embed is not None:
|
310
|
+
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
311
|
+
|
298
312
|
# Extract the current chunk of inputs and cache positions
|
299
313
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
300
314
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
315
|
+
if position_embed is not None:
|
316
|
+
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
301
317
|
|
302
318
|
if self.use_attention_mask:
|
303
319
|
# Update attention mask to ensure proper causal behavior
|
@@ -315,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
315
331
|
chunked_attention_mask if self.use_attention_mask else None,
|
316
332
|
query_position,
|
317
333
|
block_tables,
|
334
|
+
position_embed_chunk if position_embed is not None else None,
|
318
335
|
out=out_buffers,
|
319
336
|
)
|
320
337
|
|
@@ -358,17 +375,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
358
375
|
_use_rotary_emb = True
|
359
376
|
|
360
377
|
def __post_init__(self, **kwargs):
|
361
|
-
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
362
|
-
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
363
|
-
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
364
|
-
self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
|
365
|
-
# FIXME get kvcache_num_blocks from compiled results.
|
366
|
-
self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
|
367
|
-
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
368
|
-
attn_impl = self.rbln_config.model_cfg["attn_impl"]
|
369
378
|
main_input_name = self.main_input_name
|
370
379
|
|
371
|
-
if self.rbln_config.
|
380
|
+
if self.rbln_config.use_inputs_embeds:
|
372
381
|
main_input_name = "inputs_embeds"
|
373
382
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
374
383
|
with no_init_weights():
|
@@ -382,40 +391,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
382
391
|
self.embed_tokens = None
|
383
392
|
|
384
393
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
385
|
-
dec_attn_mask = torch.zeros(
|
394
|
+
dec_attn_mask = torch.zeros(
|
395
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
396
|
+
)
|
386
397
|
block_tables = torch.zeros(
|
387
|
-
self.batch_size,
|
398
|
+
self.rbln_config.batch_size,
|
399
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
400
|
+
dtype=torch.int16,
|
388
401
|
).fill_(-1)
|
389
|
-
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
402
|
+
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
390
403
|
|
391
404
|
self.prefill_decoder = RBLNRuntimeModel(
|
392
405
|
runtime=self.model[0],
|
393
406
|
main_input_name=main_input_name,
|
394
407
|
embed_tokens=self.embed_tokens,
|
395
408
|
phase="prefill",
|
396
|
-
batch_size=self.batch_size,
|
409
|
+
batch_size=self.rbln_config.batch_size,
|
397
410
|
dec_attn_mask=dec_attn_mask,
|
398
411
|
block_tables=block_tables,
|
399
412
|
free_block_pool=free_block_pool,
|
400
|
-
kvcache_block_size=self.kvcache_block_size,
|
413
|
+
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
401
414
|
vocab_size=self.config.vocab_size,
|
402
|
-
prefill_chunk_size=self.prefill_chunk_size,
|
403
|
-
max_seq_len=self.max_seq_len,
|
404
|
-
use_attention_mask=self.use_attention_mask,
|
405
|
-
attn_impl=attn_impl,
|
415
|
+
prefill_chunk_size=self.rbln_config.prefill_chunk_size,
|
416
|
+
max_seq_len=self.rbln_config.max_seq_len,
|
417
|
+
use_attention_mask=self.rbln_config.use_attention_mask,
|
418
|
+
attn_impl=self.rbln_config.attn_impl,
|
406
419
|
)
|
407
420
|
self.decoder = RBLNRuntimeModel(
|
408
421
|
runtime=self.model[1],
|
409
422
|
main_input_name=main_input_name,
|
410
423
|
embed_tokens=self.embed_tokens,
|
411
424
|
phase="decode",
|
412
|
-
batch_size=self.batch_size,
|
425
|
+
batch_size=self.rbln_config.batch_size,
|
413
426
|
dec_attn_mask=dec_attn_mask,
|
414
427
|
block_tables=block_tables,
|
415
428
|
free_block_pool=free_block_pool,
|
416
|
-
kvcache_block_size=self.kvcache_block_size,
|
417
|
-
use_attention_mask=self.use_attention_mask,
|
418
|
-
attn_impl=attn_impl,
|
429
|
+
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
430
|
+
use_attention_mask=self.rbln_config.use_attention_mask,
|
431
|
+
attn_impl=self.rbln_config.attn_impl,
|
419
432
|
)
|
420
433
|
|
421
434
|
@classmethod
|
@@ -424,13 +437,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
424
437
|
model: "PreTrainedModel",
|
425
438
|
save_dir_path: Path,
|
426
439
|
subfolder: str,
|
427
|
-
rbln_config:
|
440
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
428
441
|
):
|
429
442
|
"""
|
430
443
|
If you are unavoidably running on a CPU rather than an RBLN device,
|
431
444
|
store the torch tensor, weight, etc. in this function.
|
432
445
|
"""
|
433
|
-
if rbln_config.
|
446
|
+
if rbln_config.use_inputs_embeds:
|
434
447
|
save_dict = {}
|
435
448
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
436
449
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
@@ -438,6 +451,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
438
451
|
def get_input_embeddings(self):
|
439
452
|
return self.embed_tokens
|
440
453
|
|
454
|
+
def get_attn_impl(self) -> str:
|
455
|
+
return self.rbln_config.attn_impl
|
456
|
+
|
457
|
+
def get_kvcache_num_blocks(self) -> int:
|
458
|
+
return self.rbln_config.kvcache_num_blocks
|
459
|
+
|
441
460
|
@classmethod
|
442
461
|
def get_quantized_model(
|
443
462
|
cls,
|
@@ -495,33 +514,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
495
514
|
return val
|
496
515
|
|
497
516
|
@classmethod
|
498
|
-
def get_pytorch_model(
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
517
|
+
def get_pytorch_model(
|
518
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
519
|
+
) -> "PreTrainedModel":
|
520
|
+
if (
|
521
|
+
rbln_config is not None
|
522
|
+
and "format" in rbln_config.quantization
|
523
|
+
and rbln_config.quantization["format"] == "rbln"
|
524
|
+
):
|
504
525
|
model = cls.get_quantized_model(*args, **kwargs)
|
505
526
|
else:
|
506
527
|
model = super().get_pytorch_model(*args, **kwargs)
|
507
528
|
|
508
|
-
logger.debug("Loaded the LLM model to the CPU.")
|
509
529
|
return model
|
510
530
|
|
511
531
|
@classmethod
|
512
|
-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "
|
513
|
-
wrapper_cfg = {
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
532
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
533
|
+
wrapper_cfg = {
|
534
|
+
"max_seq_len": rbln_config.max_seq_len,
|
535
|
+
"attn_impl": rbln_config.attn_impl,
|
536
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
537
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
538
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
539
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
540
|
+
}
|
520
541
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
521
542
|
|
522
543
|
@classmethod
|
523
544
|
@torch.inference_mode()
|
524
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config:
|
545
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
525
546
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
526
547
|
|
527
548
|
rbln_compile_configs = rbln_config.compile_cfgs
|
@@ -543,28 +564,81 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
543
564
|
|
544
565
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
545
566
|
|
546
|
-
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
547
|
-
|
548
567
|
@QuantizationManager.with_quantization_env
|
549
568
|
def compile_model(*args, **kwargs):
|
550
|
-
|
551
|
-
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
556
|
-
|
569
|
+
try:
|
570
|
+
original_linear = torch.nn.functional.linear
|
571
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
572
|
+
wrapped_model.phase = "prefill"
|
573
|
+
compiled_prefill = RBLNModel.compile(
|
574
|
+
wrapped_model,
|
575
|
+
prefill_compile_config,
|
576
|
+
example_inputs=prefill_example_inputs,
|
577
|
+
compile_context=context,
|
578
|
+
)
|
557
579
|
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
|
562
|
-
|
563
|
-
|
580
|
+
wrapped_model.phase = "decode"
|
581
|
+
compiled_decoder = RBLNModel.compile(
|
582
|
+
wrapped_model,
|
583
|
+
dec_compile_config,
|
584
|
+
example_inputs=dec_example_inputs,
|
585
|
+
compile_context=context,
|
586
|
+
)
|
587
|
+
return {"prefill": compiled_prefill, "decoder": compiled_decoder}
|
588
|
+
finally:
|
589
|
+
torch.nn.functional.linear = original_linear
|
590
|
+
|
591
|
+
compiled_models = compile_model(quantize_config=rbln_config.quantization)
|
592
|
+
|
593
|
+
# check if the memory is enough to have additional blocks
|
594
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
595
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
596
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
597
|
+
compiled_models=compiled_models,
|
598
|
+
model_config=model.config,
|
599
|
+
rbln_config=rbln_config,
|
564
600
|
)
|
565
|
-
return {"prefill": compiled_prefill, "decoder": compiled_decoder}
|
566
601
|
|
567
|
-
return
|
602
|
+
return compiled_models
|
603
|
+
|
604
|
+
@classmethod
|
605
|
+
def maybe_suggest_kvcache_num_blocks(
|
606
|
+
cls,
|
607
|
+
compiled_models: Dict[str, rebel.RBLNCompiledModel],
|
608
|
+
model_config: PretrainedConfig,
|
609
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
610
|
+
) -> None:
|
611
|
+
# Get the actual memory allocation of each node by key
|
612
|
+
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
613
|
+
alloc_memory_by_key: Dict[str, int] = {
|
614
|
+
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
615
|
+
}
|
616
|
+
for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
|
617
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
618
|
+
alloc_memory_by_key.pop("PortRecur") # kv-cache
|
619
|
+
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
620
|
+
|
621
|
+
# Get the maximum number of blocks that can be allocated
|
622
|
+
buffer = sum(alloc_memory_by_key.values())
|
623
|
+
max_num_blocks = cls.get_maximum_num_blocks(
|
624
|
+
config=model_config,
|
625
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
626
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
627
|
+
kernel_size=kernel_size,
|
628
|
+
buffer=buffer,
|
629
|
+
)
|
630
|
+
|
631
|
+
# Since our estimation logic is not always accurate,
|
632
|
+
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
633
|
+
# If the memory is not enough, the model will fail to compile.
|
634
|
+
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
635
|
+
logger.warning(
|
636
|
+
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
637
|
+
"Our analysis indicates that additional memory is available for more blocks. "
|
638
|
+
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
639
|
+
"Please be advised that our memory estimation algorithm has limitations, "
|
640
|
+
"and increasing this value may not guarantee successful model compilation."
|
641
|
+
)
|
568
642
|
|
569
643
|
@classmethod
|
570
644
|
def get_maximum_num_blocks(
|
@@ -572,14 +646,46 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
572
646
|
config: PretrainedConfig,
|
573
647
|
tensor_parallel_size: int,
|
574
648
|
kvcache_block_size: int,
|
575
|
-
nbits_per_param: int,
|
576
|
-
n_model_params: int,
|
649
|
+
nbits_per_param: Optional[int] = None,
|
650
|
+
n_model_params: Optional[int] = None,
|
651
|
+
kernel_size: Optional[int] = None,
|
652
|
+
buffer: Optional[int] = None,
|
577
653
|
) -> int:
|
654
|
+
"""
|
655
|
+
We are finding max_n_blocks(x) that satisfies the following equation:
|
656
|
+
|
657
|
+
available_dram - kernel_size - buffer
|
658
|
+
- num_layers * 2 * tensor_parallel_size
|
659
|
+
* align_2MB(
|
660
|
+
x
|
661
|
+
* block_size
|
662
|
+
* align_64(head_dim)
|
663
|
+
* math.ceil(num_key_value_heads / tensor_parallel_size)
|
664
|
+
* 2
|
665
|
+
) > 0
|
666
|
+
|
667
|
+
This inequality can be rewritten as follows:
|
668
|
+
|
669
|
+
a - c * align_2MB(b * x) > 0
|
670
|
+
where
|
671
|
+
a = available_dram - kernel_size - buffer
|
672
|
+
b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
673
|
+
c = num_layers * 2 * tensor_parallel_size
|
674
|
+
|
675
|
+
We can rewrite the inequality as follows:
|
676
|
+
k > align_2MB(b*x)
|
677
|
+
where
|
678
|
+
k = a / c
|
679
|
+
|
680
|
+
After that, we can derive the following equation:
|
681
|
+
x = floor(2**21 / b * floor((k - 1) / 2**21))
|
682
|
+
"""
|
683
|
+
|
578
684
|
def align(x: int, nbytes: int) -> int:
|
579
685
|
return int(math.ceil(x / nbytes) * nbytes)
|
580
686
|
|
581
687
|
def align_2MB(x: int) -> int:
|
582
|
-
return align(x, 2
|
688
|
+
return align(x, 2**21)
|
583
689
|
|
584
690
|
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
585
691
|
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
@@ -593,223 +699,206 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
593
699
|
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
594
700
|
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
595
701
|
|
596
|
-
|
597
|
-
|
598
|
-
|
599
|
-
|
600
|
-
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
702
|
+
if kernel_size is None:
|
703
|
+
if n_model_params is None:
|
704
|
+
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
705
|
+
# Get estimated kernel size (approximated)
|
706
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
707
|
+
lm_heads_nbytes = (
|
708
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
709
|
+
)
|
710
|
+
params = n_model_params - lm_heads_params
|
711
|
+
layer_nbytes = (
|
712
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
713
|
+
* num_layers
|
714
|
+
* tensor_parallel_size
|
715
|
+
)
|
716
|
+
kernel_size = layer_nbytes + lm_heads_nbytes
|
717
|
+
elif n_model_params is not None:
|
718
|
+
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
608
719
|
|
609
720
|
available_dram -= kernel_size
|
610
721
|
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
buffer
|
615
|
-
|
722
|
+
if buffer is None:
|
723
|
+
# TODO: Accurate buffer estimation
|
724
|
+
buffer_per_core = 2**29 # 500MB per npu
|
725
|
+
buffer = buffer_per_core * tensor_parallel_size
|
616
726
|
available_dram -= buffer
|
617
727
|
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
* head_dim
|
623
|
-
* math.ceil(num_key_value_heads / tensor_parallel_size) # Shard
|
624
|
-
* 2 # (fp16)
|
625
|
-
)
|
626
|
-
* num_layers
|
627
|
-
* 2 # (k, v)
|
628
|
-
* tensor_parallel_size
|
629
|
-
)
|
630
|
-
n_blocks = available_dram // nbytes_per_block
|
728
|
+
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
729
|
+
c = num_layers * 2 * tensor_parallel_size
|
730
|
+
k = available_dram / c
|
731
|
+
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
631
732
|
|
632
|
-
return
|
733
|
+
return max_n_blocks
|
633
734
|
|
634
735
|
@classmethod
|
635
|
-
def
|
736
|
+
def get_input_info(
|
636
737
|
cls,
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
738
|
+
batch_size: int,
|
739
|
+
query_length: int,
|
740
|
+
use_inputs_embeds: bool,
|
741
|
+
use_attention_mask: bool,
|
742
|
+
max_seq_len: int,
|
743
|
+
kvcache_block_size: int,
|
744
|
+
kvcache_num_blocks: int,
|
745
|
+
num_key_value_heads: int,
|
746
|
+
num_hidden_layers: int,
|
747
|
+
hidden_size: int,
|
748
|
+
head_dim: int,
|
749
|
+
):
|
750
|
+
if use_inputs_embeds:
|
751
|
+
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
752
|
+
else:
|
753
|
+
main_input = ("input_ids", [batch_size, query_length], "int64")
|
754
|
+
|
755
|
+
input_info = [
|
756
|
+
main_input,
|
757
|
+
(
|
758
|
+
"cache_position",
|
759
|
+
[batch_size, query_length],
|
760
|
+
"int32",
|
761
|
+
),
|
762
|
+
]
|
763
|
+
|
764
|
+
if use_attention_mask:
|
765
|
+
input_info.extend(
|
766
|
+
[
|
767
|
+
("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
|
768
|
+
]
|
662
769
|
)
|
663
770
|
|
664
|
-
if
|
665
|
-
|
666
|
-
|
771
|
+
if query_length > 1:
|
772
|
+
input_info.extend(
|
773
|
+
[
|
774
|
+
("query_position", [], "int16"),
|
775
|
+
]
|
667
776
|
)
|
668
|
-
if rbln_max_seq_len is None:
|
669
|
-
raise ValueError("`rbln_max_seq_len` should be specified.")
|
670
777
|
|
671
|
-
|
672
|
-
|
778
|
+
max_block_cnt = max_seq_len // kvcache_block_size
|
779
|
+
|
780
|
+
if query_length > 1:
|
781
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
782
|
+
else:
|
783
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
673
784
|
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
785
|
+
input_info.extend(
|
786
|
+
[
|
787
|
+
(
|
788
|
+
f"past_key_values_{i}",
|
789
|
+
[
|
790
|
+
kvcache_num_blocks,
|
791
|
+
num_key_value_heads,
|
792
|
+
kvcache_block_size,
|
793
|
+
head_dim,
|
794
|
+
],
|
795
|
+
"float32",
|
796
|
+
)
|
797
|
+
for i in range(num_hidden_layers * 2)
|
798
|
+
]
|
679
799
|
)
|
680
800
|
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
801
|
+
return input_info
|
802
|
+
|
803
|
+
@classmethod
|
804
|
+
def _update_rbln_config(
|
805
|
+
cls,
|
806
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
807
|
+
model: Optional["PreTrainedModel"] = None,
|
808
|
+
model_config: Optional["PretrainedConfig"] = None,
|
809
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
810
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
811
|
+
if rbln_config.max_seq_len is None:
|
812
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
813
|
+
model_config, "n_positions", None
|
814
|
+
)
|
815
|
+
if rbln_config.max_seq_len is None:
|
816
|
+
raise ValueError("`max_seq_len` should be specified.")
|
817
|
+
|
818
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
819
|
+
attn_impl=rbln_config.attn_impl,
|
820
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
821
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
822
|
+
max_seq_len=rbln_config.max_seq_len,
|
823
|
+
)
|
824
|
+
|
825
|
+
validate_attention_method(
|
826
|
+
attn_impl=rbln_config.attn_impl,
|
827
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
828
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
829
|
+
max_seq_len=rbln_config.max_seq_len,
|
830
|
+
)
|
831
|
+
|
832
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
833
|
+
max_num_blocks = required_num_blocks
|
686
834
|
|
687
|
-
|
688
|
-
|
689
|
-
max_num_blocks, _ = cls.get_maximum_num_blocks(
|
835
|
+
if rbln_config.attn_impl == "flash_attn":
|
836
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks(
|
690
837
|
config=model_config,
|
691
|
-
tensor_parallel_size=
|
692
|
-
kvcache_block_size=
|
693
|
-
nbits_per_param=16 if
|
694
|
-
n_model_params=
|
838
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
839
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
840
|
+
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
841
|
+
n_model_params=sum(p.numel() for p in model.parameters()),
|
695
842
|
)
|
696
|
-
rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
|
697
843
|
|
698
|
-
|
699
|
-
if rbln_kvcache_num_blocks < required_blocks:
|
700
|
-
rbln_kvcache_num_blocks = required_blocks
|
844
|
+
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
701
845
|
|
702
|
-
|
846
|
+
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
847
|
+
if max_num_blocks < flash_min_blocks:
|
848
|
+
max_num_blocks = flash_min_blocks
|
703
849
|
|
704
|
-
if
|
850
|
+
if max_num_blocks < rbln_config.batch_size:
|
705
851
|
raise RuntimeError(
|
706
|
-
f"Batch size ({
|
852
|
+
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
707
853
|
"Ensure the number of blocks is at least equal to the batch size."
|
708
854
|
)
|
709
855
|
|
856
|
+
if rbln_config.kvcache_num_blocks is None:
|
857
|
+
rbln_config.kvcache_num_blocks = max_num_blocks
|
858
|
+
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
859
|
+
logger.warning(
|
860
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
861
|
+
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
862
|
+
"This can cause a failure during model compilation."
|
863
|
+
)
|
864
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
710
865
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
711
866
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
712
867
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
713
|
-
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
714
868
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
869
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
715
870
|
|
716
|
-
|
717
|
-
batch_size,
|
718
|
-
query_length,
|
719
|
-
use_inputs_embeds,
|
720
|
-
hidden_size,
|
721
|
-
):
|
722
|
-
if use_inputs_embeds:
|
723
|
-
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
724
|
-
else:
|
725
|
-
main_input = ("input_ids", [batch_size, query_length], "int64")
|
726
|
-
|
727
|
-
input_info = [
|
728
|
-
main_input,
|
729
|
-
(
|
730
|
-
"cache_position",
|
731
|
-
[batch_size, query_length],
|
732
|
-
"int32",
|
733
|
-
),
|
734
|
-
]
|
735
|
-
|
736
|
-
if rbln_use_attention_mask:
|
737
|
-
input_info.extend(
|
738
|
-
[
|
739
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
740
|
-
]
|
741
|
-
)
|
742
|
-
|
743
|
-
if query_length > 1:
|
744
|
-
input_info.extend(
|
745
|
-
[
|
746
|
-
("query_position", [], "int16"),
|
747
|
-
]
|
748
|
-
)
|
749
|
-
|
750
|
-
max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
|
751
|
-
|
752
|
-
if query_length > 1:
|
753
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
754
|
-
else:
|
755
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
756
|
-
|
757
|
-
input_info.extend(
|
758
|
-
[
|
759
|
-
(
|
760
|
-
f"past_key_values_{i}",
|
761
|
-
[
|
762
|
-
rbln_kvcache_num_blocks,
|
763
|
-
num_key_value_heads,
|
764
|
-
rbln_kvcache_block_size,
|
765
|
-
head_dim,
|
766
|
-
],
|
767
|
-
"float32",
|
768
|
-
)
|
769
|
-
for i in range(num_hidden_layers * 2)
|
770
|
-
]
|
771
|
-
)
|
772
|
-
|
773
|
-
return input_info
|
774
|
-
|
775
|
-
prefill_input_info = get_input_info(
|
871
|
+
prefill_input_info = cls.get_input_info(
|
776
872
|
batch_size=1,
|
777
|
-
query_length=
|
778
|
-
use_inputs_embeds=
|
873
|
+
query_length=rbln_config.prefill_chunk_size,
|
874
|
+
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
875
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
876
|
+
max_seq_len=rbln_config.max_seq_len,
|
877
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
878
|
+
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
879
|
+
num_key_value_heads=num_key_value_heads,
|
880
|
+
num_hidden_layers=num_hidden_layers,
|
779
881
|
hidden_size=hidden_size,
|
882
|
+
head_dim=head_dim,
|
780
883
|
)
|
781
|
-
dec_input_info = get_input_info(
|
782
|
-
batch_size=
|
884
|
+
dec_input_info = cls.get_input_info(
|
885
|
+
batch_size=rbln_config.batch_size,
|
783
886
|
query_length=1,
|
784
|
-
use_inputs_embeds=
|
887
|
+
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
888
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
889
|
+
max_seq_len=rbln_config.max_seq_len,
|
890
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
891
|
+
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
892
|
+
num_key_value_heads=num_key_value_heads,
|
893
|
+
num_hidden_layers=num_hidden_layers,
|
785
894
|
hidden_size=hidden_size,
|
895
|
+
head_dim=head_dim,
|
786
896
|
)
|
787
897
|
|
788
898
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
789
899
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
790
900
|
|
791
|
-
rbln_config
|
792
|
-
rbln_cls=cls.__name__,
|
793
|
-
compile_cfgs=[prefill_compile_config, dec_compile_config],
|
794
|
-
rbln_kwargs=rbln_kwargs,
|
795
|
-
)
|
796
|
-
|
797
|
-
rbln_config.model_cfg.update(
|
798
|
-
{
|
799
|
-
"max_seq_len": rbln_max_seq_len,
|
800
|
-
"batch_size": rbln_batch_size,
|
801
|
-
"prefill_chunk_size": rbln_prefill_chunk_size,
|
802
|
-
"use_attention_mask": rbln_use_attention_mask,
|
803
|
-
"use_inputs_embeds": rbln_use_inputs_embeds,
|
804
|
-
"kvcache_partition_len": rbln_kvcache_partition_len,
|
805
|
-
"kvcache_block_size": rbln_kvcache_block_size,
|
806
|
-
"attn_impl": rbln_attn_impl,
|
807
|
-
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
808
|
-
}
|
809
|
-
)
|
810
|
-
|
811
|
-
if rbln_quantization is not None:
|
812
|
-
rbln_config.model_cfg.update({"quantization": rbln_quantization})
|
901
|
+
rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
|
813
902
|
|
814
903
|
return rbln_config
|
815
904
|
|
@@ -817,18 +906,23 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
817
906
|
def _create_runtimes(
|
818
907
|
cls,
|
819
908
|
compiled_models: List[rebel.RBLNCompiledModel],
|
820
|
-
|
821
|
-
activate_profiler: Optional[bool] = None,
|
909
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
822
910
|
) -> List[rebel.Runtime]:
|
823
|
-
if any(model_name not in
|
911
|
+
if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
|
824
912
|
cls._raise_missing_compiled_file_error(["prefill", "decoder"])
|
825
913
|
|
826
914
|
return [
|
827
|
-
|
828
|
-
|
915
|
+
rebel.Runtime(
|
916
|
+
compiled_models[0],
|
917
|
+
tensor_type="pt",
|
918
|
+
device=rbln_config.device_map["prefill"],
|
919
|
+
activate_profiler=rbln_config.activate_profiler,
|
829
920
|
),
|
830
|
-
|
831
|
-
|
921
|
+
rebel.Runtime(
|
922
|
+
compiled_models[1],
|
923
|
+
tensor_type="pt",
|
924
|
+
device=rbln_config.device_map["decoder"],
|
925
|
+
activate_profiler=rbln_config.activate_profiler,
|
832
926
|
),
|
833
927
|
]
|
834
928
|
|
@@ -865,11 +959,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
865
959
|
model_inputs.update({"input_ids": input_ids})
|
866
960
|
|
867
961
|
if inputs_embeds is not None:
|
868
|
-
if self.rbln_config.
|
962
|
+
if self.rbln_config.use_inputs_embeds:
|
869
963
|
model_inputs.update({"inputs_embeds": inputs_embeds})
|
870
964
|
else:
|
871
965
|
raise ValueError(
|
872
|
-
"The specifying
|
966
|
+
"The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
873
967
|
)
|
874
968
|
else:
|
875
969
|
model_inputs.update({"input_ids": input_ids})
|