optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,349 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class WhisperWrapper:
27
+ def __init__(self, model, use_attention_mask, rbln_token_timestamps):
28
+ self.encoder = WhisperEncoderWrapper(model)
29
+ self.decoder = WhisperDecoderWrapper(
30
+ model, use_attention_mask=use_attention_mask, output_attentions=rbln_token_timestamps
31
+ )
32
+
33
+
34
+ class WhisperEncoderWrapper(torch.nn.Module):
35
+ def __init__(self, model):
36
+ super().__init__()
37
+ self.config = model.config
38
+ self.encoder = model.get_encoder()
39
+ self.num_heads = self.config.decoder_attention_heads
40
+ self.d_kv = self.config.d_model // self.num_heads
41
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
42
+
43
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
44
+ return (
45
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
46
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
47
+ )
48
+
49
+ def forward(
50
+ self,
51
+ input_features: Optional[torch.LongTensor],
52
+ b_idx: torch.Tensor,
53
+ cross_key_values: torch.Tensor,
54
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
55
+ # 1. get encoder last_hidden_states
56
+ encoder_outputs = self.encoder(input_features=input_features)
57
+ last_hidden_states = encoder_outputs[0]
58
+
59
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
60
+ cross_kv = []
61
+ batch_size = input_features.shape[0]
62
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
63
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
64
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
65
+
66
+ cross_kv.append(past_k)
67
+ cross_kv.append(past_v)
68
+
69
+ cross_kv = torch.stack(cross_kv, dim=0)
70
+
71
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
72
+ batch_axis = torch.tensor(1, dtype=torch.int16)
73
+ cross_key_values = torch.ops.rbln_custom_ops.rbln_cache_update(
74
+ cross_key_values, cross_kv, b_idx[0], batch_axis
75
+ )
76
+
77
+ return cross_key_values
78
+
79
+
80
+ class WhisperDecoderWrapper(torch.nn.Module):
81
+ def __init__(self, model, use_attention_mask: bool = True, output_attentions: bool = False, **kwargs):
82
+ super().__init__()
83
+ self.config = model.config
84
+ self.proj_out = model.proj_out
85
+ self.use_attention_mask = use_attention_mask
86
+ self.output_attentions = output_attentions
87
+ self.__post_init__(model, **kwargs)
88
+
89
+ def __post_init__(self, model: nn.Module, **kwargs):
90
+ """
91
+ Post-initialization to extract and configure encoder-related attributes.
92
+ It is inspired by the BART architecture, but it is designed to be flexible and can be overridden
93
+ by subclasses to modify or add custom attributes as necessary.
94
+ """
95
+ self.num_layers = self.config.decoder_layers
96
+ self.decoder = self.convert_to_rbln_conditional_generation(model)
97
+
98
+ def convert_to_rbln_conditional_generation(self, model: nn.Module):
99
+ new_layers = []
100
+ for layer in model.get_decoder().layers:
101
+ self_attn = WhisperSelfAttention(layer.self_attn)
102
+ cross_attn = WhisperCrossAttention(layer.encoder_attn)
103
+ new_layers.append(WhisperDecoderLayer(layer, self_attn, cross_attn))
104
+
105
+ decoder_model = WhisperDecoder(model.get_decoder(), new_layers)
106
+
107
+ return decoder_model
108
+
109
+ def forward(
110
+ self,
111
+ *args,
112
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
113
+ if self.use_attention_mask:
114
+ (
115
+ decoder_input_ids,
116
+ decoder_attention_mask,
117
+ cache_position,
118
+ block_tables,
119
+ cross_kv_cache,
120
+ *self_kv_cache,
121
+ ) = args
122
+ else:
123
+ decoder_attention_mask = None
124
+ (decoder_input_ids, cache_position, block_tables, cross_kv_cache, *self_kv_cache) = args
125
+
126
+ # prepare past_key_values
127
+ self_past_key_values = ()
128
+ cross_past_key_values = ()
129
+ for i in range(0, self.num_layers * 2, 2):
130
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
131
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
132
+
133
+ # Decode
134
+ sequence_output, cross_attentions = self.decoder(
135
+ input_ids=decoder_input_ids,
136
+ attention_mask=decoder_attention_mask,
137
+ cache_position=cache_position,
138
+ self_past_key_values=self_past_key_values,
139
+ cross_past_key_values=cross_past_key_values,
140
+ block_tables=block_tables,
141
+ )
142
+
143
+ lm_logits = self.proj_out(sequence_output)
144
+ outputs = (lm_logits,)
145
+
146
+ if self.output_attentions:
147
+ # deocder's cross attention is used for token_timestamps
148
+ cross_attention = torch.stack(cross_attentions, dim=0)
149
+ outputs += (cross_attention,)
150
+
151
+ return outputs
152
+
153
+
154
+ class WhisperDecoder(nn.Module):
155
+ def __init__(self, model, layers, **kwargs):
156
+ super().__init__()
157
+ self._original_mod = model
158
+ self.layers = nn.ModuleList(layers)
159
+ self.embed_tokens = model.embed_tokens
160
+ self.layer_norm = model.layer_norm
161
+ self.embed_positions = model.embed_positions
162
+
163
+ def forward(
164
+ self,
165
+ input_ids: Optional[torch.Tensor] = None,
166
+ attention_mask: Optional[torch.Tensor] = None,
167
+ self_past_key_values: Optional[torch.Tensor] = None,
168
+ cross_past_key_values: Optional[torch.Tensor] = None,
169
+ cache_position: Optional[torch.Tensor] = None,
170
+ block_tables: Optional[torch.Tensor] = None,
171
+ ):
172
+ input_shape = input_ids.size()
173
+ input_ids = input_ids.view(-1, input_shape[-1])
174
+
175
+ # positional embeding
176
+ inputs_embeds = self.embed_tokens(input_ids)
177
+ all_hiddens = []
178
+ for i in range(inputs_embeds.shape[0]):
179
+ position_id = cache_position[i]
180
+ position = self.embed_positions.weight[position_id]
181
+ batch_hidden = position + inputs_embeds[i]
182
+ all_hiddens.append(batch_hidden)
183
+
184
+ hidden_states = torch.cat(all_hiddens, dim=0).unsqueeze(1)
185
+
186
+ # prepare attn mask (normal attention - masked)
187
+ if attention_mask is not None:
188
+ attention_mask = attention_mask[:, None, None, :]
189
+
190
+ cross_attentions = ()
191
+ # iterate decoder_layer
192
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
193
+ self_past_key_values, cross_past_key_values, self.layers
194
+ ):
195
+ hidden_states, cross_attn_weights = decoder_layer(
196
+ hidden_states,
197
+ attention_mask=attention_mask,
198
+ self_past_key_value=self_past_key_value,
199
+ cross_past_key_value=cross_past_key_value,
200
+ cache_position=cache_position,
201
+ block_tables=block_tables,
202
+ )
203
+ cross_attentions += (cross_attn_weights,)
204
+
205
+ hidden_states = self.layer_norm(hidden_states)
206
+
207
+ return hidden_states, cross_attentions
208
+
209
+
210
+ class WhisperDecoderLayer(nn.Module):
211
+ def __init__(self, decoder_layer, self_attn, cross_attn):
212
+ super().__init__()
213
+ self._original_mod = decoder_layer
214
+ self.self_attn = self_attn
215
+ self.encoder_attn = cross_attn
216
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
217
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
218
+ self.final_layer_norm = decoder_layer.final_layer_norm
219
+ self.activation_fn = decoder_layer.activation_fn
220
+ self.fc1 = decoder_layer.fc1
221
+ self.fc2 = decoder_layer.fc2
222
+
223
+ def forward(
224
+ self,
225
+ hidden_states: torch.Tensor,
226
+ attention_mask: Optional[torch.Tensor] = None,
227
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
228
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
229
+ cache_position: Optional[torch.Tensor] = None,
230
+ block_tables: Optional[torch.Tensor] = None,
231
+ ) -> torch.Tensor:
232
+ # Self Attention Block
233
+ residual = hidden_states
234
+ hidden_states = self.self_attn_layer_norm(hidden_states)
235
+ hidden_states = self.self_attn(
236
+ hidden_states=hidden_states,
237
+ past_key_value=self_past_key_value,
238
+ attention_mask=attention_mask,
239
+ cache_position=cache_position,
240
+ block_tables=block_tables,
241
+ )
242
+ hidden_states = residual + hidden_states
243
+
244
+ # Cross-Attention Block
245
+ residual = hidden_states
246
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
247
+ hidden_states, cross_attn_weights = self.encoder_attn(
248
+ hidden_states=hidden_states,
249
+ past_key_value=cross_past_key_value,
250
+ )
251
+ hidden_states = residual + hidden_states
252
+
253
+ # Fully Connected Block
254
+ residual = hidden_states
255
+ hidden_states = self.final_layer_norm(hidden_states)
256
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
257
+ hidden_states = self.fc2(hidden_states)
258
+ hidden_states = residual + hidden_states
259
+
260
+ return hidden_states, cross_attn_weights
261
+
262
+
263
+ class WhisperAttention(nn.Module):
264
+ def __init__(self, attn):
265
+ super().__init__()
266
+ self._original_mod = attn
267
+ self.q_proj = attn.q_proj
268
+ self.k_proj = attn.k_proj
269
+ self.v_proj = attn.v_proj
270
+ self.out_proj = attn.out_proj
271
+ self.num_heads = attn.num_heads
272
+ self.embed_dim = attn.embed_dim
273
+ self.head_dim = attn.embed_dim // attn.num_heads
274
+ self.scaling = self.head_dim**-0.5
275
+
276
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
277
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
278
+
279
+
280
+ class WhisperSelfAttention(WhisperAttention):
281
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
282
+ return tensor.view(bsz, seq_len, 1, self.num_heads, self.head_dim).transpose(1, 3)
283
+
284
+ def forward(
285
+ self,
286
+ hidden_states: torch.Tensor,
287
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
288
+ attention_mask: Optional[torch.Tensor] = None,
289
+ cache_position: Optional[torch.Tensor] = None,
290
+ block_tables: Optional[torch.Tensor] = None,
291
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
292
+ bsz, tgt_len, _ = hidden_states.size()
293
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
294
+ query_states = query_states * self.scaling
295
+
296
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
297
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
298
+ block_size = past_key_value[0].shape[-2]
299
+
300
+ args = {
301
+ "q": query_states,
302
+ "k": key_states,
303
+ "v": value_states,
304
+ "kcache": past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
305
+ "vcache": past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
306
+ "seq": cache_position.expand(bsz, 1),
307
+ "scale": torch.tensor(1.0, dtype=torch.float32),
308
+ "block_table": block_tables,
309
+ "block_size": block_size,
310
+ }
311
+
312
+ if attention_mask is not None:
313
+ args["mask"] = attention_mask.unsqueeze(2)
314
+ attn_output = torch.ops.rbln_custom_ops.paged_attn_decode(**args)
315
+ else:
316
+ args["mask"] = None
317
+ attn_output = torch.ops.rbln_custom_ops.paged_causal_attn_decode(**args)
318
+
319
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
320
+ attn_output = attn_output.transpose(1, 2)
321
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
322
+ attn_output = self.out_proj(attn_output)
323
+
324
+ return attn_output
325
+
326
+
327
+ class WhisperCrossAttention(WhisperAttention):
328
+ def forward(
329
+ self,
330
+ hidden_states: torch.Tensor,
331
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
332
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
333
+ batch_size, query_len, _ = hidden_states.size()
334
+ query_states = self._shape(self.q_proj(hidden_states), query_len, batch_size)
335
+ query_states = query_states * self.scaling
336
+
337
+ key_states = past_key_value[0]
338
+ value_states = past_key_value[1]
339
+
340
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
341
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
342
+
343
+ attn_output = torch.matmul(attn_weights, value_states)
344
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
345
+ attn_output = attn_output.transpose(1, 2)
346
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
347
+ attn_output = self.out_proj(attn_output)
348
+
349
+ return attn_output, attn_weights
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_xlm_roberta import RBLNXLMRobertaForSequenceClassificationConfig, RBLNXLMRobertaModelConfig
16
+ from .modeling_xlm_roberta import RBLNXLMRobertaForSequenceClassification, RBLNXLMRobertaModel
17
+
18
+
19
+ __all__ = [
20
+ "RBLNXLMRobertaModelConfig",
21
+ "RBLNXLMRobertaForSequenceClassificationConfig",
22
+ "RBLNXLMRobertaModel",
23
+ "RBLNXLMRobertaForSequenceClassification",
24
+ ]
@@ -0,0 +1,32 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import (
16
+ RBLNModelForSequenceClassificationConfig,
17
+ RBLNTransformerEncoderForFeatureExtractionConfig,
18
+ )
19
+
20
+
21
+ class RBLNXLMRobertaModelConfig(RBLNTransformerEncoderForFeatureExtractionConfig):
22
+ """
23
+ Configuration class for XLM-RoBERTa model.
24
+ Inherits from RBLNTransformerEncoderForFeatureExtractionConfig with no additional parameters.
25
+ """
26
+
27
+
28
+ class RBLNXLMRobertaForSequenceClassificationConfig(RBLNModelForSequenceClassificationConfig):
29
+ """
30
+ Configuration class for XLM-RoBERTa sequence classification model.
31
+ Inherits from RBLNModelForSequenceClassificationConfig with no additional parameters.
32
+ """
@@ -0,0 +1,82 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, SequenceClassifierOutput
19
+
20
+ from ...modeling_generic import RBLNModelForSequenceClassification, RBLNTransformerEncoderForFeatureExtraction
21
+
22
+
23
+ class RBLNXLMRobertaModel(RBLNTransformerEncoderForFeatureExtraction):
24
+ """
25
+ XLM-RoBERTa base model optimized for RBLN NPU.
26
+ """
27
+
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.Tensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ token_type_ids: Optional[torch.Tensor] = None,
33
+ **kwargs,
34
+ ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple]:
35
+ """
36
+ Forward pass for the RBLN-optimized XLM-RoBERTa base model.
37
+
38
+ Args:
39
+ input_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
40
+ attention_mask (torch.Tensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
41
+ token_type_ids (torch.Tensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate different portions of the inputs.
42
+
43
+ Returns:
44
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a BaseModelOutputWithPoolingAndCrossAttentions object.
45
+ """
46
+
47
+ if token_type_ids is not None:
48
+ kwargs.setdefault("token_type_ids", token_type_ids)
49
+
50
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
51
+
52
+
53
+ class RBLNXLMRobertaForSequenceClassification(RBLNModelForSequenceClassification):
54
+ """
55
+ XLM-RoBERTa model for sequence classification tasks optimized for RBLN NPU.
56
+ """
57
+
58
+ rbln_model_input_names = ["input_ids", "attention_mask"]
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: Optional[torch.LongTensor] = None,
63
+ attention_mask: Optional[torch.FloatTensor] = None,
64
+ token_type_ids: Optional[torch.LongTensor] = None,
65
+ **kwargs,
66
+ ) -> Union[SequenceClassifierOutput, tuple]:
67
+ """
68
+ Forward pass for the RBLN-optimized XLM-RoBERTa model for sequence classification.
69
+
70
+ Args:
71
+ input_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Indices of input sequence tokens in the vocabulary.
72
+ attention_mask (torch.FloatTensor of shape (batch_size, sequence_length), optional): Mask to avoid performing attention on padding token indices.
73
+ token_type_ids (torch.LongTensor of shape (batch_size, sequence_length), optional): Segment token indices to indicate first and second portions of the inputs.
74
+
75
+ Returns:
76
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a SequenceClassifierOutput object.
77
+ """
78
+
79
+ if token_type_ids is not None:
80
+ kwargs.setdefault("token_type_ids", token_type_ids)
81
+
82
+ return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
File without changes