optimum-rbln 0.7.4a4__py3-none-any.whl → 0.7.4a6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- optimum/rbln/__init__.py +164 -36
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +772 -0
- optimum/rbln/diffusers/__init__.py +56 -0
- optimum/rbln/diffusers/configurations/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/models/__init__.py +6 -0
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +54 -0
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +44 -0
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +48 -0
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +66 -0
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +67 -0
- optimum/rbln/diffusers/configurations/pipelines/__init__.py +30 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +221 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +285 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +118 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +143 -0
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +124 -0
- optimum/rbln/diffusers/modeling_diffusers.py +63 -122
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +109 -128
- optimum/rbln/diffusers/models/autoencoders/vae.py +4 -6
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +84 -85
- optimum/rbln/diffusers/models/controlnet.py +55 -70
- optimum/rbln/diffusers/models/transformers/prior_transformer.py +40 -77
- optimum/rbln/diffusers/models/transformers/transformer_sd3.py +43 -68
- optimum/rbln/diffusers/models/unets/unet_2d_condition.py +110 -113
- optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +3 -4
- optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- optimum/rbln/modeling.py +58 -39
- optimum/rbln/modeling_base.py +107 -78
- optimum/rbln/transformers/__init__.py +87 -8
- optimum/rbln/transformers/configuration_alias.py +49 -0
- optimum/rbln/transformers/configuration_generic.py +142 -0
- optimum/rbln/transformers/modeling_generic.py +193 -280
- optimum/rbln/transformers/models/__init__.py +108 -34
- optimum/rbln/transformers/models/auto/auto_factory.py +3 -3
- optimum/rbln/transformers/models/bart/__init__.py +1 -0
- optimum/rbln/transformers/models/bart/configuration_bart.py +24 -0
- optimum/rbln/transformers/models/bart/modeling_bart.py +10 -84
- optimum/rbln/transformers/models/bert/__init__.py +1 -0
- optimum/rbln/transformers/models/bert/configuration_bert.py +31 -0
- optimum/rbln/transformers/models/bert/modeling_bert.py +7 -80
- optimum/rbln/transformers/models/clip/__init__.py +6 -0
- optimum/rbln/transformers/models/clip/configuration_clip.py +79 -0
- optimum/rbln/transformers/models/clip/modeling_clip.py +72 -75
- optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +90 -0
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +115 -84
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +282 -216
- optimum/rbln/transformers/models/dpt/__init__.py +1 -0
- optimum/rbln/transformers/models/dpt/configuration_dpt.py +19 -0
- optimum/rbln/transformers/models/dpt/modeling_dpt.py +3 -76
- optimum/rbln/transformers/models/exaone/__init__.py +1 -0
- optimum/rbln/transformers/models/exaone/configuration_exaone.py +19 -0
- optimum/rbln/transformers/models/gemma/__init__.py +1 -0
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +19 -0
- optimum/rbln/transformers/models/gpt2/__init__.py +1 -0
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +19 -0
- optimum/rbln/transformers/models/llama/__init__.py +1 -0
- optimum/rbln/transformers/models/llama/configuration_llama.py +19 -0
- optimum/rbln/transformers/models/llava_next/__init__.py +1 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +46 -0
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +12 -23
- optimum/rbln/transformers/models/midm/__init__.py +1 -0
- optimum/rbln/transformers/models/midm/configuration_midm.py +19 -0
- optimum/rbln/transformers/models/mistral/__init__.py +1 -0
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +19 -0
- optimum/rbln/transformers/models/phi/__init__.py +1 -0
- optimum/rbln/transformers/models/phi/configuration_phi.py +19 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +1 -0
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +68 -0
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +608 -0
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +214 -0
- optimum/rbln/transformers/models/seq2seq/__init__.py +1 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq2.py +66 -0
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +80 -97
- optimum/rbln/transformers/models/t5/__init__.py +1 -0
- optimum/rbln/transformers/models/t5/configuration_t5.py +24 -0
- optimum/rbln/transformers/models/t5/modeling_t5.py +22 -150
- optimum/rbln/transformers/models/time_series_transformers/__init__.py +1 -0
- optimum/rbln/transformers/models/time_series_transformers/configuration_time_series_transformer.py +34 -0
- optimum/rbln/transformers/models/time_series_transformers/modeling_time_series_transformers.py +52 -54
- optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -0
- optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec.py +19 -0
- optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +9 -72
- optimum/rbln/transformers/models/whisper/__init__.py +1 -0
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +64 -0
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +57 -72
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +1 -0
- optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +19 -0
- optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -83
- optimum/rbln/utils/runtime_utils.py +33 -2
- optimum/rbln/utils/submodule.py +26 -43
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/METADATA +1 -1
- optimum_rbln-0.7.4a6.dist-info/RECORD +166 -0
- optimum/rbln/modeling_config.py +0 -310
- optimum_rbln-0.7.4a4.dist-info/RECORD +0 -126
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.7.4a4.dist-info → optimum_rbln-0.7.4a6.dist-info}/licenses/LICENSE +0 -0
@@ -26,13 +26,15 @@ from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, Pre
|
|
26
26
|
from transformers.modeling_utils import no_init_weights
|
27
27
|
from transformers.utils import ModelOutput
|
28
28
|
|
29
|
+
from ....configuration_utils import RBLNCompileConfig
|
29
30
|
from ....modeling import RBLNModel
|
30
|
-
from ....modeling_config import RBLNCompileConfig, RBLNConfig
|
31
31
|
from ....utils.logging import get_logger
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
33
33
|
from ...utils.rbln_quantization import QuantizationManager
|
34
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
|
34
35
|
from .decoderonly_architecture import (
|
35
36
|
DecoderOnlyWrapper,
|
37
|
+
set_default_values,
|
36
38
|
validate_attention_method,
|
37
39
|
)
|
38
40
|
|
@@ -161,6 +163,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
161
163
|
attention_mask: Optional[torch.Tensor] = None,
|
162
164
|
batch_idx: Optional[int] = None,
|
163
165
|
block_tables: Optional[torch.Tensor] = None,
|
166
|
+
position_embed: Optional[torch.Tensor] = None,
|
164
167
|
):
|
165
168
|
if input_ids is None and inputs_embeds is None:
|
166
169
|
raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
|
@@ -185,9 +188,12 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
185
188
|
block_tables,
|
186
189
|
is_external_block_tables,
|
187
190
|
attention_mask=attention_mask,
|
191
|
+
position_embed=position_embed,
|
188
192
|
)
|
189
193
|
else:
|
190
|
-
return self.prefill_forward(
|
194
|
+
return self.prefill_forward(
|
195
|
+
inputs, cache_position, attention_mask, batch_idx, block_tables, position_embed=position_embed
|
196
|
+
)
|
191
197
|
|
192
198
|
def decode_forward(
|
193
199
|
self,
|
@@ -196,6 +202,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
196
202
|
block_tables: torch.Tensor = None,
|
197
203
|
is_external_block_tables: bool = None,
|
198
204
|
attention_mask: Optional[torch.Tensor] = None,
|
205
|
+
position_embed: Optional[torch.Tensor] = None,
|
199
206
|
) -> torch.FloatTensor:
|
200
207
|
batch_size = inputs.shape[0]
|
201
208
|
if batch_size != self.batch_size:
|
@@ -227,6 +234,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
227
234
|
cache_position,
|
228
235
|
attention_mask if self.use_attention_mask else None,
|
229
236
|
block_tables,
|
237
|
+
position_embed,
|
230
238
|
)
|
231
239
|
|
232
240
|
return logits
|
@@ -239,6 +247,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
239
247
|
batch_idx: int = None,
|
240
248
|
block_tables: torch.Tensor = None,
|
241
249
|
is_external_block_tables: bool = None,
|
250
|
+
position_embed: Optional[torch.Tensor] = None,
|
242
251
|
) -> torch.FloatTensor:
|
243
252
|
"""
|
244
253
|
Performs chunked prefill for efficient KV-cache updates and memory optimization.
|
@@ -249,6 +258,10 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
249
258
|
# Handle continuous batching in a compiled graph by extracting valid inputs
|
250
259
|
# If an attention mask is provided, select only the valid (non-masked) inputs
|
251
260
|
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
261
|
+
if position_embed is not None:
|
262
|
+
position_embed = (
|
263
|
+
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
264
|
+
)
|
252
265
|
|
253
266
|
query_length = inputs.shape[1]
|
254
267
|
if query_length > self.max_seq_len:
|
@@ -293,9 +306,14 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
293
306
|
dim=-1,
|
294
307
|
)
|
295
308
|
|
309
|
+
if position_embed is not None:
|
310
|
+
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
311
|
+
|
296
312
|
# Extract the current chunk of inputs and cache positions
|
297
313
|
input_chunk = inputs[:, step : step + self.prefill_chunk_size]
|
298
314
|
cache_pos_chunk = cache_position[:, step : step + self.prefill_chunk_size]
|
315
|
+
if position_embed is not None:
|
316
|
+
position_embed_chunk = position_embed[:, :, :, step : step + self.prefill_chunk_size, :]
|
299
317
|
|
300
318
|
if self.use_attention_mask:
|
301
319
|
# Update attention mask to ensure proper causal behavior
|
@@ -313,6 +331,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
313
331
|
chunked_attention_mask if self.use_attention_mask else None,
|
314
332
|
query_position,
|
315
333
|
block_tables,
|
334
|
+
position_embed_chunk if position_embed is not None else None,
|
316
335
|
out=out_buffers,
|
317
336
|
)
|
318
337
|
|
@@ -356,17 +375,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
356
375
|
_use_rotary_emb = True
|
357
376
|
|
358
377
|
def __post_init__(self, **kwargs):
|
359
|
-
self.batch_size = self.rbln_config.model_cfg["batch_size"]
|
360
|
-
self.max_seq_len = self.rbln_config.model_cfg["max_seq_len"]
|
361
|
-
self.prefill_chunk_size = self.rbln_config.model_cfg["prefill_chunk_size"]
|
362
|
-
self.kvcache_block_size = self.rbln_config.model_cfg["kvcache_block_size"]
|
363
|
-
# FIXME get kvcache_num_blocks from compiled results.
|
364
|
-
self.kvcache_num_blocks = self.rbln_config.model_cfg["kvcache_num_blocks"]
|
365
|
-
self.use_attention_mask = self.rbln_config.model_cfg["use_attention_mask"]
|
366
|
-
attn_impl = self.rbln_config.model_cfg["attn_impl"]
|
367
378
|
main_input_name = self.main_input_name
|
368
379
|
|
369
|
-
if self.rbln_config.
|
380
|
+
if self.rbln_config.use_inputs_embeds:
|
370
381
|
main_input_name = "inputs_embeds"
|
371
382
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
372
383
|
with no_init_weights():
|
@@ -380,40 +391,44 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
380
391
|
self.embed_tokens = None
|
381
392
|
|
382
393
|
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
383
|
-
dec_attn_mask = torch.zeros(
|
394
|
+
dec_attn_mask = torch.zeros(
|
395
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
396
|
+
)
|
384
397
|
block_tables = torch.zeros(
|
385
|
-
self.batch_size,
|
398
|
+
self.rbln_config.batch_size,
|
399
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
400
|
+
dtype=torch.int16,
|
386
401
|
).fill_(-1)
|
387
|
-
free_block_pool = deque(x for x in range(self.kvcache_num_blocks))
|
402
|
+
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
388
403
|
|
389
404
|
self.prefill_decoder = RBLNRuntimeModel(
|
390
405
|
runtime=self.model[0],
|
391
406
|
main_input_name=main_input_name,
|
392
407
|
embed_tokens=self.embed_tokens,
|
393
408
|
phase="prefill",
|
394
|
-
batch_size=self.batch_size,
|
409
|
+
batch_size=self.rbln_config.batch_size,
|
395
410
|
dec_attn_mask=dec_attn_mask,
|
396
411
|
block_tables=block_tables,
|
397
412
|
free_block_pool=free_block_pool,
|
398
|
-
kvcache_block_size=self.kvcache_block_size,
|
413
|
+
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
399
414
|
vocab_size=self.config.vocab_size,
|
400
|
-
prefill_chunk_size=self.prefill_chunk_size,
|
401
|
-
max_seq_len=self.max_seq_len,
|
402
|
-
use_attention_mask=self.use_attention_mask,
|
403
|
-
attn_impl=attn_impl,
|
415
|
+
prefill_chunk_size=self.rbln_config.prefill_chunk_size,
|
416
|
+
max_seq_len=self.rbln_config.max_seq_len,
|
417
|
+
use_attention_mask=self.rbln_config.use_attention_mask,
|
418
|
+
attn_impl=self.rbln_config.attn_impl,
|
404
419
|
)
|
405
420
|
self.decoder = RBLNRuntimeModel(
|
406
421
|
runtime=self.model[1],
|
407
422
|
main_input_name=main_input_name,
|
408
423
|
embed_tokens=self.embed_tokens,
|
409
424
|
phase="decode",
|
410
|
-
batch_size=self.batch_size,
|
425
|
+
batch_size=self.rbln_config.batch_size,
|
411
426
|
dec_attn_mask=dec_attn_mask,
|
412
427
|
block_tables=block_tables,
|
413
428
|
free_block_pool=free_block_pool,
|
414
|
-
kvcache_block_size=self.kvcache_block_size,
|
415
|
-
use_attention_mask=self.use_attention_mask,
|
416
|
-
attn_impl=attn_impl,
|
429
|
+
kvcache_block_size=self.rbln_config.kvcache_block_size,
|
430
|
+
use_attention_mask=self.rbln_config.use_attention_mask,
|
431
|
+
attn_impl=self.rbln_config.attn_impl,
|
417
432
|
)
|
418
433
|
|
419
434
|
@classmethod
|
@@ -422,13 +437,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
422
437
|
model: "PreTrainedModel",
|
423
438
|
save_dir_path: Path,
|
424
439
|
subfolder: str,
|
425
|
-
rbln_config:
|
440
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
426
441
|
):
|
427
442
|
"""
|
428
443
|
If you are unavoidably running on a CPU rather than an RBLN device,
|
429
444
|
store the torch tensor, weight, etc. in this function.
|
430
445
|
"""
|
431
|
-
if rbln_config.
|
446
|
+
if rbln_config.use_inputs_embeds:
|
432
447
|
save_dict = {}
|
433
448
|
save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
|
434
449
|
torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
|
@@ -493,33 +508,35 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
493
508
|
return val
|
494
509
|
|
495
510
|
@classmethod
|
496
|
-
def get_pytorch_model(
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
511
|
+
def get_pytorch_model(
|
512
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
513
|
+
) -> "PreTrainedModel":
|
514
|
+
if (
|
515
|
+
rbln_config is not None
|
516
|
+
and "format" in rbln_config.quantization
|
517
|
+
and rbln_config.quantization["format"] == "rbln"
|
518
|
+
):
|
502
519
|
model = cls.get_quantized_model(*args, **kwargs)
|
503
520
|
else:
|
504
521
|
model = super().get_pytorch_model(*args, **kwargs)
|
505
522
|
|
506
|
-
logger.debug("Loaded the LLM model to the CPU.")
|
507
523
|
return model
|
508
524
|
|
509
525
|
@classmethod
|
510
|
-
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "
|
511
|
-
wrapper_cfg = {
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
|
526
|
+
def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
527
|
+
wrapper_cfg = {
|
528
|
+
"max_seq_len": rbln_config.max_seq_len,
|
529
|
+
"attn_impl": rbln_config.attn_impl,
|
530
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
531
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
532
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
533
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
534
|
+
}
|
518
535
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
519
536
|
|
520
537
|
@classmethod
|
521
538
|
@torch.inference_mode()
|
522
|
-
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config:
|
539
|
+
def get_compiled_model(cls, model: "PreTrainedModel", rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
523
540
|
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
524
541
|
|
525
542
|
rbln_compile_configs = rbln_config.compile_cfgs
|
@@ -541,8 +558,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
541
558
|
|
542
559
|
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
543
560
|
|
544
|
-
quantize_config = rbln_config.model_cfg.get("quantization", None)
|
545
|
-
|
546
561
|
@QuantizationManager.with_quantization_env
|
547
562
|
def compile_model(*args, **kwargs):
|
548
563
|
try:
|
@@ -567,7 +582,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
567
582
|
finally:
|
568
583
|
torch.nn.functional.linear = original_linear
|
569
584
|
|
570
|
-
|
585
|
+
compiled_models = compile_model(quantize_config=rbln_config.quantization)
|
586
|
+
|
587
|
+
# check if the memory is enough to have additional blocks
|
588
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
589
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
590
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
591
|
+
compiled_models=compiled_models,
|
592
|
+
model_config=model.config,
|
593
|
+
rbln_config=rbln_config,
|
594
|
+
)
|
595
|
+
|
596
|
+
return compiled_models
|
597
|
+
|
598
|
+
@classmethod
|
599
|
+
def maybe_suggest_kvcache_num_blocks(
|
600
|
+
cls,
|
601
|
+
compiled_models: Dict[str, rebel.RBLNCompiledModel],
|
602
|
+
model_config: PretrainedConfig,
|
603
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
604
|
+
) -> None:
|
605
|
+
# Get the actual memory allocation of each node by key
|
606
|
+
alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
|
607
|
+
alloc_memory_by_key: Dict[str, int] = {
|
608
|
+
key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
|
609
|
+
}
|
610
|
+
for key, memory_per_node in compiled_models["decoder"].get_alloc_per_node_by_key().items():
|
611
|
+
alloc_memory_by_key[key] += sum(memory_per_node)
|
612
|
+
alloc_memory_by_key.pop("PortRecur") # kv-cache
|
613
|
+
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
614
|
+
|
615
|
+
# Get the maximum number of blocks that can be allocated
|
616
|
+
buffer = sum(alloc_memory_by_key.values())
|
617
|
+
max_num_blocks = cls.get_maximum_num_blocks(
|
618
|
+
config=model_config,
|
619
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
620
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
621
|
+
kernel_size=kernel_size,
|
622
|
+
buffer=buffer,
|
623
|
+
)
|
624
|
+
|
625
|
+
# Since our estimation logic is not always accurate,
|
626
|
+
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
627
|
+
# If the memory is not enough, the model will fail to compile.
|
628
|
+
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
629
|
+
logger.warning(
|
630
|
+
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
631
|
+
"Our analysis indicates that additional memory is available for more blocks. "
|
632
|
+
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
633
|
+
"Please be advised that our memory estimation algorithm has limitations, "
|
634
|
+
"and increasing this value may not guarantee successful model compilation."
|
635
|
+
)
|
571
636
|
|
572
637
|
@classmethod
|
573
638
|
def get_maximum_num_blocks(
|
@@ -575,8 +640,10 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
575
640
|
config: PretrainedConfig,
|
576
641
|
tensor_parallel_size: int,
|
577
642
|
kvcache_block_size: int,
|
578
|
-
nbits_per_param: int,
|
579
|
-
n_model_params: int,
|
643
|
+
nbits_per_param: Optional[int] = None,
|
644
|
+
n_model_params: Optional[int] = None,
|
645
|
+
kernel_size: Optional[int] = None,
|
646
|
+
buffer: Optional[int] = None,
|
580
647
|
) -> int:
|
581
648
|
"""
|
582
649
|
We are finding max_n_blocks(x) that satisfies the following equation:
|
@@ -626,24 +693,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
626
693
|
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
627
694
|
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
628
695
|
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
696
|
+
if kernel_size is None:
|
697
|
+
if n_model_params is None:
|
698
|
+
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
699
|
+
# Get estimated kernel size (approximated)
|
700
|
+
lm_heads_params = align(vocab_size, 64) * hidden_size
|
701
|
+
lm_heads_nbytes = (
|
702
|
+
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
703
|
+
)
|
704
|
+
params = n_model_params - lm_heads_params
|
705
|
+
layer_nbytes = (
|
706
|
+
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
707
|
+
* num_layers
|
708
|
+
* tensor_parallel_size
|
709
|
+
)
|
710
|
+
kernel_size = layer_nbytes + lm_heads_nbytes
|
711
|
+
elif n_model_params is not None:
|
712
|
+
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
641
713
|
|
642
714
|
available_dram -= kernel_size
|
643
715
|
|
644
|
-
|
645
|
-
|
646
|
-
|
716
|
+
if buffer is None:
|
717
|
+
# TODO: Accurate buffer estimation
|
718
|
+
buffer_per_core = 2**29 # 500MB per npu
|
719
|
+
buffer = buffer_per_core * tensor_parallel_size
|
647
720
|
available_dram -= buffer
|
648
721
|
|
649
722
|
b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
@@ -654,184 +727,172 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
654
727
|
return max_n_blocks
|
655
728
|
|
656
729
|
@classmethod
|
657
|
-
def
|
730
|
+
def get_input_info(
|
658
731
|
cls,
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
677
|
-
|
678
|
-
|
679
|
-
|
680
|
-
|
681
|
-
|
682
|
-
|
683
|
-
|
732
|
+
batch_size: int,
|
733
|
+
query_length: int,
|
734
|
+
use_inputs_embeds: bool,
|
735
|
+
use_attention_mask: bool,
|
736
|
+
max_seq_len: int,
|
737
|
+
kvcache_block_size: int,
|
738
|
+
kvcache_num_blocks: int,
|
739
|
+
num_key_value_heads: int,
|
740
|
+
num_hidden_layers: int,
|
741
|
+
hidden_size: int,
|
742
|
+
head_dim: int,
|
743
|
+
):
|
744
|
+
if use_inputs_embeds:
|
745
|
+
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
746
|
+
else:
|
747
|
+
main_input = ("input_ids", [batch_size, query_length], "int64")
|
748
|
+
|
749
|
+
input_info = [
|
750
|
+
main_input,
|
751
|
+
(
|
752
|
+
"cache_position",
|
753
|
+
[batch_size, query_length],
|
754
|
+
"int32",
|
755
|
+
),
|
756
|
+
]
|
757
|
+
|
758
|
+
if use_attention_mask:
|
759
|
+
input_info.extend(
|
760
|
+
[
|
761
|
+
("attention_mask", [batch_size, 1, query_length, max_seq_len], "float32"),
|
762
|
+
]
|
684
763
|
)
|
685
764
|
|
686
|
-
if
|
687
|
-
|
688
|
-
|
765
|
+
if query_length > 1:
|
766
|
+
input_info.extend(
|
767
|
+
[
|
768
|
+
("query_position", [], "int16"),
|
769
|
+
]
|
689
770
|
)
|
690
|
-
if rbln_max_seq_len is None:
|
691
|
-
raise ValueError("`rbln_max_seq_len` should be specified.")
|
692
771
|
|
693
|
-
|
694
|
-
|
772
|
+
max_block_cnt = max_seq_len // kvcache_block_size
|
773
|
+
|
774
|
+
if query_length > 1:
|
775
|
+
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
776
|
+
else:
|
777
|
+
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
778
|
+
|
779
|
+
input_info.extend(
|
780
|
+
[
|
781
|
+
(
|
782
|
+
f"past_key_values_{i}",
|
783
|
+
[
|
784
|
+
kvcache_num_blocks,
|
785
|
+
num_key_value_heads,
|
786
|
+
kvcache_block_size,
|
787
|
+
head_dim,
|
788
|
+
],
|
789
|
+
"float32",
|
790
|
+
)
|
791
|
+
for i in range(num_hidden_layers * 2)
|
792
|
+
]
|
793
|
+
)
|
794
|
+
|
795
|
+
return input_info
|
695
796
|
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
797
|
+
@classmethod
|
798
|
+
def _update_rbln_config(
|
799
|
+
cls,
|
800
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
801
|
+
model: Optional["PreTrainedModel"] = None,
|
802
|
+
model_config: Optional["PretrainedConfig"] = None,
|
803
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
804
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
805
|
+
if rbln_config.max_seq_len is None:
|
806
|
+
rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
|
807
|
+
model_config, "n_positions", None
|
808
|
+
)
|
809
|
+
if rbln_config.max_seq_len is None:
|
810
|
+
raise ValueError("`max_seq_len` should be specified.")
|
811
|
+
|
812
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
813
|
+
attn_impl=rbln_config.attn_impl,
|
814
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
815
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
816
|
+
max_seq_len=rbln_config.max_seq_len,
|
701
817
|
)
|
702
818
|
|
703
|
-
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
|
819
|
+
validate_attention_method(
|
820
|
+
attn_impl=rbln_config.attn_impl,
|
821
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
822
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
823
|
+
max_seq_len=rbln_config.max_seq_len,
|
824
|
+
)
|
825
|
+
|
826
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
827
|
+
max_num_blocks = required_num_blocks
|
708
828
|
|
709
|
-
|
710
|
-
|
711
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
829
|
+
if rbln_config.attn_impl == "flash_attn":
|
830
|
+
estimated_max_num_blocks = cls.get_maximum_num_blocks(
|
712
831
|
config=model_config,
|
713
|
-
tensor_parallel_size=
|
714
|
-
kvcache_block_size=
|
715
|
-
nbits_per_param=16 if
|
716
|
-
n_model_params=
|
832
|
+
tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
|
833
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
834
|
+
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
835
|
+
n_model_params=sum(p.numel() for p in model.parameters()),
|
717
836
|
)
|
718
|
-
rbln_kvcache_num_blocks = min(rbln_kvcache_num_blocks, max_num_blocks)
|
719
837
|
|
720
|
-
|
721
|
-
if rbln_kvcache_num_blocks < required_blocks:
|
722
|
-
rbln_kvcache_num_blocks = required_blocks
|
838
|
+
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
723
839
|
|
724
|
-
|
840
|
+
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
841
|
+
if max_num_blocks < flash_min_blocks:
|
842
|
+
max_num_blocks = flash_min_blocks
|
725
843
|
|
726
|
-
if
|
844
|
+
if max_num_blocks < rbln_config.batch_size:
|
727
845
|
raise RuntimeError(
|
728
|
-
f"Batch size ({
|
846
|
+
f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
|
729
847
|
"Ensure the number of blocks is at least equal to the batch size."
|
730
848
|
)
|
731
849
|
|
850
|
+
if rbln_config.kvcache_num_blocks is None:
|
851
|
+
rbln_config.kvcache_num_blocks = max_num_blocks
|
852
|
+
elif rbln_config.kvcache_num_blocks > max_num_blocks:
|
853
|
+
logger.warning(
|
854
|
+
f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
|
855
|
+
f" than the estimated maximum number of blocks ({max_num_blocks})."
|
856
|
+
"This can cause a failure during model compilation."
|
857
|
+
)
|
858
|
+
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
732
859
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
733
860
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
734
861
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
735
|
-
head_dim = getattr(model_config, "head_dim", None) or model_config.hidden_size // num_attention_heads
|
736
862
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
863
|
+
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
737
864
|
|
738
|
-
|
739
|
-
batch_size,
|
740
|
-
query_length,
|
741
|
-
use_inputs_embeds,
|
742
|
-
hidden_size,
|
743
|
-
):
|
744
|
-
if use_inputs_embeds:
|
745
|
-
main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
|
746
|
-
else:
|
747
|
-
main_input = ("input_ids", [batch_size, query_length], "int64")
|
748
|
-
|
749
|
-
input_info = [
|
750
|
-
main_input,
|
751
|
-
(
|
752
|
-
"cache_position",
|
753
|
-
[batch_size, query_length],
|
754
|
-
"int32",
|
755
|
-
),
|
756
|
-
]
|
757
|
-
|
758
|
-
if rbln_use_attention_mask:
|
759
|
-
input_info.extend(
|
760
|
-
[
|
761
|
-
("attention_mask", [batch_size, 1, query_length, rbln_max_seq_len], "float32"),
|
762
|
-
]
|
763
|
-
)
|
764
|
-
|
765
|
-
if query_length > 1:
|
766
|
-
input_info.extend(
|
767
|
-
[
|
768
|
-
("query_position", [], "int16"),
|
769
|
-
]
|
770
|
-
)
|
771
|
-
|
772
|
-
max_block_cnt = rbln_max_seq_len // rbln_kvcache_block_size
|
773
|
-
|
774
|
-
if query_length > 1:
|
775
|
-
input_info.extend([("block_tables", [max_block_cnt], "int16")])
|
776
|
-
else:
|
777
|
-
input_info.extend([("block_tables", [batch_size, max_block_cnt], "int16")])
|
778
|
-
|
779
|
-
input_info.extend(
|
780
|
-
[
|
781
|
-
(
|
782
|
-
f"past_key_values_{i}",
|
783
|
-
[
|
784
|
-
rbln_kvcache_num_blocks,
|
785
|
-
num_key_value_heads,
|
786
|
-
rbln_kvcache_block_size,
|
787
|
-
head_dim,
|
788
|
-
],
|
789
|
-
"float32",
|
790
|
-
)
|
791
|
-
for i in range(num_hidden_layers * 2)
|
792
|
-
]
|
793
|
-
)
|
794
|
-
|
795
|
-
return input_info
|
796
|
-
|
797
|
-
prefill_input_info = get_input_info(
|
865
|
+
prefill_input_info = cls.get_input_info(
|
798
866
|
batch_size=1,
|
799
|
-
query_length=
|
800
|
-
use_inputs_embeds=
|
867
|
+
query_length=rbln_config.prefill_chunk_size,
|
868
|
+
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
869
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
870
|
+
max_seq_len=rbln_config.max_seq_len,
|
871
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
872
|
+
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
873
|
+
num_key_value_heads=num_key_value_heads,
|
874
|
+
num_hidden_layers=num_hidden_layers,
|
801
875
|
hidden_size=hidden_size,
|
876
|
+
head_dim=head_dim,
|
802
877
|
)
|
803
|
-
dec_input_info = get_input_info(
|
804
|
-
batch_size=
|
878
|
+
dec_input_info = cls.get_input_info(
|
879
|
+
batch_size=rbln_config.batch_size,
|
805
880
|
query_length=1,
|
806
|
-
use_inputs_embeds=
|
881
|
+
use_inputs_embeds=rbln_config.use_inputs_embeds,
|
882
|
+
use_attention_mask=rbln_config.use_attention_mask,
|
883
|
+
max_seq_len=rbln_config.max_seq_len,
|
884
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
885
|
+
kvcache_num_blocks=rbln_config.kvcache_num_blocks,
|
886
|
+
num_key_value_heads=num_key_value_heads,
|
887
|
+
num_hidden_layers=num_hidden_layers,
|
807
888
|
hidden_size=hidden_size,
|
889
|
+
head_dim=head_dim,
|
808
890
|
)
|
809
891
|
|
810
892
|
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
811
893
|
dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
|
812
894
|
|
813
|
-
rbln_config
|
814
|
-
rbln_cls=cls.__name__,
|
815
|
-
compile_cfgs=[prefill_compile_config, dec_compile_config],
|
816
|
-
rbln_kwargs=rbln_kwargs,
|
817
|
-
)
|
818
|
-
|
819
|
-
rbln_config.model_cfg.update(
|
820
|
-
{
|
821
|
-
"max_seq_len": rbln_max_seq_len,
|
822
|
-
"batch_size": rbln_batch_size,
|
823
|
-
"prefill_chunk_size": rbln_prefill_chunk_size,
|
824
|
-
"use_attention_mask": rbln_use_attention_mask,
|
825
|
-
"use_inputs_embeds": rbln_use_inputs_embeds,
|
826
|
-
"kvcache_partition_len": rbln_kvcache_partition_len,
|
827
|
-
"kvcache_block_size": rbln_kvcache_block_size,
|
828
|
-
"attn_impl": rbln_attn_impl,
|
829
|
-
"kvcache_num_blocks": rbln_kvcache_num_blocks,
|
830
|
-
}
|
831
|
-
)
|
832
|
-
|
833
|
-
if rbln_quantization is not None:
|
834
|
-
rbln_config.model_cfg.update({"quantization": rbln_quantization})
|
895
|
+
rbln_config.set_compile_cfgs([prefill_compile_config, dec_compile_config])
|
835
896
|
|
836
897
|
return rbln_config
|
837
898
|
|
@@ -839,18 +900,23 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
839
900
|
def _create_runtimes(
|
840
901
|
cls,
|
841
902
|
compiled_models: List[rebel.RBLNCompiledModel],
|
842
|
-
|
843
|
-
activate_profiler: Optional[bool] = None,
|
903
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
844
904
|
) -> List[rebel.Runtime]:
|
845
|
-
if any(model_name not in
|
905
|
+
if any(model_name not in rbln_config.device_map for model_name in ["prefill", "decoder"]):
|
846
906
|
cls._raise_missing_compiled_file_error(["prefill", "decoder"])
|
847
907
|
|
848
908
|
return [
|
849
|
-
|
850
|
-
|
909
|
+
rebel.Runtime(
|
910
|
+
compiled_models[0],
|
911
|
+
tensor_type="pt",
|
912
|
+
device=rbln_config.device_map["prefill"],
|
913
|
+
activate_profiler=rbln_config.activate_profiler,
|
851
914
|
),
|
852
|
-
|
853
|
-
|
915
|
+
rebel.Runtime(
|
916
|
+
compiled_models[1],
|
917
|
+
tensor_type="pt",
|
918
|
+
device=rbln_config.device_map["decoder"],
|
919
|
+
activate_profiler=rbln_config.activate_profiler,
|
854
920
|
),
|
855
921
|
]
|
856
922
|
|
@@ -887,11 +953,11 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
887
953
|
model_inputs.update({"input_ids": input_ids})
|
888
954
|
|
889
955
|
if inputs_embeds is not None:
|
890
|
-
if self.rbln_config.
|
956
|
+
if self.rbln_config.use_inputs_embeds:
|
891
957
|
model_inputs.update({"inputs_embeds": inputs_embeds})
|
892
958
|
else:
|
893
959
|
raise ValueError(
|
894
|
-
"The specifying
|
960
|
+
"The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
|
895
961
|
)
|
896
962
|
else:
|
897
963
|
model_inputs.update({"input_ids": input_ids})
|