optimum-rbln 0.8.2a3__py3-none-any.whl → 0.8.2a5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (58) hide show
  1. optimum/rbln/__init__.py +36 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +40 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +31 -14
  9. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -4
  10. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  11. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +204 -44
  12. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +124 -208
  13. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +567 -366
  14. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  15. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  16. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  17. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +0 -6
  18. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +10 -6
  19. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  20. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  21. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  22. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  23. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  24. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  25. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  26. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  27. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  28. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  29. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  30. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  31. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  32. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  33. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  34. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  35. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  36. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  37. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  38. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  39. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  40. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  41. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  42. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  43. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  44. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  45. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  46. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  47. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  48. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  49. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  50. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  51. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  52. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  53. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  54. optimum/rbln/utils/depreacate_utils.py +16 -0
  55. {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/METADATA +1 -1
  56. {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/RECORD +58 -52
  57. {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/WHEEL +0 -0
  58. {optimum_rbln-0.8.2a3.dist-info → optimum_rbln-0.8.2a5.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- import math
17
16
  from collections import deque
18
17
  from dataclasses import dataclass
19
18
  from pathlib import Path
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
22
21
  import rebel
23
22
  import torch
24
23
  from rebel.compile_context import CompileContext
25
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
24
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import prepare_model_for_quantization
34
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
- from .decoderonly_architecture import (
36
- DecoderOnlyWrapper,
33
+ from ...modeling_attention_utils import (
34
+ RBLNDecoderOnlyFlashAttentionMixin,
37
35
  set_default_values,
38
36
  validate_attention_method,
39
- validate_sliding_window_size,
37
+ validate_sliding_window,
40
38
  )
39
+ from ...utils.rbln_quantization import prepare_model_for_quantization
40
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
41
+ from .decoderonly_architecture import DecoderOnlyWrapper
41
42
 
42
43
 
43
44
  logger = get_logger()
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
267
268
 
268
269
  attention_mask = self.dec_attn_mask
269
270
 
270
- if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
271
+ if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
271
272
  block_tables = block_tables[: self.batch_size]
272
273
 
273
274
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
283
284
  position_ids if self.rbln_config.use_position_ids else None,
284
285
  )
285
286
 
286
- return RBLNDecoderOnlyOutput(logits=logits)
287
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
287
288
 
288
289
  def _prepare_prefill_inputs(
289
290
  self,
@@ -303,6 +304,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
303
304
  position_embed = (
304
305
  position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
305
306
  )
307
+ if token_type_ids is not None:
308
+ token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
306
309
 
307
310
  query_length = inputs.shape[1]
308
311
  if query_length > self.rbln_config.max_seq_len:
@@ -447,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
447
450
  self.dec_attn_mask[batch_idx].fill_(0)
448
451
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
449
452
 
450
- return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
453
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
451
454
 
452
455
 
453
456
  @dataclass
454
- class RBLNDecoderOnlyOutput(ModelOutput):
457
+ class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
455
458
  logits: torch.FloatTensor = None
456
459
  generate_idx: torch.Tensor = None
457
460
  padded_cache_lengths: int = None
458
461
 
459
462
 
460
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
463
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
461
464
  """
462
- A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
465
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
466
+ This class is used for RBLN-optimized models that are not causal language models.
463
467
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
464
468
 
465
469
  The class provides core functionality for:
466
470
 
467
471
  1. Converting pre-trained transformer models to RBLN-optimized format
468
472
  2. Handling the compilation process for RBLN devices
469
- 3. Managing inference operations for causal language modeling
473
+ 3. Managing inference operations for decoder-only architectures
470
474
 
471
475
  This class inherits from RBLNModel and implements specific methods required for
472
- decoder-only architectures and causal language modeling tasks.
476
+ decoder-only architectures.
473
477
 
474
478
  Note:
475
479
  - This class is designed to be subclassed by specific model implementations
476
- (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
480
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
477
481
  - Subclasses should implement model-specific conversion logic.
478
482
  - The class handles RBLN-specific optimizations automatically during compilation
479
483
  """
480
484
 
481
485
  main_input_name = "input_ids"
482
- auto_model_class = AutoModelForCausalLM
486
+ auto_model_class = AutoModel
483
487
  _decoder_wrapper_cls = DecoderOnlyWrapper
484
488
  _use_rotary_emb = True
485
489
 
486
490
  def __post_init__(self, **kwargs):
487
- main_input_name = self.main_input_name
488
-
489
491
  if self.rbln_config.use_inputs_embeds:
490
- main_input_name = "inputs_embeds"
491
492
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
492
493
  self.embed_tokens = self._create_embedding_layer()
493
494
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
494
495
  else:
495
496
  self.embed_tokens = None
496
497
 
497
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
498
- dec_attn_mask = torch.zeros(
499
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
500
- )
501
- block_tables = torch.zeros(
502
- self.rbln_config.batch_size,
503
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
504
- dtype=torch.int16,
505
- ).fill_(-1)
506
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
507
-
508
- self.prefill_decoder = RBLNRuntimeModel(
509
- runtime=self.model[0],
510
- main_input_name=main_input_name,
511
- embed_tokens=self.embed_tokens,
512
- phase="prefill",
513
- batch_size=self.rbln_config.batch_size,
514
- dec_attn_mask=dec_attn_mask,
515
- block_tables=block_tables,
516
- free_block_pool=free_block_pool,
517
- rbln_config=self.rbln_config,
518
- vocab_size=self.config.vocab_size,
519
- )
498
+ # TODO: add prefill runtime class.
499
+ self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
520
500
 
521
- self.decoders = {}
522
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
523
- self.decoders[batch_size] = RBLNRuntimeModel(
524
- runtime=self.model[i + 1],
525
- main_input_name=main_input_name,
526
- embed_tokens=self.embed_tokens,
527
- phase="decode",
528
- batch_size=batch_size,
529
- dec_attn_mask=dec_attn_mask,
530
- block_tables=block_tables,
531
- free_block_pool=free_block_pool,
532
- rbln_config=self.rbln_config,
501
+ # attributes for prefill
502
+ if self.rbln_config.use_global_attention:
503
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
504
+ if self.rbln_config.use_local_attention:
505
+ self.local_block_tables = torch.tensor([0], dtype=torch.int16)
506
+ if self.rbln_config.use_attention_mask:
507
+ self.causal_mask = 1 - torch.triu(
508
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
533
509
  )
534
510
 
535
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
536
- self.decoder = self.decoders[self.rbln_config.batch_size]
537
-
538
511
  @classmethod
539
512
  def save_torch_artifacts(
540
513
  cls,
@@ -569,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
569
542
  return self.rbln_config.kvcache_num_blocks
570
543
 
571
544
  @classmethod
572
- def get_quantized_model(
573
- cls,
574
- model_id: str,
575
- config: Optional[PretrainedConfig] = None,
576
- use_auth_token: Optional[Union[bool, str]] = None,
577
- revision: Optional[str] = None,
578
- force_download: bool = False,
579
- cache_dir: Optional[str] = None,
580
- subfolder: str = "",
581
- local_files_only: bool = False,
582
- trust_remote_code: bool = False,
583
- **kwargs,
584
- ):
585
- kwargs = cls.update_kwargs(kwargs)
586
-
587
- if config is None:
588
- config = AutoConfig.from_pretrained(
589
- model_id,
590
- use_auth_token=use_auth_token,
591
- revision=revision,
592
- force_download=force_download,
593
- cache_dir=cache_dir,
594
- trust_remote_code=trust_remote_code,
595
- **kwargs,
596
- )
597
-
598
- with no_init_weights():
599
- model = AutoModelForCausalLM.from_config(config)
600
-
601
- model = prepare_model_for_quantization(
602
- model,
603
- model_id,
604
- kwargs.get("num_hidden_layers"),
605
- use_auth_token=use_auth_token,
606
- revision=revision,
607
- cache_dir=cache_dir,
608
- force_download=force_download,
609
- local_files_only=local_files_only,
610
- )
611
- return model
612
-
613
- def __getattr__(self, __name: str) -> Any:
614
- # Special method to delegate attribute access to the original Huggingface LM class.
615
- # This method is called when an attribute is not found in the current instance's dictionary.
616
- # It enables transparent access to the original model's attributes and methods while maintaining
617
- # proper method binding.
618
-
619
- # The method implements a delegation pattern that:
620
-
621
- # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
622
- # 2. For other attributes: Returns them directly from the original class
623
-
624
- def redirect(func):
625
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
626
-
627
- val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
628
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
629
- return redirect(val)
630
- return val
631
-
632
- @classmethod
633
- def get_pytorch_model(
634
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
635
- ) -> PreTrainedModel:
636
- if rbln_config and rbln_config.quantization:
637
- model = cls.get_quantized_model(*args, **kwargs)
638
- else:
639
- model = super().get_pytorch_model(*args, **kwargs)
640
-
641
- return model
642
-
643
- @classmethod
644
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
545
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
645
546
  wrapper_cfg = {
646
547
  "max_seq_len": rbln_config.max_seq_len,
647
548
  "attn_impl": rbln_config.attn_impl,
@@ -658,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
658
559
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
659
560
 
660
561
  @classmethod
661
- @torch.inference_mode()
662
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
663
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
664
-
665
- rbln_compile_configs = rbln_config.compile_cfgs
666
- prefill_compile_config = rbln_compile_configs[0]
562
+ def _compile_model(
563
+ cls,
564
+ wrapped_model,
565
+ compile_config,
566
+ example_inputs,
567
+ compile_context,
568
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
569
+ quantization=None,
570
+ phase: str = "prefill",
571
+ ):
572
+ try:
573
+ wrapped_model.phase = phase
574
+ if quantization:
575
+ quantization.maybe_set_quantization_env()
576
+ original_linear = torch.nn.functional.linear
577
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
578
+ compiled_model = cls.compile(
579
+ wrapped_model,
580
+ compile_config,
581
+ create_runtimes=rbln_config.create_runtimes,
582
+ device=rbln_config.device,
583
+ example_inputs=example_inputs,
584
+ compile_context=compile_context,
585
+ )
586
+ return compiled_model
587
+ finally:
588
+ torch.nn.functional.linear = original_linear
589
+ if quantization:
590
+ quantization.maybe_reset_quantization_env()
667
591
 
592
+ @classmethod
593
+ def _get_compile_context(
594
+ cls,
595
+ compile_config: RBLNCompileConfig,
596
+ example_inputs: List[torch.Tensor],
597
+ ):
668
598
  context = CompileContext(use_weight_sharing=True)
669
599
 
670
- # Here we use meta tensor, for the memory efficiency.
671
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
672
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
673
-
674
600
  # Mark static tensors (self kv states)
675
601
  static_tensors = {}
676
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
602
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
677
603
  if "past_key_values" in name:
678
604
  static_tensors[name] = tensor
679
605
  context.mark_static_address(tensor)
680
606
 
681
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
682
- try:
683
- if quantization:
684
- quantization.maybe_set_quantization_env()
685
- original_linear = torch.nn.functional.linear
686
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
687
- compiled_model = cls.compile(
688
- wrapped_model,
689
- compile_config,
690
- create_runtimes=rbln_config.create_runtimes,
691
- device=rbln_config.device,
692
- example_inputs=example_inputs,
693
- compile_context=compile_context,
694
- )
695
- return compiled_model
696
- finally:
697
- torch.nn.functional.linear = original_linear
698
- if quantization:
699
- quantization.maybe_reset_quantization_env()
700
-
701
- wrapped_model.phase = "prefill"
702
- compiled_prefill = compile_model(
703
- wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
704
- )
705
-
706
- wrapped_model.phase = "decode"
707
- compiled_models = {"prefill": compiled_prefill}
708
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
709
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
710
- compiled_decoder = compile_model(
711
- wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
712
- )
713
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
714
-
715
- # check if the memory is enough to have additional blocks
716
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
717
- if rbln_config.kvcache_num_blocks < required_num_blocks:
718
- cls.maybe_suggest_kvcache_num_blocks(
719
- compiled_models=compiled_models,
720
- model_config=model.config,
721
- rbln_config=rbln_config,
722
- )
723
-
724
- return compiled_models
607
+ return context, static_tensors
725
608
 
726
609
  @classmethod
727
- def maybe_suggest_kvcache_num_blocks(
610
+ @torch.inference_mode()
611
+ def get_compiled_model(
728
612
  cls,
729
- compiled_models: Dict[str, rebel.RBLNCompiledModel],
730
- model_config: PretrainedConfig,
731
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
732
- ) -> None:
733
- # Get the actual memory allocation of each node by key
734
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
735
- alloc_memory_by_key: Dict[str, int] = {
736
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
737
- }
738
- for batch_size in rbln_config.decoder_batch_sizes:
739
- for key, memory_per_node in (
740
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
741
- ):
742
- alloc_memory_by_key[key] += sum(memory_per_node)
743
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
744
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
745
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
746
-
747
- # Get the maximum number of blocks that can be allocated
748
- buffer = sum(alloc_memory_by_key.values())
749
- max_num_blocks = cls.get_maximum_num_blocks(
750
- config=model_config,
751
- tensor_parallel_size=rbln_config.tensor_parallel_size,
752
- kvcache_block_size=rbln_config.kvcache_block_size,
753
- kernel_size=kernel_size,
754
- buffer=buffer,
613
+ model: PreTrainedModel,
614
+ rbln_config: RBLNDecoderOnlyModelConfig,
615
+ ):
616
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
617
+ compile_config = rbln_config.compile_cfgs[0]
618
+
619
+ # Here we use meta tensor, for the memory efficiency.
620
+ meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
621
+ example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
622
+ context, _ = cls._get_compile_context(compile_config, example_inputs)
623
+
624
+ compiled_model = cls._compile_model(
625
+ wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
755
626
  )
627
+ compiled_models = {"prefill": compiled_model}
756
628
 
757
- # Since our estimation logic is not always accurate,
758
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
759
- # If the memory is not enough, the model will fail to compile.
760
- if rbln_config.kvcache_num_blocks < max_num_blocks:
761
- logger.warning(
762
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
763
- "Our analysis indicates that additional memory is available for more blocks. "
764
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
765
- "Please be advised that our memory estimation algorithm has limitations, "
766
- "and increasing this value may not guarantee successful model compilation."
767
- )
629
+ return compiled_models
768
630
 
769
631
  @classmethod
770
- def get_maximum_num_blocks(
771
- cls,
772
- config: PretrainedConfig,
773
- tensor_parallel_size: int,
774
- kvcache_block_size: int,
775
- nbits_per_param: Optional[int] = None,
776
- n_model_params: Optional[int] = None,
777
- kernel_size: Optional[int] = None,
778
- buffer: Optional[int] = None,
779
- num_runtimes: int = 2,
780
- ) -> int:
781
- # We are finding max_n_blocks(x) that satisfies the following equation:
782
-
783
- # available_dram - kernel_size - buffer
784
- # - num_layers * 2 * tensor_parallel_size
785
- # * align_2MB(
786
- # x
787
- # * block_size
788
- # * align_64(head_dim)
789
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
790
- # * 2
791
- # ) > 0
792
-
793
- # This inequality can be rewritten as follows:
794
-
795
- # a - c * align_2MB(b * x) > 0
796
- # where
797
- # a = available_dram - kernel_size - buffer
798
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
799
- # c = num_layers * 2 * tensor_parallel_size
800
-
801
- # We can rewrite the inequality as follows:
802
- # k > align_2MB(b*x)
803
- # where
804
- # k = a / c
805
-
806
- # After that, we can derive the following equation:
807
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
808
-
809
- def align(x: int, nbytes: int) -> int:
810
- return int(math.ceil(x / nbytes) * nbytes)
811
-
812
- def align_2MB(x: int) -> int:
813
- return align(x, 2**21)
814
-
815
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
816
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
817
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
818
- vocab_size = config.vocab_size
819
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
820
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
821
-
822
- # TODO(jongho): Update if target npu is REBEL.
823
- ATOM_DRAM_NBYTES = 16 * 2**30
824
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
825
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
826
-
827
- if kernel_size is None:
828
- if n_model_params is None:
829
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
830
- # Get estimated kernel size (approximated)
831
- lm_heads_params = align(vocab_size, 64) * hidden_size
832
- lm_heads_nbytes = (
833
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
834
- )
835
- params = n_model_params - lm_heads_params
836
- layer_nbytes = (
837
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
838
- * num_layers
839
- * tensor_parallel_size
840
- )
841
- kernel_size = layer_nbytes + lm_heads_nbytes
842
- elif n_model_params is not None:
843
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
844
-
845
- available_dram -= kernel_size
632
+ def get_quantized_model(
633
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
634
+ ) -> PreTrainedModel:
635
+ raise NotImplementedError
846
636
 
847
- if buffer is None:
848
- # TODO: Accurate buffer estimation
849
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
850
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
851
- buffer = buffer_per_core * tensor_parallel_size
852
- available_dram -= buffer
637
+ @classmethod
638
+ def get_pytorch_model(
639
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
640
+ ) -> PreTrainedModel:
641
+ if rbln_config and rbln_config.quantization:
642
+ model = cls.get_quantized_model(*args, **kwargs)
643
+ else:
644
+ model = super().get_pytorch_model(*args, **kwargs)
853
645
 
854
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
855
- c = num_layers * 2 * tensor_parallel_size
856
- k = available_dram / c
857
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
646
+ return model
858
647
 
859
- return max_n_blocks
648
+ @classmethod
649
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
650
+ return use_local_attention
860
651
 
861
652
  @classmethod
862
653
  def get_input_info(
@@ -866,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
866
657
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
867
658
  model_config: PretrainedConfig,
868
659
  ):
869
- is_prefill: bool = query_length > 1
870
660
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
871
661
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
872
662
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
873
663
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
874
664
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
875
- local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
665
+ is_prefill = query_length > 1
876
666
 
877
667
  # 1. main input
878
668
  if rbln_config.use_inputs_embeds:
@@ -891,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
891
681
  ]
892
682
 
893
683
  # 3. block_tables
894
- if rbln_config.cache_impl in ["static", "hybrid"]:
684
+ if rbln_config.use_global_attention:
895
685
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
896
686
  input_info.extend(
897
687
  [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
898
688
  )
899
- if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
689
+ if rbln_config.use_local_attention:
900
690
  input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
901
691
 
902
- # 4. query_position
903
- if is_prefill:
692
+ # 4. query_position for sliding window attention
693
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
904
694
  input_info.extend([("query_position", [], "int16")])
905
695
 
906
696
  # 5. attention_mask & position_ids
@@ -922,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
922
712
  rbln_config.kvcache_block_size,
923
713
  head_dim,
924
714
  ]
925
- local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
715
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
926
716
  input_info.extend(
927
717
  [
928
718
  (
@@ -969,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
969
759
  # ```
970
760
 
971
761
  # Returns:
972
- # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
762
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
973
763
 
974
764
  raise NotImplementedError(
975
765
  "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
976
766
  "See method docstring for required configuration details."
977
767
  )
978
768
 
769
+ @classmethod
770
+ def _update_attention_config(
771
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
772
+ ):
773
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
774
+ attn_impl=rbln_config.attn_impl,
775
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
776
+ kvcache_block_size=rbln_config.kvcache_block_size,
777
+ max_seq_len=rbln_config.max_seq_len,
778
+ )
779
+
780
+ validate_attention_method(
781
+ attn_impl=rbln_config.attn_impl,
782
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
783
+ kvcache_block_size=rbln_config.kvcache_block_size,
784
+ max_seq_len=rbln_config.max_seq_len,
785
+ )
786
+
787
+ if rbln_config.kvcache_num_blocks is None:
788
+ rbln_config.kvcache_num_blocks = (
789
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size
790
+ ) * rbln_config.batch_size
791
+
792
+ return rbln_config
793
+
979
794
  @classmethod
980
795
  def _update_rbln_config(
981
796
  cls,
@@ -996,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
996
811
  ):
997
812
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
998
813
  if rbln_config.sliding_window is not None:
999
- validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
814
+ validate_sliding_window(rbln_config)
815
+
816
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
817
+
818
+ prefill_input_info = cls.get_input_info(
819
+ batch_size=1,
820
+ query_length=rbln_config.prefill_chunk_size,
821
+ rbln_config=rbln_config,
822
+ model_config=model_config,
823
+ )
824
+
825
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
826
+ rbln_config.set_compile_cfgs([prefill_compile_config])
1000
827
 
828
+ return rbln_config
829
+
830
+ @classmethod
831
+ def _create_runtimes(
832
+ cls,
833
+ compiled_models: List[rebel.RBLNCompiledModel],
834
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
835
+ ) -> List[rebel.Runtime]:
836
+ expected_model_names = [
837
+ "prefill",
838
+ ]
839
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
840
+ cls._raise_missing_compiled_file_error(expected_model_names)
841
+
842
+ return [
843
+ rebel.Runtime(
844
+ compiled_models[0],
845
+ tensor_type="pt",
846
+ device=rbln_config.device_map["prefill"],
847
+ activate_profiler=rbln_config.activate_profiler,
848
+ ),
849
+ ]
850
+
851
+ def _preprocess_chunked_prefill(
852
+ self,
853
+ inputs: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_embed: Optional[torch.Tensor] = None,
856
+ ):
857
+ # valid sequence length of inputs_embeds
858
+ query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
859
+
860
+ # extract valid inputs
861
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
862
+
863
+ if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
864
+ inputs = self.get_input_embeddings()(inputs)
865
+
866
+ if position_embed is not None:
867
+ position_embed = (
868
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
869
+ )
870
+
871
+ # padding for chunked prefill
872
+ padding_size = (
873
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
874
+ ) % self.rbln_config.prefill_chunk_size
875
+ padded_len = query_length + padding_size
876
+
877
+ inputs = (
878
+ torch.nn.functional.pad(inputs, (0, padding_size))
879
+ if not self.rbln_config.use_inputs_embeds
880
+ else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
881
+ )
882
+ position_embed = (
883
+ None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
884
+ )
885
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
886
+
887
+ chunked_attention_mask = (
888
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
889
+ if self.rbln_config.use_attention_mask
890
+ else None
891
+ )
892
+
893
+ return inputs, position_embed, cache_position, query_length, chunked_attention_mask
894
+
895
+ def _chunked_prefill_forward(
896
+ self,
897
+ inputs: torch.Tensor,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_embed: Optional[torch.Tensor] = None,
900
+ ):
901
+ padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
902
+ self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
903
+ )
904
+
905
+ # chunked prefill
906
+ last_hidden_states = []
907
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
908
+ # Extract the current chunk of inputs and cache positions
909
+ input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
910
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
911
+
912
+ valid_length = (
913
+ self.rbln_config.prefill_chunk_size
914
+ if (step + self.rbln_config.prefill_chunk_size) <= query_length
915
+ else query_length - step
916
+ )
917
+ if self.rbln_config.use_local_attention:
918
+ query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
919
+ else:
920
+ query_position = None
921
+
922
+ if self.rbln_config.use_attention_mask:
923
+ if step > 0:
924
+ chunked_attention_mask[:, :, :, :step] = 1
925
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
926
+
927
+ # Forward pass for the current chunk
928
+ last_hidden_states_chunk = self.prefill_decoder(
929
+ input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
930
+ inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
931
+ cache_position=cache_pos_chunk,
932
+ block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
933
+ local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
934
+ query_position=query_position,
935
+ attention_mask=chunked_attention_mask,
936
+ position_emb=padded_position_embed,
937
+ )
938
+ last_hidden_states.append(last_hidden_states_chunk)
939
+ last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
940
+
941
+ return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
942
+
943
+ def _postprocess_chunked_prefill(
944
+ self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
945
+ ):
946
+ # index copy for attention mask
947
+ if attention_mask is not None:
948
+ new_last_hidden_states = torch.full(
949
+ (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
950
+ fill_value=1e-10,
951
+ dtype=last_hidden_states.dtype,
952
+ )
953
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
954
+ new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
955
+ else:
956
+ new_last_hidden_states = last_hidden_states
957
+ return new_last_hidden_states
958
+
959
+ def forward(
960
+ self,
961
+ input_ids: Optional[torch.LongTensor] = None,
962
+ inputs_embeds: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.LongTensor] = None,
964
+ position_embed: Optional[torch.Tensor] = None,
965
+ **kwargs,
966
+ ) -> Tuple[torch.FloatTensor]:
967
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
968
+ batch_size = inputs.shape[0]
969
+ all_last_hidden_states = []
970
+ for b_idx in range(batch_size):
971
+ last_hidden_states = self._chunked_prefill_forward(
972
+ inputs[b_idx : b_idx + 1],
973
+ attention_mask[b_idx] if attention_mask is not None else None,
974
+ position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
975
+ )
976
+ all_last_hidden_states.append(last_hidden_states)
977
+
978
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
979
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
980
+
981
+
982
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
983
+ """
984
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
985
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
986
+
987
+ The class provides core functionality for:
988
+
989
+ 1. Converting pre-trained transformer models to RBLN-optimized format
990
+ 2. Handling the compilation process for RBLN devices
991
+ 3. Managing inference operations for causal language modeling
992
+
993
+ This class inherits from RBLNModel and implements specific methods required for
994
+ decoder-only architectures and causal language modeling tasks.
995
+
996
+ Note:
997
+ - This class is designed to be subclassed by specific model implementations
998
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
999
+ - Subclasses should implement model-specific conversion logic.
1000
+ - The class handles RBLN-specific optimizations automatically during compilation
1001
+ """
1002
+
1003
+ auto_model_class = AutoModelForCausalLM
1004
+
1005
+ def __post_init__(self, **kwargs):
1006
+ main_input_name = self.main_input_name
1007
+
1008
+ if self.rbln_config.use_inputs_embeds:
1009
+ main_input_name = "inputs_embeds"
1010
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
1011
+ self.embed_tokens = self._create_embedding_layer()
1012
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
1013
+ else:
1014
+ self.embed_tokens = None
1015
+
1016
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
1017
+ dec_attn_mask = torch.zeros(
1018
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
1019
+ )
1020
+ block_tables = torch.zeros(
1021
+ self.rbln_config.batch_size,
1022
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
1023
+ dtype=torch.int16,
1024
+ ).fill_(-1)
1025
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
1026
+
1027
+ self.prefill_decoder = RBLNRuntimeModel(
1028
+ runtime=self.model[0],
1029
+ main_input_name=main_input_name,
1030
+ embed_tokens=self.embed_tokens,
1031
+ phase="prefill",
1032
+ batch_size=self.rbln_config.batch_size,
1033
+ dec_attn_mask=dec_attn_mask,
1034
+ block_tables=block_tables,
1035
+ free_block_pool=free_block_pool,
1036
+ rbln_config=self.rbln_config,
1037
+ vocab_size=self.config.vocab_size,
1038
+ )
1039
+
1040
+ if self.can_generate():
1041
+ self.decoders = {}
1042
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
1043
+ self.decoders[batch_size] = RBLNRuntimeModel(
1044
+ runtime=self.model[i + 1],
1045
+ main_input_name=main_input_name,
1046
+ embed_tokens=self.embed_tokens,
1047
+ phase="decode",
1048
+ batch_size=batch_size,
1049
+ dec_attn_mask=dec_attn_mask,
1050
+ block_tables=block_tables,
1051
+ free_block_pool=free_block_pool,
1052
+ rbln_config=self.rbln_config,
1053
+ )
1054
+
1055
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
1056
+ self.decoder = self.decoders[self.rbln_config.batch_size]
1057
+
1058
+ @classmethod
1059
+ def get_quantized_model(
1060
+ cls,
1061
+ model_id: str,
1062
+ config: Optional[PretrainedConfig] = None,
1063
+ use_auth_token: Optional[Union[bool, str]] = None,
1064
+ revision: Optional[str] = None,
1065
+ force_download: bool = False,
1066
+ cache_dir: Optional[str] = None,
1067
+ subfolder: str = "",
1068
+ local_files_only: bool = False,
1069
+ trust_remote_code: bool = False,
1070
+ **kwargs,
1071
+ ):
1072
+ kwargs = cls.update_kwargs(kwargs)
1073
+
1074
+ if config is None:
1075
+ config = AutoConfig.from_pretrained(
1076
+ model_id,
1077
+ use_auth_token=use_auth_token,
1078
+ revision=revision,
1079
+ force_download=force_download,
1080
+ cache_dir=cache_dir,
1081
+ trust_remote_code=trust_remote_code,
1082
+ **kwargs,
1083
+ )
1084
+
1085
+ with no_init_weights():
1086
+ model = AutoModelForCausalLM.from_config(config)
1087
+
1088
+ model = prepare_model_for_quantization(
1089
+ model,
1090
+ model_id,
1091
+ kwargs.get("num_hidden_layers"),
1092
+ use_auth_token=use_auth_token,
1093
+ revision=revision,
1094
+ cache_dir=cache_dir,
1095
+ force_download=force_download,
1096
+ local_files_only=local_files_only,
1097
+ )
1098
+ return model
1099
+
1100
+ def __getattr__(self, __name: str) -> Any:
1101
+ # Special method to delegate attribute access to the original Huggingface LM class.
1102
+ # This method is called when an attribute is not found in the current instance's dictionary.
1103
+ # It enables transparent access to the original model's attributes and methods while maintaining
1104
+ # proper method binding.
1105
+
1106
+ # The method implements a delegation pattern that:
1107
+
1108
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
1109
+ # 2. For other attributes: Returns them directly from the original class
1110
+
1111
+ def redirect(func):
1112
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
1113
+
1114
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
1115
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
1116
+ return redirect(val)
1117
+ return val
1118
+
1119
+ @classmethod
1120
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
1121
+ wrapper_cfg = {
1122
+ "max_seq_len": rbln_config.max_seq_len,
1123
+ "attn_impl": rbln_config.attn_impl,
1124
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
1125
+ "kvcache_block_size": rbln_config.kvcache_block_size,
1126
+ "use_rotary_emb": cls._use_rotary_emb,
1127
+ "use_attention_mask": rbln_config.use_attention_mask,
1128
+ "use_position_ids": rbln_config.use_position_ids,
1129
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
1130
+ "cache_impl": rbln_config.cache_impl,
1131
+ "sliding_window": rbln_config.sliding_window,
1132
+ "sliding_window_layers": rbln_config.sliding_window_layers,
1133
+ }
1134
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
1135
+
1136
+ @classmethod
1137
+ @torch.inference_mode()
1138
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
1139
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
1140
+ prefill_compile_config = rbln_config.compile_cfgs[0]
1141
+
1142
+ # Here we use meta tensor, for the memory efficiency.
1143
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
1144
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
1145
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
1146
+
1147
+ compiled_models = {}
1148
+ compiled_models["prefill"] = cls._compile_model(
1149
+ wrapped_model,
1150
+ prefill_compile_config,
1151
+ prefill_example_inputs,
1152
+ context,
1153
+ rbln_config,
1154
+ rbln_config.quantization,
1155
+ phase="prefill",
1156
+ )
1157
+
1158
+ if rbln_config.can_generate:
1159
+ wrapped_model.phase = "decode"
1160
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
1161
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
1162
+ compiled_decoder = cls._compile_model(
1163
+ wrapped_model,
1164
+ dec_compile_config,
1165
+ dec_example_inputs,
1166
+ context,
1167
+ rbln_config,
1168
+ rbln_config.quantization,
1169
+ phase="decode",
1170
+ )
1171
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
1172
+
1173
+ # check if the memory is enough to have additional blocks
1174
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1175
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
1176
+ cls.maybe_suggest_kvcache_num_blocks(
1177
+ compiled_models=compiled_models,
1178
+ model_config=model.config,
1179
+ rbln_config=rbln_config,
1180
+ )
1181
+
1182
+ return compiled_models
1183
+
1184
+ @classmethod
1185
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
1186
+ return is_prefill
1187
+
1188
+ @classmethod
1189
+ def _update_attention_config(
1190
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
1191
+ ):
1001
1192
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
1002
1193
  attn_impl=rbln_config.attn_impl,
1003
1194
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -1022,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1022
1213
  kvcache_block_size=rbln_config.kvcache_block_size,
1023
1214
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1024
1215
  n_model_params=sum(p.numel() for p in model.parameters()),
1025
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
1216
+ num_runtimes=1 if rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1026
1217
  )
1027
1218
 
1028
1219
  max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
1029
1220
 
1030
1221
  flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1031
- if max_num_blocks < flash_min_blocks:
1222
+ if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
1032
1223
  max_num_blocks = flash_min_blocks
1033
1224
 
1034
1225
  if max_num_blocks < rbln_config.batch_size:
@@ -1047,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1047
1238
  )
1048
1239
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1049
1240
 
1050
- prefill_input_info = cls.get_input_info(
1051
- batch_size=1,
1052
- query_length=rbln_config.prefill_chunk_size,
1053
- rbln_config=rbln_config,
1054
- model_config=model_config,
1055
- )
1056
-
1057
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
1241
+ return rbln_config
1058
1242
 
1059
- dec_compile_configs = []
1060
- for batch_size in rbln_config.decoder_batch_sizes:
1061
- dec_input_info = cls.get_input_info(
1062
- batch_size=batch_size,
1063
- query_length=1,
1064
- rbln_config=rbln_config,
1065
- model_config=model_config,
1066
- )
1067
- dec_compile_configs.append(
1068
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1069
- )
1070
- rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
1243
+ @classmethod
1244
+ def _update_rbln_config(
1245
+ cls,
1246
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
1247
+ model: Optional[PreTrainedModel] = None,
1248
+ model_config: Optional[PretrainedConfig] = None,
1249
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
1250
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
1251
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
1252
+ if rbln_config.can_generate:
1253
+ compile_configs = rbln_config.compile_cfgs
1254
+ for batch_size in rbln_config.decoder_batch_sizes:
1255
+ dec_input_info = cls.get_input_info(
1256
+ batch_size=batch_size,
1257
+ query_length=1,
1258
+ rbln_config=rbln_config,
1259
+ model_config=model_config,
1260
+ )
1261
+ compile_configs.append(
1262
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1263
+ )
1264
+ rbln_config.set_compile_cfgs(compile_configs)
1071
1265
 
1072
1266
  return rbln_config
1073
1267
 
@@ -1077,38 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1077
1271
  compiled_models: List[rebel.RBLNCompiledModel],
1078
1272
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1079
1273
  ) -> List[rebel.Runtime]:
1080
- expected_model_names = [
1081
- "prefill",
1082
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1083
- ]
1274
+ expected_model_names = ["prefill"]
1275
+ if rbln_config.can_generate:
1276
+ expected_model_names.extend(
1277
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
1278
+ )
1084
1279
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1085
1280
  cls._raise_missing_compiled_file_error(expected_model_names)
1086
1281
 
1087
- return [
1282
+ ret_val = [
1088
1283
  rebel.Runtime(
1089
1284
  compiled_models[0],
1090
1285
  tensor_type="pt",
1091
1286
  device=rbln_config.device_map["prefill"],
1092
1287
  activate_profiler=rbln_config.activate_profiler,
1093
1288
  timeout=rbln_config.timeout,
1094
- ),
1095
- *[
1096
- rebel.Runtime(
1097
- compiled_models[i + 1],
1098
- tensor_type="pt",
1099
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1100
- activate_profiler=rbln_config.activate_profiler,
1101
- timeout=rbln_config.timeout,
1102
- )
1103
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1104
- ],
1289
+ )
1105
1290
  ]
1291
+ if rbln_config.can_generate:
1292
+ ret_val.extend(
1293
+ [
1294
+ rebel.Runtime(
1295
+ compiled_models[i + 1],
1296
+ tensor_type="pt",
1297
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1298
+ activate_profiler=rbln_config.activate_profiler,
1299
+ timeout=rbln_config.timeout,
1300
+ )
1301
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1302
+ ]
1303
+ )
1304
+ return ret_val
1106
1305
 
1107
1306
  def get_decoder(self):
1307
+ if not self.can_generate():
1308
+ raise ValueError("Decode stage is not supported in this model.")
1108
1309
  return self.decoder
1109
1310
 
1110
1311
  def can_generate(self):
1111
- return True
1312
+ return self.rbln_config.can_generate
1112
1313
 
1113
1314
  def _reorder_cache(self, past_key_values, beam_idx):
1114
1315
  raise NotImplementedError
@@ -1165,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1165
1366
 
1166
1367
  def _update_model_kwargs_for_generation(
1167
1368
  self,
1168
- outputs: RBLNDecoderOnlyOutput,
1369
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
1169
1370
  model_kwargs: Dict[str, Any],
1170
1371
  **kwargs,
1171
1372
  ) -> Dict[str, Any]:
@@ -1193,15 +1394,15 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1193
1394
  # A for-loop ensures synchronization with the HuggingFace generate API.
1194
1395
  # The decoder stage operates as usual, processing inputs in batch mode.
1195
1396
 
1397
+ # for only use forward
1398
+ if not self.can_generate():
1399
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1400
+ padded_cache_lengths = torch.zeros_like(generate_idx)
1401
+
1196
1402
  # Prefll
1197
1403
  if cache_position is None:
1198
1404
  logits = []
1199
1405
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1200
- # for only use forward
1201
- if generate_idx is None:
1202
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1203
- if padded_cache_lengths is None:
1204
- padded_cache_lengths = torch.zeros_like(generate_idx)
1205
1406
  batch_size = inputs.shape[0]
1206
1407
  for b_idx in range(batch_size):
1207
1408
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
@@ -1236,6 +1437,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1236
1437
  if not return_dict:
1237
1438
  return logits, generate_idx, padded_cache_lengths
1238
1439
  else:
1239
- return RBLNDecoderOnlyOutput(
1440
+ return RBLNDecoderOnlyForCausalLMOutput(
1240
1441
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1241
1442
  )