optimum-rbln 0.8.1rc1__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (119) hide show
  1. optimum/rbln/__init__.py +58 -9
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +24 -5
  4. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +2 -2
  5. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +2 -2
  6. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +2 -2
  7. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +2 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +5 -2
  9. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +2 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +2 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +2 -2
  12. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +3 -3
  13. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +2 -2
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +4 -4
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +2 -2
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +2 -2
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +2 -2
  18. optimum/rbln/diffusers/modeling_diffusers.py +4 -5
  19. optimum/rbln/diffusers/models/__init__.py +3 -13
  20. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +1 -0
  21. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1 -0
  22. optimum/rbln/diffusers/models/autoencoders/vq_model.py +1 -0
  23. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +1 -1
  24. optimum/rbln/diffusers/pipelines/__init__.py +1 -5
  25. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +12 -4
  26. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +4 -28
  27. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  28. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -1
  29. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  30. optimum/rbln/modeling.py +4 -5
  31. optimum/rbln/modeling_base.py +18 -14
  32. optimum/rbln/ops/kv_cache_update.py +5 -0
  33. optimum/rbln/ops/linear.py +7 -0
  34. optimum/rbln/transformers/__init__.py +60 -0
  35. optimum/rbln/transformers/configuration_generic.py +4 -4
  36. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  37. optimum/rbln/transformers/modeling_generic.py +1 -4
  38. optimum/rbln/transformers/models/__init__.py +45 -30
  39. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  40. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +2 -2
  41. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +1 -5
  42. optimum/rbln/transformers/models/clip/configuration_clip.py +14 -3
  43. optimum/rbln/transformers/models/clip/modeling_clip.py +123 -28
  44. optimum/rbln/transformers/models/colpali/colpali_architecture.py +1 -4
  45. optimum/rbln/transformers/models/colpali/configuration_colpali.py +2 -2
  46. optimum/rbln/transformers/models/colpali/modeling_colpali.py +2 -10
  47. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -2
  48. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +214 -45
  49. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -454
  50. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +579 -362
  51. optimum/rbln/transformers/models/exaone/exaone_architecture.py +17 -42
  52. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  53. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  54. optimum/rbln/transformers/models/gemma/gemma_architecture.py +3 -44
  55. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  56. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +21 -9
  57. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +9 -63
  58. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +200 -292
  59. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  60. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  61. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +19 -24
  62. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  63. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +2 -2
  64. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -9
  65. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  66. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  67. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  68. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  69. optimum/rbln/transformers/models/llava/configuration_llava.py +54 -0
  70. optimum/rbln/transformers/models/llava/modeling_llava.py +419 -0
  71. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +20 -3
  72. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +6 -16
  73. optimum/rbln/transformers/models/midm/midm_architecture.py +14 -22
  74. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  75. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  76. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  77. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  78. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  79. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  80. optimum/rbln/transformers/models/opt/modeling_opt.py +41 -1
  81. optimum/rbln/transformers/models/opt/opt_architecture.py +16 -25
  82. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  83. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +34 -0
  84. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +69 -0
  85. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  86. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  87. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  88. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  89. optimum/rbln/transformers/models/phi/phi_architecture.py +16 -22
  90. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  91. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  92. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +315 -0
  93. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  94. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  95. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  96. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  97. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +3 -3
  98. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -15
  99. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +1 -4
  100. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  101. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  102. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  103. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  104. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -12
  105. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  106. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  107. optimum/rbln/transformers/models/siglip/modeling_siglip.py +2 -2
  108. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +2 -2
  109. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -5
  110. optimum/rbln/transformers/models/whisper/configuration_whisper.py +3 -12
  111. optimum/rbln/transformers/models/whisper/modeling_whisper.py +8 -2
  112. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  113. optimum/rbln/utils/depreacate_utils.py +16 -0
  114. optimum/rbln/utils/hub.py +8 -47
  115. optimum/rbln/utils/runtime_utils.py +31 -5
  116. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/METADATA +1 -1
  117. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/RECORD +119 -102
  118. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/WHEEL +0 -0
  119. {optimum_rbln-0.8.1rc1.dist-info → optimum_rbln-0.8.2.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,6 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- import math
17
16
  from collections import deque
18
17
  from dataclasses import dataclass
19
18
  from pathlib import Path
@@ -22,7 +21,8 @@ from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tu
22
21
  import rebel
23
22
  import torch
24
23
  from rebel.compile_context import CompileContext
25
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
24
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
25
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
26
  from transformers.modeling_utils import no_init_weights
27
27
  from transformers.utils import ModelOutput
28
28
 
@@ -30,14 +30,15 @@ from ....configuration_utils import RBLNCompileConfig
30
30
  from ....modeling import RBLNModel
31
31
  from ....utils.logging import get_logger
32
32
  from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import prepare_model_for_quantization
34
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
- from .decoderonly_architecture import (
36
- DecoderOnlyWrapper,
33
+ from ...modeling_attention_utils import (
34
+ RBLNDecoderOnlyFlashAttentionMixin,
37
35
  set_default_values,
38
36
  validate_attention_method,
39
- validate_sliding_window_size,
37
+ validate_sliding_window,
40
38
  )
39
+ from ...utils.rbln_quantization import prepare_model_for_quantization
40
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
41
+ from .decoderonly_architecture import DecoderOnlyWrapper
41
42
 
42
43
 
43
44
  logger = get_logger()
@@ -267,7 +268,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
267
268
 
268
269
  attention_mask = self.dec_attn_mask
269
270
 
270
- if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
271
+ if self.rbln_config.use_global_attention and self.batch_size < block_tables.shape[0]:
271
272
  block_tables = block_tables[: self.batch_size]
272
273
 
273
274
  if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
@@ -283,7 +284,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
283
284
  position_ids if self.rbln_config.use_position_ids else None,
284
285
  )
285
286
 
286
- return RBLNDecoderOnlyOutput(logits=logits)
287
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits)
287
288
 
288
289
  def _prepare_prefill_inputs(
289
290
  self,
@@ -303,6 +304,8 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
303
304
  position_embed = (
304
305
  position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
305
306
  )
307
+ if token_type_ids is not None:
308
+ token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
306
309
 
307
310
  query_length = inputs.shape[1]
308
311
  if query_length > self.rbln_config.max_seq_len:
@@ -352,8 +355,11 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
352
355
  if position_embed is not None:
353
356
  position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
354
357
 
358
+ if token_type_ids is not None:
359
+ token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
360
+
355
361
  # Overwrite position_ids and padded_cache_lengths
356
- position_ids = None
362
+ position_ids = cache_position.clone()
357
363
  padded_cache_lengths = 0
358
364
 
359
365
  return (
@@ -365,6 +371,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
365
371
  position_embed,
366
372
  padded_cache_lengths,
367
373
  query_length,
374
+ token_type_ids,
368
375
  )
369
376
 
370
377
  def prefill_forward(
@@ -393,6 +400,7 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
393
400
  position_embed,
394
401
  padded_cache_lengths,
395
402
  query_length,
403
+ token_type_ids,
396
404
  ) = self._prepare_prefill_inputs(
397
405
  inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
398
406
  )
@@ -442,94 +450,64 @@ class RBLNRuntimeModel(RBLNPytorchRuntime):
442
450
  self.dec_attn_mask[batch_idx].fill_(0)
443
451
  self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
444
452
 
445
- return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
453
+ return RBLNDecoderOnlyForCausalLMOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
446
454
 
447
455
 
448
456
  @dataclass
449
- class RBLNDecoderOnlyOutput(ModelOutput):
457
+ class RBLNDecoderOnlyForCausalLMOutput(ModelOutput):
450
458
  logits: torch.FloatTensor = None
451
459
  generate_idx: torch.Tensor = None
452
460
  padded_cache_lengths: int = None
453
461
 
454
462
 
455
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
463
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
456
464
  """
457
- A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
465
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
466
+ This class is used for RBLN-optimized models that are not causal language models.
458
467
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
459
468
 
460
469
  The class provides core functionality for:
461
470
 
462
471
  1. Converting pre-trained transformer models to RBLN-optimized format
463
472
  2. Handling the compilation process for RBLN devices
464
- 3. Managing inference operations for causal language modeling
473
+ 3. Managing inference operations for decoder-only architectures
465
474
 
466
475
  This class inherits from RBLNModel and implements specific methods required for
467
- decoder-only architectures and causal language modeling tasks.
476
+ decoder-only architectures.
468
477
 
469
478
  Note:
470
479
  - This class is designed to be subclassed by specific model implementations
471
- (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
480
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
472
481
  - Subclasses should implement model-specific conversion logic.
473
482
  - The class handles RBLN-specific optimizations automatically during compilation
474
483
  """
475
484
 
476
485
  main_input_name = "input_ids"
477
- auto_model_class = AutoModelForCausalLM
486
+ auto_model_class = AutoModel
478
487
  _decoder_wrapper_cls = DecoderOnlyWrapper
479
488
  _use_rotary_emb = True
480
489
 
481
490
  def __post_init__(self, **kwargs):
482
- main_input_name = self.main_input_name
483
-
484
491
  if self.rbln_config.use_inputs_embeds:
485
- main_input_name = "inputs_embeds"
486
492
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
487
493
  self.embed_tokens = self._create_embedding_layer()
488
494
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
489
495
  else:
490
496
  self.embed_tokens = None
491
497
 
492
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
493
- dec_attn_mask = torch.zeros(
494
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
495
- )
496
- block_tables = torch.zeros(
497
- self.rbln_config.batch_size,
498
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
499
- dtype=torch.int16,
500
- ).fill_(-1)
501
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
502
-
503
- self.prefill_decoder = RBLNRuntimeModel(
504
- runtime=self.model[0],
505
- main_input_name=main_input_name,
506
- embed_tokens=self.embed_tokens,
507
- phase="prefill",
508
- batch_size=self.rbln_config.batch_size,
509
- dec_attn_mask=dec_attn_mask,
510
- block_tables=block_tables,
511
- free_block_pool=free_block_pool,
512
- rbln_config=self.rbln_config,
513
- vocab_size=self.config.vocab_size,
514
- )
498
+ # TODO: add prefill runtime class.
499
+ self.prefill_decoder = RBLNPytorchRuntime(runtime=self.model[0])
515
500
 
516
- self.decoders = {}
517
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
518
- self.decoders[batch_size] = RBLNRuntimeModel(
519
- runtime=self.model[i + 1],
520
- main_input_name=main_input_name,
521
- embed_tokens=self.embed_tokens,
522
- phase="decode",
523
- batch_size=batch_size,
524
- dec_attn_mask=dec_attn_mask,
525
- block_tables=block_tables,
526
- free_block_pool=free_block_pool,
527
- rbln_config=self.rbln_config,
501
+ # attributes for prefill
502
+ if self.rbln_config.use_global_attention:
503
+ self.block_tables = torch.arange(self.rbln_config.kvcache_num_blocks, dtype=torch.int16)
504
+ if self.rbln_config.use_local_attention:
505
+ self.local_block_tables = torch.tensor([0], dtype=torch.int16)
506
+ if self.rbln_config.use_attention_mask:
507
+ self.causal_mask = 1 - torch.triu(
508
+ torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
528
509
  )
529
510
 
530
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
531
- self.decoder = self.decoders[self.rbln_config.batch_size]
532
-
533
511
  @classmethod
534
512
  def save_torch_artifacts(
535
513
  cls,
@@ -564,79 +542,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
564
542
  return self.rbln_config.kvcache_num_blocks
565
543
 
566
544
  @classmethod
567
- def get_quantized_model(
568
- cls,
569
- model_id: str,
570
- config: Optional[PretrainedConfig] = None,
571
- use_auth_token: Optional[Union[bool, str]] = None,
572
- revision: Optional[str] = None,
573
- force_download: bool = False,
574
- cache_dir: Optional[str] = None,
575
- subfolder: str = "",
576
- local_files_only: bool = False,
577
- trust_remote_code: bool = False,
578
- **kwargs,
579
- ):
580
- kwargs = cls.update_kwargs(kwargs)
581
-
582
- if config is None:
583
- config = AutoConfig.from_pretrained(
584
- model_id,
585
- use_auth_token=use_auth_token,
586
- revision=revision,
587
- force_download=force_download,
588
- cache_dir=cache_dir,
589
- trust_remote_code=trust_remote_code,
590
- **kwargs,
591
- )
592
-
593
- with no_init_weights():
594
- model = AutoModelForCausalLM.from_config(config)
595
-
596
- model = prepare_model_for_quantization(
597
- model,
598
- model_id,
599
- kwargs.get("num_hidden_layers"),
600
- use_auth_token=use_auth_token,
601
- revision=revision,
602
- cache_dir=cache_dir,
603
- force_download=force_download,
604
- local_files_only=local_files_only,
605
- )
606
- return model
607
-
608
- def __getattr__(self, __name: str) -> Any:
609
- # Special method to delegate attribute access to the original Huggingface LM class.
610
- # This method is called when an attribute is not found in the current instance's dictionary.
611
- # It enables transparent access to the original model's attributes and methods while maintaining
612
- # proper method binding.
613
-
614
- # The method implements a delegation pattern that:
615
-
616
- # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
617
- # 2. For other attributes: Returns them directly from the original class
618
-
619
- def redirect(func):
620
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
621
-
622
- val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
623
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
624
- return redirect(val)
625
- return val
626
-
627
- @classmethod
628
- def get_pytorch_model(
629
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
630
- ) -> PreTrainedModel:
631
- if rbln_config and rbln_config.quantization:
632
- model = cls.get_quantized_model(*args, **kwargs)
633
- else:
634
- model = super().get_pytorch_model(*args, **kwargs)
635
-
636
- return model
637
-
638
- @classmethod
639
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
545
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
640
546
  wrapper_cfg = {
641
547
  "max_seq_len": rbln_config.max_seq_len,
642
548
  "attn_impl": rbln_config.attn_impl,
@@ -653,205 +559,95 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
653
559
  return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
654
560
 
655
561
  @classmethod
656
- @torch.inference_mode()
657
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
658
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
659
-
660
- rbln_compile_configs = rbln_config.compile_cfgs
661
- prefill_compile_config = rbln_compile_configs[0]
562
+ def _compile_model(
563
+ cls,
564
+ wrapped_model,
565
+ compile_config,
566
+ example_inputs,
567
+ compile_context,
568
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
569
+ quantization=None,
570
+ phase: str = "prefill",
571
+ ):
572
+ try:
573
+ wrapped_model.phase = phase
574
+ if quantization:
575
+ quantization.maybe_set_quantization_env()
576
+ original_linear = torch.nn.functional.linear
577
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
578
+ compiled_model = cls.compile(
579
+ wrapped_model,
580
+ compile_config,
581
+ create_runtimes=rbln_config.create_runtimes,
582
+ device=rbln_config.device,
583
+ example_inputs=example_inputs,
584
+ compile_context=compile_context,
585
+ )
586
+ return compiled_model
587
+ finally:
588
+ torch.nn.functional.linear = original_linear
589
+ if quantization:
590
+ quantization.maybe_reset_quantization_env()
662
591
 
592
+ @classmethod
593
+ def _get_compile_context(
594
+ cls,
595
+ compile_config: RBLNCompileConfig,
596
+ example_inputs: List[torch.Tensor],
597
+ ):
663
598
  context = CompileContext(use_weight_sharing=True)
664
599
 
665
- # Here we use meta tensor, for the memory efficiency.
666
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
667
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
668
-
669
600
  # Mark static tensors (self kv states)
670
601
  static_tensors = {}
671
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
602
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
672
603
  if "past_key_values" in name:
673
604
  static_tensors[name] = tensor
674
605
  context.mark_static_address(tensor)
675
606
 
676
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
677
- try:
678
- if quantization:
679
- quantization.maybe_set_quantization_env()
680
- original_linear = torch.nn.functional.linear
681
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
682
- compiled_model = cls.compile(
683
- wrapped_model,
684
- compile_config,
685
- create_runtimes=rbln_config.create_runtimes,
686
- device=rbln_config.device,
687
- example_inputs=example_inputs,
688
- compile_context=compile_context,
689
- )
690
- return compiled_model
691
- finally:
692
- torch.nn.functional.linear = original_linear
693
- if quantization:
694
- quantization.maybe_reset_quantization_env()
695
-
696
- wrapped_model.phase = "prefill"
697
- compiled_prefill = compile_model(
698
- wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
699
- )
700
-
701
- wrapped_model.phase = "decode"
702
- compiled_models = {"prefill": compiled_prefill}
703
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
704
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
705
- compiled_decoder = compile_model(
706
- wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
707
- )
708
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
709
-
710
- # check if the memory is enough to have additional blocks
711
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
712
- if rbln_config.kvcache_num_blocks < required_num_blocks:
713
- cls.maybe_suggest_kvcache_num_blocks(
714
- compiled_models=compiled_models,
715
- model_config=model.config,
716
- rbln_config=rbln_config,
717
- )
718
-
719
- return compiled_models
607
+ return context, static_tensors
720
608
 
721
609
  @classmethod
722
- def maybe_suggest_kvcache_num_blocks(
610
+ @torch.inference_mode()
611
+ def get_compiled_model(
723
612
  cls,
724
- compiled_models: Dict[str, rebel.RBLNCompiledModel],
725
- model_config: PretrainedConfig,
726
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
727
- ) -> None:
728
- # Get the actual memory allocation of each node by key
729
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
730
- alloc_memory_by_key: Dict[str, int] = {
731
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
732
- }
733
- for batch_size in rbln_config.decoder_batch_sizes:
734
- for key, memory_per_node in (
735
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
736
- ):
737
- alloc_memory_by_key[key] += sum(memory_per_node)
738
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
739
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
740
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
741
-
742
- # Get the maximum number of blocks that can be allocated
743
- buffer = sum(alloc_memory_by_key.values())
744
- max_num_blocks = cls.get_maximum_num_blocks(
745
- config=model_config,
746
- tensor_parallel_size=rbln_config.tensor_parallel_size,
747
- kvcache_block_size=rbln_config.kvcache_block_size,
748
- kernel_size=kernel_size,
749
- buffer=buffer,
613
+ model: PreTrainedModel,
614
+ rbln_config: RBLNDecoderOnlyModelConfig,
615
+ ):
616
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
617
+ compile_config = rbln_config.compile_cfgs[0]
618
+
619
+ # Here we use meta tensor, for the memory efficiency.
620
+ meta_tensor_names = [name for name, _, _ in compile_config.input_info if "past_key_values" in name]
621
+ example_inputs = compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
622
+ context, _ = cls._get_compile_context(compile_config, example_inputs)
623
+
624
+ compiled_model = cls._compile_model(
625
+ wrapped_model, compile_config, example_inputs, context, rbln_config, rbln_config.quantization, "prefill"
750
626
  )
627
+ compiled_models = {"prefill": compiled_model}
751
628
 
752
- # Since our estimation logic is not always accurate,
753
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
754
- # If the memory is not enough, the model will fail to compile.
755
- if rbln_config.kvcache_num_blocks < max_num_blocks:
756
- logger.warning(
757
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
758
- "Our analysis indicates that additional memory is available for more blocks. "
759
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
760
- "Please be advised that our memory estimation algorithm has limitations, "
761
- "and increasing this value may not guarantee successful model compilation."
762
- )
629
+ return compiled_models
763
630
 
764
631
  @classmethod
765
- def get_maximum_num_blocks(
766
- cls,
767
- config: PretrainedConfig,
768
- tensor_parallel_size: int,
769
- kvcache_block_size: int,
770
- nbits_per_param: Optional[int] = None,
771
- n_model_params: Optional[int] = None,
772
- kernel_size: Optional[int] = None,
773
- buffer: Optional[int] = None,
774
- num_runtimes: int = 2,
775
- ) -> int:
776
- # We are finding max_n_blocks(x) that satisfies the following equation:
777
-
778
- # available_dram - kernel_size - buffer
779
- # - num_layers * 2 * tensor_parallel_size
780
- # * align_2MB(
781
- # x
782
- # * block_size
783
- # * align_64(head_dim)
784
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
785
- # * 2
786
- # ) > 0
787
-
788
- # This inequality can be rewritten as follows:
789
-
790
- # a - c * align_2MB(b * x) > 0
791
- # where
792
- # a = available_dram - kernel_size - buffer
793
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
794
- # c = num_layers * 2 * tensor_parallel_size
795
-
796
- # We can rewrite the inequality as follows:
797
- # k > align_2MB(b*x)
798
- # where
799
- # k = a / c
800
-
801
- # After that, we can derive the following equation:
802
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
803
-
804
- def align(x: int, nbytes: int) -> int:
805
- return int(math.ceil(x / nbytes) * nbytes)
806
-
807
- def align_2MB(x: int) -> int:
808
- return align(x, 2**21)
809
-
810
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
811
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
812
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
813
- vocab_size = config.vocab_size
814
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
815
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
816
-
817
- # TODO(jongho): Update if target npu is REBEL.
818
- ATOM_DRAM_NBYTES = 16 * 2**30
819
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
820
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
821
-
822
- if kernel_size is None:
823
- if n_model_params is None:
824
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
825
- # Get estimated kernel size (approximated)
826
- lm_heads_params = align(vocab_size, 64) * hidden_size
827
- lm_heads_nbytes = (
828
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
829
- )
830
- params = n_model_params - lm_heads_params
831
- layer_nbytes = (
832
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
833
- * num_layers
834
- * tensor_parallel_size
835
- )
836
- kernel_size = layer_nbytes + lm_heads_nbytes
837
- elif n_model_params is not None:
838
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
839
-
840
- available_dram -= kernel_size
632
+ def get_quantized_model(
633
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
634
+ ) -> PreTrainedModel:
635
+ raise NotImplementedError
841
636
 
842
- if buffer is None:
843
- # TODO: Accurate buffer estimation
844
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
845
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
846
- buffer = buffer_per_core * tensor_parallel_size
847
- available_dram -= buffer
637
+ @classmethod
638
+ def get_pytorch_model(
639
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
640
+ ) -> PreTrainedModel:
641
+ if rbln_config and rbln_config.quantization:
642
+ model = cls.get_quantized_model(*args, **kwargs)
643
+ else:
644
+ model = super().get_pytorch_model(*args, **kwargs)
848
645
 
849
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
850
- c = num_layers * 2 * tensor_parallel_size
851
- k = available_dram / c
852
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
646
+ return model
853
647
 
854
- return max_n_blocks
648
+ @classmethod
649
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
650
+ return use_local_attention
855
651
 
856
652
  @classmethod
857
653
  def get_input_info(
@@ -861,13 +657,12 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
861
657
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
862
658
  model_config: PretrainedConfig,
863
659
  ):
864
- is_prefill: bool = query_length > 1
865
660
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
866
661
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
867
662
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
868
663
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
869
664
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
870
- local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
665
+ is_prefill = query_length > 1
871
666
 
872
667
  # 1. main input
873
668
  if rbln_config.use_inputs_embeds:
@@ -886,16 +681,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
886
681
  ]
887
682
 
888
683
  # 3. block_tables
889
- if rbln_config.cache_impl in ["static", "hybrid"]:
684
+ if rbln_config.use_global_attention:
890
685
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
891
686
  input_info.extend(
892
687
  [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
893
688
  )
894
- if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
689
+ if rbln_config.use_local_attention:
895
690
  input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
896
691
 
897
- # 4. query_position
898
- if is_prefill:
692
+ # 4. query_position for sliding window attention
693
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
899
694
  input_info.extend([("query_position", [], "int16")])
900
695
 
901
696
  # 5. attention_mask & position_ids
@@ -917,7 +712,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
917
712
  rbln_config.kvcache_block_size,
918
713
  head_dim,
919
714
  ]
920
- local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
715
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
921
716
  input_info.extend(
922
717
  [
923
718
  (
@@ -964,13 +759,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
964
759
  # ```
965
760
 
966
761
  # Returns:
967
- # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
762
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
968
763
 
969
764
  raise NotImplementedError(
970
765
  "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
971
766
  "See method docstring for required configuration details."
972
767
  )
973
768
 
769
+ @classmethod
770
+ def _update_attention_config(
771
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
772
+ ):
773
+ rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
774
+ attn_impl=rbln_config.attn_impl,
775
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
776
+ kvcache_block_size=rbln_config.kvcache_block_size,
777
+ max_seq_len=rbln_config.max_seq_len,
778
+ )
779
+
780
+ validate_attention_method(
781
+ attn_impl=rbln_config.attn_impl,
782
+ kvcache_partition_len=rbln_config.kvcache_partition_len,
783
+ kvcache_block_size=rbln_config.kvcache_block_size,
784
+ max_seq_len=rbln_config.max_seq_len,
785
+ )
786
+
787
+ if rbln_config.kvcache_num_blocks is None:
788
+ rbln_config.kvcache_num_blocks = (
789
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size
790
+ ) * rbln_config.batch_size
791
+
792
+ return rbln_config
793
+
974
794
  @classmethod
975
795
  def _update_rbln_config(
976
796
  cls,
@@ -991,8 +811,384 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
991
811
  ):
992
812
  rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
993
813
  if rbln_config.sliding_window is not None:
994
- validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
814
+ validate_sliding_window(rbln_config)
815
+
816
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
817
+
818
+ prefill_input_info = cls.get_input_info(
819
+ batch_size=1,
820
+ query_length=rbln_config.prefill_chunk_size,
821
+ rbln_config=rbln_config,
822
+ model_config=model_config,
823
+ )
824
+
825
+ prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
826
+ rbln_config.set_compile_cfgs([prefill_compile_config])
827
+
828
+ return rbln_config
829
+
830
+ @classmethod
831
+ def _create_runtimes(
832
+ cls,
833
+ compiled_models: List[rebel.RBLNCompiledModel],
834
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
835
+ ) -> List[rebel.Runtime]:
836
+ expected_model_names = [
837
+ "prefill",
838
+ ]
839
+ if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
840
+ cls._raise_missing_compiled_file_error(expected_model_names)
841
+
842
+ return [
843
+ rebel.Runtime(
844
+ compiled_models[0],
845
+ tensor_type="pt",
846
+ device=rbln_config.device_map["prefill"],
847
+ activate_profiler=rbln_config.activate_profiler,
848
+ ),
849
+ ]
850
+
851
+ def _preprocess_chunked_prefill(
852
+ self,
853
+ inputs: torch.Tensor,
854
+ attention_mask: Optional[torch.Tensor] = None,
855
+ position_embed: Optional[torch.Tensor] = None,
856
+ ):
857
+ # valid sequence length of inputs_embeds
858
+ query_length = inputs.shape[1] if attention_mask is None else torch.sum(attention_mask.view(-1)).item()
859
+
860
+ # extract valid inputs
861
+ inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
862
+
863
+ if inputs.dim() == 2 and self.rbln_config.use_inputs_embeds:
864
+ inputs = self.get_input_embeddings()(inputs)
865
+
866
+ if position_embed is not None:
867
+ position_embed = (
868
+ position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
869
+ )
870
+
871
+ # padding for chunked prefill
872
+ padding_size = (
873
+ self.rbln_config.prefill_chunk_size - (query_length % self.rbln_config.prefill_chunk_size)
874
+ ) % self.rbln_config.prefill_chunk_size
875
+ padded_len = query_length + padding_size
876
+
877
+ inputs = (
878
+ torch.nn.functional.pad(inputs, (0, padding_size))
879
+ if not self.rbln_config.use_inputs_embeds
880
+ else torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
881
+ )
882
+ position_embed = (
883
+ None if position_embed is None else torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
884
+ )
885
+ cache_position = torch.arange(padded_len, dtype=torch.int32).unsqueeze(0)
886
+
887
+ chunked_attention_mask = (
888
+ torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
889
+ if self.rbln_config.use_attention_mask
890
+ else None
891
+ )
892
+
893
+ return inputs, position_embed, cache_position, query_length, chunked_attention_mask
894
+
895
+ def _chunked_prefill_forward(
896
+ self,
897
+ inputs: torch.Tensor,
898
+ attention_mask: Optional[torch.Tensor] = None,
899
+ position_embed: Optional[torch.Tensor] = None,
900
+ ):
901
+ padded_input, padded_position_embed, cache_position, query_length, chunked_attention_mask = (
902
+ self._preprocess_chunked_prefill(inputs, attention_mask, position_embed)
903
+ )
904
+
905
+ # chunked prefill
906
+ last_hidden_states = []
907
+ for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
908
+ # Extract the current chunk of inputs and cache positions
909
+ input_chunk = padded_input[:, step : step + self.rbln_config.prefill_chunk_size]
910
+ cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
911
+
912
+ valid_length = (
913
+ self.rbln_config.prefill_chunk_size
914
+ if (step + self.rbln_config.prefill_chunk_size) <= query_length
915
+ else query_length - step
916
+ )
917
+ if self.rbln_config.use_local_attention:
918
+ query_position = torch.tensor(valid_length - 1, dtype=torch.int16)
919
+ else:
920
+ query_position = None
921
+
922
+ if self.rbln_config.use_attention_mask:
923
+ if step > 0:
924
+ chunked_attention_mask[:, :, :, :step] = 1
925
+ chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
926
+
927
+ # Forward pass for the current chunk
928
+ last_hidden_states_chunk = self.prefill_decoder(
929
+ input_ids=input_chunk if not self.rbln_config.use_inputs_embeds else None,
930
+ inputs_embeds=input_chunk if self.rbln_config.use_inputs_embeds else None,
931
+ cache_position=cache_pos_chunk,
932
+ block_tables=self.block_tables if self.rbln_config.use_global_attention else None,
933
+ local_block_tables=self.local_block_tables if self.rbln_config.use_local_attention else None,
934
+ query_position=query_position,
935
+ attention_mask=chunked_attention_mask,
936
+ position_emb=padded_position_embed,
937
+ )
938
+ last_hidden_states.append(last_hidden_states_chunk)
939
+ last_hidden_states = torch.concat(last_hidden_states, dim=-2)[:, :query_length]
940
+
941
+ return self._postprocess_chunked_prefill(last_hidden_states, attention_mask)
942
+
943
+ def _postprocess_chunked_prefill(
944
+ self, last_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
945
+ ):
946
+ # index copy for attention mask
947
+ if attention_mask is not None:
948
+ new_last_hidden_states = torch.full(
949
+ (1, attention_mask.shape[-1], last_hidden_states.shape[-1]),
950
+ fill_value=1e-10,
951
+ dtype=last_hidden_states.dtype,
952
+ )
953
+ mask_indices = torch.nonzero(attention_mask, as_tuple=True)[0]
954
+ new_last_hidden_states.index_copy_(dim=-2, index=mask_indices, source=last_hidden_states)
955
+ else:
956
+ new_last_hidden_states = last_hidden_states
957
+ return new_last_hidden_states
958
+
959
+ def forward(
960
+ self,
961
+ input_ids: Optional[torch.LongTensor] = None,
962
+ inputs_embeds: Optional[torch.Tensor] = None,
963
+ attention_mask: Optional[torch.LongTensor] = None,
964
+ position_embed: Optional[torch.Tensor] = None,
965
+ **kwargs,
966
+ ) -> Tuple[torch.FloatTensor]:
967
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
968
+ batch_size = inputs.shape[0]
969
+ all_last_hidden_states = []
970
+ for b_idx in range(batch_size):
971
+ last_hidden_states = self._chunked_prefill_forward(
972
+ inputs[b_idx : b_idx + 1],
973
+ attention_mask[b_idx] if attention_mask is not None else None,
974
+ position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
975
+ )
976
+ all_last_hidden_states.append(last_hidden_states)
977
+
978
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
979
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
980
+
981
+
982
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel):
983
+ """
984
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
985
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
986
+
987
+ The class provides core functionality for:
988
+
989
+ 1. Converting pre-trained transformer models to RBLN-optimized format
990
+ 2. Handling the compilation process for RBLN devices
991
+ 3. Managing inference operations for causal language modeling
992
+
993
+ This class inherits from RBLNModel and implements specific methods required for
994
+ decoder-only architectures and causal language modeling tasks.
995
+
996
+ Note:
997
+ - This class is designed to be subclassed by specific model implementations
998
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
999
+ - Subclasses should implement model-specific conversion logic.
1000
+ - The class handles RBLN-specific optimizations automatically during compilation
1001
+ """
1002
+
1003
+ auto_model_class = AutoModelForCausalLM
1004
+
1005
+ def __post_init__(self, **kwargs):
1006
+ main_input_name = self.main_input_name
1007
+
1008
+ if self.rbln_config.use_inputs_embeds:
1009
+ main_input_name = "inputs_embeds"
1010
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
1011
+ self.embed_tokens = self._create_embedding_layer()
1012
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
1013
+ else:
1014
+ self.embed_tokens = None
1015
+
1016
+ # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
1017
+ dec_attn_mask = torch.zeros(
1018
+ self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
1019
+ )
1020
+ block_tables = torch.zeros(
1021
+ self.rbln_config.batch_size,
1022
+ self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
1023
+ dtype=torch.int16,
1024
+ ).fill_(-1)
1025
+ free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
1026
+
1027
+ self.prefill_decoder = RBLNRuntimeModel(
1028
+ runtime=self.model[0],
1029
+ main_input_name=main_input_name,
1030
+ embed_tokens=self.embed_tokens,
1031
+ phase="prefill",
1032
+ batch_size=self.rbln_config.batch_size,
1033
+ dec_attn_mask=dec_attn_mask,
1034
+ block_tables=block_tables,
1035
+ free_block_pool=free_block_pool,
1036
+ rbln_config=self.rbln_config,
1037
+ vocab_size=self.config.vocab_size,
1038
+ )
1039
+
1040
+ if self.can_generate():
1041
+ self.decoders = {}
1042
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
1043
+ self.decoders[batch_size] = RBLNRuntimeModel(
1044
+ runtime=self.model[i + 1],
1045
+ main_input_name=main_input_name,
1046
+ embed_tokens=self.embed_tokens,
1047
+ phase="decode",
1048
+ batch_size=batch_size,
1049
+ dec_attn_mask=dec_attn_mask,
1050
+ block_tables=block_tables,
1051
+ free_block_pool=free_block_pool,
1052
+ rbln_config=self.rbln_config,
1053
+ )
1054
+
1055
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
1056
+ self.decoder = self.decoders[self.rbln_config.batch_size]
1057
+
1058
+ @classmethod
1059
+ def get_quantized_model(
1060
+ cls,
1061
+ model_id: str,
1062
+ config: Optional[PretrainedConfig] = None,
1063
+ use_auth_token: Optional[Union[bool, str]] = None,
1064
+ revision: Optional[str] = None,
1065
+ force_download: bool = False,
1066
+ cache_dir: Optional[str] = None,
1067
+ subfolder: str = "",
1068
+ local_files_only: bool = False,
1069
+ trust_remote_code: bool = False,
1070
+ **kwargs,
1071
+ ):
1072
+ kwargs = cls.update_kwargs(kwargs)
1073
+
1074
+ if config is None:
1075
+ config = AutoConfig.from_pretrained(
1076
+ model_id,
1077
+ use_auth_token=use_auth_token,
1078
+ revision=revision,
1079
+ force_download=force_download,
1080
+ cache_dir=cache_dir,
1081
+ trust_remote_code=trust_remote_code,
1082
+ **kwargs,
1083
+ )
1084
+
1085
+ with no_init_weights():
1086
+ model = AutoModelForCausalLM.from_config(config)
1087
+
1088
+ model = prepare_model_for_quantization(
1089
+ model,
1090
+ model_id,
1091
+ kwargs.get("num_hidden_layers"),
1092
+ use_auth_token=use_auth_token,
1093
+ revision=revision,
1094
+ cache_dir=cache_dir,
1095
+ force_download=force_download,
1096
+ local_files_only=local_files_only,
1097
+ )
1098
+ return model
1099
+
1100
+ def __getattr__(self, __name: str) -> Any:
1101
+ # Special method to delegate attribute access to the original Huggingface LM class.
1102
+ # This method is called when an attribute is not found in the current instance's dictionary.
1103
+ # It enables transparent access to the original model's attributes and methods while maintaining
1104
+ # proper method binding.
1105
+
1106
+ # The method implements a delegation pattern that:
1107
+
1108
+ # 1. For methods: Creates a wrapper that properly binds 'self' to method calls
1109
+ # 2. For other attributes: Returns them directly from the original class
995
1110
 
1111
+ def redirect(func):
1112
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
1113
+
1114
+ val = getattr(self.get_hf_class(), __name, None) or getattr(PreTrainedModel, __name)
1115
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
1116
+ return redirect(val)
1117
+ return val
1118
+
1119
+ @classmethod
1120
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
1121
+ wrapper_cfg = {
1122
+ "max_seq_len": rbln_config.max_seq_len,
1123
+ "attn_impl": rbln_config.attn_impl,
1124
+ "kvcache_partition_len": rbln_config.kvcache_partition_len,
1125
+ "kvcache_block_size": rbln_config.kvcache_block_size,
1126
+ "use_rotary_emb": cls._use_rotary_emb,
1127
+ "use_attention_mask": rbln_config.use_attention_mask,
1128
+ "use_position_ids": rbln_config.use_position_ids,
1129
+ "use_inputs_embeds": rbln_config.use_inputs_embeds,
1130
+ "cache_impl": rbln_config.cache_impl,
1131
+ "sliding_window": rbln_config.sliding_window,
1132
+ "sliding_window_layers": rbln_config.sliding_window_layers,
1133
+ }
1134
+ return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
1135
+
1136
+ @classmethod
1137
+ @torch.inference_mode()
1138
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
1139
+ wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
1140
+ prefill_compile_config = rbln_config.compile_cfgs[0]
1141
+
1142
+ # Here we use meta tensor, for the memory efficiency.
1143
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
1144
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
1145
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
1146
+
1147
+ compiled_models = {}
1148
+ compiled_models["prefill"] = cls._compile_model(
1149
+ wrapped_model,
1150
+ prefill_compile_config,
1151
+ prefill_example_inputs,
1152
+ context,
1153
+ rbln_config,
1154
+ rbln_config.quantization,
1155
+ phase="prefill",
1156
+ )
1157
+
1158
+ if rbln_config.can_generate:
1159
+ wrapped_model.phase = "decode"
1160
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
1161
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
1162
+ compiled_decoder = cls._compile_model(
1163
+ wrapped_model,
1164
+ dec_compile_config,
1165
+ dec_example_inputs,
1166
+ context,
1167
+ rbln_config,
1168
+ rbln_config.quantization,
1169
+ phase="decode",
1170
+ )
1171
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
1172
+
1173
+ # check if the memory is enough to have additional blocks
1174
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1175
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
1176
+ cls.maybe_suggest_kvcache_num_blocks(
1177
+ compiled_models=compiled_models,
1178
+ model_config=model.config,
1179
+ rbln_config=rbln_config,
1180
+ )
1181
+
1182
+ return compiled_models
1183
+
1184
+ @classmethod
1185
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
1186
+ return is_prefill
1187
+
1188
+ @classmethod
1189
+ def _update_attention_config(
1190
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
1191
+ ):
996
1192
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
997
1193
  attn_impl=rbln_config.attn_impl,
998
1194
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -1017,13 +1213,13 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1017
1213
  kvcache_block_size=rbln_config.kvcache_block_size,
1018
1214
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1019
1215
  n_model_params=sum(p.numel() for p in model.parameters()),
1020
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
1216
+ num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1021
1217
  )
1022
1218
 
1023
1219
  max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
1024
1220
 
1025
1221
  flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1026
- if max_num_blocks < flash_min_blocks:
1222
+ if rbln_config.batch_size > 1 and max_num_blocks < flash_min_blocks:
1027
1223
  max_num_blocks = flash_min_blocks
1028
1224
 
1029
1225
  if max_num_blocks < rbln_config.batch_size:
@@ -1042,27 +1238,30 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1042
1238
  )
1043
1239
  logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1044
1240
 
1045
- prefill_input_info = cls.get_input_info(
1046
- batch_size=1,
1047
- query_length=rbln_config.prefill_chunk_size,
1048
- rbln_config=rbln_config,
1049
- model_config=model_config,
1050
- )
1051
-
1052
- prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
1241
+ return rbln_config
1053
1242
 
1054
- dec_compile_configs = []
1055
- for batch_size in rbln_config.decoder_batch_sizes:
1056
- dec_input_info = cls.get_input_info(
1057
- batch_size=batch_size,
1058
- query_length=1,
1059
- rbln_config=rbln_config,
1060
- model_config=model_config,
1061
- )
1062
- dec_compile_configs.append(
1063
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1064
- )
1065
- rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
1243
+ @classmethod
1244
+ def _update_rbln_config(
1245
+ cls,
1246
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
1247
+ model: Optional[PreTrainedModel] = None,
1248
+ model_config: Optional[PretrainedConfig] = None,
1249
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
1250
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
1251
+ rbln_config = super()._update_rbln_config(preprocessors, model, model_config, rbln_config)
1252
+ if rbln_config.can_generate:
1253
+ compile_configs = rbln_config.compile_cfgs
1254
+ for batch_size in rbln_config.decoder_batch_sizes:
1255
+ dec_input_info = cls.get_input_info(
1256
+ batch_size=batch_size,
1257
+ query_length=1,
1258
+ rbln_config=rbln_config,
1259
+ model_config=model_config,
1260
+ )
1261
+ compile_configs.append(
1262
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1263
+ )
1264
+ rbln_config.set_compile_cfgs(compile_configs)
1066
1265
 
1067
1266
  return rbln_config
1068
1267
 
@@ -1072,36 +1271,45 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1072
1271
  compiled_models: List[rebel.RBLNCompiledModel],
1073
1272
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1074
1273
  ) -> List[rebel.Runtime]:
1075
- expected_model_names = [
1076
- "prefill",
1077
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1078
- ]
1274
+ expected_model_names = ["prefill"]
1275
+ if rbln_config.can_generate:
1276
+ expected_model_names.extend(
1277
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
1278
+ )
1079
1279
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1080
1280
  cls._raise_missing_compiled_file_error(expected_model_names)
1081
1281
 
1082
- return [
1282
+ ret_val = [
1083
1283
  rebel.Runtime(
1084
1284
  compiled_models[0],
1085
1285
  tensor_type="pt",
1086
1286
  device=rbln_config.device_map["prefill"],
1087
1287
  activate_profiler=rbln_config.activate_profiler,
1088
- ),
1089
- *[
1090
- rebel.Runtime(
1091
- compiled_models[i + 1],
1092
- tensor_type="pt",
1093
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1094
- activate_profiler=rbln_config.activate_profiler,
1095
- )
1096
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1097
- ],
1288
+ timeout=rbln_config.timeout,
1289
+ )
1098
1290
  ]
1291
+ if rbln_config.can_generate:
1292
+ ret_val.extend(
1293
+ [
1294
+ rebel.Runtime(
1295
+ compiled_models[i + 1],
1296
+ tensor_type="pt",
1297
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1298
+ activate_profiler=rbln_config.activate_profiler,
1299
+ timeout=rbln_config.timeout,
1300
+ )
1301
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1302
+ ]
1303
+ )
1304
+ return ret_val
1099
1305
 
1100
1306
  def get_decoder(self):
1307
+ if not self.can_generate():
1308
+ raise ValueError("Decode stage is not supported in this model.")
1101
1309
  return self.decoder
1102
1310
 
1103
1311
  def can_generate(self):
1104
- return True
1312
+ return self.rbln_config.can_generate
1105
1313
 
1106
1314
  def _reorder_cache(self, past_key_values, beam_idx):
1107
1315
  raise NotImplementedError
@@ -1158,7 +1366,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1158
1366
 
1159
1367
  def _update_model_kwargs_for_generation(
1160
1368
  self,
1161
- outputs: RBLNDecoderOnlyOutput,
1369
+ outputs: RBLNDecoderOnlyForCausalLMOutput,
1162
1370
  model_kwargs: Dict[str, Any],
1163
1371
  **kwargs,
1164
1372
  ) -> Dict[str, Any]:
@@ -1186,7 +1394,16 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1186
1394
  # A for-loop ensures synchronization with the HuggingFace generate API.
1187
1395
  # The decoder stage operates as usual, processing inputs in batch mode.
1188
1396
 
1189
- # Prefll
1397
+ # for only use forward
1398
+ if generate_idx is None:
1399
+ generate_idx = (
1400
+ attention_mask.sum(dim=-1, keepdim=True).int()
1401
+ if attention_mask is not None
1402
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
1403
+ )
1404
+ padded_cache_lengths = torch.zeros_like(generate_idx)
1405
+
1406
+ # Prefill
1190
1407
  if cache_position is None:
1191
1408
  logits = []
1192
1409
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
@@ -1224,6 +1441,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1224
1441
  if not return_dict:
1225
1442
  return logits, generate_idx, padded_cache_lengths
1226
1443
  else:
1227
- return RBLNDecoderOnlyOutput(
1444
+ return RBLNDecoderOnlyForCausalLMOutput(
1228
1445
  logits=logits, generate_idx=generate_idx, padded_cache_lengths=padded_cache_lengths
1229
1446
  )