optimum-rbln 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. optimum/rbln/__init__.py +22 -12
  2. optimum/rbln/__version__.py +16 -1
  3. optimum/rbln/diffusers/__init__.py +22 -2
  4. optimum/rbln/diffusers/models/__init__.py +34 -3
  5. optimum/rbln/{transformers/generation → diffusers/models/autoencoders}/__init__.py +1 -2
  6. optimum/rbln/diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +44 -58
  7. optimum/rbln/diffusers/models/autoencoders/vae.py +84 -0
  8. optimum/rbln/diffusers/models/controlnet.py +54 -14
  9. optimum/rbln/diffusers/models/transformers/__init__.py +24 -0
  10. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +203 -0
  11. optimum/rbln/diffusers/models/unets/__init__.py +24 -0
  12. optimum/rbln/diffusers/models/{unet_2d_condition.py → unets/unet_2d_condition.py} +78 -16
  13. optimum/rbln/diffusers/pipelines/__init__.py +22 -2
  14. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +5 -26
  15. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -0
  16. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +1 -0
  17. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -0
  18. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  19. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +1 -0
  20. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +0 -11
  21. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +32 -0
  22. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +26 -0
  23. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +32 -0
  24. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +32 -0
  25. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +32 -0
  26. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +1 -0
  27. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +14 -6
  28. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +14 -6
  29. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +32 -0
  30. optimum/rbln/modeling.py +572 -0
  31. optimum/rbln/modeling_alias.py +1 -1
  32. optimum/rbln/modeling_base.py +164 -758
  33. optimum/rbln/modeling_diffusers.py +51 -122
  34. optimum/rbln/transformers/__init__.py +0 -2
  35. optimum/rbln/transformers/models/auto/auto_factory.py +117 -23
  36. optimum/rbln/transformers/models/auto/modeling_auto.py +37 -12
  37. optimum/rbln/transformers/models/bart/modeling_bart.py +3 -6
  38. optimum/rbln/transformers/models/bert/modeling_bert.py +3 -6
  39. optimum/rbln/transformers/models/clip/modeling_clip.py +8 -25
  40. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -3
  41. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +672 -412
  42. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +38 -155
  43. optimum/rbln/transformers/models/dpt/modeling_dpt.py +1 -1
  44. optimum/rbln/transformers/models/exaone/exaone_architecture.py +61 -45
  45. optimum/rbln/transformers/models/exaone/modeling_exaone.py +4 -2
  46. optimum/rbln/transformers/models/gemma/gemma_architecture.py +33 -104
  47. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +50 -238
  48. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +3 -2
  49. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +2 -75
  50. optimum/rbln/transformers/models/midm/midm_architecture.py +88 -242
  51. optimum/rbln/transformers/models/midm/modeling_midm.py +6 -6
  52. optimum/rbln/transformers/models/phi/phi_architecture.py +61 -261
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +1 -46
  54. optimum/rbln/transformers/models/t5/modeling_t5.py +102 -4
  55. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +1 -1
  57. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +3 -35
  58. optimum/rbln/transformers/utils/rbln_quantization.py +120 -3
  59. optimum/rbln/utils/decorator_utils.py +10 -6
  60. optimum/rbln/utils/hub.py +131 -0
  61. optimum/rbln/utils/import_utils.py +15 -1
  62. optimum/rbln/utils/model_utils.py +53 -0
  63. optimum/rbln/utils/runtime_utils.py +1 -1
  64. optimum/rbln/utils/submodule.py +114 -0
  65. optimum_rbln-0.1.15.dist-info/METADATA +106 -0
  66. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/RECORD +69 -66
  67. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/WHEEL +1 -1
  68. optimum/rbln/transformers/generation/streamers.py +0 -139
  69. optimum/rbln/transformers/generation/utils.py +0 -397
  70. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +0 -181
  71. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +0 -1725
  72. optimum/rbln/transformers/models/midm/hf_hub_cached/configuration_midm.py +0 -22
  73. optimum/rbln/transformers/models/midm/hf_hub_cached/midm_bitext_tokenization.py +0 -304
  74. optimum/rbln/transformers/models/midm/hf_hub_cached/modeling_midm.py +0 -1469
  75. optimum/rbln/transformers/models/midm/hf_hub_cached/rotary_position_embedding.py +0 -98
  76. optimum/rbln/utils/context.py +0 -58
  77. optimum_rbln-0.1.13.dist-info/METADATA +0 -120
  78. optimum_rbln-0.1.13.dist-info/entry_points.txt +0 -4
  79. {optimum_rbln-0.1.13.dist-info → optimum_rbln-0.1.15.dist-info}/licenses/LICENSE +0 -0
@@ -21,262 +21,74 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- from typing import Dict, Optional, Tuple, Union
24
+ from typing import TYPE_CHECKING, Tuple
25
25
 
26
26
  import torch
27
27
  import torch.nn as nn
28
- from transformers.modeling_outputs import BaseModelOutputWithPast
29
28
 
30
- from ...cache_utils import RebelDynamicCache_4D
29
+ from ..decoderonly.decoderonly_architecture import (
30
+ DecoderOnlyAttention,
31
+ DecoderOnlyForCausalLM,
32
+ DecoderOnlyLayer,
33
+ DecoderOnlyModel,
34
+ DecoderOnlyWrapper,
35
+ )
31
36
 
32
37
 
33
- class GPT2LMHeadModelWrapper(torch.nn.Module):
34
- def __init__(self, model, max_seq_len):
35
- super().__init__()
36
- self.model = model.transformer
37
- self.lm_head = model.lm_head
38
- self.config = model.config
39
- self.max_seq_len = max_seq_len
40
- self.forward_dict = self.get_forward_dict()
38
+ if TYPE_CHECKING:
39
+ from transformers import GPT2LMHeadModel
41
40
 
42
- def get_forward_dict(self):
43
- forward_dict = {
44
- "wrapper": _GPT2Model.forward,
45
- "model": _GPT2Block.forward,
46
- "decoder_layer": _GPT2Attention.forward,
47
- }
48
- return forward_dict
49
41
 
50
- def forward(
51
- self,
52
- input_ids,
53
- attention_mask,
54
- cache_position,
55
- batch_position,
56
- query_idx,
57
- *past_key_values,
58
- ):
59
- if input_ids.shape[1] == 1:
60
- rbln_batch_position = None
61
- else:
62
- rbln_batch_position = batch_position
42
+ class GPT2Wrapper(DecoderOnlyWrapper):
43
+ def convert_to_rbln_causal_lm(self, causal_lm: "GPT2LMHeadModel"):
44
+ if self.attn_impl != "eager":
45
+ raise NotImplementedError(f"flash attention ({self.attn_impl}) is not implemented for {self.__class__}")
46
+ new_layers = []
47
+ for layer in causal_lm.transformer.h:
48
+ new_self_attn = GPT2Attention(layer.attn)
49
+ new_layer = GPT2Layer(layer, new_self_attn)
50
+ new_layers.append(new_layer)
51
+ new_model = GPT2Model(causal_lm.transformer, new_layers)
52
+ new_causal_lm = DecoderOnlyForCausalLM(causal_lm, new_model)
53
+ return new_causal_lm
63
54
 
64
- # Formatting list of past_kv to DynamicCache class.
65
- past_key_value = RebelDynamicCache_4D.from_input_format(
66
- cache_position,
67
- self.config.n_layer,
68
- *past_key_values,
69
- )
70
-
71
- outputs = self.forward_dict["wrapper"](
72
- self.model,
73
- input_ids=input_ids,
74
- attention_mask=attention_mask,
75
- position_ids=cache_position,
76
- past_key_value=past_key_value,
77
- batch_ids=rbln_batch_position,
78
- forward_dict=self.forward_dict,
79
- # rotary_emb differenct from_llama
80
- )
81
-
82
- hidden_states = outputs[0]
83
- if batch_position >= 0:
84
- hidden_states = hidden_states[:, query_idx].unsqueeze(1)
85
- logits = self.lm_head(hidden_states)
86
-
87
- output = (logits,) + outputs[1:]
88
55
 
89
- return output, batch_position + query_idx
56
+ class GPT2Model(DecoderOnlyModel):
57
+ mask_fmin = torch.finfo(torch.float32).min
90
58
 
59
+ def get_last_layernorm(self) -> nn.LayerNorm:
60
+ return self._original_mod.ln_f
91
61
 
92
- class _GPT2Model:
93
- def forward(
94
- self,
95
- input_ids: torch.LongTensor = None,
96
- attention_mask: Optional[torch.Tensor] = None,
97
- position_ids: Optional[torch.LongTensor] = None,
98
- past_key_value: Optional[RebelDynamicCache_4D] = None,
99
- batch_ids: Optional[torch.LongTensor] = None,
100
- forward_dict: Optional[Dict[str, classmethod]] = None,
101
- ) -> BaseModelOutputWithPast:
102
- b_size, q_len = input_ids.shape
103
- inputs_embeds = self.wte(input_ids)
62
+ def get_embedding(self) -> nn.Embedding:
63
+ return self._original_mod.wte
104
64
 
105
- if position_ids.shape[0] > 1:
106
- position_embeds = []
107
- for b_idx in range(b_size):
108
- position_embed = self.wpe(position_ids[b_idx])
109
- # position_embed = position_embed.dtype(inputs_embeds.dtype)
110
- position_embeds.append(position_embed)
65
+ def get_pos_embedding(self) -> nn.Embedding:
66
+ return self._original_mod.wpe
111
67
 
112
- position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
113
- else:
114
- position_embeds = self.wpe(position_ids)
115
68
 
116
- hidden_states = inputs_embeds + position_embeds
69
+ class GPT2Layer(DecoderOnlyLayer):
70
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.ln_1
117
72
 
118
- # GPT2Attention mask.
119
- # Here we assume mask is causal mask, (batch, 1, query_length, key_length + query_length)
120
- attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
73
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
74
+ return self._original_mod.ln_2
121
75
 
122
- for layer_idx, block in enumerate(self.h):
123
- hidden_states, updated_cache = forward_dict["model"](
124
- block,
125
- hidden_states,
126
- layer_idx,
127
- attention_mask=attention_mask,
128
- past_key_value=past_key_value,
129
- position_ids=position_ids,
130
- batch_ids=batch_ids,
131
- forward_dict=forward_dict,
132
- )
133
-
134
- hidden_states = self.ln_f(hidden_states)
135
- output_shape = (-1,) + (q_len,) + (hidden_states.size(-1),)
136
- hidden_states = hidden_states.view(output_shape)
137
-
138
- # convert RebelDynamicCache to legacy Tuple[Tuple[torch.Tensor]]
139
- next_cache = updated_cache.to_legacy_cache()
140
-
141
- return BaseModelOutputWithPast(
142
- last_hidden_state=hidden_states,
143
- past_key_values=next_cache,
144
- )
145
76
 
77
+ class GPT2Attention(DecoderOnlyAttention):
78
+ def __post_init__(self):
79
+ self.c_attn = self._original_mod.c_attn
80
+ self.o_proj = self._original_mod.c_proj
81
+ self.split_size = self._original_mod.split_size
82
+ self.num_key_value_heads = self._original_mod.num_heads
146
83
 
147
- class _GPT2Block:
148
- def forward(
149
- self,
150
- hidden_states: Optional[Tuple[torch.FloatTensor]],
151
- layer_idx: int,
152
- attention_mask: Optional[torch.FloatTensor] = None,
153
- position_ids: Optional[torch.LongTensor] = None,
154
- past_key_value: Optional[RebelDynamicCache_4D] = None,
155
- batch_ids: Optional[torch.LongTensor] = None,
156
- forward_dict: Optional[Dict[str, classmethod]] = None,
157
- **kwargs,
158
- ) -> Tuple[torch.Tensor, RebelDynamicCache_4D]:
159
- residual = hidden_states
160
- hidden_states = self.ln_1(hidden_states)
84
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
85
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
86
+ return query_states, key_states, value_states
161
87
 
162
- hidden_states, k, v = forward_dict["decoder_layer"](
163
- self.attn,
164
- hidden_states=hidden_states,
165
- attention_mask=attention_mask,
166
- position_ids=position_ids,
167
- past_key_value=past_key_value,
168
- batch_index=batch_ids,
88
+ def rbln_attention(self, *args, **kwargs):
89
+ return super().rbln_attention(
90
+ *args,
91
+ **kwargs,
92
+ layer_idx=self.layer_idx,
93
+ scale_attn_by_inverse_layer_idx=self._original_mod.scale_attn_by_inverse_layer_idx,
169
94
  )
170
- past_key_value.assign(k, v, layer_idx)
171
-
172
- # residual connection
173
- hidden_states = residual + hidden_states
174
-
175
- residual = hidden_states
176
- hidden_states = self.ln_2(hidden_states)
177
- hidden_states = self.mlp(hidden_states)
178
- hidden_states = residual + hidden_states
179
-
180
- return hidden_states, past_key_value
181
-
182
-
183
- class _GPT2Attention:
184
- def _attn(self, query, key, value, attention_mask=None, head_mask=None):
185
- attn_weights = torch.matmul(query, key.transpose(-1, -2))
186
-
187
- if self.scale_attn_weights:
188
- attn_weights = attn_weights / torch.full(
189
- [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
190
- )
191
-
192
- # Layer-wise attention scaling
193
- if self.scale_attn_by_inverse_layer_idx:
194
- attn_weights = attn_weights / float(self.layer_idx + 1)
195
-
196
- # -------------------
197
- # Below are deleted since "where" op does not supported on RBLN graph.
198
- # -------------------
199
- # if not self.is_cross_attention:
200
- # # if only "normal" attention layer implements causal mask
201
- # query_length, key_length = query.size(-2), key.size(-2)
202
- # causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
203
- # mask_value = torch.finfo(attn_weights.dtype).min
204
- # # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
205
- # # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
206
- # mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
207
- # attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
208
-
209
- # Apply the attention mask
210
- attn_weights.view(
211
- -1,
212
- )
213
- attn_weights = attn_weights + attention_mask
214
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
215
- attn_output = torch.matmul(attn_weights, value)
216
-
217
- return attn_output, attn_weights
218
-
219
- def forward(
220
- self,
221
- hidden_states: Optional[Tuple[torch.FloatTensor]],
222
- attention_mask: Optional[torch.FloatTensor] = None,
223
- past_key_value: Optional[RebelDynamicCache_4D] = None,
224
- batch_index: Optional[int] = None,
225
- **kwargs,
226
- ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
227
- bsz, q_len, _ = hidden_states.size()
228
- query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
229
-
230
- querys = self._split_heads(query, self.num_heads, self.head_dim) # (batch, head, seq_length, head_features)
231
- keys = self._split_heads(key, self.num_heads, self.head_dim)
232
- values = self._split_heads(value, self.num_heads, self.head_dim)
233
-
234
- # Decoder
235
- if (batch_index is None or batch_index == -1) and bsz > 1:
236
- all_keys = []
237
- all_values = []
238
- all_attn_output = []
239
-
240
- for b in range(bsz):
241
- query = querys[b].unsqueeze(0)
242
- attn_mask = attention_mask[b].unsqueeze(0)
243
- key = keys[b].unsqueeze(0)
244
- value = values[b].unsqueeze(0)
245
-
246
- key, value = past_key_value.update(
247
- key,
248
- value,
249
- self.layer_idx,
250
- b,
251
- )
252
-
253
- attn_output, _ = _GPT2Attention._attn(self, query, key, value, attn_mask)
254
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
255
-
256
- all_keys.append(key)
257
- all_values.append(value)
258
- all_attn_output.append(attn_output)
259
-
260
- keys = torch.cat(all_keys, dim=0)
261
- values = torch.cat(all_values, dim=0)
262
- attn_output = torch.cat(all_attn_output, dim=0)
263
-
264
- # Prefill
265
- else:
266
- if batch_index is None or batch_index == -1:
267
- batch_index = 0
268
-
269
- keys, values = past_key_value.update(
270
- keys,
271
- values,
272
- self.layer_idx,
273
- batch_index,
274
- read_first_step=True,
275
- )
276
-
277
- attn_output, _ = _GPT2Attention._attn(self, querys, keys, values, attention_mask)
278
- attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
279
-
280
- attn_output = self.c_proj(attn_output)
281
-
282
- return attn_output, keys, values
@@ -23,7 +23,7 @@
23
23
 
24
24
  from ....utils import logging
25
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
26
- from .gpt2_architecture import GPT2LMHeadModelWrapper
26
+ from .gpt2_architecture import GPT2Wrapper # GPT2LMHeadModelWrapper
27
27
 
28
28
 
29
29
  logger = logging.get_logger(__name__)
@@ -43,4 +43,5 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
43
43
 
44
44
  """
45
45
 
46
- _decoder_wrapper_cls = GPT2LMHeadModelWrapper
46
+ _decoder_wrapper_cls = GPT2Wrapper
47
+ _use_rotary_emb = False
@@ -23,7 +23,7 @@
23
23
  import inspect
24
24
  import logging
25
25
  from pathlib import Path
26
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
26
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
27
27
 
28
28
  import numpy as np
29
29
  import torch
@@ -36,7 +36,7 @@ from transformers import (
36
36
  from transformers.modeling_outputs import BaseModelOutputWithPooling
37
37
  from transformers.models.llava_next.modeling_llava_next import LlavaNextCausalLMOutputWithPast
38
38
 
39
- from ....modeling_base import RBLNModel
39
+ from ....modeling import RBLNModel
40
40
  from ....modeling_config import RBLNCompileConfig, RBLNConfig
41
41
  from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyOutput
42
42
 
@@ -166,19 +166,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
166
166
  self._padding_side = "left" # set it to left by default, user can use setter to change padding_sides
167
167
  return super().__post_init__(**kwargs)
168
168
 
169
- @classmethod
170
- def get_pytorch_model(
171
- cls,
172
- model_id: str,
173
- *args,
174
- rbln_kwargs: Optional[Dict[str, Any]] = None,
175
- **kwargs,
176
- ) -> "PreTrainedModel":
177
- # Optimum's TasksManager does not handle Llava.
178
- kwargs = cls.update_kwargs(kwargs)
179
- model = LlavaNextForConditionalGeneration.from_pretrained(model_id, *args, **kwargs)
180
- return model
181
-
182
169
  def get_input_embeddings(self):
183
170
  return self.language_model.get_input_embeddings()
184
171
 
@@ -422,66 +409,6 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
422
409
 
423
410
  return outputs
424
411
 
425
- def vllm_forward(
426
- self,
427
- input_ids: torch.LongTensor = None,
428
- pixel_values: torch.FloatTensor = None,
429
- image_sizes: Optional[torch.LongTensor] = None,
430
- inputs_embeds: Optional[torch.FloatTensor] = None,
431
- vision_feature_layer: Optional[int] = None,
432
- vision_feature_select_strategy: Optional[str] = None,
433
- cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
434
- batch_idx: Optional[int] = None,
435
- **kwargs,
436
- ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
437
- is_prefill = cache_position.shape[-1] > 1
438
-
439
- if inputs_embeds is not None:
440
- raise NotImplementedError("Specifying inputs_embeds is not supported.")
441
-
442
- if is_prefill:
443
- # Get text_embeds
444
- inputs_embeds = self.text_embedding(input_ids)
445
-
446
- # If any images in the prompt, get image_embeds and merge with text
447
- if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
448
- image_features, _ = self.image_embedding(
449
- image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
450
- )
451
-
452
- def merge_vllm_multimodal_embeddings(
453
- input_ids: torch.Tensor,
454
- inputs_embeds: torch.Tensor,
455
- multimodal_embeddings: torch.Tensor,
456
- placeholder_token_id: int,
457
- ) -> torch.Tensor:
458
- mask = input_ids == placeholder_token_id
459
- num_expected_tokens = mask.sum().item()
460
-
461
- if multimodal_embeddings.shape[0] != num_expected_tokens:
462
- raise ValueError(
463
- f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
464
- f"multimodal tokens to {num_expected_tokens} placeholders"
465
- )
466
-
467
- inputs_embeds[mask] = multimodal_embeddings
468
- return inputs_embeds
469
-
470
- inputs_embeds = merge_vllm_multimodal_embeddings(
471
- input_ids, inputs_embeds, image_features, self.config.image_token_index
472
- )
473
-
474
- else:
475
- inputs_embeds = self.text_embedding(input_ids=input_ids)
476
-
477
- outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
478
- inputs_embeds=inputs_embeds,
479
- batch_idx=batch_idx,
480
- cache_position=cache_position,
481
- )
482
-
483
- return outputs
484
-
485
412
  # Almost copied from : https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/llava_next/modeling_llava_next.py
486
413
  def pack_image_features(self, image_features, image_sizes, image_newline=None):
487
414
  """