optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,477 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from abc import ABC
17
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import rebel
20
+ import torch
21
+ from rebel.compile_context import CompileContext
22
+ from transformers import AutoModelForSeq2SeqLM, PretrainedConfig, PreTrainedModel
23
+ from transformers.generation.configuration_utils import GenerationConfig
24
+ from transformers.generation.utils import GenerationMixin
25
+ from transformers.modeling_outputs import BaseModelOutput, ModelOutput, Seq2SeqLMOutput
26
+
27
+ from ....configuration_utils import RBLNCompileConfig
28
+ from ....modeling import RBLNModel
29
+ from ....utils.logging import get_logger
30
+ from ....utils.runtime_utils import RBLNPytorchRuntime
31
+ from .configuration_seq2seq import RBLNModelForSeq2SeqLMConfig
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+ if TYPE_CHECKING:
37
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig
38
+
39
+
40
+ class RBLNRuntimeEncoder(RBLNPytorchRuntime):
41
+ mandatory_members = ["main_input_name"]
42
+
43
+ def forward(self, *args: List[torch.Tensor], **kwargs: torch.Tensor):
44
+ output = super().forward(*args, **kwargs)
45
+ return BaseModelOutput(last_hidden_state=output)
46
+
47
+
48
+ class RBLNRuntimeDecoder(RBLNPytorchRuntime):
49
+ mandatory_members = ["main_input_name"]
50
+
51
+ def __init__(
52
+ self,
53
+ runtime: rebel.Runtime,
54
+ batch_size: int,
55
+ dec_max_seq_len: int,
56
+ use_attention_mask: Optional[bool] = None,
57
+ **kwargs: Any,
58
+ ) -> None:
59
+ super().__init__(runtime, **kwargs)
60
+ self.batch_size = batch_size
61
+ self.dec_max_seq_len = dec_max_seq_len
62
+ self.use_attention_mask = use_attention_mask
63
+ self.default_block_tables = torch.arange(0, self.batch_size, dtype=torch.int16).view(self.batch_size, 1)
64
+
65
+ def forward(
66
+ self,
67
+ decoder_input_ids: Optional[torch.LongTensor] = None,
68
+ attention_mask: Optional[torch.FloatTensor] = None,
69
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
70
+ cache_position: Optional[torch.Tensor] = None,
71
+ block_tables: Optional[torch.Tensor] = None,
72
+ **kwargs,
73
+ ) -> Tuple[torch.FloatTensor]:
74
+ batch_size = decoder_input_ids.shape[0]
75
+ if batch_size != self.batch_size:
76
+ raise RuntimeError(
77
+ f"Batch size mismatch: got {batch_size}, expected {self.batch_size} (compiled batch size)."
78
+ )
79
+
80
+ if batch_size != cache_position.shape[0]:
81
+ raise RuntimeError(f"Cache position size mismatch: got {cache_position.shape[0]}, expected {batch_size}.")
82
+
83
+ if self.use_attention_mask:
84
+ for b_idx in range(self.batch_size):
85
+ decoding_step = cache_position[b_idx].item()
86
+ if not (0 <= decoding_step < self.dec_max_seq_len):
87
+ raise ValueError(
88
+ f"Decoding step {decoding_step} out of bounds for decoder_max_seq_len ({self.dec_max_seq_len})."
89
+ )
90
+ decoder_attention_mask[b_idx, : decoding_step + 1] = 1
91
+
92
+ if block_tables is None:
93
+ block_tables = self.default_block_tables
94
+
95
+ lm_logits = super().forward(
96
+ decoder_input_ids,
97
+ decoder_attention_mask if self.use_attention_mask else None,
98
+ attention_mask,
99
+ cache_position,
100
+ block_tables=block_tables,
101
+ )
102
+
103
+ return Seq2SeqLMOutput(logits=lm_logits)
104
+
105
+
106
+ class RBLNModelForSeq2SeqLM(RBLNModel, GenerationMixin, ABC):
107
+ """
108
+ This is a generic model class that will be instantiated as one of the model classes of the library (with a sequence-to-sequence language modeling head) when created with the from_pretrained() class method.
109
+ This model inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods the library implements for all its models.
110
+
111
+ A class to convert and run pre-trained transformers based Seq2SeqLM models on RBLN devices.
112
+ It implements the methods to convert a pre-trained transformers Seq2SeqLM model into a RBLN transformer model by:
113
+ - transferring the checkpoint weights of the original into an optimized RBLN graph,
114
+ - compiling the resulting graph using the RBLN compiler.
115
+
116
+ Currently, this model class only supports the 'bart' and 't5' models from the transformers library. Future updates may include support for additional model types.
117
+ """
118
+
119
+ main_input_name = "input_ids"
120
+ auto_model_class = AutoModelForSeq2SeqLM
121
+ support_causal_attn = None
122
+ _is_stateful = False
123
+
124
+ def __post_init__(self, **kwargs):
125
+ batch_size = self.rbln_config.batch_size
126
+ dec_max_seq_len = self.rbln_config.dec_max_seq_len
127
+ self.use_attention_mask = self.rbln_config.use_attention_mask
128
+
129
+ self.encoder = RBLNRuntimeEncoder(
130
+ runtime=self.model[0],
131
+ main_input_name="input_ids",
132
+ )
133
+ self.decoder = RBLNRuntimeDecoder(
134
+ runtime=self.model[1],
135
+ main_input_name="input_ids",
136
+ batch_size=batch_size,
137
+ dec_max_seq_len=dec_max_seq_len,
138
+ use_attention_mask=self.use_attention_mask,
139
+ )
140
+
141
+ @classmethod
142
+ @torch.inference_mode()
143
+ def get_compiled_model(cls, model: PreTrainedModel, rbln_config: RBLNModelForSeq2SeqLMConfig):
144
+ wrapped_model = cls._wrap_model_if_needed(model, rbln_config)
145
+
146
+ enc_compile_config = rbln_config.compile_cfgs[0]
147
+ dec_compile_config = rbln_config.compile_cfgs[1]
148
+
149
+ context = CompileContext(use_weight_sharing=False)
150
+
151
+ enc_example_inputs = enc_compile_config.get_dummy_inputs(fill=0)
152
+
153
+ # Mark encoder's static tensors (cross kv states)
154
+ static_tensors = {}
155
+ for (name, _, _), tensor in zip(enc_compile_config.input_info, enc_example_inputs):
156
+ if "key_value_states" in name:
157
+ static_tensors[name] = tensor
158
+ context.mark_static_address(tensor)
159
+
160
+ dec_example_inputs = dec_compile_config.get_dummy_inputs(fill=0, static_tensors=static_tensors)
161
+
162
+ # Mark decoder's static tensors (self kv states)
163
+ for (name, _, _), tensor in zip(dec_compile_config.input_info, dec_example_inputs):
164
+ if "key_value_states" in name:
165
+ context.mark_static_address(tensor)
166
+
167
+ compiled_encoder = cls.compile(
168
+ wrapped_model.encoder,
169
+ enc_compile_config,
170
+ create_runtimes=rbln_config.create_runtimes,
171
+ device=rbln_config.device,
172
+ example_inputs=enc_example_inputs,
173
+ compile_context=context,
174
+ )
175
+
176
+ compiled_decoder = cls.compile(
177
+ wrapped_model.decoder,
178
+ dec_compile_config,
179
+ create_runtimes=rbln_config.create_runtimes,
180
+ device=rbln_config.device,
181
+ example_inputs=dec_example_inputs,
182
+ compile_context=context,
183
+ )
184
+
185
+ return {"encoder": compiled_encoder, "decoder": compiled_decoder}
186
+
187
+ @classmethod
188
+ def _update_paged_attention_config(cls, model_config: PretrainedConfig, rbln_config: RBLNModelForSeq2SeqLMConfig):
189
+ rbln_config.kvcache_num_blocks = rbln_config.kvcache_num_blocks or rbln_config.batch_size
190
+ rbln_config.kvcache_block_size = rbln_config.kvcache_block_size or rbln_config.dec_max_seq_len
191
+
192
+ if rbln_config.kvcache_num_blocks != rbln_config.batch_size:
193
+ raise NotImplementedError(
194
+ f"kvcache_num_blocks ({rbln_config.kvcache_num_blocks}) must be equal to batch_size ({rbln_config.batch_size}) as flash attention is not supported yet."
195
+ )
196
+
197
+ if rbln_config.kvcache_block_size != rbln_config.dec_max_seq_len:
198
+ raise NotImplementedError(
199
+ f"kvcache_block_size ({rbln_config.kvcache_block_size}) must be equal to dec_max_seq_len ({rbln_config.dec_max_seq_len}) as flash attention is not supported yet."
200
+ )
201
+
202
+ @classmethod
203
+ def _update_rbln_config(
204
+ cls,
205
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
206
+ model: Optional["PreTrainedModel"] = None,
207
+ model_config: Optional["PretrainedConfig"] = None,
208
+ rbln_config: Optional[RBLNModelForSeq2SeqLMConfig] = None,
209
+ ) -> RBLNModelForSeq2SeqLMConfig:
210
+ if not cls.support_causal_attn:
211
+ rbln_config.use_attention_mask = True
212
+
213
+ n_layer = getattr(model_config, "decoder_layers", None) or getattr(model_config, "num_layers")
214
+ n_head = getattr(model_config, "decoder_attention_heads", None) or getattr(model_config, "num_heads")
215
+ d_kv = (
216
+ model_config.d_kv
217
+ if hasattr(model_config, "d_kv")
218
+ else model_config.d_model // model_config.encoder_attention_heads
219
+ )
220
+
221
+ max_position_embeddings = getattr(model_config, "n_positions", None) or getattr(
222
+ model_config, "max_position_embeddings", None
223
+ )
224
+
225
+ if rbln_config.enc_max_seq_len is None:
226
+ enc_max_seq_len = max_position_embeddings
227
+ for tokenizer in preprocessors:
228
+ if hasattr(tokenizer, "model_max_length"):
229
+ enc_max_seq_len = enc_max_seq_len or tokenizer.model_max_length
230
+ break
231
+
232
+ if enc_max_seq_len is None:
233
+ raise ValueError("`enc_max_seq_len` should be specified!")
234
+ rbln_config.enc_max_seq_len = enc_max_seq_len
235
+
236
+ if max_position_embeddings is not None and rbln_config.enc_max_seq_len > max_position_embeddings:
237
+ raise ValueError("`enc_max_seq_len` should be less or equal than max_position_embeddings!")
238
+
239
+ if rbln_config.dec_max_seq_len is None:
240
+ dec_max_seq_len = max_position_embeddings
241
+ for tokenizer in preprocessors:
242
+ if hasattr(tokenizer, "model_max_length"):
243
+ dec_max_seq_len = dec_max_seq_len or tokenizer.model_max_length
244
+ break
245
+
246
+ if dec_max_seq_len is None:
247
+ raise ValueError("`dec_max_seq_len` should be specified!")
248
+ rbln_config.dec_max_seq_len = dec_max_seq_len
249
+
250
+ if max_position_embeddings is not None and rbln_config.dec_max_seq_len > max_position_embeddings:
251
+ raise ValueError("`dec_max_seq_len` should be less or equal than max_position_embeddings!")
252
+
253
+ if rbln_config.support_paged_attention:
254
+ cls._update_paged_attention_config(model_config, rbln_config)
255
+
256
+ # model input info
257
+ enc_input_info = [
258
+ ("input_ids", [1, rbln_config.enc_max_seq_len], "int64"),
259
+ ("attention_mask", [1, rbln_config.enc_max_seq_len], "float32"),
260
+ ("block_tables", [1], "int16"),
261
+ ]
262
+ enc_input_info.extend(
263
+ [
264
+ (
265
+ f"cross_key_value_states_{i}",
266
+ [
267
+ rbln_config.batch_size,
268
+ n_head,
269
+ rbln_config.enc_max_seq_len,
270
+ d_kv,
271
+ ],
272
+ "float32",
273
+ )
274
+ for i in range(n_layer * 2)
275
+ ]
276
+ )
277
+
278
+ dec_input_info = [
279
+ ("input_ids", [rbln_config.batch_size, 1], "int64"),
280
+ ("encoder_attention_mask", [rbln_config.batch_size, rbln_config.enc_max_seq_len], "float32"),
281
+ (
282
+ "cache_position",
283
+ [rbln_config.batch_size, 1],
284
+ "int32",
285
+ ),
286
+ ("block_tables", [rbln_config.batch_size, 1], "int16"),
287
+ ]
288
+ dec_input_info.extend(
289
+ [
290
+ (
291
+ f"cross_key_value_states_{i}",
292
+ [
293
+ rbln_config.batch_size,
294
+ n_head,
295
+ rbln_config.enc_max_seq_len,
296
+ d_kv,
297
+ ],
298
+ "float32",
299
+ )
300
+ for i in range(n_layer * 2)
301
+ ]
302
+ )
303
+ dec_input_info.extend(
304
+ [
305
+ (
306
+ f"self_key_value_states_{i}",
307
+ [
308
+ rbln_config.batch_size,
309
+ n_head,
310
+ rbln_config.dec_max_seq_len,
311
+ d_kv,
312
+ ],
313
+ "float32",
314
+ )
315
+ for i in range(n_layer * 2)
316
+ ]
317
+ )
318
+
319
+ if rbln_config.use_attention_mask:
320
+ dec_input_info.insert(
321
+ 1, ("attention_mask", [rbln_config.batch_size, rbln_config.dec_max_seq_len], "float32")
322
+ )
323
+
324
+ enc_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
325
+ dec_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
326
+
327
+ rbln_config.set_compile_cfgs([enc_compile_config, dec_compile_config])
328
+
329
+ return rbln_config
330
+
331
+ @classmethod
332
+ def _create_runtimes(
333
+ cls,
334
+ compiled_models: List[rebel.RBLNCompiledModel],
335
+ rbln_config: RBLNModelForSeq2SeqLMConfig,
336
+ ) -> List[rebel.Runtime]:
337
+ if any(model_name not in rbln_config.device_map for model_name in ["encoder", "decoder"]):
338
+ cls._raise_missing_compiled_file_error(["encoder", "decoder"])
339
+
340
+ return [
341
+ rebel.Runtime(
342
+ compiled_models[0],
343
+ tensor_type="pt",
344
+ device=rbln_config.device_map["encoder"],
345
+ activate_profiler=rbln_config.activate_profiler,
346
+ timeout=rbln_config.timeout,
347
+ ),
348
+ rebel.Runtime(
349
+ compiled_models[1],
350
+ tensor_type="pt",
351
+ device=rbln_config.device_map["decoder"],
352
+ activate_profiler=rbln_config.activate_profiler,
353
+ timeout=rbln_config.timeout,
354
+ ),
355
+ ]
356
+
357
+ def can_generate(self):
358
+ return True
359
+
360
+ def get_encoder(self):
361
+ return self.encoder
362
+
363
+ def get_decoder(self):
364
+ return self.decoder
365
+
366
+ def prepare_inputs_for_generation(
367
+ self,
368
+ input_ids,
369
+ attention_mask=None,
370
+ decoder_attention_mask=None,
371
+ **kwargs,
372
+ ):
373
+ cur_seq_len = input_ids.shape[-1]
374
+ cache_position = cur_seq_len - 1
375
+ max_seq_len = self.rbln_config.dec_max_seq_len
376
+ decoder_batch_size = input_ids.shape[0]
377
+ input_ids = input_ids[:, cur_seq_len - 1 : cur_seq_len].contiguous()
378
+ decoder_attention_mask = torch.zeros(decoder_batch_size, max_seq_len, dtype=torch.float32)
379
+ decoder_attention_mask[:, :cur_seq_len] = 1
380
+
381
+ return {
382
+ "decoder_input_ids": input_ids,
383
+ "attention_mask": attention_mask.to(torch.float32),
384
+ "decoder_attention_mask": decoder_attention_mask,
385
+ "cache_position": cache_position,
386
+ }
387
+
388
+ def forward(
389
+ self,
390
+ decoder_input_ids: torch.LongTensor = None,
391
+ cache_position: Union[List[torch.Tensor], torch.Tensor] = None,
392
+ **kwargs,
393
+ ) -> Tuple[torch.FloatTensor]:
394
+ # common decoder
395
+ cache_position = torch.full((self.rbln_config.batch_size, 1), cache_position, dtype=torch.int32)
396
+ logits = self.decoder(decoder_input_ids=decoder_input_ids, cache_position=cache_position, **kwargs).logits
397
+
398
+ return Seq2SeqLMOutput(
399
+ logits=logits,
400
+ )
401
+
402
+ def _prepare_encoder_decoder_kwargs_for_generation(
403
+ self,
404
+ inputs_tensor: torch.Tensor,
405
+ model_kwargs,
406
+ model_input_name: Optional[str] = None,
407
+ generation_config: Optional["GenerationConfig"] = None,
408
+ ) -> Dict[str, Any]:
409
+ # 1. get encoder
410
+ encoder = self.get_encoder()
411
+
412
+ # 2. Prepare encoder args and encoder kwargs from model kwargs.
413
+ irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
414
+ encoder_kwargs = {
415
+ argument: value
416
+ for argument, value in model_kwargs.items()
417
+ if not any(argument.startswith(p) for p in irrelevant_prefix)
418
+ }
419
+ encoder_signature = set(inspect.signature(encoder.forward).parameters)
420
+ encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
421
+ if not encoder_accepts_wildcard:
422
+ encoder_kwargs = {
423
+ argument: value for argument, value in encoder_kwargs.items() if argument in encoder_signature
424
+ }
425
+
426
+ batch_size, input_len = inputs_tensor.shape
427
+ inputs_tensor = torch.nn.functional.pad(
428
+ inputs_tensor,
429
+ (0, self.rbln_config.enc_max_seq_len - input_len),
430
+ value=self.config.pad_token_id,
431
+ )
432
+ model_kwargs["attention_mask"] = torch.nn.functional.pad(
433
+ model_kwargs["attention_mask"], (0, self.rbln_config.enc_max_seq_len - input_len)
434
+ )
435
+
436
+ # 3. make sure that encoder returns `ModelOutput`
437
+ model_input_name = model_input_name if model_input_name is not None else self.main_input_name
438
+ encoder_kwargs["return_dict"] = True
439
+ encoder_kwargs["output_hidden_states"] = False
440
+ encoder_kwargs["output_attentions"] = False
441
+
442
+ for b in range(batch_size):
443
+ block_tables = torch.tensor([b], dtype=torch.int16)
444
+ encoder_kwargs["input_ids"] = inputs_tensor[b].unsqueeze(0)
445
+ encoder_kwargs["attention_mask"] = model_kwargs["attention_mask"][b].unsqueeze(0).to(torch.float32)
446
+ model_kwargs["encoder_outputs"] = encoder(**encoder_kwargs, block_tables=block_tables)
447
+
448
+ return model_kwargs
449
+
450
+ def generate(
451
+ self,
452
+ input_ids: torch.LongTensor,
453
+ attention_mask: Optional[torch.LongTensor] = None,
454
+ generation_config: Optional[GenerationConfig] = None,
455
+ **kwargs,
456
+ ) -> Union[ModelOutput, torch.LongTensor]:
457
+ """
458
+ The generate function is utilized in its standard form as in the HuggingFace transformers library. User can use this function to generate text from the model.
459
+ Check the [HuggingFace transformers documentation](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationMixin.generate) for more details.
460
+
461
+ Args:
462
+ input_ids (torch.LongTensor): The input ids to the model.
463
+ attention_mask (torch.LongTensor, optional): The attention mask to the model.
464
+ generation_config (GenerationConfig, optional): The generation configuration to be used as base parametrization for the generation call. **kwargs passed to generate matching the attributes of generation_config will override them.
465
+ If generation_config is not provided, the default will be used, which had the following loading priority: 1) from the generation_config.json model file, if it exists; 2) from the model configuration.
466
+ Please note that unspecified parameters will inherit [GenerationConfig](https://huggingface.co/docs/transformers/v4.57.1/en/main_classes/text_generation#transformers.GenerationConfig)’s default values.
467
+ kwargs (dict[str, Any], optional): Additional arguments passed to the generate function. See the HuggingFace transformers documentation for more details.
468
+
469
+ Returns:
470
+ Generates sequences of token ids for models with a language modeling head.
471
+ """
472
+ if generation_config is not None:
473
+ kwargs["generation_config"] = generation_config
474
+ if attention_mask is not None:
475
+ kwargs["attention_mask"] = attention_mask
476
+
477
+ return super().generate(input_ids, **kwargs)