optimum-rbln 0.9.1__py3-none-any.whl → 0.9.2a0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (36) hide show
  1. optimum/rbln/__version__.py +2 -2
  2. optimum/rbln/configuration_utils.py +54 -7
  3. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +30 -14
  4. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +11 -8
  5. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +23 -13
  6. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +10 -6
  7. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +14 -10
  8. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +14 -7
  9. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +9 -11
  10. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +35 -3
  11. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +21 -22
  12. optimum/rbln/transformers/models/clip/modeling_clip.py +4 -0
  13. optimum/rbln/transformers/models/colpali/colpali_architecture.py +2 -2
  14. optimum/rbln/transformers/models/colpali/configuration_colpali.py +17 -1
  15. optimum/rbln/transformers/models/colpali/modeling_colpali.py +72 -79
  16. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +2 -2
  17. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +11 -3
  18. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +58 -43
  19. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +27 -3
  20. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +22 -15
  21. optimum/rbln/transformers/models/llava/configuration_llava.py +16 -2
  22. optimum/rbln/transformers/models/llava/modeling_llava.py +106 -49
  23. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +11 -13
  24. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +232 -342
  25. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +6 -11
  26. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +11 -1
  27. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +22 -0
  28. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +11 -1
  29. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +22 -0
  30. optimum/rbln/transformers/models/siglip/modeling_siglip.py +3 -14
  31. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  32. optimum/rbln/utils/submodule.py +21 -5
  33. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2a0.dist-info}/METADATA +2 -2
  34. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2a0.dist-info}/RECORD +36 -35
  35. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2a0.dist-info}/WHEEL +0 -0
  36. {optimum_rbln-0.9.1.dist-info → optimum_rbln-0.9.2a0.dist-info}/licenses/LICENSE +0 -0
@@ -14,6 +14,10 @@
14
14
  from typing import Any, List, Optional, Union
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
18
+
19
+
20
+ logger = get_logger(__name__)
17
21
 
18
22
 
19
23
  class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
@@ -47,6 +51,7 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
47
51
 
48
52
  def __init__(
49
53
  self,
54
+ batch_size: Optional[int] = None,
50
55
  max_seq_lens: Union[int, List[int]] = None,
51
56
  output_hidden_states: Optional[bool] = None,
52
57
  vision_tower: Optional[RBLNModelConfig] = None,
@@ -54,6 +59,8 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
54
59
  ):
55
60
  """
56
61
  Args:
62
+ batch_size (Optional[int]): The batch size for the model.
63
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
57
64
  max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
58
65
  This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
59
66
  output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
@@ -63,6 +70,15 @@ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
63
70
  ValueError: If batch_size is not a positive integer.
64
71
  """
65
72
  super().__init__(**kwargs)
66
- self.vision_tower = vision_tower
73
+ self.batch_size = batch_size or 1
74
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
+
77
+ if self.batch_size != 1:
78
+ logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
+
80
+ self.vision_tower = self.initialize_submodule_config(
81
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
82
+ )
67
83
  self.max_seq_lens = max_seq_lens
68
84
  self.output_hidden_states = output_hidden_states
@@ -26,6 +26,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModal
26
26
 
27
27
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
28
28
  from ....modeling import RBLNModel
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
29
30
  from .colpali_architecture import RBLNColPaliForRetrievalWrapper
30
31
 
31
32
 
@@ -33,93 +34,64 @@ if TYPE_CHECKING:
33
34
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
34
35
 
35
36
 
36
- class LoopVisionTower:
37
- def __init__(self, vision_tower: RBLNModel) -> None:
38
- self.vision_tower = vision_tower
37
+ class LoopVisionTower(LoopProcessor):
38
+ def __init__(self, vision_tower: "RBLNModel"):
39
+ super().__init__(model=vision_tower.model[0])
39
40
 
40
- def forward(self, pixel_values, **kwargs):
41
- batch_size = pixel_values.shape[0]
42
- outputs = []
43
- for i in range(batch_size):
44
- outputs.append(self.vision_tower(pixel_values[i : i + 1]))
41
+ def _get_batch_size(self, pixel_values, **kwargs):
42
+ return pixel_values.shape[0]
45
43
 
46
- last_hidden_states = [output.last_hidden_state for output in outputs]
47
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
44
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
45
+ pixel_values_item = pixel_values[index : index + 1]
46
+ out_buffer = kwargs["out"][index : index + 1]
47
+ return ([pixel_values_item], {"out": out_buffer})
48
48
 
49
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
49
50
  return BaseModelOutputWithPooling(
50
- last_hidden_state=last_hidden_states,
51
+ last_hidden_state=kwargs["out"],
51
52
  )
52
53
 
53
- def __call__(self, *args: Any, **kwds: Any) -> Any:
54
- return self.forward(*args, **kwds)
55
-
56
- def __repr__(self) -> str:
57
- return repr(self.vision_tower)
58
-
59
54
 
60
- class LoopLanguageModel:
61
- def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig) -> None:
62
- self.language_model = language_model
55
+ class LoopLanguageModel(LoopProcessor):
56
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig):
57
+ super().__init__(model=language_model)
63
58
  self.rbln_config = rbln_config
64
59
 
65
- def prepare_inputs(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor):
60
+ def _get_batch_size(self, inputs_embeds, **kwargs):
61
+ return inputs_embeds.shape[0]
62
+
63
+ def _prepare_inputs_before_loop(self, *, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
66
64
  input_len = inputs_embeds.shape[1]
67
65
  idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
68
66
  if idx == len(self.rbln_config.max_seq_lens):
69
67
  raise ValueError(
70
68
  f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
71
69
  )
72
- else:
73
- max_seq_len = self.rbln_config.max_seq_lens[idx]
74
-
75
- inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
76
- attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
77
- position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
78
-
79
- return inputs_embed, attn_mask, position_ids
80
-
81
- def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
82
- padded_inputs_embed, padded_attn_mask, padded_position_ids = self.prepare_inputs(inputs_embeds, attention_mask)
83
- input_batch_size = inputs_embeds.shape[0]
84
- input_seq_len = inputs_embeds.shape[1]
85
-
86
- all_embeddings = []
87
- all_hidden_states = []
88
- for i in range(input_batch_size):
89
- outputs = self.language_model(
90
- inputs_embeds=padded_inputs_embed[i : i + 1],
91
- attention_mask=padded_attn_mask[i : i + 1],
92
- position_ids=padded_position_ids,
93
- )
94
-
95
- if self.rbln_config.output_hidden_states:
96
- embedding = outputs[0]
97
- hidden_states = outputs[1:]
98
- else:
99
- embedding = outputs
100
- hidden_states = None
70
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
71
+ padded_inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
72
+ padded_attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
73
+ padded_position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
74
+
75
+ return {
76
+ "padded_inputs_embed": padded_inputs_embed,
77
+ "padded_attn_mask": padded_attn_mask,
78
+ "padded_position_ids": padded_position_ids,
79
+ }
101
80
 
102
- all_embeddings.append(embedding)
103
- all_hidden_states.append(hidden_states)
81
+ def _prepare_inputs_for_iteration(self, index: int, common_inputs, *args, **kwargs):
82
+ item_kwargs = {
83
+ "inputs_embeds": common_inputs["padded_inputs_embed"][index : index + 1],
84
+ "attention_mask": common_inputs["padded_attn_mask"][index : index + 1],
85
+ "position_ids": common_inputs["padded_position_ids"],
86
+ "out": [tensor[index : index + 1] for tensor in kwargs["out"]],
87
+ }
88
+ return ([], item_kwargs)
104
89
 
105
- embeddings = torch.cat(all_embeddings, dim=0)[:, :input_seq_len]
90
+ def _process_outputs(self, outputs: list, **kwargs):
106
91
  if self.rbln_config.output_hidden_states:
107
- hidden_states = [
108
- torch.cat(
109
- [batch_hidden_states[layer_idx][:, :input_seq_len] for batch_hidden_states in all_hidden_states],
110
- dim=0,
111
- )
112
- for layer_idx in range(len(all_hidden_states[0]))
113
- ]
114
- return embeddings, tuple(hidden_states)
92
+ return kwargs["out"][0], tuple(kwargs["out"][1:])
115
93
  else:
116
- return embeddings
117
-
118
- def __call__(self, *args: Any, **kwds: Any) -> Any:
119
- return self.forward(*args, **kwds)
120
-
121
- def __repr__(self) -> str:
122
- return repr(self.language_model)
94
+ return kwargs["out"]
123
95
 
124
96
 
125
97
  class RBLNColPaliForRetrieval(RBLNModel):
@@ -212,7 +184,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
212
184
  @classmethod
213
185
  def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
214
186
  return RBLNColPaliForRetrievalWrapper(
215
- causal_lm=model.vlm.language_model,
187
+ causal_lm=model.vlm,
216
188
  embedding_proj_layer=model.embedding_proj_layer,
217
189
  max_seq_len=max(rbln_config.max_seq_lens),
218
190
  output_hidden_states=rbln_config.output_hidden_states,
@@ -252,9 +224,9 @@ class RBLNColPaliForRetrieval(RBLNModel):
252
224
  input_infos = []
253
225
  for max_seq_len in rbln_config.max_seq_lens:
254
226
  input_info = [
255
- ("inputs_embeds", [1, max_seq_len, hidden_size], "float32"),
256
- ("attention_mask", [1, max_seq_len], "float32"),
257
- ("position_ids", [1, max_seq_len], "int32"),
227
+ ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
228
+ ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
229
+ ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
258
230
  ]
259
231
  input_infos.append(input_info)
260
232
 
@@ -306,8 +278,7 @@ class RBLNColPaliForRetrieval(RBLNModel):
306
278
  def get_pytorch_model(cls, *args, **kwargs):
307
279
  model = super().get_pytorch_model(*args, **kwargs)
308
280
  model.vision_tower = model.vlm.vision_tower
309
- del model.vlm.vision_tower
310
-
281
+ del model.vlm.model.vision_tower
311
282
  return model
312
283
 
313
284
  def get_image_features(self, pixel_values: torch.Tensor):
@@ -318,8 +289,14 @@ class RBLNColPaliForRetrieval(RBLNModel):
318
289
  # Returns:
319
290
  # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
320
291
 
321
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
322
- image_features = self.multi_modal_projector(vision_outputs)
292
+ vision_output_size = [
293
+ pixel_values.shape[0],
294
+ self.config.vlm_config.vision_config.num_image_tokens,
295
+ self.config.vlm_config.vision_config.hidden_size,
296
+ ]
297
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
298
+ self.vision_tower(pixel_values, out=vision_output)
299
+ image_features = self.multi_modal_projector(vision_output)
323
300
  image_features = image_features / (self.config.text_config.hidden_size**0.5)
324
301
  return image_features
325
302
 
@@ -385,11 +362,27 @@ class RBLNColPaliForRetrieval(RBLNModel):
385
362
  input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
386
363
  )
387
364
 
365
+ outputs = []
366
+ language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
367
+ language_model_hidden_states_size = [
368
+ inputs_embeds.shape[0],
369
+ self.rbln_config.max_seq_lens[0],
370
+ self.rbln_config.max_seq_lens[0],
371
+ ]
372
+ outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
373
+ if self.rbln_config.output_hidden_states:
374
+ for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
375
+ outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
376
+
388
377
  # Embedding_proj_layer is fused on the bottom of the language model.
389
- outputs = self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask)
378
+ self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
390
379
 
391
- embeddings = outputs if not self.rbln_config.output_hidden_states else outputs[0]
392
- hidden_states = None if not self.rbln_config.output_hidden_states else outputs[1]
380
+ embeddings = outputs[0][:, : inputs_embeds.shape[1]]
381
+ hidden_states = (
382
+ None
383
+ if not self.rbln_config.output_hidden_states
384
+ else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
385
+ )
393
386
 
394
387
  # L2 normalization
395
388
  embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
@@ -57,7 +57,6 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
57
57
  1. Converting pre-trained transformer models to RBLN-optimized format
58
58
  2. Handling the compilation process for RBLN devices
59
59
  3. Managing inference operations for decoder-only architectures
60
-
61
60
  This class inherits from RBLNModel and implements specific methods required for
62
61
  decoder-only architectures.
63
62
 
@@ -68,6 +67,8 @@ class RBLNDecoderOnlyModel(RBLNModel, RBLNDecoderOnlyFlashAttentionMixin):
68
67
  - The class handles RBLN-specific optimizations automatically during compilation
69
68
  """
70
69
 
70
+ _tp_support = True
71
+
71
72
  main_input_name = "input_ids"
72
73
  auto_model_class = AutoModel
73
74
  _decoder_wrapper_cls = DecoderOnlyWrapper
@@ -642,7 +643,6 @@ class RBLNDecoderOnlyModelForCausalLM(RBLNDecoderOnlyModel, RBLNDecoderOnlyGener
642
643
  1. Converting pre-trained transformer models to RBLN-optimized format
643
644
  2. Handling the compilation process for RBLN devices
644
645
  3. Managing inference operations for causal language modeling
645
-
646
646
  This class inherits from RBLNModel and implements specific methods required for
647
647
  decoder-only architectures and causal language modeling tasks.
648
648
 
@@ -14,8 +14,11 @@
14
14
  from typing import Any, Optional
15
15
 
16
16
  from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
17
18
  from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
18
- from ..siglip.configuration_siglip import RBLNSiglipVisionModelConfig
19
+
20
+
21
+ logger = get_logger(__name__)
19
22
 
20
23
 
21
24
  class RBLNGemma3ForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
@@ -89,8 +92,13 @@ class RBLNGemma3ForConditionalGenerationConfig(RBLNModelConfig):
89
92
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
90
93
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
91
94
 
92
- self.vision_tower = self.init_submodule_config(RBLNSiglipVisionModelConfig, vision_tower)
93
- self.language_model = self.init_submodule_config(RBLNGemma3ForCausalLMConfig, language_model)
95
+ if self.batch_size != 1:
96
+ logger.warning("Ignore batch_size for Gemma3 vision tower. It will be set to 1.")
97
+
98
+ self.vision_tower = self.initialize_submodule_config(
99
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
100
+ )
101
+ self.language_model = self.initialize_submodule_config(submodule_config=language_model)
94
102
 
95
103
  @property
96
104
  def image_prefill_chunk_size(self):
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ import importlib
14
15
  import inspect
15
16
  from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
16
17
 
@@ -25,6 +26,7 @@ from transformers.models.gemma3.modeling_gemma3 import Gemma3TextScaledWordEmbed
25
26
  from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
26
27
  from ....modeling import RBLNModel
27
28
  from ...modeling_outputs import RBLNDecoderOnlyOutput
29
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
28
30
  from ..decoderonly.decoderonly_runtime_utils import RBLNPageTableManager
29
31
  from ..decoderonly.modeling_decoderonly import (
30
32
  RBLNDecoderOnlyModelForCausalLM,
@@ -38,58 +40,41 @@ if TYPE_CHECKING:
38
40
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, Gemma3ForConditionalGeneration
39
41
 
40
42
 
41
- class LoopVisionTower:
42
- def __init__(self, vision_tower: RBLNModel) -> None:
43
- self.vision_tower = vision_tower
43
+ class LoopVisionTower(LoopProcessor):
44
+ def __init__(self, vision_tower: "RBLNModel"):
45
+ super().__init__(model=vision_tower)
44
46
 
45
- def forward(self, *args, **kwargs):
46
- # Loop instead of batch
47
- # shape of pixel_values : [batch, num_channel, height, width]
48
- pixel_values = args[0]
47
+ def _get_batch_size(self, pixel_values, **kwargs):
48
+ return pixel_values.shape[0]
49
49
 
50
- batch_size = pixel_values.shape[0]
51
- outputs = []
52
- for i in range(batch_size):
53
- outputs.append(self.vision_tower(pixel_values=pixel_values[i : i + 1], return_dict=True))
50
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
51
+ pixel_values_item = pixel_values[index : index + 1]
52
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
53
+ return ([pixel_values_item], {"out": out_buffer})
54
54
 
55
- last_hidden_states = [output.last_hidden_state for output in outputs]
56
-
57
- # FIXME:: This can be optimized using out= API of rbln runtime.
58
- last_hidden_states = torch.cat(last_hidden_states, dim=0)
55
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
56
+ output = kwargs["out"]
59
57
 
60
58
  return BaseModelOutputWithPooling(
61
- last_hidden_state=last_hidden_states,
59
+ last_hidden_state=output[0],
62
60
  )
63
61
 
64
- def __call__(self, *args: Any, **kwds: Any) -> Any:
65
- return self.forward(*args, **kwds)
66
-
67
- def __repr__(self) -> str:
68
- return repr(self.vision_tower)
69
-
70
62
 
71
- class LoopProjector:
72
- def __init__(self, multi_modal_projector) -> None:
73
- self.multi_modal_projector = multi_modal_projector
63
+ class LoopProjector(LoopProcessor):
64
+ def __init__(self, multi_modal_projector: "RBLNModel"):
65
+ super().__init__(model=multi_modal_projector)
74
66
 
75
- def forward(self, *args, **kwargs):
76
- # Loop instead of batch
77
- image_feature = args[0]
67
+ def _get_batch_size(self, image_feature, **kwargs):
68
+ return image_feature.shape[0]
78
69
 
79
- batch_size = image_feature.shape[0]
80
- outputs = []
81
- for i in range(batch_size):
82
- outputs.append(self.multi_modal_projector(image_feature[i : i + 1]))
70
+ def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
71
+ image_feature_item = image_feature[index : index + 1]
72
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
73
+ return ([image_feature_item], {"out": out_buffer})
83
74
 
84
- # FIXME:: This can be optimized using out= API of rbln runtime.
85
- outputs = torch.cat(outputs, dim=0)
86
- return outputs
87
-
88
- def __call__(self, *args: Any, **kwds: Any) -> Any:
89
- return self.forward(*args, **kwds)
90
-
91
- def __repr__(self) -> str:
92
- return repr(self.multi_modal_projector)
75
+ def _process_outputs(self, outputs: list, **kwargs):
76
+ output = kwargs["out"]
77
+ return output[0]
93
78
 
94
79
 
95
80
  class RBLNGemma3ForConditionalGeneration(RBLNModel):
@@ -112,6 +97,23 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
112
97
  def can_generate(self):
113
98
  return True
114
99
 
100
+ @classmethod
101
+ def get_pytorch_model(cls, *args, **kwargs):
102
+ model = super().get_pytorch_model(*args, **kwargs)
103
+
104
+ with no_init_weights():
105
+ model_cls_name = model.model.language_model.__class__.__name__
106
+ causal_model_cls_name = model_cls_name.replace("TextModel", "ForCausalLM")
107
+ causal_model_cls = getattr(importlib.import_module("transformers"), causal_model_cls_name)
108
+ new_language_model = causal_model_cls(model.model.language_model.config)
109
+
110
+ new_language_model.lm_head = model.lm_head
111
+ new_language_model.model = model.model.language_model
112
+ model.model.language_model = new_language_model
113
+ model.lm_head = None
114
+ del model.lm_head
115
+ return model
116
+
115
117
  def __post_init__(self, **kwargs):
116
118
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
117
119
  self.language_model = self.rbln_submodules[1]
@@ -210,8 +212,21 @@ class RBLNGemma3ForConditionalGeneration(RBLNModel):
210
212
  # Returns:
211
213
  # Image feature tensor of shape `(num_images, image_length, embed_dim)`.
212
214
 
213
- vision_outputs = self.vision_tower(pixel_values).last_hidden_state
214
- image_features = self.multi_modal_projector(vision_outputs)
215
+ vision_out_buffer = []
216
+ vision_out_size = [
217
+ pixel_values.shape[0],
218
+ (self.config.vision_config.image_size // self.config.vision_config.patch_size) ** 2,
219
+ self.config.vision_config.hidden_size,
220
+ ]
221
+ projector_out_size = [
222
+ pixel_values.shape[0],
223
+ self.config.mm_tokens_per_image,
224
+ self.config.text_config.hidden_size,
225
+ ]
226
+ vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
227
+ projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
228
+ vision_outputs = self.vision_tower(pixel_values, out=vision_out_buffer).last_hidden_state
229
+ image_features = self.multi_modal_projector(vision_outputs, out=projector_out_buffer)
215
230
  return image_features
216
231
 
217
232
  def _preprocess_prefill(
@@ -15,10 +15,29 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNIdefics3VisionTransformerConfig(RBLNModelConfig):
21
- pass
25
+ """
26
+ Configuration class for RBLNIdefics3VisionTransformer.
27
+
28
+ This configuration class stores the configuration parameters specific to
29
+ RBLN-optimized Idefics3 vision transformer.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ batch_size: Optional[int] = None,
35
+ **kwargs: Any,
36
+ ):
37
+ super().__init__(**kwargs)
38
+ self.batch_size = batch_size or 1
39
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
40
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
22
41
 
23
42
 
24
43
  class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
@@ -61,5 +80,10 @@ class RBLNIdefics3ForConditionalGenerationConfig(RBLNModelConfig):
61
80
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
62
81
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
63
82
 
64
- self.vision_model = vision_model
65
- self.text_model = text_model
83
+ if self.batch_size != 1:
84
+ logger.warning("Ignore batch_size for Idefics3 vision transformer. It will be set to 1.")
85
+
86
+ self.vision_model = self.initialize_submodule_config(
87
+ submodule_config=vision_model, batch_size=1, force_kwargs=True
88
+ )
89
+ self.text_model = self.initialize_submodule_config(submodule_config=text_model)
@@ -75,10 +75,12 @@ class RBLNRuntimeVisionModel(RBLNPytorchRuntime):
75
75
 
76
76
  hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask)
77
77
 
78
- return super().forward(hidden_states.contiguous())
78
+ return super().forward(hidden_states.contiguous(), **kwargs)
79
79
 
80
80
 
81
81
  class RBLNIdefics3VisionTransformer(RBLNModel):
82
+ _tp_support = False
83
+
82
84
  def __post_init__(self, **kwargs):
83
85
  artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
84
86
  with no_init_weights():
@@ -140,8 +142,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
140
142
  (
141
143
  "hidden_states",
142
144
  [
143
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
144
- 1,
145
+ rbln_config.batch_size,
145
146
  (model_config.image_size // model_config.patch_size) ** 2,
146
147
  model_config.hidden_size,
147
148
  ],
@@ -160,22 +161,24 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
160
161
  return_dict: Optional[bool] = None,
161
162
  **kwargs,
162
163
  ) -> Union[Tuple, BaseModelOutput]:
163
- batch_size = pixel_values.shape[0]
164
- last_hidden_state = []
165
- for i in range(batch_size):
164
+ last_hidden_state_size = [
165
+ pixel_values.shape[0],
166
+ (self.config.image_size // self.config.patch_size) ** 2,
167
+ self.config.hidden_size,
168
+ ]
169
+ last_hidden_state = torch.empty(size=last_hidden_state_size, dtype=torch.float32, device="cpu")
170
+ for i in range(pixel_values.shape[0]):
166
171
  if patch_attention_mask is not None:
167
172
  batch_attention_mask = patch_attention_mask[i : i + 1,]
168
173
  else:
169
174
  batch_attention_mask = None
170
175
 
171
- batch_hidden_state = self.model(
176
+ self.model(
172
177
  pixel_values[i : i + 1,],
173
178
  batch_attention_mask,
179
+ out=last_hidden_state[i : i + 1,],
174
180
  return_dict=False,
175
181
  )
176
- last_hidden_state.append(batch_hidden_state)
177
- last_hidden_state = torch.cat(last_hidden_state, dim=0)
178
-
179
182
  if not return_dict:
180
183
  return (last_hidden_state,)
181
184
  else:
@@ -285,8 +288,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
285
288
  (
286
289
  "image_hidden_states",
287
290
  [
288
- # batch_size * num_patches (dependent on image size) -> compile with 1 and use for loop
289
- 1,
291
+ rbln_config.vision_model.batch_size,
290
292
  (model_config.vision_config.image_size // model_config.vision_config.patch_size) ** 2,
291
293
  model_config.vision_config.hidden_size,
292
294
  ],
@@ -425,10 +427,15 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel):
425
427
  pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, return_dict=True
426
428
  ).last_hidden_state
427
429
 
428
- connector_outputs = []
430
+ connector_output_size = [
431
+ image_hidden_states.shape[0],
432
+ image_hidden_states.shape[1] // self.config.scale_factor**2,
433
+ self.config.text_config.hidden_size,
434
+ ]
435
+ connector_outputs = torch.empty(size=connector_output_size, dtype=torch.float32, device="cpu")
429
436
  for i in range(image_hidden_states.shape[0]):
430
- connector_outputs.append(self.connector(image_hidden_states[i : i + 1,]))
431
- image_hidden_states = torch.cat(connector_outputs, dim=0)
437
+ self.connector(image_hidden_states[i : i + 1,], out=connector_outputs[i : i + 1,])
438
+ image_hidden_states = connector_outputs
432
439
 
433
440
  elif image_hidden_states is not None:
434
441
  image_hidden_states = image_hidden_states.to(dtype=self.dtype, device=input_ids.device)
@@ -15,6 +15,10 @@
15
15
  from typing import Any, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
18
22
 
19
23
 
20
24
  class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
@@ -54,5 +58,15 @@ class RBLNLlavaForConditionalGenerationConfig(RBLNModelConfig):
54
58
  if not isinstance(self.batch_size, int) or self.batch_size < 0:
55
59
  raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
56
60
 
57
- self.vision_tower = vision_tower
58
- self.language_model = language_model
61
+ if self.batch_size != 1:
62
+ logger.warning("Ignore batch_size for Llava vision tower. It will be set to 1.")
63
+
64
+ self.vision_tower = self.initialize_submodule_config(
65
+ submodule_config=vision_tower,
66
+ batch_size=1, # vision_tower batch_size is always 1 in Llava
67
+ force_kwargs=True,
68
+ )
69
+
70
+ self.language_model = self.initialize_submodule_config(
71
+ submodule_config=language_model,
72
+ )