optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,267 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import importlib
15
+ import inspect
16
+ import warnings
17
+ from pathlib import Path
18
+ from typing import Any, Dict, Optional, Type, Union
19
+
20
+ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel
21
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
22
+ from transformers.models.auto.auto_factory import _get_model_class
23
+
24
+ from optimum.rbln.configuration_utils import RBLNAutoConfig, RBLNModelConfig
25
+ from optimum.rbln.modeling_base import RBLNBaseModel
26
+ from optimum.rbln.utils.model_utils import (
27
+ MODEL_MAPPING,
28
+ convert_hf_to_rbln_model_name,
29
+ convert_rbln_to_hf_model_name,
30
+ get_rbln_model_cls,
31
+ )
32
+
33
+
34
+ class _BaseAutoModelClass:
35
+ # Base class for auto models.
36
+ _model_mapping = None
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ raise EnvironmentError(
40
+ f"{self.__class__.__name__} is designed to be instantiated "
41
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`"
42
+ )
43
+
44
+ @classmethod
45
+ def get_rbln_cls(
46
+ cls,
47
+ pretrained_model_name_or_path: Union[str, Path],
48
+ *args: Any,
49
+ export: bool = None,
50
+ **kwargs: Any,
51
+ ):
52
+ """
53
+ Determine the appropriate RBLN model class based on the given model ID and configuration.
54
+
55
+ Args:
56
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
57
+ export (bool): Whether to infer the class based on HuggingFace (HF) architecture.
58
+ kwargs: Additional arguments for configuration and loading.
59
+
60
+ Returns:
61
+ RBLNBaseModel: The corresponding RBLN model class.
62
+ """
63
+ if isinstance(pretrained_model_name_or_path, Path):
64
+ pretrained_model_name_or_path = pretrained_model_name_or_path.as_posix()
65
+
66
+ if export is None:
67
+ export = not RBLNBaseModel._is_compiled(
68
+ model_id=pretrained_model_name_or_path,
69
+ token=kwargs.get("token"),
70
+ revision=kwargs.get("revision"),
71
+ force_download=kwargs.get("force_download", False),
72
+ cache_dir=kwargs.get("cache_dir"),
73
+ subfolder=kwargs.get("subfolder", ""),
74
+ local_files_only=kwargs.get("local_files_only", False),
75
+ )
76
+
77
+ if export:
78
+ hf_model_class = cls.infer_hf_model_class(pretrained_model_name_or_path, **kwargs)
79
+ rbln_class_name = convert_hf_to_rbln_model_name(hf_model_class.__name__)
80
+ else:
81
+ rbln_class_name = cls.get_rbln_model_cls_name(pretrained_model_name_or_path, **kwargs)
82
+
83
+ if convert_rbln_to_hf_model_name(rbln_class_name) not in cls._model_mapping_names.values():
84
+ raise ValueError(
85
+ f"The architecture '{rbln_class_name}' is not supported by the `{cls.__name__}.from_pretrained()` method. "
86
+ "Please use the `from_pretrained()` method of the appropriate class to load this model, "
87
+ f"or directly use '{rbln_class_name}.from_pretrained()`."
88
+ )
89
+
90
+ try:
91
+ rbln_cls = get_rbln_model_cls(rbln_class_name)
92
+ except AttributeError as e:
93
+ raise AttributeError(
94
+ f"Class '{rbln_class_name}' not found in 'optimum.rbln' module for model ID '{pretrained_model_name_or_path}'. "
95
+ "Ensure that the class name is correctly mapped and available in the 'optimum.rbln' module."
96
+ ) from e
97
+
98
+ return rbln_cls
99
+
100
+ @classmethod
101
+ def infer_hf_model_class(
102
+ cls,
103
+ pretrained_model_name_or_path: Union[str, Path],
104
+ *args: Any,
105
+ **kwargs: Any,
106
+ ):
107
+ """
108
+ Infer the HuggingFace model class based on the configuration or model name.
109
+
110
+ Args:
111
+ pretrained_model_name_or_path (str): Identifier or path to the pretrained model.
112
+ kwargs: Additional arguments for configuration and loading.
113
+
114
+ Returns:
115
+ PretrainedModel: The inferred HuggingFace model class.
116
+ """
117
+
118
+ # Try to load configuration if provided or retrieve it from the model ID
119
+ config = kwargs.pop("config", None)
120
+ kwargs.update({"trust_remote_code": True})
121
+ kwargs["_from_auto"] = True
122
+
123
+ # Load configuration if not already provided
124
+ if not isinstance(config, PretrainedConfig):
125
+ config, kwargs = AutoConfig.from_pretrained(
126
+ pretrained_model_name_or_path,
127
+ return_unused_kwargs=True,
128
+ **kwargs,
129
+ )
130
+
131
+ # Get hf_model_class from Config
132
+ has_remote_code = (
133
+ hasattr(config, "auto_map") and convert_rbln_to_hf_model_name(cls.__name__) in config.auto_map
134
+ )
135
+ if has_remote_code:
136
+ class_ref = config.auto_map[convert_rbln_to_hf_model_name(cls.__name__)]
137
+ model_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs)
138
+ elif type(config) in cls._model_mapping.keys():
139
+ model_class = _get_model_class(config, cls._model_mapping)
140
+ else:
141
+ raise ValueError(
142
+ f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
143
+ f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
144
+ )
145
+
146
+ if model_class.__name__ != config.architectures[0]:
147
+ warnings.warn(
148
+ f"`{cls.__name__}.from_pretrained()` is invoking `{convert_hf_to_rbln_model_name(model_class.__name__)}.from_pretrained()`, which does not match the "
149
+ f"expected architecture `RBLN{config.architectures[0]}` from config. This mismatch could cause some operations to not be properly loaded "
150
+ f"from the checkpoint, leading to potential unintended behavior. If this is not intentional, consider calling the "
151
+ f"`from_pretrained()` method directly from the `RBLN{config.architectures[0]}` class instead.",
152
+ UserWarning,
153
+ )
154
+
155
+ return model_class
156
+
157
+ @classmethod
158
+ def get_rbln_model_cls_name(cls, pretrained_model_name_or_path: Union[str, Path], **kwargs):
159
+ """
160
+ Retrieve the path to the compiled model directory for a given RBLN model.
161
+
162
+ Args:
163
+ pretrained_model_name_or_path (str): Identifier of the model.
164
+ kwargs: Additional arguments that match the parameters of `_load_compiled_model_dir`.
165
+
166
+ Returns:
167
+ str: Path to the compiled model directory.
168
+ """
169
+ sig = inspect.signature(RBLNBaseModel._load_compiled_model_dir)
170
+ valid_params = sig.parameters.keys()
171
+ filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
172
+
173
+ model_path_subfolder = RBLNBaseModel._load_compiled_model_dir(
174
+ model_id=pretrained_model_name_or_path, **filtered_kwargs
175
+ )
176
+ rbln_config = RBLNAutoConfig.load(model_path_subfolder)
177
+
178
+ return rbln_config.rbln_model_cls_name
179
+
180
+ @classmethod
181
+ def from_pretrained(
182
+ cls,
183
+ model_id: Union[str, Path],
184
+ export: bool = None,
185
+ rbln_config: Optional[Union[Dict, RBLNModelConfig]] = None,
186
+ **kwargs,
187
+ ):
188
+ """
189
+ Load an RBLN-accelerated model from a pretrained checkpoint or a compiled RBLN artifact.
190
+
191
+ This convenience method determines the concrete `RBLN*` model class that matches the
192
+ underlying HuggingFace architecture and dispatches to that class's
193
+ `from_pretrained()` implementation. Depending on whether a compiled RBLN folder is
194
+ detected (or if `export=True` is passed), it will either:
195
+
196
+ - Compile from a HuggingFace checkpoint to an RBLN model
197
+ - Or load an already-compiled RBLN model directory/repository
198
+
199
+ Args:
200
+ model_id:
201
+ HF repo id or local path. For compiled models, this should point to a directory
202
+ (optionally under `subfolder`) that contains `*.rbln` files and `rbln_config.json`.
203
+ export:
204
+ Force compilation from a HuggingFace checkpoint. When `None`, this is inferred by
205
+ checking whether compiled artifacts exist at `model_id`.
206
+ rbln_config:
207
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
208
+ instance of the specific model's config class (e.g., `RBLNLlamaForCausalLMConfig`).
209
+ kwargs: Additional keyword arguments.
210
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
211
+ - Remaining arguments are forwarded to the HuggingFace loader (e.g., `revision`,
212
+ `token`, `trust_remote_code`, `cache_dir`, `subfolder`, `local_files_only`).
213
+
214
+ Returns:
215
+ An instantiated RBLN model ready for inference on RBLN NPUs.
216
+ """
217
+ rbln_cls = cls.get_rbln_cls(model_id, export=export, **kwargs)
218
+ return rbln_cls.from_pretrained(model_id, export=export, rbln_config=rbln_config, **kwargs)
219
+
220
+ @classmethod
221
+ def from_model(
222
+ cls,
223
+ model: PreTrainedModel,
224
+ config: Optional[PretrainedConfig] = None,
225
+ rbln_config: Optional[Union[RBLNModelConfig, Dict]] = None,
226
+ **kwargs: Any,
227
+ ) -> RBLNBaseModel:
228
+ """
229
+ Convert and compile an in-memory HuggingFace model into an RBLN model.
230
+
231
+ This method resolves the appropriate concrete `RBLN*` class from the input model's class
232
+ name (e.g., `LlamaForCausalLM` -> `RBLNLlamaForCausalLM`) and then delegates to that
233
+ class's `from_model()` implementation.
234
+
235
+ Args:
236
+ model: A HuggingFace model instance to convert.
237
+ config: The configuration object associated with the model.
238
+ rbln_config:
239
+ RBLN compilation/runtime configuration. May be provided as a dictionary or as an
240
+ instance of the specific model's config class.
241
+ kwargs: Additional keyword arguments.
242
+ - Arguments prefixed with `rbln_` are forwarded to the RBLN config.
243
+
244
+ Returns:
245
+ An instantiated RBLN model ready for inference on RBLN NPUs.
246
+ """
247
+ rbln_cls = get_rbln_model_cls(f"RBLN{model.__class__.__name__}")
248
+ return rbln_cls.from_model(model, config=config, rbln_config=rbln_config, **kwargs)
249
+
250
+ @staticmethod
251
+ def register(rbln_cls: Type[RBLNBaseModel], exist_ok: bool = False):
252
+ """
253
+ Register a new RBLN model class.
254
+
255
+ Args:
256
+ rbln_cls (Type[RBLNBaseModel]): The RBLN model class to register.
257
+ exist_ok (bool): Whether to allow registering an already registered model.
258
+ """
259
+ if not issubclass(rbln_cls, RBLNBaseModel):
260
+ raise ValueError("`rbln_cls` must be a subclass of RBLNBaseModel.")
261
+
262
+ native_cls = getattr(importlib.import_module("optimum.rbln"), rbln_cls.__name__, None)
263
+ if rbln_cls.__name__ in MODEL_MAPPING or native_cls is not None:
264
+ if not exist_ok:
265
+ raise ValueError(f"Model for {rbln_cls.__name__} already registered.")
266
+
267
+ MODEL_MAPPING[rbln_cls.__name__] = rbln_cls
@@ -0,0 +1,162 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from transformers.models.auto.modeling_auto import (
16
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING,
17
+ MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES,
18
+ MODEL_FOR_CAUSAL_LM_MAPPING,
19
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
20
+ MODEL_FOR_CTC_MAPPING,
21
+ MODEL_FOR_CTC_MAPPING_NAMES,
22
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING,
23
+ MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES,
24
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
25
+ MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES,
26
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING,
27
+ MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES,
28
+ MODEL_FOR_MASKED_LM_MAPPING,
29
+ MODEL_FOR_MASKED_LM_MAPPING_NAMES,
30
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING,
31
+ MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
32
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
33
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
34
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
35
+ MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES,
36
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
37
+ MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES,
38
+ MODEL_FOR_TEXT_ENCODING_MAPPING,
39
+ MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES,
40
+ MODEL_FOR_VISION_2_SEQ_MAPPING,
41
+ MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES,
42
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING,
43
+ MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES,
44
+ MODEL_MAPPING,
45
+ MODEL_MAPPING_NAMES,
46
+ )
47
+
48
+ from .auto_factory import _BaseAutoModelClass
49
+
50
+
51
+ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.update(
52
+ {
53
+ "midm": "MidmLMHeadModel",
54
+ "exaone": "ExaoneForCausalLM",
55
+ }
56
+ )
57
+
58
+
59
+ class RBLNAutoModel(_BaseAutoModelClass):
60
+ """Automatically detect all supported transformers models."""
61
+
62
+ _model_mapping = MODEL_MAPPING
63
+ _model_mapping_names = MODEL_MAPPING_NAMES
64
+
65
+
66
+ class RBLNAutoModelForCTC(_BaseAutoModelClass):
67
+ """Automatically detect Connectionist Temporal Classification (CTC) head Models."""
68
+
69
+ _model_mapping = MODEL_FOR_CTC_MAPPING
70
+ _model_mapping_names = MODEL_FOR_CTC_MAPPING_NAMES
71
+
72
+
73
+ class RBLNAutoModelForCausalLM(_BaseAutoModelClass):
74
+ """Automatically detect Casual Language Models."""
75
+
76
+ """"""
77
+ _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING
78
+ _model_mapping_names = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
79
+
80
+
81
+ class RBLNAutoModelForSeq2SeqLM(_BaseAutoModelClass):
82
+ """Automatically detect Sequence to Sequence Language Models."""
83
+
84
+ _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
85
+ _model_mapping_names = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
86
+
87
+
88
+ class RBLNAutoModelForSpeechSeq2Seq(_BaseAutoModelClass):
89
+ """Automatically detect Sequence to Sequence Generation Models."""
90
+
91
+ _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING
92
+ _model_mapping_names = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES
93
+
94
+
95
+ class RBLNAutoModelForDepthEstimation(_BaseAutoModelClass):
96
+ """Automatically detect Speech Sequence to Sequence Language Models."""
97
+
98
+ _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING
99
+ _model_mapping_names = MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES
100
+
101
+
102
+ class RBLNAutoModelForSequenceClassification(_BaseAutoModelClass):
103
+ """Automatically detect Sequence Classification Models."""
104
+
105
+ _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
106
+ _model_mapping_names = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
107
+
108
+
109
+ class RBLNAutoModelForVision2Seq(_BaseAutoModelClass):
110
+ """Automatically detect Vision to Sequence Generation Models."""
111
+
112
+ _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING
113
+ _model_mapping_names = MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES
114
+
115
+
116
+ class RBLNAutoModelForImageTextToText(_BaseAutoModelClass):
117
+ """Automatically detect Image and Text to Text Generation Models."""
118
+
119
+ _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING
120
+ _model_mapping_names = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
121
+
122
+
123
+ class RBLNAutoModelForMaskedLM(_BaseAutoModelClass):
124
+ """Automatically detect Masked Lanuage Models."""
125
+
126
+ _model_mapping = MODEL_FOR_MASKED_LM_MAPPING
127
+ _model_mapping_names = MODEL_FOR_MASKED_LM_MAPPING_NAMES
128
+
129
+
130
+ class RBLNAutoModelForAudioClassification(_BaseAutoModelClass):
131
+ """Automatically detect Audio Classification Models."""
132
+
133
+ _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
134
+ _model_mapping_names = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES
135
+
136
+
137
+ class RBLNAutoModelForImageClassification(_BaseAutoModelClass):
138
+ """Automatically detect Image Classification Models."""
139
+
140
+ _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
141
+ _model_mapping_names = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
142
+
143
+
144
+ class RBLNAutoModelForQuestionAnswering(_BaseAutoModelClass):
145
+ """Automatically detect Question Answering Models."""
146
+
147
+ _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
148
+ _model_mapping_names = MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
149
+
150
+
151
+ class RBLNAutoModelForTextEncoding(_BaseAutoModelClass):
152
+ """Automatically detect Text Encoding Models."""
153
+
154
+ _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING
155
+ _model_mapping_names = MODEL_FOR_TEXT_ENCODING_MAPPING_NAMES
156
+
157
+
158
+ class RBLNAutoModelForZeroShotObjectDetection(_BaseAutoModelClass):
159
+ """Automatically detect Zero Shot Object Detection Models."""
160
+
161
+ _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING
162
+ _model_mapping_names = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_attn_decode, paged_causal_attn_decode
16
+ from .configuration_bart import RBLNBartForConditionalGenerationConfig, RBLNBartModelConfig
17
+ from .modeling_bart import RBLNBartForConditionalGeneration, RBLNBartModel
@@ -0,0 +1,163 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
20
+ from transformers.utils import logging
21
+
22
+ from ..seq2seq.seq2seq_architecture import (
23
+ Seq2SeqCrossAttention,
24
+ Seq2SeqDecoder,
25
+ Seq2SeqDecoderLayer,
26
+ Seq2SeqDecoderWrapper,
27
+ Seq2SeqEncoderWrapper,
28
+ Seq2SeqForConditionalGeneration,
29
+ Seq2SeqSelfAttention,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class BartWrapper:
37
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, use_attention_mask: bool):
38
+ self.encoder = Seq2SeqEncoderWrapper(model, enc_max_seq_len)
39
+ self.decoder = BartDecoderWrapper(model, use_attention_mask=use_attention_mask)
40
+
41
+
42
+ class BartDecoderWrapper(Seq2SeqDecoderWrapper):
43
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
44
+ new_layers = []
45
+ for layer in model.get_decoder().layers:
46
+ self_attn = BartSelfAttention(layer.self_attn, use_attention_mask=self.use_attention_mask)
47
+ cross_attn = BartCrossAttention(layer.encoder_attn)
48
+ new_layers.append(BartDecoderLayer(layer, self_attn, cross_attn))
49
+
50
+ decoder_model = BartDecoder(model.get_decoder(), new_layers)
51
+ new_model = BartForConditionalGeneration(model, decoder_model)
52
+
53
+ return new_model
54
+
55
+
56
+ class BartForConditionalGeneration(Seq2SeqForConditionalGeneration):
57
+ pass
58
+
59
+
60
+ class BartDecoder(Seq2SeqDecoder):
61
+ has_pos_emb = True
62
+
63
+ def __post_init__(self):
64
+ self.embed_positions = self._original_mod.embed_positions
65
+ self.layernorm_embedding = self._original_mod.layernorm_embedding
66
+ self.embed_scale = getattr(self._original_mod, "embed_scale", None)
67
+
68
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, **kwargs):
69
+ if attention_mask is not None:
70
+ attention_mask = attention_mask[:, None, None, :]
71
+ encoder_attention_mask = _prepare_4d_attention_mask(encoder_attention_mask, torch.float32, tgt_len=1)
72
+
73
+ return attention_mask, encoder_attention_mask
74
+
75
+ def apply_position_embedding(self, inputs_embeds, cache_position):
76
+ hidden_all = []
77
+ for i in range(inputs_embeds.shape[0]):
78
+ positions_idx = cache_position[i]
79
+ position_weight = self.embed_positions.weight[2:]
80
+ position = position_weight[positions_idx]
81
+ batch_hidden = position + inputs_embeds[i]
82
+ hidden_all.append(batch_hidden)
83
+ hidden_states = torch.stack(hidden_all, dim=0)
84
+
85
+ hidden_states = self.layernorm_embedding(hidden_states)
86
+
87
+ return hidden_states
88
+
89
+ def get_embedding(self):
90
+ if self.embed_scale is not None:
91
+ return lambda x: self.embed_tokens(x) * self.embed_scale
92
+ else:
93
+ return self.embed_tokens
94
+
95
+
96
+ class BartLayerFF(nn.Module):
97
+ def __init__(self, decoder_layer):
98
+ super().__init__()
99
+ self.fc1 = decoder_layer.fc1
100
+ self.fc2 = decoder_layer.fc2
101
+ self.activation_fn = decoder_layer.activation_fn
102
+ self.layer_norm = decoder_layer.final_layer_norm
103
+
104
+ def forward(self, hidden_states):
105
+ # Residual Connection
106
+ residual = hidden_states
107
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
108
+ hidden_states = self.fc2(hidden_states)
109
+ hidden_states = residual + hidden_states
110
+ hidden_states = self.layer_norm(hidden_states)
111
+ return hidden_states
112
+
113
+
114
+ class BartDecoderLayer(Seq2SeqDecoderLayer):
115
+ def __post_init__(self):
116
+ self.self_attn_layer_norm = self._original_mod.self_attn_layer_norm
117
+ self.encoder_attn = self._original_mod.encoder_attn
118
+ self.encoder_attn_layer_norm = self._original_mod.encoder_attn_layer_norm
119
+ self.ff_layer = BartLayerFF(self._original_mod)
120
+
121
+ def pre_self_attn_layer_norm(self, hidden_states):
122
+ return hidden_states
123
+
124
+ def post_self_attn_layer_norm(self, hidden_states):
125
+ return self.self_attn_layer_norm(hidden_states)
126
+
127
+ def pre_cross_attn_layer_norm(self, hidden_states):
128
+ return hidden_states
129
+
130
+ def post_cross_attn_layer_norm(self, hidden_states):
131
+ return self.encoder_attn_layer_norm(hidden_states)
132
+
133
+
134
+ class BartSelfAttention(Seq2SeqSelfAttention):
135
+ def __post_init__(self, use_attention_mask: bool = True):
136
+ self.q_proj = self._original_mod.q_proj
137
+ self.k_proj = self._original_mod.k_proj
138
+ self.v_proj = self._original_mod.v_proj
139
+ self.out_proj = self._original_mod.out_proj
140
+ self.num_heads = self._original_mod.num_heads
141
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
142
+ self.scaling = self.head_dim**-0.5
143
+ if use_attention_mask:
144
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_attn_decode
145
+ else:
146
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_causal_attn_decode
147
+
148
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
149
+ query_states = self.q_proj(hidden_states) * self.scaling
150
+ key_states = self.k_proj(hidden_states)
151
+ value_states = self.v_proj(hidden_states)
152
+ return query_states, key_states, value_states
153
+
154
+
155
+ class BartCrossAttention(Seq2SeqCrossAttention):
156
+ def __post_init__(self):
157
+ self.q_proj = self._original_mod.q_proj
158
+ self.k_proj = self._original_mod.k_proj
159
+ self.v_proj = self._original_mod.v_proj
160
+ self.out_proj = self._original_mod.out_proj
161
+ self.num_heads = self._original_mod.num_heads
162
+ self.head_dim = self._original_mod.embed_dim // self._original_mod.num_heads
163
+ self.embed_dim = self._original_mod.embed_dim
@@ -0,0 +1,36 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNTransformerEncoderForFeatureExtractionConfig
16
+ from ..seq2seq import RBLNModelForSeq2SeqLMConfig
17
+
18
+
19
+ class RBLNBartModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
20
+ """
21
+ Configuration class for RBLNBartModel.
22
+
23
+ This configuration class stores the configuration parameters specific to
24
+ RBLN-optimized BART models for feature extraction tasks.
25
+ """
26
+
27
+
28
+ class RBLNBartForConditionalGenerationConfig(RBLNModelForSeq2SeqLMConfig):
29
+ """
30
+ Configuration class for RBLNBartForConditionalGeneration.
31
+
32
+ This configuration class stores the configuration parameters specific to
33
+ RBLN-optimized BART models for conditional text generation tasks.
34
+ """
35
+
36
+ support_paged_attention = True