optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -13,31 +13,31 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- import math
17
- from collections import deque
18
- from dataclasses import dataclass
19
16
  from pathlib import Path
20
- from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
21
18
 
22
19
  import rebel
23
20
  import torch
24
21
  from rebel.compile_context import CompileContext
25
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
24
  from transformers.modeling_utils import no_init_weights
27
- from transformers.utils import ModelOutput
28
25
 
29
26
  from ....configuration_utils import RBLNCompileConfig
30
27
  from ....modeling import RBLNModel
31
28
  from ....utils.logging import get_logger
32
- from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import prepare_model_for_quantization
34
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
- from .decoderonly_architecture import (
36
- DecoderOnlyWrapper,
29
+ from ...modeling_attention_utils import (
30
+ RBLNDecoderOnlyFlashAttentionMixin,
37
31
  set_default_values,
38
32
  validate_attention_method,
39
- validate_sliding_window_size,
33
+ validate_sliding_window,
40
34
  )
35
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...utils.rbln_quantization import get_quantized_model
37
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .decoderonly_architecture import DecoderOnlyWrapper
39
+ from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
+ from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
41
41
 
42
42
 
43
43
  logger = get_logger()
@@ -46,522 +46,85 @@ if TYPE_CHECKING:
46
46
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
47
47
 
48
48
 
49
- class RBLNRuntimeModel(RBLNPytorchRuntime):
50
- mandatory_members = ["main_input_name", "embed_tokens"]
51
-
52
- def __init__(
53
- self,
54
- runtime: rebel.Runtime,
55
- phase: str,
56
- batch_size: int,
57
- dec_attn_mask: torch.Tensor,
58
- block_tables: torch.Tensor,
59
- free_block_pool: Deque,
60
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
61
- **kwargs: Any,
62
- ) -> None:
63
- super().__init__(runtime, **kwargs)
64
- self.phase = phase
65
- self.batch_size = batch_size
66
- self.rbln_config = rbln_config
67
-
68
- # shared tensor between prefill and decode phase
69
- self.dec_attn_mask = dec_attn_mask
70
- self.block_tables = block_tables
71
- self.free_block_pool = free_block_pool
72
-
73
- self.empty_block = -1
74
- if self.phase == "prefill":
75
- vocab_size = kwargs.pop("vocab_size")
76
- self.output_size = [1, 1, vocab_size]
77
- self.causal_mask = 1 - torch.triu(
78
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
- )
80
-
81
- def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
82
- """
83
- Manages and returns the KV cache block tables.
84
- Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
85
-
86
- Args:
87
- cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
88
- batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
89
-
90
- Returns:
91
- Updated block tables.
92
- """
93
-
94
- NO_BLOCKS_ERROR = (
95
- "No memory blocks are available for allocation. "
96
- "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
97
- "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
98
- "Using vllm-rbln should fix this issue and enhance inference performance."
99
- )
100
-
101
- def update_block(batch_idx: int, block_idx: int):
102
- """
103
- If the block is empty (empty_block), allocates a block from the free_block_pool.
104
- """
105
- if self.block_tables[batch_idx][block_idx] == self.empty_block:
106
- if self.free_block_pool:
107
- block = self.free_block_pool.popleft()
108
- self.block_tables[batch_idx][block_idx] = block
109
- else:
110
- raise RuntimeError(NO_BLOCKS_ERROR)
111
-
112
- def replace_empty_block(block_tables: torch.Tensor):
113
- """
114
- Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
115
- """
116
- if not torch.any(block_tables == self.empty_block):
117
- return block_tables.clone()
118
- elif self.free_block_pool:
119
- _free_block = self.free_block_pool[0]
120
- return torch.where(block_tables == self.empty_block, _free_block, block_tables)
121
- else:
122
- raise RuntimeError(NO_BLOCKS_ERROR)
123
-
124
- def get_global_block_tables(batch_idx: int):
125
- if self.rbln_config.cache_impl == "sliding_window":
126
- return None
127
-
128
- if self.phase == "prefill":
129
- # Track previously used blocks and return them to the free_block_pool and
130
- # reset the current batch's block table to empty blocks
131
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
132
- self.free_block_pool.extend(prev_blocks)
133
- self.block_tables[batch_idx].fill_(self.empty_block)
134
-
135
- # Get the start (s) and end (e) positions from cache_position and
136
- # iterate over the cache positions to allocate necessary blocks
137
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
138
- for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
139
- block_idx = position // self.rbln_config.kvcache_block_size
140
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
141
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
142
- update_block(batch_idx, block_idx)
143
-
144
- return replace_empty_block(self.block_tables[batch_idx])
145
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
146
- else:
147
- for b_idx in range(self.batch_size):
148
- position = cache_position[b_idx][0].item()
149
- block_idx = position // self.rbln_config.kvcache_block_size
150
- update_block(b_idx, block_idx)
151
-
152
- return replace_empty_block(self.block_tables)
153
-
154
- def get_local_block_tables(batch_idx: int):
155
- if self.rbln_config.cache_impl == "static":
156
- return None
157
- else:
158
- return (
159
- torch.tensor([batch_idx], dtype=torch.int16)
160
- if self.phase == "prefill"
161
- else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
162
- )
163
-
164
- return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
165
-
166
- def is_external_block_tables(
167
- self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
168
- ):
169
- if self.rbln_config.cache_impl == "static" and block_tables is None:
170
- return False
171
- elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
172
- return False
173
- elif self.rbln_config.cache_impl == "hybrid":
174
- if (block_tables is not None) != (local_block_tables is not None):
175
- raise ValueError(
176
- "Both block_tables and local_block_tables must be provided or neither of them must be provided."
177
- )
178
- elif block_tables is None and local_block_tables is None:
179
- return False
180
-
181
- return True
182
-
183
- def forward(
184
- self,
185
- input_ids: Optional[torch.LongTensor] = None,
186
- inputs_embeds: Optional[torch.Tensor] = None,
187
- cache_position: torch.Tensor = None,
188
- attention_mask: Optional[torch.Tensor] = None,
189
- batch_idx: Optional[int] = None,
190
- block_tables: Optional[torch.Tensor] = None,
191
- position_embed: Optional[torch.Tensor] = None,
192
- position_ids: Optional[torch.Tensor] = None,
193
- token_type_ids: Optional[torch.Tensor] = None,
194
- local_block_tables: Optional[torch.Tensor] = None,
195
- ):
196
- if input_ids is None and inputs_embeds is None:
197
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
198
-
199
- if inputs_embeds is None:
200
- inputs = input_ids
201
- if self.embed_tokens is not None:
202
- inputs = self.embed_tokens(inputs)
203
- else:
204
- inputs = inputs_embeds
205
-
206
- is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
207
- if not is_external_block_tables:
208
- block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
209
-
210
- if self.phase == "decode":
211
- return self.decode_forward(
212
- inputs,
213
- cache_position,
214
- block_tables,
215
- is_external_block_tables,
216
- attention_mask=attention_mask,
217
- position_embed=position_embed,
218
- position_ids=position_ids,
219
- local_block_tables=local_block_tables,
220
- )
221
- else:
222
- return self.prefill_forward(
223
- inputs,
224
- cache_position,
225
- attention_mask,
226
- batch_idx,
227
- block_tables,
228
- is_external_block_tables=is_external_block_tables,
229
- position_embed=position_embed,
230
- token_type_ids=token_type_ids,
231
- local_block_tables=local_block_tables,
232
- )
233
-
234
- def decode_forward(
235
- self,
236
- inputs: torch.Tensor,
237
- cache_position: torch.Tensor = None,
238
- block_tables: torch.Tensor = None,
239
- is_external_block_tables: bool = None,
240
- attention_mask: Optional[torch.Tensor] = None,
241
- position_embed: Optional[torch.Tensor] = None,
242
- position_ids: Optional[torch.Tensor] = None,
243
- local_block_tables: Optional[torch.Tensor] = None,
244
- ) -> torch.FloatTensor:
245
- batch_size = inputs.shape[0]
246
- if batch_size != self.batch_size:
247
- raise RuntimeError(
248
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
249
- )
250
-
251
- if batch_size != cache_position.shape[0]:
252
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
253
-
254
- if self.rbln_config.use_attention_mask and attention_mask is None:
255
- for b_idx in range(batch_size):
256
- decoding_step = cache_position[b_idx].item()
257
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
258
- raise ValueError(
259
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
260
- )
261
-
262
- if is_external_block_tables:
263
- self.dec_attn_mask[b_idx].fill_(0)
264
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
265
- else:
266
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
267
-
268
- attention_mask = self.dec_attn_mask
269
-
270
- if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
271
- block_tables = block_tables[: self.batch_size]
272
-
273
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
274
- attention_mask = attention_mask[: self.batch_size]
275
-
276
- logits = super().forward(
277
- inputs,
278
- cache_position,
279
- block_tables,
280
- local_block_tables,
281
- position_embed,
282
- attention_mask if self.rbln_config.use_attention_mask else None,
283
- position_ids if self.rbln_config.use_position_ids else None,
284
- )
285
-
286
- return RBLNDecoderOnlyOutput(logits=logits)
287
-
288
- def _prepare_prefill_inputs(
289
- self,
290
- inputs: torch.Tensor,
291
- cache_position: torch.Tensor,
292
- attention_mask: Optional[torch.Tensor] = None,
293
- position_embed: Optional[torch.Tensor] = None,
294
- token_type_ids: Optional[torch.Tensor] = None,
295
- ):
296
- """
297
- Prepare inputs for prefill phase.
298
- """
299
- # Handle continuous batching in a compiled graph by extracting valid inputs
300
- # If an attention mask is provided, select only the valid (non-masked) inputs
301
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
302
- if position_embed is not None:
303
- position_embed = (
304
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
305
- )
306
-
307
- query_length = inputs.shape[1]
308
- if query_length > self.rbln_config.max_seq_len:
309
- raise ValueError(
310
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
311
- )
312
-
313
- # Initialize attention mask for chunked processing
314
- chunked_attention_mask = (
315
- torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
316
- if self.rbln_config.use_attention_mask
317
- else None
318
- )
319
-
320
- # Buffer for storing output logits
321
- out_buffers = [
322
- torch.empty(
323
- size=self.output_size,
324
- dtype=torch.float32,
325
- device="cpu",
326
- )
327
- ]
328
-
329
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
330
- padding_size = 0
331
- if query_length % self.rbln_config.prefill_chunk_size != 0:
332
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
333
- # inputs_embeds
334
- if inputs.dim() == 3:
335
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
336
- # inputs_ids
337
- else:
338
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
339
-
340
- cache_position = torch.cat(
341
- [
342
- cache_position,
343
- torch.arange(
344
- query_length,
345
- query_length + padding_size,
346
- dtype=torch.int32,
347
- ).unsqueeze(0),
348
- ],
349
- dim=-1,
350
- )
351
-
352
- if position_embed is not None:
353
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
354
-
355
- # Overwrite position_ids and padded_cache_lengths
356
- position_ids = None
357
- padded_cache_lengths = 0
358
-
359
- return (
360
- inputs,
361
- cache_position,
362
- chunked_attention_mask,
363
- out_buffers,
364
- position_ids,
365
- position_embed,
366
- padded_cache_lengths,
367
- query_length,
368
- )
369
-
370
- def prefill_forward(
371
- self,
372
- inputs: torch.Tensor,
373
- cache_position: torch.Tensor = None,
374
- attention_mask: Optional[torch.Tensor] = None,
375
- batch_idx: int = None,
376
- block_tables: torch.Tensor = None,
377
- is_external_block_tables: bool = False,
378
- position_embed: Optional[torch.Tensor] = None,
379
- token_type_ids: Optional[torch.Tensor] = None,
380
- local_block_tables: Optional[torch.Tensor] = None,
381
- ) -> torch.FloatTensor:
382
- """
383
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
384
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
385
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
386
- """
387
- (
388
- inputs,
389
- cache_position,
390
- chunked_attention_mask,
391
- out_buffers,
392
- position_ids,
393
- position_embed,
394
- padded_cache_lengths,
395
- query_length,
396
- ) = self._prepare_prefill_inputs(
397
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
398
- )
399
-
400
- # Process input in chunks of size `prefill_chunk_size`
401
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
402
- # Extract the current chunk of inputs and cache positions
403
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
404
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
405
- position_ids_chunk = (
406
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
407
- if position_ids is not None
408
- else None
409
- )
410
- if position_embed is not None:
411
- position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
412
-
413
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
414
- # Update attention mask to ensure proper causal behavior
415
- if step >= self.rbln_config.prefill_chunk_size:
416
- chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
417
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
418
-
419
- # Define query position
420
- if step + self.rbln_config.prefill_chunk_size >= query_length:
421
- query_position = torch.tensor(
422
- (query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
423
- )
424
- else:
425
- query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
426
-
427
- # Forward pass for the current chunk
428
- logits = super().forward(
429
- input_chunk,
430
- cache_pos_chunk,
431
- block_tables,
432
- local_block_tables,
433
- position_embed_chunk if position_embed is not None else None,
434
- query_position,
435
- chunked_attention_mask if self.rbln_config.use_attention_mask else None,
436
- position_ids_chunk if self.rbln_config.use_position_ids else None,
437
- out=out_buffers,
438
- )
439
-
440
- # Update decoder attention mask with processed KV-cache length from prefill phase
441
- if not is_external_block_tables and self.rbln_config.use_attention_mask:
442
- self.dec_attn_mask[batch_idx].fill_(0)
443
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
444
-
445
- return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
446
-
447
-
448
- @dataclass
449
- class RBLNDecoderOnlyOutput(ModelOutput):
450
- logits: torch.FloatTensor = None
451
- generate_idx: torch.Tensor = None
452
- padded_cache_lengths: int = None
453
-
454
-
455
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
49
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
456
50
  """
457
- A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
51
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
52
+ This class is used for RBLN-optimized models that are not causal language models.
458
53
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
459
54
 
460
55
  The class provides core functionality for:
461
56
 
462
57
  1. Converting pre-trained transformer models to RBLN-optimized format
463
58
  2. Handling the compilation process for RBLN devices
464
- 3. Managing inference operations for causal language modeling
465
-
59
+ 3. Managing inference operations for decoder-only architectures
466
60
  This class inherits from RBLNModel and implements specific methods required for
467
- decoder-only architectures and causal language modeling tasks.
61
+ decoder-only architectures.
468
62
 
469
63
  Note:
470
64
  - This class is designed to be subclassed by specific model implementations
471
- (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
65
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
472
66
  - Subclasses should implement model-specific conversion logic.
473
67
  - The class handles RBLN-specific optimizations automatically during compilation
474
68
  """
475
69
 
70
+ _tp_support = True
71
+
476
72
  main_input_name = "input_ids"
477
- auto_model_class = AutoModelForCausalLM
73
+ auto_model_class = AutoModel
478
74
  _decoder_wrapper_cls = DecoderOnlyWrapper
479
75
  _use_rotary_emb = True
76
+ _supports_non_fp32 = True
480
77
 
481
78
  def __post_init__(self, **kwargs):
482
- main_input_name = self.main_input_name
483
-
484
79
  if self.rbln_config.use_inputs_embeds:
485
- main_input_name = "inputs_embeds"
486
80
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
487
81
  self.embed_tokens = self._create_embedding_layer()
488
82
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
489
83
  else:
490
84
  self.embed_tokens = None
491
85
 
492
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
493
- dec_attn_mask = torch.zeros(
494
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
495
- )
496
- block_tables = torch.zeros(
497
- self.rbln_config.batch_size,
498
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
499
- dtype=torch.int16,
500
- ).fill_(-1)
501
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
86
+ self.setup_runtime()
502
87
 
88
+ def setup_runtime(self):
89
+ # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
91
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
+ out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
93
+
94
+ common_kwargs = {
95
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
96
+ "embed_tokens": self.embed_tokens,
97
+ "dec_attn_mask": dec_attn_mask,
98
+ "page_table_manager": page_table_manager,
99
+ "rbln_config": self.rbln_config,
100
+ }
503
101
  self.prefill_decoder = RBLNRuntimeModel(
504
102
  runtime=self.model[0],
505
- main_input_name=main_input_name,
506
- embed_tokens=self.embed_tokens,
507
103
  phase="prefill",
508
104
  batch_size=self.rbln_config.batch_size,
509
- dec_attn_mask=dec_attn_mask,
510
- block_tables=block_tables,
511
- free_block_pool=free_block_pool,
512
- rbln_config=self.rbln_config,
513
- vocab_size=self.config.vocab_size,
105
+ out_buffers=out_buffers,
106
+ **common_kwargs,
514
107
  )
108
+ if self.can_generate():
109
+ self.decoders = {}
110
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
111
+ self.decoders[batch_size] = RBLNRuntimeModel(
112
+ runtime=self.model[i + 1],
113
+ phase="decode",
114
+ batch_size=batch_size,
115
+ **common_kwargs,
116
+ )
515
117
 
516
- self.decoders = {}
517
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
518
- self.decoders[batch_size] = RBLNRuntimeModel(
519
- runtime=self.model[i + 1],
520
- main_input_name=main_input_name,
521
- embed_tokens=self.embed_tokens,
522
- phase="decode",
523
- batch_size=batch_size,
524
- dec_attn_mask=dec_attn_mask,
525
- block_tables=block_tables,
526
- free_block_pool=free_block_pool,
527
- rbln_config=self.rbln_config,
528
- )
529
-
530
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
531
- self.decoder = self.decoders[self.rbln_config.batch_size]
532
-
533
- @classmethod
534
- def save_torch_artifacts(
535
- cls,
536
- model: PreTrainedModel,
537
- save_dir_path: Path,
538
- subfolder: str,
539
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
540
- ):
541
- # If you are unavoidably running on a CPU rather than an RBLN device,
542
- # store the torch tensor, weight, etc. in this function.
543
- if rbln_config.use_inputs_embeds:
544
- save_dict = {}
545
- save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
546
- torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
547
-
548
- def _create_embedding_layer(self):
549
- with no_init_weights():
550
- embed_tokens = torch.nn.Embedding(
551
- self.config.vocab_size,
552
- self.config.hidden_size,
553
- self.config.pad_token_id,
554
- )
555
- return embed_tokens
556
-
557
- def get_input_embeddings(self):
558
- return self.embed_tokens
559
-
560
- def get_attn_impl(self) -> str:
561
- return self.rbln_config.attn_impl
118
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
119
+ self.decoder = self.decoders[self.rbln_config.batch_size]
562
120
 
563
- def get_kvcache_num_blocks(self) -> int:
564
- return self.rbln_config.kvcache_num_blocks
121
+ @property
122
+ def prefill_output_size(self):
123
+ return (
124
+ 1,
125
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
+ self.config.hidden_size,
127
+ )
565
128
 
566
129
  @classmethod
567
130
  def get_quantized_model(
@@ -575,35 +138,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
575
138
  subfolder: str = "",
576
139
  local_files_only: bool = False,
577
140
  trust_remote_code: bool = False,
141
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
578
142
  **kwargs,
579
143
  ):
580
144
  kwargs = cls.update_kwargs(kwargs)
581
145
 
582
- if config is None:
583
- config = AutoConfig.from_pretrained(
584
- model_id,
585
- use_auth_token=use_auth_token,
586
- revision=revision,
587
- force_download=force_download,
588
- cache_dir=cache_dir,
589
- trust_remote_code=trust_remote_code,
590
- **kwargs,
591
- )
592
-
593
- with no_init_weights():
594
- model = AutoModelForCausalLM.from_config(config)
595
-
596
- model = prepare_model_for_quantization(
597
- model,
146
+ return get_quantized_model(
147
+ cls.auto_model_class,
598
148
  model_id,
599
- kwargs.get("num_hidden_layers"),
600
149
  use_auth_token=use_auth_token,
601
150
  revision=revision,
602
151
  cache_dir=cache_dir,
603
152
  force_download=force_download,
604
153
  local_files_only=local_files_only,
154
+ rbln_quantization=rbln_config.quantization,
155
+ **kwargs,
605
156
  )
606
- return model
607
157
 
608
158
  def __getattr__(self, __name: str) -> Any:
609
159
  # Special method to delegate attribute access to the original Huggingface LM class.
@@ -625,233 +175,162 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
625
175
  return val
626
176
 
627
177
  @classmethod
628
- def get_pytorch_model(
629
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
630
- ) -> PreTrainedModel:
631
- if rbln_config and rbln_config.quantization:
632
- model = cls.get_quantized_model(*args, **kwargs)
633
- else:
634
- model = super().get_pytorch_model(*args, **kwargs)
178
+ def save_torch_artifacts(
179
+ cls,
180
+ model: PreTrainedModel,
181
+ save_dir_path: Path,
182
+ subfolder: str,
183
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
184
+ ):
185
+ # If you are unavoidably running on a CPU rather than an RBLN device,
186
+ # store the torch tensor, weight, etc. in this function.
187
+ if rbln_config.use_inputs_embeds:
188
+ save_dict = {}
189
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
190
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
635
191
 
636
- return model
192
+ def _create_embedding_layer(self):
193
+ with no_init_weights():
194
+ embed_tokens = torch.nn.Embedding(
195
+ self.config.vocab_size,
196
+ self.config.hidden_size,
197
+ self.config.pad_token_id,
198
+ )
199
+ return embed_tokens
637
200
 
638
- @classmethod
639
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
640
- wrapper_cfg = {
641
- "max_seq_len": rbln_config.max_seq_len,
642
- "attn_impl": rbln_config.attn_impl,
643
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
644
- "kvcache_block_size": rbln_config.kvcache_block_size,
645
- "use_rotary_emb": cls._use_rotary_emb,
646
- "use_attention_mask": rbln_config.use_attention_mask,
647
- "use_position_ids": rbln_config.use_position_ids,
648
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
649
- "cache_impl": rbln_config.cache_impl,
650
- "sliding_window": rbln_config.sliding_window,
651
- "sliding_window_layers": rbln_config.sliding_window_layers,
652
- }
653
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
201
+ def get_decoder(self):
202
+ if not self.can_generate():
203
+ raise ValueError("Decode stage is not supported in this model.")
204
+ return self.decoder
205
+
206
+ def can_generate(self):
207
+ return self.rbln_config.can_generate
208
+
209
+ def get_input_embeddings(self):
210
+ return self.embed_tokens
211
+
212
+ def get_attn_impl(self) -> str:
213
+ return self.rbln_config.attn_impl
214
+
215
+ def get_kvcache_num_blocks(self) -> int:
216
+ return self.rbln_config.kvcache_num_blocks
654
217
 
655
218
  @classmethod
656
- @torch.inference_mode()
657
- def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
658
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
219
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
+ return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
659
221
 
660
- rbln_compile_configs = rbln_config.compile_cfgs
661
- prefill_compile_config = rbln_compile_configs[0]
222
+ @classmethod
223
+ def _compile_model(
224
+ cls,
225
+ wrapped_model,
226
+ compile_config,
227
+ example_inputs,
228
+ compile_context,
229
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
230
+ quantization=None,
231
+ phase: str = "prefill",
232
+ ):
233
+ try:
234
+ wrapped_model.phase = phase
235
+ if quantization:
236
+ quantization.maybe_set_quantization_env()
237
+ original_linear = torch.nn.functional.linear
238
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
239
+ compiled_model = cls.compile(
240
+ wrapped_model,
241
+ compile_config,
242
+ create_runtimes=rbln_config.create_runtimes,
243
+ device=rbln_config.device,
244
+ example_inputs=example_inputs,
245
+ compile_context=compile_context,
246
+ )
247
+ return compiled_model
248
+ finally:
249
+ torch.nn.functional.linear = original_linear
250
+ if quantization:
251
+ quantization.maybe_reset_quantization_env()
662
252
 
253
+ @classmethod
254
+ def _get_compile_context(
255
+ cls,
256
+ compile_config: RBLNCompileConfig,
257
+ example_inputs: List[torch.Tensor],
258
+ ):
663
259
  context = CompileContext(use_weight_sharing=True)
664
260
 
665
- # Here we use meta tensor, for the memory efficiency.
666
- meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
667
- prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
668
-
669
261
  # Mark static tensors (self kv states)
670
262
  static_tensors = {}
671
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
263
+ idx = 0
264
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
672
265
  if "past_key_values" in name:
673
266
  static_tensors[name] = tensor
674
- context.mark_static_address(tensor)
675
-
676
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
677
- try:
678
- if quantization:
679
- quantization.maybe_set_quantization_env()
680
- original_linear = torch.nn.functional.linear
681
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
682
- compiled_model = cls.compile(
683
- wrapped_model,
684
- compile_config,
685
- create_runtimes=rbln_config.create_runtimes,
686
- device=rbln_config.device,
687
- example_inputs=example_inputs,
688
- compile_context=compile_context,
689
- )
690
- return compiled_model
691
- finally:
692
- torch.nn.functional.linear = original_linear
693
- if quantization:
694
- quantization.maybe_reset_quantization_env()
695
-
696
- wrapped_model.phase = "prefill"
697
- compiled_prefill = compile_model(
698
- wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
699
- )
700
-
701
- wrapped_model.phase = "decode"
702
- compiled_models = {"prefill": compiled_prefill}
703
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
704
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
705
- compiled_decoder = compile_model(
706
- wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
707
- )
708
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
709
-
710
- # check if the memory is enough to have additional blocks
711
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
712
- if rbln_config.kvcache_num_blocks < required_num_blocks:
713
- cls.maybe_suggest_kvcache_num_blocks(
714
- compiled_models=compiled_models,
715
- model_config=model.config,
716
- rbln_config=rbln_config,
717
- )
267
+ context.mark_static_address(tensor, f"kv_cache_{idx}")
268
+ idx += 1
718
269
 
719
- return compiled_models
270
+ return context, static_tensors
720
271
 
721
272
  @classmethod
722
- def maybe_suggest_kvcache_num_blocks(
723
- cls,
724
- compiled_models: Dict[str, rebel.RBLNCompiledModel],
725
- model_config: PretrainedConfig,
726
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
727
- ) -> None:
728
- # Get the actual memory allocation of each node by key
729
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
730
- alloc_memory_by_key: Dict[str, int] = {
731
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
732
- }
733
- for batch_size in rbln_config.decoder_batch_sizes:
734
- for key, memory_per_node in (
735
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
736
- ):
737
- alloc_memory_by_key[key] += sum(memory_per_node)
738
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
739
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
740
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
741
-
742
- # Get the maximum number of blocks that can be allocated
743
- buffer = sum(alloc_memory_by_key.values())
744
- max_num_blocks = cls.get_maximum_num_blocks(
745
- config=model_config,
746
- tensor_parallel_size=rbln_config.tensor_parallel_size,
747
- kvcache_block_size=rbln_config.kvcache_block_size,
748
- kernel_size=kernel_size,
749
- buffer=buffer,
750
- )
273
+ @torch.inference_mode()
274
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
275
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
276
+ prefill_compile_config = rbln_config.compile_cfgs[0]
751
277
 
752
- # Since our estimation logic is not always accurate,
753
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
754
- # If the memory is not enough, the model will fail to compile.
755
- if rbln_config.kvcache_num_blocks < max_num_blocks:
756
- logger.warning(
757
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
758
- "Our analysis indicates that additional memory is available for more blocks. "
759
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
760
- "Please be advised that our memory estimation algorithm has limitations, "
761
- "and increasing this value may not guarantee successful model compilation."
762
- )
278
+ # Here we use meta tensor, for the memory efficiency.
279
+ meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
280
+ prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
281
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
282
+
283
+ compiled_models = {}
284
+ compiled_models["prefill"] = cls._compile_model(
285
+ wrapped_model,
286
+ prefill_compile_config,
287
+ prefill_example_inputs,
288
+ context,
289
+ rbln_config,
290
+ rbln_config.quantization,
291
+ phase="prefill",
292
+ )
763
293
 
764
- @classmethod
765
- def get_maximum_num_blocks(
766
- cls,
767
- config: PretrainedConfig,
768
- tensor_parallel_size: int,
769
- kvcache_block_size: int,
770
- nbits_per_param: Optional[int] = None,
771
- n_model_params: Optional[int] = None,
772
- kernel_size: Optional[int] = None,
773
- buffer: Optional[int] = None,
774
- num_runtimes: int = 2,
775
- ) -> int:
776
- # We are finding max_n_blocks(x) that satisfies the following equation:
777
-
778
- # available_dram - kernel_size - buffer
779
- # - num_layers * 2 * tensor_parallel_size
780
- # * align_2MB(
781
- # x
782
- # * block_size
783
- # * align_64(head_dim)
784
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
785
- # * 2
786
- # ) > 0
787
-
788
- # This inequality can be rewritten as follows:
789
-
790
- # a - c * align_2MB(b * x) > 0
791
- # where
792
- # a = available_dram - kernel_size - buffer
793
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
794
- # c = num_layers * 2 * tensor_parallel_size
795
-
796
- # We can rewrite the inequality as follows:
797
- # k > align_2MB(b*x)
798
- # where
799
- # k = a / c
800
-
801
- # After that, we can derive the following equation:
802
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
803
-
804
- def align(x: int, nbytes: int) -> int:
805
- return int(math.ceil(x / nbytes) * nbytes)
806
-
807
- def align_2MB(x: int) -> int:
808
- return align(x, 2**21)
809
-
810
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
811
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
812
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
813
- vocab_size = config.vocab_size
814
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
815
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
816
-
817
- # TODO(jongho): Update if target npu is REBEL.
818
- ATOM_DRAM_NBYTES = 16 * 2**30
819
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
820
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
821
-
822
- if kernel_size is None:
823
- if n_model_params is None:
824
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
825
- # Get estimated kernel size (approximated)
826
- lm_heads_params = align(vocab_size, 64) * hidden_size
827
- lm_heads_nbytes = (
828
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
829
- )
830
- params = n_model_params - lm_heads_params
831
- layer_nbytes = (
832
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
833
- * num_layers
834
- * tensor_parallel_size
835
- )
836
- kernel_size = layer_nbytes + lm_heads_nbytes
837
- elif n_model_params is not None:
838
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
294
+ if rbln_config.can_generate:
295
+ wrapped_model.phase = "decode"
296
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
297
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
298
+ compiled_decoder = cls._compile_model(
299
+ wrapped_model,
300
+ dec_compile_config,
301
+ dec_example_inputs,
302
+ context,
303
+ rbln_config,
304
+ rbln_config.quantization,
305
+ phase="decode",
306
+ )
307
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
308
+
309
+ # check if the memory is enough to have additional blocks
310
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
311
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
312
+ cls.maybe_suggest_kvcache_num_blocks(
313
+ compiled_models=compiled_models,
314
+ model_config=model.config,
315
+ rbln_config=rbln_config,
316
+ )
839
317
 
840
- available_dram -= kernel_size
318
+ return compiled_models
841
319
 
842
- if buffer is None:
843
- # TODO: Accurate buffer estimation
844
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
845
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
846
- buffer = buffer_per_core * tensor_parallel_size
847
- available_dram -= buffer
320
+ @classmethod
321
+ def get_pytorch_model(
322
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
323
+ ) -> PreTrainedModel:
324
+ if rbln_config and rbln_config.quantization:
325
+ model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
326
+ else:
327
+ model = super().get_pytorch_model(*args, **kwargs)
848
328
 
849
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
850
- c = num_layers * 2 * tensor_parallel_size
851
- k = available_dram / c
852
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
329
+ return model
853
330
 
854
- return max_n_blocks
331
+ @classmethod
332
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
333
+ return use_local_attention
855
334
 
856
335
  @classmethod
857
336
  def get_input_info(
@@ -861,63 +340,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
861
340
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
862
341
  model_config: PretrainedConfig,
863
342
  ):
864
- is_prefill: bool = query_length > 1
865
343
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
866
344
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
867
345
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
868
346
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
869
347
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
870
- local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
348
+ is_prefill = query_length > 1
871
349
 
872
- # 1. main input
350
+ input_info = []
873
351
  if rbln_config.use_inputs_embeds:
874
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
352
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
875
353
  else:
876
- main_input = ("input_ids", [batch_size, query_length], "int64")
877
-
878
- # 2. cache_position
879
- input_info = [
880
- main_input,
881
- (
882
- "cache_position",
883
- [batch_size, query_length],
884
- "int32",
885
- ),
886
- ]
354
+ input_info.append(("input_ids", [batch_size, query_length], "int64"))
355
+
356
+ input_info.append(("cache_position", [batch_size, query_length], "int32"))
887
357
 
888
- # 3. block_tables
889
- if rbln_config.cache_impl in ["static", "hybrid"]:
358
+ if rbln_config.use_global_attention:
890
359
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
891
- input_info.extend(
892
- [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
360
+ input_info.append(
361
+ ("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
893
362
  )
894
- if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
895
- input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
363
+ if rbln_config.use_local_attention:
364
+ input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
896
365
 
897
- # 4. query_position
898
- if is_prefill:
899
- input_info.extend([("query_position", [], "int16")])
366
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
367
+ input_info.append(("query_position", [], "int16"))
900
368
 
901
- # 5. attention_mask & position_ids
902
369
  if rbln_config.use_attention_mask:
903
- input_info.extend(
904
- [
905
- ("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
906
- if rbln_config.use_position_ids
907
- else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
908
- ]
909
- )
370
+ if rbln_config.use_position_ids:
371
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
372
+ else:
373
+ input_info.append(
374
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
375
+ )
376
+
910
377
  if rbln_config.use_position_ids:
911
378
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
912
379
 
913
- # 6. past_key_values
380
+ if rbln_config.use_lora:
381
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
382
+
383
+ kvcache_dtype = rbln_config.torch_dtype
384
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
385
+ kvcache_dtype = "float8_e4m3fn"
386
+
914
387
  global_kvcache_shape = [
915
388
  rbln_config.kvcache_num_blocks,
916
389
  num_key_value_heads,
917
390
  rbln_config.kvcache_block_size,
918
391
  head_dim,
919
392
  ]
920
- local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
393
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
921
394
  input_info.extend(
922
395
  [
923
396
  (
@@ -925,7 +398,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
925
398
  local_kvcache_shape
926
399
  if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
927
400
  else global_kvcache_shape,
928
- "float32",
401
+ kvcache_dtype,
929
402
  )
930
403
  for i in range(num_hidden_layers * 2)
931
404
  ]
@@ -964,7 +437,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
964
437
  # ```
965
438
 
966
439
  # Returns:
967
- # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
440
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
968
441
 
969
442
  raise NotImplementedError(
970
443
  "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
@@ -972,27 +445,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
972
445
  )
973
446
 
974
447
  @classmethod
975
- def _update_rbln_config(
976
- cls,
977
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
978
- model: Optional[PreTrainedModel] = None,
979
- model_config: Optional[PretrainedConfig] = None,
980
- rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
981
- ) -> RBLNDecoderOnlyModelForCausalLMConfig:
982
- if rbln_config.max_seq_len is None:
983
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
984
- model_config, "n_positions", None
985
- )
986
- if rbln_config.max_seq_len is None:
987
- raise ValueError("`max_seq_len` should be specified.")
988
-
989
- if getattr(model_config, "sliding_window", None) is not None and getattr(
990
- model_config, "use_sliding_window", True
991
- ):
992
- rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
993
- if rbln_config.sliding_window is not None:
994
- validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
995
-
448
+ def _update_attention_config(
449
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
450
+ ):
996
451
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
997
452
  attn_impl=rbln_config.attn_impl,
998
453
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -1007,40 +462,77 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1007
462
  max_seq_len=rbln_config.max_seq_len,
1008
463
  )
1009
464
 
1010
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1011
- max_num_blocks = required_num_blocks
465
+ num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1012
466
 
467
+ # Update kvcache_num_blocks based on the attention implementation.
1013
468
  if rbln_config.attn_impl == "flash_attn":
1014
- estimated_max_num_blocks = cls.get_maximum_num_blocks(
1015
- config=model_config,
1016
- tensor_parallel_size=rbln_config.tensor_parallel_size or 1,
1017
- kvcache_block_size=rbln_config.kvcache_block_size,
1018
- nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1019
- n_model_params=sum(p.numel() for p in model.parameters()),
1020
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
469
+ estimated_max_num_blocks = cls.get_maximum_num_blocks_by_model(
470
+ model=model, model_config=model_config, rbln_config=rbln_config
1021
471
  )
1022
472
 
1023
- max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
1024
-
1025
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1026
- if max_num_blocks < flash_min_blocks:
1027
- max_num_blocks = flash_min_blocks
1028
-
1029
- if max_num_blocks < rbln_config.batch_size:
1030
- raise RuntimeError(
1031
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
1032
- "Ensure the number of blocks is at least equal to the batch size."
473
+ if rbln_config.kvcache_num_blocks is None:
474
+ if estimated_max_num_blocks < num_full_blocks:
475
+ # lower bound of the number of blocks for flash attention.
476
+ min_blocks_for_flash = min(
477
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
478
+ )
479
+ if min_blocks_for_flash > estimated_max_num_blocks:
480
+ # NOTE: Just try to compile with lower bound of blocks for flash attention.
481
+ # Even if it's larger than the estimated maximum number of blocks.
482
+ rbln_config.kvcache_num_blocks = min_blocks_for_flash
483
+ else:
484
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
485
+ rbln_config.kvcache_num_blocks = estimated_max_num_blocks
486
+
487
+ if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
488
+ raise RuntimeError(
489
+ f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
490
+ "Ensure the number of blocks is at least equal to the batch size."
491
+ )
492
+ else:
493
+ rbln_config.kvcache_num_blocks = num_full_blocks
494
+ elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
495
+ logger.warning(
496
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
497
+ f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
498
+ "This can cause a failure during model compilation."
499
+ )
500
+ else:
501
+ if rbln_config.kvcache_num_blocks is None:
502
+ rbln_config.kvcache_num_blocks = num_full_blocks
503
+ elif rbln_config.kvcache_num_blocks > num_full_blocks:
504
+ logger.warning(
505
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
506
+ f" than the required number of blocks ({num_full_blocks})."
507
+ "This can cause a failure during model compilation."
1033
508
  )
509
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1034
510
 
1035
- if rbln_config.kvcache_num_blocks is None:
1036
- rbln_config.kvcache_num_blocks = max_num_blocks
1037
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
1038
- logger.warning(
1039
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
1040
- f" than the estimated maximum number of blocks ({max_num_blocks})."
1041
- "This can cause a failure during model compilation."
511
+ return rbln_config
512
+
513
+ @classmethod
514
+ def _update_rbln_config(
515
+ cls,
516
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
517
+ model: Optional[PreTrainedModel] = None,
518
+ model_config: Optional[PretrainedConfig] = None,
519
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
520
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
521
+ if rbln_config.max_seq_len is None:
522
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
523
+ model_config, "n_positions", None
1042
524
  )
1043
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
525
+ if rbln_config.max_seq_len is None:
526
+ raise ValueError("`max_seq_len` should be specified.")
527
+
528
+ if getattr(model_config, "sliding_window", None) is not None and getattr(
529
+ model_config, "use_sliding_window", True
530
+ ):
531
+ rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
532
+ if rbln_config.sliding_window is not None:
533
+ validate_sliding_window(rbln_config)
534
+
535
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
1044
536
 
1045
537
  prefill_input_info = cls.get_input_info(
1046
538
  batch_size=1,
@@ -1050,19 +542,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1050
542
  )
1051
543
 
1052
544
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
1053
-
1054
- dec_compile_configs = []
1055
- for batch_size in rbln_config.decoder_batch_sizes:
1056
- dec_input_info = cls.get_input_info(
1057
- batch_size=batch_size,
1058
- query_length=1,
1059
- rbln_config=rbln_config,
1060
- model_config=model_config,
1061
- )
1062
- dec_compile_configs.append(
1063
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1064
- )
1065
- rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
545
+ compile_cfgs = [prefill_compile_config]
546
+
547
+ if rbln_config.can_generate:
548
+ for batch_size in rbln_config.decoder_batch_sizes:
549
+ dec_input_info = cls.get_input_info(
550
+ batch_size=batch_size,
551
+ query_length=1,
552
+ rbln_config=rbln_config,
553
+ model_config=model_config,
554
+ )
555
+ compile_cfgs.append(
556
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
557
+ )
558
+ rbln_config.set_compile_cfgs(compile_cfgs)
1066
559
 
1067
560
  return rbln_config
1068
561
 
@@ -1072,101 +565,164 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1072
565
  compiled_models: List[rebel.RBLNCompiledModel],
1073
566
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1074
567
  ) -> List[rebel.Runtime]:
1075
- expected_model_names = [
1076
- "prefill",
1077
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1078
- ]
568
+ expected_model_names = ["prefill"]
569
+ if rbln_config.can_generate:
570
+ expected_model_names.extend(
571
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
572
+ )
1079
573
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1080
574
  cls._raise_missing_compiled_file_error(expected_model_names)
1081
575
 
1082
- return [
576
+ ret_val = [
1083
577
  rebel.Runtime(
1084
578
  compiled_models[0],
1085
579
  tensor_type="pt",
1086
580
  device=rbln_config.device_map["prefill"],
1087
581
  activate_profiler=rbln_config.activate_profiler,
1088
- ),
1089
- *[
1090
- rebel.Runtime(
1091
- compiled_models[i + 1],
1092
- tensor_type="pt",
1093
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1094
- activate_profiler=rbln_config.activate_profiler,
1095
- )
1096
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1097
- ],
582
+ timeout=rbln_config.timeout,
583
+ )
1098
584
  ]
585
+ if rbln_config.can_generate:
586
+ ret_val.extend(
587
+ [
588
+ rebel.Runtime(
589
+ compiled_models[i + 1],
590
+ tensor_type="pt",
591
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
592
+ activate_profiler=rbln_config.activate_profiler,
593
+ timeout=rbln_config.timeout,
594
+ )
595
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
596
+ ]
597
+ )
598
+ return ret_val
1099
599
 
1100
- def get_decoder(self):
1101
- return self.decoder
1102
-
1103
- def can_generate(self):
1104
- return True
1105
-
1106
- def _reorder_cache(self, past_key_values, beam_idx):
1107
- raise NotImplementedError
1108
-
1109
- def prepare_inputs_for_generation(
600
+ def forward(
1110
601
  self,
1111
- input_ids: torch.LongTensor,
1112
- generate_idx: Optional[torch.Tensor] = None,
1113
- attention_mask: Optional[torch.LongTensor] = None,
602
+ input_ids: Optional[torch.LongTensor] = None,
1114
603
  inputs_embeds: Optional[torch.Tensor] = None,
1115
- padded_cache_lengths: Optional[torch.Tensor] = None,
604
+ attention_mask: Optional[torch.LongTensor] = None,
1116
605
  **kwargs,
1117
- ):
1118
- model_inputs = {}
1119
- is_prefill_phase = generate_idx is None
606
+ ) -> BaseModelOutputWithPast:
607
+ """
608
+ Args:
609
+ input_ids (torch.LongTensor, optional): The input IDs to the model.
610
+ inputs_embeds (torch.Tensor, optional): The input embeddings to the model.
611
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
612
+ kwargs (dict[str, Any], optional): Additional keyword arguments.
1120
613
 
1121
- if is_prefill_phase:
1122
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1123
- padded_cache_lengths = torch.zeros_like(generate_idx)
1124
- cache_position = None
1125
- position_ids = None
1126
- else:
1127
- if inputs_embeds is not None:
1128
- # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1129
- inputs_embeds = None
1130
-
1131
- input_ids = input_ids[:, -1:]
1132
- position_ids = generate_idx
1133
- cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
1134
- generate_idx = generate_idx + 1
1135
- model_inputs.update({"input_ids": input_ids})
1136
-
1137
- if inputs_embeds is not None:
1138
- if self.rbln_config.use_inputs_embeds:
1139
- model_inputs.update({"inputs_embeds": inputs_embeds})
1140
- else:
1141
- raise ValueError(
1142
- "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
1143
- )
1144
- else:
1145
- model_inputs.update({"input_ids": input_ids})
1146
-
1147
- model_inputs.update(
1148
- {
1149
- "attention_mask": attention_mask,
1150
- "cache_position": cache_position,
1151
- "generate_idx": generate_idx,
1152
- "position_ids": position_ids,
1153
- "padded_cache_lengths": padded_cache_lengths,
1154
- }
614
+ Returns:
615
+ Dataclass containing the last hidden states of the model.
616
+ """
617
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
618
+ batch_size = inputs.shape[0]
619
+ position_embed = kwargs.get("position_embed", None)
620
+
621
+ if batch_size != self.rbln_config.batch_size:
622
+ raise ValueError(
623
+ f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
624
+ )
625
+
626
+ all_last_hidden_states = []
627
+ for b_idx in range(self.rbln_config.batch_size):
628
+ query_length = (
629
+ attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
630
+ )
631
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
632
+ last_hidden_states = self.prefill_decoder(
633
+ inputs[b_idx : b_idx + 1],
634
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
635
+ position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
636
+ cache_position=cache_position,
637
+ batch_idx=b_idx,
638
+ ).logits
639
+ all_last_hidden_states.append(last_hidden_states)
640
+
641
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
642
+
643
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
644
+
645
+
646
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
647
+ """
648
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
649
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
650
+
651
+ The class provides core functionality for:
652
+
653
+ 1. Converting pre-trained transformer models to RBLN-optimized format
654
+ 2. Handling the compilation process for RBLN devices
655
+ 3. Managing inference operations for causal language modeling
656
+ This class inherits from RBLNModel and implements specific methods required for
657
+ decoder-only architectures and causal language modeling tasks.
658
+
659
+ Note:
660
+ - This class is designed to be subclassed by specific model implementations
661
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
662
+ - Subclasses should implement model-specific conversion logic.
663
+ - The class handles RBLN-specific optimizations automatically during compilation
664
+ """
665
+
666
+ auto_model_class = AutoModelForCausalLM
667
+
668
+ @property
669
+ def prefill_output_size(self):
670
+ return (
671
+ 1,
672
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
673
+ self.config.vocab_size,
1155
674
  )
1156
675
 
1157
- return model_inputs
676
+ @classmethod
677
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
678
+ return is_prefill
679
+
680
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
681
+ if isinstance(lora_int_ids, int):
682
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
683
+ elif isinstance(lora_int_ids, list):
684
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
1158
685
 
1159
- def _update_model_kwargs_for_generation(
1160
- self,
1161
- outputs: RBLNDecoderOnlyOutput,
1162
- model_kwargs: Dict[str, Any],
1163
- **kwargs,
1164
- ) -> Dict[str, Any]:
1165
- # update generate_idx
1166
- model_kwargs["generate_idx"] = outputs.generate_idx
1167
- model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
686
+ self.lora_int_ids = lora_int_ids
1168
687
 
1169
- return model_kwargs
688
+ self.prefill_decoder.lora_int_ids = lora_int_ids
689
+ if self.rbln_config.can_generate:
690
+ for batch_size in self.rbln_config.decoder_batch_sizes:
691
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
692
+
693
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
694
+ """
695
+ Sets the active adapter(s) for the model using adapter name(s).
696
+
697
+ Args:
698
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
699
+ Can be a single adapter name or a list of adapter names.
700
+
701
+ Raises:
702
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
703
+ """
704
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
705
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
706
+
707
+ # Convert single adapter name to list for uniform processing
708
+ if isinstance(adapter_name, str):
709
+ adapter_names = [adapter_name]
710
+ else:
711
+ adapter_names = adapter_name
712
+
713
+ # Validate that all adapter names exist
714
+ available_adapters = {
715
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
716
+ }
717
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
718
+ if missing_adapters:
719
+ raise ValueError(
720
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
721
+ )
722
+
723
+ # Get the adapter IDs and set them
724
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
725
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
1170
726
 
1171
727
  def forward(
1172
728
  self,
@@ -1178,6 +734,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1178
734
  padded_cache_lengths: Optional[torch.Tensor] = None,
1179
735
  position_ids: Optional[torch.Tensor] = None,
1180
736
  token_type_ids: Optional[torch.Tensor] = None,
737
+ lora_int_ids: Optional[torch.Tensor] = None,
1181
738
  return_dict: Optional[torch.Tensor] = None,
1182
739
  **kwargs,
1183
740
  ) -> Tuple[torch.FloatTensor]:
@@ -1185,12 +742,38 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1185
742
  # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1186
743
  # A for-loop ensures synchronization with the HuggingFace generate API.
1187
744
  # The decoder stage operates as usual, processing inputs in batch mode.
745
+ if self.rbln_config.use_lora and lora_int_ids is None:
746
+ if self.lora_int_ids is None:
747
+ raise ValueError(
748
+ "lora_int_id is required when using LoRA. "
749
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
750
+ )
751
+ lora_int_ids = self.lora_int_ids
752
+
753
+ # for only use forward
754
+ if generate_idx is None:
755
+ generate_idx = (
756
+ attention_mask.sum(dim=-1, keepdim=True).int()
757
+ if attention_mask is not None
758
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
759
+ )
760
+ padded_cache_lengths = torch.zeros_like(generate_idx)
1188
761
 
1189
- # Prefll
762
+ # Prefill
1190
763
  if cache_position is None:
1191
764
  logits = []
1192
765
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1193
766
  batch_size = inputs.shape[0]
767
+ input_len = inputs.shape[1]
768
+ if batch_size > self.rbln_config.batch_size:
769
+ raise ValueError(
770
+ f"Input's batch({batch_size}) exceeds compiled batch_size({self.rbln_config.batch_size})"
771
+ )
772
+ if input_len > self.rbln_config.max_seq_len:
773
+ raise ValueError(
774
+ f"Input's length({input_len}) exceeds compiled max_seq_len({self.rbln_config.max_seq_len})."
775
+ )
776
+
1194
777
  for b_idx in range(batch_size):
1195
778
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
1196
779
  output = self.prefill_decoder(
@@ -1200,6 +783,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1200
783
  cache_position=cache_position,
1201
784
  batch_idx=b_idx,
1202
785
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
786
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
1203
787
  )
1204
788
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
1205
789
  logits.append(output.logits)
@@ -1214,11 +798,21 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1214
798
  f"Available batch sizes are: {list(self.decoders.keys())}. "
1215
799
  f"Please run your model with one of these batch sizes or add support for batch size {batch_size}."
1216
800
  )
801
+ if max(cache_position.reshape(-1)) >= self.rbln_config.max_seq_len:
802
+ raise ValueError(
803
+ f"Cache position exceeds the maximum sequence length.\n"
804
+ f" - Current max cache position: {int(torch.max(cache_position).item())}\n"
805
+ f" - Allowed max_seq_len: {self.rbln_config.max_seq_len}\n"
806
+ f"Solution: Reduce the generation length by adjusting `max_new_tokens` "
807
+ f"or `max_length` in the generation config."
808
+ )
809
+
1217
810
  logits = self.decoders[batch_size](
1218
811
  input_ids=input_ids,
1219
812
  inputs_embeds=inputs_embeds,
1220
813
  cache_position=cache_position,
1221
814
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
815
+ lora_int_ids=lora_int_ids,
1222
816
  ).logits
1223
817
 
1224
818
  if not return_dict: