optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,361 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import bisect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
22
+ from transformers.modeling_utils import no_init_weights
23
+ from transformers.models.colpali.modeling_colpali import ColPaliForRetrievalOutput
24
+ from transformers.models.paligemma.modeling_paligemma import PaliGemmaMultiModalProjector
25
+
26
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
27
+ from ....modeling import RBLNModel
28
+ from ...utils.rbln_runtime_wrapper import LoopProcessor
29
+ from .colpali_architecture import RBLNColPaliForRetrievalWrapper
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
34
+
35
+
36
+ class LoopVisionTower(LoopProcessor):
37
+ def __init__(self, vision_tower: "RBLNModel"):
38
+ super().__init__(model=vision_tower.model[0])
39
+
40
+ def _get_batch_size(self, pixel_values, **kwargs):
41
+ return pixel_values.shape[0]
42
+
43
+ def _prepare_inputs_for_iteration(self, index, common_inputs, pixel_values, **kwargs):
44
+ pixel_values_item = pixel_values[index : index + 1]
45
+ out_buffer = kwargs["out"][index : index + 1]
46
+ return ([pixel_values_item], {"out": out_buffer})
47
+
48
+ def _process_outputs(self, outputs: list, **kwargs) -> "BaseModelOutputWithPooling":
49
+ return BaseModelOutputWithPooling(
50
+ last_hidden_state=kwargs["out"],
51
+ )
52
+
53
+
54
+ class LoopLanguageModel(LoopProcessor):
55
+ def __init__(self, language_model: RBLNModel, rbln_config: RBLNModelConfig):
56
+ super().__init__(model=language_model)
57
+ self.rbln_config = rbln_config
58
+
59
+ def _get_batch_size(self, inputs_embeds, **kwargs):
60
+ return inputs_embeds.shape[0]
61
+
62
+ def _prepare_inputs_before_loop(self, *, inputs_embeds: torch.Tensor, attention_mask: torch.Tensor, **kwargs):
63
+ input_len = inputs_embeds.shape[1]
64
+ idx = bisect.bisect_left(self.rbln_config.max_seq_lens, input_len)
65
+ if idx == len(self.rbln_config.max_seq_lens):
66
+ raise ValueError(
67
+ f"Required seq_len({input_len}) is larger than available max_seq_lens({self.rbln_config.max_seq_lens})."
68
+ )
69
+ max_seq_len = self.rbln_config.max_seq_lens[idx]
70
+ padded_inputs_embed = torch.nn.functional.pad(inputs_embeds, (0, 0, 0, max_seq_len - input_len))
71
+ padded_attn_mask = torch.nn.functional.pad(attention_mask, (0, max_seq_len - input_len)).to(torch.float32)
72
+ padded_position_ids = torch.arange(max_seq_len, dtype=torch.int32).view(1, -1)
73
+
74
+ return {
75
+ "padded_inputs_embed": padded_inputs_embed,
76
+ "padded_attn_mask": padded_attn_mask,
77
+ "padded_position_ids": padded_position_ids,
78
+ }
79
+
80
+ def _prepare_inputs_for_iteration(self, index: int, common_inputs, *args, **kwargs):
81
+ item_kwargs = {
82
+ "inputs_embeds": common_inputs["padded_inputs_embed"][index : index + 1],
83
+ "attention_mask": common_inputs["padded_attn_mask"][index : index + 1],
84
+ "position_ids": common_inputs["padded_position_ids"],
85
+ "out": [tensor[index : index + 1] for tensor in kwargs["out"]],
86
+ }
87
+ return ([], item_kwargs)
88
+
89
+ def _process_outputs(self, outputs: list, **kwargs):
90
+ if self.rbln_config.output_hidden_states:
91
+ return kwargs["out"][0], tuple(kwargs["out"][1:])
92
+ else:
93
+ return kwargs["out"]
94
+
95
+
96
+ class RBLNColPaliForRetrieval(RBLNModel):
97
+ """
98
+ The ColPali Model transformer for document retrieval using vision-language models.
99
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
100
+
101
+ A class to convert and run pre-trained transformers based `ColPaliForRetrieval` model on RBLN devices.
102
+ It implements the methods to convert a pre-trained transformers `ColPaliForRetrieval` model into a RBLN transformer model by:
103
+
104
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
105
+ - compiling the resulting graph using the RBLN compiler.
106
+
107
+ **Configuration:**
108
+ This model uses [`RBLNColPaliForRetrievalConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
109
+ the `rbln_config` parameter should be an instance of [`RBLNColPaliForRetrievalConfig`] or a dictionary conforming to its structure.
110
+
111
+ See the [`RBLNColPaliForRetrievalConfig`] class for all available configuration options.
112
+
113
+ Examples:
114
+ ```python
115
+ from optimum.rbln import RBLNColPaliForRetrieval
116
+
117
+ # Simple usage using rbln_* arguments
118
+ # `max_seq_lens` is automatically inferred from the model config
119
+ model = RBLNColPaliForRetrieval.from_pretrained(
120
+ "vidore/colpali-v1.3-hf",
121
+ export=True,
122
+ rbln_max_seq_lens=1152,
123
+ )
124
+
125
+ # Using a config dictionary
126
+ rbln_config = {
127
+ "max_seq_lens": 1152,
128
+ "output_hidden_states": False,
129
+ }
130
+ model = RBLNColPaliForRetrieval.from_pretrained(
131
+ "vidore/colpali-v1.3-hf",
132
+ export=True,
133
+ rbln_config=rbln_config
134
+ )
135
+
136
+ # Using a RBLNColPaliForRetrievalConfig instance (recommended for type checking)
137
+ from optimum.rbln import RBLNColPaliForRetrievalConfig
138
+
139
+ config = RBLNColPaliForRetrievalConfig(
140
+ max_seq_lens=1152,
141
+ output_hidden_states=False,
142
+ tensor_parallel_size=4
143
+ )
144
+ model = RBLNColPaliForRetrieval.from_pretrained(
145
+ "vidore/colpali-v1.3-hf",
146
+ export=True,
147
+ rbln_config=config
148
+ )
149
+ ```
150
+ """
151
+
152
+ auto_model_class = None
153
+ _rbln_submodules = [
154
+ {"name": "vision_tower"},
155
+ ]
156
+
157
+ def __post_init__(self, **kwargs):
158
+ self.vision_tower = LoopVisionTower(self.rbln_submodules[0])
159
+ self.language_model = LoopLanguageModel(self.model[0], self.rbln_config)
160
+
161
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
162
+ self.embed_tokens = self._create_embedding_layer()
163
+ self.embed_tokens.load_state_dict(artifacts["embed_tokens"])
164
+ self.multi_modal_projector = self._create_multi_modal_projector()
165
+ self.multi_modal_projector.load_state_dict(artifacts["multi_modal_projector"])
166
+
167
+ return super().__post_init__(**kwargs)
168
+
169
+ def _create_embedding_layer(self):
170
+ with no_init_weights():
171
+ embed_tokens = torch.nn.Embedding(
172
+ self.config.text_config.vocab_size,
173
+ self.config.text_config.hidden_size,
174
+ self.config.text_config.pad_token_id,
175
+ )
176
+ return embed_tokens
177
+
178
+ def _create_multi_modal_projector(self):
179
+ with no_init_weights():
180
+ multi_modal_projector = PaliGemmaMultiModalProjector(self.config.vlm_config)
181
+ return multi_modal_projector
182
+
183
+ @classmethod
184
+ def _wrap_model_if_needed(cls, model: "PreTrainedModel", rbln_config: RBLNModelConfig):
185
+ return RBLNColPaliForRetrievalWrapper(
186
+ causal_lm=model.vlm,
187
+ embedding_proj_layer=model.embedding_proj_layer,
188
+ max_seq_len=max(rbln_config.max_seq_lens),
189
+ output_hidden_states=rbln_config.output_hidden_states,
190
+ )
191
+
192
+ @classmethod
193
+ def save_torch_artifacts(
194
+ cls,
195
+ model: "PreTrainedModel",
196
+ save_dir_path: Path,
197
+ subfolder: str,
198
+ rbln_config: RBLNModelConfig,
199
+ ):
200
+ save_dict = {}
201
+ save_dict["embed_tokens"] = model.vlm.get_input_embeddings().state_dict()
202
+ save_dict["multi_modal_projector"] = model.vlm.multi_modal_projector.state_dict()
203
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
204
+
205
+ @classmethod
206
+ def _update_rbln_config(
207
+ cls,
208
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]],
209
+ model: Optional["PreTrainedModel"] = None,
210
+ model_config: Optional["PretrainedConfig"] = None,
211
+ rbln_config: Optional[RBLNModelConfig] = None,
212
+ ) -> RBLNModelConfig:
213
+ hidden_size = model_config.vlm_config.text_config.hidden_size
214
+ if rbln_config.max_seq_lens is None:
215
+ rbln_config.max_seq_lens = [model_config.vlm_config.text_config.max_position_embeddings]
216
+ if isinstance(rbln_config.max_seq_lens, int):
217
+ rbln_config.max_seq_lens = [rbln_config.max_seq_lens]
218
+ rbln_config.max_seq_lens = sorted(set(rbln_config.max_seq_lens))
219
+
220
+ if rbln_config.output_hidden_states is None:
221
+ rbln_config.output_hidden_states = model_config.vlm_config.text_config.output_hidden_states
222
+
223
+ input_infos = []
224
+ for max_seq_len in rbln_config.max_seq_lens:
225
+ input_info = [
226
+ ("inputs_embeds", [rbln_config.vision_tower.batch_size, max_seq_len, hidden_size], "float32"),
227
+ ("attention_mask", [rbln_config.vision_tower.batch_size, max_seq_len], "float32"),
228
+ ("position_ids", [rbln_config.vision_tower.batch_size, max_seq_len], "int32"),
229
+ ]
230
+ input_infos.append(input_info)
231
+
232
+ rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
233
+ rbln_config.set_compile_cfgs([rbln_compile_config])
234
+
235
+ return rbln_config
236
+
237
+ @classmethod
238
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
239
+ if hasattr(model, "vlm"):
240
+ model.vision_tower = model.vlm.vision_tower
241
+ del model.vlm.model.vision_tower
242
+ return model
243
+ return model
244
+
245
+ def get_image_features(self, pixel_values: torch.Tensor):
246
+ # Projects the last hidden state from the vision model into language model space.
247
+ # Args:
248
+ # pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
249
+ # The tensors corresponding to the input images.
250
+ # Returns:
251
+ # image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
252
+
253
+ vision_output_size = [
254
+ pixel_values.shape[0],
255
+ self.config.vlm_config.vision_config.num_image_tokens,
256
+ self.config.vlm_config.vision_config.hidden_size,
257
+ ]
258
+ vision_output = torch.empty(size=vision_output_size, dtype=torch.float32, device="cpu")
259
+ self.vision_tower(pixel_values, out=vision_output)
260
+ image_features = self.multi_modal_projector(vision_output)
261
+ image_features = image_features / (self.config.text_config.hidden_size**0.5)
262
+ return image_features
263
+
264
+ def _preprocess_inputs(
265
+ self,
266
+ input_ids: Optional[torch.LongTensor] = None,
267
+ inputs_embeds: Optional[torch.FloatTensor] = None,
268
+ pixel_values: Optional[torch.FloatTensor] = None,
269
+ **kwargs,
270
+ ):
271
+ if (input_ids is None) ^ (inputs_embeds is not None):
272
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
273
+
274
+ # Replace image id woth PAD if the image token if OOV, to avoid index-errors
275
+ if input_ids is not None and self.config.vlm_config.image_token_index >= self.config.text_config.vocab_size:
276
+ special_image_mask = input_ids == self.config.vlm_config.image_token_index
277
+ llm_input_ids = input_ids.clone()
278
+ llm_input_ids[special_image_mask] = 0
279
+ else:
280
+ llm_input_ids = input_ids
281
+
282
+ if inputs_embeds is None:
283
+ inputs_embeds = self.embed_tokens(llm_input_ids)
284
+
285
+ # Merge text and images
286
+ image_features = None
287
+ if pixel_values is not None:
288
+ image_features = self.get_image_features(pixel_values)
289
+ special_image_mask = (input_ids == self.config.vlm_config.image_token_index).unsqueeze(-1)
290
+ special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device)
291
+
292
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
293
+ inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
294
+
295
+ return inputs_embeds, image_features
296
+
297
+ def forward(
298
+ self,
299
+ input_ids: Optional[torch.LongTensor] = None,
300
+ inputs_embeds: Optional[torch.FloatTensor] = None,
301
+ pixel_values: Optional[torch.FloatTensor] = None,
302
+ attention_mask: Optional[torch.Tensor] = None,
303
+ output_attentions: Optional[bool] = None,
304
+ output_hidden_states: Optional[bool] = None,
305
+ return_dict: Optional[bool] = None,
306
+ **kwargs,
307
+ ) -> Union[Tuple, ColPaliForRetrievalOutput]:
308
+ if pixel_values is not None:
309
+ pixel_values = pixel_values.to(dtype=self.dtype)
310
+
311
+ if output_attentions:
312
+ raise ValueError("output_attentions is not supported for RBLNColPaliForRetrieval")
313
+
314
+ if output_hidden_states is not None and output_hidden_states != self.rbln_config.output_hidden_states:
315
+ raise ValueError(
316
+ f"Variable output_hidden_states {output_hidden_states} is not equal to rbln_config.output_hidden_states {self.rbln_config.output_hidden_states} "
317
+ f"Please compile again with the correct argument."
318
+ )
319
+
320
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
321
+
322
+ inputs_embeds, image_features = self._preprocess_inputs(
323
+ input_ids=input_ids, inputs_embeds=inputs_embeds, pixel_values=pixel_values
324
+ )
325
+
326
+ outputs = []
327
+ language_model_out_size = [inputs_embeds.shape[0], self.rbln_config.max_seq_lens[0], self.config.embedding_dim]
328
+ language_model_hidden_states_size = [
329
+ inputs_embeds.shape[0],
330
+ self.rbln_config.max_seq_lens[0],
331
+ self.rbln_config.max_seq_lens[0],
332
+ ]
333
+ outputs.append(torch.empty(size=language_model_out_size, dtype=torch.float32, device="cpu"))
334
+ if self.rbln_config.output_hidden_states:
335
+ for i in range(self.config.vlm_config.text_config.num_hidden_layers + 1):
336
+ outputs.append(torch.empty(size=language_model_hidden_states_size, dtype=torch.float32, device="cpu"))
337
+
338
+ # Embedding_proj_layer is fused on the bottom of the language model.
339
+ self.language_model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, out=outputs)
340
+
341
+ embeddings = outputs[0][:, : inputs_embeds.shape[1]]
342
+ hidden_states = (
343
+ None
344
+ if not self.rbln_config.output_hidden_states
345
+ else [tensor[0][:, : inputs_embeds.shape[1]] for tensor in outputs[1:]]
346
+ )
347
+
348
+ # L2 normalization
349
+ embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (batch_size, sequence_length, dim)
350
+
351
+ if attention_mask is not None:
352
+ embeddings = embeddings * attention_mask.unsqueeze(-1) # (batch_size, sequence_length, dim)
353
+
354
+ if not return_dict:
355
+ return (embeddings, hidden_states, image_features)
356
+ else:
357
+ return ColPaliForRetrievalOutput(
358
+ embeddings=embeddings,
359
+ hidden_states=hidden_states,
360
+ image_hidden_states=image_features,
361
+ )
@@ -0,0 +1,2 @@
1
+ from .configuration_colqwen2 import RBLNColQwen2ForRetrievalConfig
2
+ from .modeling_colqwen2 import RBLNColQwen2ForRetrieval
@@ -0,0 +1,233 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from transformers import PreTrainedModel
20
+
21
+ from optimum.rbln.transformers.models.decoderonly.decoderonly_architecture import (
22
+ DecoderOnlyLayer,
23
+ DecoderOnlyModel,
24
+ DecoderOnlyWrapper,
25
+ )
26
+
27
+ from .configuration_colqwen2 import (
28
+ RBLNColQwen2ForRetrievalConfig,
29
+ )
30
+
31
+
32
+ def slice_and_unsqueeze_cos_sin(cos, sin, position_ids):
33
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
34
+ cos = cos[position_ids[0]][None, None, None, :, :]
35
+ sin = sin[position_ids[0]][None, None, None, :, :]
36
+
37
+ return cos, sin
38
+
39
+
40
+ class ColQwen2LanguageModelWrapper(DecoderOnlyWrapper):
41
+ def __init__(
42
+ self, model: PreTrainedModel, rbln_config: "RBLNColQwen2ForRetrievalConfig", use_rotary_emb: bool = True
43
+ ):
44
+ model.config = (
45
+ model.config.vlm_config.text_config if hasattr(model.config, "vlm_config") else model.config.text_config
46
+ )
47
+ super().__init__(model, rbln_config, use_rotary_emb)
48
+
49
+ def get_decoder_layers(self, model: PreTrainedModel):
50
+ return model.language_model.layers
51
+
52
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
53
+ new_layers = []
54
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
55
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
56
+ new_self_attn = self.get_rbln_attn_class()(
57
+ self.get_attn_layer(layer),
58
+ self.rbln_config,
59
+ is_sliding=is_sliding,
60
+ )
61
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn)
62
+ new_layers.append(new_layer)
63
+
64
+ new_model = self.get_rbln_model_class()(
65
+ model.language_model,
66
+ new_layers,
67
+ self.rbln_config,
68
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
69
+ )
70
+
71
+ # text_projection layer from model
72
+ self.embedding_proj_layer = (
73
+ model.embedding_proj_layer if hasattr(model, "embedding_proj_layer") else model.custom_text_proj
74
+ )
75
+ return new_model
76
+
77
+ def get_rbln_model_class(self):
78
+ return RBLNColQwen2LanguageModel
79
+
80
+ def prepare_forward_args(self, *args):
81
+ args = list(args)
82
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
83
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
84
+ cache_position = args.pop(0)
85
+ global_block_tables = args.pop(0)
86
+ local_block_tables = None
87
+ position_embeds = args.pop(0)
88
+ position_ids = None
89
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
90
+ past_key_values = args
91
+
92
+ if len(past_key_values) != 2 * self.num_hidden_layers:
93
+ raise ValueError(
94
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
95
+ )
96
+
97
+ _past_key_values = []
98
+ for i in range(self.config.num_hidden_layers):
99
+ key_states = past_key_values[i * 2]
100
+ value_states = past_key_values[i * 2 + 1]
101
+ past_key_value = [key_states, value_states]
102
+ _past_key_values.append(past_key_value)
103
+ past_key_values = _past_key_values
104
+
105
+ return (
106
+ input_ids,
107
+ inputs_embeds,
108
+ cache_position,
109
+ global_block_tables,
110
+ local_block_tables,
111
+ attention_mask,
112
+ position_ids,
113
+ past_key_values,
114
+ position_embeds,
115
+ )
116
+
117
+ def forward(self, *args):
118
+ (
119
+ input_ids,
120
+ inputs_embeds,
121
+ cache_position,
122
+ global_block_tables,
123
+ local_block_tables,
124
+ attention_mask,
125
+ position_ids,
126
+ past_key_values,
127
+ rotary_emb,
128
+ ) = self.prepare_forward_args(*args)
129
+
130
+ last_hidden_states = self.model(
131
+ input_ids=input_ids,
132
+ inputs_embeds=inputs_embeds,
133
+ attention_mask=attention_mask,
134
+ cache_position=cache_position,
135
+ position_ids=position_ids,
136
+ past_key_values=past_key_values,
137
+ rotary_emb=rotary_emb,
138
+ global_block_tables=global_block_tables,
139
+ local_block_tables=local_block_tables,
140
+ )
141
+
142
+ proj = self.embedding_proj_layer(last_hidden_states[0])
143
+ all_hidden_states = last_hidden_states[1] if self.rbln_config.output_hidden_states else None
144
+
145
+ if self.rbln_config.output_hidden_states:
146
+ return proj, all_hidden_states
147
+ else:
148
+ return proj
149
+
150
+
151
+ class RBLNColQwen2LanguageModel(DecoderOnlyModel):
152
+ def __init__(
153
+ self,
154
+ model,
155
+ layers: List["DecoderOnlyLayer"],
156
+ rbln_config: "RBLNColQwen2ForRetrievalConfig",
157
+ use_learned_pos_emb=None,
158
+ ):
159
+ super().__init__(model, layers, rbln_config, use_learned_pos_emb)
160
+
161
+ self.output_hidden_states = rbln_config.output_hidden_states
162
+
163
+ def forward(
164
+ self,
165
+ input_ids: torch.Tensor = None,
166
+ inputs_embeds: Optional[torch.Tensor] = None,
167
+ attention_mask: torch.Tensor = None,
168
+ cache_position: torch.Tensor = None,
169
+ position_ids: torch.Tensor = None,
170
+ query_position: torch.Tensor = None,
171
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
172
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
173
+ global_block_tables: Optional[torch.Tensor] = None,
174
+ local_block_tables: Optional[torch.Tensor] = None,
175
+ lora_int_id: Optional[torch.Tensor] = None,
176
+ ):
177
+ # retrieve input_ids and inputs_embeds
178
+ if (input_ids is None) ^ (inputs_embeds is not None):
179
+ raise ValueError(
180
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
181
+ )
182
+
183
+ # embed positions
184
+ if inputs_embeds is None:
185
+ inputs_embeds = self.get_embedding()(input_ids)
186
+
187
+ hidden_states = inputs_embeds * self.hidden_multiplier
188
+
189
+ # get cos,sin vector if needed
190
+ position_ids = position_ids if position_ids is not None else cache_position
191
+ if rotary_emb is not None:
192
+ if isinstance(rotary_emb, torch.Tensor):
193
+ cos = rotary_emb[0]
194
+ sin = rotary_emb[1]
195
+ else:
196
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
197
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
198
+
199
+ # Get sequence positions for flash attention
200
+ if self.attn_impl == "flash_attn":
201
+ seq_positions = cache_position[:, 0]
202
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
203
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
204
+ )
205
+ else:
206
+ seq_positions = cache_position[:, :1]
207
+
208
+ # Get local cache positions for sliding window layers
209
+ if len(self.sliding_window_layers) > 0:
210
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
211
+
212
+ all_hidden_states = () if self.output_hidden_states else None
213
+ for layer_idx, layer in enumerate(self.layers):
214
+ if self.output_hidden_states:
215
+ all_hidden_states += (hidden_states,)
216
+
217
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
218
+ hidden_states = layer(
219
+ hidden_states=hidden_states,
220
+ attention_mask=attention_mask,
221
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
222
+ past_key_values=past_key_values,
223
+ cos=cos,
224
+ sin=sin,
225
+ block_tables=local_block_tables if is_sliding else global_block_tables,
226
+ lora_int_id=lora_int_id,
227
+ )
228
+
229
+ hidden_states = self.get_last_layernorm()(hidden_states)
230
+ if self.output_hidden_states:
231
+ all_hidden_states += (hidden_states,)
232
+
233
+ return hidden_states, all_hidden_states
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from optimum.rbln.configuration_utils import RBLNModelConfig
18
+
19
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig
20
+
21
+
22
+ class RBLNColQwen2ForRetrievalConfig(RBLNDecoderOnlyModelConfig):
23
+ """
24
+ Configuration class for RBLN ColQwen2 models for document retrieval.
25
+
26
+ This class extends RBLNModelConfig with specific configurations for ColQwen2 models,
27
+ including vision tower settings and multi-sequence length support.
28
+
29
+ Example usage:
30
+ ```python
31
+ from optimum.rbln import RBLNColQwen2ForRetrievalConfig, RBLNColQwen2ForRetrievalConfig
32
+
33
+ # Create a configuration object
34
+ config = RBLNColQwen2ForRetrievalConfig(
35
+ visual={
36
+ "max_seq_lens": 6400,
37
+ "device": 0,
38
+ },
39
+ max_seq_len=32_768,
40
+ tensor_parallel_size=4,
41
+ device=[0, 1, 2, 3],
42
+ output_hidden_states=False,
43
+ )
44
+
45
+ # Use the configuration with from_pretrained
46
+ model = RBLNColQwen2ForRetrieval.from_pretrained(
47
+ "vidore/colqwen2-v1.0-hf",
48
+ export=True,
49
+ rbln_config=config
50
+ )
51
+ ```
52
+ """
53
+
54
+ submodules = ["visual"]
55
+
56
+ def __init__(
57
+ self,
58
+ visual: Optional[RBLNModelConfig] = None,
59
+ batch_size: Optional[int] = None,
60
+ use_inputs_embeds: bool = True,
61
+ output_hidden_states: Optional[bool] = False,
62
+ **kwargs,
63
+ ):
64
+ super().__init__(use_inputs_embeds=use_inputs_embeds, **kwargs)
65
+ if not self.use_inputs_embeds:
66
+ raise ValueError(
67
+ "RBLNColQwen2ForRetrievalConfig does not allow `use_inputs_embeds` to be set to False, "
68
+ "as RBLNColQwen2ForRetrieval accepts only `inputs_embeds` as input."
69
+ )
70
+ if batch_size is not None and batch_size != 1:
71
+ raise ValueError("batch_size is not supported for RBLNColQwen2ForRetrievalConfig")
72
+
73
+ self.visual = visual
74
+ self.output_hidden_states = output_hidden_states