optimum-rbln 0.8.4a8__py3-none-any.whl → 0.9.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (64) hide show
  1. optimum/rbln/__init__.py +8 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +63 -32
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  9. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  11. optimum/rbln/diffusers/modeling_diffusers.py +5 -7
  12. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  13. optimum/rbln/modeling.py +50 -0
  14. optimum/rbln/modeling_base.py +1 -2
  15. optimum/rbln/transformers/__init__.py +8 -0
  16. optimum/rbln/transformers/modeling_generic.py +37 -1
  17. optimum/rbln/transformers/models/__init__.py +9 -0
  18. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  19. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +86 -23
  20. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  21. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  22. optimum/rbln/transformers/models/colpali/configuration_colpali.py +34 -18
  23. optimum/rbln/transformers/models/colpali/modeling_colpali.py +73 -80
  24. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  25. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  26. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  27. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  28. optimum/rbln/transformers/models/decoderonly/__init__.py +1 -0
  29. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +34 -0
  30. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  31. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +100 -20
  32. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +50 -2
  33. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  34. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +65 -3
  35. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  36. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  37. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +31 -3
  38. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +67 -44
  39. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +4 -1
  40. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  41. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +24 -19
  42. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  43. optimum/rbln/transformers/models/llava/modeling_llava.py +108 -50
  44. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  45. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -343
  46. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  47. optimum/rbln/transformers/models/phi/phi_architecture.py +5 -1
  48. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  49. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +9 -8
  50. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +24 -0
  51. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  52. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +24 -0
  53. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +3 -1
  54. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  55. optimum/rbln/transformers/models/whisper/generation_whisper.py +28 -6
  56. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -1
  57. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  58. optimum/rbln/utils/runtime_utils.py +25 -15
  59. optimum/rbln/utils/submodule.py +21 -5
  60. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/METADATA +7 -6
  61. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/RECORD +64 -55
  62. optimum_rbln-0.9.2.dist-info/entry_points.txt +2 -0
  63. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/WHEEL +0 -0
  64. {optimum_rbln-0.8.4a8.dist-info → optimum_rbln-0.9.2.dist-info}/licenses/LICENSE +0 -0
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
17
 
@@ -25,7 +26,9 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
25
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
26
27
  from ....modeling import RBLNModel
27
28
  from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
28
30
  from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
31
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
29
32
  from ..decoderonly.modeling_decoderonly import (
30
33
  RBLNDecoderOnlyModelForCausalLM,
31
34
  )
@@ -38,61 +41,44 @@ if TYPE_CHECKING:
38
41
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
39
42
 
40
43
 
41
- class LoopVisionTower:
42
- def __init__(self, vision_tower: RBLNModel) -> None:
43
- self.vision_tower = vision_tower
44
+ class LoopVisionTower(LoopProcessor):
45
+ def __init__(self, vision_tower: "RBLNModel"):
46
+ super().__init__(model=vision_tower)
44
47
 
45
- def forward(self, *args, **kwargs):
46
- # Loop instead of batch
47
- # shape of pixel_values : [batch, num_channel, height, width]
48
- pixel_values = args[0]
48
+ def _get_batch_size(self, pixel_values, **kwargs):
49
+ return pixel_values.shape[0]
49
50
 
50
- batch_size = pixel_values.shape[0]
51
- outputs = []
52
- for i in range(batch_size):
53
- outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
52
+ pixel_values_item = pixel_values[index : index + 1]
53
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
54
+ return ([pixel_values_item], {"out": out_buffer})
54
55
 
55
- last_hidden_states = [output.last_hidden_state for output in outputs]
56
-
57
- # FIXME:: This can be optimized using out= API of rbln runtime.
58
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
56
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
57
+ output = kwargs["out"]
59
58
 
60
59
  return BaseModelOutputWithPooling(
61
- last_hidden_state=last_hidden_states,
60
+ last_hidden_state=output[0],
62
61
  )
63
62
 
64
- def __call__(self, *args: Any, **kwds: Any) -> Any:
65
- return self.forward(*args, **kwds)
66
-
67
- def __repr__(self) -> str:
68
- return repr(self.vision_tower)
69
-
70
63
 
71
- class LoopProjector:
72
- def __init__(self, multi_modal_projector) -> None:
73
- self.multi_modal_projector = multi_modal_projector
64
+ class LoopProjector(LoopProcessor):
65
+ def __init__(self, multi_modal_projector: "RBLNModel"):
66
+ super().__init__(model=multi_modal_projector)
74
67
 
75
- def forward(self, *args, **kwargs):
76
- # Loop instead of batch
77
- image_feature = args[0]
68
+ def _get_batch_size(self, image_feature, **kwargs):
69
+ return image_feature.shape[0]
78
70
 
79
- batch_size = image_feature.shape[0]
80
- outputs = []
81
- for i in range(batch_size):
82
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
71
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
72
+ image_feature_item = image_feature[index : index + 1]
73
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
74
+ return ([image_feature_item], {"out": out_buffer})
83
75
 
84
- # FIXME:: This can be optimized using out= API of rbln runtime.
85
- outputs = torch.cat(outputs, dim=0)
86
- return outputs
76
+ def _process_outputs(self, outputs: list, **kwargs):
77
+ output = kwargs["out"]
78
+ return output[0]
87
79
 
88
- def __call__(self, *args: Any, **kwds: Any) -> Any:
89
- return self.forward(*args, **kwds)
90
80
 
91
- def __repr__(self) -> str:
92
- return repr(self.multi_modal_projector)
93
-
94
-
95
- class RBLNGemma3ForConditionalGeneration(RBLNModel):
81
+ class RBLNGemma3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
96
82
  auto_model_class = AutoModelForImageTextToText
97
83
  _rbln_submodules = [
98
84
  {"name": "vision_tower"},
@@ -112,6 +98,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
112
98
  def can_generate(self):
113
99
  return True
114
100
 
101
+ @classmethod
102
+ def get_pytorch_model(cls, *args, **kwargs):
103
+ model = super().get_pytorch_model(*args, **kwargs)
104
+
105
+ with no_init_weights():
106
+ model_cls_name = model.model.language_model.__class__.__name__
107
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
108
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
109
+ new_language_model = causal_model_cls(model.model.language_model.config)
110
+
111
+ new_language_model.lm_head = model.lm_head
112
+ new_language_model.model = model.model.language_model
113
+ model.model.language_model = new_language_model
114
+ model.lm_head = None
115
+ del model.lm_head
116
+ return model
117
+
115
118
  def __post_init__(self, **kwargs):
116
119
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
117
120
  self.language_model = self.rbln_submodules[1]
@@ -210,8 +213,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
210
213
  # Returns:
211
214
  # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
212
215
 
213
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
214
- image_features = self.multi_modal_projector(vision_outputs)
216
+ vision_out_buffer = []
217
+ vision_out_size = [
218
+ pixel_values.shape[0],
219
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
220
+ self.config.vision_config.hidden_size,
221
+ ]
222
+ projector_out_size = [
223
+ pixel_values.shape[0],
224
+ self.config.mm_tokens_per_image,
225
+ self.config.text_config.hidden_size,
226
+ ]
227
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
228
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
229
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
230
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
215
231
  return image_features
216
232
 
217
233
  def _preprocess_prefill(
@@ -393,6 +409,13 @@ class RBLNGemma3ForCausalLM(RBLNDecoderOnlyModelForCausalLM):
393
409
  def _update_sliding_window_config(cls, model_config: PretrainedConfig, rbln_config: RBLNGemma3ForCausalLMConfig):
394
410
  sliding_window = getattr(model_config, "sliding_window", None)
395
411
  sliding_window_pattern = getattr(model_config, "sliding_window_pattern", None)
412
+ if sliding_window_pattern is None:
413
+ if hasattr(model_config, "layer_types"):
414
+ first_full_attention_index = model_config.layer_types.index("full_attention")
415
+ sliding_window_pattern = first_full_attention_index + 1
416
+ else:
417
+ raise ValueError("Cannot determine sliding_window_pattern from model_config")
418
+
396
419
  if sliding_window_pattern <= model_config.num_hidden_layers:
397
420
  rbln_config.cache_impl = "hybrid"
398
421
  rbln_config.sliding_window = sliding_window
@@ -75,7 +75,10 @@ class GPT2Attention(DecoderOnlyAttention):
75
75
  self.o_proj = self._original_mod.c_proj
76
76
  self.split_size = self._original_mod.split_size
77
77
 
78
- def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
78
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
79
+ if lora_int_id is not None:
80
+ raise NotImplementedError("LoRA is not supported for GPT2Attention")
81
+
79
82
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
80
83
  return query_states, key_states, value_states
81
84
 
@@ -15,10 +15,29 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
21
- pass
25
+ """
26
+ Configuration class for RBLNIdefics3VisionTransformer.
27
+
28
+ This configuration class stores the configuration parameters specific to
29
+ RBLN-optimized Idefics3 vision transformer.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs: Any,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
22
41
 
23
42
 
24
43
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -61,5 +80,10 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
61
80
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
62
81
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
63
82
 
64
- self.vision_model = vision_model
65
- self.text_model = text_model
83
+ if self.batch_size != 1:
84
+ logger.warning("Ignore batch_size for Idefics3 vision transformer. It will be set to 1.")
85
+
86
+ self.vision_model = self.initialize_submodule_config(
87
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
88
+ )
89
+ self.text_model = self.initialize_submodule_config(submodule_config=text_model)
@@ -35,6 +35,7 @@ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
35
35
  from ....modeling import RBLNModel
36
36
  from ....utils.runtime_utils import RBLNPytorchRuntime
37
37
  from ...modeling_outputs import RBLNDecoderOnlyOutput
38
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
38
39
 
39
40
 
40
41
  if TYPE_CHECKING:
@@ -75,10 +76,12 @@ class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
75
76
 
76
77
  hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
77
78
 
78
- return super().forward(hidden_states.contiguous())
79
+ return super().forward(hidden_states.contiguous(), **kwargs)
79
80
 
80
81
 
81
82
  class RBLNIdefics3VisionTransformer(RBLNModel):
83
+ _tp_support = False
84
+
82
85
  def __post_init__(self, **kwargs):
83
86
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
84
87
  with no_init_weights():
@@ -118,9 +121,6 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
118
121
  encoder_outputs = self.encoder(
119
122
  inputs_embeds=hidden_states,
120
123
  attention_mask=patch_attention_mask,
121
- output_attentions=None,
122
- output_hidden_states=None,
123
- return_dict=False,
124
124
  )
125
125
  last_hidden_state = encoder_outputs[0]
126
126
  last_hidden_state = self.post_layernorm(last_hidden_state)
@@ -140,8 +140,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
140
140
  (
141
141
  "hidden_states",
142
142
  [
143
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
144
- 1,
143
+ rbln_config.batch_size,
145
144
  (model_config.image_size // model_config.patch_size) ** 2,
146
145
  model_config.hidden_size,
147
146
  ],
@@ -160,29 +159,31 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
160
159
  return_dict: Optional[bool] = None,
161
160
  **kwargs,
162
161
  ) -> Union[Tuple, BaseModelOutput]:
163
- batch_size = pixel_values.shape[0]
164
- last_hidden_state = []
165
- for i in range(batch_size):
162
+ last_hidden_state_size = [
163
+ pixel_values.shape[0],
164
+ (self.config.image_size // self.config.patch_size) ** 2,
165
+ self.config.hidden_size,
166
+ ]
167
+ last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=torch.float32, device="cpu")
168
+ for i in range(pixel_values.shape[0]):
166
169
  if patch_attention_mask is not None:
167
170
  batch_attention_mask = patch_attention_mask[i : i + 1,]
168
171
  else:
169
172
  batch_attention_mask = None
170
173
 
171
- batch_hidden_state = self.model(
174
+ self.model(
172
175
  pixel_values[i : i + 1,],
173
176
  batch_attention_mask,
177
+ out=last_hidden_state[i : i + 1,],
174
178
  return_dict=False,
175
179
  )
176
- last_hidden_state.append(batch_hidden_state)
177
- last_hidden_state = torch.cat(last_hidden_state, dim=0)
178
-
179
180
  if not return_dict:
180
181
  return (last_hidden_state,)
181
182
  else:
182
183
  return BaseModelOutput(last_hidden_state=last_hidden_state)
183
184
 
184
185
 
185
- class RBLNIdefics3ForConditionalGeneration(RBLNModel):
186
+ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
186
187
  """
187
188
  RBLNIdefics3ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
188
189
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -285,8 +286,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
285
286
  (
286
287
  "image_hidden_states",
287
288
  [
288
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
289
- 1,
289
+ rbln_config.vision_model.batch_size,
290
290
  (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2,
291
291
  model_config.vision_config.hidden_size,
292
292
  ],
@@ -425,10 +425,15 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
425
425
  pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True
426
426
  ).last_hidden_state
427
427
 
428
- connector_outputs = []
428
+ connector_output_size = [
429
+ image_hidden_states.shape[0],
430
+ image_hidden_states.shape[1] // self.config.scale_factor**2,
431
+ self.config.text_config.hidden_size,
432
+ ]
433
+ connector_outputs = torch.empty(size=connector_output_size, dtype=torch.float32, device="cpu")
429
434
  for i in range(image_hidden_states.shape[0]):
430
- connector_outputs.append(self.connector(image_hidden_states[i : i + 1,]))
431
- image_hidden_states = torch.cat(connector_outputs, dim=0)
435
+ self.connector(image_hidden_states[i : i + 1,], out=connector_outputs[i : i + 1,])
436
+ image_hidden_states = connector_outputs
432
437
 
433
438
  elif image_hidden_states is not None:
434
439
  image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
@@ -15,6 +15,10 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
@@ -54,5 +58,15 @@ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
54
58
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
55
59
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
56
60
 
57
- self.vision_tower = vision_tower
58
- self.language_model = language_model
61
+ if self.batch_size != 1:
62
+ logger.warning("Ignore batch_size for Llava vision tower. It will be set to 1.")
63
+
64
+ self.vision_tower = self.initialize_submodule_config(
65
+ submodule_config=vision_tower,
66
+ batch_size=1, # vision_tower batch_size is always 1 in Llava
67
+ force_kwargs=True,
68
+ )
69
+
70
+ self.language_model = self.initialize_submodule_config(
71
+ submodule_config=language_model,
72
+ )
@@ -12,18 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import importlib
15
16
  import inspect
16
17
  from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
17
18
 
18
19
  import torch
19
20
  from transformers import AutoModelForImageTextToText, LlavaForConditionalGeneration, PretrainedConfig, PreTrainedModel
20
21
  from transformers.modeling_outputs import BaseModelOutputWithPooling
22
+ from transformers.modeling_utils import no_init_weights
21
23
  from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
22
24
 
23
25
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
26
  from ....modeling import RBLNModel
25
27
  from ....utils.logging import get_logger
26
28
  from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
30
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
27
31
 
28
32
 
29
33
  logger = get_logger(__name__)
@@ -32,20 +36,32 @@ if TYPE_CHECKING:
32
36
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
33
37
 
34
38
 
35
- class LoopVisionTower:
36
- def __init__(self, vision_tower: RBLNModel) -> None:
37
- self.vision_tower = vision_tower
39
+ class LoopVisionTower(LoopProcessor):
40
+ def __init__(self, vision_tower):
41
+ # FIXME: need to know RBLNModel or RuntimeWrapper
42
+ if hasattr(vision_tower.model, "runtime"):
43
+ super().__init__(model=vision_tower)
44
+ else:
45
+ super().__init__(model=vision_tower.model[0])
38
46
 
39
- def forward(self, pixel_values, image_sizes: Optional[torch.Tensor] = None, **kwargs):
40
- outputs = []
41
- for i in range(pixel_values.shape[0]):
42
- outputs.append(
43
- self.vision_tower(
44
- pixel_values[i : i + 1], image_sizes[i : i + 1] if image_sizes is not None else None, **kwargs
45
- )
46
- )
47
+ self.rbln_config = vision_tower.rbln_config
48
+
49
+ def _get_batch_size(self, pixel_values, **kwargs):
50
+ return pixel_values.shape[0]
51
+
52
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
53
+ pixel_values_item = pixel_values[index : index + 1]
54
+ if "image_sizes" in kwargs and kwargs["image_sizes"] is not None:
55
+ ret_val = [pixel_values_item, kwargs["image_sizes"][index : index + 1]]
56
+ else:
57
+ ret_val = [pixel_values_item]
58
+
59
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]] if "out" in kwargs else None
60
+ return (ret_val, {"out": out_buffer})
47
61
 
48
- if hasattr(self.vision_tower.rbln_config, "max_image_size"):
62
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
63
+ # when use another Wrapper
64
+ if hasattr(self.rbln_config, "max_image_size"):
49
65
  last_hidden_states = [output.last_hidden_state for output in outputs]
50
66
  last_hidden_states = torch.cat(last_hidden_states, dim=1)
51
67
  hidden_states = tuple(
@@ -55,52 +71,40 @@ class LoopVisionTower:
55
71
  )
56
72
  for layer_idx in range(len(outputs[0].hidden_states))
57
73
  )
58
-
59
74
  else:
60
- last_hidden_states = [output.last_hidden_state for output in outputs]
61
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
62
- hidden_states = [output.hidden_states for output in outputs]
63
- hidden_states = tuple(
64
- torch.cat(tuple((hidden_states[n][i] for n in range(pixel_values.shape[0]))), dim=0)
65
- for i in range(len(hidden_states[0]))
66
- )
75
+ output = kwargs["out"]
76
+ last_hidden_states = output[0]
77
+
78
+ if not output[2:]:
79
+ hidden_states = None
80
+ else:
81
+ hidden_states = tuple(output[2:])
67
82
 
68
83
  return BaseModelOutputWithPooling(
69
84
  last_hidden_state=last_hidden_states,
85
+ pooler_output=None,
70
86
  hidden_states=hidden_states,
71
87
  )
72
88
 
73
- def __call__(self, *args: Any, **kwds: Any) -> Any:
74
- return self.forward(*args, **kwds)
75
89
 
76
- def __repr__(self) -> str:
77
- return repr(self.vision_tower)
90
+ class LoopProjector(LoopProcessor):
91
+ def __init__(self, multi_modal_projector: "RBLNModel"):
92
+ super().__init__(model=multi_modal_projector)
78
93
 
94
+ def _get_batch_size(self, image_feature, **kwargs):
95
+ return image_feature.shape[0]
79
96
 
80
- class LoopProjector:
81
- def __init__(self, multi_modal_projector) -> None:
82
- self.multi_modal_projector = multi_modal_projector
97
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
98
+ image_feature_item = image_feature[index : index + 1]
99
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
+ return ([image_feature_item], {"out": out_buffer})
83
101
 
84
- def forward(self, *args, **kwargs):
85
- # Loop instead of batch
86
- image_feature = args[0]
102
+ def _process_outputs(self, outputs: list, **kwargs):
103
+ output = kwargs["out"]
104
+ return output[0]
87
105
 
88
- outputs = []
89
- for i in range(image_feature.shape[0]):
90
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
91
106
 
92
- # FIXME:: This can be optimized using out= API of rbln runtime.
93
- outputs = torch.cat(outputs, dim=0)
94
- return outputs
95
-
96
- def __call__(self, *args: Any, **kwds: Any) -> Any:
97
- return self.forward(*args, **kwds)
98
-
99
- def __repr__(self) -> str:
100
- return repr(self.multi_modal_projector)
101
-
102
-
103
- class RBLNLlavaForConditionalGeneration(RBLNModel):
107
+ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
104
108
  """
105
109
  RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
106
110
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
@@ -170,6 +174,23 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
170
174
  def can_generate(self):
171
175
  return True
172
176
 
177
+ @classmethod
178
+ def get_pytorch_model(cls, *args, **kwargs):
179
+ model = super().get_pytorch_model(*args, **kwargs)
180
+
181
+ with no_init_weights():
182
+ model_cls_name = model.model.language_model.__class__.__name__
183
+ causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
184
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
185
+ new_language_model = causal_model_cls(model.model.language_model.config)
186
+
187
+ new_language_model.lm_head = model.lm_head
188
+ new_language_model.model = model.model.language_model
189
+ model.model.language_model = new_language_model
190
+ model.lm_head = None
191
+ del model.lm_head
192
+ return model
193
+
173
194
  def __post_init__(self, **kwargs):
174
195
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
175
196
  self.language_model = self.rbln_submodules[1]
@@ -201,7 +222,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
201
222
  # support for pixtral that needs padding
202
223
  if hasattr(rbln_config.vision_tower, "max_image_size"):
203
224
  num_positions = (
204
- rbln_config.vision_tower.batch_size
225
+ rbln_config.batch_size
205
226
  * (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size)
206
227
  * (rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size)
207
228
  )
@@ -217,7 +238,11 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
217
238
  input_info = [
218
239
  (
219
240
  "image_features",
220
- [rbln_config.batch_size, selected_image_feature_dim, model_config.vision_config.hidden_size],
241
+ [
242
+ 1,
243
+ selected_image_feature_dim,
244
+ model_config.vision_config.hidden_size,
245
+ ],
221
246
  "float32",
222
247
  )
223
248
  ]
@@ -290,7 +315,31 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
290
315
  raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
291
316
 
292
317
  kwargs = {k: v for k, v in kwargs.items() if v is not None}
293
- image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs)
318
+
319
+ # prepare out buffer for pre-allocation
320
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
321
+ vision_out_size = [
322
+ pixel_values.shape[0],
323
+ (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
324
+ * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size),
325
+ self.config.vision_config.hidden_size,
326
+ ]
327
+ pooler_out_size = None
328
+ else:
329
+ vision_out_size = [
330
+ pixel_values.shape[0],
331
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2 + 1,
332
+ self.config.vision_config.hidden_size,
333
+ ]
334
+ pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
335
+
336
+ vision_out_buffer = []
337
+ for i in range(self.config.vision_config.num_hidden_layers + 2):
338
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
339
+ if pooler_out_size is not None:
340
+ vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
341
+
342
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True, out=vision_out_buffer, **kwargs)
294
343
 
295
344
  if isinstance(vision_feature_layer, int):
296
345
  selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
@@ -311,15 +360,24 @@ class RBLNLlavaForConditionalGeneration(RBLNModel):
311
360
  )
312
361
  num_padding_patches = max_patches - num_real_patches
313
362
 
363
+ projector_out_size = [1, max_patches, self.config.text_config.hidden_size]
364
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
365
+
314
366
  padding_tensor = torch.zeros(
315
367
  (selected_image_feature.shape[0], num_padding_patches, selected_image_feature.shape[2]),
316
368
  dtype=selected_image_feature.dtype,
317
369
  )
318
370
  padded_feature = torch.cat([selected_image_feature, padding_tensor], dim=1)
319
- padded_projected_feature = self.multi_modal_projector(padded_feature)
371
+ padded_projected_feature = self.multi_modal_projector(padded_feature, out=projector_out_buffer)
320
372
  image_features = padded_projected_feature[:, :num_real_patches, :]
321
373
  else:
322
- image_features = self.multi_modal_projector(selected_image_feature)
374
+ projector_out_size = [
375
+ pixel_values.shape[0] * pixel_values.shape[1],
376
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
377
+ self.config.text_config.hidden_size,
378
+ ]
379
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
380
+ image_features = self.multi_modal_projector(selected_image_feature, out=projector_out_buffer)
323
381
 
324
382
  return image_features
325
383
 
@@ -16,7 +16,6 @@ from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
  from ....utils.logging import get_logger
19
- from ...models.clip import RBLNCLIPVisionModelConfig
20
19
 
21
20
 
22
21
  logger = get_logger(__name__)
@@ -55,17 +54,16 @@ class RBLNLlavaNextForConditionalGenerationConfig(RBLNModelConfig):
55
54
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
56
55
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
57
56
 
58
- self.vision_tower = self.init_submodule_config(
59
- RBLNCLIPVisionModelConfig,
60
- vision_tower,
61
- )
57
+ if self.batch_size != 1:
58
+ logger.warning("Ignore batch_size for LlavaNext vision tower. It will be set to 1.")
62
59
 
63
- if self.vision_tower.output_hidden_states is False:
64
- raise ValueError(
65
- f"LlavaNext requires output_hidden_states to be True, but found output_hidden_states={self.vision_tower.output_hidden_states}. "
66
- f"Please compile again with the correct argument."
67
- )
68
- else:
69
- self.vision_tower.output_hidden_states = True
60
+ self.vision_tower = self.initialize_submodule_config(
61
+ submodule_config=vision_tower,
62
+ batch_size=1, # vision_tower batch_size is always 1 in LlavaNext
63
+ output_hidden_states=True, # LlavaNext requires output_hidden_states to be True
64
+ force_kwargs=True,
65
+ )
70
66
 
71
- self.language_model = language_model
67
+ self.language_model = self.initialize_submodule_config(
68
+ submodule_config=language_model,
69
+ )