optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,144 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import TYPE_CHECKING, Tuple
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ..decoderonly.decoderonly_architecture import (
22
+ DecoderOnlyAttention,
23
+ DecoderOnlyLayer,
24
+ DecoderOnlyModel,
25
+ DecoderOnlyWrapper,
26
+ apply_rotary_pos_emb_partial,
27
+ rotate_half,
28
+ )
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import PreTrainedModel as MidmLMHeadModel
33
+
34
+
35
+ def apply_rotary_to_tensor(tensor, cos, sin, rot_dim):
36
+ """Applies rotary position embedding to the specified dimension of the tensor."""
37
+ tensor_, tensor_pass = tensor[..., :rot_dim], tensor[..., rot_dim:]
38
+ tensor_embed = (tensor_ * cos) + (rotate_half(tensor_) * sin)
39
+ return torch.cat((tensor_embed, tensor_pass), dim=-1)
40
+
41
+
42
+ def apply_rotary_pos_emb(q, k, cos, sin):
43
+ """Applies Rotary Position Embedding to the query and key tensors."""
44
+ rot_dim = cos.shape[-1]
45
+ q_embed = apply_rotary_to_tensor(q, cos, sin, rot_dim)
46
+ k_embed = apply_rotary_to_tensor(k, cos, sin, rot_dim)
47
+ return q_embed, k_embed
48
+
49
+
50
+ class MidmLMHeadModelWrapper(DecoderOnlyWrapper):
51
+ def get_rotary_emb(self, max_seq_len):
52
+ self.config.rope_theta = 10000
53
+ self.config.head_dim = self.config.n_embd // self.config.n_head
54
+ self.config.partial_rotary_factor = self.config.rotary_percentage
55
+ return super().get_rotary_emb(max_seq_len=max_seq_len)
56
+
57
+ def get_rbln_attn_class(self):
58
+ return MidmAttention
59
+
60
+ def get_rbln_layer_class(self):
61
+ return MidmLayer
62
+
63
+ def get_rbln_model_class(self):
64
+ return MidmModel
65
+
66
+ def get_model_layer(self, causal_lm: "MidmLMHeadModel"):
67
+ return causal_lm.transformer
68
+
69
+ def get_decoder_layers(self, causal_lm: "MidmLMHeadModel"):
70
+ return causal_lm.transformer.h
71
+
72
+
73
+ class MidmModel(DecoderOnlyModel):
74
+ def get_layernorm1p(self, module: nn.LayerNorm):
75
+ def layernorm1p(input: torch.Tensor):
76
+ """Applies Layer Normalization with a slight modification on the weights."""
77
+ return torch.nn.functional.layer_norm(
78
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
79
+ )
80
+
81
+ return layernorm1p
82
+
83
+ def get_last_layernorm(self) -> nn.LayerNorm:
84
+ if self._original_mod.use_layernorm1p:
85
+ return self.get_layernorm1p(self._original_mod.ln_f)
86
+ else:
87
+ return self._original_mod.ln_f
88
+
89
+ def get_embedding(self) -> nn.Embedding:
90
+ return self._original_mod.wte
91
+
92
+ def get_pos_embedding(self) -> nn.Embedding:
93
+ return self._original_mod.wpe
94
+
95
+
96
+ class MidmLayer(DecoderOnlyLayer):
97
+ def get_layernorm1p(self, module: nn.LayerNorm):
98
+ def layernorm1p(input: torch.Tensor):
99
+ """Applies Layer Normalization with a slight modification on the weights."""
100
+ return torch.nn.functional.layer_norm(
101
+ input, module.normalized_shape, module.weight + 1, module.bias, module.eps
102
+ )
103
+
104
+ return layernorm1p
105
+
106
+ def get_pre_attention_layernorm(self) -> nn.LayerNorm:
107
+ if self._original_mod.use_layernorm1p:
108
+ return self.get_layernorm1p(self._original_mod.ln_1)
109
+ else:
110
+ return self._original_mod.ln_1
111
+
112
+ def get_post_attention_layernorm(self) -> nn.LayerNorm:
113
+ if self._original_mod.use_layernorm1p:
114
+ return self.get_layernorm1p(self._original_mod.ln_2)
115
+ else:
116
+ return self._original_mod.ln_2
117
+
118
+
119
+ class MidmAttention(DecoderOnlyAttention):
120
+ def __post_init__(self):
121
+ self.c_attn = self._original_mod.c_attn
122
+ self.o_proj = self._original_mod.c_proj
123
+ self.split_size = self._original_mod.split_size
124
+ self.num_key_value_heads = self._original_mod.num_heads
125
+
126
+ def projection(self, hidden_states, lora_int_id) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
127
+ if lora_int_id is not None:
128
+ raise NotImplementedError("LoRA is not supported for MidmAttention")
129
+
130
+ query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
131
+ return query_states, key_states, value_states
132
+
133
+ def get_attn_scale(self):
134
+ scale = 1.0
135
+ if self._original_mod.scale_attn_weights:
136
+ scale /= math.sqrt(self.head_dim)
137
+
138
+ if self._original_mod.scale_attn_by_inverse_layer_idx and not self._original_mod.scale_qk_by_inverse_layer_idx:
139
+ scale /= 1 + self.layer_idx
140
+
141
+ return scale
142
+
143
+ def apply_rotary_pos_embed(self, query_states, key_states, cos, sin):
144
+ return apply_rotary_pos_emb_partial(query_states, key_states, cos, sin, ndim=cos.shape[-1])
@@ -0,0 +1,144 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from pathlib import Path
17
+ from typing import Any, Callable, Dict, Optional, Union
18
+
19
+ from transformers import AutoModelForCausalLM
20
+ from transformers.generation.utils import GenerationMixin
21
+
22
+ from ....configuration_utils import RBLNModelConfig
23
+ from ....utils import logging
24
+ from ..decoderonly import RBLNDecoderOnlyModelForCausalLM
25
+ from .midm_architecture import MidmLMHeadModelWrapper
26
+
27
+
28
+ logger = logging.get_logger(__name__)
29
+
30
+
31
+ class RBLNMidmLMHeadModel(RBLNDecoderOnlyModelForCausalLM):
32
+ """
33
+ The MIDM Model transformer with a language modeling head (linear layer) on top.
34
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
35
+
36
+ A class to convert and run pre-trained transformers based MidmForCausalLM model on RBLN devices.
37
+ It implements the methods to convert a pre-trained transformers MidmForCausalLM model into a RBLN transformer model by:
38
+
39
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
40
+ - compiling the resulting graph using the RBLN compiler.
41
+
42
+ **Configuration:**
43
+ This model uses [`RBLNMidmLMHeadModelConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
44
+ the `rbln_config` parameter should be an instance of [`RBLNMidmLMHeadModelConfig`] or a dictionary conforming to its structure.
45
+
46
+ See the [`RBLNMidmLMHeadModelConfig`] class for all available configuration options.
47
+
48
+ Examples:
49
+ ```python
50
+ from optimum.rbln import RBLNMidmLMHeadModel
51
+
52
+ # Simple usage using rbln_* arguments
53
+ # `max_seq_len` is automatically inferred from the model config
54
+ model = RBLNMidmLMHeadModel.from_pretrained(
55
+ "KT-AI/midm-bitext-S-7B-inst-v1",
56
+ export=True,
57
+ rbln_batch_size=1,
58
+ rbln_tensor_parallel_size=4,
59
+ )
60
+
61
+
62
+ # Using a config dictionary
63
+ rbln_config = {
64
+ "batch_size": 1,
65
+ "max_seq_len": 4096,
66
+ "tensor_parallel_size": 4,
67
+ }
68
+ model = RBLNMidmLMHeadModel.from_pretrained(
69
+ "KT-AI/midm-bitext-S-7B-inst-v1",
70
+ export=True,
71
+ rbln_config=rbln_config
72
+ )
73
+
74
+
75
+ # Using a RBLNMidmLMHeadModelConfig instance (recommended for type checking)
76
+ from optimum.rbln import RBLNMidmLMHeadModelConfig
77
+
78
+ config = RBLNMidmLMHeadModelConfig(
79
+ batch_size=1,
80
+ max_seq_len=4096,
81
+ tensor_parallel_size=4
82
+ )
83
+ model = RBLNMidmLMHeadModel.from_pretrained(
84
+ "KT-AI/midm-bitext-S-7B-inst-v1",
85
+ export=True,
86
+ rbln_config=config
87
+ )
88
+ ```
89
+ """
90
+
91
+ _decoder_wrapper_cls = MidmLMHeadModelWrapper
92
+ _hf_class = AutoModelForCausalLM
93
+ _supports_cache_class = True
94
+
95
+ @classmethod
96
+ def from_pretrained(
97
+ cls,
98
+ model_id: Union[str, Path],
99
+ *,
100
+ export: Optional[bool] = None,
101
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
102
+ trust_remote_code: Optional[bool] = None,
103
+ **kwargs: Any,
104
+ ):
105
+ """
106
+ The `from_pretrained()` function is utilized in its standard form as in the HuggingFace transformers library.
107
+ User can use this function to load a pre-trained model from the HuggingFace library and convert it to a RBLN model to be run on RBLN NPUs.
108
+
109
+ Args:
110
+ model_id (Union[str, Path]): The model id of the pre-trained model to be loaded.
111
+ It can be downloaded from the HuggingFace model hub or a local path, or a model id of a compiled model using the RBLN Compiler.
112
+ export (Optional[bool]): A boolean flag to indicate whether the model should be compiled.
113
+ If None, it will be determined based on the existence of the compiled model files in the model_id.
114
+ rbln_config (Optional[Union[Dict, RBLNModelConfig]]): Configuration for RBLN model compilation and runtime.
115
+ This can be provided as a dictionary or an instance of the model's configuration class (e.g., `RBLNMidmLMHeadModelConfig` for Mi:dm models).
116
+ For detailed configuration options, see the specific model's configuration class documentation.
117
+ trust_remote_code (bool): Whether or not to trust the remote code when loading a model from the Hub.
118
+ kwargs: Additional keyword arguments. Arguments with the prefix `rbln_` are passed to rbln_config, while the remaining arguments are passed to the HuggingFace library.
119
+
120
+ Returns:
121
+ (RBLNModel): A RBLN model instance ready for inference on RBLN NPU devices.
122
+ """
123
+
124
+ if trust_remote_code is not None:
125
+ kwargs["trust_remote_code"] = trust_remote_code
126
+ elif "trust_remote_code" not in kwargs:
127
+ kwargs["trust_remote_code"] = True
128
+
129
+ return super().from_pretrained(
130
+ model_id=model_id,
131
+ export=export,
132
+ rbln_config=rbln_config,
133
+ **kwargs,
134
+ )
135
+
136
+ def __getattr__(self, __name: str) -> Any:
137
+ def redirect(func):
138
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
139
+
140
+ val = getattr(GenerationMixin, __name)
141
+
142
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
143
+ return redirect(val)
144
+ return val
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_mistral import RBLNMistralForCausalLMConfig, RBLNMistralModelConfig
16
+ from .modeling_mistral import RBLNMistralForCausalLM, RBLNMistralModel
@@ -0,0 +1,50 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNMistralForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for RBLN Mistral models.
21
+
22
+ This class is an alias of RBLNDecoderOnlyModelForCausalLMConfig.
23
+
24
+ Example usage:
25
+ ```python
26
+ from optimum.rbln import RBLNMistralForCausalLM, RBLNMistralForCausalLMConfig
27
+
28
+ # Create a configuration object
29
+ config = RBLNMistralForCausalLMConfig(
30
+ batch_size=1,
31
+ max_seq_len=4096,
32
+ tensor_parallel_size=4
33
+ )
34
+
35
+ # Use the configuration with from_pretrained
36
+ model = RBLNMistralForCausalLM.from_pretrained(
37
+ "mistralai/Mistral-7B-v0.1",
38
+ export=True,
39
+ rbln_config=config
40
+ )
41
+ ```
42
+ """
43
+
44
+
45
+ class RBLNMistralModelConfig(RBLNDecoderOnlyModelConfig):
46
+ """
47
+ Configuration class for RBLN Mistral models.
48
+
49
+ This class is an alias of RBLNDecoderOnlyModelConfig.
50
+ """
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.decoderonly_architecture import DecoderOnlyWrapper
16
+
17
+
18
+ class MistralWrapper(DecoderOnlyWrapper):
19
+ pass
@@ -0,0 +1,115 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers import PretrainedConfig
16
+
17
+ from ....utils import logging
18
+ from ...models.decoderonly import (
19
+ RBLNDecoderOnlyModel,
20
+ RBLNDecoderOnlyModelForCausalLM,
21
+ RBLNDecoderOnlyModelForCausalLMConfig,
22
+ )
23
+ from .mistral_architecture import MistralWrapper
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+ class RBLNMistralForCausalLM(RBLNDecoderOnlyModelForCausalLM):
30
+ """
31
+ The Mistral Model transformer with a language modeling head (linear layer) on top.
32
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
33
+
34
+ A class to convert and run pre-trained transformers based MistralForCausalLM model on RBLN devices.
35
+ It implements the methods to convert a pre-trained transformers MistralForCausalLM model into a RBLN transformer model by:
36
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
37
+ - compiling the resulting graph using the RBLN compiler.
38
+
39
+ **Configuration:**
40
+ This model uses [`RBLNMistralForCausalLMConfig`] for configuration. When calling methods like `from_pretrained` or `from_model`,
41
+ the `rbln_config` parameter should be an instance of [`RBLNMistralForCausalLMConfig`] or a dictionary conforming to its structure.
42
+
43
+ See the [`RBLNMistralForCausalLMConfig`] class for all available configuration options.
44
+
45
+ Examples:
46
+ ```python
47
+ from optimum.rbln import RBLNMistralForCausalLM
48
+
49
+ # Simple usage using rbln_* arguments
50
+ # `max_seq_len` is automatically inferred from the model config
51
+ model = RBLNMistralForCausalLM.from_pretrained(
52
+ "mistralai/Mistral-7B-v0.1",
53
+ export=True,
54
+ rbln_batch_size=1,
55
+ rbln_tensor_parallel_size=4,
56
+ )
57
+
58
+ # Using a config dictionary
59
+ rbln_config = {
60
+ "batch_size": 1,
61
+ "max_seq_len": 4096,
62
+ "tensor_parallel_size": 4,
63
+ }
64
+ model = RBLNMistralForCausalLM.from_pretrained(
65
+ "mistralai/Mistral-7B-v0.1",
66
+ export=True,
67
+ rbln_config=rbln_config
68
+ )
69
+
70
+ # Using a RBLNMistralForCausalLMConfig instance (recommended for type checking)
71
+ from optimum.rbln import RBLNMistralForCausalLMConfig
72
+
73
+ config = RBLNMistralForCausalLMConfig(
74
+ batch_size=1,
75
+ max_seq_len=4096,
76
+ tensor_parallel_size=4
77
+ )
78
+ model = RBLNMistralForCausalLM.from_pretrained(
79
+ "mistralai/Mistral-7B-v0.1",
80
+ export=True,
81
+ rbln_config=config
82
+ )
83
+ ```
84
+ """
85
+
86
+ _decoder_wrapper_cls = MistralWrapper
87
+
88
+ @classmethod
89
+ def _update_sliding_window_config(
90
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
91
+ ):
92
+ rbln_config.cache_impl = "sliding_window"
93
+ rbln_config.sliding_window = model_config.sliding_window
94
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
95
+
96
+ return rbln_config
97
+
98
+
99
+ class RBLNMistralModel(RBLNDecoderOnlyModel):
100
+ """
101
+ The Mistral Model transformer without a language modeling head.
102
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
103
+ """
104
+
105
+ _decoder_wrapper_cls = MistralWrapper
106
+
107
+ @classmethod
108
+ def _update_sliding_window_config(
109
+ cls, model_config: PretrainedConfig, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig
110
+ ):
111
+ rbln_config.cache_impl = "sliding_window"
112
+ rbln_config.sliding_window = model_config.sliding_window
113
+ rbln_config.sliding_window_layers = list(range(model_config.num_hidden_layers))
114
+
115
+ return rbln_config
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_opt import RBLNOPTForCausalLMConfig, RBLNOPTModelConfig
16
+ from .modeling_opt import RBLNOPTForCausalLM, RBLNOPTModel
@@ -0,0 +1,29 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelConfig, RBLNDecoderOnlyModelForCausalLMConfig
16
+
17
+
18
+ class RBLNOPTForCausalLMConfig(RBLNDecoderOnlyModelForCausalLMConfig):
19
+ """
20
+ Configuration class for OPT causal language model.
21
+ Inherits from RBLNDecoderOnlyModelForCausalLMConfig with no additional parameters.
22
+ """
23
+
24
+
25
+ class RBLNOPTModelConfig(RBLNDecoderOnlyModelConfig):
26
+ """
27
+ Configuration class for OPT model.
28
+ Inherits from RBLNDecoderOnlyModelConfig with no additional parameters.
29
+ """
@@ -0,0 +1,102 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+ from transformers import PreTrainedModel
17
+
18
+ from ....utils import logging
19
+ from ...models.decoderonly import RBLNDecoderOnlyModel, RBLNDecoderOnlyModelForCausalLM
20
+ from ...models.decoderonly.configuration_decoderonly import RBLNDecoderOnlyModelForCausalLMConfig
21
+ from .opt_architecture import OPTWrapper
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class MLP(nn.Module):
28
+ def __init__(self, fc1, fc2, activation_fn):
29
+ super(MLP, self).__init__()
30
+ self.fc1 = fc1
31
+ self.fc2 = fc2
32
+ self.activation_fn = activation_fn
33
+
34
+ def forward(self, x):
35
+ x = self.fc1(x)
36
+ x = self.activation_fn(x)
37
+ x = self.fc2(x)
38
+ return x
39
+
40
+
41
+ class RBLNOPTForCausalLM(RBLNDecoderOnlyModelForCausalLM):
42
+ """
43
+ The OPT Model transformer with a language modeling head (linear layer) on top.
44
+ This model inherits from [`RBLNDecoderOnlyModelForCausalLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
45
+
46
+ A class to convert and run pre-trained transformers based OPTForCausalLM model on RBLN devices.
47
+ It implements the methods to convert a pre-trained transformers OPTForCausalLM model into a RBLN transformer model by:
48
+
49
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
50
+ - compiling the resulting graph using the RBLN compiler.
51
+
52
+ **Configuration:**
53
+ This model uses [`RBLNOPTForCausalLM`] for configuration. When calling methods like `from_pretrained` or `from_model`,
54
+ the `rbln_config` parameter should be an instance of [`RBLNOPTForCausalLM`] or a dictionary conforming to its structure.
55
+
56
+ See the [`RBLNOPTForCausalLM`] class for all available configuration options.
57
+ """
58
+
59
+ _decoder_wrapper_cls = OPTWrapper
60
+ _use_rotary_emb = False
61
+
62
+ def modify_opt_decoder_layer(layer):
63
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
64
+ layer.mlp = mlp
65
+ del layer.fc1
66
+ del layer.fc2
67
+ del layer.activation_fn
68
+
69
+ return layer
70
+
71
+ @classmethod
72
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
73
+ for i in range(len(model.model.decoder.layers)):
74
+ model.model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.model.decoder.layers[i])
75
+
76
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()
77
+
78
+
79
+ class RBLNOPTModel(RBLNDecoderOnlyModel):
80
+ """
81
+ The OPT Model transformer without a language modeling head.
82
+ This model inherits from [`RBLNDecoderOnlyModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
83
+ """
84
+
85
+ _decoder_wrapper_cls = OPTWrapper
86
+ _use_rotary_emb = False
87
+
88
+ def modify_opt_decoder_layer(layer):
89
+ mlp = MLP(layer.fc1, layer.fc2, layer.activation_fn)
90
+ layer.mlp = mlp
91
+ del layer.fc1
92
+ del layer.fc2
93
+ del layer.activation_fn
94
+
95
+ return layer
96
+
97
+ @classmethod
98
+ def _wrap_model_if_needed(cls, model: PreTrainedModel, rbln_config: RBLNDecoderOnlyModelForCausalLMConfig):
99
+ for i in range(len(model.decoder.layers)):
100
+ model.decoder.layers[i] = cls.modify_opt_decoder_layer(model.decoder.layers[i])
101
+
102
+ return cls._decoder_wrapper_cls(model, rbln_config=rbln_config, use_rotary_emb=cls._use_rotary_emb).eval()