optimum-rbln 0.8.1rc0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of optimum-rbln might be problematic. Click here for more details.
- optimum/rbln/__init__.py +58 -9
- optimum/rbln/__version__.py +2 -2
- optimum/rbln/configuration_utils.py +24 -5
- optimum/rbln/diffusers/configurations/models/__init__.py +1 -1
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +5 -3
- optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
- optimum/rbln/diffusers/configurations/models/{configuration_cosmos_transformer.py → configuration_transformer_cosmos.py} +7 -2
- optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
- optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
- optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +10 -6
- optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
- optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
- optimum/rbln/diffusers/modeling_diffusers.py +4 -5
- optimum/rbln/diffusers/models/__init__.py +3 -13
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
- optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
- optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
- optimum/rbln/diffusers/pipelines/__init__.py +1 -5
- optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
- optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -26
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +2 -2
- optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +2 -2
- optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
- optimum/rbln/modeling.py +4 -5
- optimum/rbln/modeling_base.py +18 -14
- optimum/rbln/ops/kv_cache_update.py +5 -0
- optimum/rbln/ops/linear.py +7 -0
- optimum/rbln/transformers/__init__.py +60 -0
- optimum/rbln/transformers/configuration_generic.py +4 -4
- optimum/rbln/transformers/modeling_attention_utils.py +252 -0
- optimum/rbln/transformers/modeling_generic.py +1 -4
- optimum/rbln/transformers/models/__init__.py +45 -30
- optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
- optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
- optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
- optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
- optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
- optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
- optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
- optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
- optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
- optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
- optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
- optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
- optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
- optimum/rbln/transformers/models/gemma/__init__.py +2 -2
- optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
- optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
- optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
- optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
- optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
- optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
- optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
- optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
- optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
- optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
- optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
- optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
- optimum/rbln/transformers/models/llama/__init__.py +2 -2
- optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
- optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
- optimum/rbln/transformers/models/llava/__init__.py +16 -0
- optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
- optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
- optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
- optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
- optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
- optimum/rbln/transformers/models/mistral/__init__.py +2 -2
- optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
- optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
- optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
- optimum/rbln/transformers/models/opt/__init__.py +2 -2
- optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
- optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
- optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
- optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
- optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
- optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
- optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
- optimum/rbln/transformers/models/phi/__init__.py +2 -2
- optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
- optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
- optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
- optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
- optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
- optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
- optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
- optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
- optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
- optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
- optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
- optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
- optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
- optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
- optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
- optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
- optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
- optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
- optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
- optimum/rbln/transformers/models/siglip/__init__.py +2 -6
- optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
- optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
- optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
- optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
- optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
- optimum/rbln/utils/depreacate_utils.py +16 -0
- optimum/rbln/utils/hub.py +8 -47
- optimum/rbln/utils/runtime_utils.py +31 -5
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +120 -103
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
- {optimum_rbln-0.8.1rc0.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
|
@@ -13,7 +13,6 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import inspect
|
|
16
|
-
import math
|
|
17
16
|
from collections import deque
|
|
18
17
|
from dataclasses import dataclass
|
|
19
18
|
from pathlib import Path
|
|
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
|
|
|
22
21
|
import rebel
|
|
23
22
|
import torch
|
|
24
23
|
from rebel.compile_context import CompileContext
|
|
25
|
-
from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
24
|
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
|
|
25
|
+
from transformers.modeling_outputs import BaseModelOutputWithPast
|
|
26
26
|
from transformers.modeling_utils import no_init_weights
|
|
27
27
|
from transformers.utils import ModelOutput
|
|
28
28
|
|
|
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
|
|
|
30
30
|
from ....modeling import RBLNModel
|
|
31
31
|
from ....utils.logging import get_logger
|
|
32
32
|
from ....utils.runtime_utils import RBLNPytorchRuntime
|
|
33
|
-
from ...
|
|
34
|
-
|
|
35
|
-
from .decoderonly_architecture import (
|
|
36
|
-
DecoderOnlyWrapper,
|
|
33
|
+
from ...modeling_attention_utils import (
|
|
34
|
+
RBLNDecoderOnlyFlashAttentionMixin,
|
|
37
35
|
set_default_values,
|
|
38
36
|
validate_attention_method,
|
|
39
|
-
|
|
37
|
+
validate_sliding_window,
|
|
40
38
|
)
|
|
39
|
+
from ...utils.rbln_quantization import prepare_model_for_quantization
|
|
40
|
+
from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
|
|
41
|
+
from .decoderonly_architecture import DecoderOnlyWrapper
|
|
41
42
|
|
|
42
43
|
|
|
43
44
|
logger = get_logger()
|
|
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
267
268
|
|
|
268
269
|
attention_mask = self.dec_attn_mask
|
|
269
270
|
|
|
270
|
-
if self.rbln_config.
|
|
271
|
+
if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
|
|
271
272
|
block_tables = block_tables[: self.batch_size]
|
|
272
273
|
|
|
273
274
|
if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
|
|
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
283
284
|
position_ids if self.rbln_config.use_position_ids else None,
|
|
284
285
|
)
|
|
285
286
|
|
|
286
|
-
return
|
|
287
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
|
|
287
288
|
|
|
288
289
|
def _prepare_prefill_inputs(
|
|
289
290
|
self,
|
|
@@ -303,6 +304,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
303
304
|
position_embed = (
|
|
304
305
|
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
305
306
|
)
|
|
307
|
+
if token_type_ids is not None:
|
|
308
|
+
token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
|
|
306
309
|
|
|
307
310
|
query_length = inputs.shape[1]
|
|
308
311
|
if query_length > self.rbln_config.max_seq_len:
|
|
@@ -352,8 +355,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
352
355
|
if position_embed is not None:
|
|
353
356
|
position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
354
357
|
|
|
358
|
+
if token_type_ids is not None:
|
|
359
|
+
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
|
|
360
|
+
|
|
355
361
|
# Overwrite position_ids and padded_cache_lengths
|
|
356
|
-
position_ids =
|
|
362
|
+
position_ids = cache_position.clone()
|
|
357
363
|
padded_cache_lengths = 0
|
|
358
364
|
|
|
359
365
|
return (
|
|
@@ -365,6 +371,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
365
371
|
position_embed,
|
|
366
372
|
padded_cache_lengths,
|
|
367
373
|
query_length,
|
|
374
|
+
token_type_ids,
|
|
368
375
|
)
|
|
369
376
|
|
|
370
377
|
def prefill_forward(
|
|
@@ -393,6 +400,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
393
400
|
position_embed,
|
|
394
401
|
padded_cache_lengths,
|
|
395
402
|
query_length,
|
|
403
|
+
token_type_ids,
|
|
396
404
|
) = self._prepare_prefill_inputs(
|
|
397
405
|
inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
|
|
398
406
|
)
|
|
@@ -442,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
|
|
|
442
450
|
self.dec_attn_mask[batch_idx].fill_(0)
|
|
443
451
|
self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
|
|
444
452
|
|
|
445
|
-
return
|
|
453
|
+
return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
|
|
446
454
|
|
|
447
455
|
|
|
448
456
|
@dataclass
|
|
449
|
-
class
|
|
457
|
+
class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
|
|
450
458
|
logits: torch.FloatTensor = None
|
|
451
459
|
generate_idx: torch.Tensor = None
|
|
452
460
|
padded_cache_lengths: int = None
|
|
453
461
|
|
|
454
462
|
|
|
455
|
-
class
|
|
463
|
+
class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
|
|
456
464
|
"""
|
|
457
|
-
A base class for decoder-only transformer models
|
|
465
|
+
A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
|
|
466
|
+
This class is used for RBLN-optimized models that are not causal language models.
|
|
458
467
|
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
459
468
|
|
|
460
469
|
The class provides core functionality for:
|
|
461
470
|
|
|
462
471
|
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
463
472
|
2. Handling the compilation process for RBLN devices
|
|
464
|
-
3. Managing inference operations for
|
|
473
|
+
3. Managing inference operations for decoder-only architectures
|
|
465
474
|
|
|
466
475
|
This class inherits from RBLNModel and implements specific methods required for
|
|
467
|
-
decoder-only architectures
|
|
476
|
+
decoder-only architectures.
|
|
468
477
|
|
|
469
478
|
Note:
|
|
470
479
|
- This class is designed to be subclassed by specific model implementations
|
|
471
|
-
(e.g.,
|
|
480
|
+
(e.g., RBLNLlamaModel, RBLNQwen2Model)
|
|
472
481
|
- Subclasses should implement model-specific conversion logic.
|
|
473
482
|
- The class handles RBLN-specific optimizations automatically during compilation
|
|
474
483
|
"""
|
|
475
484
|
|
|
476
485
|
main_input_name = "input_ids"
|
|
477
|
-
auto_model_class =
|
|
486
|
+
auto_model_class = AutoModel
|
|
478
487
|
_decoder_wrapper_cls = DecoderOnlyWrapper
|
|
479
488
|
_use_rotary_emb = True
|
|
480
489
|
|
|
481
490
|
def __post_init__(self, **kwargs):
|
|
482
|
-
main_input_name = self.main_input_name
|
|
483
|
-
|
|
484
491
|
if self.rbln_config.use_inputs_embeds:
|
|
485
|
-
main_input_name = "inputs_embeds"
|
|
486
492
|
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
487
493
|
self.embed_tokens = self._create_embedding_layer()
|
|
488
494
|
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
489
495
|
else:
|
|
490
496
|
self.embed_tokens = None
|
|
491
497
|
|
|
492
|
-
#
|
|
493
|
-
|
|
494
|
-
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
495
|
-
)
|
|
496
|
-
block_tables = torch.zeros(
|
|
497
|
-
self.rbln_config.batch_size,
|
|
498
|
-
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
499
|
-
dtype=torch.int16,
|
|
500
|
-
).fill_(-1)
|
|
501
|
-
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
502
|
-
|
|
503
|
-
self.prefill_decoder = RBLNRuntimeModel(
|
|
504
|
-
runtime=self.model[0],
|
|
505
|
-
main_input_name=main_input_name,
|
|
506
|
-
embed_tokens=self.embed_tokens,
|
|
507
|
-
phase="prefill",
|
|
508
|
-
batch_size=self.rbln_config.batch_size,
|
|
509
|
-
dec_attn_mask=dec_attn_mask,
|
|
510
|
-
block_tables=block_tables,
|
|
511
|
-
free_block_pool=free_block_pool,
|
|
512
|
-
rbln_config=self.rbln_config,
|
|
513
|
-
vocab_size=self.config.vocab_size,
|
|
514
|
-
)
|
|
498
|
+
# TODO: add prefill runtime class.
|
|
499
|
+
self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
|
|
515
500
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
self.
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
dec_attn_mask=dec_attn_mask,
|
|
525
|
-
block_tables=block_tables,
|
|
526
|
-
free_block_pool=free_block_pool,
|
|
527
|
-
rbln_config=self.rbln_config,
|
|
501
|
+
# attributes for prefill
|
|
502
|
+
if self.rbln_config.use_global_attention:
|
|
503
|
+
self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
|
|
504
|
+
if self.rbln_config.use_local_attention:
|
|
505
|
+
self.local_block_tables = torch.tensor([0], dtype=torch.int16)
|
|
506
|
+
if self.rbln_config.use_attention_mask:
|
|
507
|
+
self.causal_mask = 1 - torch.triu(
|
|
508
|
+
torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
|
|
528
509
|
)
|
|
529
510
|
|
|
530
|
-
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
531
|
-
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
532
|
-
|
|
533
511
|
@classmethod
|
|
534
512
|
def save_torch_artifacts(
|
|
535
513
|
cls,
|
|
@@ -564,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
564
542
|
return self.rbln_config.kvcache_num_blocks
|
|
565
543
|
|
|
566
544
|
@classmethod
|
|
567
|
-
def
|
|
568
|
-
cls,
|
|
569
|
-
model_id: str,
|
|
570
|
-
config: Optional[PretrainedConfig] = None,
|
|
571
|
-
use_auth_token: Optional[Union[bool, str]] = None,
|
|
572
|
-
revision: Optional[str] = None,
|
|
573
|
-
force_download: bool = False,
|
|
574
|
-
cache_dir: Optional[str] = None,
|
|
575
|
-
subfolder: str = "",
|
|
576
|
-
local_files_only: bool = False,
|
|
577
|
-
trust_remote_code: bool = False,
|
|
578
|
-
**kwargs,
|
|
579
|
-
):
|
|
580
|
-
kwargs = cls.update_kwargs(kwargs)
|
|
581
|
-
|
|
582
|
-
if config is None:
|
|
583
|
-
config = AutoConfig.from_pretrained(
|
|
584
|
-
model_id,
|
|
585
|
-
use_auth_token=use_auth_token,
|
|
586
|
-
revision=revision,
|
|
587
|
-
force_download=force_download,
|
|
588
|
-
cache_dir=cache_dir,
|
|
589
|
-
trust_remote_code=trust_remote_code,
|
|
590
|
-
**kwargs,
|
|
591
|
-
)
|
|
592
|
-
|
|
593
|
-
with no_init_weights():
|
|
594
|
-
model = AutoModelForCausalLM.from_config(config)
|
|
595
|
-
|
|
596
|
-
model = prepare_model_for_quantization(
|
|
597
|
-
model,
|
|
598
|
-
model_id,
|
|
599
|
-
kwargs.get("num_hidden_layers"),
|
|
600
|
-
use_auth_token=use_auth_token,
|
|
601
|
-
revision=revision,
|
|
602
|
-
cache_dir=cache_dir,
|
|
603
|
-
force_download=force_download,
|
|
604
|
-
local_files_only=local_files_only,
|
|
605
|
-
)
|
|
606
|
-
return model
|
|
607
|
-
|
|
608
|
-
def __getattr__(self, __name: str) -> Any:
|
|
609
|
-
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
610
|
-
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
611
|
-
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
612
|
-
# proper method binding.
|
|
613
|
-
|
|
614
|
-
# The method implements a delegation pattern that:
|
|
615
|
-
|
|
616
|
-
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
617
|
-
# 2. For other attributes: Returns them directly from the original class
|
|
618
|
-
|
|
619
|
-
def redirect(func):
|
|
620
|
-
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
621
|
-
|
|
622
|
-
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
623
|
-
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
624
|
-
return redirect(val)
|
|
625
|
-
return val
|
|
626
|
-
|
|
627
|
-
@classmethod
|
|
628
|
-
def get_pytorch_model(
|
|
629
|
-
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
|
|
630
|
-
) -> PreTrainedModel:
|
|
631
|
-
if rbln_config and rbln_config.quantization:
|
|
632
|
-
model = cls.get_quantized_model(*args, **kwargs)
|
|
633
|
-
else:
|
|
634
|
-
model = super().get_pytorch_model(*args, **kwargs)
|
|
635
|
-
|
|
636
|
-
return model
|
|
637
|
-
|
|
638
|
-
@classmethod
|
|
639
|
-
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
545
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
|
|
640
546
|
wrapper_cfg = {
|
|
641
547
|
"max_seq_len": rbln_config.max_seq_len,
|
|
642
548
|
"attn_impl": rbln_config.attn_impl,
|
|
@@ -653,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
653
559
|
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
654
560
|
|
|
655
561
|
@classmethod
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
wrapped_model
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
562
|
+
def _compile_model(
|
|
563
|
+
cls,
|
|
564
|
+
wrapped_model,
|
|
565
|
+
compile_config,
|
|
566
|
+
example_inputs,
|
|
567
|
+
compile_context,
|
|
568
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
569
|
+
quantization=None,
|
|
570
|
+
phase: str = "prefill",
|
|
571
|
+
):
|
|
572
|
+
try:
|
|
573
|
+
wrapped_model.phase = phase
|
|
574
|
+
if quantization:
|
|
575
|
+
quantization.maybe_set_quantization_env()
|
|
576
|
+
original_linear = torch.nn.functional.linear
|
|
577
|
+
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
578
|
+
compiled_model = cls.compile(
|
|
579
|
+
wrapped_model,
|
|
580
|
+
compile_config,
|
|
581
|
+
create_runtimes=rbln_config.create_runtimes,
|
|
582
|
+
device=rbln_config.device,
|
|
583
|
+
example_inputs=example_inputs,
|
|
584
|
+
compile_context=compile_context,
|
|
585
|
+
)
|
|
586
|
+
return compiled_model
|
|
587
|
+
finally:
|
|
588
|
+
torch.nn.functional.linear = original_linear
|
|
589
|
+
if quantization:
|
|
590
|
+
quantization.maybe_reset_quantization_env()
|
|
662
591
|
|
|
592
|
+
@classmethod
|
|
593
|
+
def _get_compile_context(
|
|
594
|
+
cls,
|
|
595
|
+
compile_config: RBLNCompileConfig,
|
|
596
|
+
example_inputs: List[torch.Tensor],
|
|
597
|
+
):
|
|
663
598
|
context = CompileContext(use_weight_sharing=True)
|
|
664
599
|
|
|
665
|
-
# Here we use meta tensor, for the memory efficiency.
|
|
666
|
-
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
667
|
-
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
668
|
-
|
|
669
600
|
# Mark static tensors (self kv states)
|
|
670
601
|
static_tensors = {}
|
|
671
|
-
for (name, _, _), tensor in zip(
|
|
602
|
+
for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
|
|
672
603
|
if "past_key_values" in name:
|
|
673
604
|
static_tensors[name] = tensor
|
|
674
605
|
context.mark_static_address(tensor)
|
|
675
606
|
|
|
676
|
-
|
|
677
|
-
try:
|
|
678
|
-
if quantization:
|
|
679
|
-
quantization.maybe_set_quantization_env()
|
|
680
|
-
original_linear = torch.nn.functional.linear
|
|
681
|
-
torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
|
|
682
|
-
compiled_model = cls.compile(
|
|
683
|
-
wrapped_model,
|
|
684
|
-
compile_config,
|
|
685
|
-
create_runtimes=rbln_config.create_runtimes,
|
|
686
|
-
device=rbln_config.device,
|
|
687
|
-
example_inputs=example_inputs,
|
|
688
|
-
compile_context=compile_context,
|
|
689
|
-
)
|
|
690
|
-
return compiled_model
|
|
691
|
-
finally:
|
|
692
|
-
torch.nn.functional.linear = original_linear
|
|
693
|
-
if quantization:
|
|
694
|
-
quantization.maybe_reset_quantization_env()
|
|
695
|
-
|
|
696
|
-
wrapped_model.phase = "prefill"
|
|
697
|
-
compiled_prefill = compile_model(
|
|
698
|
-
wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
|
|
699
|
-
)
|
|
700
|
-
|
|
701
|
-
wrapped_model.phase = "decode"
|
|
702
|
-
compiled_models = {"prefill": compiled_prefill}
|
|
703
|
-
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
|
|
704
|
-
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
705
|
-
compiled_decoder = compile_model(
|
|
706
|
-
wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
|
|
707
|
-
)
|
|
708
|
-
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
709
|
-
|
|
710
|
-
# check if the memory is enough to have additional blocks
|
|
711
|
-
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
712
|
-
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
713
|
-
cls.maybe_suggest_kvcache_num_blocks(
|
|
714
|
-
compiled_models=compiled_models,
|
|
715
|
-
model_config=model.config,
|
|
716
|
-
rbln_config=rbln_config,
|
|
717
|
-
)
|
|
718
|
-
|
|
719
|
-
return compiled_models
|
|
607
|
+
return context, static_tensors
|
|
720
608
|
|
|
721
609
|
@classmethod
|
|
722
|
-
|
|
610
|
+
@torch.inference_mode()
|
|
611
|
+
def get_compiled_model(
|
|
723
612
|
cls,
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
|
|
729
|
-
|
|
730
|
-
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
737
|
-
alloc_memory_by_key[key] += sum(memory_per_node)
|
|
738
|
-
alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
|
|
739
|
-
alloc_memory_by_key.pop("DramTensor", None) # kv-cache
|
|
740
|
-
kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
|
|
741
|
-
|
|
742
|
-
# Get the maximum number of blocks that can be allocated
|
|
743
|
-
buffer = sum(alloc_memory_by_key.values())
|
|
744
|
-
max_num_blocks = cls.get_maximum_num_blocks(
|
|
745
|
-
config=model_config,
|
|
746
|
-
tensor_parallel_size=rbln_config.tensor_parallel_size,
|
|
747
|
-
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
748
|
-
kernel_size=kernel_size,
|
|
749
|
-
buffer=buffer,
|
|
613
|
+
model: PreTrainedModel,
|
|
614
|
+
rbln_config: RBLNDecoderOnlyModelConfig,
|
|
615
|
+
):
|
|
616
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
617
|
+
compile_config = rbln_config.compile_cfgs[0]
|
|
618
|
+
|
|
619
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
620
|
+
meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
|
|
621
|
+
example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
622
|
+
context, _ = cls._get_compile_context(compile_config, example_inputs)
|
|
623
|
+
|
|
624
|
+
compiled_model = cls._compile_model(
|
|
625
|
+
wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
|
|
750
626
|
)
|
|
627
|
+
compiled_models = {"prefill": compiled_model}
|
|
751
628
|
|
|
752
|
-
|
|
753
|
-
# users can set `kvcache_num_blocks` to `max_num_blocks`.
|
|
754
|
-
# If the memory is not enough, the model will fail to compile.
|
|
755
|
-
if rbln_config.kvcache_num_blocks < max_num_blocks:
|
|
756
|
-
logger.warning(
|
|
757
|
-
f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
|
|
758
|
-
"Our analysis indicates that additional memory is available for more blocks. "
|
|
759
|
-
f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
|
|
760
|
-
"Please be advised that our memory estimation algorithm has limitations, "
|
|
761
|
-
"and increasing this value may not guarantee successful model compilation."
|
|
762
|
-
)
|
|
629
|
+
return compiled_models
|
|
763
630
|
|
|
764
631
|
@classmethod
|
|
765
|
-
def
|
|
766
|
-
cls,
|
|
767
|
-
|
|
768
|
-
|
|
769
|
-
kvcache_block_size: int,
|
|
770
|
-
nbits_per_param: Optional[int] = None,
|
|
771
|
-
n_model_params: Optional[int] = None,
|
|
772
|
-
kernel_size: Optional[int] = None,
|
|
773
|
-
buffer: Optional[int] = None,
|
|
774
|
-
num_runtimes: int = 2,
|
|
775
|
-
) -> int:
|
|
776
|
-
# We are finding max_n_blocks(x) that satisfies the following equation:
|
|
777
|
-
|
|
778
|
-
# available_dram - kernel_size - buffer
|
|
779
|
-
# - num_layers * 2 * tensor_parallel_size
|
|
780
|
-
# * align_2MB(
|
|
781
|
-
# x
|
|
782
|
-
# * block_size
|
|
783
|
-
# * align_64(head_dim)
|
|
784
|
-
# * math.ceil(num_key_value_heads / tensor_parallel_size)
|
|
785
|
-
# * 2
|
|
786
|
-
# ) > 0
|
|
787
|
-
|
|
788
|
-
# This inequality can be rewritten as follows:
|
|
789
|
-
|
|
790
|
-
# a - c * align_2MB(b * x) > 0
|
|
791
|
-
# where
|
|
792
|
-
# a = available_dram - kernel_size - buffer
|
|
793
|
-
# b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
|
|
794
|
-
# c = num_layers * 2 * tensor_parallel_size
|
|
795
|
-
|
|
796
|
-
# We can rewrite the inequality as follows:
|
|
797
|
-
# k > align_2MB(b*x)
|
|
798
|
-
# where
|
|
799
|
-
# k = a / c
|
|
800
|
-
|
|
801
|
-
# After that, we can derive the following equation:
|
|
802
|
-
# x = floor(2**21 / b * floor((k - 1) / 2**21))
|
|
803
|
-
|
|
804
|
-
def align(x: int, nbytes: int) -> int:
|
|
805
|
-
return int(math.ceil(x / nbytes) * nbytes)
|
|
806
|
-
|
|
807
|
-
def align_2MB(x: int) -> int:
|
|
808
|
-
return align(x, 2**21)
|
|
809
|
-
|
|
810
|
-
num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
|
|
811
|
-
num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
|
|
812
|
-
head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
|
|
813
|
-
vocab_size = config.vocab_size
|
|
814
|
-
hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
|
|
815
|
-
num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
|
|
816
|
-
|
|
817
|
-
# TODO(jongho): Update if target npu is REBEL.
|
|
818
|
-
ATOM_DRAM_NBYTES = 16 * 2**30
|
|
819
|
-
ATOM_SYS_DRAM_NBYTES = 288 * 2**20
|
|
820
|
-
available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
|
|
821
|
-
|
|
822
|
-
if kernel_size is None:
|
|
823
|
-
if n_model_params is None:
|
|
824
|
-
raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
|
|
825
|
-
# Get estimated kernel size (approximated)
|
|
826
|
-
lm_heads_params = align(vocab_size, 64) * hidden_size
|
|
827
|
-
lm_heads_nbytes = (
|
|
828
|
-
align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
|
|
829
|
-
)
|
|
830
|
-
params = n_model_params - lm_heads_params
|
|
831
|
-
layer_nbytes = (
|
|
832
|
-
align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
|
|
833
|
-
* num_layers
|
|
834
|
-
* tensor_parallel_size
|
|
835
|
-
)
|
|
836
|
-
kernel_size = layer_nbytes + lm_heads_nbytes
|
|
837
|
-
elif n_model_params is not None:
|
|
838
|
-
raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
|
|
839
|
-
|
|
840
|
-
available_dram -= kernel_size
|
|
632
|
+
def get_quantized_model(
|
|
633
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
634
|
+
) -> PreTrainedModel:
|
|
635
|
+
raise NotImplementedError
|
|
841
636
|
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
637
|
+
@classmethod
|
|
638
|
+
def get_pytorch_model(
|
|
639
|
+
cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
|
|
640
|
+
) -> PreTrainedModel:
|
|
641
|
+
if rbln_config and rbln_config.quantization:
|
|
642
|
+
model = cls.get_quantized_model(*args, **kwargs)
|
|
643
|
+
else:
|
|
644
|
+
model = super().get_pytorch_model(*args, **kwargs)
|
|
848
645
|
|
|
849
|
-
|
|
850
|
-
c = num_layers * 2 * tensor_parallel_size
|
|
851
|
-
k = available_dram / c
|
|
852
|
-
max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
|
|
646
|
+
return model
|
|
853
647
|
|
|
854
|
-
|
|
648
|
+
@classmethod
|
|
649
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
650
|
+
return use_local_attention
|
|
855
651
|
|
|
856
652
|
@classmethod
|
|
857
653
|
def get_input_info(
|
|
@@ -861,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
861
657
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
862
658
|
model_config: PretrainedConfig,
|
|
863
659
|
):
|
|
864
|
-
is_prefill: bool = query_length > 1
|
|
865
660
|
num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
|
|
866
661
|
num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
|
|
867
662
|
num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
|
|
868
663
|
hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
|
|
869
664
|
head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
|
|
870
|
-
|
|
665
|
+
is_prefill = query_length > 1
|
|
871
666
|
|
|
872
667
|
# 1. main input
|
|
873
668
|
if rbln_config.use_inputs_embeds:
|
|
@@ -886,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
886
681
|
]
|
|
887
682
|
|
|
888
683
|
# 3. block_tables
|
|
889
|
-
if rbln_config.
|
|
684
|
+
if rbln_config.use_global_attention:
|
|
890
685
|
max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
891
686
|
input_info.extend(
|
|
892
687
|
[("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
|
|
893
688
|
)
|
|
894
|
-
if rbln_config.
|
|
689
|
+
if rbln_config.use_local_attention:
|
|
895
690
|
input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
|
|
896
691
|
|
|
897
|
-
# 4. query_position
|
|
898
|
-
if is_prefill:
|
|
692
|
+
# 4. query_position for sliding window attention
|
|
693
|
+
if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
|
|
899
694
|
input_info.extend([("query_position", [], "int16")])
|
|
900
695
|
|
|
901
696
|
# 5. attention_mask & position_ids
|
|
@@ -917,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
917
712
|
rbln_config.kvcache_block_size,
|
|
918
713
|
head_dim,
|
|
919
714
|
]
|
|
920
|
-
local_kvcache_shape = [
|
|
715
|
+
local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
|
|
921
716
|
input_info.extend(
|
|
922
717
|
[
|
|
923
718
|
(
|
|
@@ -964,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
964
759
|
# ```
|
|
965
760
|
|
|
966
761
|
# Returns:
|
|
967
|
-
#
|
|
762
|
+
# RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
|
|
968
763
|
|
|
969
764
|
raise NotImplementedError(
|
|
970
765
|
"Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
|
|
971
766
|
"See method docstring for required configuration details."
|
|
972
767
|
)
|
|
973
768
|
|
|
769
|
+
@classmethod
|
|
770
|
+
def _update_attention_config(
|
|
771
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
772
|
+
):
|
|
773
|
+
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
774
|
+
attn_impl=rbln_config.attn_impl,
|
|
775
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
776
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
777
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
validate_attention_method(
|
|
781
|
+
attn_impl=rbln_config.attn_impl,
|
|
782
|
+
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
783
|
+
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
784
|
+
max_seq_len=rbln_config.max_seq_len,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if rbln_config.kvcache_num_blocks is None:
|
|
788
|
+
rbln_config.kvcache_num_blocks = (
|
|
789
|
+
rbln_config.max_seq_len // rbln_config.kvcache_block_size
|
|
790
|
+
) * rbln_config.batch_size
|
|
791
|
+
|
|
792
|
+
return rbln_config
|
|
793
|
+
|
|
974
794
|
@classmethod
|
|
975
795
|
def _update_rbln_config(
|
|
976
796
|
cls,
|
|
@@ -991,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
991
811
|
):
|
|
992
812
|
rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
|
|
993
813
|
if rbln_config.sliding_window is not None:
|
|
994
|
-
|
|
814
|
+
validate_sliding_window(rbln_config)
|
|
815
|
+
|
|
816
|
+
rbln_config = cls._update_attention_config(model, model_config, rbln_config)
|
|
817
|
+
|
|
818
|
+
prefill_input_info = cls.get_input_info(
|
|
819
|
+
batch_size=1,
|
|
820
|
+
query_length=rbln_config.prefill_chunk_size,
|
|
821
|
+
rbln_config=rbln_config,
|
|
822
|
+
model_config=model_config,
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
826
|
+
rbln_config.set_compile_cfgs([prefill_compile_config])
|
|
827
|
+
|
|
828
|
+
return rbln_config
|
|
829
|
+
|
|
830
|
+
@classmethod
|
|
831
|
+
def _create_runtimes(
|
|
832
|
+
cls,
|
|
833
|
+
compiled_models: List[rebel.RBLNCompiledModel],
|
|
834
|
+
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
835
|
+
) -> List[rebel.Runtime]:
|
|
836
|
+
expected_model_names = [
|
|
837
|
+
"prefill",
|
|
838
|
+
]
|
|
839
|
+
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
840
|
+
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
841
|
+
|
|
842
|
+
return [
|
|
843
|
+
rebel.Runtime(
|
|
844
|
+
compiled_models[0],
|
|
845
|
+
tensor_type="pt",
|
|
846
|
+
device=rbln_config.device_map["prefill"],
|
|
847
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
848
|
+
),
|
|
849
|
+
]
|
|
850
|
+
|
|
851
|
+
def _preprocess_chunked_prefill(
|
|
852
|
+
self,
|
|
853
|
+
inputs: torch.Tensor,
|
|
854
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
855
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
856
|
+
):
|
|
857
|
+
# valid sequence length of inputs_embeds
|
|
858
|
+
query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
|
|
859
|
+
|
|
860
|
+
# extract valid inputs
|
|
861
|
+
inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
|
|
862
|
+
|
|
863
|
+
if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
|
|
864
|
+
inputs = self.get_input_embeddings()(inputs)
|
|
865
|
+
|
|
866
|
+
if position_embed is not None:
|
|
867
|
+
position_embed = (
|
|
868
|
+
position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
# padding for chunked prefill
|
|
872
|
+
padding_size = (
|
|
873
|
+
self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
|
|
874
|
+
) % self.rbln_config.prefill_chunk_size
|
|
875
|
+
padded_len = query_length + padding_size
|
|
876
|
+
|
|
877
|
+
inputs = (
|
|
878
|
+
torch.nn.functional.pad(inputs, (0, padding_size))
|
|
879
|
+
if not self.rbln_config.use_inputs_embeds
|
|
880
|
+
else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
|
|
881
|
+
)
|
|
882
|
+
position_embed = (
|
|
883
|
+
None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
|
|
884
|
+
)
|
|
885
|
+
cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
|
|
886
|
+
|
|
887
|
+
chunked_attention_mask = (
|
|
888
|
+
torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
|
|
889
|
+
if self.rbln_config.use_attention_mask
|
|
890
|
+
else None
|
|
891
|
+
)
|
|
892
|
+
|
|
893
|
+
return inputs, position_embed, cache_position, query_length, chunked_attention_mask
|
|
894
|
+
|
|
895
|
+
def _chunked_prefill_forward(
|
|
896
|
+
self,
|
|
897
|
+
inputs: torch.Tensor,
|
|
898
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
899
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
900
|
+
):
|
|
901
|
+
padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
|
|
902
|
+
self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
|
|
903
|
+
)
|
|
904
|
+
|
|
905
|
+
# chunked prefill
|
|
906
|
+
last_hidden_states = []
|
|
907
|
+
for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
|
|
908
|
+
# Extract the current chunk of inputs and cache positions
|
|
909
|
+
input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
910
|
+
cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
|
|
911
|
+
|
|
912
|
+
valid_length = (
|
|
913
|
+
self.rbln_config.prefill_chunk_size
|
|
914
|
+
if (step + self.rbln_config.prefill_chunk_size) <= query_length
|
|
915
|
+
else query_length - step
|
|
916
|
+
)
|
|
917
|
+
if self.rbln_config.use_local_attention:
|
|
918
|
+
query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
|
|
919
|
+
else:
|
|
920
|
+
query_position = None
|
|
921
|
+
|
|
922
|
+
if self.rbln_config.use_attention_mask:
|
|
923
|
+
if step > 0:
|
|
924
|
+
chunked_attention_mask[:, :, :, :step] = 1
|
|
925
|
+
chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
|
|
926
|
+
|
|
927
|
+
# Forward pass for the current chunk
|
|
928
|
+
last_hidden_states_chunk = self.prefill_decoder(
|
|
929
|
+
input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
|
|
930
|
+
inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
|
|
931
|
+
cache_position=cache_pos_chunk,
|
|
932
|
+
block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
|
|
933
|
+
local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
|
|
934
|
+
query_position=query_position,
|
|
935
|
+
attention_mask=chunked_attention_mask,
|
|
936
|
+
position_emb=padded_position_embed,
|
|
937
|
+
)
|
|
938
|
+
last_hidden_states.append(last_hidden_states_chunk)
|
|
939
|
+
last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
|
|
940
|
+
|
|
941
|
+
return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
|
|
942
|
+
|
|
943
|
+
def _postprocess_chunked_prefill(
|
|
944
|
+
self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
|
|
945
|
+
):
|
|
946
|
+
# index copy for attention mask
|
|
947
|
+
if attention_mask is not None:
|
|
948
|
+
new_last_hidden_states = torch.full(
|
|
949
|
+
(1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
|
|
950
|
+
fill_value=1e-10,
|
|
951
|
+
dtype=last_hidden_states.dtype,
|
|
952
|
+
)
|
|
953
|
+
mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
|
|
954
|
+
new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
|
|
955
|
+
else:
|
|
956
|
+
new_last_hidden_states = last_hidden_states
|
|
957
|
+
return new_last_hidden_states
|
|
958
|
+
|
|
959
|
+
def forward(
|
|
960
|
+
self,
|
|
961
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
962
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
963
|
+
attention_mask: Optional[torch.LongTensor] = None,
|
|
964
|
+
position_embed: Optional[torch.Tensor] = None,
|
|
965
|
+
**kwargs,
|
|
966
|
+
) -> Tuple[torch.FloatTensor]:
|
|
967
|
+
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
968
|
+
batch_size = inputs.shape[0]
|
|
969
|
+
all_last_hidden_states = []
|
|
970
|
+
for b_idx in range(batch_size):
|
|
971
|
+
last_hidden_states = self._chunked_prefill_forward(
|
|
972
|
+
inputs[b_idx : b_idx + 1],
|
|
973
|
+
attention_mask[b_idx] if attention_mask is not None else None,
|
|
974
|
+
position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
|
|
975
|
+
)
|
|
976
|
+
all_last_hidden_states.append(last_hidden_states)
|
|
977
|
+
|
|
978
|
+
last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
|
|
979
|
+
return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
|
|
980
|
+
|
|
981
|
+
|
|
982
|
+
class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
|
|
983
|
+
"""
|
|
984
|
+
A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
|
|
985
|
+
This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
|
|
986
|
+
|
|
987
|
+
The class provides core functionality for:
|
|
988
|
+
|
|
989
|
+
1. Converting pre-trained transformer models to RBLN-optimized format
|
|
990
|
+
2. Handling the compilation process for RBLN devices
|
|
991
|
+
3. Managing inference operations for causal language modeling
|
|
992
|
+
|
|
993
|
+
This class inherits from RBLNModel and implements specific methods required for
|
|
994
|
+
decoder-only architectures and causal language modeling tasks.
|
|
995
|
+
|
|
996
|
+
Note:
|
|
997
|
+
- This class is designed to be subclassed by specific model implementations
|
|
998
|
+
(e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
|
|
999
|
+
- Subclasses should implement model-specific conversion logic.
|
|
1000
|
+
- The class handles RBLN-specific optimizations automatically during compilation
|
|
1001
|
+
"""
|
|
1002
|
+
|
|
1003
|
+
auto_model_class = AutoModelForCausalLM
|
|
1004
|
+
|
|
1005
|
+
def __post_init__(self, **kwargs):
|
|
1006
|
+
main_input_name = self.main_input_name
|
|
1007
|
+
|
|
1008
|
+
if self.rbln_config.use_inputs_embeds:
|
|
1009
|
+
main_input_name = "inputs_embeds"
|
|
1010
|
+
artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
|
|
1011
|
+
self.embed_tokens = self._create_embedding_layer()
|
|
1012
|
+
self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
|
|
1013
|
+
else:
|
|
1014
|
+
self.embed_tokens = None
|
|
1015
|
+
|
|
1016
|
+
# Initialize shared resources to be used across Runtime instances (prefill and decode phases)
|
|
1017
|
+
dec_attn_mask = torch.zeros(
|
|
1018
|
+
self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
|
|
1019
|
+
)
|
|
1020
|
+
block_tables = torch.zeros(
|
|
1021
|
+
self.rbln_config.batch_size,
|
|
1022
|
+
self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
|
|
1023
|
+
dtype=torch.int16,
|
|
1024
|
+
).fill_(-1)
|
|
1025
|
+
free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
|
|
1026
|
+
|
|
1027
|
+
self.prefill_decoder = RBLNRuntimeModel(
|
|
1028
|
+
runtime=self.model[0],
|
|
1029
|
+
main_input_name=main_input_name,
|
|
1030
|
+
embed_tokens=self.embed_tokens,
|
|
1031
|
+
phase="prefill",
|
|
1032
|
+
batch_size=self.rbln_config.batch_size,
|
|
1033
|
+
dec_attn_mask=dec_attn_mask,
|
|
1034
|
+
block_tables=block_tables,
|
|
1035
|
+
free_block_pool=free_block_pool,
|
|
1036
|
+
rbln_config=self.rbln_config,
|
|
1037
|
+
vocab_size=self.config.vocab_size,
|
|
1038
|
+
)
|
|
1039
|
+
|
|
1040
|
+
if self.can_generate():
|
|
1041
|
+
self.decoders = {}
|
|
1042
|
+
for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
|
|
1043
|
+
self.decoders[batch_size] = RBLNRuntimeModel(
|
|
1044
|
+
runtime=self.model[i + 1],
|
|
1045
|
+
main_input_name=main_input_name,
|
|
1046
|
+
embed_tokens=self.embed_tokens,
|
|
1047
|
+
phase="decode",
|
|
1048
|
+
batch_size=batch_size,
|
|
1049
|
+
dec_attn_mask=dec_attn_mask,
|
|
1050
|
+
block_tables=block_tables,
|
|
1051
|
+
free_block_pool=free_block_pool,
|
|
1052
|
+
rbln_config=self.rbln_config,
|
|
1053
|
+
)
|
|
1054
|
+
|
|
1055
|
+
# NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
|
|
1056
|
+
self.decoder = self.decoders[self.rbln_config.batch_size]
|
|
1057
|
+
|
|
1058
|
+
@classmethod
|
|
1059
|
+
def get_quantized_model(
|
|
1060
|
+
cls,
|
|
1061
|
+
model_id: str,
|
|
1062
|
+
config: Optional[PretrainedConfig] = None,
|
|
1063
|
+
use_auth_token: Optional[Union[bool, str]] = None,
|
|
1064
|
+
revision: Optional[str] = None,
|
|
1065
|
+
force_download: bool = False,
|
|
1066
|
+
cache_dir: Optional[str] = None,
|
|
1067
|
+
subfolder: str = "",
|
|
1068
|
+
local_files_only: bool = False,
|
|
1069
|
+
trust_remote_code: bool = False,
|
|
1070
|
+
**kwargs,
|
|
1071
|
+
):
|
|
1072
|
+
kwargs = cls.update_kwargs(kwargs)
|
|
1073
|
+
|
|
1074
|
+
if config is None:
|
|
1075
|
+
config = AutoConfig.from_pretrained(
|
|
1076
|
+
model_id,
|
|
1077
|
+
use_auth_token=use_auth_token,
|
|
1078
|
+
revision=revision,
|
|
1079
|
+
force_download=force_download,
|
|
1080
|
+
cache_dir=cache_dir,
|
|
1081
|
+
trust_remote_code=trust_remote_code,
|
|
1082
|
+
**kwargs,
|
|
1083
|
+
)
|
|
1084
|
+
|
|
1085
|
+
with no_init_weights():
|
|
1086
|
+
model = AutoModelForCausalLM.from_config(config)
|
|
1087
|
+
|
|
1088
|
+
model = prepare_model_for_quantization(
|
|
1089
|
+
model,
|
|
1090
|
+
model_id,
|
|
1091
|
+
kwargs.get("num_hidden_layers"),
|
|
1092
|
+
use_auth_token=use_auth_token,
|
|
1093
|
+
revision=revision,
|
|
1094
|
+
cache_dir=cache_dir,
|
|
1095
|
+
force_download=force_download,
|
|
1096
|
+
local_files_only=local_files_only,
|
|
1097
|
+
)
|
|
1098
|
+
return model
|
|
1099
|
+
|
|
1100
|
+
def __getattr__(self, __name: str) -> Any:
|
|
1101
|
+
# Special method to delegate attribute access to the original Huggingface LM class.
|
|
1102
|
+
# This method is called when an attribute is not found in the current instance's dictionary.
|
|
1103
|
+
# It enables transparent access to the original model's attributes and methods while maintaining
|
|
1104
|
+
# proper method binding.
|
|
1105
|
+
|
|
1106
|
+
# The method implements a delegation pattern that:
|
|
1107
|
+
|
|
1108
|
+
# 1. For methods: Creates a wrapper that properly binds 'self' to method calls
|
|
1109
|
+
# 2. For other attributes: Returns them directly from the original class
|
|
995
1110
|
|
|
1111
|
+
def redirect(func):
|
|
1112
|
+
return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
|
|
1113
|
+
|
|
1114
|
+
val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
|
|
1115
|
+
if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
|
|
1116
|
+
return redirect(val)
|
|
1117
|
+
return val
|
|
1118
|
+
|
|
1119
|
+
@classmethod
|
|
1120
|
+
def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
|
|
1121
|
+
wrapper_cfg = {
|
|
1122
|
+
"max_seq_len": rbln_config.max_seq_len,
|
|
1123
|
+
"attn_impl": rbln_config.attn_impl,
|
|
1124
|
+
"kvcache_partition_len": rbln_config.kvcache_partition_len,
|
|
1125
|
+
"kvcache_block_size": rbln_config.kvcache_block_size,
|
|
1126
|
+
"use_rotary_emb": cls._use_rotary_emb,
|
|
1127
|
+
"use_attention_mask": rbln_config.use_attention_mask,
|
|
1128
|
+
"use_position_ids": rbln_config.use_position_ids,
|
|
1129
|
+
"use_inputs_embeds": rbln_config.use_inputs_embeds,
|
|
1130
|
+
"cache_impl": rbln_config.cache_impl,
|
|
1131
|
+
"sliding_window": rbln_config.sliding_window,
|
|
1132
|
+
"sliding_window_layers": rbln_config.sliding_window_layers,
|
|
1133
|
+
}
|
|
1134
|
+
return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
|
|
1135
|
+
|
|
1136
|
+
@classmethod
|
|
1137
|
+
@torch.inference_mode()
|
|
1138
|
+
def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
|
|
1139
|
+
wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
|
|
1140
|
+
prefill_compile_config = rbln_config.compile_cfgs[0]
|
|
1141
|
+
|
|
1142
|
+
# Here we use meta tensor, for the memory efficiency.
|
|
1143
|
+
meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
|
|
1144
|
+
prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
|
|
1145
|
+
context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
|
|
1146
|
+
|
|
1147
|
+
compiled_models = {}
|
|
1148
|
+
compiled_models["prefill"] = cls._compile_model(
|
|
1149
|
+
wrapped_model,
|
|
1150
|
+
prefill_compile_config,
|
|
1151
|
+
prefill_example_inputs,
|
|
1152
|
+
context,
|
|
1153
|
+
rbln_config,
|
|
1154
|
+
rbln_config.quantization,
|
|
1155
|
+
phase="prefill",
|
|
1156
|
+
)
|
|
1157
|
+
|
|
1158
|
+
if rbln_config.can_generate:
|
|
1159
|
+
wrapped_model.phase = "decode"
|
|
1160
|
+
for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
|
|
1161
|
+
dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
|
|
1162
|
+
compiled_decoder = cls._compile_model(
|
|
1163
|
+
wrapped_model,
|
|
1164
|
+
dec_compile_config,
|
|
1165
|
+
dec_example_inputs,
|
|
1166
|
+
context,
|
|
1167
|
+
rbln_config,
|
|
1168
|
+
rbln_config.quantization,
|
|
1169
|
+
phase="decode",
|
|
1170
|
+
)
|
|
1171
|
+
compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
|
|
1172
|
+
|
|
1173
|
+
# check if the memory is enough to have additional blocks
|
|
1174
|
+
required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
|
|
1175
|
+
if rbln_config.kvcache_num_blocks < required_num_blocks:
|
|
1176
|
+
cls.maybe_suggest_kvcache_num_blocks(
|
|
1177
|
+
compiled_models=compiled_models,
|
|
1178
|
+
model_config=model.config,
|
|
1179
|
+
rbln_config=rbln_config,
|
|
1180
|
+
)
|
|
1181
|
+
|
|
1182
|
+
return compiled_models
|
|
1183
|
+
|
|
1184
|
+
@classmethod
|
|
1185
|
+
def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
|
|
1186
|
+
return is_prefill
|
|
1187
|
+
|
|
1188
|
+
@classmethod
|
|
1189
|
+
def _update_attention_config(
|
|
1190
|
+
cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
|
|
1191
|
+
):
|
|
996
1192
|
rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
|
|
997
1193
|
attn_impl=rbln_config.attn_impl,
|
|
998
1194
|
kvcache_partition_len=rbln_config.kvcache_partition_len,
|
|
@@ -1017,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1017
1213
|
kvcache_block_size=rbln_config.kvcache_block_size,
|
|
1018
1214
|
nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
|
|
1019
1215
|
n_model_params=sum(p.numel() for p in model.parameters()),
|
|
1020
|
-
num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
|
|
1216
|
+
num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
|
|
1021
1217
|
)
|
|
1022
1218
|
|
|
1023
1219
|
max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
|
|
1024
1220
|
|
|
1025
1221
|
flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
|
|
1026
|
-
if max_num_blocks < flash_min_blocks:
|
|
1222
|
+
if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
|
|
1027
1223
|
max_num_blocks = flash_min_blocks
|
|
1028
1224
|
|
|
1029
1225
|
if max_num_blocks < rbln_config.batch_size:
|
|
@@ -1042,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1042
1238
|
)
|
|
1043
1239
|
logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
|
|
1044
1240
|
|
|
1045
|
-
|
|
1046
|
-
batch_size=1,
|
|
1047
|
-
query_length=rbln_config.prefill_chunk_size,
|
|
1048
|
-
rbln_config=rbln_config,
|
|
1049
|
-
model_config=model_config,
|
|
1050
|
-
)
|
|
1051
|
-
|
|
1052
|
-
prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
|
|
1241
|
+
return rbln_config
|
|
1053
1242
|
|
|
1054
|
-
|
|
1055
|
-
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1243
|
+
@classmethod
|
|
1244
|
+
def _update_rbln_config(
|
|
1245
|
+
cls,
|
|
1246
|
+
preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
|
|
1247
|
+
model: Optional[PreTrainedModel] = None,
|
|
1248
|
+
model_config: Optional[PretrainedConfig] = None,
|
|
1249
|
+
rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
|
|
1250
|
+
) -> RBLNDecoderOnlyModelForCausalLMConfig:
|
|
1251
|
+
rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
|
|
1252
|
+
if rbln_config.can_generate:
|
|
1253
|
+
compile_configs = rbln_config.compile_cfgs
|
|
1254
|
+
for batch_size in rbln_config.decoder_batch_sizes:
|
|
1255
|
+
dec_input_info = cls.get_input_info(
|
|
1256
|
+
batch_size=batch_size,
|
|
1257
|
+
query_length=1,
|
|
1258
|
+
rbln_config=rbln_config,
|
|
1259
|
+
model_config=model_config,
|
|
1260
|
+
)
|
|
1261
|
+
compile_configs.append(
|
|
1262
|
+
RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
|
|
1263
|
+
)
|
|
1264
|
+
rbln_config.set_compile_cfgs(compile_configs)
|
|
1066
1265
|
|
|
1067
1266
|
return rbln_config
|
|
1068
1267
|
|
|
@@ -1072,36 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1072
1271
|
compiled_models: List[rebel.RBLNCompiledModel],
|
|
1073
1272
|
rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
|
|
1074
1273
|
) -> List[rebel.Runtime]:
|
|
1075
|
-
expected_model_names = [
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1274
|
+
expected_model_names = ["prefill"]
|
|
1275
|
+
if rbln_config.can_generate:
|
|
1276
|
+
expected_model_names.extend(
|
|
1277
|
+
[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
|
|
1278
|
+
)
|
|
1079
1279
|
if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
|
|
1080
1280
|
cls._raise_missing_compiled_file_error(expected_model_names)
|
|
1081
1281
|
|
|
1082
|
-
|
|
1282
|
+
ret_val = [
|
|
1083
1283
|
rebel.Runtime(
|
|
1084
1284
|
compiled_models[0],
|
|
1085
1285
|
tensor_type="pt",
|
|
1086
1286
|
device=rbln_config.device_map["prefill"],
|
|
1087
1287
|
activate_profiler=rbln_config.activate_profiler,
|
|
1088
|
-
|
|
1089
|
-
|
|
1090
|
-
rebel.Runtime(
|
|
1091
|
-
compiled_models[i + 1],
|
|
1092
|
-
tensor_type="pt",
|
|
1093
|
-
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1094
|
-
activate_profiler=rbln_config.activate_profiler,
|
|
1095
|
-
)
|
|
1096
|
-
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1097
|
-
],
|
|
1288
|
+
timeout=rbln_config.timeout,
|
|
1289
|
+
)
|
|
1098
1290
|
]
|
|
1291
|
+
if rbln_config.can_generate:
|
|
1292
|
+
ret_val.extend(
|
|
1293
|
+
[
|
|
1294
|
+
rebel.Runtime(
|
|
1295
|
+
compiled_models[i + 1],
|
|
1296
|
+
tensor_type="pt",
|
|
1297
|
+
device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
|
|
1298
|
+
activate_profiler=rbln_config.activate_profiler,
|
|
1299
|
+
timeout=rbln_config.timeout,
|
|
1300
|
+
)
|
|
1301
|
+
for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
|
|
1302
|
+
]
|
|
1303
|
+
)
|
|
1304
|
+
return ret_val
|
|
1099
1305
|
|
|
1100
1306
|
def get_decoder(self):
|
|
1307
|
+
if not self.can_generate():
|
|
1308
|
+
raise ValueError("Decode stage is not supported in this model.")
|
|
1101
1309
|
return self.decoder
|
|
1102
1310
|
|
|
1103
1311
|
def can_generate(self):
|
|
1104
|
-
return
|
|
1312
|
+
return self.rbln_config.can_generate
|
|
1105
1313
|
|
|
1106
1314
|
def _reorder_cache(self, past_key_values, beam_idx):
|
|
1107
1315
|
raise NotImplementedError
|
|
@@ -1158,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1158
1366
|
|
|
1159
1367
|
def _update_model_kwargs_for_generation(
|
|
1160
1368
|
self,
|
|
1161
|
-
outputs:
|
|
1369
|
+
outputs: RBLNDecoderOnlyForCausalLMOutput,
|
|
1162
1370
|
model_kwargs: Dict[str, Any],
|
|
1163
1371
|
**kwargs,
|
|
1164
1372
|
) -> Dict[str, Any]:
|
|
@@ -1186,7 +1394,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1186
1394
|
# A for-loop ensures synchronization with the HuggingFace generate API.
|
|
1187
1395
|
# The decoder stage operates as usual, processing inputs in batch mode.
|
|
1188
1396
|
|
|
1189
|
-
#
|
|
1397
|
+
# for only use forward
|
|
1398
|
+
if generate_idx is None:
|
|
1399
|
+
generate_idx = (
|
|
1400
|
+
attention_mask.sum(dim=-1, keepdim=True).int()
|
|
1401
|
+
if attention_mask is not None
|
|
1402
|
+
else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
|
|
1403
|
+
)
|
|
1404
|
+
padded_cache_lengths = torch.zeros_like(generate_idx)
|
|
1405
|
+
|
|
1406
|
+
# Prefill
|
|
1190
1407
|
if cache_position is None:
|
|
1191
1408
|
logits = []
|
|
1192
1409
|
inputs = inputs_embeds if inputs_embeds is not None else input_ids
|
|
@@ -1224,6 +1441,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
|
|
|
1224
1441
|
if not return_dict:
|
|
1225
1442
|
return logits, generate_idx, padded_cache_lengths
|
|
1226
1443
|
else:
|
|
1227
|
-
return
|
|
1444
|
+
return RBLNDecoderOnlyForCausalLMOutput(
|
|
1228
1445
|
logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
|
|
1229
1446
|
)
|