optimum-rbln 0.8.0.post2__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. optimum/rbln/__init__.py +24 -0
  2. optimum/rbln/__version__.py +2 -2
  3. optimum/rbln/configuration_utils.py +45 -33
  4. optimum/rbln/diffusers/__init__.py +21 -1
  5. optimum/rbln/diffusers/configurations/__init__.py +4 -0
  6. optimum/rbln/diffusers/configurations/models/__init__.py +2 -0
  7. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +9 -2
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +4 -2
  10. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +9 -2
  11. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +70 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +4 -2
  13. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +9 -2
  14. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +9 -2
  15. optimum/rbln/diffusers/configurations/pipelines/__init__.py +1 -0
  16. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +29 -9
  17. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +114 -0
  18. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +28 -12
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +18 -6
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +13 -6
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +12 -6
  22. optimum/rbln/diffusers/modeling_diffusers.py +72 -65
  23. optimum/rbln/diffusers/models/__init__.py +4 -0
  24. optimum/rbln/diffusers/models/autoencoders/__init__.py +1 -0
  25. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +17 -1
  26. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +219 -0
  27. optimum/rbln/diffusers/models/autoencoders/vae.py +45 -8
  28. optimum/rbln/diffusers/models/autoencoders/vq_model.py +17 -1
  29. optimum/rbln/diffusers/models/controlnet.py +14 -8
  30. optimum/rbln/diffusers/models/transformers/__init__.py +1 -0
  31. optimum/rbln/diffusers/models/transformers/prior_transformer.py +10 -0
  32. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +321 -0
  33. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +2 -0
  34. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +11 -1
  35. optimum/rbln/diffusers/pipelines/__init__.py +10 -0
  36. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +1 -4
  37. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -0
  38. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -0
  39. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +7 -0
  40. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +7 -0
  41. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  42. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +102 -0
  43. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +455 -0
  44. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +98 -0
  45. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +98 -0
  46. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +7 -0
  47. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +48 -27
  48. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +7 -0
  49. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +7 -0
  50. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -0
  51. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +7 -0
  52. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +7 -0
  53. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +7 -0
  54. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -0
  55. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -0
  56. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -0
  57. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +7 -0
  58. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -0
  59. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +7 -0
  60. optimum/rbln/modeling.py +71 -37
  61. optimum/rbln/modeling_base.py +63 -109
  62. optimum/rbln/transformers/__init__.py +41 -47
  63. optimum/rbln/transformers/configuration_generic.py +16 -13
  64. optimum/rbln/transformers/modeling_generic.py +21 -22
  65. optimum/rbln/transformers/modeling_rope_utils.py +5 -2
  66. optimum/rbln/transformers/models/__init__.py +54 -4
  67. optimum/rbln/transformers/models/{wav2vec2/configuration_wav2vec.py → audio_spectrogram_transformer/__init__.py} +2 -4
  68. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +21 -0
  69. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +28 -0
  70. optimum/rbln/transformers/models/auto/auto_factory.py +35 -12
  71. optimum/rbln/transformers/models/bart/bart_architecture.py +14 -1
  72. optimum/rbln/transformers/models/bart/configuration_bart.py +12 -2
  73. optimum/rbln/transformers/models/bart/modeling_bart.py +16 -7
  74. optimum/rbln/transformers/models/bert/configuration_bert.py +18 -3
  75. optimum/rbln/transformers/models/bert/modeling_bert.py +24 -0
  76. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +15 -3
  77. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +50 -4
  78. optimum/rbln/transformers/models/clip/configuration_clip.py +15 -5
  79. optimum/rbln/transformers/models/clip/modeling_clip.py +38 -13
  80. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  81. optimum/rbln/transformers/models/colpali/colpali_architecture.py +221 -0
  82. optimum/rbln/transformers/models/colpali/configuration_colpali.py +68 -0
  83. optimum/rbln/transformers/models/colpali/modeling_colpali.py +383 -0
  84. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +111 -14
  85. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +102 -35
  86. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +253 -195
  87. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  88. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  89. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +27 -0
  90. optimum/rbln/transformers/models/dpt/configuration_dpt.py +6 -1
  91. optimum/rbln/transformers/models/dpt/modeling_dpt.py +6 -1
  92. optimum/rbln/transformers/models/exaone/configuration_exaone.py +24 -1
  93. optimum/rbln/transformers/models/exaone/exaone_architecture.py +5 -1
  94. optimum/rbln/transformers/models/exaone/modeling_exaone.py +66 -5
  95. optimum/rbln/transformers/models/gemma/configuration_gemma.py +24 -1
  96. optimum/rbln/transformers/models/gemma/gemma_architecture.py +5 -1
  97. optimum/rbln/transformers/models/gemma/modeling_gemma.py +49 -0
  98. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +3 -3
  99. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +18 -250
  100. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +89 -244
  101. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +4 -1
  102. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +6 -1
  103. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +12 -2
  104. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +41 -4
  105. optimum/rbln/transformers/models/llama/configuration_llama.py +24 -1
  106. optimum/rbln/transformers/models/llama/modeling_llama.py +49 -0
  107. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +10 -2
  108. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +32 -4
  109. optimum/rbln/transformers/models/midm/configuration_midm.py +24 -1
  110. optimum/rbln/transformers/models/midm/midm_architecture.py +6 -1
  111. optimum/rbln/transformers/models/midm/modeling_midm.py +66 -5
  112. optimum/rbln/transformers/models/mistral/configuration_mistral.py +24 -1
  113. optimum/rbln/transformers/models/mistral/modeling_mistral.py +62 -4
  114. optimum/rbln/transformers/models/opt/configuration_opt.py +4 -1
  115. optimum/rbln/transformers/models/opt/modeling_opt.py +10 -0
  116. optimum/rbln/transformers/models/opt/opt_architecture.py +7 -1
  117. optimum/rbln/transformers/models/phi/configuration_phi.py +24 -1
  118. optimum/rbln/transformers/models/phi/modeling_phi.py +49 -0
  119. optimum/rbln/transformers/models/phi/phi_architecture.py +1 -1
  120. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +24 -1
  121. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +67 -4
  122. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +31 -3
  123. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +54 -25
  124. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +6 -4
  125. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  126. optimum/rbln/transformers/models/resnet/configuration_resnet.py +25 -0
  127. optimum/rbln/transformers/models/resnet/modeling_resnet.py +26 -0
  128. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  129. optimum/rbln/transformers/{configuration_alias.py → models/roberta/configuration_roberta.py} +12 -28
  130. optimum/rbln/transformers/{modeling_alias.py → models/roberta/modeling_roberta.py} +14 -28
  131. optimum/rbln/transformers/models/seq2seq/__init__.py +1 -1
  132. optimum/rbln/transformers/models/seq2seq/{configuration_seq2seq2.py → configuration_seq2seq.py} +2 -2
  133. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +7 -3
  134. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +41 -3
  135. optimum/rbln/transformers/models/siglip/configuration_siglip.py +10 -0
  136. optimum/rbln/transformers/models/siglip/modeling_siglip.py +69 -21
  137. optimum/rbln/transformers/models/t5/configuration_t5.py +12 -2
  138. optimum/rbln/transformers/models/t5/modeling_t5.py +56 -8
  139. optimum/rbln/transformers/models/t5/t5_architecture.py +5 -1
  140. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/__init__.py +1 -1
  141. optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/configuration_time_series_transformer.py +9 -2
  142. optimum/rbln/transformers/models/{time_series_transformers/modeling_time_series_transformers.py → time_series_transformer/modeling_time_series_transformer.py} +20 -11
  143. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  144. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  145. optimum/rbln/transformers/models/vit/modeling_vit.py +25 -0
  146. optimum/rbln/transformers/models/wav2vec2/__init__.py +1 -1
  147. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +26 -0
  148. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +1 -1
  149. optimum/rbln/transformers/models/whisper/configuration_whisper.py +10 -1
  150. optimum/rbln/transformers/models/whisper/modeling_whisper.py +41 -17
  151. optimum/rbln/transformers/models/xlm_roberta/__init__.py +16 -2
  152. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +15 -2
  153. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +12 -3
  154. optimum/rbln/utils/model_utils.py +20 -0
  155. optimum/rbln/utils/runtime_utils.py +49 -1
  156. optimum/rbln/utils/submodule.py +6 -8
  157. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/METADATA +6 -6
  158. optimum_rbln-0.8.1.dist-info/RECORD +211 -0
  159. optimum_rbln-0.8.0.post2.dist-info/RECORD +0 -184
  160. /optimum/rbln/transformers/models/{time_series_transformers → time_series_transformer}/time_series_transformers_architecture.py +0 -0
  161. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/WHEEL +0 -0
  162. {optimum_rbln-0.8.0.post2.dist-info → optimum_rbln-0.8.1.dist-info}/licenses/LICENSE +0 -0
@@ -18,6 +18,10 @@ from transformers.utils import _LazyModule
18
18
 
19
19
 
20
20
  _import_structure = {
21
+ "audio_spectrogram_transformer": [
22
+ "RBLNASTForAudioClassification",
23
+ "RBLNASTForAudioClassificationConfig",
24
+ ],
21
25
  "auto": [
22
26
  "RBLNAutoModel",
23
27
  "RBLNAutoModelForAudioClassification",
@@ -65,6 +69,14 @@ _import_structure = {
65
69
  "RBLNCLIPVisionModelWithProjection",
66
70
  "RBLNCLIPVisionModelWithProjectionConfig",
67
71
  ],
72
+ "colpali": [
73
+ "RBLNColPaliForRetrieval",
74
+ "RBLNColPaliForRetrievalConfig",
75
+ ],
76
+ "distilbert": [
77
+ "RBLNDistilBertForQuestionAnswering",
78
+ "RBLNDistilBertForQuestionAnsweringConfig",
79
+ ],
68
80
  "qwen2_5_vl": [
69
81
  "RBLNQwen2_5_VisionTransformerPretrainedModel",
70
82
  "RBLNQwen2_5_VisionTransformerPretrainedModelConfig",
@@ -101,11 +113,18 @@ _import_structure = {
101
113
  "mistral": ["RBLNMistralForCausalLM", "RBLNMistralForCausalLMConfig"],
102
114
  "phi": ["RBLNPhiForCausalLM", "RBLNPhiForCausalLMConfig"],
103
115
  "qwen2": ["RBLNQwen2ForCausalLM", "RBLNQwen2ForCausalLMConfig"],
116
+ "resnet": ["RBLNResNetForImageClassification", "RBLNResNetForImageClassificationConfig"],
117
+ "roberta": [
118
+ "RBLNRobertaForMaskedLM",
119
+ "RBLNRobertaForMaskedLMConfig",
120
+ "RBLNRobertaForSequenceClassification",
121
+ "RBLNRobertaForSequenceClassificationConfig",
122
+ ],
104
123
  "siglip": [
105
124
  "RBLNSiglipVisionModel",
106
125
  "RBLNSiglipVisionModelConfig",
107
126
  ],
108
- "time_series_transformers": [
127
+ "time_series_transformer": [
109
128
  "RBLNTimeSeriesTransformerForPrediction",
110
129
  "RBLNTimeSeriesTransformerForPredictionConfig",
111
130
  ],
@@ -115,12 +134,22 @@ _import_structure = {
115
134
  "RBLNT5EncoderModelConfig",
116
135
  "RBLNT5ForConditionalGenerationConfig",
117
136
  ],
137
+ "vit": ["RBLNViTForImageClassification", "RBLNViTForImageClassificationConfig"],
118
138
  "wav2vec2": ["RBLNWav2Vec2ForCTC", "RBLNWav2Vec2ForCTCConfig"],
119
139
  "whisper": ["RBLNWhisperForConditionalGeneration", "RBLNWhisperForConditionalGenerationConfig"],
120
- "xlm_roberta": ["RBLNXLMRobertaModel", "RBLNXLMRobertaModelConfig"],
140
+ "xlm_roberta": [
141
+ "RBLNXLMRobertaModel",
142
+ "RBLNXLMRobertaModelConfig",
143
+ "RBLNXLMRobertaForSequenceClassification",
144
+ "RBLNXLMRobertaForSequenceClassificationConfig",
145
+ ],
121
146
  }
122
147
 
123
148
  if TYPE_CHECKING:
149
+ from .audio_spectrogram_transformer import (
150
+ RBLNASTForAudioClassification,
151
+ RBLNASTForAudioClassificationConfig,
152
+ )
124
153
  from .auto import (
125
154
  RBLNAutoModel,
126
155
  RBLNAutoModelForAudioClassification,
@@ -168,10 +197,18 @@ if TYPE_CHECKING:
168
197
  RBLNCLIPVisionModelWithProjection,
169
198
  RBLNCLIPVisionModelWithProjectionConfig,
170
199
  )
200
+ from .colpali import (
201
+ RBLNColPaliForRetrieval,
202
+ RBLNColPaliForRetrievalConfig,
203
+ )
171
204
  from .decoderonly import (
172
205
  RBLNDecoderOnlyModelForCausalLM,
173
206
  RBLNDecoderOnlyModelForCausalLMConfig,
174
207
  )
208
+ from .distilbert import (
209
+ RBLNDistilBertForQuestionAnswering,
210
+ RBLNDistilBertForQuestionAnsweringConfig,
211
+ )
175
212
  from .dpt import (
176
213
  RBLNDPTForDepthEstimation,
177
214
  RBLNDPTForDepthEstimationConfig,
@@ -204,6 +241,13 @@ if TYPE_CHECKING:
204
241
  RBLNQwen2_5_VLForConditionalGeneration,
205
242
  RBLNQwen2_5_VLForConditionalGenerationConfig,
206
243
  )
244
+ from .resnet import RBLNResNetForImageClassification, RBLNResNetForImageClassificationConfig
245
+ from .roberta import (
246
+ RBLNRobertaForMaskedLM,
247
+ RBLNRobertaForMaskedLMConfig,
248
+ RBLNRobertaForSequenceClassification,
249
+ RBLNRobertaForSequenceClassificationConfig,
250
+ )
207
251
  from .siglip import RBLNSiglipVisionModel, RBLNSiglipVisionModelConfig
208
252
  from .t5 import (
209
253
  RBLNT5EncoderModel,
@@ -211,13 +255,19 @@ if TYPE_CHECKING:
211
255
  RBLNT5ForConditionalGeneration,
212
256
  RBLNT5ForConditionalGenerationConfig,
213
257
  )
214
- from .time_series_transformers import (
258
+ from .time_series_transformer import (
215
259
  RBLNTimeSeriesTransformerForPrediction,
216
260
  RBLNTimeSeriesTransformerForPredictionConfig,
217
261
  )
262
+ from .vit import RBLNViTForImageClassification, RBLNViTForImageClassificationConfig
218
263
  from .wav2vec2 import RBLNWav2Vec2ForCTC, RBLNWav2Vec2ForCTCConfig
219
264
  from .whisper import RBLNWhisperForConditionalGeneration, RBLNWhisperForConditionalGenerationConfig
220
- from .xlm_roberta import RBLNXLMRobertaModel, RBLNXLMRobertaModelConfig
265
+ from .xlm_roberta import (
266
+ RBLNXLMRobertaForSequenceClassification,
267
+ RBLNXLMRobertaForSequenceClassificationConfig,
268
+ RBLNXLMRobertaModel,
269
+ RBLNXLMRobertaModelConfig,
270
+ )
221
271
 
222
272
  else:
223
273
  import sys
@@ -12,8 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ...configuration_generic import RBLNModelForMaskedLMConfig
16
15
 
17
-
18
- class RBLNWav2Vec2ForCTCConfig(RBLNModelForMaskedLMConfig):
19
- rbln_model_input_names = ["input_values"]
16
+ from .configuration_audio_spectrogram_transformer import RBLNASTForAudioClassificationConfig
17
+ from .modeling_audio_spectrogram_transformer import RBLNASTForAudioClassification
@@ -0,0 +1,21 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForAudioClassificationConfig
16
+
17
+
18
+ class RBLNASTForAudioClassificationConfig(RBLNModelForAudioClassificationConfig):
19
+ """
20
+ Configuration class for RBLNASTForAudioClassification.
21
+ """
@@ -0,0 +1,28 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...modeling_generic import RBLNModelForAudioClassification
16
+
17
+
18
+ class RBLNASTForAudioClassification(RBLNModelForAudioClassification):
19
+ """
20
+ Audio Spectrogram Transformer model with an audio classification head on top (a linear layer on top of the pooled output) e.g. for datasets like AudioSet, Speech Commands v2.
21
+ This model inherits from [`RBLNModelForAudioClassification`]. Check the superclass documentation for the generic methods the library implements for all its models.
22
+
23
+ A class to convert and run pre-trained transformer-based `ASTForAudioClassification` models on RBLN devices.
24
+ It implements the methods to convert a pre-trained transformers `ASTForAudioClassification` model into a RBLN transformer model by:
25
+
26
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
27
+ - compiling the resulting graph using the RBLN Compiler.
28
+ """
@@ -11,10 +11,10 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  import importlib
16
15
  import inspect
17
16
  import warnings
17
+ from typing import Type
18
18
 
19
19
  from transformers import AutoConfig, PretrainedConfig
20
20
  from transformers.dynamic_module_utils import get_class_from_dynamic_module
@@ -22,7 +22,12 @@ from transformers.models.auto.auto_factory import _get_model_class
22
22
 
23
23
  from optimum.rbln.configuration_utils import RBLNAutoConfig
24
24
  from optimum.rbln.modeling_base import RBLNBaseModel
25
- from optimum.rbln.utils.model_utils import convert_hf_to_rbln_model_name, convert_rbln_to_hf_model_name
25
+ from optimum.rbln.utils.model_utils import (
26
+ MODEL_MAPPING,
27
+ convert_hf_to_rbln_model_name,
28
+ convert_rbln_to_hf_model_name,
29
+ get_rbln_model_cls,
30
+ )
26
31
 
27
32
 
28
33
  class _BaseAutoModelClass:
@@ -58,7 +63,7 @@ class _BaseAutoModelClass:
58
63
  hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
59
64
  rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
60
65
  else:
61
- rbln_class_name = cls.get_rbln_model_class_name(pretrained_model_name_or_path, **kwargs)
66
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
62
67
 
63
68
  if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
64
69
  raise ValueError(
@@ -68,8 +73,7 @@ class _BaseAutoModelClass:
68
73
  )
69
74
 
70
75
  try:
71
- module = importlib.import_module("optimum.rbln")
72
- rbln_cls = getattr(module, rbln_class_name)
76
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
73
77
  except AttributeError as e:
74
78
  raise AttributeError(
75
79
  f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
@@ -136,7 +140,7 @@ class _BaseAutoModelClass:
136
140
  return model_class
137
141
 
138
142
  @classmethod
139
- def get_rbln_model_class_name(cls, pretrained_model_name_or_path, **kwargs):
143
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path, **kwargs):
140
144
  """
141
145
  Retrieve the path to the compiled model directory for a given RBLN model.
142
146
 
@@ -159,11 +163,30 @@ class _BaseAutoModelClass:
159
163
  return rbln_config.rbln_model_cls_name
160
164
 
161
165
  @classmethod
162
- def from_pretrained(
163
- cls,
164
- model_id,
165
- *args,
166
- **kwargs,
167
- ):
166
+ def from_pretrained(cls, model_id, *args, **kwargs):
168
167
  rbln_cls = cls.get_rbln_cls(model_id, *args, **kwargs)
169
168
  return rbln_cls.from_pretrained(model_id, *args, **kwargs)
169
+
170
+ @classmethod
171
+ def from_model(cls, model, *args, **kwargs):
172
+ rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
173
+ return rbln_cls.from_model(model, *args, **kwargs)
174
+
175
+ @staticmethod
176
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok=False):
177
+ """
178
+ Register a new RBLN model class.
179
+
180
+ Args:
181
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
182
+ exist_ok (bool): Whether to allow registering an already registered model.
183
+ """
184
+ if not issubclass(rbln_cls, RBLNBaseModel):
185
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
186
+
187
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
188
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
189
+ if not exist_ok:
190
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
191
+
192
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
@@ -22,6 +22,7 @@ from transformers.modeling_attn_mask_utils import (
22
22
  from transformers.utils import logging
23
23
 
24
24
  from ..seq2seq.seq2seq_architecture import (
25
+ Seq2SeqCrossAttention,
25
26
  Seq2SeqDecoder,
26
27
  Seq2SeqDecoderLayer,
27
28
  Seq2SeqDecoderWrapper,
@@ -45,7 +46,8 @@ class BartDecoderWrapper(Seq2SeqDecoderWrapper):
45
46
  new_layers = []
46
47
  for layer in model.get_decoder().layers:
47
48
  self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
48
- new_layers.append(BartDecoderLayer(layer, self_attn))
49
+ cross_attn = BartCrossAttention(layer.encoder_attn)
50
+ new_layers.append(BartDecoderLayer(layer, self_attn, cross_attn))
49
51
 
50
52
  decoder_model = BartDecoder(model.get_decoder(), new_layers)
51
53
  new_model = BartForConditionalGeneration(model, decoder_model)
@@ -153,3 +155,14 @@ class BartSelfAttention(Seq2SeqSelfAttention):
153
155
  key_states = self.k_proj(hidden_states)
154
156
  value_states = self.v_proj(hidden_states)
155
157
  return query_states, key_states, value_states
158
+
159
+
160
+ class BartCrossAttention(Seq2SeqCrossAttention):
161
+ def __post_init__(self):
162
+ self.q_proj = self._original_mod.q_proj
163
+ self.k_proj = self._original_mod.k_proj
164
+ self.v_proj = self._original_mod.v_proj
165
+ self.out_proj = self._original_mod.out_proj
166
+ self.num_heads = self._original_mod.num_heads
167
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
168
+ self.embed_dim = self._original_mod.embed_dim
@@ -17,8 +17,18 @@ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
17
 
18
18
 
19
19
  class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
- pass
20
+ """
21
+ Configuration class for RBLNBartModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized BART models for feature extraction tasks.
25
+ """
21
26
 
22
27
 
23
28
  class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
24
- pass
29
+ """
30
+ Configuration class for RBLNBartForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized BART models for conditional text generation tasks.
34
+ """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
- from typing import TYPE_CHECKING, Any, Callable
16
+ from typing import Any, Callable
17
17
 
18
18
  from transformers import BartForConditionalGeneration, PreTrainedModel
19
19
 
@@ -27,19 +27,28 @@ from .configuration_bart import RBLNBartForConditionalGenerationConfig
27
27
  logger = get_logger()
28
28
 
29
29
 
30
- if TYPE_CHECKING:
31
- from transformers import PreTrainedModel
32
-
33
-
34
30
  class RBLNBartModel(RBLNTransformerEncoderForFeatureExtraction):
35
- pass
31
+ """
32
+ RBLN optimized BART model for feature extraction tasks.
33
+
34
+ This class provides hardware-accelerated inference for BART encoder models
35
+ on RBLN devices, optimized for feature extraction use cases.
36
+ """
36
37
 
37
38
 
38
39
  class RBLNBartForConditionalGeneration(RBLNModelForSeq2SeqLM):
40
+ """
41
+ RBLN optimized BART model for conditional text generation tasks.
42
+
43
+ This class provides hardware-accelerated inference for BART models
44
+ on RBLN devices, supporting sequence-to-sequence generation tasks
45
+ such as summarization, translation, and text generation.
46
+ """
47
+
39
48
  support_causal_attn = True
40
49
 
41
50
  @classmethod
42
- def wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNBartForConditionalGenerationConfig):
51
+ def wrap_model_if_needed(self, model: PreTrainedModel, rbln_config: RBLNBartForConditionalGenerationConfig):
43
52
  return BartWrapper(
44
53
  model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
45
54
  )
@@ -20,12 +20,27 @@ from ...configuration_generic import (
20
20
 
21
21
 
22
22
  class RBLNBertModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
23
- pass
23
+ """
24
+ Configuration class for RBLNBertModel.
25
+
26
+ This configuration class stores the configuration parameters specific to
27
+ RBLN-optimized BERT models for feature extraction tasks.
28
+ """
24
29
 
25
30
 
26
31
  class RBLNBertForMaskedLMConfig(RBLNModelForMaskedLMConfig):
27
- pass
32
+ """
33
+ Configuration class for RBLNBertForMaskedLM.
34
+
35
+ This configuration class stores the configuration parameters specific to
36
+ RBLN-optimized BERT models for masked language modeling tasks.
37
+ """
28
38
 
29
39
 
30
40
  class RBLNBertForQuestionAnsweringConfig(RBLNModelForQuestionAnsweringConfig):
31
- pass
41
+ """
42
+ Configuration class for RBLNBertForQuestionAnswering.
43
+
44
+ This configuration class stores the configuration parameters specific to
45
+ RBLN-optimized BERT models for question answering tasks.
46
+ """
@@ -24,12 +24,36 @@ logger = get_logger(__name__)
24
24
 
25
25
 
26
26
  class RBLNBertModel(RBLNTransformerEncoderForFeatureExtraction):
27
+ """
28
+ RBLN optimized BERT model for feature extraction tasks.
29
+
30
+ This class provides hardware-accelerated inference for BERT models
31
+ on RBLN devices, optimized for extracting contextualized embeddings
32
+ and features from text sequences.
33
+ """
34
+
27
35
  rbln_model_input_names = ["input_ids", "attention_mask"]
28
36
 
29
37
 
30
38
  class RBLNBertForMaskedLM(RBLNModelForMaskedLM):
39
+ """
40
+ RBLN optimized BERT model for masked language modeling tasks.
41
+
42
+ This class provides hardware-accelerated inference for BERT models
43
+ on RBLN devices, supporting masked language modeling tasks such as
44
+ token prediction and text completion.
45
+ """
46
+
31
47
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
32
48
 
33
49
 
34
50
  class RBLNBertForQuestionAnswering(RBLNModelForQuestionAnswering):
51
+ """
52
+ RBLN optimized BERT model for question answering tasks.
53
+
54
+ This class provides hardware-accelerated inference for BERT models
55
+ on RBLN devices, supporting extractive question answering tasks where
56
+ the model predicts start and end positions of answers in text.
57
+ """
58
+
35
59
  rbln_model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
@@ -12,16 +12,28 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNBlip2VisionModelConfig(RBLNModelConfig):
21
- pass
21
+ """
22
+ Configuration class for RBLNBlip2VisionModel.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized BLIP-2 vision encoder models for multimodal tasks.
26
+ """
22
27
 
23
28
 
24
29
  class RBLNBlip2QFormerModelConfig(RBLNModelConfig):
30
+ """
31
+ Configuration class for RBLNBlip2QFormerModel.
32
+
33
+ This configuration class stores the configuration parameters specific to
34
+ RBLN-optimized BLIP-2 Q-Former models that bridge vision and language modalities.
35
+ """
36
+
25
37
  def __init__(
26
38
  self,
27
39
  num_query_tokens: Optional[int] = None,
@@ -50,7 +62,7 @@ class RBLNBlip2ForConditionalGenerationConfig(RBLNModelConfig):
50
62
  vision_model: Optional[RBLNModelConfig] = None,
51
63
  qformer: Optional[RBLNModelConfig] = None,
52
64
  language_model: Optional[RBLNModelConfig] = None,
53
- **kwargs,
65
+ **kwargs: Dict[str, Any],
54
66
  ):
55
67
  """
56
68
  Args:
@@ -65,6 +65,13 @@ class LoopProjector:
65
65
 
66
66
 
67
67
  class RBLNBlip2VisionModel(RBLNModel):
68
+ """
69
+ RBLN optimized BLIP-2 vision encoder model.
70
+
71
+ This class provides hardware-accelerated inference for BLIP-2 vision encoders
72
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
73
+ """
74
+
68
75
  def get_input_embeddings(self):
69
76
  return self.embeddings
70
77
 
@@ -136,6 +143,14 @@ class RBLNBlip2VisionModel(RBLNModel):
136
143
 
137
144
 
138
145
  class RBLNBlip2QFormerModel(RBLNModel):
146
+ """
147
+ RBLN optimized BLIP-2 Q-Former model.
148
+
149
+ This class provides hardware-accelerated inference for BLIP-2 Q-Former models
150
+ on RBLN devices, which bridge vision and language modalities through cross-attention
151
+ mechanisms for multimodal understanding tasks.
152
+ """
153
+
139
154
  def get_input_embeddings(self):
140
155
  return self.embeddings.word_embeddings
141
156
 
@@ -251,6 +266,38 @@ class RBLNBlip2QFormerModel(RBLNModel):
251
266
 
252
267
 
253
268
  class RBLNBlip2ForConditionalGeneration(RBLNModel):
269
+ """
270
+ RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
271
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
272
+
273
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
274
+
275
+ Important Note:
276
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
277
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
278
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNBlip2ForConditionalGeneration class for details.
279
+
280
+ Examples:
281
+ ```python
282
+ from optimum.rbln import RBLNBlip2ForConditionalGeneration
283
+
284
+ model = RBLNBlip2ForConditionalGeneration.from_pretrained(
285
+ "Salesforce/blip2-opt-2.7b",
286
+ export=True,
287
+ rbln_config={
288
+ "language_model": {
289
+ "batch_size": 1,
290
+ "max_seq_len": 2048,
291
+ "tensor_parallel_size": 1,
292
+ "use_inputs_embeds": True,
293
+ },
294
+ },
295
+ )
296
+
297
+ model.save_pretrained("compiled-blip2-opt-2.7b")
298
+ ```
299
+ """
300
+
254
301
  auto_model_class = AutoModelForVisualQuestionAnswering
255
302
  _rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
256
303
 
@@ -275,10 +322,9 @@ class RBLNBlip2ForConditionalGeneration(RBLNModel):
275
322
  subfolder: str,
276
323
  rbln_config: RBLNModelConfig,
277
324
  ):
278
- """
279
- If you are unavoidably running on a CPU rather than an RBLN device,
280
- store the torch tensor, weight, etc. in this function.
281
- """
325
+ # If you are unavoidably running on a CPU rather than an RBLN device,
326
+ # store the torch tensor, weight, etc. in this function.
327
+
282
328
  save_dict = {}
283
329
  save_dict["query_tokens"] = model.query_tokens
284
330
  torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
@@ -12,13 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Dict, Optional
16
16
 
17
17
  from ....configuration_utils import RBLNModelConfig
18
18
 
19
19
 
20
20
  class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
- def __init__(self, batch_size: Optional[int] = None, **kwargs):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Dict[str, Any]):
22
22
  """
23
23
  Args:
24
24
  batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
@@ -34,11 +34,16 @@ class RBLNCLIPTextModelConfig(RBLNModelConfig):
34
34
 
35
35
 
36
36
  class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
- pass
37
+ """
38
+ Configuration class for RBLNCLIPTextModelWithProjection.
39
+
40
+ This configuration inherits from RBLNCLIPTextModelConfig and stores
41
+ configuration parameters for CLIP text models with projection layers.
42
+ """
38
43
 
39
44
 
40
45
  class RBLNCLIPVisionModelConfig(RBLNModelConfig):
41
- def __init__(self, batch_size: Optional[int] = None, image_size: Optional[int] = None, **kwargs):
46
+ def __init__(self, batch_size: Optional[int] = None, image_size: Optional[int] = None, **kwargs: Dict[str, Any]):
42
47
  """
43
48
  Args:
44
49
  batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
@@ -76,4 +81,9 @@ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
76
81
 
77
82
 
78
83
  class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
79
- pass
84
+ """
85
+ Configuration class for RBLNCLIPVisionModelWithProjection.
86
+
87
+ This configuration inherits from RBLNCLIPVisionModelConfig and stores
88
+ configuration parameters for CLIP vision models with projection layers.
89
+ """