optimum-rbln 0.8.2a0__py3-none-any.whl → 0.9.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (197) hide show
  1. optimum/rbln/__init__.py +116 -9
  2. optimum/rbln/__version__.py +16 -3
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +171 -43
  5. optimum/rbln/diffusers/__init__.py +19 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +3 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +3 -3
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +1 -1
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +3 -3
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +4 -4
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +12 -4
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +9 -4
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +3 -3
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +3 -3
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +3 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +35 -19
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +14 -11
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +30 -20
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +13 -9
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +17 -13
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +17 -10
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +33 -18
  27. optimum/rbln/diffusers/models/__init__.py +4 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +32 -3
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +32 -6
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +27 -8
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +32 -3
  34. optimum/rbln/diffusers/models/controlnet.py +16 -1
  35. optimum/rbln/diffusers/models/transformers/prior_transformer.py +17 -3
  36. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +26 -3
  37. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +23 -2
  38. optimum/rbln/diffusers/models/unets/__init__.py +1 -0
  39. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +23 -4
  40. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  41. optimum/rbln/diffusers/pipelines/__init__.py +15 -5
  42. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  43. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +20 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +23 -12
  45. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +16 -46
  46. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +31 -1
  47. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +31 -1
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -6
  49. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  50. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  51. optimum/rbln/modeling.py +50 -24
  52. optimum/rbln/modeling_base.py +116 -35
  53. optimum/rbln/ops/attn.py +158 -0
  54. optimum/rbln/ops/flash_attn.py +166 -0
  55. optimum/rbln/ops/kv_cache_update.py +5 -0
  56. optimum/rbln/ops/linear.py +7 -0
  57. optimum/rbln/transformers/__init__.py +100 -0
  58. optimum/rbln/transformers/configuration_generic.py +7 -32
  59. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  60. optimum/rbln/transformers/modeling_generic.py +48 -65
  61. optimum/rbln/transformers/modeling_outputs.py +37 -0
  62. optimum/rbln/transformers/models/__init__.py +93 -30
  63. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +28 -2
  64. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +68 -5
  65. optimum/rbln/transformers/models/auto/__init__.py +2 -0
  66. optimum/rbln/transformers/models/auto/auto_factory.py +92 -17
  67. optimum/rbln/transformers/models/auto/modeling_auto.py +45 -0
  68. optimum/rbln/transformers/models/bart/bart_architecture.py +2 -7
  69. optimum/rbln/transformers/models/bart/configuration_bart.py +2 -0
  70. optimum/rbln/transformers/models/bart/modeling_bart.py +23 -2
  71. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  72. optimum/rbln/transformers/models/bert/modeling_bert.py +93 -4
  73. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +42 -11
  74. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +135 -44
  75. optimum/rbln/transformers/models/clip/configuration_clip.py +21 -7
  76. optimum/rbln/transformers/models/clip/modeling_clip.py +183 -27
  77. optimum/rbln/transformers/models/colpali/colpali_architecture.py +3 -6
  78. optimum/rbln/transformers/models/colpali/configuration_colpali.py +37 -21
  79. optimum/rbln/transformers/models/colpali/modeling_colpali.py +82 -104
  80. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  82. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  83. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  84. optimum/rbln/transformers/models/decoderonly/__init__.py +3 -2
  85. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +114 -37
  86. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  87. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +323 -316
  88. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  89. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  90. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  91. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +486 -892
  92. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  93. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  94. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  95. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +24 -0
  96. optimum/rbln/transformers/models/dpt/modeling_dpt.py +17 -0
  97. optimum/rbln/transformers/models/exaone/modeling_exaone.py +42 -4
  98. optimum/rbln/transformers/models/gemma/__init__.py +2 -2
  99. optimum/rbln/transformers/models/gemma/configuration_gemma.py +9 -1
  100. optimum/rbln/transformers/models/gemma/gemma_architecture.py +1 -4
  101. optimum/rbln/transformers/models/gemma/modeling_gemma.py +22 -1
  102. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +49 -14
  103. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +12 -2
  104. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  105. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +212 -504
  106. optimum/rbln/transformers/models/gpt2/__init__.py +2 -2
  107. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +31 -3
  108. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +10 -8
  109. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +18 -1
  110. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  111. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  112. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  113. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  114. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +35 -7
  115. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +29 -32
  116. optimum/rbln/transformers/models/llama/__init__.py +2 -2
  117. optimum/rbln/transformers/models/llama/configuration_llama.py +9 -1
  118. optimum/rbln/transformers/models/llama/modeling_llama.py +22 -1
  119. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  120. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  121. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  122. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +21 -6
  123. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +234 -376
  124. optimum/rbln/transformers/models/midm/midm_architecture.py +4 -1
  125. optimum/rbln/transformers/models/midm/modeling_midm.py +42 -4
  126. optimum/rbln/transformers/models/mistral/__init__.py +2 -2
  127. optimum/rbln/transformers/models/mistral/configuration_mistral.py +9 -1
  128. optimum/rbln/transformers/models/mistral/mistral_architecture.py +1 -1
  129. optimum/rbln/transformers/models/mistral/modeling_mistral.py +26 -3
  130. optimum/rbln/transformers/models/opt/__init__.py +2 -2
  131. optimum/rbln/transformers/models/opt/configuration_opt.py +8 -1
  132. optimum/rbln/transformers/models/opt/modeling_opt.py +29 -17
  133. optimum/rbln/transformers/models/opt/opt_architecture.py +4 -4
  134. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  135. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  136. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  137. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  138. optimum/rbln/transformers/models/phi/__init__.py +2 -2
  139. optimum/rbln/transformers/models/phi/configuration_phi.py +9 -1
  140. optimum/rbln/transformers/models/phi/modeling_phi.py +10 -1
  141. optimum/rbln/transformers/models/phi/phi_architecture.py +11 -7
  142. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  143. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  144. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  145. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  146. optimum/rbln/transformers/models/qwen2/__init__.py +2 -2
  147. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +9 -1
  148. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +27 -1
  149. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +21 -6
  150. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +15 -22
  151. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +28 -7
  152. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  153. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  154. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  155. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  156. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  158. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  159. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  160. optimum/rbln/transformers/models/resnet/configuration_resnet.py +17 -0
  161. optimum/rbln/transformers/models/resnet/modeling_resnet.py +73 -0
  162. optimum/rbln/transformers/models/roberta/modeling_roberta.py +33 -0
  163. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +21 -16
  164. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +60 -13
  165. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +2 -2
  166. optimum/rbln/transformers/models/siglip/__init__.py +2 -6
  167. optimum/rbln/transformers/models/siglip/configuration_siglip.py +1 -1
  168. optimum/rbln/transformers/models/siglip/modeling_siglip.py +21 -16
  169. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  170. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  171. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  172. optimum/rbln/transformers/models/t5/configuration_t5.py +2 -0
  173. optimum/rbln/transformers/models/t5/modeling_t5.py +2 -2
  174. optimum/rbln/transformers/models/t5/t5_architecture.py +8 -1
  175. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +3 -3
  176. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +22 -16
  177. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +7 -1
  178. optimum/rbln/transformers/models/vit/modeling_vit.py +19 -0
  179. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +15 -3
  180. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +61 -8
  181. optimum/rbln/transformers/models/whisper/configuration_whisper.py +12 -13
  182. optimum/rbln/transformers/models/whisper/generation_whisper.py +62 -6
  183. optimum/rbln/transformers/models/whisper/modeling_whisper.py +32 -5
  184. optimum/rbln/transformers/models/xlm_roberta/__init__.py +2 -8
  185. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +43 -0
  186. optimum/rbln/transformers/utils/rbln_quantization.py +400 -75
  187. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  188. optimum/rbln/utils/deprecation.py +213 -0
  189. optimum/rbln/utils/hub.py +22 -50
  190. optimum/rbln/utils/runtime_utils.py +85 -17
  191. optimum/rbln/utils/submodule.py +31 -9
  192. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/METADATA +8 -7
  193. optimum_rbln-0.9.3.dist-info/RECORD +264 -0
  194. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/WHEEL +1 -1
  195. optimum_rbln-0.9.3.dist-info/entry_points.txt +2 -0
  196. optimum_rbln-0.8.2a0.dist-info/RECORD +0 -211
  197. {optimum_rbln-0.8.2a0.dist-info → optimum_rbln-0.9.3.dist-info}/licenses/LICENSE +0 -0
@@ -20,7 +20,9 @@ import rebel
20
20
  import torch
21
21
  from rebel.compile_context import CompileContext
22
22
  from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
- from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
23
+ from transformers.generation.configuration_utils import GenerationConfig
24
+ from transformers.generation.utils import GenerationMixin
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
24
26
 
25
27
  from ....configuration_utils import RBLNCompileConfig
26
28
  from ....modeling import RBLNModel
@@ -32,13 +34,13 @@ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
32
34
  logger = get_logger(__name__)
33
35
 
34
36
  if TYPE_CHECKING:
35
- from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, GenerationConfig, PretrainedConfig
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
36
38
 
37
39
 
38
40
  class RBLNRuntimeEncoder(RBLNPytorchRuntime):
39
41
  mandatory_members = ["main_input_name"]
40
42
 
41
- def forward(self, *args: List[torch.Tensor], **kwargs: Dict[str, torch.Tensor]):
43
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
42
44
  output = super().forward(*args, **kwargs)
43
45
  return BaseModelOutput(last_hidden_state=output)
44
46
 
@@ -83,7 +85,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
83
85
  decoding_step = cache_position[b_idx].item()
84
86
  if not (0 <= decoding_step < self.dec_max_seq_len):
85
87
  raise ValueError(
86
- f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
88
+ f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
87
89
  )
88
90
  decoder_attention_mask[b_idx, : decoding_step + 1] = 1
89
91
 
@@ -101,7 +103,7 @@ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
101
103
  return Seq2SeqLMOutput(logits=lm_logits)
102
104
 
103
105
 
104
- class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
106
+ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
105
107
  """
106
108
  This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
107
109
  This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
@@ -117,6 +119,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
117
119
  main_input_name = "input_ids"
118
120
  auto_model_class = AutoModelForSeq2SeqLM
119
121
  support_causal_attn = None
122
+ _is_stateful = False
120
123
 
121
124
  def __post_init__(self, **kwargs):
122
125
  batch_size = self.rbln_config.batch_size
@@ -138,7 +141,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
138
141
  @classmethod
139
142
  @torch.inference_mode()
140
143
  def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
141
- wrapped_model = cls.wrap_model_if_needed(model, rbln_config)
144
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
142
145
 
143
146
  enc_compile_config = rbln_config.compile_cfgs[0]
144
147
  dec_compile_config = rbln_config.compile_cfgs[1]
@@ -181,6 +184,21 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
181
184
 
182
185
  return {"encoder": compiled_encoder, "decoder": compiled_decoder}
183
186
 
187
+ @classmethod
188
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
189
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
190
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
191
+
192
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
193
+ raise NotImplementedError(
194
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
195
+ )
196
+
197
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
198
+ raise NotImplementedError(
199
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
200
+ )
201
+
184
202
  @classmethod
185
203
  def _update_rbln_config(
186
204
  cls,
@@ -204,12 +222,6 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
204
222
  model_config, "max_position_embeddings", None
205
223
  )
206
224
 
207
- pad_token_id = getattr(model_config, "pad_token_id", None)
208
- pad_token_id = pad_token_id or getattr(model_config, "bos_token_id", None)
209
- pad_token_id = pad_token_id or getattr(model_config, "eos_token_id", None)
210
- pad_token_id = pad_token_id or -1
211
- rbln_config.pad_token_id = pad_token_id
212
-
213
225
  if rbln_config.enc_max_seq_len is None:
214
226
  enc_max_seq_len = max_position_embeddings
215
227
  for tokenizer in preprocessors:
@@ -238,6 +250,9 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
238
250
  if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
239
251
  raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
240
252
 
253
+ if rbln_config.support_paged_attention:
254
+ cls._update_paged_attention_config(model_config, rbln_config)
255
+
241
256
  # model input info
242
257
  enc_input_info = [
243
258
  ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
@@ -310,6 +325,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
310
325
  dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
311
326
 
312
327
  rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
328
+
313
329
  return rbln_config
314
330
 
315
331
  @classmethod
@@ -327,12 +343,14 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
327
343
  tensor_type="pt",
328
344
  device=rbln_config.device_map["encoder"],
329
345
  activate_profiler=rbln_config.activate_profiler,
346
+ timeout=rbln_config.timeout,
330
347
  ),
331
348
  rebel.Runtime(
332
349
  compiled_models[1],
333
350
  tensor_type="pt",
334
351
  device=rbln_config.device_map["decoder"],
335
352
  activate_profiler=rbln_config.activate_profiler,
353
+ timeout=rbln_config.timeout,
336
354
  ),
337
355
  ]
338
356
 
@@ -409,7 +427,7 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
409
427
  inputs_tensor = torch.nn.functional.pad(
410
428
  inputs_tensor,
411
429
  (0, self.rbln_config.enc_max_seq_len - input_len),
412
- value=self.rbln_config.pad_token_id,
430
+ value=self.config.pad_token_id,
413
431
  )
414
432
  model_kwargs["attention_mask"] = torch.nn.functional.pad(
415
433
  model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
@@ -428,3 +446,32 @@ class RBLNModelForSeq2SeqLM(RBLNModel, ABC):
428
446
  model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
429
447
 
430
448
  return model_kwargs
449
+
450
+ def generate(
451
+ self,
452
+ input_ids: torch.LongTensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ generation_config: Optional[GenerationConfig] = None,
455
+ **kwargs,
456
+ ) -> Union[ModelOutput, torch.LongTensor]:
457
+ """
458
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
459
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
460
+
461
+ Args:
462
+ input_ids (torch.LongTensor): The input ids to the model.
463
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
464
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
465
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
466
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
467
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
468
+
469
+ Returns:
470
+ Generates sequences of token ids for models with a language modeling head.
471
+ """
472
+ if generation_config is not None:
473
+ kwargs["generation_config"] = generation_config
474
+ if attention_mask is not None:
475
+ kwargs["attention_mask"] = attention_mask
476
+
477
+ return super().generate(input_ids, **kwargs)
@@ -31,7 +31,7 @@ class Seq2SeqWrapper:
31
31
  Args:
32
32
  model (nn.Module): The Seq2Seq model to wrap.
33
33
  enc_max_seq_len (int): Maximum sequence length for the encoder's position embeddings and cache sizes.
34
- **kwargs: Additional arguments to pass to the decoder wrapper.
34
+ kwargs: Additional arguments to pass to the decoder wrapper.
35
35
  """
36
36
 
37
37
  def __init__(self, model: nn.Module, enc_max_seq_len: int, **kwargs):
@@ -125,7 +125,7 @@ class Seq2SeqDecoderWrapper(nn.Module):
125
125
 
126
126
  Args:
127
127
  model (nn.Module): The Seq2Seq model containing the decoder.
128
- **kwargs: Additional arguments for decoder configuration.
128
+ kwargs: Additional arguments for decoder configuration.
129
129
  """
130
130
 
131
131
  def __init__(self, model: nn.Module, use_attention_mask: bool = True, **kwargs):
@@ -12,9 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_siglip import (
16
- RBLNSiglipVisionModelConfig,
17
- )
18
- from .modeling_siglip import (
19
- RBLNSiglipVisionModel,
20
- )
15
+ from .configuration_siglip import RBLNSiglipVisionModelConfig
16
+ from .modeling_siglip import RBLNSiglipVisionModel
@@ -42,7 +42,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
42
42
  interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
43
43
  output_hidden_states: (Optional[bool]): Whether to return hidden states.
44
44
  output_attentions: (Optional[bool]): Whether to return attentions.
45
- **kwargs: Additional arguments passed to the parent RBLNModelConfig.
45
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
46
46
 
47
47
  Raises:
48
48
  ValueError: If batch_size is not a positive integer.
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
@@ -29,8 +29,6 @@ logger = get_logger(__name__)
29
29
  if TYPE_CHECKING:
30
30
  from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
31
 
32
- from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
33
-
34
32
 
35
33
  class _SiglipVisionModel(torch.nn.Module):
36
34
  def __init__(
@@ -65,8 +63,12 @@ class RBLNSiglipVisionModel(RBLNModel):
65
63
  on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
64
  """
67
65
 
66
+ _tp_support = False
67
+
68
68
  @classmethod
69
- def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
69
+ def _wrap_model_if_needed(
70
+ cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig
71
+ ) -> torch.nn.Module:
70
72
  wrapper_cfg = {
71
73
  "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
72
74
  "output_hidden_states": rbln_config.output_hidden_states,
@@ -74,12 +76,6 @@ class RBLNSiglipVisionModel(RBLNModel):
74
76
  }
75
77
  return _SiglipVisionModel(model, **wrapper_cfg).eval()
76
78
 
77
- @classmethod
78
- def update_rbln_config_using_pipe(
79
- cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
80
- ) -> "RBLNDiffusionMixinConfig":
81
- return rbln_config
82
-
83
79
  @classmethod
84
80
  def _update_rbln_config(
85
81
  cls,
@@ -126,12 +122,21 @@ class RBLNSiglipVisionModel(RBLNModel):
126
122
  output_attentions: bool = None,
127
123
  output_hidden_states: bool = None,
128
124
  interpolate_pos_encoding: bool = False,
129
- **kwargs: Dict[str, Any],
125
+ **kwargs: Any,
130
126
  ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
- if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
- logger.warning(
133
- f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
134
- )
127
+ """
128
+ Forward pass for the RBLN-optimized SigLIP vision model.
129
+
130
+ Args:
131
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, image_size, image_size), optional): The tensors corresponding to the input images. Pixel values can be obtained using ViTImageProcessor. See ViTImageProcessor.call() for details (processor_class uses ViTImageProcessor for processing images).
132
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
133
+ output_attentions (bool, optional): Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.
134
+ output_hidden_states (bool, optional): Whether or not to return the hidden states of all layers. See hidden_states under returned tensors for more detail.
135
+ interpolate_pos_encoding (bool, defaults to False): Whether to interpolate the pre-trained position encodings.
136
+
137
+ Returns:
138
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
139
+ """
135
140
 
136
141
  output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
137
142
  output_hidden_states = (
@@ -156,7 +161,7 @@ class RBLNSiglipVisionModel(RBLNModel):
156
161
  f"Please compile again with the correct argument."
157
162
  )
158
163
 
159
- output = super().forward(pixel_values, return_dict=return_dict)
164
+ output = super().forward(pixel_values, return_dict=return_dict, **kwargs)
160
165
  return output
161
166
 
162
167
  def _prepare_output(self, output, return_dict):
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_swin import RBLNSwinBackboneConfig
16
+ from .modeling_swin import RBLNSwinBackbone
@@ -0,0 +1,42 @@
1
+ # Licensed under the Apache License, Version 2.0 (the "License");
2
+ # you may not use this file except in compliance with the License.
3
+ # You may obtain a copy of the License at:
4
+
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+
7
+ # Unless required by applicable law or agreed to in writing, software
8
+ # distributed under the License is distributed on an "AS IS" BASIS,
9
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ # See the License for the specific language governing permissions and
11
+ # limitations under the License.
12
+
13
+ from typing import Any, Optional, Tuple, Union
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNSwinBackboneConfig(RBLNModelForImageClassificationConfig):
19
+ def __init__(
20
+ self,
21
+ image_size: Optional[Union[int, Tuple[int, int]]] = None,
22
+ batch_size: Optional[int] = None,
23
+ output_hidden_states: Optional[bool] = None,
24
+ output_attentions: Optional[bool] = None,
25
+ **kwargs: Any,
26
+ ):
27
+ """
28
+ Args:
29
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
30
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
31
+
32
+ Raises:
33
+ ValueError: If batch_size is not a positive integer.
34
+ """
35
+ super().__init__(**kwargs)
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
39
+
40
+ self.image_size = image_size
41
+ self.output_hidden_states = output_hidden_states
42
+ self.output_attentions = output_attentions