optimum-rbln 0.9.3rc0__py3-none-any.whl → 0.9.5a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (157) hide show
  1. optimum/rbln/__init__.py +48 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +50 -21
  4. optimum/rbln/diffusers/__init__.py +12 -0
  5. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  9. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  10. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  11. optimum/rbln/diffusers/modeling_diffusers.py +1 -1
  12. optimum/rbln/diffusers/models/__init__.py +17 -3
  13. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  14. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -3
  15. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  16. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  17. optimum/rbln/diffusers/models/controlnet.py +17 -2
  18. optimum/rbln/diffusers/models/transformers/prior_transformer.py +16 -2
  19. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +16 -1
  20. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +14 -1
  21. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  22. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +18 -2
  23. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  24. optimum/rbln/diffusers/pipelines/__init__.py +4 -0
  25. optimum/rbln/diffusers/pipelines/auto_pipeline.py +2 -2
  26. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  27. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +13 -4
  28. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +13 -4
  29. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +13 -4
  30. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -4
  31. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +1 -1
  32. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +1 -1
  33. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +1 -2
  34. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  35. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  36. optimum/rbln/modeling.py +20 -45
  37. optimum/rbln/modeling_base.py +18 -14
  38. optimum/rbln/ops/__init__.py +1 -0
  39. optimum/rbln/ops/attn.py +10 -0
  40. optimum/rbln/ops/flash_attn.py +8 -0
  41. optimum/rbln/ops/moe.py +180 -0
  42. optimum/rbln/ops/sliding_window_attn.py +9 -0
  43. optimum/rbln/transformers/__init__.py +36 -0
  44. optimum/rbln/transformers/configuration_generic.py +0 -27
  45. optimum/rbln/transformers/modeling_attention_utils.py +156 -127
  46. optimum/rbln/transformers/modeling_generic.py +2 -61
  47. optimum/rbln/transformers/modeling_outputs.py +26 -0
  48. optimum/rbln/transformers/modeling_rope_utils.py +78 -42
  49. optimum/rbln/transformers/models/__init__.py +28 -0
  50. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  51. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  52. optimum/rbln/transformers/models/auto/auto_factory.py +1 -0
  53. optimum/rbln/transformers/models/bart/bart_architecture.py +24 -24
  54. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  55. optimum/rbln/transformers/models/bert/modeling_bert.py +86 -1
  56. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +42 -15
  57. optimum/rbln/transformers/models/clip/modeling_clip.py +40 -2
  58. optimum/rbln/transformers/models/colpali/colpali_architecture.py +14 -20
  59. optimum/rbln/transformers/models/colpali/configuration_colpali.py +12 -17
  60. optimum/rbln/transformers/models/colpali/modeling_colpali.py +66 -221
  61. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +38 -23
  62. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +107 -371
  63. optimum/rbln/transformers/models/decoderonly/__init__.py +2 -0
  64. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +128 -17
  65. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +2 -2
  66. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +211 -89
  67. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +205 -64
  68. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +17 -9
  69. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +1 -1
  70. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +194 -132
  71. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +17 -0
  72. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  73. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  74. optimum/rbln/transformers/models/exaone/exaone_architecture.py +0 -36
  75. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -1
  76. optimum/rbln/transformers/models/gemma2/__init__.py +16 -0
  77. optimum/rbln/transformers/models/gemma2/configuration_gemma2.py +45 -0
  78. optimum/rbln/transformers/models/gemma2/gemma2_architecture.py +83 -0
  79. optimum/rbln/transformers/models/gemma2/modeling_gemma2.py +101 -0
  80. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +23 -19
  81. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +42 -70
  82. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +46 -31
  83. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +8 -34
  84. optimum/rbln/transformers/models/gpt_oss/__init__.py +16 -0
  85. optimum/rbln/transformers/models/gpt_oss/configuration_gpt_oss.py +41 -0
  86. optimum/rbln/transformers/models/gpt_oss/gpt_oss_architecture.py +122 -0
  87. optimum/rbln/transformers/models/gpt_oss/modeling_gpt_oss.py +165 -0
  88. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +8 -5
  89. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +7 -5
  90. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +24 -9
  91. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +3 -5
  92. optimum/rbln/transformers/models/llava/modeling_llava.py +37 -26
  93. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +3 -5
  94. optimum/rbln/transformers/models/midm/midm_architecture.py +29 -22
  95. optimum/rbln/transformers/models/mistral/modeling_mistral.py +0 -22
  96. optimum/rbln/transformers/models/opt/modeling_opt.py +2 -2
  97. optimum/rbln/transformers/models/opt/opt_architecture.py +1 -44
  98. optimum/rbln/transformers/models/paligemma/__init__.py +16 -0
  99. optimum/rbln/transformers/models/paligemma/configuration_paligemma.py +129 -0
  100. optimum/rbln/transformers/models/paligemma/modeling_paligemma.py +564 -0
  101. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +1 -1
  102. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +24 -24
  103. optimum/rbln/transformers/models/phi/phi_architecture.py +13 -21
  104. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +13 -1
  105. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +2 -2
  106. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +0 -28
  107. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +6 -1
  108. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +11 -1
  109. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +278 -130
  110. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +43 -39
  111. optimum/rbln/transformers/models/qwen2_moe/__init__.py +16 -0
  112. optimum/rbln/transformers/models/qwen2_moe/configuration_qwen2_moe.py +38 -0
  113. optimum/rbln/transformers/models/qwen2_moe/modeling_qwen2_moe.py +68 -0
  114. optimum/rbln/transformers/models/qwen2_moe/qwen2_moe_architecture.py +94 -0
  115. optimum/rbln/transformers/models/qwen2_vl/__init__.py +6 -1
  116. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +11 -1
  117. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +268 -111
  118. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +27 -35
  119. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +0 -20
  120. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +7 -7
  121. optimum/rbln/transformers/models/qwen3_moe/__init__.py +16 -0
  122. optimum/rbln/transformers/models/qwen3_moe/configuration_qwen3_moe.py +38 -0
  123. optimum/rbln/transformers/models/qwen3_moe/modeling_qwen3_moe.py +68 -0
  124. optimum/rbln/transformers/models/qwen3_moe/qwen3_moe_architecture.py +100 -0
  125. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  126. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  127. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  128. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +2 -4
  129. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +36 -12
  130. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +14 -12
  131. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -19
  132. optimum/rbln/transformers/models/swin/configuration_swin.py +1 -6
  133. optimum/rbln/transformers/models/swin/modeling_swin.py +17 -4
  134. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  135. optimum/rbln/transformers/models/t5/t5_architecture.py +16 -17
  136. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +25 -10
  137. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +0 -3
  138. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  139. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  140. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +60 -8
  141. optimum/rbln/transformers/models/whisper/generation_whisper.py +48 -14
  142. optimum/rbln/transformers/models/whisper/modeling_whisper.py +2 -2
  143. optimum/rbln/transformers/models/whisper/whisper_architecture.py +0 -3
  144. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +53 -0
  145. optimum/rbln/transformers/utils/rbln_quantization.py +29 -12
  146. optimum/rbln/utils/deprecation.py +213 -0
  147. optimum/rbln/utils/hub.py +14 -3
  148. optimum/rbln/utils/import_utils.py +23 -2
  149. optimum/rbln/utils/runtime_utils.py +42 -6
  150. optimum/rbln/utils/submodule.py +27 -1
  151. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/METADATA +6 -6
  152. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/RECORD +155 -129
  153. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/WHEEL +1 -1
  154. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +0 -233
  155. optimum/rbln/utils/depreacate_utils.py +0 -16
  156. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/entry_points.txt +0 -0
  157. {optimum_rbln-0.9.3rc0.dist-info → optimum_rbln-0.9.5a4.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from pathlib import Path
16
- from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from torch import Tensor, nn
@@ -206,8 +206,7 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
206
206
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
207
207
 
208
208
  @classmethod
209
- def get_pytorch_model(cls, *args, **kwargs):
210
- model = super().get_pytorch_model(*args, **kwargs)
209
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
211
210
  model.encoder = model.model.encoder
212
211
  model.decoder = model.model.decoder
213
212
  model.text_backbone = model.model.text_backbone
@@ -217,7 +216,7 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
217
216
  return model
218
217
 
219
218
  @classmethod
220
- def wrap_model_if_needed(
219
+ def _wrap_model_if_needed(
221
220
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
222
221
  ) -> torch.nn.Module:
223
222
  return model.model.text_projection
@@ -305,7 +304,6 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
305
304
  for feature_map, mask in vision_features:
306
305
  # position encoding
307
306
  position_embeddings_list.append(self.backbone_position_embedding(feature_map, mask).to(feature_map.dtype))
308
- vision_features, position_embeddings_list
309
307
 
310
308
  # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
311
309
  feature_maps = []
@@ -530,9 +528,26 @@ class RBLNGroundingDinoForObjectDetection(RBLNModel):
530
528
  output_attentions: Optional[bool] = None,
531
529
  output_hidden_states: Optional[bool] = None,
532
530
  return_dict: Optional[bool] = None,
533
- labels: List[Dict[str, Union[torch.LongTensor, torch.FloatTensor]]] = None,
534
531
  **kwargs,
535
- ):
532
+ ) -> Union[GroundingDinoObjectDetectionOutput, Tuple]:
533
+ """
534
+ Forward pass for the RBLN-optimized GroundingDinoForObjectDetection model.
535
+
536
+ Args:
537
+ pixel_values (torch.Tensor of shape (batch_size, num_channels, image_size, image_size)): The tensors corresponding to the input images.
538
+ input_ids (torch.LongTensor of shape (batch_size, text_sequence_length)): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it.
539
+ token_type_ids (torch.LongTensor of shape (batch_size, text_sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
540
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
541
+ pixel_mask (torch.Tensor of shape (batch_size, height, width), optional): Mask to avoid performing attention on padding pixel values.
542
+ encoder_outputs (Tuple consists of last_hidden_state of shape(batch_size, sequence_length, hidden_size), optional): A sequence of hidden-states at the output of the last layer of the encoder.
543
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers.
544
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers.
545
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
546
+
547
+ Returns:
548
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a GroundingDinoObjectDetectionOutput object.
549
+ """
550
+
536
551
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
537
552
 
538
553
  # Pad image to rbln_config.image_height and rbln_config.image_width
@@ -663,7 +678,7 @@ class RBLNGroundingDinoEncoder(RBLNModel):
663
678
  self.encoder_runtime = RBLNPytorchRuntime(self.model[0])
664
679
 
665
680
  @classmethod
666
- def wrap_model_if_needed(
681
+ def _wrap_model_if_needed(
667
682
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
668
683
  ) -> torch.nn.Module:
669
684
  model = _GroundingDinoEncoder(model, rbln_config).eval()
@@ -861,7 +876,7 @@ class RBLNGroundingDinoDecoder(RBLNModel):
861
876
  self.decoder_runtime = RBLNPytorchRuntime(self.model[0])
862
877
 
863
878
  @classmethod
864
- def wrap_model_if_needed(
879
+ def _wrap_model_if_needed(
865
880
  cls, model: torch.nn.Module, rbln_config: RBLNGroundingDinoForObjectDetectionConfig
866
881
  ) -> torch.nn.Module:
867
882
  return _GroundingDinoDecoder(model, rbln_config).eval()
@@ -110,7 +110,7 @@ class RBLNIdefics3VisionTransformer(RBLNModel):
110
110
  return self.embeddings
111
111
 
112
112
  @classmethod
113
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
113
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
114
114
  class Idefics3VisionTransformerWrapper(torch.nn.Module):
115
115
  def __init__(self, model: "Idefics3VisionTransformer"):
116
116
  super().__init__()
@@ -240,9 +240,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM
240
240
  return True
241
241
 
242
242
  @classmethod
243
- def get_pytorch_model(cls, *args, **kwargs):
244
- model = super().get_pytorch_model(*args, **kwargs)
245
-
243
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
246
244
  with no_init_weights():
247
245
  model_cls_name = model.model.text_model.__class__.__name__
248
246
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -271,7 +269,7 @@ class RBLNIdefics3ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationM
271
269
  return self.text_model.get_input_embeddings()
272
270
 
273
271
  @classmethod
274
- def wrap_model_if_needed(cls, model, rbln_config):
272
+ def _wrap_model_if_needed(cls, model, rbln_config):
275
273
  return model.model.connector
276
274
 
277
275
  @classmethod
@@ -88,15 +88,22 @@ class LoopVisionTower(LoopProcessor):
88
88
 
89
89
 
90
90
  class LoopProjector(LoopProcessor):
91
- def __init__(self, multi_modal_projector: "RBLNModel"):
91
+ def __init__(self, multi_modal_projector: "RBLNModel", rbln_config=None):
92
92
  super().__init__(model=multi_modal_projector)
93
+ self.rbln_config = rbln_config
93
94
 
94
95
  def _get_batch_size(self, image_feature, **kwargs):
95
96
  return image_feature.shape[0]
96
97
 
97
98
  def _prepare_inputs_for_iteration(self, index, common_inputs, image_feature, **kwargs):
98
99
  image_feature_item = image_feature[index : index + 1]
99
- out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
+ if hasattr(self.rbln_config.vision_tower, "max_image_size"):
101
+ out_buffer = [
102
+ tensor[:, index * image_feature.shape[1] : (index + 1) * image_feature.shape[1], :]
103
+ for tensor in kwargs["out"]
104
+ ]
105
+ else:
106
+ out_buffer = [tensor[index : index + 1] for tensor in kwargs["out"]]
100
107
  return ([image_feature_item], {"out": out_buffer})
101
108
 
102
109
  def _process_outputs(self, outputs: list, **kwargs):
@@ -109,7 +116,6 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
109
116
  RBLNLlavaForConditionalGeneration is a multi-modal model that combines vision and language processing capabilities,
110
117
  optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
111
118
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
112
-
113
119
  Important Note:
114
120
  This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
115
121
  tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
@@ -175,9 +181,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
175
181
  return True
176
182
 
177
183
  @classmethod
178
- def get_pytorch_model(cls, *args, **kwargs):
179
- model = super().get_pytorch_model(*args, **kwargs)
180
-
184
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
181
185
  with no_init_weights():
182
186
  model_cls_name = model.model.language_model.__class__.__name__
183
187
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -194,7 +198,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
194
198
  def __post_init__(self, **kwargs):
195
199
  self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
196
200
  self.language_model = self.rbln_submodules[1]
197
- self.multi_modal_projector = LoopProjector(self.model[0])
201
+ self.multi_modal_projector = LoopProjector(self.model[0], rbln_config=self.rbln_config)
198
202
  self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
199
203
  return super().__post_init__(**kwargs)
200
204
 
@@ -208,7 +212,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
208
212
  return self.language_model.get_input_embeddings()
209
213
 
210
214
  @classmethod
211
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
215
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
212
216
  return model.multi_modal_projector
213
217
 
214
218
  @classmethod
@@ -221,10 +225,8 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
221
225
  ) -> RBLNModelConfig:
222
226
  # support for pixtral that needs padding
223
227
  if hasattr(rbln_config.vision_tower, "max_image_size"):
224
- num_positions = (
225
- rbln_config.batch_size
226
- * (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size)
227
- * (rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size)
228
+ num_positions = (rbln_config.vision_tower.max_image_size[0] // model_config.vision_config.patch_size) * (
229
+ rbln_config.vision_tower.max_image_size[1] // model_config.vision_config.patch_size
228
230
  )
229
231
  selected_image_feature_dim = num_positions
230
232
 
@@ -334,7 +336,7 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
334
336
  pooler_out_size = [pixel_values.shape[0], self.config.vision_config.hidden_size]
335
337
 
336
338
  vision_out_buffer = []
337
- for i in range(self.config.vision_config.num_hidden_layers + 2):
339
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
338
340
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
339
341
  if pooler_out_size is not None:
340
342
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
@@ -353,23 +355,32 @@ class RBLNLlavaForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixi
353
355
 
354
356
  if hasattr(self.rbln_config.vision_tower, "max_image_size"):
355
357
  num_real_patches = selected_image_feature.shape[1]
356
- max_patches = (
357
- (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size)
358
- * (self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size)
359
- * pixel_values.shape[0]
358
+ max_patches = (self.rbln_config.vision_tower.max_image_size[0] // self.config.vision_config.patch_size) * (
359
+ self.rbln_config.vision_tower.max_image_size[1] // self.config.vision_config.patch_size
360
360
  )
361
- num_padding_patches = max_patches - num_real_patches
362
361
 
363
- projector_out_size = [1, max_patches, self.config.text_config.hidden_size]
362
+ chunks = []
363
+ for i in range(0, num_real_patches, max_patches):
364
+ chunk = selected_image_feature[:, i : i + max_patches, :]
365
+ chunk_size = chunk.shape[1]
366
+
367
+ if chunk_size < max_patches:
368
+ padding_tensor = torch.zeros(
369
+ (selected_image_feature.shape[0], max_patches - chunk_size, selected_image_feature.shape[2]),
370
+ dtype=selected_image_feature.dtype,
371
+ )
372
+ chunk = torch.cat([chunk, padding_tensor], dim=1)
373
+ chunks.append(chunk)
374
+
375
+ split_features = torch.cat(chunks, dim=0)
376
+ num_chunks = len(chunks)
377
+ projector_out_size = [1, max_patches * num_chunks, self.config.text_config.hidden_size]
364
378
  projector_out_buffer = [torch.empty(size=projector_out_size, dtype=torch.float32, device="cpu")]
365
-
366
- padding_tensor = torch.zeros(
367
- (selected_image_feature.shape[0], num_padding_patches, selected_image_feature.shape[2]),
368
- dtype=selected_image_feature.dtype,
379
+ projected_features = self.multi_modal_projector(split_features, out=projector_out_buffer)
380
+ projected_features = projected_features.view(
381
+ selected_image_feature.shape[0], num_chunks * max_patches, self.config.text_config.hidden_size
369
382
  )
370
- padded_feature = torch.cat([selected_image_feature, padding_tensor], dim=1)
371
- padded_projected_feature = self.multi_modal_projector(padded_feature, out=projector_out_buffer)
372
- image_features = padded_projected_feature[:, :num_real_patches, :]
383
+ image_features = projected_features[:, :num_real_patches, :]
373
384
  else:
374
385
  projector_out_size = [
375
386
  pixel_values.shape[0] * pixel_values.shape[1],
@@ -139,9 +139,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
139
139
  return True
140
140
 
141
141
  @classmethod
142
- def get_pytorch_model(cls, *args, **kwargs):
143
- model = super().get_pytorch_model(*args, **kwargs)
144
-
142
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
145
143
  with no_init_weights():
146
144
  model_cls_name = model.model.language_model.__class__.__name__
147
145
  causal_model_cls_name = model_cls_name.replace("Model", "ForCausalLM")
@@ -192,7 +190,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
192
190
  return self.language_model.get_input_embeddings()
193
191
 
194
192
  @classmethod
195
- def wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
193
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
196
194
  return model.multi_modal_projector
197
195
 
198
196
  @classmethod
@@ -302,7 +300,7 @@ class RBLNLlavaNextForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGeneration
302
300
  ]
303
301
  pooler_out_size = [pixel_values.shape[0] * pixel_values.shape[1], self.config.vision_config.hidden_size]
304
302
  vision_out_buffer = []
305
- for i in range(self.config.vision_config.num_hidden_layers + 2):
303
+ for _ in range(self.config.vision_config.num_hidden_layers + 2):
306
304
  vision_out_buffer.append(torch.empty(size=vision_out_size, dtype=torch.float32, device="cpu"))
307
305
  vision_out_buffer.insert(1, torch.empty(size=pooler_out_size, dtype=torch.float32, device="cpu"))
308
306
 
@@ -71,6 +71,12 @@ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
71
71
 
72
72
 
73
73
  class MidmModel(DecoderOnlyModel):
74
+ def __init__(self, model, layers, rbln_config, use_learned_pos_emb=None, use_rotary_emb=True):
75
+ super().__init__(
76
+ model, layers, rbln_config, use_learned_pos_emb=use_learned_pos_emb, use_rotary_emb=use_rotary_emb
77
+ )
78
+ self.use_layernorm1p = getattr(model, "use_layernorm1p", False)
79
+
74
80
  def get_layernorm1p(self, module: nn.LayerNorm):
75
81
  def layernorm1p(input: torch.Tensor):
76
82
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -81,19 +87,22 @@ class MidmModel(DecoderOnlyModel):
81
87
  return layernorm1p
82
88
 
83
89
  def get_last_layernorm(self) -> nn.LayerNorm:
84
- if self._original_mod.use_layernorm1p:
85
- return self.get_layernorm1p(self._original_mod.ln_f)
86
- else:
87
- return self._original_mod.ln_f
90
+ if self.use_layernorm1p:
91
+ return self.get_layernorm1p(self.norm)
92
+ return self.norm
88
93
 
89
94
  def get_embedding(self) -> nn.Embedding:
90
- return self._original_mod.wte
95
+ return self.embed_tokens
91
96
 
92
97
  def get_pos_embedding(self) -> nn.Embedding:
93
- return self._original_mod.wpe
98
+ return self.embed_positions
94
99
 
95
100
 
96
101
  class MidmLayer(DecoderOnlyLayer):
102
+ def __init__(self, layer, self_attn: DecoderOnlyAttention, lora_config=None):
103
+ super().__init__(layer, self_attn, lora_config)
104
+ self.use_layernorm1p = getattr(layer, "use_layernorm1p", False)
105
+
97
106
  def get_layernorm1p(self, module: nn.LayerNorm):
98
107
  def layernorm1p(input: torch.Tensor):
99
108
  """Applies Layer Normalization with a slight modification on the weights."""
@@ -104,24 +113,22 @@ class MidmLayer(DecoderOnlyLayer):
104
113
  return layernorm1p
105
114
 
106
115
  def get_pre_attention_layernorm(self) -> nn.LayerNorm:
107
- if self._original_mod.use_layernorm1p:
108
- return self.get_layernorm1p(self._original_mod.ln_1)
109
- else:
110
- return self._original_mod.ln_1
116
+ if self.use_layernorm1p:
117
+ return self.get_layernorm1p(self.pre_attention_layernorm)
118
+ return self.pre_attention_layernorm
111
119
 
112
120
  def get_post_attention_layernorm(self) -> nn.LayerNorm:
113
- if self._original_mod.use_layernorm1p:
114
- return self.get_layernorm1p(self._original_mod.ln_2)
115
- else:
116
- return self._original_mod.ln_2
121
+ if self.use_layernorm1p:
122
+ return self.get_layernorm1p(self.post_attention_layernorm)
123
+ return self.post_attention_layernorm
117
124
 
118
125
 
119
126
  class MidmAttention(DecoderOnlyAttention):
120
- def __post_init__(self):
121
- self.c_attn = self._original_mod.c_attn
122
- self.o_proj = self._original_mod.c_proj
123
- self.split_size = self._original_mod.split_size
124
- self.num_key_value_heads = self._original_mod.num_heads
127
+ def __post_init__(self, self_attn):
128
+ self.c_attn = self_attn.c_attn
129
+ self.o_proj = self_attn.c_proj
130
+ self.split_size = self_attn.split_size
131
+ self.num_key_value_heads = self_attn.num_heads
125
132
 
126
133
  def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
134
  if lora_int_id is not None:
@@ -130,12 +137,12 @@ class MidmAttention(DecoderOnlyAttention):
130
137
  query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
131
138
  return query_states, key_states, value_states
132
139
 
133
- def get_attn_scale(self):
140
+ def get_attn_scale(self, self_attn):
134
141
  scale = 1.0
135
- if self._original_mod.scale_attn_weights:
142
+ if self_attn.scale_attn_weights:
136
143
  scale /= math.sqrt(self.head_dim)
137
144
 
138
- if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
145
+ if self_attn.scale_attn_by_inverse_layer_idx:
139
146
  scale /= 1 + self.layer_idx
140
147
 
141
148
  return scale
@@ -12,13 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from transformers import PretrainedConfig
16
15
 
17
16
  from ....utils import logging
18
17
  from ...models.decoderonly import (
19
18
  RBLNDecoderOnlyModel,
20
19
  RBLNDecoderOnlyModelForCausalLM,
21
- RBLNDecoderOnlyModelForCausalLMConfig,
22
20
  )
23
21
  from .mistral_architecture import MistralWrapper
24
22
 
@@ -85,16 +83,6 @@ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
85
83
 
86
84
  _decoder_wrapper_cls = MistralWrapper
87
85
 
88
- @classmethod
89
- def _update_sliding_window_config(
90
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
- ):
92
- rbln_config.cache_impl = "sliding_window"
93
- rbln_config.sliding_window = model_config.sliding_window
94
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
-
96
- return rbln_config
97
-
98
86
 
99
87
  class RBLNMistralModel(RBLNDecoderOnlyModel):
100
88
  """
@@ -103,13 +91,3 @@ class RBLNMistralModel(RBLNDecoderOnlyModel):
103
91
  """
104
92
 
105
93
  _decoder_wrapper_cls = MistralWrapper
106
-
107
- @classmethod
108
- def _update_sliding_window_config(
109
- cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
110
- ):
111
- rbln_config.cache_impl = "sliding_window"
112
- rbln_config.sliding_window = model_config.sliding_window
113
- rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
114
-
115
- return rbln_config
@@ -69,7 +69,7 @@ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
69
69
  return layer
70
70
 
71
71
  @classmethod
72
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
72
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
73
73
  for i in range(len(model.model.decoder.layers)):
74
74
  model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
75
75
 
@@ -95,7 +95,7 @@ class RBLNOPTModel(RBLNDecoderOnlyModel):
95
95
  return layer
96
96
 
97
97
  @classmethod
98
- def wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
98
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
99
99
  for i in range(len(model.decoder.layers)):
100
100
  model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
101
101
 
@@ -14,14 +14,7 @@
14
14
 
15
15
  from typing import TYPE_CHECKING
16
16
 
17
- import torch.nn as nn
18
-
19
- from ...models.decoderonly.decoderonly_architecture import (
20
- DecoderOnlyAttention,
21
- DecoderOnlyLayer,
22
- DecoderOnlyModel,
23
- DecoderOnlyWrapper,
24
- )
17
+ from ...models.decoderonly.decoderonly_architecture import DecoderOnlyWrapper
25
18
 
26
19
 
27
20
  if TYPE_CHECKING:
@@ -31,44 +24,8 @@ if TYPE_CHECKING:
31
24
  class OPTWrapper(DecoderOnlyWrapper):
32
25
  _use_learned_pos_emb = True
33
26
 
34
- def get_rbln_attn_class(self):
35
- return OPTAttention
36
-
37
- def get_rbln_layer_class(self):
38
- return OPTDecoderLayer
39
-
40
- def get_rbln_model_class(self):
41
- return OPTModel
42
-
43
27
  def get_model_layer(self, model: "OPTForCausalLM"):
44
28
  return model.model.decoder if self.is_causal_lm else model.decoder
45
29
 
46
30
  def get_decoder_layers(self, model: "OPTForCausalLM"):
47
31
  return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
48
-
49
-
50
- class OPTAttention(DecoderOnlyAttention):
51
- def __post_init__(self):
52
- self.k_proj = self._original_mod.k_proj
53
- self.v_proj = self._original_mod.v_proj
54
- self.q_proj = self._original_mod.q_proj
55
- self.o_proj = self._original_mod.out_proj
56
-
57
-
58
- class OPTModel(DecoderOnlyModel):
59
- def get_embedding(self) -> nn.Embedding:
60
- return self._original_mod.embed_tokens
61
-
62
- def get_pos_embedding(self):
63
- return self._original_mod.embed_positions
64
-
65
- def get_last_layernorm(self) -> nn.LayerNorm:
66
- return self._original_mod.final_layer_norm
67
-
68
-
69
- class OPTDecoderLayer(DecoderOnlyLayer):
70
- def get_pre_attention_layernorm(self) -> nn.LayerNorm:
71
- return self._original_mod.self_attn_layer_norm
72
-
73
- def get_post_attention_layernorm(self) -> nn.LayerNorm:
74
- return self._original_mod.final_layer_norm
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_paligemma import RBLNPaliGemmaForConditionalGenerationConfig, RBLNPaliGemmaModelConfig
16
+ from .modeling_paligemma import RBLNPaliGemmaForConditionalGeneration, RBLNPaliGemmaModel
@@ -0,0 +1,129 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger(__name__)
22
+
23
+
24
+ class RBLNPaliGemmaForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNPaliGemmaForConditionalGenerationConfig.
27
+ This configuration class stores the configuration parameters specific to
28
+ RBLN-optimized PaliGemma models for multimodal conditional generation tasks
29
+ that combine vision and language processing capabilities.
30
+ """
31
+
32
+ submodules = ["vision_tower", "language_model"]
33
+ _allow_no_compile_cfgs = True
34
+
35
+ def __init__(
36
+ self,
37
+ batch_size: Optional[int] = None,
38
+ vision_tower: Optional[RBLNModelConfig] = None,
39
+ language_model: Optional[RBLNModelConfig] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ **kwargs: Any,
42
+ ):
43
+ """
44
+ Args:
45
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
46
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
47
+ This can include settings specific to the vision encoder, such as input resolution or other vision-related parameters.
48
+ If not provided, default settings will be used.
49
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
50
+ This can include settings specific to the language model, such as tensor parallelism or other text-related parameters.
51
+ If not provided, default settings will be used.
52
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
53
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
54
+ Raises:
55
+ ValueError: If `batch_size` is not a positive integer.
56
+ """
57
+ super().__init__(**kwargs)
58
+ self.batch_size = batch_size or 1
59
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
60
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
61
+
62
+ if self.batch_size != 1:
63
+ logger.warning("Ignore batch_size for PaliGemma vision tower. It will be set to 1.")
64
+
65
+ self.output_hidden_states = output_hidden_states or False
66
+
67
+ self.vision_tower = self.initialize_submodule_config(
68
+ submodule_config=vision_tower,
69
+ batch_size=1, # vision_tower batch_size is always 1 in PaliGemma
70
+ force_kwargs=True,
71
+ )
72
+ self.language_model = self.initialize_submodule_config(
73
+ submodule_config=language_model,
74
+ batch_size=batch_size,
75
+ use_position_ids=True,
76
+ use_attention_mask=True,
77
+ use_inputs_embeds=True,
78
+ )
79
+
80
+
81
+ class RBLNPaliGemmaModelConfig(RBLNModelConfig):
82
+ submodules = ["vision_tower", "language_model"]
83
+ _allow_no_compile_cfgs = True
84
+
85
+ def __init__(
86
+ self,
87
+ batch_size: Optional[int] = None,
88
+ vision_tower: Optional[RBLNModelConfig] = None,
89
+ language_model: Optional[RBLNModelConfig] = None,
90
+ output_hidden_states: Optional[bool] = None,
91
+ **kwargs: Any,
92
+ ):
93
+ """
94
+ Args:
95
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
96
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
97
+ This can include settings specific to the vision encoder, such as input resolution or other vision-related parameters.
98
+ If not provided, default settings will be used.
99
+ language_model (Optional[RBLNModelConfig]): Configuration for the language model component.
100
+ This can include settings specific to the language model, such as tensor parallelism or other text-related parameters.
101
+ If not provided, default settings will be used.
102
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the decoder. Defaults to False.
103
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
104
+ Raises:
105
+ ValueError: If `batch_size` is not a positive integer.
106
+ """
107
+ super().__init__(**kwargs)
108
+ self.batch_size = batch_size or 1
109
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
110
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
111
+
112
+ if self.batch_size != 1:
113
+ logger.warning("Ignore batch_size for PaliGemma vision tower. It will be set to 1.")
114
+
115
+ self.output_hidden_states = output_hidden_states or False
116
+ self.vision_tower = self.initialize_submodule_config(
117
+ submodule_config=vision_tower,
118
+ batch_size=1, # vision_tower batch_size is always 1 in PaliGemma
119
+ force_kwargs=True,
120
+ )
121
+
122
+ self.language_model = self.initialize_submodule_config(
123
+ submodule_config=language_model,
124
+ batch_size=batch_size,
125
+ use_position_ids=True,
126
+ use_attention_mask=True,
127
+ use_inputs_embeds=True,
128
+ output_hidden_states=output_hidden_states,
129
+ )