optimum-rbln 0.1.11__py3-none-any.whl → 0.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. optimum/rbln/__init__.py +14 -7
  2. optimum/rbln/__version__.py +1 -1
  3. optimum/rbln/diffusers/models/autoencoder_kl.py +30 -63
  4. optimum/rbln/diffusers/models/controlnet.py +36 -62
  5. optimum/rbln/diffusers/models/unet_2d_condition.py +57 -156
  6. optimum/rbln/diffusers/pipelines/__init__.py +40 -12
  7. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +11 -0
  8. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -187
  9. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -192
  10. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +8 -206
  11. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -207
  12. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -111
  13. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +12 -117
  14. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +4 -123
  15. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +4 -126
  16. optimum/rbln/modeling_alias.py +4 -9
  17. optimum/rbln/modeling_base.py +117 -144
  18. optimum/rbln/modeling_config.py +51 -0
  19. optimum/rbln/modeling_diffusers.py +400 -0
  20. optimum/rbln/transformers/__init__.py +10 -0
  21. optimum/rbln/transformers/cache_utils.py +5 -9
  22. optimum/rbln/transformers/modeling_rope_utils.py +283 -0
  23. optimum/rbln/transformers/models/__init__.py +80 -28
  24. optimum/rbln/transformers/models/auto/modeling_auto.py +1 -0
  25. optimum/rbln/transformers/models/bart/__init__.py +1 -1
  26. optimum/rbln/transformers/models/bart/bart_architecture.py +18 -12
  27. optimum/rbln/transformers/models/bart/modeling_bart.py +25 -6
  28. optimum/rbln/transformers/models/bert/modeling_bert.py +1 -2
  29. optimum/rbln/transformers/models/clip/modeling_clip.py +13 -23
  30. optimum/rbln/transformers/models/decoderonly/__init__.py +0 -2
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +376 -218
  32. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +246 -116
  33. optimum/rbln/transformers/models/dpt/modeling_dpt.py +0 -1
  34. optimum/rbln/transformers/models/exaone/__init__.py +32 -0
  35. optimum/rbln/transformers/models/exaone/exaone_architecture.py +81 -0
  36. optimum/rbln/transformers/models/exaone/hf_hub_cached/configuration_exaone.py +181 -0
  37. optimum/rbln/transformers/models/exaone/hf_hub_cached/modeling_exaone.py +1725 -0
  38. optimum/rbln/transformers/models/exaone/modeling_exaone.py +53 -0
  39. optimum/rbln/transformers/models/gemma/gemma_architecture.py +12 -2
  40. optimum/rbln/transformers/models/gemma/modeling_gemma.py +4 -28
  41. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +4 -30
  42. optimum/rbln/transformers/models/llama/modeling_llama.py +4 -28
  43. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +166 -151
  44. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -15
  45. optimum/rbln/transformers/models/midm/modeling_midm.py +8 -28
  46. optimum/rbln/transformers/models/mistral/modeling_mistral.py +4 -29
  47. optimum/rbln/transformers/models/phi/modeling_phi.py +5 -31
  48. optimum/rbln/transformers/models/phi/phi_architecture.py +75 -159
  49. optimum/rbln/transformers/models/qwen2/__init__.py +24 -0
  50. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +43 -0
  51. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +29 -0
  52. optimum/rbln/transformers/models/seq2seq/__init__.py +24 -0
  53. optimum/rbln/{modeling_seq2seq.py → transformers/models/seq2seq/modeling_seq2seq.py} +107 -166
  54. optimum/rbln/transformers/models/t5/__init__.py +1 -0
  55. optimum/rbln/transformers/models/t5/modeling_t5.py +108 -0
  56. optimum/rbln/transformers/models/t5/t5_architecture.py +46 -32
  57. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +0 -1
  58. optimum/rbln/transformers/models/whisper/modeling_whisper.py +38 -13
  59. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +1 -2
  60. optimum/rbln/transformers/utils/rbln_quantization.py +8 -2
  61. optimum/rbln/utils/context.py +58 -0
  62. optimum/rbln/utils/decorator_utils.py +55 -0
  63. optimum/rbln/utils/import_utils.py +21 -0
  64. optimum/rbln/utils/logging.py +1 -1
  65. optimum/rbln/utils/runtime_utils.py +4 -4
  66. optimum/rbln/utils/timer_utils.py +26 -2
  67. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/METADATA +11 -9
  68. optimum_rbln-0.1.13.dist-info/RECORD +107 -0
  69. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/WHEEL +1 -1
  70. optimum_rbln-0.1.11.dist-info/RECORD +0 -93
  71. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/entry_points.txt +0 -0
  72. {optimum_rbln-0.1.11.dist-info → optimum_rbln-0.1.13.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,53 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from ....utils import logging
25
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
26
+ from .exaone_architecture import ExaoneForCausalLMWrapper
27
+ from .hf_hub_cached.modeling_exaone import ExaoneForCausalLM
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class RBLNExaoneForCausalLM(RBLNDecoderOnlyModelForCausalLM):
34
+ """
35
+ The Exaone Model transformer with a language modeling head on top (linear layer with weights tied to the input
36
+ embeddings).
37
+
38
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
39
+ library implements for all its model.
40
+
41
+ It implements the methods to convert a pre-trained transformers Exaone model into a RBLN transformer model by:
42
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
43
+ - compiling the resulting graph using the RBLN compiler.
44
+
45
+ """
46
+
47
+ _decoder_wrapper_cls = ExaoneForCausalLMWrapper
48
+ _original_cls = ExaoneForCausalLM
49
+
50
+ @classmethod
51
+ def from_pretrained(cls, *args, **kwargs):
52
+ kwargs.setdefault("trust_remote_code", True)
53
+ return super().from_pretrained(*args, **kwargs)
@@ -29,11 +29,11 @@ from transformers.modeling_outputs import (
29
29
  )
30
30
 
31
31
  from ...models.decoderonly import (
32
- DecoderOnlyAttention,
33
32
  DecoderOnlyDecoderLayer,
34
33
  DecoderOnlyWrapper,
35
34
  slice_and_unsqueeze_cos_sin,
36
35
  )
36
+ from ...models.decoderonly.decoderonly_architecture import DECODERONLY_ATTENTION_CLASSES
37
37
 
38
38
 
39
39
  class GemmaWrapper(DecoderOnlyWrapper):
@@ -43,7 +43,7 @@ class GemmaWrapper(DecoderOnlyWrapper):
43
43
  {
44
44
  "wrapper": GemmaModel.forward,
45
45
  "model": DecoderOnlyDecoderLayer.forward,
46
- "decoder_layer": DecoderOnlyAttention.forward,
46
+ "decoder_layer": DECODERONLY_ATTENTION_CLASSES[self.attn_implementation].forward,
47
47
  }
48
48
  )
49
49
  return forward_dict
@@ -61,9 +61,17 @@ class GemmaModel:
61
61
  use_cache: Optional[bool] = True,
62
62
  output_attentions: Optional[bool] = False,
63
63
  output_hidden_states: Optional[bool] = False,
64
+ cache_pos_for_partitions: Optional[torch.Tensor] = None,
65
+ kvcache_partition_size: Optional[torch.Tensor] = None,
64
66
  forward_dict: Optional[Dict[str, classmethod]] = None,
65
67
  rotary_pos_emb=None,
66
68
  ) -> Union[Tuple, BaseModelOutputWithPast]:
69
+ # retrieve input_ids and inputs_embeds
70
+ if (input_ids is None) ^ (inputs_embeds is not None):
71
+ raise ValueError(
72
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
73
+ )
74
+
67
75
  # embed positions
68
76
  inputs_embeds = self.embed_tokens(input_ids)
69
77
  hidden_states = inputs_embeds
@@ -96,6 +104,8 @@ class GemmaModel:
96
104
  batch_ids=batch_ids,
97
105
  cos=cos,
98
106
  sin=sin,
107
+ cache_pos_for_partitions=cache_pos_for_partitions,
108
+ kvcache_partition_size=kvcache_partition_size,
99
109
  forward_dict=forward_dict,
100
110
  )
101
111
 
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GemmaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .gemma_architecture import GemmaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Gemma Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based GemmaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers GemmaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNGemmaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return GemmaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(GemmaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = GemmaWrapper
@@ -21,20 +21,12 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import GPT2LMHeadModel
29
-
30
- from ....modeling_config import RBLNConfig
24
+ from ....utils import logging
31
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
32
26
  from .gpt2_architecture import GPT2LMHeadModelWrapper
33
27
 
34
28
 
35
- logger = logging.getLogger(__name__)
36
- if TYPE_CHECKING:
37
- from transformers import PreTrainedModel
29
+ logger = logging.get_logger(__name__)
38
30
 
39
31
 
40
32
  class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
@@ -42,7 +34,7 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
42
34
  The GPT2 Model transformer with a language modeling head on top (linear layer with weights tied to the input
43
35
  embeddings).
44
36
 
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the
37
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the
46
38
  library implements for all its model.
47
39
 
48
40
  It implements the methods to convert a pre-trained transformers GPT2 model into a RBLN transformer model by:
@@ -51,22 +43,4 @@ class RBLNGPT2LMHeadModel(RBLNDecoderOnlyModelForCausalLM):
51
43
 
52
44
  """
53
45
 
54
- @classmethod
55
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
56
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
57
- return GPT2LMHeadModelWrapper(model, rbln_max_seq_len).eval()
58
-
59
- def __getattr__(self, __name: str) -> Any:
60
- """This is the key method to implement RBLN-GPT2.
61
-
62
- Returns:
63
- Any: GPT2's corresponding method
64
- """
65
-
66
- def redirect(func):
67
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
68
-
69
- val = getattr(GPT2LMHeadModel, __name)
70
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
71
- return redirect(val)
72
- return val
46
+ _decoder_wrapper_cls = GPT2LMHeadModelWrapper
@@ -21,28 +21,18 @@
21
21
  # copied, modified, or distributed without prior written permission
22
22
  # from Rebellions Inc.
23
23
 
24
- import inspect
25
- import logging
26
- from typing import TYPE_CHECKING, Any, Callable
27
-
28
- from transformers import LlamaForCausalLM
29
-
24
+ from ....utils import logging
30
25
  from ...models.decoderonly import RBLNDecoderOnlyModelForCausalLM
31
26
  from .llama_architecture import LlamaWrapper
32
27
 
33
28
 
34
- if TYPE_CHECKING:
35
- from transformers import PreTrainedModel
36
-
37
- from ....modeling_config import RBLNConfig
38
-
39
- logger = logging.getLogger(__name__)
29
+ logger = logging.get_logger(__name__)
40
30
 
41
31
 
42
32
  class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
43
33
  """
44
34
  The Llama Model transformer with a language modeling head (linear layer) on top.
45
- This model inherits from [`RBLNMultiModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
46
36
 
47
37
  A class to convert and run pre-trained transformers based LlamaForCausalLM model on RBLN devices.
48
38
  It implements the methods to convert a pre-trained transformers LlamaForCausalLM model into a RBLN transformer model by:
@@ -50,18 +40,4 @@ class RBLNLlamaForCausalLM(RBLNDecoderOnlyModelForCausalLM):
50
40
  - compiling the resulting graph using the RBLN compiler.
51
41
  """
52
42
 
53
- @classmethod
54
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: "RBLNConfig"):
55
- rbln_max_seq_len = rbln_config.model_cfg["max_seq_len"]
56
- return LlamaWrapper(model, rbln_max_seq_len).eval()
57
-
58
- def __getattr__(self, __name: str) -> Any:
59
- def redirect(func):
60
- return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
61
-
62
- val = getattr(LlamaForCausalLM, __name)
63
-
64
- if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
65
- return redirect(val)
66
-
67
- return val
43
+ _decoder_wrapper_cls = LlamaWrapper
@@ -114,7 +114,7 @@ class LoopProjector:
114
114
  return self.forward(*args, **kwds)
115
115
 
116
116
  def __repr__(self) -> str:
117
- return repr(self.vision_tower)
117
+ return repr(self.multi_modal_projector)
118
118
 
119
119
 
120
120
  class RBLNLlavaNextForConditionalGeneration(RBLNModel):
@@ -228,29 +228,26 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
228
228
  pixel_values=None,
229
229
  image_sizes=None,
230
230
  attention_mask=None,
231
- past_cached_length=None,
231
+ generate_idx=None,
232
232
  **kwargs,
233
233
  ):
234
234
  # Prepare HF generation
235
- is_prefill_phase = past_cached_length is None
235
+ is_prefill_phase = generate_idx is None
236
236
  batch_size = input_ids.shape[0]
237
237
 
238
238
  model_inputs = self.language_model.prepare_inputs_for_generation(
239
239
  input_ids=input_ids,
240
240
  inputs_embeds=inputs_embeds,
241
- past_cached_length=past_cached_length, # Not affect
241
+ generate_idx=generate_idx, # Not affect
242
242
  attention_mask=attention_mask,
243
243
  **kwargs,
244
244
  )
245
245
 
246
246
  if is_prefill_phase:
247
- model_inputs["past_cached_length"] = torch.zeros((batch_size, 1), dtype=torch.int32)
248
- else:
249
- model_inputs["past_cached_length"] = past_cached_length + 1
247
+ model_inputs["generate_idx"] = torch.zeros((batch_size, 1), dtype=torch.int32)
250
248
 
251
249
  model_inputs.update(
252
250
  {
253
- # "position_ids": position_ids or cache_positions,
254
251
  "pixel_values": pixel_values,
255
252
  "image_sizes": image_sizes,
256
253
  "attention_mask": attention_mask,
@@ -264,43 +261,28 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
264
261
  model_kwargs: Dict[str, Any],
265
262
  **kwargs,
266
263
  ) -> Dict[str, Any]:
267
- # update past_cached_length
268
- model_kwargs["past_cached_length"] = outputs.past_cached_length
264
+ # update generate_idx
265
+ model_kwargs["generate_idx"] = outputs.generate_idx
269
266
 
270
267
  return model_kwargs
271
268
 
272
- def _merge_vllm_multimodal_embeddings(
269
+ def text_embedding(
273
270
  self,
274
- input_ids: torch.Tensor,
275
- inputs_embeds: torch.Tensor,
276
- multimodal_embeddings: torch.Tensor,
277
- placeholder_token_id: int,
271
+ input_ids: torch.LongTensor,
278
272
  ) -> torch.Tensor:
279
- mask = input_ids == placeholder_token_id
280
- num_expected_tokens = mask.sum().item()
281
- assert isinstance(num_expected_tokens, int)
282
-
283
- if multimodal_embeddings.shape[0] != num_expected_tokens:
284
- raise ValueError(
285
- f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
286
- f"multimodal tokens to {num_expected_tokens} placeholders"
287
- )
273
+ for_inputs_embeds_ids = input_ids.clone()
274
+ for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
275
+ inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
288
276
 
289
- inputs_embeds[mask] = multimodal_embeddings
290
277
  return inputs_embeds
291
278
 
292
- def _embed(
279
+ def image_embedding(
293
280
  self,
294
- input_ids: torch.LongTensor,
295
281
  image_sizes: torch.LongTensor,
296
- attention_mask: torch.Tensor,
297
282
  pixel_values: torch.FloatTensor,
298
283
  vision_feature_layer: int,
299
284
  vision_feature_select_strategy: str,
300
- cache_position: torch.Tensor,
301
- past_cached_length: torch.Tensor,
302
- from_vllm_prefill: bool = False,
303
- ) -> List[torch.Tensor]:
285
+ ) -> torch.Tensor:
304
286
  vision_feature_layer = (
305
287
  vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
306
288
  )
@@ -310,84 +292,137 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
310
292
  else self.config.vision_feature_select_strategy
311
293
  )
312
294
 
313
- # 1. Extract the input embeddings
314
- # In case image_token_index is not in the embeddings (extra token but embedding don't have it)
315
- for_inputs_embeds_ids = input_ids.clone()
316
- for_inputs_embeds_ids[(input_ids == self.config.image_token_index)] = 0
295
+ # ! infer image_num_patches from image_sizes
296
+ image_num_patches = [
297
+ image_size_to_num_patches(
298
+ image_size=imsize,
299
+ grid_pinpoints=self.config.image_grid_pinpoints,
300
+ patch_size=self.config.vision_config.image_size,
301
+ )
302
+ for imsize in image_sizes
303
+ ]
317
304
 
318
- inputs_embeds = self.get_input_embeddings()(for_inputs_embeds_ids)
305
+ # figure out if pixel_values is concatenated or stacked
306
+ if pixel_values.dim() == 5:
307
+ # stacking when input is (batch_size, num_patches, num_channels, height, width)
308
+ _pixel_values_list = [pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)]
309
+ pixel_values = torch.cat(_pixel_values_list, dim=0)
310
+ elif pixel_values.dim() != 4:
311
+ # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
312
+ raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
319
313
 
320
- # 2. Merge text and images
321
- if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
322
- # ! infer image_num_patches from image_sizes
323
- image_num_patches = [
324
- image_size_to_num_patches(
325
- image_size=imsize,
326
- grid_pinpoints=self.config.image_grid_pinpoints,
327
- patch_size=self.config.vision_config.image_size,
328
- )
329
- for imsize in image_sizes
330
- ]
331
- # figure out if pixel_values is concatenated or stacked
332
- if pixel_values.dim() == 5:
333
- # stacking when input is (batch_size, num_patches, num_channels, height, width)
334
- _pixel_values_list = [
335
- pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches)
336
- ]
337
- pixel_values = torch.cat(_pixel_values_list, dim=0)
338
- elif pixel_values.dim() != 4:
339
- # otherwise has to be stacked from list of (num_patches, num_channels, height, width)
340
- raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions")
341
-
342
- image_features = self.vision_tower(pixel_values, output_hidden_states=True)
343
- selected_image_feature = image_features.hidden_states[vision_feature_layer]
344
-
345
- if vision_feature_select_strategy == "default":
346
- selected_image_feature = selected_image_feature[:, 1:]
347
- elif vision_feature_select_strategy == "full":
348
- selected_image_feature = selected_image_feature
349
-
350
- image_features = self.multi_modal_projector(selected_image_feature)
351
- image_features = torch.split(image_features, image_num_patches, dim=0)
352
-
353
- # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
354
- image_features, feature_lens = self.pack_image_features(
355
- image_features,
356
- image_sizes,
357
- image_newline=self.image_newline,
358
- )
314
+ image_features = self.vision_tower(pixel_values, output_hidden_states=True)
315
+ selected_image_feature = image_features.hidden_states[vision_feature_layer]
359
316
 
360
- inputs_embeds = inputs_embeds.to(image_features.dtype)
317
+ if vision_feature_select_strategy == "default":
318
+ selected_image_feature = selected_image_feature[:, 1:]
319
+ elif vision_feature_select_strategy == "full":
320
+ selected_image_feature = selected_image_feature
321
+
322
+ image_features = self.multi_modal_projector(selected_image_feature)
323
+ image_features = torch.split(image_features, image_num_patches, dim=0)
324
+
325
+ # NOTE we only support multimodal_patch_merge_type == "spatial_unpad"
326
+ image_features, feature_lens = self.pack_image_features(
327
+ image_features,
328
+ image_sizes,
329
+ image_newline=self.image_newline,
330
+ )
361
331
 
362
- if from_vllm_prefill:
363
- self._merge_vllm_multimodal_embeddings(
364
- input_ids, inputs_embeds, image_features, self.config.image_token_index
365
- )
332
+ return image_features, feature_lens
333
+
334
+ def forward(
335
+ self,
336
+ input_ids: torch.LongTensor = None,
337
+ attention_mask: torch.LongTensor = None,
338
+ pixel_values: torch.FloatTensor = None,
339
+ image_sizes: Optional[torch.LongTensor] = None,
340
+ inputs_embeds: Optional[torch.FloatTensor] = None,
341
+ vision_feature_layer: Optional[int] = None,
342
+ vision_feature_select_strategy: Optional[str] = None,
343
+ cache_position: torch.Tensor = None,
344
+ generate_idx: Optional[torch.Tensor] = None,
345
+ **kwargs,
346
+ ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
347
+ if inputs_embeds is not None:
348
+ raise NotImplementedError("Specifying inputs_embeds is not supported.")
349
+
350
+ is_prefill_phase = not generate_idx.bool().all()
351
+
352
+ if is_prefill_phase:
353
+ # if the number of image tokens is more than image embeddings seq length, then prob we expanded it in processing
354
+ # not very reliable, but we don't expect one to actually pass 500+ images for one prompt
355
+ # In case we're in decoding stage, legacy behavior is checked by presence of pixel values even if use_cache=True
356
+ legacy_processing = (
357
+ (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length
358
+ ) or (input_ids.shape[-1] == 1 and pixel_values is not None)
359
+
360
+ # Get the number of images in the prompt
361
+ special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
362
+ if legacy_processing:
363
+ num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
366
364
  else:
367
- inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features(
368
- image_features,
369
- feature_lens,
370
- inputs_embeds,
371
- input_ids,
372
- attention_mask,
365
+ image_tokens_masks_diff = [
366
+ torch.diff(mask, prepend=torch.tensor([0])) for mask in special_image_token_masks
367
+ ]
368
+ num_special_image_tokens = [int(torch.sum((diff == 1).int())) for diff in image_tokens_masks_diff]
369
+
370
+ # Split images for each prompt
371
+ if pixel_values is not None and pixel_values.size(0) > 0:
372
+ pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
373
+ image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
374
+
375
+ logits = []
376
+ for b_idx in range(input_ids.shape[0]):
377
+ # Get text_embeds from input_id
378
+ input_id = input_ids[b_idx : b_idx + 1, attention_mask[b_idx].bool()]
379
+ inputs_embed = self.text_embedding(input_id)
380
+
381
+ # If any images in the prompt, get image_embeds and merge with text
382
+ if num_special_image_tokens[b_idx] > 0:
383
+ image_features, feature_lens = self.image_embedding(
384
+ image_sizes[b_idx], pixel_values[b_idx], vision_feature_layer, vision_feature_select_strategy
385
+ )
386
+ if legacy_processing:
387
+ inputs_embed, _, _, _, _ = self._merge_input_ids_with_image_features(
388
+ image_features,
389
+ feature_lens,
390
+ inputs_embed.to(image_features.dtype),
391
+ input_id,
392
+ torch.ones_like(input_id, dtype=torch.long),
393
+ )
394
+ else:
395
+ special_image_mask = (
396
+ (input_id == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embed)
397
+ )
398
+ inputs_embed = inputs_embed.masked_scatter(special_image_mask, image_features)
399
+
400
+ # Update generate_idx according to inputs_embed
401
+ generate_idx[b_idx] = inputs_embed.shape[1]
402
+
403
+ logit = self.language_model._forward_prefill(
404
+ inputs_embeds=inputs_embed,
405
+ batch_idx=b_idx,
406
+ cache_position=torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0),
373
407
  )
374
408
 
375
- cache_position = torch.arange(0, inputs_embeds.shape[1], dtype=torch.int32).unsqueeze_(0)
409
+ logits.append(logit)
376
410
 
377
- # pixel_values is not None but is empty ---> text only cases
378
- elif (
379
- pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) == 0 or pixel_values is None
380
- ):
381
- pass
411
+ logits = torch.cat(logits, dim=0)
412
+ outputs = RBLNDecoderOnlyOutput(logits=logits, generate_idx=generate_idx)
413
+
414
+ else:
415
+ inputs_embeds = self.text_embedding(input_ids)
382
416
 
383
- # In case input_ids.shape[1] == 1 & pixel_values==None & past_key_values != None, we are in the case of
384
- # generation with cache
385
- elif pixel_values is not None and input_ids.shape[1] == 1 and past_cached_length is not None:
386
- cache_position = past_cached_length
417
+ outputs: RBLNDecoderOnlyOutput = self.language_model(
418
+ inputs_embeds=inputs_embeds,
419
+ cache_position=cache_position,
420
+ generate_idx=generate_idx,
421
+ )
387
422
 
388
- return inputs_embeds, cache_position
423
+ return outputs
389
424
 
390
- def forward(
425
+ def vllm_forward(
391
426
  self,
392
427
  input_ids: torch.LongTensor = None,
393
428
  pixel_values: torch.FloatTensor = None,
@@ -397,72 +432,52 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel):
397
432
  vision_feature_select_strategy: Optional[str] = None,
398
433
  cache_position: Union[List[torch.Tensor], torch.Tensor] = None, # vllm keyword argument
399
434
  batch_idx: Optional[int] = None,
400
- past_cached_length: Optional[torch.Tensor] = None,
401
435
  **kwargs,
402
436
  ) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
403
- from_vllm_prefill = isinstance(cache_position, torch.Tensor) and cache_position.shape[-1] > 1
404
- from_hf_generate_prefill = isinstance(input_ids, list)
437
+ is_prefill = cache_position.shape[-1] > 1
405
438
 
406
439
  if inputs_embeds is not None:
407
440
  raise NotImplementedError("Specifying inputs_embeds is not supported.")
408
441
 
409
- if from_hf_generate_prefill:
410
- inputs_embeds = []
411
- batch_size = len(input_ids)
442
+ if is_prefill:
443
+ # Get text_embeds
444
+ inputs_embeds = self.text_embedding(input_ids)
412
445
 
413
- # Get the number of images in the prompt
414
- special_image_token_masks = [input_id == self.config.image_token_index for input_id in input_ids]
415
- num_special_image_tokens = [torch.sum(mask, dim=-1) for mask in special_image_token_masks]
446
+ # If any images in the prompt, get image_embeds and merge with text
447
+ if pixel_values is not None and input_ids.shape[1] != 1 and pixel_values.size(0) > 0:
448
+ image_features, _ = self.image_embedding(
449
+ image_sizes, pixel_values, vision_feature_layer, vision_feature_select_strategy
450
+ )
416
451
 
417
- # Split images for each prompt
418
- pixel_values = pixel_values.split(num_special_image_tokens, dim=0)
419
- image_sizes = image_sizes.split(num_special_image_tokens, dim=0)
420
-
421
- for b_idx in range(batch_size):
422
- embed, cache_pos = self._embed(
423
- input_ids=input_ids[b_idx],
424
- image_sizes=image_sizes[b_idx] if image_sizes is not None else None,
425
- attention_mask=torch.ones_like(input_ids[b_idx]),
426
- pixel_values=pixel_values[b_idx] if pixel_values is not None else None,
427
- vision_feature_layer=vision_feature_layer,
428
- vision_feature_select_strategy=vision_feature_select_strategy,
429
- cache_position=cache_position[b_idx],
430
- past_cached_length=past_cached_length[b_idx : b_idx + 1],
452
+ def merge_vllm_multimodal_embeddings(
453
+ input_ids: torch.Tensor,
454
+ inputs_embeds: torch.Tensor,
455
+ multimodal_embeddings: torch.Tensor,
456
+ placeholder_token_id: int,
457
+ ) -> torch.Tensor:
458
+ mask = input_ids == placeholder_token_id
459
+ num_expected_tokens = mask.sum().item()
460
+
461
+ if multimodal_embeddings.shape[0] != num_expected_tokens:
462
+ raise ValueError(
463
+ f"Attempted to assign {inputs_embeds[mask].shape} = {multimodal_embeddings.shape} "
464
+ f"multimodal tokens to {num_expected_tokens} placeholders"
465
+ )
466
+
467
+ inputs_embeds[mask] = multimodal_embeddings
468
+ return inputs_embeds
469
+
470
+ inputs_embeds = merge_vllm_multimodal_embeddings(
471
+ input_ids, inputs_embeds, image_features, self.config.image_token_index
431
472
  )
432
- inputs_embeds.append(embed)
433
- cache_position[b_idx] = cache_pos
434
- past_cached_length[b_idx] += embed.shape[1]
435
-
436
- elif from_vllm_prefill:
437
- inputs_embeds, cache_position = self._embed(
438
- input_ids=input_ids,
439
- image_sizes=image_sizes,
440
- attention_mask=torch.ones_like(input_ids),
441
- pixel_values=pixel_values,
442
- vision_feature_layer=vision_feature_layer,
443
- vision_feature_select_strategy=vision_feature_select_strategy,
444
- cache_position=cache_position,
445
- past_cached_length=past_cached_length,
446
- from_vllm_prefill=from_vllm_prefill,
447
- )
473
+
448
474
  else:
449
- # Decoding step
450
- inputs_embeds, cache_position = self._embed(
451
- input_ids=input_ids,
452
- image_sizes=image_sizes,
453
- attention_mask=torch.ones_like(input_ids),
454
- pixel_values=pixel_values,
455
- vision_feature_layer=vision_feature_layer,
456
- vision_feature_select_strategy=vision_feature_select_strategy,
457
- cache_position=cache_position,
458
- past_cached_length=past_cached_length,
459
- )
475
+ inputs_embeds = self.text_embedding(input_ids=input_ids)
460
476
 
461
- outputs: RBLNDecoderOnlyOutput = self.language_model(
477
+ outputs: RBLNDecoderOnlyOutput = self.language_model.vllm_forward(
462
478
  inputs_embeds=inputs_embeds,
463
479
  batch_idx=batch_idx,
464
480
  cache_position=cache_position,
465
- past_cached_length=past_cached_length,
466
481
  )
467
482
 
468
483
  return outputs