optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,384 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers import CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
19
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
20
+ from transformers.models.clip.modeling_clip import CLIPTextModelOutput, CLIPVisionModelOutput
21
+
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
24
+ from ....utils.logging import get_logger
25
+ from .configuration_clip import RBLNCLIPTextModelConfig, RBLNCLIPVisionModelConfig
26
+
27
+
28
+ logger = get_logger(__name__)
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, CLIPTextModel, PreTrainedModel
32
+
33
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
34
+
35
+
36
+ class _TextEncoder(torch.nn.Module):
37
+ def __init__(self, enc: "CLIPTextModel"):
38
+ super().__init__()
39
+ self.enc = enc
40
+
41
+ def forward(self, inp):
42
+ enc_out = self.enc(inp, output_hidden_states=True, return_dict=False)
43
+ return enc_out
44
+
45
+
46
+ class RBLNCLIPTextModel(RBLNModel):
47
+ """
48
+ RBLN optimized CLIP text encoder model.
49
+
50
+ This class provides hardware-accelerated inference for CLIP text encoders
51
+ on RBLN devices, supporting text encoding for multimodal tasks.
52
+ """
53
+
54
+ _tp_support = False
55
+
56
+ @classmethod
57
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPTextModelConfig) -> torch.nn.Module:
58
+ return _TextEncoder(model).eval()
59
+
60
+ @classmethod
61
+ def update_rbln_config_using_pipe(
62
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
63
+ ) -> "RBLNDiffusionMixinConfig":
64
+ return rbln_config
65
+
66
+ @classmethod
67
+ def _update_rbln_config(
68
+ cls,
69
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
70
+ model: Optional["PreTrainedModel"] = None,
71
+ model_config: "CLIPTextConfig" = None,
72
+ rbln_config: Optional[RBLNCLIPTextModelConfig] = None,
73
+ ) -> RBLNCLIPTextModelConfig:
74
+ input_info = [
75
+ (
76
+ "input_ids",
77
+ [
78
+ rbln_config.batch_size,
79
+ model_config.max_position_embeddings,
80
+ ],
81
+ "int64",
82
+ ),
83
+ ]
84
+
85
+ rbln_config.set_compile_cfgs([RBLNCompileConfig(input_info=input_info)])
86
+ return rbln_config
87
+
88
+ def forward(self, input_ids: torch.LongTensor, return_dict: Optional[bool] = None, **kwargs) -> torch.FloatTensor:
89
+ """
90
+ Forward pass for the RBLN-optimized CLIP text encoder model.
91
+
92
+ Args:
93
+ input_ids (torch.LongTensor): The input ids to the model.
94
+ return_dict (Optional[bool]): Whether to return a dictionary of outputs.
95
+
96
+ Returns:
97
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPTextModelOutput object.
98
+ """
99
+
100
+ # To ignore using attention_mask, we override forward method.
101
+ output = super().forward(input_ids, return_dict=return_dict)
102
+ return output
103
+
104
+ def _prepare_output(self, output, return_dict):
105
+ # Prepare model output based on return_dict flag.
106
+ # This method can be overridden by subclasses to provide task-specific output handling.
107
+
108
+ if not return_dict:
109
+ return (output,) if not isinstance(output, (tuple, list)) else output
110
+ else:
111
+ return CLIPTextModelOutput(
112
+ text_embeds=output[0],
113
+ last_hidden_state=output[1],
114
+ hidden_states=output[2:],
115
+ )
116
+
117
+
118
+ class RBLNCLIPTextModelWithProjection(RBLNCLIPTextModel):
119
+ """
120
+ RBLN optimized CLIP text encoder model with projection layer.
121
+
122
+ This class extends RBLNCLIPTextModel with a projection layer for
123
+ multimodal embedding alignment tasks.
124
+ """
125
+
126
+
127
+ class _VisionEncoder(torch.nn.Module):
128
+ def __init__(
129
+ self,
130
+ enc: CLIPVisionModel,
131
+ interpolate_pos_encoding: bool,
132
+ output_hidden_states: bool,
133
+ output_attentions: bool,
134
+ ):
135
+ super().__init__()
136
+ self.enc = enc
137
+ self.interpolate_pos_encoding = interpolate_pos_encoding
138
+ self.output_hidden_states = output_hidden_states
139
+ self.output_attentions = output_attentions
140
+
141
+ def forward(self, inp):
142
+ enc_out = self.enc(
143
+ inp,
144
+ output_hidden_states=self.output_hidden_states,
145
+ interpolate_pos_encoding=self.interpolate_pos_encoding,
146
+ output_attentions=self.output_attentions,
147
+ return_dict=False,
148
+ )
149
+ return enc_out
150
+
151
+
152
+ class RBLNCLIPVisionModel(RBLNModel):
153
+ """
154
+ RBLN optimized CLIP vision encoder model.
155
+
156
+ This class provides hardware-accelerated inference for CLIP vision encoders
157
+ on RBLN devices, supporting image encoding for multimodal tasks.
158
+ """
159
+
160
+ _tp_support = False
161
+
162
+ @classmethod
163
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNCLIPVisionModelConfig) -> torch.nn.Module:
164
+ wrapper_cfg = {
165
+ "interpolate_pos_encoding": rbln_config.interpolate_pos_encoding,
166
+ "output_hidden_states": rbln_config.output_hidden_states,
167
+ "output_attentions": rbln_config.output_attentions,
168
+ }
169
+ return _VisionEncoder(model, **wrapper_cfg).eval()
170
+
171
+ @classmethod
172
+ def update_rbln_config_using_pipe(
173
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
174
+ ) -> "RBLNDiffusionMixinConfig":
175
+ return rbln_config
176
+
177
+ @classmethod
178
+ def _update_rbln_config(
179
+ cls,
180
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
181
+ model: Optional["PreTrainedModel"] = None,
182
+ model_config: "CLIPVisionConfig" = None,
183
+ rbln_config: Optional[RBLNCLIPVisionModelConfig] = None,
184
+ ) -> RBLNCLIPVisionModelConfig:
185
+ if rbln_config.image_size is None:
186
+ rbln_config.image_size = getattr(model_config, "image_size", None)
187
+
188
+ if isinstance(rbln_config.image_size, int):
189
+ rbln_config.image_size = (rbln_config.image_size, rbln_config.image_size)
190
+
191
+ if rbln_config.image_size is None:
192
+ raise ValueError("`rbln_image_size` should be specified!")
193
+
194
+ if rbln_config.output_attentions is None:
195
+ rbln_config.output_attentions = getattr(model_config, "output_attentions", False)
196
+
197
+ if rbln_config.output_hidden_states is None:
198
+ rbln_config.output_hidden_states = getattr(model_config, "output_hidden_states", False)
199
+
200
+ rbln_compile_config = RBLNCompileConfig(
201
+ input_info=[
202
+ (
203
+ "pixel_values",
204
+ [
205
+ rbln_config.batch_size,
206
+ 3,
207
+ rbln_config.image_height,
208
+ rbln_config.image_width,
209
+ ],
210
+ "float32",
211
+ )
212
+ ]
213
+ )
214
+
215
+ rbln_config.set_compile_cfgs([rbln_compile_config])
216
+ return rbln_config
217
+
218
+ def forward(
219
+ self,
220
+ pixel_values: torch.FloatTensor,
221
+ return_dict: bool = True,
222
+ output_attentions: Optional[bool] = None,
223
+ output_hidden_states: Optional[bool] = None,
224
+ interpolate_pos_encoding: bool = False,
225
+ **kwargs,
226
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
227
+ """
228
+ Forward pass for the RBLN-optimized CLIP vision encoder model.
229
+
230
+ Args:
231
+ pixel_values (torch.Tensor): The pixel values to the model.
232
+ return_dict (bool): Whether to return a dictionary of outputs.
233
+ output_attentions (Optional[bool]): Whether to return attentions.
234
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
235
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
236
+
237
+ Returns:
238
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPooling object.
239
+ """
240
+
241
+ if len(kwargs) > 0 and any(value is not None for value in kwargs.values()):
242
+ logger.warning(
243
+ f"Currently, optimum-rbln does not support kwargs {kwargs.keys()} for {self.__class__.__name__}."
244
+ )
245
+
246
+ output_attentions = output_attentions if output_attentions is not None else self.rbln_config.output_attentions
247
+ output_hidden_states = (
248
+ output_hidden_states if output_hidden_states is not None else self.rbln_config.output_hidden_states
249
+ )
250
+
251
+ if output_attentions != self.rbln_config.output_attentions:
252
+ raise ValueError(
253
+ f"Variable output_attentions {output_attentions} is not equal to rbln_config.output_attentions {self.rbln_config.output_attentions} "
254
+ f"Please compile again with the correct argument."
255
+ )
256
+
257
+ if output_hidden_states != self.rbln_config.output_hidden_states:
258
+ raise ValueError(
259
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
260
+ f"Please compile again with the correct argument."
261
+ )
262
+
263
+ if interpolate_pos_encoding != self.rbln_config.interpolate_pos_encoding:
264
+ raise ValueError(
265
+ f"Variable interpolate_pos_encoding {interpolate_pos_encoding} is not equal to rbln_config.interpolate_pos_encoding {self.rbln_config.interpolate_pos_encoding} "
266
+ f"Please compile again with the correct argument."
267
+ )
268
+
269
+ output = super().forward(pixel_values, return_dict=return_dict)
270
+ return output
271
+
272
+ def _prepare_output(self, output, return_dict):
273
+ # Prepare model output based on return_dict flag.
274
+ # This method can be overridden by subclasses to provide task-specific output handling.
275
+ last_hidden_state = output.pop(0)
276
+ pooler_output = output.pop(0)
277
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
278
+
279
+ if self.rbln_config.output_hidden_states:
280
+ hidden_states = ()
281
+ num_hidden_layers = vision_config.num_hidden_layers
282
+ for _ in range(num_hidden_layers + 1):
283
+ hidden_states += (output.pop(0),)
284
+ else:
285
+ hidden_states = None
286
+
287
+ if self.rbln_config.output_attentions:
288
+ attentions = ()
289
+ num_hidden_layers = vision_config.num_hidden_layers
290
+ for _ in range(num_hidden_layers):
291
+ attentions += (output.pop(0),)
292
+ else:
293
+ attentions = None
294
+
295
+ if not return_dict:
296
+ return tuple(
297
+ item for item in (last_hidden_state, pooler_output, hidden_states, attentions) if item is not None
298
+ )
299
+ else:
300
+ return BaseModelOutputWithPooling(
301
+ last_hidden_state=last_hidden_state,
302
+ pooler_output=pooler_output,
303
+ hidden_states=hidden_states,
304
+ attentions=attentions,
305
+ )
306
+
307
+
308
+ class RBLNCLIPVisionModelWithProjection(RBLNCLIPVisionModel):
309
+ """
310
+ RBLN optimized CLIP vision encoder model with projection layer.
311
+
312
+ This class extends RBLNCLIPVisionModel with a projection layer for
313
+ multimodal embedding alignment tasks.
314
+ """
315
+
316
+ def forward(
317
+ self,
318
+ pixel_values: torch.FloatTensor,
319
+ return_dict: bool = True,
320
+ output_attentions: Optional[bool] = None,
321
+ output_hidden_states: Optional[bool] = None,
322
+ interpolate_pos_encoding: bool = False,
323
+ **kwargs,
324
+ ) -> Union[Tuple, CLIPVisionModelOutput]:
325
+ """
326
+ Forward pass for the RBLN-optimized CLIP vision encoder model with projection.
327
+
328
+ Args:
329
+ pixel_values (torch.Tensor): The pixel values to the model.
330
+ return_dict (bool): Whether to return a dictionary of outputs.
331
+ output_attentions (Optional[bool]): Whether to return attentions.
332
+ output_hidden_states (Optional[bool]): Whether to return hidden states.
333
+ interpolate_pos_encoding (bool): Whether to interpolate position encoding.
334
+
335
+ Returns:
336
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CLIPVisionModelOutput object.
337
+ """
338
+
339
+ return super().forward(
340
+ pixel_values=pixel_values,
341
+ return_dict=return_dict,
342
+ output_attentions=output_attentions,
343
+ output_hidden_states=output_hidden_states,
344
+ interpolate_pos_encoding=interpolate_pos_encoding,
345
+ **kwargs,
346
+ )
347
+
348
+ def _prepare_output(self, output, return_dict):
349
+ # Prepare model output based on return_dict flag.
350
+ # This method can be overridden by subclasses to provide task-specific output handling.
351
+
352
+ image_embeds = output.pop(0) if isinstance(output, (tuple, list)) else output
353
+ last_hidden_state = output.pop(0)
354
+
355
+ vision_config = self.config.vision_config if hasattr(self.config, "vision_config") else self.config
356
+
357
+ if self.rbln_config.output_hidden_states:
358
+ hidden_states = ()
359
+ num_hidden_layers = vision_config.num_hidden_layers
360
+ for _ in range(num_hidden_layers + 1):
361
+ hidden_states += (output.pop(0),)
362
+ else:
363
+ hidden_states = None
364
+
365
+ if self.rbln_config.output_attentions:
366
+ attentions = ()
367
+ num_hidden_layers = vision_config.num_hidden_layers
368
+ for _ in range(num_hidden_layers):
369
+ attentions += (output.pop(0),)
370
+ else:
371
+ attentions = None
372
+
373
+ if not return_dict:
374
+ return tuple(
375
+ item for item in (image_embeds, last_hidden_state, hidden_states, attentions) if item is not None
376
+ )
377
+
378
+ else:
379
+ return CLIPVisionModelOutput(
380
+ image_embeds=image_embeds,
381
+ last_hidden_state=last_hidden_state,
382
+ hidden_states=hidden_states,
383
+ attentions=attentions,
384
+ )
@@ -0,0 +1,2 @@
1
+ from .configuration_colpali import RBLNColPaliForRetrievalConfig
2
+ from .modeling_colpali import RBLNColPaliForRetrieval
@@ -0,0 +1,218 @@
1
+ from typing import List, Optional, Tuple, Union
2
+
3
+ import torch
4
+ from torch import nn
5
+ from transformers import GemmaForCausalLM, GemmaModel
6
+
7
+ from ..decoderonly.decoderonly_architecture import RotaryEmbedding, apply_rotary_pos_emb
8
+
9
+
10
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
11
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
12
+ cos = cos[position_ids[0]][None, None, None, :, :]
13
+ sin = sin[position_ids[0]][None, None, None, :, :]
14
+
15
+ return cos, sin
16
+
17
+
18
+ class RBLNColPaliForRetrievalWrapper(nn.Module):
19
+ def __init__(
20
+ self,
21
+ causal_lm: GemmaForCausalLM,
22
+ embedding_proj_layer: nn.Module,
23
+ max_seq_len: int,
24
+ output_hidden_states: bool = False,
25
+ ):
26
+ super().__init__()
27
+ self.text_config = causal_lm.config.text_config
28
+ self.rotary_emb = self.get_rotary_emb(max_seq_len=max_seq_len)
29
+
30
+ self.output_hidden_states = output_hidden_states
31
+ self.language_model = self.convert_to_rbln_language_model(causal_lm.model.language_model, max_seq_len)
32
+
33
+ self.num_hidden_layers = getattr(self.text_config, "num_hidden_layers", None)
34
+ self.embedding_proj_layer = embedding_proj_layer
35
+
36
+ def get_rotary_emb(self, max_seq_len):
37
+ return RotaryEmbedding(config=self.text_config, max_seq_len_cached=max_seq_len)
38
+
39
+ def convert_to_rbln_language_model(self, gemma_model: GemmaModel, max_seq_len: int):
40
+ new_layers = []
41
+ for layer in gemma_model.layers:
42
+ new_self_attn = ColPaliAttention(
43
+ layer.self_attn,
44
+ )
45
+ new_layer = ColPaliLayer(layer, new_self_attn)
46
+ new_layers.append(new_layer)
47
+
48
+ new_model = ColPaliModel(
49
+ gemma_model,
50
+ new_layers,
51
+ output_hidden_states=self.output_hidden_states,
52
+ max_seq_len=max_seq_len,
53
+ )
54
+
55
+ return new_model
56
+
57
+ def forward(self, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor):
58
+ attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float32).min
59
+ attention_mask = attention_mask[:, None, None, None, :]
60
+
61
+ hidden_states, all_hidden_states = self.language_model(
62
+ inputs_embeds=inputs_embeds,
63
+ attention_mask=attention_mask,
64
+ rotary_emb=self.rotary_emb,
65
+ position_ids=position_ids,
66
+ )
67
+ embeddings = self.embedding_proj_layer(hidden_states)
68
+
69
+ if self.output_hidden_states:
70
+ return embeddings, all_hidden_states
71
+ else:
72
+ return embeddings
73
+
74
+
75
+ class ColPaliModel(nn.Module):
76
+ def __init__(
77
+ self, model, layers: List["ColPaliLayer"], output_hidden_states: bool = False, max_seq_len: int = 2048
78
+ ):
79
+ super().__init__()
80
+ self._original_mod = model
81
+ self.layers = nn.ModuleList(layers)
82
+ self.output_hidden_states = output_hidden_states
83
+ self.norm = self._original_mod.norm
84
+ self.hidden_size = self._original_mod.config.hidden_size
85
+ self.max_seq_len = max_seq_len
86
+
87
+ def forward(
88
+ self,
89
+ inputs_embeds: Optional[torch.Tensor] = None,
90
+ attention_mask: torch.Tensor = None,
91
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
92
+ position_ids: Optional[torch.Tensor] = None,
93
+ ):
94
+ hidden_states = inputs_embeds * self.hidden_size**0.5
95
+
96
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
97
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
98
+
99
+ all_hidden_states = () if self.output_hidden_states else None
100
+ for layer in self.layers:
101
+ if self.output_hidden_states:
102
+ all_hidden_states += (hidden_states,)
103
+
104
+ hidden_states = layer(
105
+ hidden_states=hidden_states,
106
+ attention_mask=attention_mask,
107
+ cos=cos,
108
+ sin=sin,
109
+ )
110
+ hidden_states = self.norm(hidden_states)
111
+
112
+ if self.output_hidden_states:
113
+ all_hidden_states += (hidden_states,)
114
+
115
+ return hidden_states, all_hidden_states
116
+
117
+
118
+ class ColPaliLayer(nn.Module):
119
+ def __init__(self, layer, self_attn: "ColPaliAttention"):
120
+ super().__init__()
121
+ self._original_mod = layer
122
+ self.self_attn = self_attn
123
+ self.mlp = layer.mlp
124
+ self.input_layernorm = layer.input_layernorm
125
+ self.post_attention_layernorm = layer.post_attention_layernorm
126
+
127
+ def forward(
128
+ self,
129
+ hidden_states: torch.Tensor,
130
+ attention_mask: Optional[torch.Tensor] = None,
131
+ cos: Optional[torch.Tensor] = None,
132
+ sin: Optional[torch.Tensor] = None,
133
+ ) -> Tuple[torch.FloatTensor]:
134
+ residual = hidden_states
135
+ hidden_states = self.input_layernorm(hidden_states)
136
+
137
+ # Self Attention
138
+ hidden_states = self.self_attn(
139
+ hidden_states=hidden_states,
140
+ attention_mask=attention_mask,
141
+ cos=cos,
142
+ sin=sin,
143
+ )
144
+ hidden_states = residual + hidden_states
145
+
146
+ # Fully Connected
147
+ residual = hidden_states
148
+ hidden_states = self.post_attention_layernorm(hidden_states)
149
+ hidden_states = self.mlp(hidden_states)
150
+ hidden_states = residual + hidden_states
151
+
152
+ return hidden_states
153
+
154
+
155
+ class ColPaliAttention(nn.Module):
156
+ def __init__(self, self_attn):
157
+ super().__init__()
158
+ self._original_mod = self_attn
159
+ self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
160
+ self._original_mod.config, "num_attention_heads"
161
+ )
162
+ self.head_dim = self._original_mod.head_dim
163
+ self.scaling = self.head_dim**-0.5
164
+
165
+ if hasattr(self._original_mod, "num_key_value_heads"):
166
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
167
+ elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
168
+ self.num_key_value_heads = self._original_mod.config.num_key_value_heads
169
+ else:
170
+ self.num_key_value_heads = self.num_heads
171
+
172
+ self.__post_init__()
173
+
174
+ def __post_init__(self):
175
+ self.q_proj = self._original_mod.q_proj
176
+ self.k_proj = self._original_mod.k_proj
177
+ self.v_proj = self._original_mod.v_proj
178
+ self.o_proj = self._original_mod.o_proj
179
+
180
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
181
+ query_states = self.q_proj(hidden_states)
182
+ key_states = self.k_proj(hidden_states)
183
+ value_states = self.v_proj(hidden_states)
184
+
185
+ return query_states, key_states, value_states
186
+
187
+ def forward(
188
+ self,
189
+ hidden_states: torch.Tensor,
190
+ attention_mask: torch.Tensor,
191
+ cos: Optional[torch.Tensor] = None,
192
+ sin: Optional[torch.Tensor] = None,
193
+ ):
194
+ batch_size, query_length, _ = hidden_states.size()
195
+
196
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
197
+
198
+ query_states = query_states.view(batch_size, query_length, 1, self.num_heads, self.head_dim).transpose(1, 3)
199
+ key_states = key_states.view(batch_size, query_length, 1, self.num_key_value_heads, self.head_dim).transpose(
200
+ 1, 3
201
+ )
202
+ value_states = value_states.view(
203
+ batch_size, query_length, 1, self.num_key_value_heads, self.head_dim
204
+ ).transpose(1, 3)
205
+
206
+ if cos is not None and sin is not None:
207
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
208
+
209
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4)) * self.scaling
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32)
212
+ attn_output = torch.matmul(attn_weights, value_states)
213
+ attn_output = attn_output.transpose(1, 3)
214
+
215
+ attn_output = attn_output.reshape(batch_size, query_length, -1)
216
+ attn_output = self.o_proj(attn_output)
217
+
218
+ return attn_output
@@ -0,0 +1,84 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Any, List, Optional, Union
15
+
16
+ from ....configuration_utils import RBLNModelConfig
17
+ from ....utils.logging import get_logger
18
+
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ class RBLNColPaliForRetrievalConfig(RBLNModelConfig):
24
+ """
25
+ Configuration class for RBLN ColPali models for document retrieval.
26
+
27
+ This class extends RBLNModelConfig with specific configurations for ColPali models,
28
+ including vision tower settings and multi-sequence length support.
29
+
30
+ Example usage:
31
+ ```python
32
+ from optimum.rbln import RBLNColPaliForRetrieval, RBLNColPaliForRetrievalConfig
33
+
34
+ # Create a configuration object
35
+ config = RBLNColPaliForRetrievalConfig(
36
+ max_seq_lens=1152,
37
+ output_hidden_states=False,
38
+ tensor_parallel_size=4
39
+ )
40
+
41
+ # Use the configuration with from_pretrained
42
+ model = RBLNColPaliForRetrieval.from_pretrained(
43
+ "vidore/colpali-v1.3-hf",
44
+ export=True,
45
+ rbln_config=config
46
+ )
47
+ ```
48
+ """
49
+
50
+ submodules = ["vision_tower"]
51
+
52
+ def __init__(
53
+ self,
54
+ batch_size: Optional[int] = None,
55
+ max_seq_lens: Union[int, List[int]] = None,
56
+ output_hidden_states: Optional[bool] = None,
57
+ vision_tower: Optional[RBLNModelConfig] = None,
58
+ **kwargs: Any,
59
+ ):
60
+ """
61
+ Args:
62
+ batch_size (Optional[int]): The batch size for the model.
63
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
64
+ max_seq_lens (Union[int, List[int]]): The maximum sequence lengths for the language model.
65
+ This can be multiple values, and the model will be compiled for each max_seq_len, allowing selection of the most appropriate max_seq_len at inference time.
66
+ output_hidden_states (Optional[bool]): Whether to output the hidden states of the language model.
67
+ vision_tower (Optional[RBLNModelConfig]): Configuration for the vision encoder component.
68
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
69
+ Raises:
70
+ ValueError: If batch_size is not a positive integer.
71
+ """
72
+ super().__init__(**kwargs)
73
+ self.batch_size = batch_size or 1
74
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
75
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
76
+
77
+ if self.batch_size != 1:
78
+ logger.warning("Ignore batch_size for ColPali vision tower. It will be set to 1.")
79
+
80
+ self.vision_tower = self.initialize_submodule_config(
81
+ submodule_config=vision_tower, batch_size=1, force_kwargs=True
82
+ )
83
+ self.max_seq_lens = max_seq_lens
84
+ self.output_hidden_states = output_hidden_states