optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,337 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ from torch import nn
28
+ from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
29
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
30
+ from transformers.utils import logging
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class TimeSeriesTransformersWrapper:
37
+ def __init__(self, model, num_parallel_samples):
38
+ self.encoder = TimeSeriesTransformersEncoderWrapper(model)
39
+ self.decoder = TimeSeriesTransformersDecoderWrapper(model, num_parallel_samples)
40
+
41
+
42
+ class TimeSeriesTransformersEncoderWrapper(torch.nn.Module):
43
+ def __init__(self, model):
44
+ super().__init__()
45
+ self.config = model.config
46
+ self.encoder = model.get_encoder()
47
+ self.num_heads = self.config.decoder_attention_heads
48
+ self.d_kv = self.config.d_model // self.num_heads
49
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().layers)
50
+
51
+ def _extract_cross_kv_projects(self, decoder_layers: nn.Module):
52
+ return (
53
+ nn.ModuleList(layer.encoder_attn.k_proj for layer in decoder_layers),
54
+ nn.ModuleList(layer.encoder_attn.v_proj for layer in decoder_layers),
55
+ )
56
+
57
+ def forward(
58
+ self,
59
+ inputs_embeds: torch.Tensor,
60
+ cross_key_values: torch.Tensor, # n_layers, batch_size, num_heads, context_length, d_kv
61
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
62
+ # 1. get encoder last_hidden_states
63
+ encoder_outputs = self.encoder(inputs_embeds=inputs_embeds, attention_mask=None, return_dict=False)
64
+ last_hidden_states = encoder_outputs[0]
65
+
66
+ # 2. pre-compute cross_attention's past_key_value which used in decoder phase.
67
+ cross_kv = []
68
+ batch_size = inputs_embeds.shape[0]
69
+ for k_proj, v_proj in zip(self.cross_k_projects, self.cross_v_projects):
70
+ past_k = k_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
71
+ past_v = v_proj(last_hidden_states).view(batch_size, -1, self.num_heads, self.d_kv).transpose(1, 2)
72
+
73
+ cross_kv.append(past_k)
74
+ cross_kv.append(past_v)
75
+
76
+ cross_kv = torch.stack(cross_kv, dim=0)
77
+
78
+ # 3. update cross_attention's past_key_value to the device-dram for optimization.
79
+ bidx = torch.tensor(0, dtype=torch.int16)
80
+ axis = torch.tensor(1, dtype=torch.int16)
81
+ enc_output = torch.ops.rbln_custom_ops.rbln_cache_update(cross_key_values, cross_kv, bidx, axis)
82
+
83
+ return enc_output
84
+
85
+
86
+ class TimeSeriesTransformersDecoderWrapper(torch.nn.Module):
87
+ def __init__(self, model, num_parallel_samples):
88
+ super().__init__()
89
+ self.config = model.config
90
+ self.num_layers = self.config.decoder_layers
91
+ self.decoder = self.convert_to_rbln_tst_decoder(model, num_parallel_samples)
92
+ self.parameter_projection = model.parameter_projection
93
+
94
+ def convert_to_rbln_tst_decoder(self, model: nn.Module, num_parallel_samples: int):
95
+ new_layers = []
96
+ for layer in model.get_decoder().layers:
97
+ self_attn = TimeSeriesTransformersSelfAttention(layer.self_attn, num_parallel_samples)
98
+ cross_attn = TimeSeriesTransformersCrossAttention(layer.encoder_attn, num_parallel_samples)
99
+ new_layers.append(TimeSeriesTransformersDecoderLayer(layer, self_attn, cross_attn))
100
+
101
+ decoder_model = TimeSeriesTransformersDecoder(model.get_decoder(), new_layers)
102
+
103
+ return decoder_model
104
+
105
+ def forward(
106
+ self,
107
+ inputs_embeds: torch.Tensor,
108
+ decoder_attention_mask: torch.Tensor,
109
+ cache_position: torch.Tensor,
110
+ block_tables: torch.Tensor,
111
+ cross_kv_cache: torch.Tensor, # batch_size, num_heads, context_length, d_kv
112
+ *self_kv_cache: torch.Tensor, # batch_size * num_parallel_samples, num_heads, prediction_length, d_kv
113
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
114
+ # prepare past_key_values
115
+ self_past_key_values = ()
116
+ cross_past_key_values = ()
117
+ for i in range(0, self.num_layers * 2, 2):
118
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
119
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
120
+
121
+ # Decode
122
+ last_hidden_states = self.decoder(
123
+ inputs_embeds=inputs_embeds,
124
+ attention_mask=decoder_attention_mask,
125
+ cache_position=cache_position,
126
+ block_tables=block_tables,
127
+ self_past_key_values=self_past_key_values,
128
+ cross_past_key_values=cross_past_key_values,
129
+ )
130
+
131
+ params = self.parameter_projection(last_hidden_states[:, -1:])
132
+
133
+ outputs = ()
134
+ outputs += (params,)
135
+ outputs += (last_hidden_states,)
136
+
137
+ return outputs
138
+
139
+
140
+ class TimeSeriesTransformersDecoder(nn.Module):
141
+ def __init__(self, model, layers, **kwargs):
142
+ super().__init__()
143
+ self._original_mod = model
144
+ self.config = model.config
145
+ self.layers = nn.ModuleList(layers)
146
+ self.value_embedding = model.value_embedding
147
+ self.embed_positions = model.embed_positions
148
+ self.layernorm_embedding = model.layernorm_embedding
149
+
150
+ def forward(
151
+ self,
152
+ inputs_embeds: torch.Tensor = None,
153
+ attention_mask: Optional[torch.Tensor] = None,
154
+ self_past_key_values: Optional[torch.Tensor] = None,
155
+ cross_past_key_values: Optional[torch.Tensor] = None,
156
+ cache_position: Optional[torch.Tensor] = None,
157
+ block_tables: torch.Tensor = None,
158
+ ):
159
+ input_shape = inputs_embeds.size()[:-1]
160
+
161
+ # prepare casual_attn_mask
162
+ attention_mask = _prepare_4d_causal_attention_mask(attention_mask, input_shape, inputs_embeds, cache_position)
163
+
164
+ hidden_states = self.value_embedding(inputs_embeds)
165
+ embed_idx = cache_position + self.config.context_length
166
+ if torch.compiler.is_exporting():
167
+ embed_idx = embed_idx.item()
168
+ torch._check_is_size(embed_idx)
169
+ torch._check(embed_idx >= 0)
170
+ torch._check(embed_idx < len(self.embed_positions.weight))
171
+ embed_pos = self.embed_positions.weight[embed_idx]
172
+ hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
173
+
174
+ # iterate decoder_layer
175
+ for self_past_key_value, cross_past_key_value, decoder_layer in zip(
176
+ self_past_key_values, cross_past_key_values, self.layers
177
+ ):
178
+ hidden_states = decoder_layer(
179
+ hidden_states,
180
+ attention_mask=attention_mask,
181
+ self_past_key_value=self_past_key_value,
182
+ cross_past_key_value=cross_past_key_value,
183
+ cache_position=cache_position,
184
+ block_tables=block_tables,
185
+ )
186
+
187
+ return hidden_states
188
+
189
+
190
+ class TimeSeriesTransformersDecoderLayer(nn.Module):
191
+ def __init__(self, decoder_layer, self_attn, cross_attn):
192
+ super().__init__()
193
+ self._original_mod = decoder_layer
194
+ self.self_attn = self_attn
195
+ self.encoder_attn = cross_attn
196
+ self.embed_dim = decoder_layer.embed_dim
197
+ self.self_attn_layer_norm = decoder_layer.self_attn_layer_norm
198
+ self.encoder_attn_layer_norm = decoder_layer.encoder_attn_layer_norm
199
+ self.final_layer_norm = decoder_layer.final_layer_norm
200
+ self.activation_fn = decoder_layer.activation_fn
201
+ self.fc1 = decoder_layer.fc1
202
+ self.fc2 = decoder_layer.fc2
203
+
204
+ def forward(
205
+ self,
206
+ hidden_states: torch.Tensor,
207
+ attention_mask: Optional[torch.Tensor] = None,
208
+ self_past_key_value: Optional[Tuple[torch.Tensor]] = None,
209
+ cross_past_key_value: Optional[Tuple[torch.Tensor]] = None,
210
+ cache_position: Optional[torch.Tensor] = None,
211
+ block_tables: torch.Tensor = None,
212
+ ) -> torch.Tensor:
213
+ # Self Attention Block
214
+ residual = hidden_states
215
+ hidden_states = self.self_attn(
216
+ hidden_states=hidden_states,
217
+ past_key_value=self_past_key_value,
218
+ attention_mask=attention_mask,
219
+ cache_position=cache_position,
220
+ block_tables=block_tables,
221
+ )
222
+ hidden_states = residual + hidden_states
223
+ hidden_states = self.self_attn_layer_norm(hidden_states)
224
+
225
+ # Cross-Attention Block
226
+ residual = hidden_states
227
+ hidden_states = self.encoder_attn(
228
+ hidden_states=hidden_states,
229
+ past_key_value=cross_past_key_value,
230
+ # attention_mask=encoder_attention_mask,
231
+ )
232
+ hidden_states = residual + hidden_states
233
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
234
+
235
+ # Fully Connected Block
236
+ residual = hidden_states
237
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
238
+ hidden_states = self.fc2(hidden_states)
239
+ hidden_states = residual + hidden_states
240
+ hidden_states = self.final_layer_norm(hidden_states)
241
+
242
+ return hidden_states
243
+
244
+
245
+ class TimeSeriesTransformersAttention(nn.Module):
246
+ def __init__(self, attn, num_parallel_samples):
247
+ super().__init__()
248
+ self._original_mod = attn
249
+ self.q_proj = attn.q_proj
250
+ self.k_proj = attn.k_proj
251
+ self.v_proj = attn.v_proj
252
+ self.out_proj = attn.out_proj
253
+ self.num_heads = attn.num_heads
254
+ self.embed_dim = attn.embed_dim
255
+ self.head_dim = attn.head_dim
256
+ self.scaling = attn.scaling
257
+ self.num_parallel_samples = num_parallel_samples
258
+
259
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
260
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
261
+
262
+
263
+ class TimeSeriesTransformersSelfAttention(TimeSeriesTransformersAttention):
264
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int) -> torch.Tensor:
265
+ return tensor.view(1, seq_len, 1, bsz * self.num_heads, self.head_dim).transpose(1, 3)
266
+
267
+ def forward(
268
+ self,
269
+ hidden_states: torch.Tensor,
270
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
271
+ attention_mask: Optional[torch.Tensor] = None,
272
+ cache_position: Optional[torch.Tensor] = None,
273
+ block_tables: Optional[torch.Tensor] = None,
274
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
275
+ bsz, tgt_len, _ = hidden_states.size()
276
+ query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
277
+ query_states = query_states * self.scaling
278
+
279
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
280
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
281
+
282
+ block_size = past_key_value[0].shape[-2]
283
+ attn_output = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode(
284
+ q=query_states,
285
+ k=key_states,
286
+ v=value_states,
287
+ mask=attention_mask.unsqueeze(2),
288
+ kcache=past_key_value[0].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
289
+ vcache=past_key_value[1].view(1, bsz * self.num_heads, 1, -1, self.head_dim),
290
+ seq=cache_position.expand(bsz, 1),
291
+ scale=torch.tensor(1.0, dtype=torch.float32), # scale
292
+ block_table=block_tables,
293
+ block_size=block_size,
294
+ )
295
+
296
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
297
+ attn_output = attn_output.transpose(1, 2)
298
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
299
+ attn_output = self.out_proj(attn_output)
300
+
301
+ return attn_output
302
+
303
+
304
+ class TimeSeriesTransformersCrossAttention(TimeSeriesTransformersSelfAttention):
305
+ def forward(
306
+ self,
307
+ hidden_states: torch.Tensor,
308
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
309
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
310
+ batch_size, query_len, _ = hidden_states.size()
311
+ query_states = (
312
+ self.q_proj(hidden_states)
313
+ .view(
314
+ batch_size // self.num_parallel_samples,
315
+ self.num_parallel_samples,
316
+ query_len,
317
+ self.num_heads,
318
+ self.head_dim,
319
+ )
320
+ .transpose(2, 3)
321
+ )
322
+ query_states = query_states * self.scaling
323
+
324
+ key_states = past_key_value[0].unsqueeze(1)
325
+ value_states = past_key_value[1].unsqueeze(1)
326
+
327
+ attn_weights = torch.matmul(query_states, key_states.transpose(3, 4))
328
+ attn_weights = attn_weights
329
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
330
+
331
+ attn_output = torch.matmul(attn_weights, value_states)
332
+ attn_output = attn_output.view(batch_size, self.num_heads, query_len, self.head_dim)
333
+ attn_output = attn_output.transpose(1, 2)
334
+ attn_output = attn_output.reshape(batch_size, query_len, self.embed_dim)
335
+ attn_output = self.out_proj(attn_output)
336
+
337
+ return attn_output
@@ -0,0 +1,19 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_vit import RBLNViTForImageClassificationConfig
16
+ from .modeling_vit import RBLNViTForImageClassification
17
+
18
+
19
+ __all__ = ["RBLNViTForImageClassificationConfig", "RBLNViTForImageClassification"]
@@ -0,0 +1,24 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ...configuration_generic import RBLNModelForImageClassificationConfig
16
+
17
+
18
+ class RBLNViTForImageClassificationConfig(RBLNModelForImageClassificationConfig):
19
+ """
20
+ Configuration class for RBLNViTForImageClassification.
21
+
22
+ This configuration class stores the configuration parameters specific to
23
+ RBLN-optimized Vision Transformer (ViT) models for image classification tasks.
24
+ """
@@ -0,0 +1,44 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, Union
16
+
17
+ import torch
18
+ from transformers.modeling_outputs import ImageClassifierOutput
19
+
20
+ from ...modeling_generic import RBLNModelForImageClassification
21
+
22
+
23
+ class RBLNViTForImageClassification(RBLNModelForImageClassification):
24
+ """
25
+ RBLN optimized Vision Transformer (ViT) model for image classification tasks.
26
+
27
+ This class provides hardware-accelerated inference for Vision Transformer models
28
+ on RBLN devices, supporting image classification with transformer-based architectures
29
+ that process images as sequences of patches.
30
+ """
31
+
32
+ def forward(self, pixel_values: torch.Tensor, **kwargs) -> Union[ImageClassifierOutput, Tuple]:
33
+ """
34
+ Forward pass for the RBLN-optimized Vision Transformer model for image classification.
35
+
36
+ Args:
37
+ pixel_values (torch.FloatTensor of shape (batch_size, channels, height, width)):
38
+ The tensors corresponding to the input images.
39
+
40
+ Returns:
41
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns an ImageClassifierOutput object.
42
+
43
+ """
44
+ return super().forward(pixel_values, **kwargs)
@@ -0,0 +1,16 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
16
+ from .modeling_wav2vec2 import RBLNWav2Vec2ForCTC
@@ -0,0 +1,38 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+
19
+
20
+ class RBLNWav2Vec2ForCTCConfig(RBLNModelConfig):
21
+ """
22
+ Configuration class for RBLNWav2Vec2ForCTC.
23
+
24
+ This configuration class stores the configuration parameters specific to
25
+ RBLN-optimized Wav2Vec2 models for Connectionist Temporal Classification (CTC) tasks.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ max_seq_len: Optional[int] = None,
31
+ batch_size: Optional[int] = None,
32
+ **kwargs: Any,
33
+ ):
34
+ super().__init__(**kwargs)
35
+ self.max_seq_len = max_seq_len
36
+ self.batch_size = batch_size or 1
37
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
38
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
@@ -0,0 +1,104 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import TYPE_CHECKING, Optional, Union
17
+
18
+ import torch
19
+ from transformers import AutoModelForCTC, Wav2Vec2Config, Wav2Vec2ForCTC
20
+ from transformers.modeling_outputs import CausalLMOutput
21
+
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
24
+ from .configuration_wav2vec2 import RBLNWav2Vec2ForCTCConfig
25
+
26
+
27
+ if TYPE_CHECKING:
28
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
29
+
30
+
31
+ class _Wav2Vec2(torch.nn.Module):
32
+ def __init__(self, model: "Wav2Vec2ForCTC"):
33
+ super().__init__()
34
+ self.model = model
35
+
36
+ def forward(self, input_values):
37
+ output = self.model.wav2vec2(input_values=input_values)
38
+ return self.model.lm_head(output[0])
39
+
40
+
41
+ class RBLNWav2Vec2ForCTC(RBLNModel):
42
+ """
43
+ Wav2Vec2 Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
44
+
45
+ It implements the methods to convert a pre-trained Wav2Vec2 model into a RBLN Wav2Vec2 model by:
46
+
47
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
48
+ - compiling the resulting graph using the RBLN compiler.
49
+ """
50
+
51
+ main_input_name = "input_values"
52
+ auto_model_class = AutoModelForCTC
53
+ rbln_dtype = "float32"
54
+
55
+ @classmethod
56
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNWav2Vec2ForCTCConfig) -> torch.nn.Module:
57
+ return _Wav2Vec2(model).eval()
58
+
59
+ @classmethod
60
+ def _update_rbln_config(
61
+ cls,
62
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
63
+ model: Optional["PreTrainedModel"] = None,
64
+ model_config: "Wav2Vec2Config" = None,
65
+ rbln_config: Optional[RBLNWav2Vec2ForCTCConfig] = None,
66
+ ) -> RBLNWav2Vec2ForCTCConfig:
67
+ if rbln_config.max_seq_len is None:
68
+ for tokenizer in preprocessors:
69
+ if hasattr(tokenizer, "model_max_length"):
70
+ rbln_config.max_seq_len = tokenizer.model_max_length
71
+ break
72
+ if rbln_config.max_seq_len is None:
73
+ raise ValueError("`rbln_max_seq_len` should be specified!")
74
+
75
+ rbln_compile_config = RBLNCompileConfig(
76
+ input_info=[
77
+ (
78
+ "input_values",
79
+ [
80
+ rbln_config.batch_size,
81
+ rbln_config.max_seq_len,
82
+ ],
83
+ "float32",
84
+ )
85
+ ]
86
+ )
87
+
88
+ rbln_config.set_compile_cfgs([rbln_compile_config])
89
+ return rbln_config
90
+
91
+ def forward(
92
+ self, input_values: torch.Tensor, return_dict: Optional[bool] = None, **kwargs
93
+ ) -> Union[CausalLMOutput, tuple]:
94
+ """
95
+ Forward pass for the RBLN-optimized Wav2Vec2 model for Connectionist Temporal Classification (CTC).
96
+
97
+ Args:
98
+ input_values (torch.FloatTensor of shape (batch_size, sequence_length)): Float values of input raw speech waveform. Values can be obtained by loading a .flac or .wav audio file into an array of type List[float] or a numpy.ndarray, e.g. via the soundfile library (pip install soundfile). To prepare the array into input_values, the AutoProcessor should be used for padding and conversion into a tensor of type torch.FloatTensor.
99
+ return_dict (bool, optional): Whether or not to return a ModelOutput instead of a plain tuple.
100
+
101
+ Returns:
102
+ The model outputs. If return_dict=False is passed, returns a tuple of tensors. Otherwise, returns a CausalLMOutput object.
103
+ """
104
+ return super().forward(input_values=input_values, return_dict=return_dict, **kwargs)
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ....ops import paged_add_softmax_attn_decode
16
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
17
+ from .modeling_whisper import RBLNWhisperForConditionalGeneration
@@ -0,0 +1,72 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....utils.logging import get_logger
19
+
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ class RBLNWhisperForConditionalGenerationConfig(RBLNModelConfig):
25
+ """
26
+ Configuration class for RBLNWhisperForConditionalGeneration.
27
+
28
+ This configuration class stores the configuration parameters specific to
29
+ RBLN-optimized Whisper models for speech recognition and transcription tasks.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ batch_size: int = None,
35
+ token_timestamps: bool = None,
36
+ use_attention_mask: bool = None,
37
+ enc_max_seq_len: int = None,
38
+ dec_max_seq_len: int = None,
39
+ kvcache_num_blocks: int = None,
40
+ kvcache_block_size: int = None,
41
+ **kwargs: Any,
42
+ ):
43
+ """
44
+ Args:
45
+ batch_size (int, optional): The batch size for inference. Defaults to 1.
46
+ token_timestamps (bool, optional): Whether to output token timestamps during generation. Defaults to False.
47
+ use_attention_mask (bool, optional): Whether to use attention masks during inference. This is automatically
48
+ enc_max_seq_len (int, optional): Maximum sequence length for the encoder.
49
+ dec_max_seq_len (int, optional): Maximum sequence length for the decoder.
50
+ kvcache_num_blocks (int, optional): The total number of blocks to allocate for the
51
+ PagedAttention KV cache for the SelfAttention. Defaults to batch_size.
52
+ kvcache_block_size (int, optional): Sets the size (in number of tokens) of each block
53
+ in the PagedAttention KV cache for the SelfAttention. Defaults to dec_max_seq_len.
54
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
55
+
56
+ Raises:
57
+ ValueError: If batch_size is not a positive integer.
58
+ """
59
+ super().__init__(**kwargs)
60
+
61
+ self.batch_size = batch_size or 1
62
+ if not isinstance(self.batch_size, int) or self.batch_size < 0:
63
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
64
+
65
+ self.token_timestamps = token_timestamps or False
66
+ self.enc_max_seq_len = enc_max_seq_len
67
+ self.dec_max_seq_len = dec_max_seq_len
68
+
69
+ self.use_attention_mask = use_attention_mask
70
+ self.use_attention_mask = self.use_attention_mask or False
71
+ self.kvcache_num_blocks = kvcache_num_blocks
72
+ self.kvcache_block_size = kvcache_block_size