optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1224 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from torch import nn
20
+ from transformers import PretrainedConfig, PreTrainedModel
21
+
22
+ from ....utils import logging
23
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
24
+ from ...utils.rbln_quantization import RBLNQuantizationConfig
25
+ from .configuration_lora import RBLNLoRAConfig
26
+ from .lora_architecture import LoRALinear
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from .configuration_decoderonly import RBLNDecoderOnlyModelConfig
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class DecoderOnlyWrapper(nn.Module):
37
+ """A wrapper class for decoder-only language models that handles RBLN-specific optimizations and requirements.
38
+
39
+ This wrapper is designed to:
40
+ 1. Convert Huggingface decoder models for RBLN compilation with static shapes
41
+ 2. Handle input/model mapping and additional information supply (e.g., positional embeddings)
42
+ 3. Manage different attention implementations (standard/flash attention)
43
+ 4. Support both prefill and decode phases
44
+
45
+ Notes:
46
+ - Wrapper must only receive positional arguments in forward() due to torch.jit.trace dependency
47
+ - Wrapper should not contain neural network graph operations (including memory view handling)
48
+
49
+ Args:
50
+ model (PreTrainedModel): The Huggingface causal language model to wrap
51
+ rbln_config: The RBLN model configuration containing all necessary parameters
52
+ use_rotary_emb (bool): Whether to use rotary position embeddings
53
+ """
54
+
55
+ _use_learned_pos_emb = False
56
+
57
+ def __init__(self, model: PreTrainedModel, rbln_config: "RBLNDecoderOnlyModelConfig", use_rotary_emb: bool):
58
+ super().__init__()
59
+ self.quantization = rbln_config.quantization
60
+ self.config = model.config
61
+ self.is_causal_lm = getattr(model, "lm_head", None) is not None
62
+ self.rbln_config = rbln_config
63
+
64
+ if use_rotary_emb:
65
+ rotary_embs = self.get_rotary_emb(max_seq_len=rbln_config.max_seq_len)
66
+ if isinstance(rotary_embs, tuple):
67
+ self.rotary_emb_global, self.rotary_emb_local = rotary_embs
68
+ else:
69
+ self.rotary_emb = rotary_embs
70
+ else:
71
+ self.rotary_emb = None
72
+
73
+ if rbln_config.kvcache_partition_len and rbln_config.kvcache_partition_len > rbln_config.max_seq_len:
74
+ raise ValueError(
75
+ f"kvcache_partition_len({rbln_config.kvcache_partition_len}) should be lower"
76
+ f" or equal to max_seq_len({rbln_config.max_seq_len})!"
77
+ )
78
+
79
+ self.model = self.convert_to_rbln_class(model, rbln_config.max_seq_len)
80
+ self.num_hidden_layers = getattr(self.config, "num_hidden_layers", None) or getattr(self.config, "n_layer")
81
+ self._phase = "prefill"
82
+
83
+ def get_rotary_emb(self, max_seq_len):
84
+ return RotaryEmbedding(config=self.config, max_seq_len_cached=max_seq_len)
85
+
86
+ def get_decoder_layers(self, model: PreTrainedModel):
87
+ return model.model.layers if self.is_causal_lm else model.layers
88
+
89
+ def get_attn_layer(self, layer: nn.Module):
90
+ return layer.self_attn
91
+
92
+ def get_model_layer(self, model: PreTrainedModel):
93
+ return model.model if self.is_causal_lm else model
94
+
95
+ def get_rbln_attn_class(self):
96
+ return DecoderOnlyAttention
97
+
98
+ def get_rbln_layer_class(self):
99
+ return DecoderOnlyLayer
100
+
101
+ def get_rbln_model_class(self):
102
+ return DecoderOnlyModel
103
+
104
+ def get_rbln_causal_lm_class(self):
105
+ return DecoderOnlyForCausalLM
106
+
107
+ def convert_to_rbln_class(self, model: PreTrainedModel, max_seq_len: int):
108
+ new_layers = []
109
+ for layer_idx, layer in enumerate(self.get_decoder_layers(model)):
110
+ is_sliding = layer_idx in self.rbln_config.sliding_window_layers
111
+ new_self_attn = self.get_rbln_attn_class()(
112
+ self.get_attn_layer(layer), self.rbln_config, is_sliding=is_sliding
113
+ )
114
+ new_layer = self.get_rbln_layer_class()(layer, new_self_attn, lora_config=self.rbln_config.lora_config)
115
+ new_layers.append(new_layer)
116
+
117
+ new_model = self.get_rbln_model_class()(
118
+ self.get_model_layer(model),
119
+ new_layers,
120
+ self.rbln_config,
121
+ use_learned_pos_emb=self.__class__._use_learned_pos_emb,
122
+ )
123
+
124
+ if self.is_causal_lm:
125
+ new_model = self.get_rbln_causal_lm_class()(model, new_model)
126
+ return new_model
127
+ else:
128
+ return new_model
129
+
130
+ @property
131
+ def phase(self) -> str:
132
+ return self._phase
133
+
134
+ @phase.setter
135
+ def phase(self, phase: str):
136
+ self._phase = phase
137
+ self.model.phase = phase
138
+
139
+ def prepare_forward_args(self, *args):
140
+ args = list(args)
141
+ input_ids = None if self.rbln_config.use_inputs_embeds else args.pop(0)
142
+ inputs_embeds = args.pop(0) if self.rbln_config.use_inputs_embeds else None
143
+ cache_position = args.pop(0)
144
+ global_block_tables = args.pop(0) if self.rbln_config.use_global_attention else None
145
+ local_block_tables = args.pop(0) if self.rbln_config.use_local_attention else None
146
+ query_position = (
147
+ args.pop(0)
148
+ # query_position usage: 1. causal_lm prefill or 2. sliding_window cache_position
149
+ if ("prefill" in self.phase and (self.is_causal_lm or self.rbln_config.use_local_attention))
150
+ else None
151
+ )
152
+ attention_mask = args.pop(0) if self.rbln_config.use_attention_mask else None
153
+ position_ids = args.pop(0) if self.rbln_config.use_position_ids else None
154
+ lora_int_id = args.pop(0) if self.rbln_config.lora_config else None
155
+ past_key_values = args
156
+
157
+ if len(past_key_values) != 2 * self.num_hidden_layers:
158
+ raise ValueError(
159
+ f"Different past_key_values to model's config. {len(past_key_values)} != {2 * self.num_hidden_layers}"
160
+ )
161
+
162
+ # [key, value] * n_layer -> ( (key, value) ) * n_layer
163
+ # cache shape : batch, n_heads, 1, max_seq_len, head_dim
164
+ _past_key_values = []
165
+ for i in range(self.config.num_hidden_layers):
166
+ key_states = past_key_values[i * 2]
167
+ value_states = past_key_values[i * 2 + 1]
168
+ past_key_value = [key_states, value_states]
169
+ _past_key_values.append(past_key_value)
170
+ past_key_values = _past_key_values
171
+
172
+ if hasattr(self, "rotary_emb_global") and hasattr(self, "rotary_emb_local"):
173
+ rotary_emb = (self.rotary_emb_global, self.rotary_emb_local)
174
+ else:
175
+ rotary_emb = self.rotary_emb
176
+
177
+ return (
178
+ input_ids,
179
+ inputs_embeds,
180
+ cache_position,
181
+ global_block_tables,
182
+ local_block_tables,
183
+ query_position,
184
+ attention_mask,
185
+ position_ids,
186
+ lora_int_id,
187
+ past_key_values,
188
+ rotary_emb,
189
+ )
190
+
191
+ def forward(self, *args):
192
+ (
193
+ input_ids,
194
+ inputs_embeds,
195
+ cache_position,
196
+ global_block_tables,
197
+ local_block_tables,
198
+ query_position,
199
+ attention_mask,
200
+ position_ids,
201
+ lora_int_id,
202
+ past_key_values,
203
+ rotary_emb,
204
+ ) = self.prepare_forward_args(*args)
205
+
206
+ logit = self.model(
207
+ input_ids=input_ids,
208
+ inputs_embeds=inputs_embeds,
209
+ attention_mask=attention_mask,
210
+ cache_position=cache_position,
211
+ position_ids=position_ids,
212
+ query_position=query_position,
213
+ past_key_values=past_key_values,
214
+ rotary_emb=rotary_emb,
215
+ global_block_tables=global_block_tables,
216
+ local_block_tables=local_block_tables,
217
+ lora_int_id=lora_int_id,
218
+ )
219
+
220
+ return logit
221
+
222
+
223
+ class DecoderOnlyForCausalLM(nn.Module):
224
+ """A specialized wrapper for Causal Language Models optimized for RBLN compilation.
225
+
226
+ This class adapts Huggingface's CausalLM (or similar models) for RBLN deployment by:
227
+ 1. Managing model phases (prefill/decode) throughout the computation graph
228
+ 2. Handling output shape alignments for static compilation
229
+ 3. Coordinating between the original model and RBLN-optimized components
230
+
231
+ The class serves as an intermediate layer between DecoderOnlyWrapper and the core model,
232
+ focusing on maintaining correct model behavior while enabling RBLN-specific optimizations.
233
+
234
+ Args:
235
+ causal_lm (PreTrainedModel): Original Huggingface causal language model
236
+ model (DecoderOnlyModel): RBLN-optimized model instance
237
+
238
+ Attributes:
239
+ config: Configuration from the original causal language model
240
+ _original_mod: Reference to the original model for components like lm_head
241
+ model: RBLN-optimized decoder model instance
242
+ _phase: Current processing phase ("prefill" or "decode")
243
+ """
244
+
245
+ def __init__(self, causal_lm: PreTrainedModel, model: nn.Module):
246
+ super().__init__()
247
+ self.config = causal_lm.config
248
+ self._original_mod = causal_lm
249
+ self.model = model
250
+ self._phase = "prefill"
251
+ self.lm_head = self._original_mod.lm_head
252
+
253
+ @property
254
+ def phase(self):
255
+ return self._phase
256
+
257
+ @phase.setter
258
+ def phase(self, phase: str):
259
+ self._phase = phase
260
+ self.model.phase = phase
261
+
262
+ def forward(
263
+ self,
264
+ input_ids: torch.Tensor = None,
265
+ inputs_embeds: torch.Tensor = None,
266
+ attention_mask: torch.Tensor = None,
267
+ cache_position: torch.Tensor = None,
268
+ position_ids: torch.Tensor = None,
269
+ query_position: torch.Tensor = None,
270
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
271
+ rotary_emb: nn.Module = None,
272
+ global_block_tables: Optional[torch.Tensor] = None,
273
+ local_block_tables: Optional[torch.Tensor] = None,
274
+ lora_int_id: Optional[torch.Tensor] = None,
275
+ ):
276
+ # outputs
277
+ hidden_states = self.model(
278
+ input_ids=input_ids,
279
+ inputs_embeds=inputs_embeds,
280
+ attention_mask=attention_mask,
281
+ cache_position=cache_position,
282
+ position_ids=position_ids,
283
+ query_position=query_position,
284
+ past_key_values=past_key_values,
285
+ rotary_emb=rotary_emb,
286
+ global_block_tables=global_block_tables,
287
+ local_block_tables=local_block_tables,
288
+ lora_int_id=lora_int_id,
289
+ )
290
+
291
+ if "prefill" in self.phase:
292
+ hidden_states = hidden_states[:, query_position.to(torch.int).unsqueeze(0)]
293
+
294
+ logits = self.lm_head(hidden_states)
295
+
296
+ # Apply final logit softmaxing if configured, e.g. for Gemma2
297
+ if getattr(self.config, "final_logit_softcapping", None) is not None:
298
+ logits = logits / self.config.final_logit_softcapping
299
+ logits = torch.tanh(logits)
300
+ logits = logits * self.config.final_logit_softcapping
301
+
302
+ return logits
303
+
304
+
305
+ class DecoderOnlyModel(nn.Module):
306
+ """A modified decoder-only model implementation optimized for RBLN compilation.
307
+
308
+ Args:
309
+ model: Original Huggingface model to adapt
310
+ layers (List[DecoderOnlyLayer]): Modified transformer layers optimized for RBLN
311
+ rbln_config: RBLN model configuration
312
+ use_learned_pos_emb: Whether to use learned position embeddings (class-specific override)
313
+
314
+ Attributes:
315
+ _original_mod: Reference to original Huggingface model
316
+ layers: ModuleList of RBLN-optimized transformer layers
317
+ _phase: Current processing phase ("prefill" or "decode")
318
+ """
319
+
320
+ def __init__(
321
+ self,
322
+ model,
323
+ layers: List["DecoderOnlyLayer"],
324
+ rbln_config: "RBLNDecoderOnlyModelConfig",
325
+ use_learned_pos_emb=None,
326
+ ):
327
+ super().__init__()
328
+ self._original_mod = model
329
+ self.layers = nn.ModuleList(layers)
330
+ self.rbln_config = rbln_config
331
+ self._phase = "prefill"
332
+ self.partition_len = rbln_config.kvcache_partition_len
333
+ self.kvcache_block_size = rbln_config.kvcache_block_size
334
+ self.max_seq_len = rbln_config.max_seq_len
335
+ self.use_learned_pos_emb = use_learned_pos_emb
336
+ self.sliding_window_layers = rbln_config.sliding_window_layers
337
+
338
+ @property
339
+ def phase(self):
340
+ return self._phase
341
+
342
+ @phase.setter
343
+ def phase(self, phase: str):
344
+ self._phase = phase
345
+ for layer in self.layers:
346
+ layer.phase = phase
347
+
348
+ @property
349
+ def attn_impl(self) -> str:
350
+ return "eager" if self.partition_len is None else "flash_attn"
351
+
352
+ @property
353
+ def hidden_multiplier(self):
354
+ return 1
355
+
356
+ def convert_sequence_positions_for_flash_attn(self, seq_positions, max_seq_len):
357
+ if self.attn_impl not in ["flash_attn"]:
358
+ raise NotImplementedError(f"Unknown attn_impl ({self.attn_impl}).")
359
+ partition_len = self.partition_len
360
+ num_partition = max_seq_len // partition_len
361
+
362
+ cs = seq_positions.repeat(num_partition, 1).transpose(0, 1)
363
+ pidx = torch.arange(num_partition)
364
+ cache_pos_for_partitions = torch.clamp(cs - pidx * partition_len, 0, partition_len)
365
+ return cache_pos_for_partitions
366
+
367
+ def get_local_cache_positions(self, position_ids, query_position):
368
+ max_cache_len = self._original_mod.config.sliding_window
369
+ valid_input_len = 1 if query_position is None else query_position + 1
370
+ cache_seq_len = torch.clamp(position_ids, max=max_cache_len)[:, :1] # past seen tokens
371
+ cache_offset = (
372
+ torch.clamp(position_ids, max=max_cache_len)[:, :1] + valid_input_len
373
+ ) # cache offset for next steps
374
+
375
+ return cache_seq_len, cache_offset
376
+
377
+ def get_last_layernorm(self) -> nn.LayerNorm:
378
+ return self._original_mod.norm
379
+
380
+ def get_embedding(self) -> nn.Embedding:
381
+ return self._original_mod.embed_tokens
382
+
383
+ def get_pos_embedding(self) -> nn.Embedding:
384
+ raise NotImplementedError(
385
+ "The 'get_pos_embedding' method is not implemented. Please define this method in a subclass."
386
+ )
387
+
388
+ def forward(
389
+ self,
390
+ input_ids: torch.Tensor = None,
391
+ inputs_embeds: Optional[torch.Tensor] = None,
392
+ attention_mask: torch.Tensor = None,
393
+ cache_position: torch.Tensor = None,
394
+ position_ids: torch.Tensor = None,
395
+ query_position: torch.Tensor = None,
396
+ past_key_values: Tuple[Tuple[torch.Tensor]] = None,
397
+ rotary_emb: Optional[Union[nn.Module, torch.Tensor]] = None,
398
+ global_block_tables: Optional[torch.Tensor] = None,
399
+ local_block_tables: Optional[torch.Tensor] = None,
400
+ lora_int_id: Optional[torch.Tensor] = None,
401
+ ):
402
+ # retrieve input_ids and inputs_embeds
403
+ if (input_ids is None) ^ (inputs_embeds is not None):
404
+ raise ValueError(
405
+ "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
406
+ )
407
+
408
+ # embed positions
409
+ if inputs_embeds is None:
410
+ inputs_embeds = self.get_embedding()(input_ids)
411
+
412
+ hidden_states = inputs_embeds * self.hidden_multiplier
413
+
414
+ # get cos,sin vector if needed
415
+ position_ids = position_ids if position_ids is not None else cache_position
416
+ if rotary_emb is not None:
417
+ if isinstance(rotary_emb, torch.Tensor):
418
+ cos = rotary_emb[0]
419
+ sin = rotary_emb[1]
420
+ else:
421
+ cos, sin = rotary_emb(hidden_states, self.max_seq_len) # dtype carrier, max_seq_len
422
+ cos, sin = slice_and_unsqueeze_cos_sin(cos, sin, position_ids)
423
+
424
+ elif self.use_learned_pos_emb:
425
+ batch_size = inputs_embeds.shape[0]
426
+ hidden_all = []
427
+ for i in range(batch_size):
428
+ positions_idx = position_ids[i]
429
+ position_weight = self.get_pos_embedding().weight[2:]
430
+ position = position_weight[positions_idx]
431
+ batch_hidden = position + inputs_embeds[i]
432
+ hidden_all.append(batch_hidden)
433
+ hidden_states = torch.stack(hidden_all, dim=0)
434
+ cos, sin = None, None
435
+
436
+ else:
437
+ batch_size = inputs_embeds.shape[0]
438
+ if position_ids.shape[0] > 1:
439
+ position_embeds = []
440
+ for b_idx in range(batch_size):
441
+ position_embed = self.get_pos_embedding()(position_ids[b_idx])
442
+ position_embeds.append(position_embed)
443
+
444
+ position_embeds = torch.cat(position_embeds, dim=0).unsqueeze(1)
445
+ else:
446
+ position_embeds = self.get_pos_embedding()(position_ids)
447
+ hidden_states = hidden_states + position_embeds
448
+ cos, sin = None, None
449
+
450
+ # Get sequence positions for flash attention
451
+ if self.attn_impl == "flash_attn":
452
+ seq_positions = cache_position[:, 0]
453
+ seq_positions = self.convert_sequence_positions_for_flash_attn(
454
+ seq_positions=seq_positions, max_seq_len=self.max_seq_len
455
+ )
456
+ else:
457
+ seq_positions = cache_position[:, :1]
458
+
459
+ # Get local cache positions for sliding window layers
460
+ if len(self.sliding_window_layers) > 0:
461
+ sliding_cache_pos = self.get_local_cache_positions(position_ids, query_position)
462
+
463
+ for layer_idx, layer in enumerate(self.layers):
464
+ is_sliding = True if layer_idx in self.sliding_window_layers else False
465
+ hidden_states = layer(
466
+ hidden_states=hidden_states,
467
+ attention_mask=attention_mask,
468
+ seq_positions=sliding_cache_pos if is_sliding else seq_positions,
469
+ past_key_values=past_key_values,
470
+ cos=cos,
471
+ sin=sin,
472
+ block_tables=local_block_tables if is_sliding else global_block_tables,
473
+ lora_int_id=lora_int_id,
474
+ )
475
+
476
+ hidden_states = self.get_last_layernorm()(hidden_states)
477
+ return hidden_states
478
+
479
+
480
+ class DecoderOnlyLayer(nn.Module):
481
+ """A single transformer layer adapted for RBLN compilation with static shapes.
482
+
483
+ This layer implements a modified transformer block that includes:
484
+ 1. Self-attention mechanism (either standard or flash attention)
485
+ 2. Feed-forward network (FFN)
486
+ 3. Layer normalization
487
+ 4. Residual connections
488
+
489
+ The layer is specifically designed to:
490
+ - Support compilation to RBLN custom ops
491
+ - Maintain static tensor shapes throughout computations
492
+ - Handle both prefill and decode phases efficiently
493
+ - Manage attention state transitions properly
494
+
495
+ Args:
496
+ layer: Original transformer layer module to wrap
497
+ self_attn (DecoderOnlyAttention): Modified attention module optimized for RBLN
498
+
499
+ Attributes:
500
+ _original_mod: Reference to original layer for accessing components
501
+ self_attn: Modified attention mechanism mapped to RBLN ops at compile time
502
+ phase: Current operation phase ("prefill" or "decode")
503
+ """
504
+
505
+ def __init__(self, layer, self_attn: "DecoderOnlyAttention", lora_config: Optional[RBLNLoRAConfig] = None):
506
+ super().__init__()
507
+ self._original_mod = layer
508
+ self.self_attn = self_attn
509
+ self._phase = "prefill"
510
+ self.lora_config = lora_config
511
+
512
+ # Replace target Linear modules in MLP with LoRALinear if configured
513
+ if self.lora_config:
514
+ mlp = self.get_mlp()
515
+ for proj_name in ["gate_proj", "up_proj", "down_proj"]:
516
+ if hasattr(mlp, proj_name):
517
+ original_linear = getattr(mlp, proj_name)
518
+ if isinstance(original_linear, nn.Linear):
519
+ lora_linear = LoRALinear(
520
+ original_linear=original_linear,
521
+ lora_config=self.lora_config,
522
+ projection_name=proj_name,
523
+ layer_idx=self.self_attn.layer_idx,
524
+ )
525
+ setattr(mlp, proj_name, lora_linear)
526
+
527
+ @property
528
+ def phase(self):
529
+ return self._phase
530
+
531
+ @phase.setter
532
+ def phase(self, phase: str):
533
+ self._phase = phase
534
+ self.self_attn.phase = phase
535
+
536
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
537
+ return self._original_mod.input_layernorm
538
+
539
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
540
+ return self._original_mod.post_attention_layernorm
541
+
542
+ def get_mlp(self) -> nn.Module:
543
+ return self._original_mod.mlp
544
+
545
+ def forward_mlp(self, hidden_states: torch.Tensor, lora_int_id: Optional[torch.Tensor] = None) -> torch.Tensor:
546
+ mlp = self.get_mlp()
547
+ if self.lora_config and lora_int_id is not None:
548
+ gate = mlp.gate_proj(hidden_states, lora_int_id)
549
+ up = mlp.up_proj(hidden_states, lora_int_id)
550
+ act_fn = getattr(mlp, "act_fn", None) or getattr(mlp, "activation_fn", None)
551
+ if act_fn is None:
552
+ gate = torch.nn.functional.silu(gate)
553
+ else:
554
+ gate = act_fn(gate)
555
+ fused = gate * up
556
+ hidden_states = mlp.down_proj(fused, lora_int_id)
557
+ else:
558
+ hidden_states = mlp(hidden_states)
559
+ return hidden_states
560
+
561
+ def forward(
562
+ self,
563
+ hidden_states: torch.Tensor,
564
+ attention_mask: torch.Tensor,
565
+ seq_positions: torch.LongTensor,
566
+ past_key_values: Tuple[Tuple[torch.Tensor]],
567
+ cos: Optional[torch.Tensor] = None,
568
+ sin: Optional[torch.Tensor] = None,
569
+ block_tables: Optional[torch.Tensor] = None,
570
+ lora_int_id: Optional[torch.Tensor] = None,
571
+ ):
572
+ residual = hidden_states
573
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
574
+
575
+ hidden_states = self.self_attn(
576
+ hidden_states=hidden_states,
577
+ attention_mask=attention_mask,
578
+ seq_positions=seq_positions,
579
+ past_key_values=past_key_values,
580
+ cos=cos,
581
+ sin=sin,
582
+ block_tables=block_tables,
583
+ lora_int_id=lora_int_id,
584
+ )
585
+ hidden_states = residual + hidden_states
586
+
587
+ # Fully Connected
588
+ residual = hidden_states
589
+ hidden_states = self.get_post_attention_layernorm()(hidden_states)
590
+ hidden_states = self.forward_mlp(hidden_states, lora_int_id)
591
+ hidden_states = residual + hidden_states
592
+
593
+ return hidden_states
594
+
595
+
596
+ class DecoderOnlyAttention(nn.Module):
597
+ """Attention implementation for decoder-only models optimized for RBLN compilation.
598
+
599
+ This class implements a modified version of the standard attention mechanism that:
600
+ 1. Supports static shape requirements for RBLN compilation
601
+ 2. Handles explicit batch and position management
602
+
603
+ Args:
604
+ self_attn: Original attention module from the base model
605
+ rbln_config: RBLN model configuration containing attention parameters
606
+ is_sliding: Whether this is sliding window attention
607
+ """
608
+
609
+ def __init__(
610
+ self,
611
+ self_attn,
612
+ rbln_config: "RBLNDecoderOnlyModelConfig",
613
+ is_sliding=False,
614
+ ):
615
+ super().__init__()
616
+ self._original_mod = self_attn
617
+ self.rbln_config = rbln_config
618
+ self.layer_idx = self_attn.layer_idx
619
+ self.num_heads = getattr(self._original_mod, "num_heads", None) or getattr(
620
+ self._original_mod.config, "num_attention_heads"
621
+ )
622
+ self.head_dim = self._original_mod.head_dim
623
+ self._phase = "prefill"
624
+ self.scale = torch.nn.Parameter(torch.tensor(self.get_attn_scale()))
625
+ self.quantization = rbln_config.quantization
626
+
627
+ if hasattr(self._original_mod, "num_key_value_heads"):
628
+ self.num_key_value_heads = self._original_mod.num_key_value_heads
629
+ elif hasattr(self._original_mod, "config") and hasattr(self._original_mod.config, "num_key_value_heads"):
630
+ self.num_key_value_heads = self._original_mod.config.num_key_value_heads
631
+ else:
632
+ self.num_key_value_heads = self.num_heads
633
+
634
+ self.use_attention_mask = rbln_config.use_attention_mask if not is_sliding else True
635
+ self.use_position_ids = rbln_config.use_position_ids
636
+ self.is_sliding = is_sliding
637
+ self.attn_impl = rbln_config.attn_impl if not is_sliding else "eager"
638
+ self.kvcache_partition_len = getattr(rbln_config, "kvcache_partition_len", None)
639
+ self.kvcache_block_size = rbln_config.sliding_window if is_sliding else rbln_config.kvcache_block_size
640
+ self.lora_config = rbln_config.lora_config
641
+
642
+ setattr(self, self.get_attention_name(), self.create_attention_op())
643
+ self.__post_init__()
644
+
645
+ def _init_lora_weights(self):
646
+ """Initialize LoRA adapter weights by replacing linear layers with LoRALinear."""
647
+ for proj_name in ["q_proj", "k_proj", "v_proj", "o_proj"]:
648
+ original_linear = getattr(self._original_mod, proj_name)
649
+ lora_linear = LoRALinear(
650
+ original_linear=original_linear,
651
+ lora_config=self.lora_config,
652
+ projection_name=proj_name,
653
+ layer_idx=self.layer_idx,
654
+ )
655
+ setattr(self, proj_name, lora_linear)
656
+
657
+ def get_attention_name(self):
658
+ if self.is_sliding:
659
+ return "sliding_window_attention"
660
+ elif self.attn_impl == "flash_attn":
661
+ return "flash_attention"
662
+ else:
663
+ return "attention"
664
+
665
+ def get_attention_op(self):
666
+ return getattr(self, self.get_attention_name())
667
+
668
+ @property
669
+ def phase(self):
670
+ return self._phase
671
+
672
+ @phase.setter
673
+ def phase(self, phase: str):
674
+ self._phase = phase
675
+ getattr(self, self.get_attention_name()).phase = phase
676
+
677
+ def create_attention_op(self):
678
+ if self.is_sliding:
679
+ return SlidingWindowAttentionOp(
680
+ self.num_heads,
681
+ self.head_dim,
682
+ self.num_key_value_heads,
683
+ self.use_attention_mask,
684
+ self.use_position_ids,
685
+ )
686
+ elif self.attn_impl == "flash_attn":
687
+ return FlashAttentionOp(
688
+ self.num_heads,
689
+ self.head_dim,
690
+ self.num_key_value_heads,
691
+ self.kvcache_partition_len,
692
+ self.use_attention_mask,
693
+ self.use_position_ids,
694
+ self.quantization,
695
+ )
696
+ elif self.attn_impl == "eager":
697
+ return AttentionOp(
698
+ self.num_heads,
699
+ self.head_dim,
700
+ self.num_key_value_heads,
701
+ self.use_attention_mask,
702
+ self.use_position_ids,
703
+ self.quantization,
704
+ )
705
+ else:
706
+ raise NotImplementedError(f"Unknown attention implementation: {self.attn_impl}")
707
+
708
+ def __post_init__(self):
709
+ # Initialize LoRA weights if configured, which will replace linear layers
710
+ if self.lora_config:
711
+ self._init_lora_weights()
712
+ else:
713
+ # Use original linear layers if no LoRA
714
+ self.q_proj = self._original_mod.q_proj
715
+ self.k_proj = self._original_mod.k_proj
716
+ self.v_proj = self._original_mod.v_proj
717
+ self.o_proj = self._original_mod.o_proj
718
+
719
+ def projection(
720
+ self, hidden_states, lora_int_id: Optional[torch.Tensor] = None
721
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
722
+ """Projects input hidden states into query, key, and value representations.
723
+
724
+ Args:
725
+ hidden_states: Input tensor of shape [batch_size, seq_len, hidden_dim]
726
+ lora_int_id: Adapter ID tensor for LoRA selection [batch_size]
727
+
728
+ Returns:
729
+ Tuple of (query_states, key_states, value_states)
730
+ """
731
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
732
+ if self.lora_config:
733
+ # LoRALinear handles both base projection and LoRA in one forward pass
734
+ query_states = self.q_proj(hidden_states, lora_int_id)
735
+ key_states = self.k_proj(hidden_states, lora_int_id)
736
+ value_states = self.v_proj(hidden_states, lora_int_id)
737
+ else:
738
+ # Standard linear projection without LoRA
739
+ query_states = self.q_proj(hidden_states)
740
+ key_states = self.k_proj(hidden_states)
741
+ value_states = self.v_proj(hidden_states)
742
+
743
+ return query_states, key_states, value_states
744
+
745
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
746
+ return apply_rotary_pos_emb(query_states, key_states, cos, sin)
747
+
748
+ def get_attn_scale(self):
749
+ return 1 / math.sqrt(self.head_dim)
750
+
751
+ def maybe_get_kvcache_scale(self) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
752
+ if hasattr(self, "k_proj") and hasattr(self, "v_proj"):
753
+ k_scale = getattr(self.k_proj, "k_scale", None)
754
+ v_scale = getattr(self.v_proj, "v_scale", None)
755
+ else:
756
+ k_scale = None
757
+ v_scale = None
758
+
759
+ return k_scale, v_scale
760
+
761
+ def forward(
762
+ self,
763
+ hidden_states: torch.Tensor,
764
+ attention_mask: torch.Tensor,
765
+ seq_positions: torch.LongTensor,
766
+ past_key_values: Tuple[Tuple[torch.Tensor]],
767
+ cos: Optional[torch.Tensor] = None,
768
+ sin: Optional[torch.Tensor] = None,
769
+ block_tables: Optional[torch.Tensor] = None,
770
+ lora_int_id: Optional[torch.Tensor] = None,
771
+ ):
772
+ batch_size, query_length, _ = hidden_states.size()
773
+
774
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states, lora_int_id=lora_int_id)
775
+
776
+ query_states = query_states.view(batch_size, query_length, self.num_heads, self.head_dim).transpose(1, 2)
777
+ key_states = key_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
778
+ value_states = value_states.view(batch_size, query_length, self.num_key_value_heads, self.head_dim).transpose(
779
+ 1, 2
780
+ )
781
+ if hasattr(self, "q_norm") and hasattr(self, "k_norm"):
782
+ query_states = self.q_norm(query_states)
783
+ key_states = self.k_norm(key_states)
784
+
785
+ if cos is not None and sin is not None:
786
+ query_states, key_states = self.apply_rotary_pos_embed(query_states, key_states, cos, sin)
787
+
788
+ if batch_size > 1 and "prefill" in self.phase:
789
+ raise NotImplementedError(f"batch size should be 1 if prefill phase, but got {batch_size}.")
790
+
791
+ k_scale, v_scale = self.maybe_get_kvcache_scale()
792
+
793
+ attn_output = self.get_attention_op()(
794
+ query_states,
795
+ key_states,
796
+ value_states,
797
+ attention_mask,
798
+ past_key_state=past_key_values[self.layer_idx][0],
799
+ past_value_state=past_key_values[self.layer_idx][1],
800
+ seq_position=seq_positions,
801
+ scale=self.scale,
802
+ block_tables=block_tables,
803
+ block_size=self.kvcache_block_size,
804
+ k_scale=k_scale,
805
+ v_scale=v_scale,
806
+ )
807
+
808
+ # Check if using LoRALinear (which accepts lora_int_id) or standard linear layers
809
+ if self.lora_config:
810
+ # LoRALinear handles both base projection and LoRA in one forward pass
811
+ attn_outputs = self.o_proj(attn_output, lora_int_id)
812
+ else:
813
+ # Standard linear projection without LoRA
814
+ attn_outputs = self.o_proj(attn_output)
815
+
816
+ return attn_outputs
817
+
818
+
819
+ class DecoderOnlyFlashAttention(DecoderOnlyAttention):
820
+ def __init__(self, *args, **kwargs):
821
+ super().__init__(*args, **kwargs)
822
+ logger.warning(
823
+ "DecoderOnlyFlashAttention is deprecated and may not work as expected. Use DecoderOnlyAttention instead."
824
+ )
825
+
826
+
827
+ class AttentionOp(nn.Module):
828
+ def __init__(
829
+ self,
830
+ num_heads: int,
831
+ head_dim: int,
832
+ num_key_value_heads: int,
833
+ use_attention_mask: bool,
834
+ use_position_ids: bool,
835
+ quantization: Optional[RBLNQuantizationConfig] = None,
836
+ ):
837
+ super().__init__()
838
+ self.num_heads = num_heads
839
+ self.head_dim = head_dim
840
+ self.num_key_value_heads = num_key_value_heads
841
+ self.phase = "prefill"
842
+ self.use_attention_mask = use_attention_mask
843
+ self.use_position_ids = use_position_ids
844
+ self.quantization = quantization
845
+
846
+ def get_attn_op_name(self):
847
+ phase = "decode" if self.phase == "decode" else "prefill"
848
+ if self.use_attention_mask and not self.use_position_ids:
849
+ attn_op_name = "paged_attn_"
850
+ else:
851
+ attn_op_name = "paged_causal_attn_"
852
+
853
+ attn_op_name += phase
854
+
855
+ if self.quantization and self.quantization.kv_caches == "fp8":
856
+ attn_op_name += "_kv_fp8"
857
+
858
+ return attn_op_name
859
+
860
+ def forward(
861
+ self,
862
+ query_state: torch.Tensor,
863
+ key_state: torch.Tensor,
864
+ value_state: torch.Tensor,
865
+ attn_mask: torch.Tensor,
866
+ past_key_state: torch.Tensor,
867
+ past_value_state: torch.Tensor,
868
+ seq_position: torch.Tensor,
869
+ scale: torch.Tensor,
870
+ block_tables: torch.Tensor,
871
+ block_size: int,
872
+ k_scale: Optional[torch.Tensor] = None,
873
+ v_scale: Optional[torch.Tensor] = None,
874
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
875
+ """Compute attention with static shapes and explicit cache management.
876
+
877
+ Args:
878
+ query_state: Query tensor [1, num_heads, 1, head_dim]
879
+ key_state: Key tensor [1, num_heads, seq_len, head_dim]
880
+ value_state: Value tensor [1, num_heads, seq_len, head_dim]
881
+ attn_mask: Attention mask tensor ∈ {0, 1}
882
+ past_key_state: Previous key cache states
883
+ past_value_state: Previous value cache states
884
+ seq_position: Current position in sequence
885
+ scale: Scale applied to attn weights
886
+ block_tables: Block tables for paged attention
887
+ block_size: Block size for paged attention
888
+ k_scale: Scale applied to key
889
+ v_scale: Scale applied to value
890
+
891
+ Returns:
892
+ Tensor: attention_output: [batch, num_heads, seq_len, head_dim]
893
+ """
894
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
895
+ key_state = key_state.unsqueeze(2) # 1, 32, 1, 128, 128
896
+ value_state = value_state.unsqueeze(2)
897
+
898
+ if self.use_attention_mask and not self.use_position_ids:
899
+ attn_mask = attn_mask.unsqueeze(2)
900
+
901
+ if self.phase == "decode":
902
+ batch_size = key_state.shape[0]
903
+ else:
904
+ batch_size = 1
905
+
906
+ query_state = query_state.view(
907
+ batch_size,
908
+ self.num_key_value_heads,
909
+ self.num_heads // self.num_key_value_heads,
910
+ -1, # seq len
911
+ self.head_dim,
912
+ )
913
+
914
+ op_args = {
915
+ "q": query_state,
916
+ "k": key_state,
917
+ "v": value_state,
918
+ "kcache": past_key_state.unsqueeze(2),
919
+ "vcache": past_value_state.unsqueeze(2),
920
+ "seq": seq_position,
921
+ "scale": scale,
922
+ "block_table": block_tables,
923
+ "block_size": block_size,
924
+ }
925
+
926
+ if self.use_attention_mask:
927
+ op_args["mask"] = attn_mask
928
+
929
+ if self.phase == "prefill" or self.phase == "image_prefill":
930
+ if not self.use_attention_mask or self.use_position_ids:
931
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
932
+
933
+ if self.quantization and self.quantization.kv_caches == "fp8":
934
+ if past_key_state.dtype != torch.float8_e4m3fn:
935
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
936
+ op_args["k_scale"] = k_scale
937
+ op_args["v_scale"] = v_scale
938
+
939
+ attn_op_name = self.get_attn_op_name()
940
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
941
+ if attn_op is None:
942
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
943
+
944
+ attn_output = attn_op(**op_args)
945
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
946
+ attn_output = attn_output.transpose(1, 2).contiguous()
947
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
948
+
949
+ return attn_output
950
+
951
+
952
+ class FlashAttentionOp(AttentionOp):
953
+ def __init__(
954
+ self,
955
+ num_heads: int,
956
+ head_dim: int,
957
+ num_key_value_heads: int,
958
+ kvcache_partition_len: int,
959
+ use_attention_mask: bool,
960
+ use_position_ids: bool,
961
+ quantization: Optional[RBLNQuantizationConfig] = None,
962
+ ):
963
+ super().__init__(
964
+ num_heads=num_heads,
965
+ head_dim=head_dim,
966
+ num_key_value_heads=num_key_value_heads,
967
+ use_attention_mask=use_attention_mask,
968
+ use_position_ids=use_position_ids,
969
+ quantization=quantization,
970
+ )
971
+ self.kvcache_partition_size = kvcache_partition_len
972
+
973
+ def get_attn_op_name(self):
974
+ phase = "decode" if self.phase == "decode" else "prefill"
975
+ if self.use_attention_mask and not self.use_position_ids:
976
+ attn_op_name = "paged_flash_attn_"
977
+ else:
978
+ attn_op_name = "paged_flash_causal_attn_"
979
+
980
+ attn_op_name += phase
981
+
982
+ if self.quantization and self.quantization.kv_caches == "fp8":
983
+ attn_op_name += "_kv_fp8"
984
+
985
+ return attn_op_name
986
+
987
+ def forward(
988
+ self,
989
+ query_state,
990
+ key_state,
991
+ value_state,
992
+ attn_mask,
993
+ past_key_state,
994
+ past_value_state,
995
+ seq_position,
996
+ scale,
997
+ block_tables,
998
+ block_size,
999
+ k_scale=None,
1000
+ v_scale=None,
1001
+ ):
1002
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1003
+ key_state = key_state.unsqueeze(2)
1004
+ value_state = value_state.unsqueeze(2)
1005
+ if self.use_attention_mask and not self.use_position_ids:
1006
+ attn_mask = attn_mask.unsqueeze(2)
1007
+
1008
+ if self.phase == "decode":
1009
+ batch_size = key_state.shape[0]
1010
+ else:
1011
+ batch_size = 1
1012
+
1013
+ query_state = query_state.view(
1014
+ batch_size,
1015
+ self.num_key_value_heads,
1016
+ self.num_heads // self.num_key_value_heads,
1017
+ -1, # seq len
1018
+ self.head_dim,
1019
+ )
1020
+
1021
+ op_args = {
1022
+ "q": query_state,
1023
+ "k": key_state,
1024
+ "v": value_state,
1025
+ "kcache": past_key_state.unsqueeze(2),
1026
+ "vcache": past_value_state.unsqueeze(2),
1027
+ "seq": seq_position,
1028
+ "scale": scale,
1029
+ "block_table": block_tables,
1030
+ "block_size": block_size,
1031
+ "partition": self.kvcache_partition_size,
1032
+ }
1033
+
1034
+ if self.use_attention_mask:
1035
+ op_args["mask"] = attn_mask
1036
+
1037
+ if self.phase == "prefill" or self.phase == "image_prefill":
1038
+ if not self.use_attention_mask or self.use_position_ids:
1039
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1040
+
1041
+ if self.quantization and self.quantization.kv_caches == "fp8":
1042
+ if past_key_state.dtype != torch.float8_e4m3fn:
1043
+ raise ValueError(f"Unsupported KVCaches type: {past_key_state.dtype}")
1044
+ op_args["k_scale"] = k_scale
1045
+ op_args["v_scale"] = v_scale
1046
+
1047
+ attn_op_name = self.get_attn_op_name()
1048
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1049
+ if attn_op is None:
1050
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1051
+
1052
+ attn_output = attn_op(**op_args)
1053
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1054
+ attn_output = attn_output.transpose(1, 2).contiguous()
1055
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1056
+
1057
+ return attn_output
1058
+
1059
+
1060
+ class SlidingWindowAttentionOp(AttentionOp):
1061
+ def get_attn_op_name(self):
1062
+ phase = "decode" if self.phase == "decode" else "prefill"
1063
+ if not self.use_attention_mask:
1064
+ raise NotImplementedError("Attention mask is needed for sliding window attention.")
1065
+
1066
+ attn_op_name = "paged_sliding_window_attn_" + phase
1067
+ return attn_op_name
1068
+
1069
+ def forward(
1070
+ self,
1071
+ query_state: torch.Tensor,
1072
+ key_state: torch.Tensor,
1073
+ value_state: torch.Tensor,
1074
+ attn_mask: Optional[torch.Tensor],
1075
+ past_key_state: torch.Tensor,
1076
+ past_value_state: torch.Tensor,
1077
+ seq_position: Tuple[torch.Tensor],
1078
+ scale: torch.Tensor,
1079
+ block_tables: torch.Tensor,
1080
+ block_size: int,
1081
+ k_scale: Optional[torch.Tensor] = None,
1082
+ v_scale: Optional[torch.Tensor] = None,
1083
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1084
+ assert self.quantization is None, "Sliding window attention does not support quantization"
1085
+ assert k_scale is None and v_scale is None, "Sliding window attention does not support quantization"
1086
+
1087
+ # reshape for removing repeat_kv (batch=1 , num_head, 1, q_len=1, head_dim)
1088
+ key_state = key_state.unsqueeze(2)
1089
+ value_state = value_state.unsqueeze(2)
1090
+
1091
+ if self.phase == "decode":
1092
+ batch_size = key_state.shape[0]
1093
+ else:
1094
+ batch_size = 1
1095
+
1096
+ query_state = query_state.view(
1097
+ batch_size,
1098
+ self.num_key_value_heads,
1099
+ self.num_heads // self.num_key_value_heads,
1100
+ -1, # seq len
1101
+ self.head_dim,
1102
+ )
1103
+
1104
+ op_args = {
1105
+ "q": query_state,
1106
+ "k": key_state,
1107
+ "v": value_state,
1108
+ "kcache": past_key_state.unsqueeze(2),
1109
+ "vcache": past_value_state.unsqueeze(2),
1110
+ "cache_seq_len": seq_position[0],
1111
+ "cache_offset": seq_position[1],
1112
+ "scale": scale,
1113
+ "block_table": block_tables,
1114
+ "block_size": block_size,
1115
+ }
1116
+
1117
+ if self.phase == "prefill" or self.phase == "image_prefill":
1118
+ op_args["is_bidirectional"] = self.phase == "image_prefill" # FIXME, Hard-coded for Gemma3.
1119
+
1120
+ attn_op_name = self.get_attn_op_name()
1121
+ attn_op = getattr(torch.ops.rbln_custom_ops, attn_op_name, None)
1122
+ if attn_op is None:
1123
+ raise ValueError(f"Attention operator {attn_op_name} not found.")
1124
+
1125
+ attn_output = attn_op(**op_args)
1126
+ attn_output = attn_output.view(batch_size, self.num_heads, -1, self.head_dim)
1127
+ attn_output = attn_output.transpose(1, 2).contiguous()
1128
+ attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim)
1129
+
1130
+ return attn_output
1131
+
1132
+
1133
+ class RotaryEmbedding(nn.Module):
1134
+ """RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
1135
+
1136
+ def __init__(
1137
+ self,
1138
+ config: PretrainedConfig,
1139
+ max_seq_len_cached: int,
1140
+ ):
1141
+ super().__init__()
1142
+
1143
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
1144
+ rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
1145
+ else:
1146
+ rope_type = "default"
1147
+
1148
+ inv_freq, attention_scaling = ROPE_INIT_FUNCTIONS[rope_type](config, max_seq_len_cached)
1149
+ cache_position = torch.arange(0, max_seq_len_cached)
1150
+ cache_position_expanded = cache_position[:, None]
1151
+
1152
+ if rope_type == "dynamic":
1153
+ freqs = cache_position_expanded.float() * inv_freq.float()
1154
+ else:
1155
+ inv_freq_expanded = inv_freq[None, :]
1156
+ freqs = cache_position_expanded.float() @ inv_freq_expanded.float()
1157
+
1158
+ emb = torch.cat((freqs, freqs), dim=-1)
1159
+
1160
+ cos = emb.cos() * attention_scaling
1161
+ sin = emb.sin() * attention_scaling
1162
+
1163
+ self.register_buffer("_cos_cached", cos, persistent=False)
1164
+ self.register_buffer("_sin_cached", sin, persistent=False)
1165
+
1166
+ def forward(self, x, seq_len):
1167
+ return (
1168
+ self._cos_cached[:seq_len].to(dtype=torch.float32),
1169
+ self._sin_cached[:seq_len].to(dtype=torch.float32),
1170
+ )
1171
+
1172
+
1173
+ def slice_and_unsqueeze_cos_sin(cos, sin, cache_position, unsqueeze_dim=1):
1174
+ """Slice cos[cache_position], sin[cache_position] vector for the query."""
1175
+ if cache_position.shape[0] > 1:
1176
+ cos_all = []
1177
+ sin_all = []
1178
+ for i in range(cache_position.shape[0]):
1179
+ cos_all.append(cos[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1180
+ sin_all.append(sin[cache_position[i : i + 1]].unsqueeze(unsqueeze_dim))
1181
+ cos = torch.cat(cos_all, dim=0)
1182
+ sin = torch.cat(sin_all, dim=0)
1183
+ else:
1184
+ cos = cos[cache_position].unsqueeze(unsqueeze_dim)
1185
+ sin = sin[cache_position].unsqueeze(unsqueeze_dim)
1186
+
1187
+ return cos, sin
1188
+
1189
+
1190
+ def rotate_half(x):
1191
+ """Rotates half the hidden dims of the input."""
1192
+ x1 = x[..., : x.shape[-1] // 2]
1193
+ x2 = x[..., x.shape[-1] // 2 :]
1194
+ return torch.cat((-x2, x1), dim=-1)
1195
+
1196
+
1197
+ def apply_rotary_pos_emb(q, k, cos, sin):
1198
+ """Applies Rotary Position Embedding to the query and key tensors."""
1199
+ dtype = q.dtype
1200
+ q_embed = (q * cos) + (rotate_half(q) * sin)
1201
+ k_embed = (k * cos) + (rotate_half(k) * sin)
1202
+ q_embed = q_embed.to(dtype)
1203
+ k_embed = k_embed.to(dtype)
1204
+ return q_embed, k_embed
1205
+
1206
+
1207
+ def apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim) -> Tuple[torch.Tensor, torch.Tensor]:
1208
+ # Partial rotary embedding
1209
+ query_rot, query_pass = (
1210
+ query_states[..., :ndim],
1211
+ query_states[..., ndim:],
1212
+ )
1213
+ key_rot, key_pass = (
1214
+ key_states[..., :ndim],
1215
+ key_states[..., ndim:],
1216
+ )
1217
+
1218
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
1219
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
1220
+
1221
+ # [batch_size, seq_length, num_heads, head_dim]
1222
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
1223
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
1224
+ return query_states, key_states