optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,636 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ from transformers import AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, Qwen2_5_VLForConditionalGeneration
21
+ from transformers.modeling_utils import no_init_weights
22
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
23
+ Qwen2_5_VisionPatchEmbed,
24
+ Qwen2_5_VisionRotaryEmbedding,
25
+ Qwen2_5_VisionTransformerPretrainedModel,
26
+ Qwen2_5_VLModel,
27
+ Qwen2_5_VLRotaryEmbedding,
28
+ )
29
+
30
+ from ....configuration_utils import RBLNCompileConfig
31
+ from ....modeling import RBLNModel
32
+ from ....utils.logging import get_logger
33
+ from ...modeling_outputs import RBLNDecoderOnlyOutput
34
+ from ..decoderonly.modeling_decoderonly import RBLNDecoderOnlyModelForCausalLM
35
+ from .configuration_qwen2_5_vl import (
36
+ RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
37
+ RBLNQwen2_5_VLForConditionalGenerationConfig,
38
+ )
39
+ from .qwen2_5_vl_architecture import Qwen2_5_VisionTransformerWrapper, Qwen2_5_VL_LanguageModelWrapper
40
+
41
+
42
+ logger = get_logger(__name__)
43
+
44
+ if TYPE_CHECKING:
45
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
46
+
47
+
48
+ class RBLNQwen2_5_VisionTransformerPretrainedModel(RBLNModel):
49
+ """
50
+ RBLN optimized Qwen2.5-VL vision transformer model.
51
+
52
+ This class provides hardware-accelerated inference for Qwen2.5-VL vision transformers
53
+ on RBLN devices, supporting image and video encoding for multimodal vision-language tasks
54
+ with window-based attention mechanisms.
55
+ """
56
+
57
+ auto_model_class = None
58
+
59
+ def __post_init__(self, **kwargs):
60
+ self.transformer = self.model[0]
61
+ self.max_seq_lens = torch.tensor(sorted(self.rbln_config.max_seq_lens, reverse=False))
62
+ config = self.config
63
+ self.window_size = config.window_size
64
+ self.patch_size = config.spatial_patch_size
65
+ self.spatial_merge_size = config.spatial_merge_size
66
+ self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
67
+ self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding((config.hidden_size // config.num_heads) // 2)
68
+ with no_init_weights():
69
+ self.patch_embed = Qwen2_5_VisionPatchEmbed(
70
+ patch_size=config.patch_size,
71
+ temporal_patch_size=config.temporal_patch_size,
72
+ in_channels=config.in_channels,
73
+ embed_dim=config.hidden_size,
74
+ )
75
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
76
+ self.patch_embed.load_state_dict(artifacts["patch_embed"])
77
+
78
+ @classmethod
79
+ def save_torch_artifacts(
80
+ cls,
81
+ model: "Qwen2_5_VLForConditionalGeneration",
82
+ save_dir_path: Path,
83
+ subfolder: str,
84
+ rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig,
85
+ ):
86
+ save_dict = {}
87
+ save_dict["patch_embed"] = model.patch_embed.state_dict()
88
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
89
+
90
+ @classmethod
91
+ def _wrap_model_if_needed(
92
+ cls, model: "PreTrainedModel", rbln_config: RBLNQwen2_5_VisionTransformerPretrainedModelConfig
93
+ ):
94
+ return Qwen2_5_VisionTransformerWrapper(model).eval()
95
+
96
+ def __getattr__(self, __name: str) -> Any:
97
+ def redirect(func):
98
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
99
+
100
+ val = getattr(Qwen2_5_VisionTransformerPretrainedModel, __name)
101
+
102
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
103
+ return redirect(val)
104
+ return val
105
+
106
+ @classmethod
107
+ def _update_rbln_config(
108
+ cls,
109
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
110
+ model: Optional["PreTrainedModel"] = None,
111
+ model_config: "PretrainedConfig" = None,
112
+ rbln_config: Optional[RBLNQwen2_5_VisionTransformerPretrainedModelConfig] = None,
113
+ ) -> RBLNQwen2_5_VisionTransformerPretrainedModelConfig:
114
+ window_size = getattr(model_config, "window_size")
115
+ patch_size = getattr(model_config, "patch_size")
116
+ hidden_size = getattr(model_config, "hidden_size")
117
+ num_heads = getattr(model_config, "num_heads")
118
+ head_dim = hidden_size // num_heads
119
+ window_seq_len = (window_size // patch_size) ** 2
120
+
121
+ input_infos = []
122
+ for max_seq_len in rbln_config.max_seq_lens:
123
+ if max_seq_len % window_seq_len > 0:
124
+ raise ValueError(
125
+ f"max_seq_len ({max_seq_len}) must be a multiple of window_seq_len ({window_seq_len})."
126
+ )
127
+
128
+ input_info = [
129
+ ("hidden_states", [max_seq_len, hidden_size], "float32"),
130
+ ("full_attn_masks", [1, 1, max_seq_len, max_seq_len], "float32"),
131
+ (
132
+ "window_attn_masks",
133
+ [max_seq_len // window_seq_len, 1, window_seq_len, window_seq_len],
134
+ "float32",
135
+ ),
136
+ (
137
+ "cos",
138
+ [1, 1, max_seq_len, head_dim],
139
+ "float32",
140
+ ),
141
+ (
142
+ "sin",
143
+ [1, 1, max_seq_len, head_dim],
144
+ "float32",
145
+ ),
146
+ ]
147
+ input_infos.append(input_info)
148
+
149
+ rbln_compile_config = RBLNCompileConfig(input_info=input_infos)
150
+ rbln_config.set_compile_cfgs([rbln_compile_config])
151
+
152
+ return rbln_config
153
+
154
+ @staticmethod
155
+ def _pad_for_window_attn_layers(
156
+ window_indice: List[int],
157
+ hidden_states: torch.Tensor,
158
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
159
+ window_seq_len: int,
160
+ max_seq_len: int,
161
+ ):
162
+ # Padding for Window Attention
163
+ padded_hidden_state = []
164
+ padded_cos = []
165
+ padded_sin = []
166
+ window_valid_lengths = []
167
+ for i in range(len(window_indice) - 1):
168
+ start, end = window_indice[i], window_indice[i + 1]
169
+ segment = hidden_states[start:end]
170
+ cos_segment = position_embeddings[0][start:end]
171
+ sin_segment = position_embeddings[1][start:end]
172
+ segment_len = end - start
173
+
174
+ if segment_len < window_seq_len:
175
+ padding_size = window_seq_len - segment_len
176
+ padding = torch.zeros(
177
+ padding_size,
178
+ segment.shape[-1],
179
+ dtype=segment.dtype,
180
+ )
181
+ padding_pos = torch.zeros(
182
+ padding_size,
183
+ cos_segment.shape[-1],
184
+ dtype=cos_segment.dtype,
185
+ )
186
+ padded_segment = torch.cat([segment, padding], dim=0)
187
+ padded_cos_segment = torch.cat([cos_segment, padding_pos], dim=0)
188
+ padded_sin_segment = torch.cat([sin_segment, padding_pos], dim=0)
189
+ else:
190
+ padded_segment = segment
191
+ padded_cos_segment = cos_segment
192
+ padded_sin_segment = sin_segment
193
+ padded_hidden_state.append(padded_segment)
194
+ window_valid_lengths.append(segment_len)
195
+ padded_cos.append(padded_cos_segment)
196
+ padded_sin.append(padded_sin_segment)
197
+ hidden_state_padded = torch.cat(padded_hidden_state)
198
+ cos_padded = torch.cat(padded_cos, dim=0)
199
+ sin_padded = torch.cat(padded_sin, dim=0)
200
+
201
+ window_attn_masks = torch.ones(
202
+ max_seq_len // window_seq_len,
203
+ 1,
204
+ window_seq_len,
205
+ window_seq_len,
206
+ dtype=torch.float32,
207
+ )
208
+ for i, valid_len in enumerate(window_valid_lengths):
209
+ if valid_len < window_seq_len:
210
+ window_attn_masks[i, :, valid_len:, :] = 0
211
+ window_attn_masks[i, :, :, valid_len:] = 0
212
+
213
+ return hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths
214
+
215
+ @staticmethod
216
+ def _pad_for_full_attn_layers(
217
+ hidden_state_padded, cos_padded, sin_padded, max_seq_len, window_valid_lengths, window_seq_len
218
+ ):
219
+ if hidden_state_padded.shape[0] < max_seq_len:
220
+ full_padding_size = max_seq_len - hidden_state_padded.shape[0]
221
+ full_padding_hidden = torch.zeros(
222
+ full_padding_size,
223
+ hidden_state_padded.shape[-1],
224
+ dtype=hidden_state_padded.dtype,
225
+ )
226
+ hidden_state_full_padded = torch.cat([hidden_state_padded, full_padding_hidden], dim=0) # [5120, 1280]
227
+ full_padding_pos = torch.zeros(
228
+ full_padding_size,
229
+ cos_padded.shape[-1],
230
+ dtype=cos_padded.dtype,
231
+ )
232
+ cos_full_padded = torch.cat([cos_padded, full_padding_pos], dim=0)
233
+ sin_full_padded = torch.cat([sin_padded, full_padding_pos], dim=0)
234
+ window_valid_lengths.extend([0] * (max_seq_len // window_seq_len - len(window_valid_lengths)))
235
+ else:
236
+ hidden_state_full_padded = hidden_state_padded
237
+ cos_full_padded = cos_padded
238
+ sin_full_padded = sin_padded
239
+
240
+ full_attn_masks = torch.ones(
241
+ 1,
242
+ 1,
243
+ max_seq_len,
244
+ max_seq_len,
245
+ dtype=torch.float32,
246
+ )
247
+ for i, valid_len in enumerate(window_valid_lengths):
248
+ start = i * window_seq_len
249
+ end = start + window_seq_len
250
+ full_attn_masks[:, :, start + valid_len : end, :] = 0
251
+ full_attn_masks[:, :, :, start + valid_len : end] = 0
252
+
253
+ return hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks
254
+
255
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
256
+ hidden_states = self.patch_embed(hidden_states)
257
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
258
+ window_index, cu_window_seqlens = self.get_window_index(grid_thw)
259
+ cu_window_seqlens = torch.tensor(
260
+ cu_window_seqlens,
261
+ dtype=torch.int32,
262
+ )
263
+ cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
264
+
265
+ seq_len, _ = hidden_states.size()
266
+ hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
267
+ hidden_states = hidden_states[window_index, :, :]
268
+ hidden_states = hidden_states.reshape(seq_len, -1)
269
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
270
+ rotary_pos_emb = rotary_pos_emb[window_index, :, :]
271
+ rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
272
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
273
+ position_embeddings = (emb.cos(), emb.sin())
274
+
275
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
276
+ dim=0,
277
+ dtype=torch.int32,
278
+ )
279
+ cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0)
280
+
281
+ num_images = len(cu_seqlens) - 1
282
+ cu_window_seqlens = cu_window_seqlens.tolist()
283
+ window_seq_len = (self.window_size // self.patch_size) ** 2
284
+
285
+ output_hidden_states = []
286
+
287
+ # Process each image in the sequence
288
+ for i in range(num_images):
289
+ image_s, image_e = cu_seqlens[i], cu_seqlens[i + 1]
290
+ window_indice = cu_window_seqlens[cu_window_seqlens.index(image_s) : cu_window_seqlens.index(image_e) + 1]
291
+
292
+ # Select the nearest higher max_seq_len from the available compiled models.
293
+ window_padded_len = len(window_indice) * window_seq_len
294
+ try:
295
+ ws_index = torch.searchsorted(self.max_seq_lens, window_padded_len).item()
296
+ max_seq_len = self.max_seq_lens[ws_index]
297
+ except Exception:
298
+ raise ValueError(
299
+ f"Required seq_len({window_padded_len}) is larger than available max_seq_lens({self.max_seq_lens.tolist()})."
300
+ )
301
+
302
+ # Padding for Window Attention Layers
303
+ hidden_state_padded, cos_padded, sin_padded, window_attn_masks, window_valid_lengths = (
304
+ self._pad_for_window_attn_layers(
305
+ window_indice, hidden_states, position_embeddings, window_seq_len, max_seq_len
306
+ )
307
+ )
308
+
309
+ # Padding for Full Attention Layers
310
+ hidden_state_full_padded, cos_full_padded, sin_full_padded, full_attn_masks = (
311
+ self._pad_for_full_attn_layers(
312
+ hidden_state_padded, cos_padded, sin_padded, max_seq_len, window_valid_lengths, window_seq_len
313
+ )
314
+ )
315
+
316
+ # RBLN run with the compiled model
317
+ output = self.transformer(
318
+ hidden_state_full_padded,
319
+ full_attn_masks,
320
+ window_attn_masks,
321
+ cos_full_padded[None, None, :, :],
322
+ sin_full_padded[None, None, :, :],
323
+ )
324
+
325
+ # Depadding
326
+ depadded_output = []
327
+ for i, valid_len in enumerate(window_valid_lengths):
328
+ start = i * (window_seq_len // self.spatial_merge_unit)
329
+ end = start + (valid_len // self.spatial_merge_unit)
330
+ depadded_output.append(output[start:end])
331
+ output = torch.cat(depadded_output, dim=0)
332
+
333
+ output_hidden_states.append(output)
334
+ hidden_states = torch.cat(output_hidden_states)
335
+ reverse_indices = torch.argsort(window_index)
336
+ hidden_states = hidden_states[reverse_indices, :]
337
+
338
+ return hidden_states
339
+
340
+
341
+ class RBLNQwen2_5_VLForConditionalGeneration(RBLNDecoderOnlyModelForCausalLM):
342
+ """
343
+ RBLNQwen2_5_VLForConditionalGeneration is a multi-modal model that integrates vision and language processing capabilities,
344
+ optimized for RBLN NPUs. It is designed for conditional generation tasks that involve both image and text inputs.
345
+
346
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
347
+
348
+ Important Note:
349
+ This model includes a Large Language Model (LLM). For optimal performance, it is highly recommended to use
350
+ tensor parallelism for the language model. This can be achieved by using the `rbln_config` parameter in the
351
+ `from_pretrained` method. Refer to the `from_pretrained` documentation and the RBLNQwen2_5_VLForConditionalGenerationConfig class for details.
352
+
353
+ Examples:
354
+ ```python
355
+ from optimum.rbln import RBLNQwen2_5_VLForConditionalGeneration
356
+
357
+ model = RBLNQwen2_5_VLForConditionalGeneration.from_pretrained(
358
+ "Qwen/Qwen2.5-VL-7B-Instruct",
359
+ export=True,
360
+ rbln_config={
361
+ "visual": {
362
+ "max_seq_lens": 6400,
363
+ "device": 0,
364
+ },
365
+ "tensor_parallel_size": 8,
366
+ "kvcache_partition_len": 16_384,
367
+ "max_seq_len": 114_688,
368
+ "device": [0, 1, 2, 3, 4, 5, 6, 7],
369
+ },
370
+ )
371
+
372
+ model.save_pretrained("compiled-qwen2.5-vl-7b-instruct")
373
+ ```
374
+ """
375
+
376
+ _supports_non_fp32 = False
377
+
378
+ auto_model_class = AutoModelForVision2Seq
379
+ _rbln_submodules = [
380
+ {"name": "visual"},
381
+ ]
382
+ _decoder_wrapper_cls = Qwen2_5_VL_LanguageModelWrapper
383
+ _use_rotary_emb = False
384
+
385
+ def __post_init__(self, **kwargs):
386
+ super().__post_init__(**kwargs)
387
+ self.visual = self.rbln_submodules[0]
388
+ self.mrope_section = self.config.rope_scaling["mrope_section"]
389
+ self.rotary_emb = Qwen2_5_VLRotaryEmbedding(self.config)
390
+ self.rope_deltas = torch.zeros(self.rbln_config.batch_size)
391
+
392
+ def can_generate(self):
393
+ return True
394
+
395
+ @classmethod
396
+ def _reconstruct_model_if_needed(cls, model: "PreTrainedModel"):
397
+ model.model.lm_head = model.lm_head
398
+ model.lm_head = None
399
+ del model.lm_head
400
+ return model
401
+
402
+ @classmethod
403
+ def get_input_info(
404
+ cls,
405
+ batch_size: int,
406
+ query_length: int,
407
+ rbln_config: RBLNQwen2_5_VLForConditionalGenerationConfig,
408
+ model_config: PretrainedConfig,
409
+ ):
410
+ input_info = super().get_input_info(batch_size, query_length, rbln_config, model_config)
411
+ pos_idx = 3
412
+ input_info.insert(
413
+ pos_idx,
414
+ (
415
+ "position_emb",
416
+ [2, batch_size, 1, query_length, model_config.hidden_size // model_config.num_attention_heads],
417
+ "float32",
418
+ ),
419
+ )
420
+
421
+ return input_info
422
+
423
+ def prepare_inputs_for_generation(
424
+ self,
425
+ input_ids: torch.LongTensor,
426
+ generate_idx: Optional[torch.Tensor] = None,
427
+ attention_mask: Optional[torch.LongTensor] = None,
428
+ inputs_embeds: Optional[torch.Tensor] = None,
429
+ pixel_values=None,
430
+ pixel_values_videos=None,
431
+ image_grid_thw=None,
432
+ video_grid_thw=None,
433
+ second_per_grid_ts=None,
434
+ **kwargs,
435
+ ):
436
+ model_inputs = {}
437
+ is_prefill_phase = generate_idx is None
438
+
439
+ if is_prefill_phase:
440
+ generate_idx = attention_mask.sum(dim=-1, keepdim=True).int()
441
+ cache_position = None
442
+ model_inputs.update({"input_ids": input_ids})
443
+ else:
444
+ if inputs_embeds is not None:
445
+ raise NotImplementedError("Specifying inputs_embeds in decoder phase is not supported.")
446
+
447
+ input_ids = input_ids[:, -1:]
448
+ cache_position = generate_idx
449
+ generate_idx = generate_idx + 1
450
+ model_inputs.update({"input_ids": input_ids})
451
+
452
+ model_inputs.update(
453
+ {
454
+ "attention_mask": attention_mask,
455
+ "cache_position": cache_position,
456
+ "generate_idx": generate_idx,
457
+ "pixel_values": pixel_values,
458
+ "pixel_values_videos": pixel_values_videos,
459
+ "image_grid_thw": image_grid_thw,
460
+ "video_grid_thw": video_grid_thw,
461
+ "second_per_grid_ts": second_per_grid_ts,
462
+ }
463
+ )
464
+
465
+ return model_inputs
466
+
467
+ def _get_position_embeddings(self, hidden_states, position_ids):
468
+ cos, sin = self.rotary_emb(hidden_states, position_ids)
469
+ mrope_section = self.mrope_section * 2
470
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
471
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(1)
472
+ return torch.stack([cos, sin])
473
+
474
+ def _preprocess_prefill(
475
+ self,
476
+ input_ids: torch.LongTensor = None,
477
+ attention_mask: torch.Tensor = None,
478
+ pixel_values: torch.Tensor = None,
479
+ pixel_values_videos: torch.FloatTensor = None,
480
+ image_grid_thw: torch.LongTensor = None,
481
+ video_grid_thw: torch.LongTensor = None,
482
+ second_per_grid_ts: torch.Tensor = None,
483
+ ):
484
+ batch_size = input_ids.shape[0]
485
+ inputs_embeds = self.embed_tokens(input_ids)
486
+
487
+ if pixel_values is not None:
488
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
489
+ n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
490
+ n_image_features = image_embeds.shape[0]
491
+ if n_image_tokens != n_image_features:
492
+ raise ValueError(
493
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
494
+ )
495
+
496
+ mask = input_ids == self.config.image_token_id
497
+ mask_unsqueezed = mask.unsqueeze(-1)
498
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
499
+
500
+ image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
501
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, image_embeds)
502
+
503
+ if pixel_values_videos is not None:
504
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
505
+ n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
506
+ n_video_features = video_embeds.shape[0]
507
+ if n_video_tokens != n_video_features:
508
+ raise ValueError(
509
+ f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
510
+ )
511
+
512
+ mask = input_ids == self.config.video_token_id
513
+ mask_unsqueezed = mask.unsqueeze(-1)
514
+ mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
515
+ inputs_embeds = inputs_embeds.masked_scatter(mask_expanded, video_embeds)
516
+
517
+ max_inputs_len = input_ids.shape[1]
518
+
519
+ head_dim = getattr(self.config, "head_dim", None) or self.config.hidden_size // self.config.num_attention_heads
520
+ all_position_embeds = torch.zeros(2, batch_size, 1, max_inputs_len, head_dim)
521
+ all_rope_deltas = []
522
+
523
+ image_token_id = self.config.image_token_id
524
+ video_token_id = self.config.video_token_id
525
+ vision_start_token_id = self.config.vision_start_token_id
526
+ image_idx, video_idx = 0, 0
527
+
528
+ for b_idx in range(batch_size):
529
+ input_id = input_ids[b_idx : b_idx + 1][:, attention_mask[b_idx].bool()]
530
+ vision_start_indices = torch.argwhere(input_id == vision_start_token_id).squeeze(1)
531
+ vision_tokens = input_id[0][vision_start_indices + 1]
532
+ image_nums = (vision_tokens == image_token_id).sum()
533
+ video_nums = (vision_tokens == video_token_id).sum()
534
+ position_ids, rope_deltas = Qwen2_5_VLModel.get_rope_index(
535
+ self,
536
+ input_id,
537
+ image_grid_thw[image_idx : image_idx + image_nums] if image_grid_thw is not None else None,
538
+ video_grid_thw[video_idx : video_idx + video_nums] if video_grid_thw is not None else None,
539
+ second_per_grid_ts[video_idx : video_idx + video_nums] if second_per_grid_ts is not None else None,
540
+ )
541
+ image_idx += image_nums
542
+ video_idx += video_nums
543
+
544
+ position_embed = self._get_position_embeddings(inputs_embeds, position_ids)
545
+ mask_indices = torch.nonzero(attention_mask[b_idx], as_tuple=True)[0]
546
+ all_position_embeds[:, b_idx : b_idx + 1].index_copy_(dim=-2, index=mask_indices, source=position_embed)
547
+ all_rope_deltas.append(rope_deltas)
548
+
549
+ rope_deltas = torch.stack(all_rope_deltas)
550
+
551
+ return inputs_embeds, all_position_embeds, rope_deltas
552
+
553
+ def _preprocess_decoder(
554
+ self,
555
+ input_ids: torch.LongTensor = None,
556
+ cache_position: torch.LongTensor = None,
557
+ ):
558
+ if self.rbln_config.batch_size != cache_position.shape[0]:
559
+ raise RuntimeError(
560
+ f"Cache position size mismatch: got {cache_position.shape[0]}, expected {self.rbln_config.batch_size}."
561
+ )
562
+
563
+ inputs_embeds = self.embed_tokens(input_ids)
564
+ position_embeds = []
565
+ for b_idx in range(self.rbln_config.batch_size):
566
+ delta = cache_position[b_idx] + self.rope_deltas[b_idx]
567
+ position_ids = torch.arange(1).view(1, -1)
568
+ position_ids = position_ids.add(delta)
569
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
570
+ position_embed = self._get_position_embeddings(torch.zeros(1, dtype=torch.float32), position_ids)
571
+ position_embeds.append(position_embed)
572
+
573
+ position_embeds = torch.cat(position_embeds, dim=1)
574
+
575
+ return inputs_embeds, position_embeds
576
+
577
+ def forward(
578
+ self,
579
+ input_ids: Optional[torch.LongTensor] = None,
580
+ inputs_embeds: Optional[torch.FloatTensor] = None,
581
+ attention_mask: Optional[torch.Tensor] = None,
582
+ pixel_values: Optional[torch.Tensor] = None,
583
+ pixel_values_videos: Optional[torch.FloatTensor] = None,
584
+ image_grid_thw: Optional[torch.LongTensor] = None,
585
+ video_grid_thw: Optional[torch.LongTensor] = None,
586
+ cache_position: Optional[torch.LongTensor] = None,
587
+ second_per_grid_ts: Optional[torch.Tensor] = None,
588
+ generate_idx: Optional[torch.Tensor] = None,
589
+ return_dict: Optional[bool] = None,
590
+ **kwargs,
591
+ ) -> RBLNDecoderOnlyOutput:
592
+ # Prefill
593
+ if cache_position is None:
594
+ inputs_embeds, position_embed, rope_deltas = self._preprocess_prefill(
595
+ input_ids,
596
+ attention_mask,
597
+ pixel_values,
598
+ pixel_values_videos,
599
+ image_grid_thw,
600
+ video_grid_thw,
601
+ second_per_grid_ts,
602
+ )
603
+
604
+ self.rope_deltas = rope_deltas
605
+ batch_size = inputs_embeds.shape[0]
606
+
607
+ logits = []
608
+ for b_idx in range(batch_size):
609
+ cache_position = torch.arange(0, generate_idx[b_idx].item(), dtype=torch.int32).unsqueeze(0)
610
+
611
+ output = self.prefill_decoder(
612
+ inputs_embeds=inputs_embeds[b_idx : b_idx + 1],
613
+ attention_mask=attention_mask[b_idx] if attention_mask is not None else None,
614
+ cache_position=cache_position,
615
+ batch_idx=b_idx,
616
+ position_embed=position_embed[:, b_idx : b_idx + 1],
617
+ )
618
+ logits.append(output.logits)
619
+ logits = torch.cat(logits, dim=0)
620
+ # Decoder
621
+ else:
622
+ inputs_embeds, position_embed = self._preprocess_decoder(input_ids, cache_position)
623
+ output = self.decoder(
624
+ inputs_embeds=inputs_embeds,
625
+ cache_position=cache_position,
626
+ position_embed=position_embed,
627
+ )
628
+ logits = output.logits
629
+
630
+ if not return_dict:
631
+ return logits, generate_idx
632
+ else:
633
+ return RBLNDecoderOnlyOutput(
634
+ logits=logits,
635
+ generate_idx=generate_idx,
636
+ )