optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,74 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING
16
+
17
+ import torch.nn as nn
18
+
19
+ from ...models.decoderonly.decoderonly_architecture import (
20
+ DecoderOnlyAttention,
21
+ DecoderOnlyLayer,
22
+ DecoderOnlyModel,
23
+ DecoderOnlyWrapper,
24
+ )
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers import OPTForCausalLM
29
+
30
+
31
+ class OPTWrapper(DecoderOnlyWrapper):
32
+ _use_learned_pos_emb = True
33
+
34
+ def get_rbln_attn_class(self):
35
+ return OPTAttention
36
+
37
+ def get_rbln_layer_class(self):
38
+ return OPTDecoderLayer
39
+
40
+ def get_rbln_model_class(self):
41
+ return OPTModel
42
+
43
+ def get_model_layer(self, model: "OPTForCausalLM"):
44
+ return model.model.decoder if self.is_causal_lm else model.decoder
45
+
46
+ def get_decoder_layers(self, model: "OPTForCausalLM"):
47
+ return model.model.decoder.layers if self.is_causal_lm else model.decoder.layers
48
+
49
+
50
+ class OPTAttention(DecoderOnlyAttention):
51
+ def __post_init__(self):
52
+ self.k_proj = self._original_mod.k_proj
53
+ self.v_proj = self._original_mod.v_proj
54
+ self.q_proj = self._original_mod.q_proj
55
+ self.o_proj = self._original_mod.out_proj
56
+
57
+
58
+ class OPTModel(DecoderOnlyModel):
59
+ def get_embedding(self) -> nn.Embedding:
60
+ return self._original_mod.embed_tokens
61
+
62
+ def get_pos_embedding(self):
63
+ return self._original_mod.embed_positions
64
+
65
+ def get_last_layernorm(self) -> nn.LayerNorm:
66
+ return self._original_mod.final_layer_norm
67
+
68
+
69
+ class OPTDecoderLayer(DecoderOnlyLayer):
70
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
71
+ return self._original_mod.self_attn_layer_norm
72
+
73
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
74
+ return self._original_mod.final_layer_norm
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig, RBLNPegasusModelConfig
17
+ from .modeling_pegasus import RBLNPegasusForConditionalGeneration, RBLNPegasusModel
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNPegasusModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNPegasusModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized PEGASUS models for feature extraction tasks.
25
+ """
26
+
27
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
28
+
29
+
30
+ class RBLNPegasusForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
31
+ """
32
+ Configuration class for RBLNPegasusForConditionalGeneration.
33
+
34
+ This configuration class stores the configuration parameters specific to
35
+ RBLN-optimized PEGASUS models for conditional text generation tasks.
36
+ """
37
+
38
+ support_paged_attention = True
@@ -0,0 +1,71 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ from transformers import PegasusForConditionalGeneration, PreTrainedModel
19
+
20
+ from ....utils.logging import get_logger
21
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
22
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
23
+ from .configuration_pegasus import RBLNPegasusForConditionalGenerationConfig
24
+ from .pegasus_architecture import PegasusWrapper
25
+
26
+
27
+ logger = get_logger()
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import PreTrainedModel
32
+
33
+
34
+ class RBLNPegasusModel(RBLNTransformerEncoderForFeatureExtraction):
35
+ """
36
+ RBLN optimized PEGASUS model for feature extraction tasks.
37
+
38
+ This class provides hardware-accelerated inference for PEGASUS encoder models
39
+ on RBLN devices, optimized for feature extraction use cases.
40
+ """
41
+
42
+ rbln_model_input_names = ["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask"]
43
+
44
+
45
+ class RBLNPegasusForConditionalGeneration(RBLNModelForSeq2SeqLM):
46
+ """
47
+ RBLN optimized PEGASUS model for conditional text generation tasks.
48
+
49
+ This class provides hardware-accelerated inference for PEGASUS models
50
+ on RBLN devices, supporting sequence-to-sequence generation tasks
51
+ such as summarization, translation, and text generation.
52
+ """
53
+
54
+ support_causal_attn = True
55
+
56
+ @classmethod
57
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNPegasusForConditionalGenerationConfig):
58
+ return PegasusWrapper(
59
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, use_attention_mask=rbln_config.use_attention_mask
60
+ )
61
+
62
+ def __getattr__(self, __name: str) -> Any:
63
+ def redirect(func):
64
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
65
+
66
+ val = getattr(PegasusForConditionalGeneration, __name)
67
+
68
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
69
+ return redirect(val)
70
+
71
+ return val
@@ -0,0 +1,161 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.utils import logging
21
+
22
+ from ..seq2seq.seq2seq_architecture import (
23
+ Seq2SeqCrossAttention,
24
+ Seq2SeqDecoder,
25
+ Seq2SeqDecoderLayer,
26
+ Seq2SeqDecoderWrapper,
27
+ Seq2SeqEncoderWrapper,
28
+ Seq2SeqForConditionalGeneration,
29
+ Seq2SeqSelfAttention,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class PegasusWrapper:
37
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
38
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
39
+ self.decoder = PegasusDecoderWrapper(model, use_attention_mask=use_attention_mask)
40
+
41
+
42
+ class PegasusDecoderWrapper(Seq2SeqDecoderWrapper):
43
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
44
+ new_layers = []
45
+ for layer in model.get_decoder().layers:
46
+ self_attn = PegasusSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
47
+ cross_attn = PegasusCrossAttention(layer.encoder_attn)
48
+ new_layers.append(PegasusDecoderLayer(layer, self_attn, cross_attn))
49
+
50
+ decoder_model = PegasusDecoder(model.get_decoder(), new_layers)
51
+ new_model = PegasusForConditionalGeneration(model, decoder_model)
52
+
53
+ return new_model
54
+
55
+
56
+ class PegasusForConditionalGeneration(Seq2SeqForConditionalGeneration):
57
+ pass
58
+
59
+
60
+ class PegasusDecoder(Seq2SeqDecoder):
61
+ has_pos_emb = True
62
+
63
+ def __post_init__(self):
64
+ self.embed_positions = self._original_mod.embed_positions
65
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
66
+ self.final_layer_norm = getattr(self._original_mod, "layer_norm", None)
67
+
68
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
+ if attention_mask is not None:
70
+ attention_mask = attention_mask[:, None, None, :]
71
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
72
+
73
+ return attention_mask, encoder_attention_mask
74
+
75
+ def apply_position_embedding(self, inputs_embeds, cache_position):
76
+ hidden_all = []
77
+ for i in range(inputs_embeds.shape[0]):
78
+ positions_idx = cache_position[i]
79
+ position_weight = self.embed_positions.weight
80
+ position = position_weight[positions_idx]
81
+ batch_hidden = position + inputs_embeds[i]
82
+ hidden_all.append(batch_hidden)
83
+ hidden_states = torch.stack(hidden_all, dim=0)
84
+
85
+ return hidden_states
86
+
87
+ def get_embedding(self):
88
+ if self.embed_scale is not None:
89
+ return lambda x: self.embed_tokens(x) * self.embed_scale
90
+ else:
91
+ return self.embed_tokens
92
+
93
+
94
+ class PegasusLayerFF(nn.Module):
95
+ def __init__(self, decoder_layer):
96
+ super().__init__()
97
+ self.fc1 = decoder_layer.fc1
98
+ self.fc2 = decoder_layer.fc2
99
+ self.activation_fn = decoder_layer.activation_fn
100
+ self.layer_norm = decoder_layer.final_layer_norm
101
+
102
+ def forward(self, hidden_states):
103
+ # Residual Connection
104
+ residual = hidden_states
105
+ hidden_states = self.layer_norm(hidden_states)
106
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
107
+ hidden_states = self.fc2(hidden_states)
108
+ hidden_states = residual + hidden_states
109
+ return hidden_states
110
+
111
+
112
+ class PegasusDecoderLayer(Seq2SeqDecoderLayer):
113
+ def __post_init__(self):
114
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
115
+ self.encoder_attn = self._original_mod.encoder_attn
116
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
117
+ self.ff_layer = PegasusLayerFF(self._original_mod)
118
+
119
+ def pre_self_attn_layer_norm(self, hidden_states):
120
+ return self.self_attn_layer_norm(hidden_states)
121
+
122
+ def post_self_attn_layer_norm(self, hidden_states):
123
+ return hidden_states
124
+
125
+ def pre_cross_attn_layer_norm(self, hidden_states):
126
+ return self.encoder_attn_layer_norm(hidden_states)
127
+
128
+ def post_cross_attn_layer_norm(self, hidden_states):
129
+ return hidden_states
130
+
131
+
132
+ class PegasusSelfAttention(Seq2SeqSelfAttention):
133
+ def __post_init__(self, use_attention_mask: bool = True):
134
+ self.q_proj = self._original_mod.q_proj
135
+ self.k_proj = self._original_mod.k_proj
136
+ self.v_proj = self._original_mod.v_proj
137
+ self.out_proj = self._original_mod.out_proj
138
+ self.num_heads = self._original_mod.num_heads
139
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
140
+ self.scaling = self.head_dim**-0.5
141
+ if use_attention_mask:
142
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
143
+ else:
144
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
145
+
146
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
147
+ query_states = self.q_proj(hidden_states) * self.scaling
148
+ key_states = self.k_proj(hidden_states)
149
+ value_states = self.v_proj(hidden_states)
150
+ return query_states, key_states, value_states
151
+
152
+
153
+ class PegasusCrossAttention(Seq2SeqCrossAttention):
154
+ def __post_init__(self):
155
+ self.q_proj = self._original_mod.q_proj
156
+ self.k_proj = self._original_mod.k_proj
157
+ self.v_proj = self._original_mod.v_proj
158
+ self.out_proj = self._original_mod.out_proj
159
+ self.num_heads = self._original_mod.num_heads
160
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
161
+ self.embed_dim = self._original_mod.embed_dim
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_phi import RBLNPhiForCausalLMConfig, RBLNPhiModelConfig
16
+ from .modeling_phi import RBLNPhiForCausalLM, RBLNPhiModel
@@ -0,0 +1,50 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNPhiForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Phi models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNPhiForCausalLM, RBLNPhiForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNPhiForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNPhiForCausalLM.from_pretrained(
37
+ "microsoft/phi-2",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
43
+
44
+
45
+ class RBLNPhiModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Phi models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -0,0 +1,92 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....utils import logging
16
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
17
+ from .phi_architecture import PhiWrapper
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class RBLNPhiForCausalLM(RBLNDecoderOnlyModelForCausalLM):
24
+ """
25
+ The Phi Model transformer with a language modeling head (linear layer) on top.
26
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
27
+
28
+ A class to convert and run pre-trained transformers based PhiForCausalLM model on RBLN devices.
29
+ It implements the methods to convert a pre-trained transformers PhiForCausalLM model into a RBLN transformer model by:
30
+
31
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
32
+ - compiling the resulting graph using the RBLN compiler.
33
+
34
+ **Configuration:**
35
+ This model uses [`RBLNPhiForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
36
+ the `rbln_config` parameter should be an instance of [`RBLNPhiForCausalLMConfig`] or a dictionary conforming to its structure.
37
+
38
+ See the [`RBLNPhiForCausalLMConfig`] class for all available configuration options.
39
+
40
+ Examples:
41
+ ```python
42
+ from optimum.rbln import RBLNPhiForCausalLM
43
+
44
+ # Simple usage using rbln_* arguments
45
+ # `max_seq_len` is automatically inferred from the model config
46
+ model = RBLNPhiForCausalLM.from_pretrained(
47
+ "microsoft/phi-2",
48
+ export=True,
49
+ rbln_batch_size=1,
50
+ rbln_tensor_parallel_size=4,
51
+ )
52
+
53
+
54
+ # Using a config dictionary
55
+ rbln_config = {
56
+ "batch_size": 1,
57
+ "max_seq_len": 4096,
58
+ "tensor_parallel_size": 4,
59
+ }
60
+ model = RBLNPhiForCausalLM.from_pretrained(
61
+ "microsoft/phi-2",
62
+ export=True,
63
+ rbln_config=rbln_config
64
+ )
65
+
66
+
67
+ # Using a RBLNPhiForCausalLMConfig instance (recommended for type checking)
68
+ from optimum.rbln import RBLNPhiForCausalLMConfig
69
+
70
+ config = RBLNPhiForCausalLMConfig(
71
+ batch_size=1,
72
+ max_seq_len=4096,
73
+ tensor_parallel_size=4
74
+ )
75
+ model = RBLNPhiForCausalLM.from_pretrained(
76
+ "microsoft/phi-2",
77
+ export=True,
78
+ rbln_config=config
79
+ )
80
+ ```
81
+ """
82
+
83
+ _decoder_wrapper_cls = PhiWrapper
84
+
85
+
86
+ class RBLNPhiModel(RBLNDecoderOnlyModel):
87
+ """
88
+ The Phi Model transformer without a language modeling head.
89
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
90
+ """
91
+
92
+ _decoder_wrapper_cls = PhiWrapper
@@ -0,0 +1,115 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Optional, Tuple, Union
16
+
17
+ import torch
18
+ from transformers import PhiForCausalLM
19
+
20
+ from ..decoderonly.decoderonly_architecture import (
21
+ DecoderOnlyAttention,
22
+ DecoderOnlyLayer,
23
+ DecoderOnlyModel,
24
+ DecoderOnlyWrapper,
25
+ apply_rotary_pos_emb_partial,
26
+ )
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import PhiForCausalLM, PhiModel
31
+
32
+
33
+ class PhiWrapper(DecoderOnlyWrapper):
34
+ def get_rbln_attn_class(self):
35
+ return PhiAttention
36
+
37
+ def get_rbln_layer_class(self):
38
+ return PhiLayer
39
+
40
+ def get_rbln_model_class(self):
41
+ return PhiModel
42
+
43
+ def get_model_layer(self, model: Union["PhiForCausalLM", "PhiModel"]):
44
+ return model.model if self.is_causal_lm else model
45
+
46
+ def get_decoder_layers(self, model: Union["PhiForCausalLM", "PhiModel"]):
47
+ return model.model.layers if self.is_causal_lm else model.layers
48
+
49
+
50
+ class PhiAttention(DecoderOnlyAttention):
51
+ def __post_init__(self):
52
+ self.q_proj = self._original_mod.q_proj
53
+ self.k_proj = self._original_mod.k_proj
54
+ self.v_proj = self._original_mod.v_proj
55
+ self.o_proj = self._original_mod.dense
56
+ self.qk_layernorm = self._original_mod.qk_layernorm
57
+ self.rotary_ndims = self._original_mod.rotary_ndims
58
+
59
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
60
+ if lora_int_id is not None:
61
+ raise NotImplementedError("LoRA is not supported for PhiAttention")
62
+
63
+ query_states = self.q_proj(hidden_states)
64
+ key_states = self.k_proj(hidden_states)
65
+ value_states = self.v_proj(hidden_states)
66
+
67
+ if self.qk_layernorm:
68
+ query_states = self._original_mod.q_layernorm(query_states)
69
+ key_states = self._original_mod.k_layernorm(key_states)
70
+
71
+ return query_states, key_states, value_states
72
+
73
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
74
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=self.rotary_ndims)
75
+
76
+
77
+ class PhiLayer(DecoderOnlyLayer):
78
+ def get_post_attention_layernorm(self):
79
+ raise NotImplementedError
80
+
81
+ def forward(
82
+ self,
83
+ hidden_states: torch.Tensor,
84
+ attention_mask: torch.Tensor,
85
+ seq_positions: torch.LongTensor,
86
+ past_key_values: Tuple[Tuple[torch.Tensor]],
87
+ cos: Optional[torch.Tensor] = None,
88
+ sin: Optional[torch.Tensor] = None,
89
+ block_tables: Optional[torch.Tensor] = None,
90
+ lora_int_id: Optional[torch.Tensor] = None,
91
+ ):
92
+ residual = hidden_states
93
+
94
+ hidden_states = self.get_pre_attention_layernorm()(hidden_states)
95
+
96
+ attn_output = self.self_attn(
97
+ hidden_states=hidden_states,
98
+ attention_mask=attention_mask,
99
+ seq_positions=seq_positions,
100
+ past_key_values=past_key_values,
101
+ cos=cos,
102
+ sin=sin,
103
+ block_tables=block_tables,
104
+ )
105
+
106
+ feed_forward_hidden_states = self._original_mod.mlp(hidden_states)
107
+
108
+ hidden_states = attn_output + feed_forward_hidden_states + residual
109
+
110
+ return hidden_states
111
+
112
+
113
+ class PhiModel(DecoderOnlyModel):
114
+ def get_last_layernorm(self):
115
+ return self._original_mod.final_layernorm
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_pixtral import RBLNPixtralVisionModelConfig
16
+ from .modeling_pixtral import RBLNPixtralVisionModel