optimum-rbln 0.8.2a4__py3-none-any.whl → 0.8.2a6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +44 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +4 -0
  4. optimum/rbln/ops/kv_cache_update.py +5 -0
  5. optimum/rbln/ops/linear.py +7 -0
  6. optimum/rbln/transformers/__init__.py +48 -0
  7. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  8. optimum/rbln/transformers/models/__init__.py +35 -14
  9. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  10. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  11. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +122 -205
  12. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +569 -366
  13. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  14. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  15. optimum/rbln/transformers/models/gemma/modeling_gemma.py +13 -1
  16. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +7 -5
  17. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +82 -59
  18. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  19. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  20. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -7
  21. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +16 -1
  22. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +2 -2
  23. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  24. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  25. optimum/rbln/transformers/models/llama/modeling_llama.py +13 -1
  26. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  27. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  28. optimum/rbln/transformers/models/llava/modeling_llava.py +379 -0
  29. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  30. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  31. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  32. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  33. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  34. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  35. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  36. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  37. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  38. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  39. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  40. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  41. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +163 -0
  42. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  43. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  44. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  45. optimum/rbln/transformers/models/phi/phi_architecture.py +6 -6
  46. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  47. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +318 -0
  49. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  50. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  51. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  52. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  53. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +3 -3
  54. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  55. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +10 -328
  56. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +0 -241
  57. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +0 -10
  58. optimum/rbln/transformers/models/whisper/configuration_whisper.py +1 -10
  59. optimum/rbln/transformers/models/whisper/modeling_whisper.py +5 -1
  60. optimum/rbln/utils/depreacate_utils.py +16 -0
  61. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/METADATA +1 -1
  62. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/RECORD +64 -51
  63. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.8.2a6.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- import math
17
16
  from collections import deque
18
17
  from dataclasses import dataclass
19
18
  from pathlib import Path
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
22
21
  import rebel
23
22
  import torch
24
23
  from rebel.compile_context import CompileContext
25
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
24
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import prepare_model_for_quantization
34
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
- from .decoderonly_architecture import (
36
- DecoderOnlyWrapper,
33
+ from ...modeling_attention_utils import (
34
+ RBLNDecoderOnlyFlashAttentionMixin,
37
35
  set_default_values,
38
36
  validate_attention_method,
39
- validate_sliding_window_size,
37
+ validate_sliding_window,
40
38
  )
39
+ from ...utils.rbln_quantization import prepare_model_for_quantization
40
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
41
+ from .decoderonly_architecture import DecoderOnlyWrapper
41
42
 
42
43
 
43
44
  logger = get_logger()
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
267
268
 
268
269
  attention_mask = self.dec_attn_mask
269
270
 
270
- if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
271
+ if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
271
272
  block_tables = block_tables[: self.batch_size]
272
273
 
273
274
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
283
284
  position_ids if self.rbln_config.use_position_ids else None,
284
285
  )
285
286
 
286
- return RBLNDecoderOnlyOutput(logits=logits)
287
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
287
288
 
288
289
  def _prepare_prefill_inputs(
289
290
  self,
@@ -449,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
449
450
  self.dec_attn_mask[batch_idx].fill_(0)
450
451
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
451
452
 
452
- return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
453
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
453
454
 
454
455
 
455
456
  @dataclass
456
- class RBLNDecoderOnlyOutput(ModelOutput):
457
+ class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
457
458
  logits: torch.FloatTensor = None
458
459
  generate_idx: torch.Tensor = None
459
460
  padded_cache_lengths: int = None
460
461
 
461
462
 
462
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
463
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
463
464
  """
464
- A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
465
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
466
+ This class is used for RBLN-optimized models that are not causal language models.
465
467
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
466
468
 
467
469
  The class provides core functionality for:
468
470
 
469
471
  1. Converting pre-trained transformer models to RBLN-optimized format
470
472
  2. Handling the compilation process for RBLN devices
471
- 3. Managing inference operations for causal language modeling
473
+ 3. Managing inference operations for decoder-only architectures
472
474
 
473
475
  This class inherits from RBLNModel and implements specific methods required for
474
- decoder-only architectures and causal language modeling tasks.
476
+ decoder-only architectures.
475
477
 
476
478
  Note:
477
479
  - This class is designed to be subclassed by specific model implementations
478
- (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
480
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
479
481
  - Subclasses should implement model-specific conversion logic.
480
482
  - The class handles RBLN-specific optimizations automatically during compilation
481
483
  """
482
484
 
483
485
  main_input_name = "input_ids"
484
- auto_model_class = AutoModelForCausalLM
486
+ auto_model_class = AutoModel
485
487
  _decoder_wrapper_cls = DecoderOnlyWrapper
486
488
  _use_rotary_emb = True
487
489
 
488
490
  def __post_init__(self, **kwargs):
489
- main_input_name = self.main_input_name
490
-
491
491
  if self.rbln_config.use_inputs_embeds:
492
- main_input_name = "inputs_embeds"
493
492
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
494
493
  self.embed_tokens = self._create_embedding_layer()
495
494
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
496
495
  else:
497
496
  self.embed_tokens = None
498
497
 
499
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
500
- dec_attn_mask = torch.zeros(
501
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
502
- )
503
- block_tables = torch.zeros(
504
- self.rbln_config.batch_size,
505
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
506
- dtype=torch.int16,
507
- ).fill_(-1)
508
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
509
-
510
- self.prefill_decoder = RBLNRuntimeModel(
511
- runtime=self.model[0],
512
- main_input_name=main_input_name,
513
- embed_tokens=self.embed_tokens,
514
- phase="prefill",
515
- batch_size=self.rbln_config.batch_size,
516
- dec_attn_mask=dec_attn_mask,
517
- block_tables=block_tables,
518
- free_block_pool=free_block_pool,
519
- rbln_config=self.rbln_config,
520
- vocab_size=self.config.vocab_size,
521
- )
498
+ # TODO: add prefill runtime class.
499
+ self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
522
500
 
523
- self.decoders = {}
524
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
525
- self.decoders[batch_size] = RBLNRuntimeModel(
526
- runtime=self.model[i + 1],
527
- main_input_name=main_input_name,
528
- embed_tokens=self.embed_tokens,
529
- phase="decode",
530
- batch_size=batch_size,
531
- dec_attn_mask=dec_attn_mask,
532
- block_tables=block_tables,
533
- free_block_pool=free_block_pool,
534
- rbln_config=self.rbln_config,
501
+ # attributes for prefill
502
+ if self.rbln_config.use_global_attention:
503
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
504
+ if self.rbln_config.use_local_attention:
505
+ self.local_block_tables = torch.tensor([0], dtype=torch.int16)
506
+ if self.rbln_config.use_attention_mask:
507
+ self.causal_mask = 1 - torch.triu(
508
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
535
509
  )
536
510
 
537
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
538
- self.decoder = self.decoders[self.rbln_config.batch_size]
539
-
540
511
  @classmethod
541
512
  def save_torch_artifacts(
542
513
  cls,
@@ -571,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
571
542
  return self.rbln_config.kvcache_num_blocks
572
543
 
573
544
  @classmethod
574
- def get_quantized_model(
575
- cls,
576
- model_id: str,
577
- config: Optional[PretrainedConfig] = None,
578
- use_auth_token: Optional[Union[bool, str]] = None,
579
- revision: Optional[str] = None,
580
- force_download: bool = False,
581
- cache_dir: Optional[str] = None,
582
- subfolder: str = "",
583
- local_files_only: bool = False,
584
- trust_remote_code: bool = False,
585
- **kwargs,
586
- ):
587
- kwargs = cls.update_kwargs(kwargs)
588
-
589
- if config is None:
590
- config = AutoConfig.from_pretrained(
591
- model_id,
592
- use_auth_token=use_auth_token,
593
- revision=revision,
594
- force_download=force_download,
595
- cache_dir=cache_dir,
596
- trust_remote_code=trust_remote_code,
597
- **kwargs,
598
- )
599
-
600
- with no_init_weights():
601
- model = AutoModelForCausalLM.from_config(config)
602
-
603
- model = prepare_model_for_quantization(
604
- model,
605
- model_id,
606
- kwargs.get("num_hidden_layers"),
607
- use_auth_token=use_auth_token,
608
- revision=revision,
609
- cache_dir=cache_dir,
610
- force_download=force_download,
611
- local_files_only=local_files_only,
612
- )
613
- return model
614
-
615
- def __getattr__(self, __name: str) -> Any:
616
- # Special method to delegate attribute access to the original Huggingface LM class.
617
- # This method is called when an attribute is not found in the current instance's dictionary.
618
- # It enables transparent access to the original model's attributes and methods while maintaining
619
- # proper method binding.
620
-
621
- # The method implements a delegation pattern that:
622
-
623
- # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
624
- # 2. For other attributes: Returns them directly from the original class
625
-
626
- def redirect(func):
627
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
628
-
629
- val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
630
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
631
- return redirect(val)
632
- return val
633
-
634
- @classmethod
635
- def get_pytorch_model(
636
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
637
- ) -> PreTrainedModel:
638
- if rbln_config and rbln_config.quantization:
639
- model = cls.get_quantized_model(*args, **kwargs)
640
- else:
641
- model = super().get_pytorch_model(*args, **kwargs)
642
-
643
- return model
644
-
645
- @classmethod
646
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
545
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
647
546
  wrapper_cfg = {
648
547
  "max_seq_len": rbln_config.max_seq_len,
649
548
  "attn_impl": rbln_config.attn_impl,
@@ -660,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
660
559
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
661
560
 
662
561
  @classmethod
663
- @torch.inference_mode()
664
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
665
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
666
-
667
- rbln_compile_configs = rbln_config.compile_cfgs
668
- prefill_compile_config = rbln_compile_configs[0]
562
+ def _compile_model(
563
+ cls,
564
+ wrapped_model,
565
+ compile_config,
566
+ example_inputs,
567
+ compile_context,
568
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
569
+ quantization=None,
570
+ phase: str = "prefill",
571
+ ):
572
+ try:
573
+ wrapped_model.phase = phase
574
+ if quantization:
575
+ quantization.maybe_set_quantization_env()
576
+ original_linear = torch.nn.functional.linear
577
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
578
+ compiled_model = cls.compile(
579
+ wrapped_model,
580
+ compile_config,
581
+ create_runtimes=rbln_config.create_runtimes,
582
+ device=rbln_config.device,
583
+ example_inputs=example_inputs,
584
+ compile_context=compile_context,
585
+ )
586
+ return compiled_model
587
+ finally:
588
+ torch.nn.functional.linear = original_linear
589
+ if quantization:
590
+ quantization.maybe_reset_quantization_env()
669
591
 
592
+ @classmethod
593
+ def _get_compile_context(
594
+ cls,
595
+ compile_config: RBLNCompileConfig,
596
+ example_inputs: List[torch.Tensor],
597
+ ):
670
598
  context = CompileContext(use_weight_sharing=True)
671
599
 
672
- # Here we use meta tensor, for the memory efficiency.
673
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
674
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
675
-
676
600
  # Mark static tensors (self kv states)
677
601
  static_tensors = {}
678
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
602
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
679
603
  if "past_key_values" in name:
680
604
  static_tensors[name] = tensor
681
605
  context.mark_static_address(tensor)
682
606
 
683
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
684
- try:
685
- if quantization:
686
- quantization.maybe_set_quantization_env()
687
- original_linear = torch.nn.functional.linear
688
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
689
- compiled_model = cls.compile(
690
- wrapped_model,
691
- compile_config,
692
- create_runtimes=rbln_config.create_runtimes,
693
- device=rbln_config.device,
694
- example_inputs=example_inputs,
695
- compile_context=compile_context,
696
- )
697
- return compiled_model
698
- finally:
699
- torch.nn.functional.linear = original_linear
700
- if quantization:
701
- quantization.maybe_reset_quantization_env()
702
-
703
- wrapped_model.phase = "prefill"
704
- compiled_prefill = compile_model(
705
- wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
706
- )
707
-
708
- wrapped_model.phase = "decode"
709
- compiled_models = {"prefill": compiled_prefill}
710
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
711
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
712
- compiled_decoder = compile_model(
713
- wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
714
- )
715
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
716
-
717
- # check if the memory is enough to have additional blocks
718
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
719
- if rbln_config.kvcache_num_blocks < required_num_blocks:
720
- cls.maybe_suggest_kvcache_num_blocks(
721
- compiled_models=compiled_models,
722
- model_config=model.config,
723
- rbln_config=rbln_config,
724
- )
725
-
726
- return compiled_models
607
+ return context, static_tensors
727
608
 
728
609
  @classmethod
729
- def maybe_suggest_kvcache_num_blocks(
610
+ @torch.inference_mode()
611
+ def get_compiled_model(
730
612
  cls,
731
- compiled_models: Dict[str, rebel.RBLNCompiledModel],
732
- model_config: PretrainedConfig,
733
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
734
- ) -> None:
735
- # Get the actual memory allocation of each node by key
736
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
737
- alloc_memory_by_key: Dict[str, int] = {
738
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
739
- }
740
- for batch_size in rbln_config.decoder_batch_sizes:
741
- for key, memory_per_node in (
742
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
743
- ):
744
- alloc_memory_by_key[key] += sum(memory_per_node)
745
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
746
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
747
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
748
-
749
- # Get the maximum number of blocks that can be allocated
750
- buffer = sum(alloc_memory_by_key.values())
751
- max_num_blocks = cls.get_maximum_num_blocks(
752
- config=model_config,
753
- tensor_parallel_size=rbln_config.tensor_parallel_size,
754
- kvcache_block_size=rbln_config.kvcache_block_size,
755
- kernel_size=kernel_size,
756
- buffer=buffer,
613
+ model: PreTrainedModel,
614
+ rbln_config: RBLNDecoderOnlyModelConfig,
615
+ ):
616
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
617
+ compile_config = rbln_config.compile_cfgs[0]
618
+
619
+ # Here we use meta tensor, for the memory efficiency.
620
+ meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
621
+ example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
622
+ context, _ = cls._get_compile_context(compile_config, example_inputs)
623
+
624
+ compiled_model = cls._compile_model(
625
+ wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
757
626
  )
627
+ compiled_models = {"prefill": compiled_model}
758
628
 
759
- # Since our estimation logic is not always accurate,
760
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
761
- # If the memory is not enough, the model will fail to compile.
762
- if rbln_config.kvcache_num_blocks < max_num_blocks:
763
- logger.warning(
764
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
765
- "Our analysis indicates that additional memory is available for more blocks. "
766
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
767
- "Please be advised that our memory estimation algorithm has limitations, "
768
- "and increasing this value may not guarantee successful model compilation."
769
- )
629
+ return compiled_models
770
630
 
771
631
  @classmethod
772
- def get_maximum_num_blocks(
773
- cls,
774
- config: PretrainedConfig,
775
- tensor_parallel_size: int,
776
- kvcache_block_size: int,
777
- nbits_per_param: Optional[int] = None,
778
- n_model_params: Optional[int] = None,
779
- kernel_size: Optional[int] = None,
780
- buffer: Optional[int] = None,
781
- num_runtimes: int = 2,
782
- ) -> int:
783
- # We are finding max_n_blocks(x) that satisfies the following equation:
784
-
785
- # available_dram - kernel_size - buffer
786
- # - num_layers * 2 * tensor_parallel_size
787
- # * align_2MB(
788
- # x
789
- # * block_size
790
- # * align_64(head_dim)
791
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
792
- # * 2
793
- # ) > 0
794
-
795
- # This inequality can be rewritten as follows:
796
-
797
- # a - c * align_2MB(b * x) > 0
798
- # where
799
- # a = available_dram - kernel_size - buffer
800
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
801
- # c = num_layers * 2 * tensor_parallel_size
802
-
803
- # We can rewrite the inequality as follows:
804
- # k > align_2MB(b*x)
805
- # where
806
- # k = a / c
807
-
808
- # After that, we can derive the following equation:
809
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
810
-
811
- def align(x: int, nbytes: int) -> int:
812
- return int(math.ceil(x / nbytes) * nbytes)
813
-
814
- def align_2MB(x: int) -> int:
815
- return align(x, 2**21)
816
-
817
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
818
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
819
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
820
- vocab_size = config.vocab_size
821
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
822
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
823
-
824
- # TODO(jongho): Update if target npu is REBEL.
825
- ATOM_DRAM_NBYTES = 16 * 2**30
826
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
827
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
828
-
829
- if kernel_size is None:
830
- if n_model_params is None:
831
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
832
- # Get estimated kernel size (approximated)
833
- lm_heads_params = align(vocab_size, 64) * hidden_size
834
- lm_heads_nbytes = (
835
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
836
- )
837
- params = n_model_params - lm_heads_params
838
- layer_nbytes = (
839
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
840
- * num_layers
841
- * tensor_parallel_size
842
- )
843
- kernel_size = layer_nbytes + lm_heads_nbytes
844
- elif n_model_params is not None:
845
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
846
-
847
- available_dram -= kernel_size
632
+ def get_quantized_model(
633
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
634
+ ) -> PreTrainedModel:
635
+ raise NotImplementedError
848
636
 
849
- if buffer is None:
850
- # TODO: Accurate buffer estimation
851
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
852
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
853
- buffer = buffer_per_core * tensor_parallel_size
854
- available_dram -= buffer
637
+ @classmethod
638
+ def get_pytorch_model(
639
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
640
+ ) -> PreTrainedModel:
641
+ if rbln_config and rbln_config.quantization:
642
+ model = cls.get_quantized_model(*args, **kwargs)
643
+ else:
644
+ model = super().get_pytorch_model(*args, **kwargs)
855
645
 
856
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
857
- c = num_layers * 2 * tensor_parallel_size
858
- k = available_dram / c
859
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
646
+ return model
860
647
 
861
- return max_n_blocks
648
+ @classmethod
649
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
650
+ return use_local_attention
862
651
 
863
652
  @classmethod
864
653
  def get_input_info(
@@ -868,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
868
657
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
869
658
  model_config: PretrainedConfig,
870
659
  ):
871
- is_prefill: bool = query_length > 1
872
660
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
873
661
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
874
662
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
875
663
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
876
664
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
877
- local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
665
+ is_prefill = query_length > 1
878
666
 
879
667
  # 1. main input
880
668
  if rbln_config.use_inputs_embeds:
@@ -893,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
893
681
  ]
894
682
 
895
683
  # 3. block_tables
896
- if rbln_config.cache_impl in ["static", "hybrid"]:
684
+ if rbln_config.use_global_attention:
897
685
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
898
686
  input_info.extend(
899
687
  [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
900
688
  )
901
- if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
689
+ if rbln_config.use_local_attention:
902
690
  input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
903
691
 
904
- # 4. query_position
905
- if is_prefill:
692
+ # 4. query_position for sliding window attention
693
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
906
694
  input_info.extend([("query_position", [], "int16")])
907
695
 
908
696
  # 5. attention_mask & position_ids
@@ -924,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
924
712
  rbln_config.kvcache_block_size,
925
713
  head_dim,
926
714
  ]
927
- local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
715
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
928
716
  input_info.extend(
929
717
  [
930
718
  (
@@ -971,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
971
759
  # ```
972
760
 
973
761
  # Returns:
974
- # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
762
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
975
763
 
976
764
  raise NotImplementedError(
977
765
  "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
978
766
  "See method docstring for required configuration details."
979
767
  )
980
768
 
769
+ @classmethod
770
+ def _update_attention_config(
771
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
772
+ ):
773
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
774
+ attn_impl=rbln_config.attn_impl,
775
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
776
+ kvcache_block_size=rbln_config.kvcache_block_size,
777
+ max_seq_len=rbln_config.max_seq_len,
778
+ )
779
+
780
+ validate_attention_method(
781
+ attn_impl=rbln_config.attn_impl,
782
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
783
+ kvcache_block_size=rbln_config.kvcache_block_size,
784
+ max_seq_len=rbln_config.max_seq_len,
785
+ )
786
+
787
+ if rbln_config.kvcache_num_blocks is None:
788
+ rbln_config.kvcache_num_blocks = (
789
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size
790
+ ) * rbln_config.batch_size
791
+
792
+ return rbln_config
793
+
981
794
  @classmethod
982
795
  def _update_rbln_config(
983
796
  cls,
@@ -998,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
998
811
  ):
999
812
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
1000
813
  if rbln_config.sliding_window is not None:
1001
- validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
814
+ validate_sliding_window(rbln_config)
815
+
816
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
817
+
818
+ prefill_input_info = cls.get_input_info(
819
+ batch_size=1,
820
+ query_length=rbln_config.prefill_chunk_size,
821
+ rbln_config=rbln_config,
822
+ model_config=model_config,
823
+ )
824
+
825
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
826
+ rbln_config.set_compile_cfgs([prefill_compile_config])
1002
827
 
828
+ return rbln_config
829
+
830
+ @classmethod
831
+ def _create_runtimes(
832
+ cls,
833
+ compiled_models: List[rebel.RBLNCompiledModel],
834
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
835
+ ) -> List[rebel.Runtime]:
836
+ expected_model_names = [
837
+ "prefill",
838
+ ]
839
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
840
+ cls._raise_missing_compiled_file_error(expected_model_names)
841
+
842
+ return [
843
+ rebel.Runtime(
844
+ compiled_models[0],
845
+ tensor_type="pt",
846
+ device=rbln_config.device_map["prefill"],
847
+ activate_profiler=rbln_config.activate_profiler,
848
+ ),
849
+ ]
850
+
851
+ def _preprocess_chunked_prefill(
852
+ self,
853
+ inputs: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_embed: Optional[torch.Tensor] = None,
856
+ ):
857
+ # valid sequence length of inputs_embeds
858
+ query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
859
+
860
+ # extract valid inputs
861
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
862
+
863
+ if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
864
+ inputs = self.get_input_embeddings()(inputs)
865
+
866
+ if position_embed is not None:
867
+ position_embed = (
868
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
869
+ )
870
+
871
+ # padding for chunked prefill
872
+ padding_size = (
873
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
874
+ ) % self.rbln_config.prefill_chunk_size
875
+ padded_len = query_length + padding_size
876
+
877
+ inputs = (
878
+ torch.nn.functional.pad(inputs, (0, padding_size))
879
+ if not self.rbln_config.use_inputs_embeds
880
+ else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
881
+ )
882
+ position_embed = (
883
+ None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
884
+ )
885
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
886
+
887
+ chunked_attention_mask = (
888
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
889
+ if self.rbln_config.use_attention_mask
890
+ else None
891
+ )
892
+
893
+ return inputs, position_embed, cache_position, query_length, chunked_attention_mask
894
+
895
+ def _chunked_prefill_forward(
896
+ self,
897
+ inputs: torch.Tensor,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_embed: Optional[torch.Tensor] = None,
900
+ ):
901
+ padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
902
+ self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
903
+ )
904
+
905
+ # chunked prefill
906
+ last_hidden_states = []
907
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
908
+ # Extract the current chunk of inputs and cache positions
909
+ input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
910
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
911
+
912
+ valid_length = (
913
+ self.rbln_config.prefill_chunk_size
914
+ if (step + self.rbln_config.prefill_chunk_size) <= query_length
915
+ else query_length - step
916
+ )
917
+ if self.rbln_config.use_local_attention:
918
+ query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
919
+ else:
920
+ query_position = None
921
+
922
+ if self.rbln_config.use_attention_mask:
923
+ if step > 0:
924
+ chunked_attention_mask[:, :, :, :step] = 1
925
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
926
+
927
+ # Forward pass for the current chunk
928
+ last_hidden_states_chunk = self.prefill_decoder(
929
+ input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
930
+ inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
931
+ cache_position=cache_pos_chunk,
932
+ block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
933
+ local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
934
+ query_position=query_position,
935
+ attention_mask=chunked_attention_mask,
936
+ position_emb=padded_position_embed,
937
+ )
938
+ last_hidden_states.append(last_hidden_states_chunk)
939
+ last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
940
+
941
+ return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
942
+
943
+ def _postprocess_chunked_prefill(
944
+ self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
945
+ ):
946
+ # index copy for attention mask
947
+ if attention_mask is not None:
948
+ new_last_hidden_states = torch.full(
949
+ (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
950
+ fill_value=1e-10,
951
+ dtype=last_hidden_states.dtype,
952
+ )
953
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
954
+ new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
955
+ else:
956
+ new_last_hidden_states = last_hidden_states
957
+ return new_last_hidden_states
958
+
959
+ def forward(
960
+ self,
961
+ input_ids: Optional[torch.LongTensor] = None,
962
+ inputs_embeds: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.LongTensor] = None,
964
+ position_embed: Optional[torch.Tensor] = None,
965
+ **kwargs,
966
+ ) -> Tuple[torch.FloatTensor]:
967
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
968
+ batch_size = inputs.shape[0]
969
+ all_last_hidden_states = []
970
+ for b_idx in range(batch_size):
971
+ last_hidden_states = self._chunked_prefill_forward(
972
+ inputs[b_idx : b_idx + 1],
973
+ attention_mask[b_idx] if attention_mask is not None else None,
974
+ position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
975
+ )
976
+ all_last_hidden_states.append(last_hidden_states)
977
+
978
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
979
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
980
+
981
+
982
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
983
+ """
984
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
985
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
986
+
987
+ The class provides core functionality for:
988
+
989
+ 1. Converting pre-trained transformer models to RBLN-optimized format
990
+ 2. Handling the compilation process for RBLN devices
991
+ 3. Managing inference operations for causal language modeling
992
+
993
+ This class inherits from RBLNModel and implements specific methods required for
994
+ decoder-only architectures and causal language modeling tasks.
995
+
996
+ Note:
997
+ - This class is designed to be subclassed by specific model implementations
998
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
999
+ - Subclasses should implement model-specific conversion logic.
1000
+ - The class handles RBLN-specific optimizations automatically during compilation
1001
+ """
1002
+
1003
+ auto_model_class = AutoModelForCausalLM
1004
+
1005
+ def __post_init__(self, **kwargs):
1006
+ main_input_name = self.main_input_name
1007
+
1008
+ if self.rbln_config.use_inputs_embeds:
1009
+ main_input_name = "inputs_embeds"
1010
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
1011
+ self.embed_tokens = self._create_embedding_layer()
1012
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
1013
+ else:
1014
+ self.embed_tokens = None
1015
+
1016
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
1017
+ dec_attn_mask = torch.zeros(
1018
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
1019
+ )
1020
+ block_tables = torch.zeros(
1021
+ self.rbln_config.batch_size,
1022
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
1023
+ dtype=torch.int16,
1024
+ ).fill_(-1)
1025
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
1026
+
1027
+ self.prefill_decoder = RBLNRuntimeModel(
1028
+ runtime=self.model[0],
1029
+ main_input_name=main_input_name,
1030
+ embed_tokens=self.embed_tokens,
1031
+ phase="prefill",
1032
+ batch_size=self.rbln_config.batch_size,
1033
+ dec_attn_mask=dec_attn_mask,
1034
+ block_tables=block_tables,
1035
+ free_block_pool=free_block_pool,
1036
+ rbln_config=self.rbln_config,
1037
+ vocab_size=self.config.vocab_size,
1038
+ )
1039
+
1040
+ if self.can_generate():
1041
+ self.decoders = {}
1042
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
1043
+ self.decoders[batch_size] = RBLNRuntimeModel(
1044
+ runtime=self.model[i + 1],
1045
+ main_input_name=main_input_name,
1046
+ embed_tokens=self.embed_tokens,
1047
+ phase="decode",
1048
+ batch_size=batch_size,
1049
+ dec_attn_mask=dec_attn_mask,
1050
+ block_tables=block_tables,
1051
+ free_block_pool=free_block_pool,
1052
+ rbln_config=self.rbln_config,
1053
+ )
1054
+
1055
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
1056
+ self.decoder = self.decoders[self.rbln_config.batch_size]
1057
+
1058
+ @classmethod
1059
+ def get_quantized_model(
1060
+ cls,
1061
+ model_id: str,
1062
+ config: Optional[PretrainedConfig] = None,
1063
+ use_auth_token: Optional[Union[bool, str]] = None,
1064
+ revision: Optional[str] = None,
1065
+ force_download: bool = False,
1066
+ cache_dir: Optional[str] = None,
1067
+ subfolder: str = "",
1068
+ local_files_only: bool = False,
1069
+ trust_remote_code: bool = False,
1070
+ **kwargs,
1071
+ ):
1072
+ kwargs = cls.update_kwargs(kwargs)
1073
+
1074
+ if config is None:
1075
+ config = AutoConfig.from_pretrained(
1076
+ model_id,
1077
+ use_auth_token=use_auth_token,
1078
+ revision=revision,
1079
+ force_download=force_download,
1080
+ cache_dir=cache_dir,
1081
+ trust_remote_code=trust_remote_code,
1082
+ **kwargs,
1083
+ )
1084
+
1085
+ with no_init_weights():
1086
+ model = AutoModelForCausalLM.from_config(config)
1087
+
1088
+ model = prepare_model_for_quantization(
1089
+ model,
1090
+ model_id,
1091
+ kwargs.get("num_hidden_layers"),
1092
+ use_auth_token=use_auth_token,
1093
+ revision=revision,
1094
+ cache_dir=cache_dir,
1095
+ force_download=force_download,
1096
+ local_files_only=local_files_only,
1097
+ )
1098
+ return model
1099
+
1100
+ def __getattr__(self, __name: str) -> Any:
1101
+ # Special method to delegate attribute access to the original Huggingface LM class.
1102
+ # This method is called when an attribute is not found in the current instance's dictionary.
1103
+ # It enables transparent access to the original model's attributes and methods while maintaining
1104
+ # proper method binding.
1105
+
1106
+ # The method implements a delegation pattern that:
1107
+
1108
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
1109
+ # 2. For other attributes: Returns them directly from the original class
1110
+
1111
+ def redirect(func):
1112
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
1113
+
1114
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
1115
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
1116
+ return redirect(val)
1117
+ return val
1118
+
1119
+ @classmethod
1120
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
1121
+ wrapper_cfg = {
1122
+ "max_seq_len": rbln_config.max_seq_len,
1123
+ "attn_impl": rbln_config.attn_impl,
1124
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
1125
+ "kvcache_block_size": rbln_config.kvcache_block_size,
1126
+ "use_rotary_emb": cls._use_rotary_emb,
1127
+ "use_attention_mask": rbln_config.use_attention_mask,
1128
+ "use_position_ids": rbln_config.use_position_ids,
1129
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
1130
+ "cache_impl": rbln_config.cache_impl,
1131
+ "sliding_window": rbln_config.sliding_window,
1132
+ "sliding_window_layers": rbln_config.sliding_window_layers,
1133
+ }
1134
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
1135
+
1136
+ @classmethod
1137
+ @torch.inference_mode()
1138
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
1139
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
1140
+ prefill_compile_config = rbln_config.compile_cfgs[0]
1141
+
1142
+ # Here we use meta tensor, for the memory efficiency.
1143
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
1144
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
1145
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
1146
+
1147
+ compiled_models = {}
1148
+ compiled_models["prefill"] = cls._compile_model(
1149
+ wrapped_model,
1150
+ prefill_compile_config,
1151
+ prefill_example_inputs,
1152
+ context,
1153
+ rbln_config,
1154
+ rbln_config.quantization,
1155
+ phase="prefill",
1156
+ )
1157
+
1158
+ if rbln_config.can_generate:
1159
+ wrapped_model.phase = "decode"
1160
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
1161
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
1162
+ compiled_decoder = cls._compile_model(
1163
+ wrapped_model,
1164
+ dec_compile_config,
1165
+ dec_example_inputs,
1166
+ context,
1167
+ rbln_config,
1168
+ rbln_config.quantization,
1169
+ phase="decode",
1170
+ )
1171
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
1172
+
1173
+ # check if the memory is enough to have additional blocks
1174
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1175
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
1176
+ cls.maybe_suggest_kvcache_num_blocks(
1177
+ compiled_models=compiled_models,
1178
+ model_config=model.config,
1179
+ rbln_config=rbln_config,
1180
+ )
1181
+
1182
+ return compiled_models
1183
+
1184
+ @classmethod
1185
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
1186
+ return is_prefill
1187
+
1188
+ @classmethod
1189
+ def _update_attention_config(
1190
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
1191
+ ):
1003
1192
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
1004
1193
  attn_impl=rbln_config.attn_impl,
1005
1194
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -1024,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1024
1213
  kvcache_block_size=rbln_config.kvcache_block_size,
1025
1214
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1026
1215
  n_model_params=sum(p.numel() for p in model.parameters()),
1027
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
1216
+ num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1028
1217
  )
1029
1218
 
1030
1219
  max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
1031
1220
 
1032
1221
  flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1033
- if max_num_blocks < flash_min_blocks:
1222
+ if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
1034
1223
  max_num_blocks = flash_min_blocks
1035
1224
 
1036
1225
  if max_num_blocks < rbln_config.batch_size:
@@ -1049,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1049
1238
  )
1050
1239
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1051
1240
 
1052
- prefill_input_info = cls.get_input_info(
1053
- batch_size=1,
1054
- query_length=rbln_config.prefill_chunk_size,
1055
- rbln_config=rbln_config,
1056
- model_config=model_config,
1057
- )
1058
-
1059
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
1241
+ return rbln_config
1060
1242
 
1061
- dec_compile_configs = []
1062
- for batch_size in rbln_config.decoder_batch_sizes:
1063
- dec_input_info = cls.get_input_info(
1064
- batch_size=batch_size,
1065
- query_length=1,
1066
- rbln_config=rbln_config,
1067
- model_config=model_config,
1068
- )
1069
- dec_compile_configs.append(
1070
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1071
- )
1072
- rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
1243
+ @classmethod
1244
+ def _update_rbln_config(
1245
+ cls,
1246
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
1247
+ model: Optional[PreTrainedModel] = None,
1248
+ model_config: Optional[PretrainedConfig] = None,
1249
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
1250
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
1251
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
1252
+ if rbln_config.can_generate:
1253
+ compile_configs = rbln_config.compile_cfgs
1254
+ for batch_size in rbln_config.decoder_batch_sizes:
1255
+ dec_input_info = cls.get_input_info(
1256
+ batch_size=batch_size,
1257
+ query_length=1,
1258
+ rbln_config=rbln_config,
1259
+ model_config=model_config,
1260
+ )
1261
+ compile_configs.append(
1262
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1263
+ )
1264
+ rbln_config.set_compile_cfgs(compile_configs)
1073
1265
 
1074
1266
  return rbln_config
1075
1267
 
@@ -1079,38 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1079
1271
  compiled_models: List[rebel.RBLNCompiledModel],
1080
1272
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1081
1273
  ) -> List[rebel.Runtime]:
1082
- expected_model_names = [
1083
- "prefill",
1084
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1085
- ]
1274
+ expected_model_names = ["prefill"]
1275
+ if rbln_config.can_generate:
1276
+ expected_model_names.extend(
1277
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
1278
+ )
1086
1279
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1087
1280
  cls._raise_missing_compiled_file_error(expected_model_names)
1088
1281
 
1089
- return [
1282
+ ret_val = [
1090
1283
  rebel.Runtime(
1091
1284
  compiled_models[0],
1092
1285
  tensor_type="pt",
1093
1286
  device=rbln_config.device_map["prefill"],
1094
1287
  activate_profiler=rbln_config.activate_profiler,
1095
1288
  timeout=rbln_config.timeout,
1096
- ),
1097
- *[
1098
- rebel.Runtime(
1099
- compiled_models[i + 1],
1100
- tensor_type="pt",
1101
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1102
- activate_profiler=rbln_config.activate_profiler,
1103
- timeout=rbln_config.timeout,
1104
- )
1105
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1106
- ],
1289
+ )
1107
1290
  ]
1291
+ if rbln_config.can_generate:
1292
+ ret_val.extend(
1293
+ [
1294
+ rebel.Runtime(
1295
+ compiled_models[i + 1],
1296
+ tensor_type="pt",
1297
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1298
+ activate_profiler=rbln_config.activate_profiler,
1299
+ timeout=rbln_config.timeout,
1300
+ )
1301
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1302
+ ]
1303
+ )
1304
+ return ret_val
1108
1305
 
1109
1306
  def get_decoder(self):
1307
+ if not self.can_generate():
1308
+ raise ValueError("Decode stage is not supported in this model.")
1110
1309
  return self.decoder
1111
1310
 
1112
1311
  def can_generate(self):
1113
- return True
1312
+ return self.rbln_config.can_generate
1114
1313
 
1115
1314
  def _reorder_cache(self, past_key_values, beam_idx):
1116
1315
  raise NotImplementedError
@@ -1167,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1167
1366
 
1168
1367
  def _update_model_kwargs_for_generation(
1169
1368
  self,
1170
- outputs: RBLNDecoderOnlyOutput,
1369
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
1171
1370
  model_kwargs: Dict[str, Any],
1172
1371
  **kwargs,
1173
1372
  ) -> Dict[str, Any]:
@@ -1195,15 +1394,19 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1195
1394
  # A for-loop ensures synchronization with the HuggingFace generate API.
1196
1395
  # The decoder stage operates as usual, processing inputs in batch mode.
1197
1396
 
1397
+ # for only use forward
1398
+ if generate_idx is None:
1399
+ generate_idx = (
1400
+ attention_mask.sum(dim=-1, keepdim=True).int()
1401
+ if attention_mask is not None
1402
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
1403
+ )
1404
+ padded_cache_lengths = torch.zeros_like(generate_idx)
1405
+
1198
1406
  # Prefll
1199
1407
  if cache_position is None:
1200
1408
  logits = []
1201
1409
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1202
- # for only use forward
1203
- if generate_idx is None:
1204
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1205
- if padded_cache_lengths is None:
1206
- padded_cache_lengths = torch.zeros_like(generate_idx)
1207
1410
  batch_size = inputs.shape[0]
1208
1411
  for b_idx in range(batch_size):
1209
1412
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
@@ -1238,6 +1441,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1238
1441
  if not return_dict:
1239
1442
  return logits, generate_idx, padded_cache_lengths
1240
1443
  else:
1241
- return RBLNDecoderOnlyOutput(
1444
+ return RBLNDecoderOnlyForCausalLMOutput(
1242
1445
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1243
1446
  )