optimum-rbln 0.7.4a9__py3-none-any.whl → 0.7.5a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. optimum/rbln/__init__.py +21 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +11 -7
  4. optimum/rbln/diffusers/models/controlnet.py +1 -1
  5. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +1 -1
  6. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +1 -1
  7. optimum/rbln/modeling.py +7 -5
  8. optimum/rbln/ops/__init__.py +1 -0
  9. optimum/rbln/ops/attn.py +10 -0
  10. optimum/rbln/ops/flash_attn.py +8 -0
  11. optimum/rbln/ops/sliding_window_attn.py +111 -0
  12. optimum/rbln/transformers/__init__.py +22 -3
  13. optimum/rbln/transformers/models/__init__.py +23 -0
  14. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  15. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +93 -0
  16. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +298 -0
  17. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +42 -6
  18. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +81 -77
  19. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +251 -135
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +11 -7
  21. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +4 -4
  22. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  23. optimum/rbln/transformers/models/opt/configuration_opt.py +19 -0
  24. optimum/rbln/transformers/models/opt/modeling_opt.py +78 -0
  25. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  26. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +16 -10
  27. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +35 -52
  28. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -0
  29. optimum/rbln/transformers/models/siglip/__init__.py +20 -0
  30. optimum/rbln/transformers/models/siglip/configuration_siglip.py +66 -0
  31. optimum/rbln/transformers/models/siglip/modeling_siglip.py +146 -0
  32. optimum/rbln/transformers/models/whisper/whisper_architecture.py +1 -0
  33. optimum/rbln/transformers/utils/rbln_quantization.py +121 -72
  34. optimum/rbln/utils/import_utils.py +23 -6
  35. optimum/rbln/utils/submodule.py +13 -1
  36. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/METADATA +1 -1
  37. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/RECORD +39 -28
  38. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/WHEEL +0 -0
  39. {optimum_rbln-0.7.4a9.dist-info → optimum_rbln-0.7.5a1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,298 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ AutoModelForVisualQuestionAnswering,
22
+ Blip2ForConditionalGeneration,
23
+ Blip2QFormerModel,
24
+ Blip2VisionModel,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
29
+ from transformers.utils import logging
30
+
31
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
+ from ....modeling import RBLNModel
33
+
34
+
35
+ logger = logging.get_logger(__name__)
36
+
37
+ if TYPE_CHECKING:
38
+ from transformers import (
39
+ AutoFeatureExtractor,
40
+ AutoProcessor,
41
+ AutoTokenizer,
42
+ )
43
+
44
+
45
+ class RBLNBlip2VisionModel(RBLNModel):
46
+ def get_input_embeddings(self):
47
+ return self.embeddings
48
+
49
+ @classmethod
50
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
51
+ class Blip2VisionModelWrapper(torch.nn.Module):
52
+ def __init__(self, model: "Blip2VisionModel") -> None:
53
+ super().__init__()
54
+ self.model = model
55
+
56
+ def forward(self, *args, **kwargs):
57
+ kwargs.pop("return_dict", None)
58
+ return self.model(*args, **kwargs, return_dict=False)
59
+
60
+ return Blip2VisionModelWrapper(model).eval()
61
+
62
+ @classmethod
63
+ def _update_rbln_config(
64
+ cls,
65
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
66
+ model: Optional["PreTrainedModel"] = None,
67
+ model_config: Optional["PretrainedConfig"] = None,
68
+ rbln_config: Optional[RBLNModelConfig] = None,
69
+ ) -> RBLNModelConfig:
70
+ input_info = [
71
+ (
72
+ "pixel_values",
73
+ [
74
+ rbln_config.batch_size,
75
+ model_config.num_channels,
76
+ model_config.image_size,
77
+ model_config.image_size,
78
+ ],
79
+ "float32",
80
+ ),
81
+ ]
82
+
83
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
84
+ rbln_config.set_compile_cfgs([rbln_compile_config])
85
+ return rbln_config
86
+
87
+ def forward(
88
+ self,
89
+ pixel_values: Optional[torch.FloatTensor] = None,
90
+ output_attentions: Optional[bool] = None,
91
+ output_hidden_states: Optional[bool] = None,
92
+ return_dict: Optional[bool] = None,
93
+ interpolate_pos_encoding: bool = False,
94
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
95
+ output = super().forward(pixel_values, return_dict=return_dict)
96
+ return output
97
+
98
+ def _prepare_output(self, output, return_dict):
99
+ """
100
+ Prepare model output based on return_dict flag.
101
+ This method can be overridden by subclasses to provide task-specific output handling.
102
+ """
103
+ if not return_dict:
104
+ return (output,) if not isinstance(output, (tuple, list)) else output
105
+ else:
106
+ return BaseModelOutputWithPooling(
107
+ last_hidden_state=output[0],
108
+ pooler_output=output[1],
109
+ )
110
+
111
+
112
+ class RBLNBlip2QFormerModel(RBLNModel):
113
+ def get_input_embeddings(self):
114
+ return self.embeddings.word_embeddings
115
+
116
+ @classmethod
117
+ def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
118
+ class Blip2QFormerModelWrapper(torch.nn.Module):
119
+ def __init__(self, model: "Blip2QFormerModel"):
120
+ super().__init__()
121
+ self.model = model
122
+
123
+ def forward(
124
+ self,
125
+ query_embeds: torch.FloatTensor,
126
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
127
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
128
+ ) -> torch.Tensor:
129
+ qformer_out = self.model(
130
+ query_embeds=query_embeds,
131
+ encoder_hidden_states=encoder_hidden_states,
132
+ encoder_attention_mask=encoder_attention_mask,
133
+ return_dict=False,
134
+ )
135
+ return qformer_out
136
+
137
+ return Blip2QFormerModelWrapper(model).eval()
138
+
139
+ @classmethod
140
+ def _update_submodule_config(cls, model: "PreTrainedModel", rbln_config: "RBLNModelConfig") -> "RBLNModelConfig":
141
+ if rbln_config.num_query_tokens is None:
142
+ rbln_config.num_query_tokens = model.config.num_query_tokens
143
+
144
+ if rbln_config.image_text_hidden_size is None:
145
+ rbln_config.image_text_hidden_size = model.config.image_text_hidden_size
146
+
147
+ return rbln_config
148
+
149
+ @classmethod
150
+ def _update_rbln_config(
151
+ cls,
152
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
153
+ model: Optional["PreTrainedModel"] = None,
154
+ model_config: Optional["PretrainedConfig"] = None,
155
+ rbln_config: Optional[RBLNModelConfig] = None,
156
+ ) -> RBLNModelConfig:
157
+ input_info = [
158
+ (
159
+ "query_embeds",
160
+ [
161
+ rbln_config.batch_size,
162
+ rbln_config.num_query_tokens,
163
+ model_config.hidden_size,
164
+ ],
165
+ "float32",
166
+ ),
167
+ (
168
+ "encoder_hidden_states",
169
+ [
170
+ rbln_config.batch_size,
171
+ # image_text_hidden_size + cls token
172
+ rbln_config.image_text_hidden_size + 1,
173
+ model_config.encoder_hidden_size,
174
+ ],
175
+ "float32",
176
+ ),
177
+ (
178
+ "encoder_attention_mask",
179
+ # image_text_hidden_size + cls token
180
+ [rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
181
+ "int64",
182
+ ),
183
+ ]
184
+
185
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
186
+ rbln_config.set_compile_cfgs([rbln_compile_config])
187
+ return rbln_config
188
+
189
+ def forward(
190
+ self,
191
+ query_embeds: torch.FloatTensor,
192
+ query_length: Optional[int] = None,
193
+ attention_mask: Optional[torch.FloatTensor] = None,
194
+ head_mask: Optional[torch.FloatTensor] = None,
195
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
196
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
197
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
198
+ use_cache: Optional[bool] = None,
199
+ output_attentions: Optional[bool] = None,
200
+ output_hidden_states: Optional[bool] = None,
201
+ return_dict: Optional[bool] = None,
202
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
203
+ output = super().forward(query_embeds, encoder_hidden_states, encoder_attention_mask, return_dict=return_dict)
204
+ return output
205
+
206
+ def _prepare_output(self, output, return_dict):
207
+ """
208
+ Prepare model output based on return_dict flag.
209
+ This method can be overridden by subclasses to provide task-specific output handling.
210
+ """
211
+ if not return_dict:
212
+ return (output,) if not isinstance(output, (tuple, list)) else output
213
+ else:
214
+ return BaseModelOutputWithPoolingAndCrossAttentions(
215
+ last_hidden_state=output[0],
216
+ pooler_output=output[1],
217
+ )
218
+
219
+
220
+ class RBLNBlip2ForConditionalGeneration(RBLNModel):
221
+ auto_model_class = AutoModelForVisualQuestionAnswering
222
+ _rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
223
+
224
+ def __getattr__(self, __name: str) -> Any:
225
+ def redirect(func):
226
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
227
+
228
+ val = getattr(Blip2ForConditionalGeneration, __name)
229
+
230
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
231
+ return redirect(val)
232
+ return val
233
+
234
+ def can_generate(self):
235
+ return True
236
+
237
+ @classmethod
238
+ def save_torch_artifacts(
239
+ cls,
240
+ model: "Blip2ForConditionalGeneration",
241
+ save_dir_path: Path,
242
+ subfolder: str,
243
+ rbln_config: RBLNModelConfig,
244
+ ):
245
+ """
246
+ If you are unavoidably running on a CPU rather than an RBLN device,
247
+ store the torch tensor, weight, etc. in this function.
248
+ """
249
+ save_dict = {}
250
+ save_dict["query_tokens"] = model.query_tokens
251
+ torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
252
+
253
+ def __post_init__(self, **kwargs):
254
+ self.vision_model = self.rbln_submodules[0]
255
+ self.language_model = self.rbln_submodules[2]
256
+ self.qformer = self.rbln_submodules[1]
257
+ self.language_projection = self.model[0]
258
+
259
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "query_tokens.pth", weights_only=False)
260
+ self.query_tokens = artifacts["query_tokens"]
261
+
262
+ def get_attn_impl(self) -> str:
263
+ return self.rbln_config.language_model.attn_impl
264
+
265
+ def get_kvcache_num_blocks(self) -> int:
266
+ return self.rbln_config.language_model.kvcache_num_blocks
267
+
268
+ def get_input_embeddings(self):
269
+ return self.language_model.get_input_embeddings()
270
+
271
+ @classmethod
272
+ def wrap_model_if_needed(cls, model, rbln_config):
273
+ return model.language_projection
274
+
275
+ @classmethod
276
+ def _update_rbln_config(
277
+ cls,
278
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
279
+ model: Optional["PreTrainedModel"] = None,
280
+ model_config: Optional["PretrainedConfig"] = None,
281
+ rbln_config: Optional[RBLNModelConfig] = None,
282
+ ) -> RBLNModelConfig:
283
+ input_info = [
284
+ (
285
+ "query_output",
286
+ [
287
+ rbln_config.batch_size,
288
+ model_config.num_query_tokens,
289
+ model_config.qformer_config.hidden_size,
290
+ ],
291
+ "float32",
292
+ ),
293
+ ]
294
+
295
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
296
+ rbln_config.set_compile_cfgs([rbln_compile_config])
297
+
298
+ return rbln_config
@@ -12,13 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ from typing import Any, Dict, List, Optional, Union
16
16
 
17
17
  import rebel
18
18
 
19
19
  from ....configuration_utils import RBLNModelConfig
20
20
  from ....utils.logging import get_logger
21
- from ...utils.rbln_quantization import QuantizationManager
21
+ from ...utils.rbln_quantization import RBLNQuantizationConfig
22
22
 
23
23
 
24
24
  logger = get_logger()
@@ -31,12 +31,14 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
31
31
  max_seq_len: Optional[int] = None,
32
32
  use_inputs_embeds: Optional[bool] = None,
33
33
  use_attention_mask: Optional[bool] = None,
34
+ use_position_ids: Optional[bool] = None,
34
35
  attn_impl: Optional[str] = None,
35
36
  kvcache_partition_len: Optional[int] = None,
36
37
  kvcache_block_size: Optional[int] = None,
37
- quantization: Optional[Dict[str, Any]] = None,
38
+ quantization: Optional[Union[Dict[str, Any], RBLNQuantizationConfig]] = None,
38
39
  prefill_chunk_size: Optional[int] = None,
39
40
  kvcache_num_blocks: Optional[int] = None,
41
+ decoder_batch_sizes: Optional[List[int]] = None,
40
42
  **kwargs,
41
43
  ):
42
44
  """
@@ -46,6 +48,7 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
46
48
  use_inputs_embeds (Optional[bool]): Whether to use input embeddings directly. Defaults to False.
47
49
  use_attention_mask (Optional[bool]): Whether to use attention masks. This is automatically set to True
48
50
  for RBLN-CA02 devices.
51
+ use_position_ids (Optional[bool]): Whether to use position IDs. Defaults to False.
49
52
  attn_impl (Optional[str]): The attention implementation to use.
50
53
  kvcache_partition_len (Optional[int]): The length of each KV cache partition.
51
54
  kvcache_block_size (Optional[int]): The block size for KV cache.
@@ -53,6 +56,13 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
53
56
  prefill_chunk_size (Optional[int]): The chunk size for prefilling the KV cache. Defaults to 128,
54
57
  and must be a positive integer divisible by 64.
55
58
  kvcache_num_blocks (Optional[int]): The number of blocks in the KV cache.
59
+ decoder_batch_sizes (Optional[List[int]]): A list of batch sizes for which separate decoder models will be compiled.
60
+ This allows the model to handle varying batch sizes efficiently during generation. If not specified,
61
+ defaults to a list containing only the model's main batch size. When specifying multiple batch sizes:
62
+ 1) All values must be less than or equal to the main batch size.
63
+ 2) The list will be sorted in descending order (larger batch sizes first).
64
+ 3) If using multiple decoders, at least one batch size should match the main batch size.
65
+
56
66
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
57
67
 
58
68
  Raises:
@@ -66,8 +76,9 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
66
76
 
67
77
  self.max_seq_len = max_seq_len
68
78
  self.use_inputs_embeds = use_inputs_embeds or False
69
-
79
+ self.use_position_ids = use_position_ids or False
70
80
  self.use_attention_mask = use_attention_mask
81
+
71
82
  npu = self.npu or rebel.get_npu_name()
72
83
  if npu == "RBLN-CA02":
73
84
  if self.use_attention_mask is False:
@@ -76,15 +87,40 @@ class RBLNDecoderOnlyModelForCausalLMConfig(RBLNModelConfig):
76
87
  else:
77
88
  self.use_attention_mask = self.use_attention_mask or False
78
89
 
90
+ if self.use_position_ids and not self.use_attention_mask:
91
+ raise ValueError("Position IDs should be used with attention mask.")
92
+
79
93
  self.attn_impl = attn_impl
80
94
  self.kvcache_partition_len = kvcache_partition_len
81
95
  self.kvcache_block_size = kvcache_block_size
82
96
  self.quantization = quantization or {}
83
- if self.quantization:
84
- QuantizationManager.validate_quantization_config(self.quantization)
97
+ if self.quantization and isinstance(self.quantization, dict):
98
+ self.quantization = RBLNQuantizationConfig(**self.quantization)
85
99
 
86
100
  self.prefill_chunk_size = prefill_chunk_size or 128
87
101
  if self.prefill_chunk_size % 64 != 0 or self.prefill_chunk_size <= 0:
88
102
  raise ValueError("`prefill_chunk_size` must be a positive integer divisible by 64.")
89
103
 
90
104
  self.kvcache_num_blocks = kvcache_num_blocks
105
+ self.decoder_batch_sizes = decoder_batch_sizes
106
+ if self.decoder_batch_sizes is None:
107
+ self.decoder_batch_sizes = [self.batch_size]
108
+
109
+ if self.use_multiple_decoder:
110
+ if max(self.decoder_batch_sizes) > self.batch_size:
111
+ raise ValueError(
112
+ f"Decoder batch size ({max(self.decoder_batch_sizes)}) must be less than or equal to the runtime batch size ({self.batch_size})."
113
+ )
114
+ if max(self.decoder_batch_sizes) < self.batch_size:
115
+ logger.warning(
116
+ f"Maximum decoder batch size ({max(self.decoder_batch_sizes)}) is less than the model's batch size ({self.batch_size}). "
117
+ "Appending the model's batch size to the decoder batch size."
118
+ )
119
+ self.decoder_batch_sizes.append(self.batch_size)
120
+
121
+ # Larger batch size should be at the beginning of the list.
122
+ self.decoder_batch_sizes.sort(reverse=True)
123
+
124
+ @property
125
+ def use_multiple_decoder(self):
126
+ return isinstance(self.decoder_batch_sizes, list) and len(self.decoder_batch_sizes) > 1
@@ -146,7 +146,10 @@ class DecoderOnlyWrapper(nn.Module):
146
146
  max_seq_len: int,
147
147
  use_rotary_emb: bool,
148
148
  attn_impl: str,
149
+ use_inputs_embeds: bool,
149
150
  use_attention_mask: bool,
151
+ use_position_ids: bool,
152
+ use_learned_pos_emb: Optional[bool] = None,
150
153
  kvcache_partition_len: Optional[int] = None,
151
154
  kvcache_block_size: Optional[int] = None,
152
155
  ):
@@ -161,6 +164,10 @@ class DecoderOnlyWrapper(nn.Module):
161
164
  self.attn_impl = attn_impl
162
165
  self.kvcache_block_size = kvcache_block_size
163
166
  self.use_attention_mask = use_attention_mask
167
+ self.use_position_ids = use_position_ids
168
+ self.use_inputs_embeds = use_inputs_embeds
169
+ self.use_learned_pos_emb = use_learned_pos_emb
170
+
164
171
  if self.attn_impl == "flash_attn":
165
172
  self.kvcache_partition_len = kvcache_partition_len or DEFAULT_FLASH_ATTN_PARTITION_LENGTH
166
173
  elif self.attn_impl == "eager":
@@ -209,6 +216,7 @@ class DecoderOnlyWrapper(nn.Module):
209
216
  partition_len=self.kvcache_partition_len,
210
217
  max_seq_len=max_seq_len,
211
218
  kvcache_block_size=self.kvcache_block_size,
219
+ use_learned_pos_emb=self.use_learned_pos_emb,
212
220
  )
213
221
  new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
214
222
  return new_causal_lm
@@ -222,24 +230,16 @@ class DecoderOnlyWrapper(nn.Module):
222
230
  self._phase = phase
223
231
  self.causal_lm.phase = phase
224
232
 
225
- def forward_common(
226
- self,
227
- input_ids_or_inputs_embeds: torch.Tensor,
228
- cache_position: torch.Tensor,
229
- attention_mask: torch.Tensor,
230
- query_position: torch.Tensor,
231
- block_tables: torch.Tensor,
232
- rotary_emb: Union[nn.Module, torch.Tensor],
233
- *past_key_values: List[torch.Tensor],
234
- ):
235
- if input_ids_or_inputs_embeds.ndim == 2:
236
- input_ids = input_ids_or_inputs_embeds
237
- inputs_embeds = None
238
- elif input_ids_or_inputs_embeds.ndim == 3:
239
- input_ids = None
240
- inputs_embeds = input_ids_or_inputs_embeds
241
- else:
242
- raise NotImplementedError(f"Unknown ndim of input : {input_ids_or_inputs_embeds.ndim}")
233
+ def prepare_forward_args(self, *args):
234
+ args = list(args)
235
+ input_ids = None if self.use_inputs_embeds else args.pop(0)
236
+ inputs_embeds = args.pop(0) if self.use_inputs_embeds else None
237
+ cache_position = args.pop(0)
238
+ block_tables = args.pop(0)
239
+ query_position = args.pop(0) if self.phase == "prefill" else None
240
+ attention_mask = args.pop(0) if self.use_attention_mask else None
241
+ position_ids = args.pop(0) if self.use_position_ids else None
242
+ past_key_values = args
243
243
 
244
244
  if len(past_key_values) != 2 * self.num_hidden_layers:
245
245
  raise ValueError(
@@ -256,11 +256,37 @@ class DecoderOnlyWrapper(nn.Module):
256
256
  _past_key_values.append(past_key_value)
257
257
  past_key_values = _past_key_values
258
258
 
259
+ return (
260
+ input_ids,
261
+ inputs_embeds,
262
+ cache_position,
263
+ block_tables,
264
+ query_position,
265
+ attention_mask,
266
+ position_ids,
267
+ past_key_values,
268
+ self.rotary_emb,
269
+ )
270
+
271
+ def forward(self, *args):
272
+ (
273
+ input_ids,
274
+ inputs_embeds,
275
+ cache_position,
276
+ block_tables,
277
+ query_position,
278
+ attention_mask,
279
+ position_ids,
280
+ past_key_values,
281
+ rotary_emb,
282
+ ) = self.prepare_forward_args(*args)
283
+
259
284
  logit = self.causal_lm(
260
285
  input_ids=input_ids,
261
286
  inputs_embeds=inputs_embeds,
262
287
  attention_mask=attention_mask,
263
288
  cache_position=cache_position,
289
+ position_ids=position_ids,
264
290
  query_position=query_position,
265
291
  past_key_values=past_key_values,
266
292
  rotary_emb=rotary_emb,
@@ -269,58 +295,6 @@ class DecoderOnlyWrapper(nn.Module):
269
295
 
270
296
  return logit
271
297
 
272
- def forward(self, *args):
273
- if self.phase == "decode":
274
- if self.use_attention_mask:
275
- (
276
- input_ids_or_inputs_embeds,
277
- cache_position,
278
- attention_mask,
279
- block_tables,
280
- *past_key_values,
281
- ) = args
282
- else:
283
- (
284
- input_ids_or_inputs_embeds,
285
- cache_position,
286
- block_tables,
287
- *past_key_values,
288
- ) = args
289
- attention_mask = None
290
- query_position = None
291
- elif self.phase == "prefill":
292
- if self.use_attention_mask:
293
- (
294
- input_ids_or_inputs_embeds,
295
- cache_position,
296
- attention_mask,
297
- query_position,
298
- block_tables,
299
- *past_key_values,
300
- ) = args
301
- else:
302
- (
303
- input_ids_or_inputs_embeds,
304
- cache_position,
305
- query_position,
306
- block_tables,
307
- *past_key_values,
308
- ) = args
309
- attention_mask = None
310
-
311
- else:
312
- raise ValueError(f"Unknown phase: {self.phase}")
313
-
314
- return self.forward_common(
315
- input_ids_or_inputs_embeds,
316
- cache_position,
317
- attention_mask,
318
- query_position,
319
- block_tables,
320
- self.rotary_emb,
321
- *past_key_values,
322
- )
323
-
324
298
 
325
299
  class DecoderOnlyForCausalLM(nn.Module):
326
300
  """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
@@ -367,6 +341,7 @@ class DecoderOnlyForCausalLM(nn.Module):
367
341
  inputs_embeds: torch.Tensor = None,
368
342
  attention_mask: torch.Tensor = None,
369
343
  cache_position: torch.Tensor = None,
344
+ position_ids: torch.Tensor = None,
370
345
  query_position: torch.Tensor = None,
371
346
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
372
347
  rotary_emb: nn.Module = None,
@@ -378,6 +353,7 @@ class DecoderOnlyForCausalLM(nn.Module):
378
353
  inputs_embeds=inputs_embeds,
379
354
  attention_mask=attention_mask,
380
355
  cache_position=cache_position,
356
+ position_ids=position_ids,
381
357
  past_key_values=past_key_values,
382
358
  rotary_emb=rotary_emb,
383
359
  block_tables=block_tables,
@@ -404,7 +380,13 @@ class DecoderOnlyModel(nn.Module):
404
380
  """
405
381
 
406
382
  def __init__(
407
- self, model, layers: List["DecoderOnlyLayer"], partition_len=None, max_seq_len=None, kvcache_block_size=None
383
+ self,
384
+ model,
385
+ layers: List["DecoderOnlyLayer"],
386
+ partition_len=None,
387
+ max_seq_len=None,
388
+ kvcache_block_size=None,
389
+ use_learned_pos_emb=None,
408
390
  ):
409
391
  super().__init__()
410
392
  self._original_mod = model
@@ -413,6 +395,7 @@ class DecoderOnlyModel(nn.Module):
413
395
  self.partition_len = partition_len
414
396
  self.kvcache_block_size = kvcache_block_size
415
397
  self.max_seq_len = max_seq_len
398
+ self.use_learned_pos_emb = use_learned_pos_emb
416
399
 
417
400
  @property
418
401
  def phase(self):
@@ -457,11 +440,12 @@ class DecoderOnlyModel(nn.Module):
457
440
  def forward(
458
441
  self,
459
442
  input_ids: torch.Tensor = None,
460
- inputs_embeds: torch.Tensor = None,
443
+ inputs_embeds: Optional[torch.Tensor] = None,
461
444
  attention_mask: torch.Tensor = None,
462
445
  cache_position: torch.Tensor = None,
446
+ position_ids: torch.Tensor = None,
463
447
  past_key_values: Tuple[Tuple[torch.Tensor]] = None,
464
- rotary_emb: nn.Module = None,
448
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
465
449
  block_tables: Optional[torch.Tensor] = None,
466
450
  ):
467
451
  # retrieve input_ids and inputs_embeds
@@ -477,24 +461,38 @@ class DecoderOnlyModel(nn.Module):
477
461
  hidden_states = inputs_embeds * self.hidden_multiplier
478
462
 
479
463
  # get cos,sin vector if needed
464
+ position_ids = position_ids if position_ids is not None else cache_position
480
465
  if rotary_emb is not None:
481
466
  if isinstance(rotary_emb, torch.Tensor):
482
467
  cos = rotary_emb[0]
483
468
  sin = rotary_emb[1]
484
469
  else:
485
470
  cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
486
- cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, cache_position)
471
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
472
+
473
+ elif self.use_learned_pos_emb:
474
+ batch_size = inputs_embeds.shape[0]
475
+ hidden_all = []
476
+ for i in range(batch_size):
477
+ positions_idx = position_ids[i]
478
+ position_weight = self.get_pos_embedding().weight[2:]
479
+ position = position_weight[positions_idx]
480
+ batch_hidden = position + inputs_embeds[i]
481
+ hidden_all.append(batch_hidden)
482
+ hidden_states = torch.stack(hidden_all, dim=0)
483
+ cos, sin = None, None
484
+
487
485
  else:
488
486
  batch_size = inputs_embeds.shape[0]
489
- if cache_position.shape[0] > 1:
487
+ if position_ids.shape[0] > 1:
490
488
  position_embeds = []
491
489
  for b_idx in range(batch_size):
492
- position_embed = self.get_pos_embedding()(cache_position[b_idx])
490
+ position_embed = self.get_pos_embedding()(position_ids[b_idx])
493
491
  position_embeds.append(position_embed)
494
492
 
495
493
  position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
496
494
  else:
497
- position_embeds = self.get_pos_embedding()(cache_position)
495
+ position_embeds = self.get_pos_embedding()(position_ids)
498
496
  hidden_states = hidden_states + position_embeds
499
497
  cos, sin = None, None
500
498
 
@@ -798,6 +796,7 @@ class AttentionOp(nn.Module):
798
796
  scale=scale,
799
797
  block_table=block_tables,
800
798
  block_size=block_size,
799
+ mask=None,
801
800
  )
802
801
 
803
802
  else:
@@ -825,6 +824,8 @@ class AttentionOp(nn.Module):
825
824
  scale=scale,
826
825
  block_table=block_tables,
827
826
  block_size=block_size,
827
+ is_bidirectional=False,
828
+ mask=None,
828
829
  )
829
830
 
830
831
  attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
@@ -1058,6 +1059,7 @@ class FlashAttentionOp(AttentionOp):
1058
1059
  block_table=block_tables,
1059
1060
  block_size=kvcache_block_size,
1060
1061
  partition=self.kvcache_partition_size,
1062
+ mask=None,
1061
1063
  )
1062
1064
  else:
1063
1065
  if self.use_attention_mask:
@@ -1086,6 +1088,8 @@ class FlashAttentionOp(AttentionOp):
1086
1088
  block_table=block_tables,
1087
1089
  block_size=kvcache_block_size,
1088
1090
  partition=self.kvcache_partition_size,
1091
+ is_bidirectional=False,
1092
+ mask=None,
1089
1093
  )
1090
1094
 
1091
1095
  # reshape for removing repeat_kv