optimum-rbln 0.8.2a4__py3-none-any.whl → 0.9.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (167) hide show
  1. optimum/rbln/__init__.py +96 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +153 -42
  5. optimum/rbln/diffusers/__init__.py +7 -0
  6. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  8. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  10. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +9 -4
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  12. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  13. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  14. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  15. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  20. optimum/rbln/diffusers/modeling_diffusers.py +30 -14
  21. optimum/rbln/diffusers/models/__init__.py +3 -13
  22. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +31 -3
  23. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +28 -3
  24. optimum/rbln/diffusers/models/autoencoders/vq_model.py +31 -3
  25. optimum/rbln/diffusers/models/transformers/prior_transformer.py +1 -1
  26. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +9 -1
  27. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +9 -1
  28. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +6 -3
  29. optimum/rbln/diffusers/pipelines/__init__.py +11 -5
  30. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  31. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +19 -16
  32. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +14 -18
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  34. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  35. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  36. optimum/rbln/modeling.py +71 -19
  37. optimum/rbln/modeling_base.py +99 -21
  38. optimum/rbln/ops/attn.py +158 -0
  39. optimum/rbln/ops/flash_attn.py +166 -0
  40. optimum/rbln/ops/kv_cache_update.py +5 -0
  41. optimum/rbln/ops/linear.py +7 -0
  42. optimum/rbln/transformers/__init__.py +92 -0
  43. optimum/rbln/transformers/configuration_generic.py +9 -7
  44. optimum/rbln/transformers/modeling_attention_utils.py +252 -0
  45. optimum/rbln/transformers/modeling_generic.py +51 -9
  46. optimum/rbln/transformers/modeling_outputs.py +37 -0
  47. optimum/rbln/transformers/models/__init__.py +91 -30
  48. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  49. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  50. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  51. optimum/rbln/transformers/models/bart/bart_architecture.py +1 -3
  52. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  53. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  54. optimum/rbln/transformers/models/bert/modeling_bert.py +8 -4
  55. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +94 -30
  57. optimum/rbln/transformers/models/clip/configuration_clip.py +10 -7
  58. optimum/rbln/transformers/models/clip/modeling_clip.py +27 -4
  59. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  60. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  61. optimum/rbln/transformers/models/colpali/modeling_colpali.py +113 -96
  62. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  63. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  64. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  65. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  66. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  67. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +109 -37
  68. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  69. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +318 -309
  70. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +504 -0
  71. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +111 -0
  72. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  73. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +453 -897
  74. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  75. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  76. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +25 -0
  77. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  78. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  79. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  80. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  81. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  82. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -13
  83. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  84. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  85. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +201 -349
  86. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  87. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  88. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  89. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  90. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  91. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  92. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  93. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1032 -0
  94. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  95. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +26 -27
  96. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  97. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  98. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  99. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  100. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  101. optimum/rbln/transformers/models/llava/modeling_llava.py +478 -0
  102. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +15 -17
  103. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +235 -375
  104. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  105. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  106. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  107. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  108. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  109. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  110. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  111. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  112. optimum/rbln/transformers/models/opt/modeling_opt.py +28 -16
  113. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  114. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  115. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  116. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  117. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  118. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  119. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  120. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  121. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  122. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  123. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  124. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +310 -0
  125. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  126. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  127. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  128. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  129. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  130. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -21
  131. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  132. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  133. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  134. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +514 -0
  135. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  136. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +2 -2
  137. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +86 -330
  138. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +1 -245
  139. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +20 -13
  140. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +24 -3
  141. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  142. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  143. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  144. optimum/rbln/transformers/models/siglip/modeling_siglip.py +5 -16
  145. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  146. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  147. optimum/rbln/transformers/models/swin/modeling_swin.py +341 -0
  148. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  149. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  150. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  151. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -14
  152. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  153. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -0
  154. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  155. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  156. optimum/rbln/transformers/models/whisper/modeling_whisper.py +28 -3
  157. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  158. optimum/rbln/transformers/utils/rbln_quantization.py +391 -75
  159. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  160. optimum/rbln/utils/depreacate_utils.py +16 -0
  161. optimum/rbln/utils/runtime_utils.py +28 -18
  162. optimum/rbln/utils/submodule.py +31 -9
  163. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/METADATA +8 -7
  164. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/RECORD +167 -125
  165. optimum_rbln-0.9.3rc0.dist-info/entry_points.txt +2 -0
  166. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/WHEEL +0 -0
  167. {optimum_rbln-0.8.2a4.dist-info → optimum_rbln-0.9.3rc0.dist-info}/licenses/LICENSE +0 -0
@@ -13,31 +13,31 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- import math
17
- from collections import deque
18
- from dataclasses import dataclass
19
16
  from pathlib import Path
20
- from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Union
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
21
18
 
22
19
  import rebel
23
20
  import torch
24
21
  from rebel.compile_context import CompileContext
25
- from transformers import AutoConfig, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
22
+ from transformers import AutoModel, AutoModelForCausalLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.modeling_outputs import BaseModelOutputWithPast
26
24
  from transformers.modeling_utils import no_init_weights
27
- from transformers.utils import ModelOutput
28
25
 
29
26
  from ....configuration_utils import RBLNCompileConfig
30
27
  from ....modeling import RBLNModel
31
28
  from ....utils.logging import get_logger
32
- from ....utils.runtime_utils import RBLNPytorchRuntime
33
- from ...utils.rbln_quantization import prepare_model_for_quantization
34
- from .configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
35
- from .decoderonly_architecture import (
36
- DecoderOnlyWrapper,
29
+ from ...modeling_attention_utils import (
30
+ RBLNDecoderOnlyFlashAttentionMixin,
37
31
  set_default_values,
38
32
  validate_attention_method,
39
- validate_sliding_window_size,
33
+ validate_sliding_window,
40
34
  )
35
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
36
+ from ...utils.rbln_quantization import get_quantized_model
37
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
38
+ from .decoderonly_architecture import DecoderOnlyWrapper
39
+ from .decoderonly_runtime_utils import RBLNPageTableManager, RBLNRuntimeModel
40
+ from .generation_decoderonly import RBLNDecoderOnlyGenerationMixin
41
41
 
42
42
 
43
43
  logger = get_logger()
@@ -46,529 +46,85 @@ if TYPE_CHECKING:
46
46
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
47
47
 
48
48
 
49
- class RBLNRuntimeModel(RBLNPytorchRuntime):
50
- mandatory_members = ["main_input_name", "embed_tokens"]
51
-
52
- def __init__(
53
- self,
54
- runtime: rebel.Runtime,
55
- phase: str,
56
- batch_size: int,
57
- dec_attn_mask: torch.Tensor,
58
- block_tables: torch.Tensor,
59
- free_block_pool: Deque,
60
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
61
- **kwargs: Any,
62
- ) -> None:
63
- super().__init__(runtime, **kwargs)
64
- self.phase = phase
65
- self.batch_size = batch_size
66
- self.rbln_config = rbln_config
67
-
68
- # shared tensor between prefill and decode phase
69
- self.dec_attn_mask = dec_attn_mask
70
- self.block_tables = block_tables
71
- self.free_block_pool = free_block_pool
72
-
73
- self.empty_block = -1
74
- if self.phase == "prefill":
75
- vocab_size = kwargs.pop("vocab_size")
76
- self.output_size = [1, 1, vocab_size]
77
- self.causal_mask = 1 - torch.triu(
78
- torch.ones(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.prefill_chunk_size), diagonal=1
79
- )
80
-
81
- def get_block_tables(self, cache_position: torch.Tensor, batch_idx: int = None) -> torch.Tensor:
82
- """
83
- Manages and returns the KV cache block tables.
84
- Updates the block tables based on the given cache_position, allocating new blocks or reusing existing ones as needed.
85
-
86
- Args:
87
- cache_position (torch.Tensor): Tensor containing cache position information, indicating positions within the cache for each batch item.
88
- batch_idx (int, optional): Specific batch index, used when phase is 'prefill'.
89
-
90
- Returns:
91
- Updated block tables.
92
- """
93
-
94
- NO_BLOCKS_ERROR = (
95
- "No memory blocks are available for allocation. "
96
- "The generate() API cannot complete this inference task because Paged Attention is not fully supported by optimum-rbln. "
97
- "This is supported by vllm-rbln (see: https://docs.rbln.ai/software/model_serving/vllm_support/vllm-rbln.html). "
98
- "Using vllm-rbln should fix this issue and enhance inference performance."
99
- )
100
-
101
- def update_block(batch_idx: int, block_idx: int):
102
- """
103
- If the block is empty (empty_block), allocates a block from the free_block_pool.
104
- """
105
- if self.block_tables[batch_idx][block_idx] == self.empty_block:
106
- if self.free_block_pool:
107
- block = self.free_block_pool.popleft()
108
- self.block_tables[batch_idx][block_idx] = block
109
- else:
110
- raise RuntimeError(NO_BLOCKS_ERROR)
111
-
112
- def replace_empty_block(block_tables: torch.Tensor):
113
- """
114
- Replaces all occurrences of `self.empty_block` in `block_tables` with a dummy block from `self.free_block_pool`.
115
- """
116
- if not torch.any(block_tables == self.empty_block):
117
- return block_tables.clone()
118
- elif self.free_block_pool:
119
- _free_block = self.free_block_pool[0]
120
- return torch.where(block_tables == self.empty_block, _free_block, block_tables)
121
- else:
122
- raise RuntimeError(NO_BLOCKS_ERROR)
123
-
124
- def get_global_block_tables(batch_idx: int):
125
- if self.rbln_config.cache_impl == "sliding_window":
126
- return None
127
-
128
- if self.phase == "prefill":
129
- # Track previously used blocks and return them to the free_block_pool and
130
- # reset the current batch's block table to empty blocks
131
- prev_blocks = self.block_tables[batch_idx][self.block_tables[batch_idx] != self.empty_block].tolist()
132
- self.free_block_pool.extend(prev_blocks)
133
- self.block_tables[batch_idx].fill_(self.empty_block)
134
-
135
- # Get the start (s) and end (e) positions from cache_position and
136
- # iterate over the cache positions to allocate necessary blocks
137
- s, e = cache_position[0][0].item(), cache_position[0][-1].item()
138
- for position in range(s, e + 1, self.rbln_config.kvcache_block_size):
139
- block_idx = position // self.rbln_config.kvcache_block_size
140
- if batch_idx >= len(self.block_tables) or block_idx >= len(self.block_tables[batch_idx]):
141
- raise IndexError(f"Invalid index: batch_idx={batch_idx}, block_idx={block_idx}")
142
- update_block(batch_idx, block_idx)
143
-
144
- return replace_empty_block(self.block_tables[batch_idx])
145
- # Case for 'decoder' phase, iterate over the cache positions to allocate necessary blocks
146
- else:
147
- for b_idx in range(self.batch_size):
148
- position = cache_position[b_idx][0].item()
149
- block_idx = position // self.rbln_config.kvcache_block_size
150
- update_block(b_idx, block_idx)
151
-
152
- return replace_empty_block(self.block_tables)
153
-
154
- def get_local_block_tables(batch_idx: int):
155
- if self.rbln_config.cache_impl == "static":
156
- return None
157
- else:
158
- return (
159
- torch.tensor([batch_idx], dtype=torch.int16)
160
- if self.phase == "prefill"
161
- else torch.arange(self.batch_size, dtype=torch.int16).view(self.batch_size, -1)
162
- )
163
-
164
- return get_global_block_tables(batch_idx), get_local_block_tables(batch_idx)
165
-
166
- def is_external_block_tables(
167
- self, block_tables: Optional[torch.Tensor], local_block_tables: Optional[torch.Tensor]
168
- ):
169
- if self.rbln_config.cache_impl == "static" and block_tables is None:
170
- return False
171
- elif self.rbln_config.cache_impl == "sliding_window" and local_block_tables is None:
172
- return False
173
- elif self.rbln_config.cache_impl == "hybrid":
174
- if (block_tables is not None) != (local_block_tables is not None):
175
- raise ValueError(
176
- "Both block_tables and local_block_tables must be provided or neither of them must be provided."
177
- )
178
- elif block_tables is None and local_block_tables is None:
179
- return False
180
-
181
- return True
182
-
183
- def forward(
184
- self,
185
- input_ids: Optional[torch.LongTensor] = None,
186
- inputs_embeds: Optional[torch.Tensor] = None,
187
- cache_position: torch.Tensor = None,
188
- attention_mask: Optional[torch.Tensor] = None,
189
- batch_idx: Optional[int] = None,
190
- block_tables: Optional[torch.Tensor] = None,
191
- position_embed: Optional[torch.Tensor] = None,
192
- position_ids: Optional[torch.Tensor] = None,
193
- token_type_ids: Optional[torch.Tensor] = None,
194
- local_block_tables: Optional[torch.Tensor] = None,
195
- ):
196
- if input_ids is None and inputs_embeds is None:
197
- raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
198
-
199
- if inputs_embeds is None:
200
- inputs = input_ids
201
- if self.embed_tokens is not None:
202
- inputs = self.embed_tokens(inputs)
203
- else:
204
- inputs = inputs_embeds
205
-
206
- is_external_block_tables = self.is_external_block_tables(block_tables, local_block_tables)
207
- if not is_external_block_tables:
208
- block_tables, local_block_tables = self.get_block_tables(cache_position, batch_idx=batch_idx)
209
-
210
- if self.phase == "decode":
211
- return self.decode_forward(
212
- inputs,
213
- cache_position,
214
- block_tables,
215
- is_external_block_tables,
216
- attention_mask=attention_mask,
217
- position_embed=position_embed,
218
- position_ids=position_ids,
219
- local_block_tables=local_block_tables,
220
- )
221
- else:
222
- return self.prefill_forward(
223
- inputs,
224
- cache_position,
225
- attention_mask,
226
- batch_idx,
227
- block_tables,
228
- is_external_block_tables=is_external_block_tables,
229
- position_embed=position_embed,
230
- token_type_ids=token_type_ids,
231
- local_block_tables=local_block_tables,
232
- )
233
-
234
- def decode_forward(
235
- self,
236
- inputs: torch.Tensor,
237
- cache_position: torch.Tensor = None,
238
- block_tables: torch.Tensor = None,
239
- is_external_block_tables: bool = None,
240
- attention_mask: Optional[torch.Tensor] = None,
241
- position_embed: Optional[torch.Tensor] = None,
242
- position_ids: Optional[torch.Tensor] = None,
243
- local_block_tables: Optional[torch.Tensor] = None,
244
- ) -> torch.FloatTensor:
245
- batch_size = inputs.shape[0]
246
- if batch_size != self.batch_size:
247
- raise RuntimeError(
248
- f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
249
- )
250
-
251
- if batch_size != cache_position.shape[0]:
252
- raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
253
-
254
- if self.rbln_config.use_attention_mask and attention_mask is None:
255
- for b_idx in range(batch_size):
256
- decoding_step = cache_position[b_idx].item()
257
- if not (0 <= decoding_step < self.dec_attn_mask.shape[-1]):
258
- raise ValueError(
259
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
260
- )
261
-
262
- if is_external_block_tables:
263
- self.dec_attn_mask[b_idx].fill_(0)
264
- self.dec_attn_mask[b_idx, :, :, : decoding_step + 1] = 1
265
- else:
266
- self.dec_attn_mask[b_idx, :, :, decoding_step] = 1
267
-
268
- attention_mask = self.dec_attn_mask
269
-
270
- if self.rbln_config.cache_impl in ["hybrid", "static"] and self.batch_size < block_tables.shape[0]:
271
- block_tables = block_tables[: self.batch_size]
272
-
273
- if attention_mask is not None and self.batch_size < attention_mask.shape[0]:
274
- attention_mask = attention_mask[: self.batch_size]
275
-
276
- logits = super().forward(
277
- inputs,
278
- cache_position,
279
- block_tables,
280
- local_block_tables,
281
- position_embed,
282
- attention_mask if self.rbln_config.use_attention_mask else None,
283
- position_ids if self.rbln_config.use_position_ids else None,
284
- )
285
-
286
- return RBLNDecoderOnlyOutput(logits=logits)
287
-
288
- def _prepare_prefill_inputs(
289
- self,
290
- inputs: torch.Tensor,
291
- cache_position: torch.Tensor,
292
- attention_mask: Optional[torch.Tensor] = None,
293
- position_embed: Optional[torch.Tensor] = None,
294
- token_type_ids: Optional[torch.Tensor] = None,
295
- ):
296
- """
297
- Prepare inputs for prefill phase.
298
- """
299
- # Handle continuous batching in a compiled graph by extracting valid inputs
300
- # If an attention mask is provided, select only the valid (non-masked) inputs
301
- inputs = inputs[:, attention_mask.bool()] if attention_mask is not None else inputs
302
- if position_embed is not None:
303
- position_embed = (
304
- position_embed[:, :, :, attention_mask.bool(), :] if attention_mask is not None else position_embed
305
- )
306
- if token_type_ids is not None:
307
- token_type_ids = token_type_ids[:, attention_mask.bool()] if attention_mask is not None else token_type_ids
308
-
309
- query_length = inputs.shape[1]
310
- if query_length > self.rbln_config.max_seq_len:
311
- raise ValueError(
312
- f"Input length ({query_length}) exceeds the maximum allowed sequence length ({self.rbln_config.max_seq_len})."
313
- )
314
-
315
- # Initialize attention mask for chunked processing
316
- chunked_attention_mask = (
317
- torch.zeros(1, 1, self.rbln_config.prefill_chunk_size, self.rbln_config.max_seq_len, dtype=torch.float32)
318
- if self.rbln_config.use_attention_mask
319
- else None
320
- )
321
-
322
- # Buffer for storing output logits
323
- out_buffers = [
324
- torch.empty(
325
- size=self.output_size,
326
- dtype=torch.float32,
327
- device="cpu",
328
- )
329
- ]
330
-
331
- # Pad input and cache_position if the last chunk is smaller than `prefill_chunk_size`
332
- padding_size = 0
333
- if query_length % self.rbln_config.prefill_chunk_size != 0:
334
- padding_size = (self.rbln_config.prefill_chunk_size - query_length) % self.rbln_config.prefill_chunk_size
335
- # inputs_embeds
336
- if inputs.dim() == 3:
337
- inputs = torch.nn.functional.pad(inputs, (0, 0, 0, padding_size))
338
- # inputs_ids
339
- else:
340
- inputs = torch.nn.functional.pad(inputs, (0, padding_size))
341
-
342
- cache_position = torch.cat(
343
- [
344
- cache_position,
345
- torch.arange(
346
- query_length,
347
- query_length + padding_size,
348
- dtype=torch.int32,
349
- ).unsqueeze(0),
350
- ],
351
- dim=-1,
352
- )
353
-
354
- if position_embed is not None:
355
- position_embed = torch.nn.functional.pad(position_embed, (0, 0, 0, padding_size))
356
-
357
- if token_type_ids is not None:
358
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, padding_size), value=-1)
359
-
360
- # Overwrite position_ids and padded_cache_lengths
361
- position_ids = cache_position.clone()
362
- padded_cache_lengths = 0
363
-
364
- return (
365
- inputs,
366
- cache_position,
367
- chunked_attention_mask,
368
- out_buffers,
369
- position_ids,
370
- position_embed,
371
- padded_cache_lengths,
372
- query_length,
373
- token_type_ids,
374
- )
375
-
376
- def prefill_forward(
377
- self,
378
- inputs: torch.Tensor,
379
- cache_position: torch.Tensor = None,
380
- attention_mask: Optional[torch.Tensor] = None,
381
- batch_idx: int = None,
382
- block_tables: torch.Tensor = None,
383
- is_external_block_tables: bool = False,
384
- position_embed: Optional[torch.Tensor] = None,
385
- token_type_ids: Optional[torch.Tensor] = None,
386
- local_block_tables: Optional[torch.Tensor] = None,
387
- ) -> torch.FloatTensor:
388
- """
389
- Performs chunked prefill for efficient KV-cache updates and memory optimization.
390
- Instead of processing the entire sequence at once, the input is divided into chunks of size `prefill_chunk_size`,
391
- and each chunk is processed sequentially. This allows for better memory utilization and compatibility with continuous batching.
392
- """
393
- (
394
- inputs,
395
- cache_position,
396
- chunked_attention_mask,
397
- out_buffers,
398
- position_ids,
399
- position_embed,
400
- padded_cache_lengths,
401
- query_length,
402
- token_type_ids,
403
- ) = self._prepare_prefill_inputs(
404
- inputs, cache_position, attention_mask, position_embed, token_type_ids=token_type_ids
405
- )
406
-
407
- # Process input in chunks of size `prefill_chunk_size`
408
- for step in range(0, query_length, self.rbln_config.prefill_chunk_size):
409
- # Extract the current chunk of inputs and cache positions
410
- input_chunk = inputs[:, step : step + self.rbln_config.prefill_chunk_size]
411
- cache_pos_chunk = cache_position[:, step : step + self.rbln_config.prefill_chunk_size]
412
- position_ids_chunk = (
413
- position_ids[:, step : step + self.rbln_config.prefill_chunk_size]
414
- if position_ids is not None
415
- else None
416
- )
417
- if position_embed is not None:
418
- position_embed_chunk = position_embed[:, :, :, step : step + self.rbln_config.prefill_chunk_size, :]
419
-
420
- if self.rbln_config.use_attention_mask and not self.rbln_config.use_position_ids:
421
- # Update attention mask to ensure proper causal behavior
422
- if step >= self.rbln_config.prefill_chunk_size:
423
- chunked_attention_mask[:, :, :, step - self.rbln_config.prefill_chunk_size : step] = 1
424
- chunked_attention_mask[:, :, :, step : step + self.rbln_config.prefill_chunk_size] = self.causal_mask
425
-
426
- # Define query position
427
- if step + self.rbln_config.prefill_chunk_size >= query_length:
428
- query_position = torch.tensor(
429
- (query_length - 1) % self.rbln_config.prefill_chunk_size, dtype=torch.int16
430
- )
431
- else:
432
- query_position = torch.tensor(self.rbln_config.prefill_chunk_size - 1, dtype=torch.int16)
433
-
434
- # Forward pass for the current chunk
435
- logits = super().forward(
436
- input_chunk,
437
- cache_pos_chunk,
438
- block_tables,
439
- local_block_tables,
440
- position_embed_chunk if position_embed is not None else None,
441
- query_position,
442
- chunked_attention_mask if self.rbln_config.use_attention_mask else None,
443
- position_ids_chunk if self.rbln_config.use_position_ids else None,
444
- out=out_buffers,
445
- )
446
-
447
- # Update decoder attention mask with processed KV-cache length from prefill phase
448
- if not is_external_block_tables and self.rbln_config.use_attention_mask:
449
- self.dec_attn_mask[batch_idx].fill_(0)
450
- self.dec_attn_mask[batch_idx, :, :, :query_length] = 1
451
-
452
- return RBLNDecoderOnlyOutput(logits=logits, padded_cache_lengths=padded_cache_lengths)
453
-
454
-
455
- @dataclass
456
- class RBLNDecoderOnlyOutput(ModelOutput):
457
- logits: torch.FloatTensor = None
458
- generate_idx: torch.Tensor = None
459
- padded_cache_lengths: int = None
460
-
461
-
462
- class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
49
+ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
463
50
  """
464
- A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
51
+ A base class for decoder-only transformer models outputting raw hidden-states without any specific head on top.
52
+ This class is used for RBLN-optimized models that are not causal language models.
465
53
  This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
466
54
 
467
55
  The class provides core functionality for:
468
56
 
469
57
  1. Converting pre-trained transformer models to RBLN-optimized format
470
58
  2. Handling the compilation process for RBLN devices
471
- 3. Managing inference operations for causal language modeling
472
-
59
+ 3. Managing inference operations for decoder-only architectures
473
60
  This class inherits from RBLNModel and implements specific methods required for
474
- decoder-only architectures and causal language modeling tasks.
61
+ decoder-only architectures.
475
62
 
476
63
  Note:
477
64
  - This class is designed to be subclassed by specific model implementations
478
- (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
65
+ (e.g., RBLNLlamaModel, RBLNQwen2Model)
479
66
  - Subclasses should implement model-specific conversion logic.
480
67
  - The class handles RBLN-specific optimizations automatically during compilation
481
68
  """
482
69
 
70
+ _tp_support = True
71
+
483
72
  main_input_name = "input_ids"
484
- auto_model_class = AutoModelForCausalLM
73
+ auto_model_class = AutoModel
485
74
  _decoder_wrapper_cls = DecoderOnlyWrapper
486
75
  _use_rotary_emb = True
76
+ _supports_non_fp32 = True
487
77
 
488
78
  def __post_init__(self, **kwargs):
489
- main_input_name = self.main_input_name
490
-
491
79
  if self.rbln_config.use_inputs_embeds:
492
- main_input_name = "inputs_embeds"
493
80
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
494
81
  self.embed_tokens = self._create_embedding_layer()
495
82
  self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
496
83
  else:
497
84
  self.embed_tokens = None
498
85
 
499
- # Initialize shared resources to be used across Runtime instances (prefill and decode phases)
500
- dec_attn_mask = torch.zeros(
501
- self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=torch.float32
502
- )
503
- block_tables = torch.zeros(
504
- self.rbln_config.batch_size,
505
- self.rbln_config.max_seq_len // self.rbln_config.kvcache_block_size,
506
- dtype=torch.int16,
507
- ).fill_(-1)
508
- free_block_pool = deque(x for x in range(self.rbln_config.kvcache_num_blocks))
86
+ self.setup_runtime()
509
87
 
88
+ def setup_runtime(self):
89
+ # Initialize resources to be used across Runtime instances (prefill and decode phases)
90
+ page_table_manager = RBLNPageTableManager(self.rbln_config)
91
+ dec_attn_mask = torch.zeros(self.rbln_config.batch_size, 1, 1, self.rbln_config.max_seq_len, dtype=self.dtype)
92
+ out_buffers = [torch.empty(self.prefill_output_size, dtype=self.dtype)]
93
+
94
+ common_kwargs = {
95
+ "main_input_name": "inputs_embeds" if self.rbln_config.use_inputs_embeds else "input_ids",
96
+ "embed_tokens": self.embed_tokens,
97
+ "dec_attn_mask": dec_attn_mask,
98
+ "page_table_manager": page_table_manager,
99
+ "rbln_config": self.rbln_config,
100
+ }
510
101
  self.prefill_decoder = RBLNRuntimeModel(
511
102
  runtime=self.model[0],
512
- main_input_name=main_input_name,
513
- embed_tokens=self.embed_tokens,
514
103
  phase="prefill",
515
104
  batch_size=self.rbln_config.batch_size,
516
- dec_attn_mask=dec_attn_mask,
517
- block_tables=block_tables,
518
- free_block_pool=free_block_pool,
519
- rbln_config=self.rbln_config,
520
- vocab_size=self.config.vocab_size,
105
+ out_buffers=out_buffers,
106
+ **common_kwargs,
521
107
  )
108
+ if self.can_generate():
109
+ self.decoders = {}
110
+ for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
111
+ self.decoders[batch_size] = RBLNRuntimeModel(
112
+ runtime=self.model[i + 1],
113
+ phase="decode",
114
+ batch_size=batch_size,
115
+ **common_kwargs,
116
+ )
522
117
 
523
- self.decoders = {}
524
- for i, batch_size in enumerate(self.rbln_config.decoder_batch_sizes):
525
- self.decoders[batch_size] = RBLNRuntimeModel(
526
- runtime=self.model[i + 1],
527
- main_input_name=main_input_name,
528
- embed_tokens=self.embed_tokens,
529
- phase="decode",
530
- batch_size=batch_size,
531
- dec_attn_mask=dec_attn_mask,
532
- block_tables=block_tables,
533
- free_block_pool=free_block_pool,
534
- rbln_config=self.rbln_config,
535
- )
536
-
537
- # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
538
- self.decoder = self.decoders[self.rbln_config.batch_size]
539
-
540
- @classmethod
541
- def save_torch_artifacts(
542
- cls,
543
- model: PreTrainedModel,
544
- save_dir_path: Path,
545
- subfolder: str,
546
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
547
- ):
548
- # If you are unavoidably running on a CPU rather than an RBLN device,
549
- # store the torch tensor, weight, etc. in this function.
550
- if rbln_config.use_inputs_embeds:
551
- save_dict = {}
552
- save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
553
- torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
554
-
555
- def _create_embedding_layer(self):
556
- with no_init_weights():
557
- embed_tokens = torch.nn.Embedding(
558
- self.config.vocab_size,
559
- self.config.hidden_size,
560
- self.config.pad_token_id,
561
- )
562
- return embed_tokens
563
-
564
- def get_input_embeddings(self):
565
- return self.embed_tokens
566
-
567
- def get_attn_impl(self) -> str:
568
- return self.rbln_config.attn_impl
118
+ # NOTE(eunji): Use a decoder whose batch size matches the model's main batch size for compatibility.
119
+ self.decoder = self.decoders[self.rbln_config.batch_size]
569
120
 
570
- def get_kvcache_num_blocks(self) -> int:
571
- return self.rbln_config.kvcache_num_blocks
121
+ @property
122
+ def prefill_output_size(self):
123
+ return (
124
+ 1,
125
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
126
+ self.config.hidden_size,
127
+ )
572
128
 
573
129
  @classmethod
574
130
  def get_quantized_model(
@@ -582,35 +138,22 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
582
138
  subfolder: str = "",
583
139
  local_files_only: bool = False,
584
140
  trust_remote_code: bool = False,
141
+ rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None,
585
142
  **kwargs,
586
143
  ):
587
144
  kwargs = cls.update_kwargs(kwargs)
588
145
 
589
- if config is None:
590
- config = AutoConfig.from_pretrained(
591
- model_id,
592
- use_auth_token=use_auth_token,
593
- revision=revision,
594
- force_download=force_download,
595
- cache_dir=cache_dir,
596
- trust_remote_code=trust_remote_code,
597
- **kwargs,
598
- )
599
-
600
- with no_init_weights():
601
- model = AutoModelForCausalLM.from_config(config)
602
-
603
- model = prepare_model_for_quantization(
604
- model,
146
+ return get_quantized_model(
147
+ cls.auto_model_class,
605
148
  model_id,
606
- kwargs.get("num_hidden_layers"),
607
149
  use_auth_token=use_auth_token,
608
150
  revision=revision,
609
151
  cache_dir=cache_dir,
610
152
  force_download=force_download,
611
153
  local_files_only=local_files_only,
154
+ rbln_quantization=rbln_config.quantization,
155
+ **kwargs,
612
156
  )
613
- return model
614
157
 
615
158
  def __getattr__(self, __name: str) -> Any:
616
159
  # Special method to delegate attribute access to the original Huggingface LM class.
@@ -632,233 +175,162 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
632
175
  return val
633
176
 
634
177
  @classmethod
635
- def get_pytorch_model(
636
- cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None, **kwargs
637
- ) -> PreTrainedModel:
638
- if rbln_config and rbln_config.quantization:
639
- model = cls.get_quantized_model(*args, **kwargs)
640
- else:
641
- model = super().get_pytorch_model(*args, **kwargs)
178
+ def save_torch_artifacts(
179
+ cls,
180
+ model: PreTrainedModel,
181
+ save_dir_path: Path,
182
+ subfolder: str,
183
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
184
+ ):
185
+ # If you are unavoidably running on a CPU rather than an RBLN device,
186
+ # store the torch tensor, weight, etc. in this function.
187
+ if rbln_config.use_inputs_embeds:
188
+ save_dict = {}
189
+ save_dict["embed_tokens"] = model.get_input_embeddings().state_dict()
190
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
642
191
 
643
- return model
192
+ def _create_embedding_layer(self):
193
+ with no_init_weights():
194
+ embed_tokens = torch.nn.Embedding(
195
+ self.config.vocab_size,
196
+ self.config.hidden_size,
197
+ self.config.pad_token_id,
198
+ )
199
+ return embed_tokens
200
+
201
+ def get_decoder(self):
202
+ if not self.can_generate():
203
+ raise ValueError("Decode stage is not supported in this model.")
204
+ return self.decoder
205
+
206
+ def can_generate(self):
207
+ return self.rbln_config.can_generate
208
+
209
+ def get_input_embeddings(self):
210
+ return self.embed_tokens
211
+
212
+ def get_attn_impl(self) -> str:
213
+ return self.rbln_config.attn_impl
214
+
215
+ def get_kvcache_num_blocks(self) -> int:
216
+ return self.rbln_config.kvcache_num_blocks
644
217
 
645
218
  @classmethod
646
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelForCausalLMConfig"):
647
- wrapper_cfg = {
648
- "max_seq_len": rbln_config.max_seq_len,
649
- "attn_impl": rbln_config.attn_impl,
650
- "kvcache_partition_len": rbln_config.kvcache_partition_len,
651
- "kvcache_block_size": rbln_config.kvcache_block_size,
652
- "use_rotary_emb": cls._use_rotary_emb,
653
- "use_attention_mask": rbln_config.use_attention_mask,
654
- "use_position_ids": rbln_config.use_position_ids,
655
- "use_inputs_embeds": rbln_config.use_inputs_embeds,
656
- "cache_impl": rbln_config.cache_impl,
657
- "sliding_window": rbln_config.sliding_window,
658
- "sliding_window_layers": rbln_config.sliding_window_layers,
659
- }
660
- return cls._decoder_wrapper_cls(model, **wrapper_cfg).eval()
219
+ def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig"):
220
+ return cls._decoder_wrapper_cls(model, rbln_config, cls._use_rotary_emb).eval()
221
+
222
+ @classmethod
223
+ def _compile_model(
224
+ cls,
225
+ wrapped_model,
226
+ compile_config,
227
+ example_inputs,
228
+ compile_context,
229
+ rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
230
+ quantization=None,
231
+ phase: str = "prefill",
232
+ ):
233
+ try:
234
+ wrapped_model.phase = phase
235
+ if quantization:
236
+ quantization.maybe_set_quantization_env()
237
+ original_linear = torch.nn.functional.linear
238
+ torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
239
+ compiled_model = cls.compile(
240
+ wrapped_model,
241
+ compile_config,
242
+ create_runtimes=rbln_config.create_runtimes,
243
+ device=rbln_config.device,
244
+ example_inputs=example_inputs,
245
+ compile_context=compile_context,
246
+ )
247
+ return compiled_model
248
+ finally:
249
+ torch.nn.functional.linear = original_linear
250
+ if quantization:
251
+ quantization.maybe_reset_quantization_env()
252
+
253
+ @classmethod
254
+ def _get_compile_context(
255
+ cls,
256
+ compile_config: RBLNCompileConfig,
257
+ example_inputs: List[torch.Tensor],
258
+ ):
259
+ context = CompileContext(use_weight_sharing=True)
260
+
261
+ # Mark static tensors (self kv states)
262
+ static_tensors = {}
263
+ idx = 0
264
+ for (name, _, _), tensor in zip(compile_config.input_info, example_inputs):
265
+ if "past_key_values" in name:
266
+ static_tensors[name] = tensor
267
+ context.mark_static_address(tensor, f"kv_cache_{idx}")
268
+ idx += 1
269
+
270
+ return context, static_tensors
661
271
 
662
272
  @classmethod
663
273
  @torch.inference_mode()
664
274
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
665
275
  wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
666
-
667
- rbln_compile_configs = rbln_config.compile_cfgs
668
- prefill_compile_config = rbln_compile_configs[0]
669
-
670
- context = CompileContext(use_weight_sharing=True)
276
+ prefill_compile_config = rbln_config.compile_cfgs[0]
671
277
 
672
278
  # Here we use meta tensor, for the memory efficiency.
673
279
  meta_tensor_names = [name for name, _, _ in prefill_compile_config.input_info if "past_key_values" in name]
674
280
  prefill_example_inputs = prefill_compile_config.get_dummy_inputs(fill=0, meta_tensor_names=meta_tensor_names)
281
+ context, static_tensors = cls._get_compile_context(prefill_compile_config, prefill_example_inputs)
282
+
283
+ compiled_models = {}
284
+ compiled_models["prefill"] = cls._compile_model(
285
+ wrapped_model,
286
+ prefill_compile_config,
287
+ prefill_example_inputs,
288
+ context,
289
+ rbln_config,
290
+ rbln_config.quantization,
291
+ phase="prefill",
292
+ )
675
293
 
676
- # Mark static tensors (self kv states)
677
- static_tensors = {}
678
- for (name, _, _), tensor in zip(prefill_compile_config.input_info, prefill_example_inputs):
679
- if "past_key_values" in name:
680
- static_tensors[name] = tensor
681
- context.mark_static_address(tensor)
682
-
683
- def compile_model(wrapped_model, compile_config, example_inputs, compile_context, quantization):
684
- try:
685
- if quantization:
686
- quantization.maybe_set_quantization_env()
687
- original_linear = torch.nn.functional.linear
688
- torch.nn.functional.linear = torch.ops.rbln_custom_ops.linear
689
- compiled_model = cls.compile(
294
+ if rbln_config.can_generate:
295
+ wrapped_model.phase = "decode"
296
+ for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_config.compile_cfgs[1:]):
297
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
298
+ compiled_decoder = cls._compile_model(
690
299
  wrapped_model,
691
- compile_config,
692
- create_runtimes=rbln_config.create_runtimes,
693
- device=rbln_config.device,
694
- example_inputs=example_inputs,
695
- compile_context=compile_context,
300
+ dec_compile_config,
301
+ dec_example_inputs,
302
+ context,
303
+ rbln_config,
304
+ rbln_config.quantization,
305
+ phase="decode",
306
+ )
307
+ compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
308
+
309
+ # check if the memory is enough to have additional blocks
310
+ required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
311
+ if rbln_config.kvcache_num_blocks < required_num_blocks:
312
+ cls.maybe_suggest_kvcache_num_blocks(
313
+ compiled_models=compiled_models,
314
+ model_config=model.config,
315
+ rbln_config=rbln_config,
696
316
  )
697
- return compiled_model
698
- finally:
699
- torch.nn.functional.linear = original_linear
700
- if quantization:
701
- quantization.maybe_reset_quantization_env()
702
-
703
- wrapped_model.phase = "prefill"
704
- compiled_prefill = compile_model(
705
- wrapped_model, prefill_compile_config, prefill_example_inputs, context, rbln_config.quantization
706
- )
707
-
708
- wrapped_model.phase = "decode"
709
- compiled_models = {"prefill": compiled_prefill}
710
- for batch_size, dec_compile_config in zip(rbln_config.decoder_batch_sizes, rbln_compile_configs[1:]):
711
- dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
712
- compiled_decoder = compile_model(
713
- wrapped_model, dec_compile_config, dec_example_inputs, context, rbln_config.quantization
714
- )
715
- compiled_models[f"decoder_batch_{batch_size}"] = compiled_decoder
716
-
717
- # check if the memory is enough to have additional blocks
718
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
719
- if rbln_config.kvcache_num_blocks < required_num_blocks:
720
- cls.maybe_suggest_kvcache_num_blocks(
721
- compiled_models=compiled_models,
722
- model_config=model.config,
723
- rbln_config=rbln_config,
724
- )
725
317
 
726
318
  return compiled_models
727
319
 
728
320
  @classmethod
729
- def maybe_suggest_kvcache_num_blocks(
730
- cls,
731
- compiled_models: Dict[str, rebel.RBLNCompiledModel],
732
- model_config: PretrainedConfig,
733
- rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
734
- ) -> None:
735
- # Get the actual memory allocation of each node by key
736
- alloc_memory_per_node_by_key: Dict[str, List[int]] = compiled_models["prefill"].get_alloc_per_node_by_key()
737
- alloc_memory_by_key: Dict[str, int] = {
738
- key: sum(memory_per_node) for key, memory_per_node in alloc_memory_per_node_by_key.items()
739
- }
740
- for batch_size in rbln_config.decoder_batch_sizes:
741
- for key, memory_per_node in (
742
- compiled_models[f"decoder_batch_{batch_size}"].get_alloc_per_node_by_key().items()
743
- ):
744
- alloc_memory_by_key[key] += sum(memory_per_node)
745
- alloc_memory_by_key.pop("PortRecur", None) # Old compiler's kv-cache Key
746
- alloc_memory_by_key.pop("DramTensor", None) # kv-cache
747
- kernel_size = alloc_memory_by_key.pop("Kernel") # model weight
748
-
749
- # Get the maximum number of blocks that can be allocated
750
- buffer = sum(alloc_memory_by_key.values())
751
- max_num_blocks = cls.get_maximum_num_blocks(
752
- config=model_config,
753
- tensor_parallel_size=rbln_config.tensor_parallel_size,
754
- kvcache_block_size=rbln_config.kvcache_block_size,
755
- kernel_size=kernel_size,
756
- buffer=buffer,
757
- )
321
+ def get_pytorch_model(
322
+ cls, *args, rbln_config: Optional[RBLNDecoderOnlyModelConfig] = None, **kwargs
323
+ ) -> PreTrainedModel:
324
+ if rbln_config and rbln_config.quantization:
325
+ model = cls.get_quantized_model(*args, rbln_config=rbln_config, **kwargs)
326
+ else:
327
+ model = super().get_pytorch_model(*args, **kwargs)
758
328
 
759
- # Since our estimation logic is not always accurate,
760
- # users can set `kvcache_num_blocks` to `max_num_blocks`.
761
- # If the memory is not enough, the model will fail to compile.
762
- if rbln_config.kvcache_num_blocks < max_num_blocks:
763
- logger.warning(
764
- f"Current `kvcache_num_blocks` setting is {rbln_config.kvcache_num_blocks}. "
765
- "Our analysis indicates that additional memory is available for more blocks. "
766
- f"Consider increasing `kvcache_num_blocks` to {max_num_blocks} for potentially improved performance. "
767
- "Please be advised that our memory estimation algorithm has limitations, "
768
- "and increasing this value may not guarantee successful model compilation."
769
- )
329
+ return model
770
330
 
771
331
  @classmethod
772
- def get_maximum_num_blocks(
773
- cls,
774
- config: PretrainedConfig,
775
- tensor_parallel_size: int,
776
- kvcache_block_size: int,
777
- nbits_per_param: Optional[int] = None,
778
- n_model_params: Optional[int] = None,
779
- kernel_size: Optional[int] = None,
780
- buffer: Optional[int] = None,
781
- num_runtimes: int = 2,
782
- ) -> int:
783
- # We are finding max_n_blocks(x) that satisfies the following equation:
784
-
785
- # available_dram - kernel_size - buffer
786
- # - num_layers * 2 * tensor_parallel_size
787
- # * align_2MB(
788
- # x
789
- # * block_size
790
- # * align_64(head_dim)
791
- # * math.ceil(num_key_value_heads / tensor_parallel_size)
792
- # * 2
793
- # ) > 0
794
-
795
- # This inequality can be rewritten as follows:
796
-
797
- # a - c * align_2MB(b * x) > 0
798
- # where
799
- # a = available_dram - kernel_size - buffer
800
- # b = block_size * align_64(head_dim) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
801
- # c = num_layers * 2 * tensor_parallel_size
802
-
803
- # We can rewrite the inequality as follows:
804
- # k > align_2MB(b*x)
805
- # where
806
- # k = a / c
807
-
808
- # After that, we can derive the following equation:
809
- # x = floor(2**21 / b * floor((k - 1) / 2**21))
810
-
811
- def align(x: int, nbytes: int) -> int:
812
- return int(math.ceil(x / nbytes) * nbytes)
813
-
814
- def align_2MB(x: int) -> int:
815
- return align(x, 2**21)
816
-
817
- num_attention_heads = getattr(config, "n_head", None) or getattr(config, "num_attention_heads")
818
- num_layers = getattr(config, "n_layer", None) or getattr(config, "num_hidden_layers")
819
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // num_attention_heads
820
- vocab_size = config.vocab_size
821
- hidden_size = getattr(config, "n_embd", None) or getattr(config, "hidden_size")
822
- num_key_value_heads = getattr(config, "num_key_value_heads", None) or num_attention_heads
823
-
824
- # TODO(jongho): Update if target npu is REBEL.
825
- ATOM_DRAM_NBYTES = 16 * 2**30
826
- ATOM_SYS_DRAM_NBYTES = 288 * 2**20
827
- available_dram = tensor_parallel_size * (ATOM_DRAM_NBYTES - ATOM_SYS_DRAM_NBYTES)
828
-
829
- if kernel_size is None:
830
- if n_model_params is None:
831
- raise ValueError("`n_model_params` should be specified to estimate the kernel memory.")
832
- # Get estimated kernel size (approximated)
833
- lm_heads_params = align(vocab_size, 64) * hidden_size
834
- lm_heads_nbytes = (
835
- align_2MB(lm_heads_params * nbits_per_param // 8 / tensor_parallel_size) * tensor_parallel_size
836
- )
837
- params = n_model_params - lm_heads_params
838
- layer_nbytes = (
839
- align_2MB(params * nbits_per_param // 8 / num_layers / tensor_parallel_size)
840
- * num_layers
841
- * tensor_parallel_size
842
- )
843
- kernel_size = layer_nbytes + lm_heads_nbytes
844
- elif n_model_params is not None:
845
- raise ValueError("Both `n_model_params` and `kernel_size` cannot be specified.")
846
-
847
- available_dram -= kernel_size
848
-
849
- if buffer is None:
850
- # TODO: Accurate buffer estimation
851
- buffer_per_runtime_per_core = 2**28 # 256MB per runtime
852
- buffer_per_core = buffer_per_runtime_per_core * num_runtimes # 1 for prefill, 1 for decoder
853
- buffer = buffer_per_core * tensor_parallel_size
854
- available_dram -= buffer
855
-
856
- b = kvcache_block_size * align(head_dim, 64) * math.ceil(num_key_value_heads / tensor_parallel_size) * 2
857
- c = num_layers * 2 * tensor_parallel_size
858
- k = available_dram / c
859
- max_n_blocks = math.floor(2**21 / b * math.floor((k - 1) / 2**21))
860
-
861
- return max_n_blocks
332
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
333
+ return use_local_attention
862
334
 
863
335
  @classmethod
864
336
  def get_input_info(
@@ -868,63 +340,57 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
868
340
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
869
341
  model_config: PretrainedConfig,
870
342
  ):
871
- is_prefill: bool = query_length > 1
872
343
  num_attention_heads = getattr(model_config, "n_head", None) or getattr(model_config, "num_attention_heads")
873
344
  num_key_value_heads = getattr(model_config, "num_key_value_heads", None) or num_attention_heads
874
345
  num_hidden_layers = getattr(model_config, "n_layer", None) or getattr(model_config, "num_hidden_layers")
875
346
  hidden_size = getattr(model_config, "n_embd", None) or getattr(model_config, "hidden_size")
876
347
  head_dim = getattr(model_config, "head_dim", None) or hidden_size // num_attention_heads
877
- local_kvcache_num_blocks = max(rbln_config.decoder_batch_sizes)
348
+ is_prefill = query_length > 1
878
349
 
879
- # 1. main input
350
+ input_info = []
880
351
  if rbln_config.use_inputs_embeds:
881
- main_input = ("inputs_embeds", [batch_size, query_length, hidden_size], "float32")
352
+ input_info.append(("inputs_embeds", [batch_size, query_length, hidden_size], rbln_config.torch_dtype))
882
353
  else:
883
- main_input = ("input_ids", [batch_size, query_length], "int64")
884
-
885
- # 2. cache_position
886
- input_info = [
887
- main_input,
888
- (
889
- "cache_position",
890
- [batch_size, query_length],
891
- "int32",
892
- ),
893
- ]
354
+ input_info.append(("input_ids", [batch_size, query_length], "int64"))
894
355
 
895
- # 3. block_tables
896
- if rbln_config.cache_impl in ["static", "hybrid"]:
356
+ input_info.append(("cache_position", [batch_size, query_length], "int32"))
357
+
358
+ if rbln_config.use_global_attention:
897
359
  max_block_cnt = rbln_config.max_seq_len // rbln_config.kvcache_block_size
898
- input_info.extend(
899
- [("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")]
360
+ input_info.append(
361
+ ("block_tables", [max_block_cnt] if is_prefill else [batch_size, max_block_cnt], "int16")
900
362
  )
901
- if rbln_config.cache_impl in ["hybrid", "sliding_window"]:
902
- input_info.extend([("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16")])
363
+ if rbln_config.use_local_attention:
364
+ input_info.append(("local_block_tables", [1] if is_prefill else [batch_size, 1], "int16"))
903
365
 
904
- # 4. query_position
905
- if is_prefill:
906
- input_info.extend([("query_position", [], "int16")])
366
+ if cls.use_query_position(rbln_config.use_local_attention, is_prefill):
367
+ input_info.append(("query_position", [], "int16"))
907
368
 
908
- # 5. attention_mask & position_ids
909
369
  if rbln_config.use_attention_mask:
910
- input_info.extend(
911
- [
912
- ("attention_mask", [batch_size, rbln_config.max_seq_len], "float32")
913
- if rbln_config.use_position_ids
914
- else ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], "float32")
915
- ]
916
- )
370
+ if rbln_config.use_position_ids:
371
+ input_info.append(("attention_mask", [batch_size, rbln_config.max_seq_len], rbln_config.torch_dtype))
372
+ else:
373
+ input_info.append(
374
+ ("attention_mask", [batch_size, 1, query_length, rbln_config.max_seq_len], rbln_config.torch_dtype)
375
+ )
376
+
917
377
  if rbln_config.use_position_ids:
918
378
  input_info.append(("position_ids", [batch_size, query_length], "int32"))
919
379
 
920
- # 6. past_key_values
380
+ if rbln_config.use_lora:
381
+ input_info.append(("lora_int_ids", [batch_size], "int32"))
382
+
383
+ kvcache_dtype = rbln_config.torch_dtype
384
+ if rbln_config.quantization and rbln_config.quantization.kv_caches == "fp8":
385
+ kvcache_dtype = "float8_e4m3fn"
386
+
921
387
  global_kvcache_shape = [
922
388
  rbln_config.kvcache_num_blocks,
923
389
  num_key_value_heads,
924
390
  rbln_config.kvcache_block_size,
925
391
  head_dim,
926
392
  ]
927
- local_kvcache_shape = [local_kvcache_num_blocks, num_key_value_heads, rbln_config.sliding_window, head_dim]
393
+ local_kvcache_shape = [rbln_config.batch_size, num_key_value_heads, rbln_config.sliding_window, head_dim]
928
394
  input_info.extend(
929
395
  [
930
396
  (
@@ -932,7 +398,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
932
398
  local_kvcache_shape
933
399
  if rbln_config.sliding_window is not None and ((i // 2) in rbln_config.sliding_window_layers)
934
400
  else global_kvcache_shape,
935
- "float32",
401
+ kvcache_dtype,
936
402
  )
937
403
  for i in range(num_hidden_layers * 2)
938
404
  ]
@@ -971,7 +437,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
971
437
  # ```
972
438
 
973
439
  # Returns:
974
- # RBLNDecoderOnlyModelForCausalLMConfig: The updated RBLN model configuration.
440
+ # RBLNDecoderOnlyModelConfig: The updated RBLN model configuration.
975
441
 
976
442
  raise NotImplementedError(
977
443
  "Subclasses must implement _update_sliding_window_config to configure sliding window attention settings. "
@@ -979,27 +445,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
979
445
  )
980
446
 
981
447
  @classmethod
982
- def _update_rbln_config(
983
- cls,
984
- preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
985
- model: Optional[PreTrainedModel] = None,
986
- model_config: Optional[PretrainedConfig] = None,
987
- rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
988
- ) -> RBLNDecoderOnlyModelForCausalLMConfig:
989
- if rbln_config.max_seq_len is None:
990
- rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
991
- model_config, "n_positions", None
992
- )
993
- if rbln_config.max_seq_len is None:
994
- raise ValueError("`max_seq_len` should be specified.")
995
-
996
- if getattr(model_config, "sliding_window", None) is not None and getattr(
997
- model_config, "use_sliding_window", True
998
- ):
999
- rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
1000
- if rbln_config.sliding_window is not None:
1001
- validate_sliding_window_size(rbln_config.sliding_window, rbln_config.prefill_chunk_size)
1002
-
448
+ def _update_attention_config(
449
+ cls, model: PreTrainedModel, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
450
+ ):
1003
451
  rbln_config.attn_impl, rbln_config.kvcache_partition_len, rbln_config.kvcache_block_size = set_default_values(
1004
452
  attn_impl=rbln_config.attn_impl,
1005
453
  kvcache_partition_len=rbln_config.kvcache_partition_len,
@@ -1014,9 +462,9 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1014
462
  max_seq_len=rbln_config.max_seq_len,
1015
463
  )
1016
464
 
1017
- required_num_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1018
- max_num_blocks = required_num_blocks
465
+ num_full_blocks = (rbln_config.max_seq_len // rbln_config.kvcache_block_size) * rbln_config.batch_size
1019
466
 
467
+ # Update kvcache_num_blocks based on the attention implementation.
1020
468
  if rbln_config.attn_impl == "flash_attn":
1021
469
  estimated_max_num_blocks = cls.get_maximum_num_blocks(
1022
470
  config=model_config,
@@ -1024,30 +472,73 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1024
472
  kvcache_block_size=rbln_config.kvcache_block_size,
1025
473
  nbits_per_param=16 if not rbln_config.quantization else 4, # TODO(jongho): FIX Ad-hoc
1026
474
  n_model_params=sum(p.numel() for p in model.parameters()),
1027
- num_runtimes=1 + len(rbln_config.decoder_batch_sizes),
475
+ num_runtimes=1 if not rbln_config.can_generate else 1 + len(rbln_config.decoder_batch_sizes),
1028
476
  )
1029
477
 
1030
- max_num_blocks = min(max_num_blocks, estimated_max_num_blocks)
478
+ if rbln_config.kvcache_num_blocks is None:
479
+ if estimated_max_num_blocks < num_full_blocks:
480
+ # lower bound of the number of blocks for flash attention.
481
+ min_blocks_for_flash = min(
482
+ rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1, num_full_blocks
483
+ )
484
+ if min_blocks_for_flash > estimated_max_num_blocks:
485
+ # NOTE: Just try to compile with lower bound of blocks for flash attention.
486
+ # Even if it's larger than the estimated maximum number of blocks.
487
+ rbln_config.kvcache_num_blocks = min_blocks_for_flash
488
+ else:
489
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
490
+ rbln_config.kvcache_num_blocks = estimated_max_num_blocks
491
+
492
+ if rbln_config.kvcache_num_blocks < rbln_config.batch_size:
493
+ raise RuntimeError(
494
+ f"Batch size ({rbln_config.batch_size}) exceeds num_blocks ({rbln_config.kvcache_num_blocks}). "
495
+ "Ensure the number of blocks is at least equal to the batch size."
496
+ )
497
+ else:
498
+ rbln_config.kvcache_num_blocks = num_full_blocks
499
+ elif rbln_config.kvcache_num_blocks > estimated_max_num_blocks:
500
+ logger.warning(
501
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
502
+ f" than the estimated maximum number of blocks ({estimated_max_num_blocks})."
503
+ "This can cause a failure during model compilation."
504
+ )
505
+ else:
506
+ if rbln_config.kvcache_num_blocks is None:
507
+ rbln_config.kvcache_num_blocks = num_full_blocks
508
+ elif rbln_config.kvcache_num_blocks > num_full_blocks:
509
+ logger.warning(
510
+ f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
511
+ f" than the required number of blocks ({num_full_blocks})."
512
+ "This can cause a failure during model compilation."
513
+ )
1031
514
 
1032
- flash_min_blocks = rbln_config.max_seq_len // rbln_config.kvcache_block_size + 1
1033
- if max_num_blocks < flash_min_blocks:
1034
- max_num_blocks = flash_min_blocks
515
+ logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
1035
516
 
1036
- if max_num_blocks < rbln_config.batch_size:
1037
- raise RuntimeError(
1038
- f"Batch size ({rbln_config.batch_size}) exceeds available KV cache blocks ({max_num_blocks}). "
1039
- "Ensure the number of blocks is at least equal to the batch size."
1040
- )
517
+ return rbln_config
1041
518
 
1042
- if rbln_config.kvcache_num_blocks is None:
1043
- rbln_config.kvcache_num_blocks = max_num_blocks
1044
- elif rbln_config.kvcache_num_blocks > max_num_blocks:
1045
- logger.warning(
1046
- f"The set `kvcache_num_blocks` ({rbln_config.kvcache_num_blocks}) is greater"
1047
- f" than the estimated maximum number of blocks ({max_num_blocks})."
1048
- "This can cause a failure during model compilation."
519
+ @classmethod
520
+ def _update_rbln_config(
521
+ cls,
522
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
523
+ model: Optional[PreTrainedModel] = None,
524
+ model_config: Optional[PretrainedConfig] = None,
525
+ rbln_config: Optional[RBLNDecoderOnlyModelForCausalLMConfig] = None,
526
+ ) -> RBLNDecoderOnlyModelForCausalLMConfig:
527
+ if rbln_config.max_seq_len is None:
528
+ rbln_config.max_seq_len = getattr(model_config, "max_position_embeddings", None) or getattr(
529
+ model_config, "n_positions", None
1049
530
  )
1050
- logger.info(f"[KVCache] Compiling with num_blocks: {rbln_config.kvcache_num_blocks}")
531
+ if rbln_config.max_seq_len is None:
532
+ raise ValueError("`max_seq_len` should be specified.")
533
+
534
+ if getattr(model_config, "sliding_window", None) is not None and getattr(
535
+ model_config, "use_sliding_window", True
536
+ ):
537
+ rbln_config = cls._update_sliding_window_config(model_config, rbln_config)
538
+ if rbln_config.sliding_window is not None:
539
+ validate_sliding_window(rbln_config)
540
+
541
+ rbln_config = cls._update_attention_config(model, model_config, rbln_config)
1051
542
 
1052
543
  prefill_input_info = cls.get_input_info(
1053
544
  batch_size=1,
@@ -1057,19 +548,20 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1057
548
  )
1058
549
 
1059
550
  prefill_compile_config = RBLNCompileConfig(compiled_model_name="prefill", input_info=prefill_input_info)
1060
-
1061
- dec_compile_configs = []
1062
- for batch_size in rbln_config.decoder_batch_sizes:
1063
- dec_input_info = cls.get_input_info(
1064
- batch_size=batch_size,
1065
- query_length=1,
1066
- rbln_config=rbln_config,
1067
- model_config=model_config,
1068
- )
1069
- dec_compile_configs.append(
1070
- RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
1071
- )
1072
- rbln_config.set_compile_cfgs([prefill_compile_config, *dec_compile_configs])
551
+ compile_cfgs = [prefill_compile_config]
552
+
553
+ if rbln_config.can_generate:
554
+ for batch_size in rbln_config.decoder_batch_sizes:
555
+ dec_input_info = cls.get_input_info(
556
+ batch_size=batch_size,
557
+ query_length=1,
558
+ rbln_config=rbln_config,
559
+ model_config=model_config,
560
+ )
561
+ compile_cfgs.append(
562
+ RBLNCompileConfig(compiled_model_name=f"decoder_batch_{batch_size}", input_info=dec_input_info)
563
+ )
564
+ rbln_config.set_compile_cfgs(compile_cfgs)
1073
565
 
1074
566
  return rbln_config
1075
567
 
@@ -1079,103 +571,153 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1079
571
  compiled_models: List[rebel.RBLNCompiledModel],
1080
572
  rbln_config: RBLNDecoderOnlyModelForCausalLMConfig,
1081
573
  ) -> List[rebel.Runtime]:
1082
- expected_model_names = [
1083
- "prefill",
1084
- *[f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes],
1085
- ]
574
+ expected_model_names = ["prefill"]
575
+ if rbln_config.can_generate:
576
+ expected_model_names.extend(
577
+ [f"decoder_batch_{batch_size}" for batch_size in rbln_config.decoder_batch_sizes]
578
+ )
1086
579
  if any(model_name not in rbln_config.device_map for model_name in expected_model_names):
1087
580
  cls._raise_missing_compiled_file_error(expected_model_names)
1088
581
 
1089
- return [
582
+ ret_val = [
1090
583
  rebel.Runtime(
1091
584
  compiled_models[0],
1092
585
  tensor_type="pt",
1093
586
  device=rbln_config.device_map["prefill"],
1094
587
  activate_profiler=rbln_config.activate_profiler,
1095
588
  timeout=rbln_config.timeout,
1096
- ),
1097
- *[
1098
- rebel.Runtime(
1099
- compiled_models[i + 1],
1100
- tensor_type="pt",
1101
- device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
1102
- activate_profiler=rbln_config.activate_profiler,
1103
- timeout=rbln_config.timeout,
1104
- )
1105
- for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
1106
- ],
589
+ )
1107
590
  ]
591
+ if rbln_config.can_generate:
592
+ ret_val.extend(
593
+ [
594
+ rebel.Runtime(
595
+ compiled_models[i + 1],
596
+ tensor_type="pt",
597
+ device=rbln_config.device_map[f"decoder_batch_{batch_size}"],
598
+ activate_profiler=rbln_config.activate_profiler,
599
+ timeout=rbln_config.timeout,
600
+ )
601
+ for i, batch_size in enumerate(rbln_config.decoder_batch_sizes)
602
+ ]
603
+ )
604
+ return ret_val
1108
605
 
1109
- def get_decoder(self):
1110
- return self.decoder
1111
-
1112
- def can_generate(self):
1113
- return True
1114
-
1115
- def _reorder_cache(self, past_key_values, beam_idx):
1116
- raise NotImplementedError
1117
-
1118
- def prepare_inputs_for_generation(
606
+ def forward(
1119
607
  self,
1120
- input_ids: torch.LongTensor,
1121
- generate_idx: Optional[torch.Tensor] = None,
1122
- attention_mask: Optional[torch.LongTensor] = None,
608
+ input_ids: Optional[torch.LongTensor] = None,
1123
609
  inputs_embeds: Optional[torch.Tensor] = None,
1124
- padded_cache_lengths: Optional[torch.Tensor] = None,
610
+ attention_mask: Optional[torch.LongTensor] = None,
611
+ position_embed: Optional[torch.Tensor] = None,
1125
612
  **kwargs,
1126
- ):
1127
- model_inputs = {}
1128
- is_prefill_phase = generate_idx is None
613
+ ) -> Tuple[torch.FloatTensor]:
614
+ inputs = inputs_embeds if inputs_embeds is not None else input_ids
615
+ batch_size = inputs.shape[0]
1129
616
 
1130
- if is_prefill_phase:
1131
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1132
- padded_cache_lengths = torch.zeros_like(generate_idx)
1133
- cache_position = None
1134
- position_ids = None
1135
- else:
1136
- if inputs_embeds is not None:
1137
- # if `inputs_embeds` are passed, only use them in the 1st generation step for every prompt.
1138
- inputs_embeds = None
1139
-
1140
- input_ids = input_ids[:, -1:]
1141
- position_ids = generate_idx
1142
- cache_position = generate_idx + padded_cache_lengths if padded_cache_lengths is not None else generate_idx
1143
- generate_idx = generate_idx + 1
1144
- model_inputs.update({"input_ids": input_ids})
1145
-
1146
- if inputs_embeds is not None:
1147
- if self.rbln_config.use_inputs_embeds:
1148
- model_inputs.update({"inputs_embeds": inputs_embeds})
1149
- else:
1150
- raise ValueError(
1151
- "The specifying inputs_embeds is only supported when using a compiled RBLN model with 'rbln_use_inputs_embeds' set to True."
1152
- )
1153
- else:
1154
- model_inputs.update({"input_ids": input_ids})
1155
-
1156
- model_inputs.update(
1157
- {
1158
- "attention_mask": attention_mask,
1159
- "cache_position": cache_position,
1160
- "generate_idx": generate_idx,
1161
- "position_ids": position_ids,
1162
- "padded_cache_lengths": padded_cache_lengths,
1163
- }
617
+ if batch_size != self.rbln_config.batch_size:
618
+ raise ValueError(
619
+ f"Batch size ({batch_size}) must be equal to the batch size of the model ({self.rbln_config.batch_size})."
620
+ )
621
+
622
+ all_last_hidden_states = []
623
+ for b_idx in range(self.rbln_config.batch_size):
624
+ query_length = (
625
+ attention_mask[b_idx].sum(dim=-1).int().item() if attention_mask is not None else inputs.shape[1]
626
+ )
627
+ cache_position = torch.arange(query_length, dtype=torch.int32).unsqueeze(0)
628
+ last_hidden_states = self.prefill_decoder(
629
+ inputs[b_idx : b_idx + 1],
630
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
631
+ position_embed=position_embed[b_idx : b_idx + 1] if position_embed is not None else None,
632
+ cache_position=cache_position,
633
+ batch_idx=b_idx,
634
+ ).logits
635
+ all_last_hidden_states.append(last_hidden_states)
636
+
637
+ last_hidden_states = torch.concat(all_last_hidden_states, dim=0)
638
+ return BaseModelOutputWithPast(last_hidden_state=last_hidden_states)
639
+
640
+
641
+ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGenerationMixin):
642
+ """
643
+ A base class for decoder-only transformer models optimized for causal language modeling tasks on RBLN devices.
644
+ This class serves as the foundation for various decoder-only architectures like GPT, LLaMA, etc.
645
+
646
+ The class provides core functionality for:
647
+
648
+ 1. Converting pre-trained transformer models to RBLN-optimized format
649
+ 2. Handling the compilation process for RBLN devices
650
+ 3. Managing inference operations for causal language modeling
651
+ This class inherits from RBLNModel and implements specific methods required for
652
+ decoder-only architectures and causal language modeling tasks.
653
+
654
+ Note:
655
+ - This class is designed to be subclassed by specific model implementations
656
+ (e.g., RBLNLlamaForCausalLM, RBLNGPT2LMHeadModel)
657
+ - Subclasses should implement model-specific conversion logic.
658
+ - The class handles RBLN-specific optimizations automatically during compilation
659
+ """
660
+
661
+ auto_model_class = AutoModelForCausalLM
662
+
663
+ @property
664
+ def prefill_output_size(self):
665
+ return (
666
+ 1,
667
+ self.rbln_config.prefill_chunk_size if self.rbln_config.logits_to_keep == 0 else 1,
668
+ self.config.vocab_size,
1164
669
  )
1165
670
 
1166
- return model_inputs
671
+ @classmethod
672
+ def use_query_position(cls, use_local_attention: bool, is_prefill: bool = True):
673
+ return is_prefill
1167
674
 
1168
- def _update_model_kwargs_for_generation(
1169
- self,
1170
- outputs: RBLNDecoderOnlyOutput,
1171
- model_kwargs: Dict[str, Any],
1172
- **kwargs,
1173
- ) -> Dict[str, Any]:
1174
- # update generate_idx
1175
- model_kwargs["generate_idx"] = outputs.generate_idx
1176
- model_kwargs["padded_cache_lengths"] = outputs.padded_cache_lengths
675
+ def set_lora_int_ids(self, lora_int_ids: Optional[torch.Tensor]):
676
+ if isinstance(lora_int_ids, int):
677
+ lora_int_ids = torch.tensor([lora_int_ids], dtype=torch.int32)
678
+ elif isinstance(lora_int_ids, list):
679
+ lora_int_ids = torch.tensor(lora_int_ids, dtype=torch.int32)
680
+
681
+ self.lora_int_ids = lora_int_ids
682
+
683
+ self.prefill_decoder.lora_int_ids = lora_int_ids
684
+ if self.rbln_config.can_generate:
685
+ for batch_size in self.rbln_config.decoder_batch_sizes:
686
+ self.decoders[batch_size].lora_int_ids = lora_int_ids
687
+
688
+ def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
689
+ """
690
+ Sets the active adapter(s) for the model using adapter name(s).
691
+
692
+ Args:
693
+ adapter_name (Union[str, List[str]]): The name(s) of the adapter(s) to be activated.
694
+ Can be a single adapter name or a list of adapter names.
1177
695
 
1178
- return model_kwargs
696
+ Raises:
697
+ ValueError: If the model is not configured with LoRA or if the adapter name is not found.
698
+ """
699
+ if not hasattr(self.rbln_config, "lora_config") or self.rbln_config.lora_config is None:
700
+ raise ValueError("Model is not configured with LoRA. Cannot set adapter.")
701
+
702
+ # Convert single adapter name to list for uniform processing
703
+ if isinstance(adapter_name, str):
704
+ adapter_names = [adapter_name]
705
+ else:
706
+ adapter_names = adapter_name
707
+
708
+ # Validate that all adapter names exist
709
+ available_adapters = {
710
+ adapter.lora_name: adapter.lora_int_id for adapter in self.rbln_config.lora_config.adapters
711
+ }
712
+ missing_adapters = [name for name in adapter_names if name not in available_adapters]
713
+ if missing_adapters:
714
+ raise ValueError(
715
+ f"Adapter(s) {missing_adapters} not found. Available adapters: {list(available_adapters.keys())}"
716
+ )
717
+
718
+ # Get the adapter IDs and set them
719
+ lora_int_ids = [available_adapters[name] for name in adapter_names]
720
+ self.set_lora_int_ids(torch.tensor(lora_int_ids, dtype=torch.int32))
1179
721
 
1180
722
  def forward(
1181
723
  self,
@@ -1187,6 +729,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1187
729
  padded_cache_lengths: Optional[torch.Tensor] = None,
1188
730
  position_ids: Optional[torch.Tensor] = None,
1189
731
  token_type_ids: Optional[torch.Tensor] = None,
732
+ lora_int_ids: Optional[torch.Tensor] = None,
1190
733
  return_dict: Optional[torch.Tensor] = None,
1191
734
  **kwargs,
1192
735
  ) -> Tuple[torch.FloatTensor]:
@@ -1194,16 +737,27 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1194
737
  # For continuous batching, the prefill stage processes one batch at a time and updates the KV cache using batch_idx.
1195
738
  # A for-loop ensures synchronization with the HuggingFace generate API.
1196
739
  # The decoder stage operates as usual, processing inputs in batch mode.
740
+ if self.rbln_config.use_lora and lora_int_ids is None:
741
+ if self.lora_int_ids is None:
742
+ raise ValueError(
743
+ "lora_int_id is required when using LoRA. "
744
+ "You should call set_lora_int_ids() before forward() or pass lora_int_id to forward()."
745
+ )
746
+ lora_int_ids = self.lora_int_ids
747
+
748
+ # for only use forward
749
+ if generate_idx is None:
750
+ generate_idx = (
751
+ attention_mask.sum(dim=-1, keepdim=True).int()
752
+ if attention_mask is not None
753
+ else torch.full((input_ids.shape[0], 1), input_ids.shape[1], dtype=torch.int32)
754
+ )
755
+ padded_cache_lengths = torch.zeros_like(generate_idx)
1197
756
 
1198
- # Prefll
757
+ # Prefill
1199
758
  if cache_position is None:
1200
759
  logits = []
1201
760
  inputs = inputs_embeds if inputs_embeds is not None else input_ids
1202
- # for only use forward
1203
- if generate_idx is None:
1204
- generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
1205
- if padded_cache_lengths is None:
1206
- padded_cache_lengths = torch.zeros_like(generate_idx)
1207
761
  batch_size = inputs.shape[0]
1208
762
  for b_idx in range(batch_size):
1209
763
  cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
@@ -1214,6 +768,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1214
768
  cache_position=cache_position,
1215
769
  batch_idx=b_idx,
1216
770
  token_type_ids=token_type_ids[b_idx : b_idx + 1] if token_type_ids is not None else None,
771
+ lora_int_ids=lora_int_ids[b_idx : b_idx + 1] if lora_int_ids is not None else None,
1217
772
  )
1218
773
  padded_cache_lengths[b_idx] += output.padded_cache_lengths
1219
774
  logits.append(output.logits)
@@ -1233,6 +788,7 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNModel):
1233
788
  inputs_embeds=inputs_embeds,
1234
789
  cache_position=cache_position,
1235
790
  position_ids=position_ids if self.rbln_config.use_position_ids else None,
791
+ lora_int_ids=lora_int_ids,
1236
792
  ).logits
1237
793
 
1238
794
  if not return_dict: