optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,159 @@
1
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Copyright 2025 Rebellions Inc. All rights reserved.
16
+
17
+ # Licensed under the Apache License, Version 2.0 (the "License");
18
+ # you may not use this file except in compliance with the License.
19
+ # You may obtain a copy of the License at:
20
+
21
+ # http://www.apache.org/licenses/LICENSE-2.0
22
+
23
+ # Unless required by applicable law or agreed to in writing, software
24
+ # distributed under the License is distributed on an "AS IS" BASIS,
25
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26
+ # See the License for the specific language governing permissions and
27
+ # limitations under the License.
28
+
29
+ """
30
+ Generation utilities for Whisper.
31
+ Modified from `transformers.models.whisper.generation_whisper.py`
32
+ """
33
+
34
+ from typing import Any, Dict, Optional, Union
35
+
36
+ import torch
37
+ import transformers
38
+ from packaging import version
39
+ from transformers import GenerationMixin
40
+ from transformers.generation.configuration_utils import GenerationConfig
41
+ from transformers.modeling_outputs import ModelOutput
42
+ from transformers.models.whisper.generation_whisper import WhisperGenerationMixin
43
+
44
+
45
+ class RBLNWhisperGenerationMixin(WhisperGenerationMixin, GenerationMixin):
46
+ def generate(
47
+ self,
48
+ input_features: Optional[torch.Tensor] = None,
49
+ attention_mask: Optional[torch.Tensor] = None,
50
+ generation_config: Optional[GenerationConfig] = None,
51
+ return_segments: Optional[bool] = None,
52
+ return_timestamps: Optional[bool] = None,
53
+ return_token_timestamps: Optional[bool] = None,
54
+ **kwargs,
55
+ ) -> Union[ModelOutput, Dict[str, Any], torch.LongTensor]:
56
+ """
57
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
58
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate) for more details.
59
+
60
+ Args:
61
+ input_features(torch.Tensor, optional): The input features to the model.
62
+ attention_mask(torch.Tensor, optional): Attention mask needs to be passed when doing long-form transcription using a batch size > 1.
63
+ generation_config(GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
64
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
65
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
66
+ return_segments(bool, optional): Whether to return segments.
67
+ return_timestamps(bool, optional): Whether to return the timestamps with the text. For audios longer than 30 seconds, it is necessary to set return_timestamps=True.
68
+ return_token_timestamps(bool, optional): Whether to return token timestamps.
69
+ kwargs(dict[str, Any], optional): Additional arguments passed to the generate function.
70
+
71
+ Returns:
72
+ Transcribes or translates log-mel input features to a sequence of auto-regressively generated token ids.
73
+ """
74
+ if kwargs.get("num_beams", None) is not None:
75
+ if kwargs.get("num_beams") != 1:
76
+ raise ValueError(
77
+ "Beam search is not supported in RBLNWhisperGenerationMixin. "
78
+ "Received num_beams={num_beams}, but only num_beams=1 is allowed. "
79
+ "Please set num_beams=1 for greedy search or adjust your configuration."
80
+ )
81
+
82
+ return super().generate(
83
+ input_features,
84
+ attention_mask=attention_mask,
85
+ generation_config=generation_config,
86
+ return_segments=return_segments,
87
+ return_timestamps=return_timestamps,
88
+ return_token_timestamps=return_token_timestamps,
89
+ **kwargs,
90
+ )
91
+
92
+ def _postprocess_outputs(
93
+ self,
94
+ seek_outputs,
95
+ decoder_input_ids,
96
+ return_token_timestamps,
97
+ generation_config,
98
+ is_shortform,
99
+ seek,
100
+ batch_idx_map,
101
+ ):
102
+ # remove all previously passed decoder input ids
103
+ # should happen only if it is the first generated segment
104
+ start_idx = decoder_input_ids.shape[-1]
105
+
106
+ if isinstance(seek_outputs, torch.Tensor):
107
+ return seek_outputs[:, start_idx:], seek_outputs
108
+
109
+ if return_token_timestamps and not self.rbln_token_timestamps:
110
+ raise RuntimeError(
111
+ "To use .generate() with return_token_timestamps=True, the model must be compiled with rbln_token_timestamps=True. "
112
+ "You can compile the model by calling .from_pretrained() with export=True and rbln_token_timestamps=True as keyword arguments, "
113
+ "or you can generate with return_token_timestamps=False."
114
+ )
115
+
116
+ if return_token_timestamps and hasattr(generation_config, "alignment_heads"):
117
+ num_frames = getattr(generation_config, "num_frames", None)
118
+
119
+ if num_frames is not None:
120
+ num_frames = num_frames - seek
121
+ num_frames = num_frames[batch_idx_map]
122
+
123
+ if version.parse(transformers.__version__) >= version.parse("4.46.0"):
124
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
125
+ seek_outputs,
126
+ generation_config.alignment_heads,
127
+ num_frames=num_frames,
128
+ num_input_ids=decoder_input_ids.shape[-1],
129
+ )
130
+ else:
131
+ seek_outputs["token_timestamps"] = self._extract_token_timestamps(
132
+ seek_outputs,
133
+ generation_config.alignment_heads,
134
+ num_frames=num_frames,
135
+ )
136
+ seek_outputs["sequences"] = seek_outputs["sequences"][:, start_idx:]
137
+
138
+ def split_by_batch_index(values, key, batch_idx):
139
+ if key in ["scores", "encoder_attentions", "encoder_hidden_states", "logits"]:
140
+ return [v[batch_idx].cpu() for v in values]
141
+ if key in ["decoder_attentions", "decoder_hidden_states", "cross_attentions"]:
142
+ return tuple(tuple(w[batch_idx][None].cpu() for w in v) for v in values)
143
+ elif key == "past_key_values":
144
+ # we don't save `past_key_values in rbln
145
+ return None
146
+
147
+ return values[batch_idx].cpu()
148
+
149
+ sequence_tokens = seek_outputs["sequences"]
150
+
151
+ valid_seekoutputs = []
152
+ for k, v in seek_outputs.items():
153
+ if v is not None and len(v) > 0 and v[0] is not None:
154
+ valid_seekoutputs.append((k, v))
155
+ seek_outputs = [
156
+ {k: split_by_batch_index(v, k, i) for k, v in valid_seekoutputs} for i in range(sequence_tokens.shape[0])
157
+ ]
158
+
159
+ return sequence_tokens, seek_outputs
@@ -0,0 +1,475 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
17
+
18
+ import rebel
19
+ import torch
20
+ from rebel.compile_context import CompileContext
21
+ from transformers import AutoModelForSpeechSeq2Seq, WhisperForConditionalGeneration, WhisperModel
22
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ....utils.runtime_utils import RBLNPytorchRuntime
28
+ from .configuration_whisper import RBLNWhisperForConditionalGenerationConfig
29
+ from .generation_whisper import RBLNWhisperGenerationMixin
30
+ from .whisper_architecture import WhisperWrapper
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ if TYPE_CHECKING:
36
+ from transformers import (
37
+ AutoFeatureExtractor,
38
+ AutoProcessor,
39
+ AutoTokenizer,
40
+ GenerationConfig,
41
+ PretrainedConfig,
42
+ PreTrainedModel,
43
+ )
44
+
45
+
46
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
47
+ mandatory_members = ["main_input_name"]
48
+
49
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
50
+ output = super().forward(*args, **kwargs)
51
+ return BaseModelOutput(last_hidden_state=output)
52
+
53
+
54
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
55
+ mandatory_members = ["main_input_name"]
56
+
57
+ def __init__(
58
+ self,
59
+ runtime: rebel.Runtime,
60
+ batch_size: int,
61
+ dec_max_seq_len: int,
62
+ use_attention_mask: Optional[bool] = None,
63
+ **kwargs: Any,
64
+ ) -> None:
65
+ super().__init__(runtime, **kwargs)
66
+ self.batch_size = batch_size
67
+ self.dec_max_seq_len = dec_max_seq_len
68
+ self.use_attention_mask = use_attention_mask
69
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
70
+
71
+ def forward(
72
+ self,
73
+ decoder_input_ids: torch.Tensor = None,
74
+ decoder_attention_mask: torch.Tensor = None,
75
+ cache_position: torch.Tensor = None,
76
+ block_tables: torch.Tensor = None,
77
+ ):
78
+ inputs_bsz = decoder_input_ids.shape[0]
79
+ padded_bsz = self.batch_size - inputs_bsz
80
+
81
+ if padded_bsz > 0:
82
+ decoder_input_ids = torch.nn.functional.pad(decoder_input_ids, (0, 0, 0, padded_bsz))
83
+
84
+ if self.use_attention_mask:
85
+ for b_idx in range(self.batch_size):
86
+ decoding_step = cache_position[b_idx].item()
87
+ if not (0 <= decoding_step < self.dec_max_seq_len):
88
+ raise ValueError(
89
+ f"Decoding step {decoding_step} out of bounds for attention mask with shape {self.dec_attn_mask.shape}."
90
+ )
91
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
92
+
93
+ if block_tables is None:
94
+ block_tables = self.default_block_tables
95
+
96
+ outputs = super().forward(
97
+ decoder_input_ids,
98
+ decoder_attention_mask if self.use_attention_mask else None,
99
+ cache_position,
100
+ block_tables=block_tables,
101
+ )
102
+
103
+ if isinstance(outputs, torch.Tensor):
104
+ return Seq2SeqLMOutput(logits=outputs[:inputs_bsz], cross_attentions=None)
105
+ else:
106
+ return Seq2SeqLMOutput(logits=outputs[0][:inputs_bsz], cross_attentions=outputs[1][:, :inputs_bsz])
107
+
108
+
109
+ class RBLNWhisperForConditionalGeneration(RBLNModel, RBLNWhisperGenerationMixin):
110
+ """
111
+ Whisper model for speech recognition and transcription optimized for RBLN NPU.
112
+
113
+ This model inherits from [`RBLNModel`]. It implements the methods to convert and run
114
+ pre-trained transformers based Whisper model on RBLN devices by:
115
+
116
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
117
+ - compiling the resulting graph using the RBLN compiler.
118
+
119
+ Example (Short form):
120
+ ```python
121
+ import torch
122
+ from transformers import AutoProcessor
123
+ from datasets import load_dataset
124
+ from optimum.rbln import RBLNWhisperForConditionalGeneration
125
+
126
+ # Load processor and dataset
127
+ model_id = "openai/whisper-tiny"
128
+ processor = AutoProcessor.from_pretrained(model_id)
129
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
130
+
131
+ # Prepare input features
132
+ input_features = processor(
133
+ ds[0]["audio"]["array"],
134
+ sampling_rate=ds[0]["audio"]["sampling_rate"],
135
+ return_tensors="pt"
136
+ ).input_features
137
+
138
+ # Load and compile model (or load pre-compiled model)
139
+ model = RBLNWhisperForConditionalGeneration.from_pretrained(
140
+ model_id=model_id,
141
+ export=True,
142
+ rbln_batch_size=1
143
+ )
144
+
145
+ # Generate transcription
146
+ outputs = model.generate(input_features=input_features, return_timestamps=True)
147
+ transcription = processor.batch_decode(outputs, skip_special_tokens=True)[0]
148
+ print(f"Transcription: {transcription}")
149
+ ```
150
+ """
151
+
152
+ auto_model_class = AutoModelForSpeechSeq2Seq
153
+ main_input_name = "input_features"
154
+ _is_stateful = False
155
+
156
+ def __post_init__(self, **kwargs):
157
+ super().__post_init__(**kwargs)
158
+
159
+ self.batch_size = self.rbln_config.batch_size
160
+ self.dec_max_seq_len = self.rbln_config.dec_max_seq_len
161
+ self.rbln_token_timestamps = self.rbln_config.token_timestamps
162
+ self.use_attention_mask = self.rbln_config.use_attention_mask
163
+
164
+ self.encoder = RBLNRuntimeEncoder(runtime=self.model[0], main_input_name="input_features")
165
+ self.decoder = RBLNRuntimeDecoder(
166
+ runtime=self.model[1],
167
+ main_input_name="input_ids",
168
+ batch_size=self.batch_size,
169
+ dec_max_seq_len=self.dec_max_seq_len,
170
+ use_attention_mask=self.use_attention_mask,
171
+ )
172
+
173
+ # skip encoder & first decoder when language detected
174
+ self.is_language_detected = False
175
+ self.language_cross = None
176
+
177
+ # Used in GenerationMixin.generate()
178
+ # transformers/models/whisper/generation_whisper.py, line 505, in generate
179
+ # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0]
180
+ self.model = WhisperModel(self.config)
181
+ self.pad_token_id = self.config.pad_token_id
182
+
183
+ def can_generate(self):
184
+ return True
185
+
186
+ def get_encoder(self):
187
+ return self.encoder
188
+
189
+ def get_decoder(self):
190
+ return self.decoder
191
+
192
+ def __getattr__(self, __name: str) -> Any:
193
+ def redirect(func):
194
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
195
+
196
+ val = getattr(WhisperForConditionalGeneration, __name)
197
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
198
+ return redirect(val)
199
+ return val
200
+
201
+ def _reorder_cache(self, past_key_values, beam_idx):
202
+ # TODO(jongho): implement
203
+ raise NotImplementedError
204
+
205
+ @classmethod
206
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNWhisperForConditionalGenerationConfig):
207
+ return WhisperWrapper(
208
+ model,
209
+ use_attention_mask=rbln_config.use_attention_mask,
210
+ rbln_token_timestamps=rbln_config.token_timestamps,
211
+ )
212
+
213
+ @classmethod
214
+ @torch.inference_mode()
215
+ def get_compiled_model(cls, model, rbln_config: RBLNWhisperForConditionalGenerationConfig):
216
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
217
+
218
+ enc_compile_config = rbln_config.compile_cfgs[0]
219
+ dec_compile_config = rbln_config.compile_cfgs[1]
220
+
221
+ context = CompileContext(use_weight_sharing=False)
222
+
223
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
224
+
225
+ # Mark encoder's static tensors (cross kv states)
226
+ static_tensors = {}
227
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
228
+ if "key_value_states" in name:
229
+ static_tensors[name] = tensor
230
+ context.mark_static_address(tensor)
231
+
232
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
233
+
234
+ # Mark decoder's static tensors (self kv states)
235
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
236
+ if "key_value_states" in name:
237
+ context.mark_static_address(tensor)
238
+
239
+ compiled_encoder = cls.compile(
240
+ wrapped_model.encoder,
241
+ enc_compile_config,
242
+ create_runtimes=rbln_config.create_runtimes,
243
+ device=rbln_config.device,
244
+ example_inputs=enc_example_inputs,
245
+ compile_context=context,
246
+ )
247
+ compiled_decoder = cls.compile(
248
+ wrapped_model.decoder,
249
+ dec_compile_config,
250
+ create_runtimes=rbln_config.create_runtimes,
251
+ device=rbln_config.device,
252
+ example_inputs=dec_example_inputs,
253
+ compile_context=context,
254
+ )
255
+
256
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
257
+
258
+ @classmethod
259
+ def _update_paged_attention_config(
260
+ cls, model_config: "PretrainedConfig", rbln_config: RBLNWhisperForConditionalGenerationConfig
261
+ ):
262
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
263
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
264
+
265
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
266
+ raise NotImplementedError(
267
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
268
+ )
269
+
270
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
271
+ raise NotImplementedError(
272
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
273
+ )
274
+
275
+ @classmethod
276
+ def _update_rbln_config(
277
+ cls,
278
+ preprocessors: Optional[Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"]] = None,
279
+ model: Optional["PreTrainedModel"] = None,
280
+ model_config: Optional["PretrainedConfig"] = None,
281
+ rbln_config: Optional[RBLNWhisperForConditionalGenerationConfig] = None,
282
+ ) -> RBLNWhisperForConditionalGenerationConfig:
283
+ expected_seq_len = model_config.max_source_positions * 2
284
+ num_mel_bins = model_config.num_mel_bins
285
+ rbln_config.enc_max_seq_len = model_config.max_source_positions
286
+
287
+ # 'whisper-large-v3-turbo' doesn't have 'max_length', but PretrainedConfig have default value for the key 'max_length'
288
+ rbln_config.dec_max_seq_len = getattr(model_config, "max_target_positions", None)
289
+ if rbln_config.dec_max_seq_len is None:
290
+ rbln_config.dec_max_seq_len = model_config.max_length
291
+
292
+ cls._update_paged_attention_config(model_config, rbln_config)
293
+
294
+ enc_input_info = [
295
+ ("input_features", [1, num_mel_bins, expected_seq_len], "float32"),
296
+ ("block_tables", [1], "int16"),
297
+ (
298
+ "cross_key_value_states",
299
+ [
300
+ model_config.decoder_layers * 2,
301
+ rbln_config.batch_size,
302
+ model_config.decoder_attention_heads,
303
+ rbln_config.enc_max_seq_len,
304
+ model_config.d_model // model_config.decoder_attention_heads,
305
+ ],
306
+ "float32",
307
+ ),
308
+ ]
309
+
310
+ dec_input_info = [
311
+ ("decoder_input_ids", [rbln_config.batch_size, 1], "int64"),
312
+ ("cache_position", [rbln_config.batch_size, 1], "int32"),
313
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
314
+ ]
315
+ dec_input_info.extend(
316
+ [
317
+ (
318
+ "cross_key_value_states",
319
+ [
320
+ model_config.decoder_layers * 2,
321
+ rbln_config.batch_size,
322
+ model_config.decoder_attention_heads,
323
+ rbln_config.enc_max_seq_len,
324
+ model_config.d_model // model_config.decoder_attention_heads,
325
+ ],
326
+ "float32",
327
+ )
328
+ ]
329
+ )
330
+ dec_input_info.extend(
331
+ [
332
+ (
333
+ f"self_key_value_states_{i}",
334
+ [
335
+ rbln_config.batch_size,
336
+ model_config.decoder_attention_heads,
337
+ rbln_config.dec_max_seq_len,
338
+ model_config.d_model // model_config.encoder_attention_heads,
339
+ ],
340
+ "float32",
341
+ )
342
+ for i in range(model_config.decoder_layers * 2)
343
+ ]
344
+ )
345
+
346
+ if rbln_config.use_attention_mask:
347
+ dec_input_info.insert(
348
+ 1, ("decoder_attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
349
+ )
350
+
351
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
352
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
353
+
354
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
355
+
356
+ return rbln_config
357
+
358
+ @classmethod
359
+ def _create_runtimes(
360
+ cls,
361
+ compiled_models: List[rebel.RBLNCompiledModel],
362
+ rbln_config: RBLNWhisperForConditionalGenerationConfig,
363
+ ) -> List[rebel.Runtime]:
364
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
365
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
366
+
367
+ return [
368
+ rebel.Runtime(
369
+ compiled_models[0],
370
+ tensor_type="pt",
371
+ device=rbln_config.device_map["encoder"],
372
+ activate_profiler=rbln_config.activate_profiler,
373
+ timeout=rbln_config.timeout,
374
+ ),
375
+ rebel.Runtime(
376
+ compiled_models[1],
377
+ tensor_type="pt",
378
+ device=rbln_config.device_map["decoder"],
379
+ activate_profiler=rbln_config.activate_profiler,
380
+ timeout=rbln_config.timeout,
381
+ ),
382
+ ]
383
+
384
+ def prepare_inputs_for_generation(
385
+ self,
386
+ input_ids,
387
+ cache_position: Optional[torch.Tensor] = None,
388
+ attention_mask: Optional[torch.Tensor] = None, # need for support transformers>=4.45.0
389
+ **kwargs,
390
+ ):
391
+ return {
392
+ "input_ids": input_ids,
393
+ "cache_position": cache_position,
394
+ }
395
+
396
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/generation/utils.py#L512
397
+ def _prepare_encoder_decoder_kwargs_for_generation(
398
+ self,
399
+ inputs_tensor: torch.Tensor,
400
+ model_kwargs,
401
+ model_input_name: Optional[str] = None,
402
+ generation_config: Optional["GenerationConfig"] = None,
403
+ **kwargs,
404
+ ) -> Dict[str, Any]:
405
+ batch_size = inputs_tensor.shape[0]
406
+ n_pad_to_batch = self.batch_size - batch_size
407
+ if n_pad_to_batch > 0:
408
+ inputs_tensor = torch.nn.functional.pad(inputs_tensor, (0, 0, 0, 0, 0, n_pad_to_batch))
409
+
410
+ if not self.is_language_detected:
411
+ for b in range(inputs_tensor.shape[0]):
412
+ block_tables = torch.tensor([b], dtype=torch.int16)
413
+ model_kwargs["encoder_outputs"] = self.encoder(
414
+ input_features=inputs_tensor[b].unsqueeze(0), block_tables=block_tables
415
+ )
416
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
417
+ else:
418
+ model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=torch.tensor([[-1.0]]))
419
+
420
+ return model_kwargs
421
+
422
+ def forward(
423
+ self,
424
+ input_ids: Optional[torch.LongTensor] = None,
425
+ cache_position: Optional[torch.Tensor] = None,
426
+ input_features: Optional[torch.Tensor] = None,
427
+ decoder_input_ids: Optional[torch.Tensor] = None,
428
+ encoder_outputs: Optional[Seq2SeqLMOutput] = None,
429
+ **kwargs,
430
+ ) -> Seq2SeqLMOutput:
431
+ # default decoder pass
432
+ if input_features is None and encoder_outputs is None:
433
+ cross_attentions = []
434
+ for step in cache_position:
435
+ # skip step 0 if language_detection has been processed
436
+ if step == 0 and self.is_language_detected:
437
+ cross_attentions.append(self.language_cross)
438
+ self.is_language_detected = False
439
+ else:
440
+ self.decoder_attention_mask[:, step] = 1
441
+ decoder_output = self.decoder(
442
+ decoder_input_ids=input_ids[:, step : step + 1].contiguous(),
443
+ decoder_attention_mask=self.decoder_attention_mask,
444
+ cache_position=torch.full((self.batch_size, 1), step, dtype=torch.int32),
445
+ )
446
+ cross_attentions.append(decoder_output.cross_attentions)
447
+ lm_logits = decoder_output.logits
448
+
449
+ if self.rbln_token_timestamps:
450
+ cross_attentions = torch.cat(cross_attentions, dim=-2)
451
+ else:
452
+ cross_attentions = None
453
+
454
+ return Seq2SeqLMOutput(logits=lm_logits, cross_attentions=cross_attentions)
455
+
456
+ # detect language pass
457
+ # https://github.com/huggingface/transformers/blob/174890280b340b89c5bfa092f6b4fb0e2dc2d7fc/src/transformers/models/whisper/generation_whisper.py#L1442
458
+ else:
459
+ # for language auto detection (generate with language=None)
460
+ if encoder_outputs is None:
461
+ for b in range(input_features.shape[0]):
462
+ block_tables = torch.tensor([b], dtype=torch.int16)
463
+ self.encoder(input_features=input_features[b].unsqueeze(0), block_tables=block_tables)
464
+
465
+ self.decoder_attention_mask = torch.zeros(self.batch_size, self.dec_max_seq_len, dtype=torch.float32)
466
+ self.is_language_detected = True
467
+ self.decoder_attention_mask[:, 0] = 1
468
+ decoder_output = self.decoder(
469
+ decoder_input_ids=decoder_input_ids.contiguous(),
470
+ decoder_attention_mask=self.decoder_attention_mask,
471
+ cache_position=torch.zeros([self.rbln_config.batch_size, 1], dtype=torch.int32),
472
+ )
473
+ lm_logits = decoder_output.logits
474
+ self.language_cross = decoder_output.cross_attentions
475
+ return Seq2SeqLMOutput(logits=lm_logits)