optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,526 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import (
21
+ AutoModelForVisualQuestionAnswering,
22
+ Blip2ForConditionalGeneration,
23
+ Blip2QFormerModel,
24
+ Blip2VisionModel,
25
+ PretrainedConfig,
26
+ PreTrainedModel,
27
+ )
28
+ from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions
29
+ from transformers.utils import logging
30
+
31
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
32
+ from ....modeling import RBLNModel
33
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
34
+ from ..decoderonly.generation_decoderonly import RBLNDecoderOnlyGenerationMixin
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ if TYPE_CHECKING:
40
+ import rebel
41
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer
42
+
43
+
44
+ class LoopProjector(LoopProcessor):
45
+ def __init__(self, language_projection: Union[RBLNModel, "rebel.Runtime"]):
46
+ super().__init__(model=language_projection)
47
+
48
+ def _get_batch_size(self, query_output, **kwargs):
49
+ return query_output.shape[0]
50
+
51
+ def _prepare_inputs_for_iteration(self, index, common_inputs, query_output, **kwargs):
52
+ query_output_item = query_output[index : index + 1]
53
+ return ([query_output_item], {})
54
+
55
+ def _process_outputs(self, outputs: list, **kwargs):
56
+ output = torch.cat(outputs, dim=0)
57
+ return output
58
+
59
+
60
+ class RBLNBlip2VisionModel(RBLNModel):
61
+ """
62
+ RBLN optimized BLIP-2 vision encoder model.
63
+
64
+ This class provides hardware-accelerated inference for BLIP-2 vision encoders
65
+ on RBLN devices, supporting image encoding for multimodal vision-language tasks.
66
+ """
67
+
68
+ _tp_support = False
69
+
70
+ def get_input_embeddings(self):
71
+ return self.embeddings
72
+
73
+ @classmethod
74
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
75
+ class Blip2VisionModelWrapper(torch.nn.Module):
76
+ def __init__(self, model: "Blip2VisionModel") -> None:
77
+ super().__init__()
78
+ self.model = model
79
+
80
+ def forward(self, *args, **kwargs):
81
+ kwargs.pop("return_dict", None)
82
+ return self.model(*args, **kwargs, return_dict=False)
83
+
84
+ return Blip2VisionModelWrapper(model).eval()
85
+
86
+ @classmethod
87
+ def _update_rbln_config(
88
+ cls,
89
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
90
+ model: Optional["PreTrainedModel"] = None,
91
+ model_config: Optional["PretrainedConfig"] = None,
92
+ rbln_config: Optional[RBLNModelConfig] = None,
93
+ ) -> RBLNModelConfig:
94
+ input_info = [
95
+ (
96
+ "pixel_values",
97
+ [
98
+ rbln_config.batch_size,
99
+ model_config.num_channels,
100
+ model_config.image_size,
101
+ model_config.image_size,
102
+ ],
103
+ "float32",
104
+ ),
105
+ ]
106
+
107
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
108
+ rbln_config.set_compile_cfgs([rbln_compile_config])
109
+ return rbln_config
110
+
111
+ def forward(
112
+ self,
113
+ pixel_values: torch.FloatTensor,
114
+ interpolate_pos_encoding: bool = False,
115
+ return_dict: Optional[bool] = None,
116
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
117
+ """
118
+ Forward pass for the RBLN-optimized Blip2VisionModel model.
119
+
120
+ Args:
121
+ pixel_values (torch.FloatTensor of shape (batch_size, num_channels, height, width)): The tensors corresponding to the input images.
122
+ interpolate_pos_encoding (bool, optional): Whether to interpolate the positional encoding of the image embeddings. Defaults to False.
123
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
124
+
125
+ Returns:
126
+ BaseModelOutputWithPooling or tuple(torch.FloatTensor): The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
127
+ """
128
+ batch_size = pixel_values.shape[0]
129
+ outputs = []
130
+ for i in range(batch_size):
131
+ outputs.append(self.model[0](pixel_values[i : i + 1]))
132
+
133
+ last_hidden_state = [output[0] for output in outputs]
134
+ pooler_output = [output[1] for output in outputs]
135
+
136
+ last_hidden_state = torch.cat(last_hidden_state, dim=0)
137
+ pooler_output = torch.cat(pooler_output, dim=0)
138
+
139
+ if not return_dict:
140
+ return (last_hidden_state, pooler_output)
141
+
142
+ return BaseModelOutputWithPooling(
143
+ last_hidden_state=last_hidden_state,
144
+ pooler_output=pooler_output,
145
+ )
146
+
147
+
148
+ class RBLNBlip2QFormerModel(RBLNModel):
149
+ """
150
+ RBLN optimized BLIP-2 Q-Former model.
151
+
152
+ This class provides hardware-accelerated inference for BLIP-2 Q-Former models
153
+ on RBLN devices, which bridge vision and language modalities through cross-attention
154
+ mechanisms for multimodal understanding tasks.
155
+ """
156
+
157
+ _tp_support = False
158
+
159
+ def get_input_embeddings(self):
160
+ return self.embeddings.word_embeddings
161
+
162
+ @classmethod
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
164
+ class Blip2QFormerModelWrapper(torch.nn.Module):
165
+ def __init__(self, model: "Blip2QFormerModel"):
166
+ super().__init__()
167
+ self.model = model
168
+
169
+ def forward(
170
+ self,
171
+ query_embeds: torch.FloatTensor,
172
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
173
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
174
+ ) -> torch.Tensor:
175
+ qformer_out = self.model(
176
+ query_embeds=query_embeds,
177
+ encoder_hidden_states=encoder_hidden_states,
178
+ encoder_attention_mask=encoder_attention_mask,
179
+ return_dict=False,
180
+ )
181
+ return qformer_out
182
+
183
+ return Blip2QFormerModelWrapper(model).eval()
184
+
185
+ @classmethod
186
+ def _update_submodule_config(
187
+ cls,
188
+ model: "PreTrainedModel",
189
+ rbln_config: RBLNModelConfig,
190
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
191
+ ):
192
+ if rbln_config.num_query_tokens is None:
193
+ rbln_config.num_query_tokens = model.config.num_query_tokens
194
+
195
+ if rbln_config.image_text_hidden_size is None:
196
+ rbln_config.image_text_hidden_size = model.config.image_text_hidden_size
197
+
198
+ return rbln_config
199
+
200
+ @classmethod
201
+ def _update_rbln_config(
202
+ cls,
203
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
204
+ model: Optional["PreTrainedModel"] = None,
205
+ model_config: Optional["PretrainedConfig"] = None,
206
+ rbln_config: Optional[RBLNModelConfig] = None,
207
+ ) -> RBLNModelConfig:
208
+ input_info = [
209
+ (
210
+ "query_embeds",
211
+ [
212
+ rbln_config.batch_size,
213
+ rbln_config.num_query_tokens,
214
+ model_config.hidden_size,
215
+ ],
216
+ "float32",
217
+ ),
218
+ (
219
+ "encoder_hidden_states",
220
+ [
221
+ rbln_config.batch_size,
222
+ # image_text_hidden_size + cls token
223
+ rbln_config.image_text_hidden_size + 1,
224
+ model_config.encoder_hidden_size,
225
+ ],
226
+ "float32",
227
+ ),
228
+ (
229
+ "encoder_attention_mask",
230
+ # image_text_hidden_size + cls token
231
+ [rbln_config.batch_size, rbln_config.image_text_hidden_size + 1],
232
+ "int64",
233
+ ),
234
+ ]
235
+
236
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
237
+ rbln_config.set_compile_cfgs([rbln_compile_config])
238
+ return rbln_config
239
+
240
+ def forward(
241
+ self,
242
+ query_embeds: torch.FloatTensor,
243
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
244
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
245
+ return_dict: Optional[bool] = None,
246
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
247
+ """
248
+ The forward pass for the RBLN-optimized Blip2QFormerModel model.
249
+
250
+ Args:
251
+ query_embeds (torch.FloatTensor): Hidden states to be used in the attention computation.
252
+ encoder_hidden_states (torch.FloatTensor, optional): Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model is configured as a decoder.
253
+ encoder_attention_mask (torch.FloatTensor, optional): Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in the cross-attention if the model is configured as a decoder.
254
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
255
+
256
+ Returns:
257
+ BaseModelOutputWithPoolingAndCrossAttentions or tuple(torch.FloatTensor): The model outputs. If `return_dict=False` is passed, returns a tuple of tensors. Otherwise, returns a `BaseModelOutputWithPoolingAndCrossAttentions` object.
258
+ """
259
+ batch_size = query_embeds.shape[0]
260
+ outputs = []
261
+ for i in range(batch_size):
262
+ outputs.append(
263
+ self.model[0](
264
+ query_embeds[i : i + 1], encoder_hidden_states[i : i + 1], encoder_attention_mask[i : i + 1]
265
+ )
266
+ )
267
+
268
+ sequence_output = [output[0] for output in outputs]
269
+ pooled_output = [output[1] for output in outputs]
270
+
271
+ sequence_output = torch.cat(sequence_output, dim=0)
272
+ pooled_output = torch.cat(pooled_output, dim=0)
273
+
274
+ if not return_dict:
275
+ return (sequence_output, pooled_output)
276
+
277
+ return BaseModelOutputWithPoolingAndCrossAttentions(
278
+ last_hidden_state=sequence_output,
279
+ pooler_output=pooled_output,
280
+ )
281
+
282
+
283
+ class RBLNBlip2ForConditionalGeneration(RBLNModel, RBLNDecoderOnlyGenerationMixin):
284
+ """
285
+ RBLNBlip2ForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
286
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
287
+
288
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
289
+
290
+ Important Note:
291
+ This model includes a Large Language Model (LLM) as a submodule. For optimal performance, it is highly recommended to use
292
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
293
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNBlip2ForConditionalGeneration class for details.
294
+
295
+ Examples:
296
+ ```python
297
+ from optimum.rbln import RBLNBlip2ForConditionalGeneration
298
+
299
+ model = RBLNBlip2ForConditionalGeneration.from_pretrained(
300
+ "Salesforce/blip2-opt-2.7b",
301
+ export=True,
302
+ rbln_config={
303
+ "language_model": {
304
+ "batch_size": 1,
305
+ "max_seq_len": 2048,
306
+ "tensor_parallel_size": 1,
307
+ "use_inputs_embeds": True,
308
+ },
309
+ },
310
+ )
311
+
312
+ model.save_pretrained("compiled-blip2-opt-2.7b")
313
+ ```
314
+ """
315
+
316
+ auto_model_class = AutoModelForVisualQuestionAnswering
317
+ _rbln_submodules = [{"name": "vision_model"}, {"name": "qformer"}, {"name": "language_model"}]
318
+
319
+ def __getattr__(self, __name: str) -> Any:
320
+ def redirect(func):
321
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
322
+
323
+ val = getattr(Blip2ForConditionalGeneration, __name)
324
+
325
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
326
+ return redirect(val)
327
+ return val
328
+
329
+ def can_generate(self):
330
+ return True
331
+
332
+ @classmethod
333
+ def save_torch_artifacts(
334
+ cls,
335
+ model: "Blip2ForConditionalGeneration",
336
+ save_dir_path: Path,
337
+ subfolder: str,
338
+ rbln_config: RBLNModelConfig,
339
+ ):
340
+ # If you are unavoidably running on a CPU rather than an RBLN device,
341
+ # store the torch tensor, weight, etc. in this function.
342
+
343
+ save_dict = {}
344
+ save_dict["query_tokens"] = model.query_tokens
345
+ torch.save(save_dict, save_dir_path / subfolder / "query_tokens.pth")
346
+
347
+ def __post_init__(self, **kwargs):
348
+ self.vision_model = self.rbln_submodules[0]
349
+ self.language_model = self.rbln_submodules[2]
350
+ self.qformer = self.rbln_submodules[1]
351
+ self.language_projection = LoopProjector(self.model[0])
352
+
353
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "query_tokens.pth", weights_only=False)
354
+ self.query_tokens = artifacts["query_tokens"]
355
+
356
+ def get_attn_impl(self) -> str:
357
+ return self.rbln_config.language_model.attn_impl
358
+
359
+ def get_kvcache_num_blocks(self) -> int:
360
+ return self.rbln_config.language_model.kvcache_num_blocks
361
+
362
+ def get_input_embeddings(self):
363
+ return self.language_model.get_input_embeddings()
364
+
365
+ @classmethod
366
+ def _wrap_model_if_needed(cls, model, rbln_config):
367
+ return model.language_projection
368
+
369
+ @classmethod
370
+ def _update_rbln_config(
371
+ cls,
372
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
373
+ model: Optional["PreTrainedModel"] = None,
374
+ model_config: Optional["PretrainedConfig"] = None,
375
+ rbln_config: Optional[RBLNModelConfig] = None,
376
+ ) -> RBLNModelConfig:
377
+ input_info = [
378
+ (
379
+ "query_output",
380
+ [
381
+ 1,
382
+ model_config.num_query_tokens,
383
+ model_config.qformer_config.hidden_size,
384
+ ],
385
+ "float32",
386
+ ),
387
+ ]
388
+
389
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
390
+ rbln_config.set_compile_cfgs([rbln_compile_config])
391
+
392
+ return rbln_config
393
+
394
+ def _preprocess_prefill(
395
+ self,
396
+ pixel_values: torch.FloatTensor,
397
+ input_ids: torch.FloatTensor,
398
+ attention_mask: Optional[torch.LongTensor] = None,
399
+ return_dict: Optional[bool] = None,
400
+ interpolate_pos_encoding: bool = False,
401
+ **kwargs,
402
+ ):
403
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
404
+
405
+ vision_outputs = self.vision_model(
406
+ pixel_values=pixel_values,
407
+ return_dict=return_dict,
408
+ interpolate_pos_encoding=interpolate_pos_encoding,
409
+ )
410
+ image_embeds = vision_outputs[0]
411
+
412
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
413
+
414
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
415
+
416
+ query_outputs = self.qformer(
417
+ query_embeds=query_tokens,
418
+ encoder_hidden_states=image_embeds,
419
+ encoder_attention_mask=image_attention_mask,
420
+ return_dict=return_dict,
421
+ )
422
+ query_output = query_outputs[0]
423
+
424
+ if query_output.dtype != image_embeds.dtype:
425
+ query_output = query_output.to(image_embeds.dtype)
426
+
427
+ language_model_inputs = self.language_projection(query_output)
428
+ language_model_attention_mask = torch.ones(
429
+ language_model_inputs.size()[:-1], dtype=torch.long, device=language_model_inputs.device
430
+ )
431
+ inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
432
+ if attention_mask is None:
433
+ attention_mask = torch.ones_like(input_ids)
434
+
435
+ if getattr(self.config, "image_token_index", None) is not None:
436
+ special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
437
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
438
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
439
+ else:
440
+ logger.warning_once(
441
+ "Expanding inputs for image tokens in BLIP-2 should be done in processing. "
442
+ "Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. "
443
+ "Using processors without these attributes in the config is deprecated and will throw an error in v4.50."
444
+ )
445
+ inputs_embeds = torch.cat([language_model_inputs, inputs_embeds.to(language_model_inputs.device)], dim=1)
446
+ attention_mask = torch.cat(
447
+ [language_model_attention_mask, attention_mask.to(language_model_attention_mask.device)], dim=1
448
+ )
449
+
450
+ return inputs_embeds
451
+
452
+ @torch.no_grad()
453
+ def generate(
454
+ self,
455
+ pixel_values: torch.FloatTensor,
456
+ input_ids: Optional[torch.LongTensor] = None,
457
+ attention_mask: Optional[torch.LongTensor] = None,
458
+ inputs_embeds: Optional[torch.FloatTensor] = None,
459
+ interpolate_pos_encoding: bool = False,
460
+ **generate_kwargs,
461
+ ) -> List[torch.LongTensor]:
462
+ """
463
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
464
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/blip-2#transformers.Blip2ForConditionalGeneration.generate) for more details.
465
+
466
+ Args:
467
+ pixel_values (torch.FloatTensor): Input images to be processed.
468
+ input_ids (torch.LongTensor, optional): The sequence used as a prompt for the generation.
469
+ attention_mask (torch.LongTensor, optional): Mask to avoid performing attention on padding token indices
470
+ inputs_embeds (torch.FloatTensor, optional): Embedded representation of the inputs. Should be float, not int tokens.
471
+ interpolate_pos_encoding (bool, optional, defaults to False) — Whether to interpolate the positional encoding of the image embeddings.
472
+ Returns:
473
+ A list of strings of length batch_size * num_captions.
474
+ """
475
+ batch_size = pixel_values.shape[0]
476
+ image_embeds = self.vision_model(
477
+ pixel_values,
478
+ return_dict=True,
479
+ interpolate_pos_encoding=interpolate_pos_encoding,
480
+ ).last_hidden_state
481
+ image_attention_mask = torch.ones(image_embeds.size()[:-1], dtype=torch.long, device=image_embeds.device)
482
+
483
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
484
+ query_outputs = self.qformer(
485
+ query_embeds=query_tokens,
486
+ encoder_hidden_states=image_embeds,
487
+ encoder_attention_mask=image_attention_mask,
488
+ return_dict=True,
489
+ )
490
+ query_output = query_outputs.last_hidden_state
491
+
492
+ if query_output.dtype != image_embeds.dtype:
493
+ query_output = query_output.to(image_embeds.dtype)
494
+
495
+ language_model_inputs = self.language_projection(query_output)
496
+
497
+ if inputs_embeds is None:
498
+ if input_ids is None:
499
+ image_tokens = [self.config.image_token_index] * self.config.num_query_tokens
500
+ start_tokens = image_tokens + [self.config.text_config.bos_token_id]
501
+ input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
502
+ input_ids = input_ids.repeat(batch_size, 1)
503
+ inputs_embeds = self.get_input_embeddings()(input_ids)
504
+
505
+ if attention_mask is None:
506
+ attention_mask = torch.ones_like(input_ids)
507
+
508
+ if input_ids is None:
509
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
510
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
511
+ )
512
+ special_image_mask = special_image_mask.all(-1)
513
+ else:
514
+ special_image_mask = input_ids == self.config.image_token_id
515
+
516
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
517
+ language_model_inputs = language_model_inputs.to(inputs_embeds.device, inputs_embeds.dtype)
518
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, language_model_inputs)
519
+
520
+ inputs = {"inputs_embeds": inputs_embeds, "attention_mask": attention_mask}
521
+ if not self.language_model.config.is_encoder_decoder:
522
+ inputs["input_ids"] = input_ids
523
+
524
+ outputs = self.language_model.generate(**inputs, **generate_kwargs)
525
+
526
+ return outputs
@@ -0,0 +1,26 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_clip import (
16
+ RBLNCLIPTextModelConfig,
17
+ RBLNCLIPTextModelWithProjectionConfig,
18
+ RBLNCLIPVisionModelConfig,
19
+ RBLNCLIPVisionModelWithProjectionConfig,
20
+ )
21
+ from .modeling_clip import (
22
+ RBLNCLIPTextModel,
23
+ RBLNCLIPTextModelWithProjection,
24
+ RBLNCLIPVisionModel,
25
+ RBLNCLIPVisionModelWithProjection,
26
+ )
@@ -0,0 +1,103 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNCLIPTextModelConfig(RBLNModelConfig):
21
+ def __init__(self, batch_size: Optional[int] = None, **kwargs: Any):
22
+ """
23
+ Args:
24
+ batch_size (Optional[int]): The batch size for text processing. Defaults to 1.
25
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
26
+
27
+ Raises:
28
+ ValueError: If `batch_size` is not a positive integer.
29
+ """
30
+ super().__init__(**kwargs)
31
+ self.batch_size = batch_size or 1
32
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
33
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
34
+
35
+
36
+ class RBLNCLIPTextModelWithProjectionConfig(RBLNCLIPTextModelConfig):
37
+ """
38
+ Configuration class for RBLNCLIPTextModelWithProjection.
39
+
40
+ This configuration inherits from RBLNCLIPTextModelConfig and stores
41
+ configuration parameters for CLIP text models with projection layers.
42
+ """
43
+
44
+
45
+ class RBLNCLIPVisionModelConfig(RBLNModelConfig):
46
+ def __init__(
47
+ self,
48
+ batch_size: Optional[int] = None,
49
+ image_size: Optional[int] = None,
50
+ interpolate_pos_encoding: Optional[bool] = None,
51
+ output_hidden_states: Optional[bool] = None,
52
+ output_attentions: Optional[bool] = None,
53
+ **kwargs: Any,
54
+ ):
55
+ """
56
+ Args:
57
+ batch_size (Optional[int]): The batch size for image processing. Defaults to 1.
58
+ image_size (Optional[int]): The size of input images. Can be an integer for square images,
59
+ a tuple/list (height, width), or a dictionary with 'height' and 'width' keys.
60
+ interpolate_pos_encoding (Optional[bool]): Whether or not to interpolate pre-trained position encodings. Defaults to `False`.
61
+ output_hidden_states (Optional[bool]): Whether or not to return the hidden states of all layers.
62
+ output_attentions (Optional[bool]): Whether or not to return the attentions tensors of all attention layers
63
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
64
+
65
+ Raises:
66
+ ValueError: If `batch_size` is not a positive integer.
67
+ """
68
+ super().__init__(**kwargs)
69
+ self.batch_size = batch_size or 1
70
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
71
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
72
+
73
+ self.image_size = image_size
74
+ self.interpolate_pos_encoding = interpolate_pos_encoding or False
75
+ self.output_hidden_states = output_hidden_states
76
+ self.output_attentions = output_attentions
77
+
78
+ @property
79
+ def image_width(self):
80
+ if isinstance(self.image_size, int):
81
+ return self.image_size
82
+ elif isinstance(self.image_size, (list, tuple)):
83
+ return self.image_size[1]
84
+ else:
85
+ return self.image_size["width"]
86
+
87
+ @property
88
+ def image_height(self):
89
+ if isinstance(self.image_size, int):
90
+ return self.image_size
91
+ elif isinstance(self.image_size, (list, tuple)):
92
+ return self.image_size[0]
93
+ else:
94
+ return self.image_size["height"]
95
+
96
+
97
+ class RBLNCLIPVisionModelWithProjectionConfig(RBLNCLIPVisionModelConfig):
98
+ """
99
+ Configuration class for RBLNCLIPVisionModelWithProjection.
100
+
101
+ This configuration inherits from RBLNCLIPVisionModelConfig and stores
102
+ configuration parameters for CLIP vision models with projection layers.
103
+ """