optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,255 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers import AutoencoderKL
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ...configurations import RBLNAutoencoderKLConfig
28
+ from .vae import RBLNRuntimeVAEDecoder, RBLNRuntimeVAEEncoder, _VAEDecoder, _VAEEncoder
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ import torch
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
34
+
35
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
36
+
37
+ logger = get_logger(__name__)
38
+
39
+
40
+ class RBLNAutoencoderKL(RBLNModel):
41
+ """
42
+ RBLN implementation of AutoencoderKL (VAE) for diffusion models.
43
+
44
+ This model is used to accelerate AutoencoderKL (VAE) models from diffusers library on RBLN NPUs.
45
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
46
+ conversion.
47
+
48
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
49
+ the library implements for all its models.
50
+ """
51
+
52
+ auto_model_class = AutoencoderKL
53
+ hf_library_name = "diffusers"
54
+ _rbln_config_class = RBLNAutoencoderKLConfig
55
+
56
+ def __post_init__(self, **kwargs):
57
+ super().__post_init__(**kwargs)
58
+
59
+ if self.rbln_config.uses_encoder:
60
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
61
+ else:
62
+ self.encoder = None
63
+
64
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
65
+ self.image_size = self.rbln_config.image_size
66
+
67
+ @classmethod
68
+ def get_compiled_model(cls, model, rbln_config: RBLNAutoencoderKLConfig) -> Dict[str, rebel.RBLNCompiledModel]:
69
+ if rbln_config.uses_encoder:
70
+ expected_models = ["encoder", "decoder"]
71
+ else:
72
+ expected_models = ["decoder"]
73
+
74
+ compiled_models = {}
75
+ for i, model_name in enumerate(expected_models):
76
+ if model_name == "encoder":
77
+ wrapped_model = _VAEEncoder(model)
78
+ else:
79
+ wrapped_model = _VAEDecoder(model)
80
+
81
+ wrapped_model.eval()
82
+
83
+ compiled_models[model_name] = cls.compile(
84
+ wrapped_model,
85
+ rbln_compile_config=rbln_config.compile_cfgs[i],
86
+ create_runtimes=rbln_config.create_runtimes,
87
+ device=rbln_config.device_map[model_name],
88
+ )
89
+
90
+ return compiled_models
91
+
92
+ @classmethod
93
+ def get_vae_sample_size(
94
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: RBLNAutoencoderKLConfig, return_vae_scale_factor: bool = False
95
+ ) -> Tuple[int, int]:
96
+ sample_size = rbln_config.sample_size
97
+ noise_module = getattr(pipe, "unet", None) or getattr(pipe, "transformer", None)
98
+ vae_scale_factor = (
99
+ pipe.vae_scale_factor
100
+ if hasattr(pipe, "vae_scale_factor")
101
+ else 2 ** (len(pipe.vae.config.block_out_channels) - 1)
102
+ )
103
+
104
+ if noise_module is None:
105
+ raise AttributeError(
106
+ "Cannot find noise processing or predicting module attributes. ex. U-Net, Transformer, ..."
107
+ )
108
+
109
+ if sample_size is None:
110
+ sample_size = noise_module.config.sample_size
111
+ if isinstance(sample_size, int):
112
+ sample_size = (sample_size, sample_size)
113
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
114
+
115
+ if return_vae_scale_factor:
116
+ return sample_size, vae_scale_factor
117
+ else:
118
+ return sample_size
119
+
120
+ @classmethod
121
+ def update_rbln_config_using_pipe(
122
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
123
+ ) -> "RBLNDiffusionMixinConfig":
124
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
125
+ pipe, rbln_config.vae, return_vae_scale_factor=True
126
+ )
127
+ return rbln_config
128
+
129
+ @classmethod
130
+ def _update_rbln_config(
131
+ cls,
132
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
133
+ model: "PreTrainedModel",
134
+ model_config: "PretrainedConfig",
135
+ rbln_config: RBLNAutoencoderKLConfig,
136
+ ) -> RBLNAutoencoderKLConfig:
137
+ if rbln_config.sample_size is None:
138
+ rbln_config.sample_size = model_config.sample_size
139
+
140
+ if isinstance(rbln_config.sample_size, int):
141
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
142
+
143
+ if rbln_config.in_channels is None:
144
+ rbln_config.in_channels = model_config.in_channels
145
+
146
+ if rbln_config.latent_channels is None:
147
+ rbln_config.latent_channels = model_config.latent_channels
148
+
149
+ if rbln_config.vae_scale_factor is None:
150
+ if hasattr(model_config, "block_out_channels"):
151
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
152
+ else:
153
+ # vae image processor default value 8 (int)
154
+ rbln_config.vae_scale_factor = 8
155
+
156
+ compile_cfgs = []
157
+ if rbln_config.uses_encoder:
158
+ vae_enc_input_info = [
159
+ (
160
+ "x",
161
+ [
162
+ rbln_config.batch_size,
163
+ rbln_config.in_channels,
164
+ rbln_config.sample_size[0],
165
+ rbln_config.sample_size[1],
166
+ ],
167
+ "float32",
168
+ )
169
+ ]
170
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
171
+
172
+ vae_dec_input_info = [
173
+ (
174
+ "z",
175
+ [
176
+ rbln_config.batch_size,
177
+ rbln_config.latent_channels,
178
+ rbln_config.latent_sample_size[0],
179
+ rbln_config.latent_sample_size[1],
180
+ ],
181
+ "float32",
182
+ )
183
+ ]
184
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
185
+
186
+ rbln_config.set_compile_cfgs(compile_cfgs)
187
+ return rbln_config
188
+
189
+ @classmethod
190
+ def _create_runtimes(
191
+ cls,
192
+ compiled_models: List[rebel.RBLNCompiledModel],
193
+ rbln_config: RBLNAutoencoderKLConfig,
194
+ ) -> List[rebel.Runtime]:
195
+ if len(compiled_models) == 1:
196
+ # decoder
197
+ expected_models = ["decoder"]
198
+ else:
199
+ # encoder, decoder
200
+ expected_models = ["encoder", "decoder"]
201
+
202
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
203
+ cls._raise_missing_compiled_file_error(expected_models)
204
+
205
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
206
+ return [
207
+ rebel.Runtime(
208
+ compiled_model,
209
+ tensor_type="pt",
210
+ device=device_val,
211
+ activate_profiler=rbln_config.activate_profiler,
212
+ timeout=rbln_config.timeout,
213
+ )
214
+ for compiled_model, device_val in zip(compiled_models, device_vals)
215
+ ]
216
+
217
+ def encode(
218
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
219
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
220
+ """
221
+ Encode an input image into a latent representation.
222
+
223
+ Args:
224
+ x: The input image to encode.
225
+ return_dict:
226
+ Whether to return output as a dictionary. Defaults to True.
227
+ kwargs: Additional arguments to pass to the encoder.
228
+
229
+ Returns:
230
+ The latent representation or AutoencoderKLOutput if return_dict=True
231
+ """
232
+ posterior = self.encoder.encode(x)
233
+ if not return_dict:
234
+ return (posterior,)
235
+ return AutoencoderKLOutput(latent_dist=posterior)
236
+
237
+ def decode(
238
+ self, z: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
239
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
240
+ """
241
+ Decode a latent representation into an image.
242
+
243
+ Args:
244
+ z: The latent representation to decode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+ kwargs: Additional arguments to pass to the decoder.
248
+
249
+ Returns:
250
+ The decoded image or DecoderOutput if return_dict=True
251
+ """
252
+ dec = self.decoder.decode(z)
253
+ if not return_dict:
254
+ return (dec,)
255
+ return DecoderOutput(sample=dec)
@@ -0,0 +1,245 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, Dict, List, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers.models.autoencoders.autoencoder_kl_cosmos import AutoencoderKLCosmos, CosmosCausalConv3d
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from torch.nn import functional as F
23
+ from transformers import PretrainedConfig
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...configurations import RBLNAutoencoderKLCosmosConfig
29
+ from .vae import RBLNRuntimeCosmosVAEDecoder, RBLNRuntimeCosmosVAEEncoder, _VAECosmosDecoder, _VAECosmosEncoder
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ import torch
34
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
35
+
36
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
37
+
38
+ logger = get_logger(__name__)
39
+
40
+
41
+ class RBLNAutoencoderKLCosmos(RBLNModel):
42
+ """
43
+ RBLN implementation of AutoencoderKLCosmos for diffusion models.
44
+
45
+ This model is used to accelerate AutoencoderKLCosmos models from diffusers library on RBLN NPUs.
46
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-video
47
+ conversion.
48
+
49
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
50
+ the library implements for all its models.
51
+ """
52
+
53
+ auto_model_class = AutoencoderKLCosmos
54
+ hf_library_name = "diffusers"
55
+ _rbln_config_class = RBLNAutoencoderKLCosmosConfig
56
+
57
+ def __post_init__(self, **kwargs):
58
+ super().__post_init__(**kwargs)
59
+
60
+ if self.rbln_config.uses_encoder:
61
+ self.encoder = RBLNRuntimeCosmosVAEEncoder(
62
+ runtime=self.model[0], main_input_name="x", use_slicing=self.rbln_config.use_slicing
63
+ )
64
+
65
+ self.decoder = RBLNRuntimeCosmosVAEDecoder(
66
+ runtime=self.model[-1], main_input_name="z", use_slicing=self.rbln_config.use_slicing
67
+ )
68
+ self.image_size = self.rbln_config.image_size
69
+
70
+ @classmethod
71
+ def _wrap_model_if_needed(
72
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLCosmosConfig
73
+ ) -> torch.nn.Module:
74
+ decoder_model = _VAECosmosDecoder(model)
75
+ decoder_model.eval()
76
+
77
+ if rbln_config.uses_encoder:
78
+ encoder_model = _VAECosmosEncoder(model)
79
+ encoder_model.eval()
80
+ return encoder_model, decoder_model
81
+ else:
82
+ return decoder_model
83
+
84
+ @classmethod
85
+ def get_compiled_model(
86
+ cls, model, rbln_config: RBLNAutoencoderKLCosmosConfig
87
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
88
+ def replaced_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
89
+ if self.temporal_pad != 0:
90
+ hidden_states_prev = hidden_states[:, :, :1, ...].repeat(1, 1, self.temporal_pad, 1, 1)
91
+ hidden_states = torch.cat([hidden_states_prev, hidden_states], dim=2)
92
+ hidden_states = F.pad(hidden_states, (*self.spatial_pad, 0, 0), mode=self.pad_mode, value=0.0)
93
+ return super(CosmosCausalConv3d, self).forward(hidden_states)
94
+
95
+ try:
96
+ original_forward = CosmosCausalConv3d.forward
97
+ CosmosCausalConv3d.forward = replaced_forward
98
+
99
+ compiled_models = {}
100
+ if rbln_config.uses_encoder:
101
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
102
+ enc_compiled_model = cls.compile(
103
+ encoder_model,
104
+ rbln_compile_config=rbln_config.compile_cfgs[0],
105
+ create_runtimes=rbln_config.create_runtimes,
106
+ device=rbln_config.device_map["encoder"],
107
+ )
108
+ compiled_models["encoder"] = enc_compiled_model
109
+ else:
110
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
111
+ dec_compiled_model = cls.compile(
112
+ decoder_model,
113
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
114
+ create_runtimes=rbln_config.create_runtimes,
115
+ device=rbln_config.device_map["decoder"],
116
+ )
117
+ compiled_models["decoder"] = dec_compiled_model
118
+
119
+ finally:
120
+ CosmosCausalConv3d.forward = original_forward
121
+
122
+ return compiled_models
123
+
124
+ @classmethod
125
+ def update_rbln_config_using_pipe(
126
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
127
+ ) -> "RBLNDiffusionMixinConfig":
128
+ rbln_config.vae.num_channels_latents = pipe.transformer.config.out_channels
129
+ rbln_config.vae.vae_scale_factor_temporal = pipe.vae_scale_factor_temporal
130
+ rbln_config.vae.vae_scale_factor_spatial = pipe.vae_scale_factor_spatial
131
+ return rbln_config
132
+
133
+ @classmethod
134
+ def _update_rbln_config(
135
+ cls,
136
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
137
+ model: "PreTrainedModel",
138
+ model_config: "PretrainedConfig",
139
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
140
+ ) -> RBLNAutoencoderKLCosmosConfig:
141
+ batch_size = 1 if rbln_config.use_slicing else rbln_config.batch_size
142
+ compile_cfgs = []
143
+ if rbln_config.uses_encoder:
144
+ vae_enc_input_info = [
145
+ (
146
+ "x",
147
+ [
148
+ batch_size,
149
+ model_config.in_channels,
150
+ rbln_config.num_frames,
151
+ rbln_config.height,
152
+ rbln_config.width,
153
+ ],
154
+ "float32",
155
+ ),
156
+ ]
157
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
158
+
159
+ num_latent_frames = (rbln_config.num_frames - 1) // rbln_config.vae_scale_factor_temporal + 1
160
+ latent_height = rbln_config.height // rbln_config.vae_scale_factor_spatial
161
+ latent_width = rbln_config.width // rbln_config.vae_scale_factor_spatial
162
+
163
+ vae_dec_input_info = [
164
+ (
165
+ "z",
166
+ [
167
+ batch_size,
168
+ rbln_config.num_channels_latents,
169
+ num_latent_frames,
170
+ latent_height,
171
+ latent_width,
172
+ ],
173
+ "float32",
174
+ ),
175
+ ]
176
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
177
+
178
+ rbln_config.set_compile_cfgs(compile_cfgs)
179
+ return rbln_config
180
+
181
+ @classmethod
182
+ def _create_runtimes(
183
+ cls,
184
+ compiled_models: List[rebel.RBLNCompiledModel],
185
+ rbln_config: RBLNAutoencoderKLCosmosConfig,
186
+ ) -> List[rebel.Runtime]:
187
+ if len(compiled_models) == 1:
188
+ # decoder
189
+ expected_models = ["decoder"]
190
+ else:
191
+ expected_models = ["encoder", "decoder"]
192
+
193
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
194
+ cls._raise_missing_compiled_file_error(expected_models)
195
+
196
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
197
+ return [
198
+ rebel.Runtime(
199
+ compiled_model,
200
+ tensor_type="pt",
201
+ device=device_val,
202
+ activate_profiler=rbln_config.activate_profiler,
203
+ timeout=rbln_config.timeout,
204
+ )
205
+ for compiled_model, device_val in zip(compiled_models, device_vals)
206
+ ]
207
+
208
+ def encode(
209
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Dict[str, Any]
210
+ ) -> Union[torch.FloatTensor, AutoencoderKLOutput]:
211
+ """
212
+ Encode an input video into a latent representation.
213
+
214
+ Args:
215
+ x: The input video to encode.
216
+ return_dict:
217
+ Whether to return output as a dictionary. Defaults to True.
218
+ kwargs: Additional arguments to pass to the encoder.
219
+
220
+ Returns:
221
+ The latent representation or AutoencoderKLOutput if return_dict=True
222
+ """
223
+ posterior = self.encoder.encode(x)
224
+ if not return_dict:
225
+ return (posterior,)
226
+ return AutoencoderKLOutput(latent_dist=posterior)
227
+
228
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[torch.FloatTensor, DecoderOutput]:
229
+ """
230
+ Decode a latent representation into a video.
231
+
232
+ Args:
233
+ z: The latent representation to decode.
234
+ return_dict:
235
+ Whether to return output as a dictionary. Defaults to True.
236
+
237
+ Returns:
238
+ The decoded video or DecoderOutput if return_dict=True
239
+ """
240
+ decoded = self.decoder.decode(z)
241
+
242
+ if not return_dict:
243
+ return (decoded,)
244
+
245
+ return DecoderOutput(sample=decoded)