optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -18,12 +18,20 @@ from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNSiglipVisionModelConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNSiglipVisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized SigLIP vision models for image encoding in multimodal tasks.
26
+ """
27
+
21
28
  def __init__(
22
29
  self,
23
30
  batch_size: Optional[int] = None,
24
31
  image_size: Optional[int] = None,
25
32
  interpolate_pos_encoding: Optional[bool] = None,
26
33
  output_hidden_states: Optional[bool] = None,
34
+ output_attentions: Optional[bool] = None,
27
35
  **kwargs,
28
36
  ):
29
37
  """
@@ -33,6 +41,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
33
41
  a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
34
42
  interpolate_pos_encoding (Optional[bool]): Whether to interpolate the position encoding.
35
43
  output_hidden_states: (Optional[bool]): Whether to return hidden states.
44
+ output_attentions: (Optional[bool]): Whether to return attentions.
36
45
  **kwargs: Additional arguments passed to the parent RBLNModelConfig.
37
46
 
38
47
  Raises:
@@ -46,6 +55,7 @@ class RBLNSiglipVisionModelConfig(RBLNModelConfig):
46
55
  self.image_size = image_size
47
56
  self.interpolate_pos_encoding = interpolate_pos_encoding or False
48
57
  self.output_hidden_states = output_hidden_states
58
+ self.output_attentions = output_attentions
49
59
 
50
60
  @property
51
61
  def image_width(self):
@@ -12,12 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import TYPE_CHECKING, Optional, Tuple, Union
15
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from transformers import SiglipVisionConfig, SiglipVisionModel
19
19
  from transformers.modeling_outputs import BaseModelOutputWithPooling
20
- from transformers.models.siglip.modeling_siglip import SiglipVisionModelOutput
21
20
 
22
21
  from ....configuration_utils import RBLNCompileConfig
23
22
  from ....modeling import RBLNModel
@@ -34,11 +33,18 @@ if TYPE_CHECKING:
34
33
 
35
34
 
36
35
  class _SiglipVisionModel(torch.nn.Module):
37
- def __init__(self, model: SiglipVisionModel, interpolate_pos_encoding: bool, output_hidden_states: bool):
36
+ def __init__(
37
+ self,
38
+ model: SiglipVisionModel,
39
+ interpolate_pos_encoding: bool,
40
+ output_hidden_states: bool,
41
+ output_attentions: bool,
42
+ ):
38
43
  super().__init__()
39
44
  self.vision_model = model.vision_model
40
45
  self.interpolate_pos_encoding = interpolate_pos_encoding
41
46
  self.output_hidden_states = output_hidden_states
47
+ self.output_attentions = output_attentions
42
48
 
43
49
  def forward(self, inp):
44
50
  enc_out = self.vision_model(
@@ -46,16 +52,25 @@ class _SiglipVisionModel(torch.nn.Module):
46
52
  output_hidden_states=self.output_hidden_states,
47
53
  return_dict=False,
48
54
  interpolate_pos_encoding=self.interpolate_pos_encoding,
55
+ output_attentions=self.output_attentions,
49
56
  )
50
57
  return tuple(x for x in enc_out if x is not None)
51
58
 
52
59
 
53
60
  class RBLNSiglipVisionModel(RBLNModel):
61
+ """
62
+ RBLN optimized SigLIP vision model.
63
+
64
+ This class provides hardware-accelerated inference for SigLIP vision models
65
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
+ """
67
+
54
68
  @classmethod
55
69
  def wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNSiglipVisionModelConfig) -> torch.nn.Module:
56
70
  wrapper_cfg = {
57
71
  "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
58
72
  "output_hidden_states": rbln_config.output_hidden_states,
73
+ "output_attentions": rbln_config.output_attentions,
59
74
  }
60
75
  return _SiglipVisionModel(model, **wrapper_cfg).eval()
61
76
 
@@ -81,8 +96,10 @@ class RBLNSiglipVisionModel(RBLNModel):
81
96
  if rbln_config.image_size is None:
82
97
  raise ValueError("`rbln_image_size` should be specified!")
83
98
 
99
+ if rbln_config.output_attentions is None:
100
+ rbln_config.output_attentions = getattr(model_config, "output_attentions", False)
84
101
  if rbln_config.output_hidden_states is None:
85
- rbln_config.output_hidden_states = model_config.output_hidden_states
102
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
86
103
 
87
104
  rbln_compile_config = RBLNCompileConfig(
88
105
  input_info=[
@@ -104,43 +121,74 @@ class RBLNSiglipVisionModel(RBLNModel):
104
121
 
105
122
  def forward(
106
123
  self,
107
- pixel_values: Optional[torch.FloatTensor] = None,
124
+ pixel_values: torch.Tensor,
108
125
  return_dict: bool = None,
126
+ output_attentions: bool = None,
127
+ output_hidden_states: bool = None,
109
128
  interpolate_pos_encoding: bool = False,
110
- **kwargs,
111
- ) -> Union[Tuple, SiglipVisionModelOutput]:
112
- if len(kwargs) > 0 and any(kwargs.values()):
113
- logger.warning(f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__}.")
129
+ **kwargs: Dict[str, Any],
130
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
131
+ if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
132
+ logger.warning(
133
+ f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
134
+ )
135
+
136
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
137
+ output_hidden_states = (
138
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
139
+ )
140
+
141
+ if output_attentions != self.rbln_config.output_attentions:
142
+ raise ValueError(
143
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
144
+ f"Please compile again with the correct argument."
145
+ )
146
+
147
+ if output_hidden_states != self.rbln_config.output_hidden_states:
148
+ raise ValueError(
149
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
150
+ f"Please compile again with the correct argument."
151
+ )
114
152
 
115
153
  if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
116
154
  raise ValueError(
117
- f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding}"
155
+ f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
118
156
  f"Please compile again with the correct argument."
119
157
  )
158
+
120
159
  output = super().forward(pixel_values, return_dict=return_dict)
121
160
  return output
122
161
 
123
162
  def _prepare_output(self, output, return_dict):
124
- """
125
- Prepare model output based on return_dict flag.
126
- This method can be overridden by subclasses to provide task-specific output handling.
127
- """
163
+ # Prepare model output based on return_dict flag.
164
+ # This method can be overridden by subclasses to provide task-specific output handling.
165
+
128
166
  if not return_dict:
129
167
  return (output,) if not isinstance(output, (tuple, list)) else output
130
168
  else:
131
- last_hidden_state = (
132
- output[0]
133
- if self.rbln_config.interpolate_pos_encoding or self.rbln_config.output_hidden_states
134
- else output
135
- )
136
- pooler_output = output[1] if self.rbln_config.interpolate_pos_encoding else None
169
+ last_hidden_state = output.pop(0) if isinstance(output, (tuple, list)) else output
170
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
171
+ pooler_output = output.pop(0) if getattr(vision_config, "vision_use_head", True) else None
172
+
137
173
  if self.rbln_config.output_hidden_states:
138
- hidden_states = (output[2:] if self.rbln_config.interpolate_pos_encoding else output[1:],)
174
+ hidden_states = ()
175
+ num_hidden_layers = vision_config.num_hidden_layers
176
+ for _ in range(num_hidden_layers + 1):
177
+ hidden_states += (output.pop(0),)
139
178
  else:
140
179
  hidden_states = None
141
180
 
181
+ if self.rbln_config.output_attentions:
182
+ attentions = ()
183
+ num_hidden_layers = vision_config.num_hidden_layers
184
+ for _ in range(num_hidden_layers):
185
+ attentions += (output.pop(0),)
186
+ else:
187
+ attentions = None
188
+
142
189
  return BaseModelOutputWithPooling(
143
190
  last_hidden_state=last_hidden_state,
144
191
  pooler_output=pooler_output,
145
192
  hidden_states=hidden_states,
193
+ attentions=attentions,
146
194
  )
@@ -17,8 +17,18 @@ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
17
 
18
18
 
19
19
  class RBLNT5EncoderModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
- pass
20
+ """
21
+ Configuration class for RBLNT5EncoderModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized T5 encoder models for feature extraction tasks.
25
+ """
21
26
 
22
27
 
23
28
  class RBLNT5ForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
- pass
29
+ """
30
+ Configuration class for RBLNT5ForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized T5 models for conditional text generation tasks.
34
+ """
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Any, Callable
17
17
 
18
18
  import torch
19
19
  from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
20
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
20
21
 
21
22
  from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
23
  from ...models.seq2seq import RBLNModelForSeq2SeqLM
@@ -41,8 +42,30 @@ class T5EncoderWrapper(torch.nn.Module):
41
42
 
42
43
 
43
44
  class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
45
+ """
46
+ The T5 Model transformer with an encoder-only architecture for feature extraction.
47
+ This model inherits from [`RBLNTransformerEncoderForFeatureExtraction`]. Check the superclass documentation for the generic methods the library implements for all its models.
48
+
49
+ Important Note:
50
+ This model supports various sizes of the T5EncoderModel. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
51
+ based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
52
+
53
+ Examples:
54
+ ```python
55
+ from optimum.rbln import RBLNT5EncoderModel
56
+
57
+ model = RBLNT5EncoderModel.from_pretrained(
58
+ "sentence-transformers/sentence-t5-xxl",
59
+ export=True,
60
+ rbln_tensor_parallel_size=4,
61
+ )
62
+
63
+ model.save_pretrained("compiled-sentence-t5-xxl")
64
+ ```
65
+ """
66
+
44
67
  auto_model_class = AutoModelForTextEncoding
45
- rbln_model_input_names = ["input_ids", "attention_mask"]
68
+ output_class = BaseModelOutputWithPastAndCrossAttentions
46
69
 
47
70
  @classmethod
48
71
  def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
@@ -50,18 +73,43 @@ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
50
73
 
51
74
  @classmethod
52
75
  def update_rbln_config_using_pipe(
53
- cls,
54
- pipe: "RBLNDiffusionMixin",
55
- rbln_config: "RBLNDiffusionMixinConfig",
56
- submodule_name: str,
76
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
57
77
  ) -> "RBLNDiffusionMixinConfig":
58
- submodule_config = getattr(rbln_config, submodule_name)
59
- submodule_config.max_seq_len = rbln_config.max_seq_len or 256
60
- submodule_config.model_input_names = ["input_ids"]
61
78
  return rbln_config
62
79
 
80
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
81
+ input_dict = {"input_ids": input_ids.long()}
82
+ if attention_mask is not None:
83
+ input_dict["attention_mask"] = attention_mask.long()
84
+
85
+ output = super().forward(**input_dict, **kwargs)
86
+ return output
87
+
63
88
 
64
89
  class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
90
+ """
91
+ The T5 Model transformer with a language modeling head for conditional generation.
92
+ This model inherits from [`RBLNModelForSeq2SeqLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
93
+
94
+ Important Note:
95
+ This model supports various sizes of the T5ForConditionalGeneration. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
96
+ based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
97
+
98
+
99
+ Examples:
100
+ ```python
101
+ from optimum.rbln import RBLNT5ForConditionalGeneration
102
+
103
+ model = RBLNT5ForConditionalGeneration.from_pretrained(
104
+ "google-t5/t5-11b",
105
+ export=True,
106
+ rbln_tensor_parallel_size=4,
107
+ )
108
+
109
+ model.save_pretrained("compiled-sentence-t5-xxl")
110
+ ```
111
+ """
112
+
65
113
  support_causal_attn = False
66
114
 
67
115
  @classmethod
@@ -136,10 +136,14 @@ class T5Decoder(Seq2SeqDecoder):
136
136
 
137
137
 
138
138
  class T5Block(Seq2SeqDecoderLayer):
139
+ def __init__(self, decoder_layer, self_attn):
140
+ super().__init__(decoder_layer, self_attn, cross_attn=None)
141
+ self.__post_init__()
142
+
139
143
  def __post_init__(self):
140
144
  self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
141
145
  self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
142
- self.encoder_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
146
+ self.cross_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
143
147
  self.ff_layer = self._original_mod.layer[2]
144
148
 
145
149
  def pre_self_attn_layer_norm(self, hidden_states):
@@ -23,4 +23,4 @@
23
23
 
24
24
  from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
25
25
  from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
26
- from .modeling_time_series_transformers import RBLNTimeSeriesTransformerForPrediction
26
+ from .modeling_time_series_transformer import RBLNTimeSeriesTransformerForPrediction
@@ -1,16 +1,23 @@
1
- from typing import Optional
1
+ from typing import Any, Dict, Optional
2
2
 
3
3
  from ....configuration_utils import RBLNModelConfig
4
4
 
5
5
 
6
6
  class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ """
8
+ Configuration class for RBLNTimeSeriesTransformerForPrediction.
9
+
10
+ This configuration class stores the configuration parameters specific to
11
+ RBLN-optimized Time Series Transformer models for time series forecasting tasks.
12
+ """
13
+
7
14
  def __init__(
8
15
  self,
9
16
  batch_size: Optional[int] = None,
10
17
  enc_max_seq_len: Optional[int] = None,
11
18
  dec_max_seq_len: Optional[int] = None,
12
19
  num_parallel_samples: Optional[int] = None,
13
- **kwargs,
20
+ **kwargs: Dict[str, Any],
14
21
  ):
15
22
  """
16
23
  Args:
@@ -120,6 +120,17 @@ class RBLNSeq2SeqTSDecoderOutput(ModelOutput):
120
120
 
121
121
 
122
122
  class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
123
+ """
124
+ The Time Series Transformer Model with a distribution head on top for time-series forecasting. e.g., for datasets like M4, NN5, or other time series forecasting benchmarks.
125
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
126
+
127
+ A class to convert and run pre-trained transformer-based `TimeSeriesTransformerForPrediction` models on RBLN devices.
128
+ It implements the methods to convert a pre-trained transformers `TimeSeriesTransformerForPrediction` model into a RBLN transformer model by:
129
+
130
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
131
+ - compiling the resulting graph using the RBLN Compiler.
132
+ """
133
+
123
134
  auto_model_class = None
124
135
  main_input_name = "inputs_embeds"
125
136
 
@@ -144,11 +155,6 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
144
155
  )
145
156
 
146
157
  def __getattr__(self, __name: str) -> Any:
147
- """This is the key method to implement RBLN-TimeSeriesTransformersForPrediction.
148
- Returns:
149
- Any: TimeSeriesTransformersForPrediction's corresponding method
150
- """
151
-
152
158
  def redirect(func):
153
159
  return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
154
160
 
@@ -188,15 +194,19 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
188
194
  if "key_value_states" in name:
189
195
  context.mark_static_address(tensor)
190
196
 
191
- compiled_decoder = super().compile(
197
+ compiled_decoder = cls.compile(
192
198
  wrapped_model.decoder,
193
199
  dec_compile_config,
200
+ create_runtimes=rbln_config.create_runtimes,
201
+ device=rbln_config.device,
194
202
  example_inputs=dec_example_inputs,
195
203
  compile_context=context,
196
204
  )
197
- compiled_encoder = super().compile(
205
+ compiled_encoder = cls.compile(
198
206
  wrapped_model.encoder,
199
207
  enc_compile_config,
208
+ create_runtimes=rbln_config.create_runtimes,
209
+ device=rbln_config.device,
200
210
  example_inputs=enc_example_inputs,
201
211
  compile_context=context,
202
212
  )
@@ -211,10 +221,9 @@ class RBLNTimeSeriesTransformerForPrediction(RBLNModel):
211
221
  subfolder: str,
212
222
  rbln_config: RBLNTimeSeriesTransformerForPredictionConfig,
213
223
  ):
214
- """
215
- If you are unavoidably running on a CPU rather than an RBLN device,
216
- store the torch tensor, weight, etc. in this function.
217
- """
224
+ # If you are unavoidably running on a CPU rather than an RBLN device,
225
+ # store the torch tensor, weight, etc. in this function.
226
+
218
227
  save_dict = {}
219
228
  save_dict["embedder"] = model.model.embedder.state_dict()
220
229
  torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_vit import RBLNViTForImageClassificationConfig
16
+ from .modeling_vit import RBLNViTForImageClassification
17
+
18
+
19
+ __all__ = ["RBLNViTForImageClassificationConfig", "RBLNViTForImageClassification"]
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
19
+ """
20
+ Configuration class for RBLNViTForImageClassification.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized Vision Transformer (ViT) models for image classification tasks.
24
+ """
@@ -0,0 +1,25 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...modeling_generic import RBLNModelForImageClassification
16
+
17
+
18
+ class RBLNViTForImageClassification(RBLNModelForImageClassification):
19
+ """
20
+ RBLN optimized Vision Transformer (ViT) model for image classification tasks.
21
+
22
+ This class provides hardware-accelerated inference for Vision Transformer models
23
+ on RBLN devices, supporting image classification with transformer-based architectures
24
+ that process images as sequences of patches.
25
+ """
@@ -12,5 +12,5 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
15
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
16
16
  from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForMaskedLMConfig
16
+
17
+
18
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
+ """
20
+ Configuration class for RBLNWav2Vec2ForCTC.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
24
+ """
25
+
26
+ rbln_model_input_names = ["input_values"]
@@ -17,7 +17,7 @@ import torch
17
17
  from transformers import AutoModelForMaskedLM, Wav2Vec2ForCTC
18
18
 
19
19
  from ...modeling_generic import RBLNModelForMaskedLM
20
- from .configuration_wav2vec import RBLNWav2Vec2ForCTCConfig
20
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
21
21
 
22
22
 
23
23
  class _Wav2Vec2(torch.nn.Module):
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Any, Dict
16
+
15
17
  import rebel
16
18
 
17
19
  from ....configuration_utils import RBLNModelConfig
@@ -22,6 +24,13 @@ logger = get_logger()
22
24
 
23
25
 
24
26
  class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
27
+ """
28
+ Configuration class for RBLNWhisperForConditionalGeneration.
29
+
30
+ This configuration class stores the configuration parameters specific to
31
+ RBLN-optimized Whisper models for speech recognition and transcription tasks.
32
+ """
33
+
25
34
  def __init__(
26
35
  self,
27
36
  batch_size: int = None,
@@ -29,7 +38,7 @@ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
29
38
  use_attention_mask: bool = None,
30
39
  enc_max_seq_len: int = None,
31
40
  dec_max_seq_len: int = None,
32
- **kwargs,
41
+ **kwargs: Dict[str, Any],
33
42
  ):
34
43
  """
35
44
  Args: