optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,130 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import TYPE_CHECKING, Any, Callable
17
+
18
+ import torch
19
+ from transformers import AutoModelForTextEncoding, T5EncoderModel, T5ForConditionalGeneration
20
+ from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
21
+
22
+ from ...modeling_generic import RBLNTransformerEncoderForFeatureExtraction
23
+ from ...models.seq2seq import RBLNModelForSeq2SeqLM
24
+ from .configuration_t5 import RBLNT5EncoderModelConfig, RBLNT5ForConditionalGenerationConfig
25
+ from .t5_architecture import T5Wrapper
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers import PreTrainedModel
30
+
31
+ from ....diffusers.modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
32
+
33
+
34
+ class T5EncoderWrapper(torch.nn.Module):
35
+ def __init__(self, model: "T5EncoderModel") -> None:
36
+ super().__init__()
37
+ self.model = model
38
+
39
+ def forward(self, *args, **kwargs):
40
+ kwargs.pop("return_dict", None)
41
+ return self.model(*args, **kwargs, return_dict=False)
42
+
43
+
44
+ class RBLNT5EncoderModel(RBLNTransformerEncoderForFeatureExtraction):
45
+ """
46
+ The T5 Model transformer with an encoder-only architecture for feature extraction.
47
+ This model inherits from [`RBLNTransformerEncoderForFeatureExtraction`]. Check the superclass documentation for the generic methods the library implements for all its models.
48
+
49
+ Important Note:
50
+ This model supports various sizes of the T5EncoderModel. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
51
+ based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
52
+
53
+ Examples:
54
+ ```python
55
+ from optimum.rbln import RBLNT5EncoderModel
56
+
57
+ model = RBLNT5EncoderModel.from_pretrained(
58
+ "sentence-transformers/sentence-t5-xxl",
59
+ export=True,
60
+ rbln_tensor_parallel_size=4,
61
+ )
62
+
63
+ model.save_pretrained("compiled-sentence-t5-xxl")
64
+ ```
65
+ """
66
+
67
+ auto_model_class = AutoModelForTextEncoding
68
+ output_class = BaseModelOutputWithPastAndCrossAttentions
69
+
70
+ @classmethod
71
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5EncoderModelConfig):
72
+ return T5EncoderWrapper(model)
73
+
74
+ @classmethod
75
+ def update_rbln_config_using_pipe(
76
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
77
+ ) -> "RBLNDiffusionMixinConfig":
78
+ return rbln_config
79
+
80
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
81
+ input_dict = {"input_ids": input_ids.long()}
82
+ if attention_mask is not None:
83
+ input_dict["attention_mask"] = attention_mask.long()
84
+
85
+ output = super().forward(**input_dict, **kwargs)
86
+ return output
87
+
88
+
89
+ class RBLNT5ForConditionalGeneration(RBLNModelForSeq2SeqLM):
90
+ """
91
+ The T5 Model transformer with a language modeling head for conditional generation.
92
+ This model inherits from [`RBLNModelForSeq2SeqLM`]. Check the superclass documentation for the generic methods the library implements for all its models.
93
+
94
+ Important Note:
95
+ This model supports various sizes of the T5ForConditionalGeneration. For optimal performance, it is highly recommended to adjust the tensor parallelism setting
96
+ based on the model size. Please refer to the [Optimum RBLN Overview](../../../optimum_rbln.md) for guidance on choosing the appropriate tensor parallelism size for your model.
97
+
98
+
99
+ Examples:
100
+ ```python
101
+ from optimum.rbln import RBLNT5ForConditionalGeneration
102
+
103
+ model = RBLNT5ForConditionalGeneration.from_pretrained(
104
+ "google-t5/t5-11b",
105
+ export=True,
106
+ rbln_tensor_parallel_size=4,
107
+ )
108
+
109
+ model.save_pretrained("compiled-sentence-t5-xxl")
110
+ ```
111
+ """
112
+
113
+ support_causal_attn = False
114
+
115
+ @classmethod
116
+ def _wrap_model_if_needed(self, model: "PreTrainedModel", rbln_config: RBLNT5ForConditionalGenerationConfig):
117
+ return T5Wrapper(
118
+ model, enc_max_seq_len=rbln_config.enc_max_seq_len, dec_max_seq_len=rbln_config.dec_max_seq_len
119
+ )
120
+
121
+ def __getattr__(self, __name: str) -> Any:
122
+ def redirect(func):
123
+ return lambda *pargs, **kwargs: func(self, *pargs, **kwargs)
124
+
125
+ val = getattr(T5ForConditionalGeneration, __name)
126
+
127
+ if isinstance(val, Callable) and "self" in set(inspect.signature(val).parameters):
128
+ return redirect(val)
129
+
130
+ return val
@@ -0,0 +1,264 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple
16
+
17
+ import torch
18
+ from torch import nn
19
+ from transformers.utils import logging
20
+
21
+ from ..seq2seq.seq2seq_architecture import (
22
+ Seq2SeqDecoder,
23
+ Seq2SeqDecoderLayer,
24
+ Seq2SeqDecoderWrapper,
25
+ Seq2SeqEncoderWrapper,
26
+ Seq2SeqForConditionalGeneration,
27
+ Seq2SeqSelfAttention,
28
+ )
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ class T5Wrapper:
35
+ def __init__(self, model: nn.Module, enc_max_seq_len: int, dec_max_seq_len: int = None):
36
+ self.encoder = T5EncoderWrapper(model, enc_max_seq_len)
37
+ self.decoder = T5DecoderWrapper(model, dec_max_seq_len=dec_max_seq_len)
38
+
39
+
40
+ class T5EncoderWrapper(Seq2SeqEncoderWrapper):
41
+ def __post_init__(self, model: nn.Module):
42
+ self.n_layer = getattr(self.config, "num_layers")
43
+ self.cross_k_projects, self.cross_v_projects = self._extract_cross_kv_projects(model.get_decoder().block)
44
+ self.num_heads = self.config.num_heads
45
+ self.d_kv = self.config.d_kv
46
+
47
+ def _extract_cross_kv_projects(self, t5_block: nn.Module):
48
+ return (
49
+ # different from bart
50
+ nn.ModuleList(t5_block[i].layer[1].EncDecAttention.k for i in range(self.n_layer)),
51
+ nn.ModuleList(t5_block[i].layer[1].EncDecAttention.v for i in range(self.n_layer)),
52
+ )
53
+
54
+
55
+ class T5DecoderWrapper(Seq2SeqDecoderWrapper):
56
+ def __post_init__(self, model, dec_max_seq_len: int = None):
57
+ self.num_layers = self.config.num_layers
58
+ self.conditional_generation = self.convert_to_rbln_conditional_generation(model, dec_max_seq_len)
59
+
60
+ def convert_to_rbln_conditional_generation(self, model: nn.Module, dec_max_seq_len: int):
61
+ new_blocks = []
62
+ for block in model.get_decoder().block:
63
+ self_attn = T5LayerSelfAttention(block.layer[0].SelfAttention)
64
+ block = T5Block(block, self_attn)
65
+ new_blocks.append(block)
66
+
67
+ decoder_model = T5Decoder(model.get_decoder(), new_blocks, dec_max_seq_len=dec_max_seq_len)
68
+ new_model = T5ForConditionalGeneration(model, decoder_model)
69
+
70
+ return new_model
71
+
72
+ def forward(
73
+ self,
74
+ input_ids,
75
+ attention_mask,
76
+ encoder_attention_mask,
77
+ cache_position,
78
+ block_tables,
79
+ *kv_cache,
80
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor]]:
81
+ self_past_key_values = ()
82
+ cross_past_key_values = ()
83
+ self_kv_cache = kv_cache[self.num_layers * 2 :]
84
+ cross_kv_cache = kv_cache[: self.num_layers * 2]
85
+
86
+ for i in range(0, self.num_layers * 2, 2):
87
+ self_past_key_values = self_past_key_values + ((self_kv_cache[i], self_kv_cache[i + 1]),)
88
+ cross_past_key_values = cross_past_key_values + ((cross_kv_cache[i], cross_kv_cache[i + 1]),)
89
+
90
+ # decode
91
+ lm_logits = self.conditional_generation(
92
+ input_ids=input_ids,
93
+ attention_mask=attention_mask,
94
+ encoder_attention_mask=encoder_attention_mask,
95
+ self_past_key_values=self_past_key_values,
96
+ cross_past_key_values=cross_past_key_values,
97
+ cache_position=cache_position,
98
+ block_tables=block_tables,
99
+ )
100
+
101
+ return lm_logits
102
+
103
+
104
+ class T5ForConditionalGeneration(Seq2SeqForConditionalGeneration):
105
+ has_rescaling = True
106
+
107
+ def __post_init__(self):
108
+ self.scaling = self.config.d_model**-0.5
109
+
110
+
111
+ class T5Decoder(Seq2SeqDecoder):
112
+ has_pos_emb = False
113
+
114
+ def __post_init__(self, dec_max_seq_len: int = None):
115
+ self.invert_attention_mask = self._original_mod.invert_attention_mask
116
+ self._dec_position_bias = self.precompute_dec_position_bias(self._original_mod, dec_max_seq_len)
117
+
118
+ def precompute_dec_position_bias(self, model, dec_max_length):
119
+ attn_layer = model.block[0].layer[0].SelfAttention
120
+ return attn_layer.compute_bias(dec_max_length, dec_max_length)
121
+
122
+ def prepare_attn_mask(self, attention_mask, encoder_attention_mask, cache_position):
123
+ attention_mask = self.invert_attention_mask(attention_mask)
124
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
125
+
126
+ b_size = attention_mask.shape[0]
127
+ batch_decoder_position_bias = []
128
+ for i in range(b_size):
129
+ if torch.compiler.is_exporting():
130
+ cache_pos = cache_position[i][0].item()
131
+ torch._check_is_size(cache_pos)
132
+ torch._check(cache_pos >= 0)
133
+ torch._check(cache_pos < self._dec_position_bias.shape[2])
134
+ else:
135
+ cache_pos = cache_position[i][0]
136
+ batch_position_bias = torch.select(self._dec_position_bias, dim=2, index=cache_pos).unsqueeze(2)
137
+ batch_decoder_position_bias.append(batch_position_bias)
138
+ position_bias = torch.cat(batch_decoder_position_bias, dim=0)
139
+
140
+ attention_mask = position_bias + attention_mask
141
+
142
+ return attention_mask, encoder_attention_mask
143
+
144
+
145
+ class T5Block(Seq2SeqDecoderLayer):
146
+ def __init__(self, decoder_layer, self_attn):
147
+ super().__init__(decoder_layer, self_attn, cross_attn=None)
148
+ self.__post_init__()
149
+
150
+ def __post_init__(self):
151
+ self.self_attn_layer_norm = self._original_mod.layer[0].layer_norm
152
+ self.encoder_attn_layer_norm = self._original_mod.layer[1].layer_norm
153
+ self.cross_attn = T5CrossAttention(self._original_mod.layer[1].EncDecAttention)
154
+ self.ff_layer = self._original_mod.layer[2]
155
+
156
+ def pre_self_attn_layer_norm(self, hidden_states):
157
+ return self.self_attn_layer_norm(hidden_states)
158
+
159
+ def post_self_attn_layer_norm(self, hidden_states):
160
+ return hidden_states
161
+
162
+ def pre_cross_attn_layer_norm(self, hidden_states):
163
+ return self.encoder_attn_layer_norm(hidden_states)
164
+
165
+ def post_cross_attn_layer_norm(self, hidden_states):
166
+ return hidden_states
167
+
168
+
169
+ class T5LayerSelfAttention(Seq2SeqSelfAttention):
170
+ def __post_init__(self):
171
+ self.q_proj = self._original_mod.q
172
+ self.k_proj = self._original_mod.k
173
+ self.v_proj = self._original_mod.v
174
+ self.out_proj = self._original_mod.o
175
+ self.num_heads = self._original_mod.n_heads
176
+ self.head_dim = self._original_mod.key_value_proj_dim
177
+ self.attn_decode = torch.ops.rbln_custom_ops.paged_add_softmax_attn_decode
178
+
179
+ def projection(self, hidden_states) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
180
+ query_states = self.q_proj(hidden_states)
181
+ key_states = self.k_proj(hidden_states)
182
+ value_states = self.v_proj(hidden_states)
183
+ return query_states, key_states, value_states
184
+
185
+ def forward(
186
+ self,
187
+ hidden_states: torch.Tensor,
188
+ past_key_value: Tuple[torch.Tensor],
189
+ attention_mask: torch.Tensor,
190
+ cache_position: torch.Tensor,
191
+ block_tables: torch.Tensor,
192
+ **kwargs,
193
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor]]:
194
+ bsz, tgt_len, _ = hidden_states.size()
195
+
196
+ query_states, key_states, value_states = self.projection(hidden_states=hidden_states)
197
+ query_states = self._shape(query_states, tgt_len, bsz)
198
+ key_states = self._shape(key_states, -1, bsz)
199
+ value_states = self._shape(value_states, -1, bsz)
200
+
201
+ block_size = past_key_value[0].shape[-2]
202
+ attn_output = self.attn_decode(
203
+ query_states,
204
+ key_states,
205
+ value_states,
206
+ attention_mask.unsqueeze(
207
+ 2
208
+ ), # Unsqueeze group axis since CustomKernel expects it for group query attention
209
+ past_key_value[0].view(bsz, self.num_heads, 1, -1, self.head_dim),
210
+ past_key_value[1].view(bsz, self.num_heads, 1, -1, self.head_dim),
211
+ cache_position,
212
+ torch.tensor(1.0, dtype=torch.float32), # scale
213
+ block_tables,
214
+ block_size,
215
+ )
216
+
217
+ attn_output = attn_output.view(bsz, self.num_heads, -1, self.head_dim).transpose(1, 2)
218
+ attn_output = attn_output.reshape(bsz, -1, self.num_heads * self.head_dim)
219
+
220
+ attn_output = self.out_proj(attn_output)
221
+ return attn_output
222
+
223
+
224
+ class T5CrossAttention(nn.Module):
225
+ def __init__(self, attn):
226
+ super().__init__()
227
+ self.attn = attn
228
+ self.q = attn.q
229
+ self.o = attn.o
230
+ self.n_heads = attn.n_heads
231
+ self.key_value_proj_dim = attn.key_value_proj_dim
232
+ self.inner_dim = attn.inner_dim
233
+
234
+ def forward(
235
+ self,
236
+ hidden_states: torch.Tensor = None,
237
+ past_key_value: torch.Tensor = None,
238
+ attention_mask: torch.Tensor = None,
239
+ key_value_states: torch.Tensor = None,
240
+ ):
241
+ batch_size = hidden_states.shape[0]
242
+
243
+ query_states = self.q(hidden_states)
244
+ query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
245
+
246
+ # reuse k,v, cross_attentions
247
+ key_states = past_key_value[0]
248
+ value_states = past_key_value[1]
249
+
250
+ # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
251
+ scores = torch.matmul(query_states, key_states.transpose(3, 2))
252
+ scores += attention_mask
253
+
254
+ # (batch_size, n_heads, seq_length, key_length)
255
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
256
+ attn_output = torch.matmul(attn_weights, value_states)
257
+
258
+ attn_output = attn_output.transpose(1, 2).contiguous()
259
+ attn_output = attn_output.view(batch_size, -1, self.inner_dim)
260
+ attn_output = self.o(attn_output)
261
+
262
+ outputs = (attn_output, past_key_value)
263
+
264
+ return outputs
@@ -0,0 +1,26 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Portions of this software are licensed under the Apache License,
16
+ # Version 2.0. See the NOTICE file distributed with this work for
17
+ # additional information regarding copyright ownership.
18
+
19
+ # All other portions of this software, including proprietary code,
20
+ # are the intellectual property of Rebellions Inc. and may not be
21
+ # copied, modified, or distributed without prior written permission
22
+ # from Rebellions Inc.
23
+
24
+ from ....ops import paged_add_softmax_attn_decode, rbln_cache_update
25
+ from .configuration_time_series_transformer import RBLNTimeSeriesTransformerForPredictionConfig
26
+ from .modeling_time_series_transformer import RBLNTimeSeriesTransformerForPrediction
@@ -0,0 +1,41 @@
1
+ from typing import Any, Optional
2
+
3
+ from ....configuration_utils import RBLNModelConfig
4
+
5
+
6
+ class RBLNTimeSeriesTransformerForPredictionConfig(RBLNModelConfig):
7
+ """
8
+ Configuration class for RBLNTimeSeriesTransformerForPrediction.
9
+
10
+ This configuration class stores the configuration parameters specific to
11
+ RBLN-optimized Time Series Transformer models for time series forecasting tasks.
12
+ """
13
+
14
+ def __init__(
15
+ self,
16
+ batch_size: Optional[int] = None,
17
+ enc_max_seq_len: Optional[int] = None,
18
+ dec_max_seq_len: Optional[int] = None,
19
+ num_parallel_samples: Optional[int] = None,
20
+ **kwargs: Any,
21
+ ):
22
+ """
23
+ Args:
24
+ batch_size (Optional[int]): The batch size for inference. Defaults to 1.
25
+ enc_max_seq_len (Optional[int]): Maximum sequence length for the encoder.
26
+ dec_max_seq_len (Optional[int]): Maximum sequence length for the decoder.
27
+ num_parallel_samples (Optional[int]): Number of samples to generate in parallel during prediction.
28
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
29
+
30
+ Raises:
31
+ ValueError: If batch_size is not a positive integer.
32
+ """
33
+ super().__init__(**kwargs)
34
+
35
+ self.batch_size = batch_size or 1
36
+ if not isinstance(self.batch_size, int) or self.batch_size <= 0:
37
+ raise ValueError(f"batch_size must be a positive integer, got {self.batch_size}")
38
+
39
+ self.enc_max_seq_len = enc_max_seq_len
40
+ self.dec_max_seq_len = dec_max_seq_len
41
+ self.num_parallel_samples = num_parallel_samples