optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,275 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, List, Tuple, Union
16
+
17
+ import rebel
18
+ import torch # noqa: I001
19
+ from diffusers import AutoencoderKLTemporalDecoder
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
+ from transformers import PretrainedConfig
23
+
24
+ from ....configuration_utils import RBLNCompileConfig
25
+ from ....modeling import RBLNModel
26
+ from ....utils.logging import get_logger
27
+ from ...configurations import RBLNAutoencoderKLTemporalDecoderConfig
28
+ from ...modeling_diffusers import RBLNDiffusionMixin
29
+ from .vae import (
30
+ DiagonalGaussianDistribution,
31
+ RBLNRuntimeVAEDecoder,
32
+ RBLNRuntimeVAEEncoder,
33
+ _VAEEncoder,
34
+ _VAETemporalDecoder,
35
+ )
36
+
37
+
38
+ if TYPE_CHECKING:
39
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
40
+
41
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
42
+
43
+ logger = get_logger(__name__)
44
+
45
+
46
+ class RBLNAutoencoderKLTemporalDecoder(RBLNModel):
47
+ auto_model_class = AutoencoderKLTemporalDecoder
48
+ hf_library_name = "diffusers"
49
+ _rbln_config_class = RBLNAutoencoderKLTemporalDecoderConfig
50
+
51
+ def __post_init__(self, **kwargs):
52
+ super().__post_init__(**kwargs)
53
+
54
+ if self.rbln_config.uses_encoder:
55
+ self.encoder = RBLNRuntimeVAEEncoder(runtime=self.model[0], main_input_name="x")
56
+ self.decoder = RBLNRuntimeVAEDecoder(runtime=self.model[-1], main_input_name="z")
57
+ self.image_size = self.rbln_config.image_size
58
+
59
+ @classmethod
60
+ def _wrap_model_if_needed(
61
+ cls, model: torch.nn.Module, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
62
+ ) -> torch.nn.Module:
63
+ decoder_model = _VAETemporalDecoder(model)
64
+ decoder_model.num_frames = rbln_config.decode_chunk_size
65
+ decoder_model.eval()
66
+
67
+ if rbln_config.uses_encoder:
68
+ encoder_model = _VAEEncoder(model)
69
+ encoder_model.eval()
70
+ return encoder_model, decoder_model
71
+ else:
72
+ return decoder_model
73
+
74
+ @classmethod
75
+ def get_compiled_model(
76
+ cls, model, rbln_config: RBLNAutoencoderKLTemporalDecoderConfig
77
+ ) -> Dict[str, rebel.RBLNCompiledModel]:
78
+ compiled_models = {}
79
+ if rbln_config.uses_encoder:
80
+ encoder_model, decoder_model = cls._wrap_model_if_needed(model, rbln_config)
81
+ enc_compiled_model = cls.compile(
82
+ encoder_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[0],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map["encoder"],
86
+ )
87
+ compiled_models["encoder"] = enc_compiled_model
88
+ else:
89
+ decoder_model = cls._wrap_model_if_needed(model, rbln_config)
90
+ dec_compiled_model = cls.compile(
91
+ decoder_model,
92
+ rbln_compile_config=rbln_config.compile_cfgs[-1],
93
+ create_runtimes=rbln_config.create_runtimes,
94
+ device=rbln_config.device_map["decoder"],
95
+ )
96
+ compiled_models["decoder"] = dec_compiled_model
97
+
98
+ return compiled_models
99
+
100
+ @classmethod
101
+ def get_vae_sample_size(
102
+ cls,
103
+ pipe: "RBLNDiffusionMixin",
104
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
105
+ return_vae_scale_factor: bool = False,
106
+ ) -> Tuple[int, int]:
107
+ sample_size = rbln_config.sample_size
108
+ if hasattr(pipe, "vae_scale_factor"):
109
+ vae_scale_factor = pipe.vae_scale_factor
110
+ else:
111
+ if hasattr(pipe.vae.config, "block_out_channels"):
112
+ vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
113
+ else:
114
+ vae_scale_factor = 8 # vae image processor default value 8 (int)
115
+
116
+ if sample_size is None:
117
+ sample_size = pipe.unet.config.sample_size
118
+ if isinstance(sample_size, int):
119
+ sample_size = (sample_size, sample_size)
120
+ sample_size = (sample_size[0] * vae_scale_factor, sample_size[1] * vae_scale_factor)
121
+
122
+ if return_vae_scale_factor:
123
+ return sample_size, vae_scale_factor
124
+ else:
125
+ return sample_size
126
+
127
+ @classmethod
128
+ def update_rbln_config_using_pipe(
129
+ cls, pipe: "RBLNDiffusionMixin", rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
130
+ ) -> "RBLNDiffusionMixinConfig":
131
+ rbln_config.vae.sample_size, rbln_config.vae.vae_scale_factor = cls.get_vae_sample_size(
132
+ pipe, rbln_config.vae, return_vae_scale_factor=True
133
+ )
134
+
135
+ if rbln_config.vae.num_frames is None:
136
+ if hasattr(pipe.unet.config, "num_frames"):
137
+ rbln_config.vae.num_frames = pipe.unet.config.num_frames
138
+ else:
139
+ raise ValueError("num_frames should be specified in unet config.json")
140
+
141
+ if rbln_config.vae.decode_chunk_size is None:
142
+ rbln_config.vae.decode_chunk_size = rbln_config.vae.num_frames
143
+
144
+ def chunk_frame(num_frames, decode_chunk_size):
145
+ # get closest divisor to num_frames
146
+ divisors = [i for i in range(1, num_frames) if num_frames % i == 0]
147
+ closest = min(divisors, key=lambda x: abs(x - decode_chunk_size))
148
+ if decode_chunk_size != closest:
149
+ logger.warning(
150
+ f"To ensure successful model compilation and prevent device OOM, {decode_chunk_size} is set to {closest}."
151
+ )
152
+ return closest
153
+
154
+ decode_chunk_size = chunk_frame(rbln_config.vae.num_frames, rbln_config.vae.decode_chunk_size)
155
+ rbln_config.vae.decode_chunk_size = decode_chunk_size
156
+ return rbln_config
157
+
158
+ @classmethod
159
+ def _update_rbln_config(
160
+ cls,
161
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
162
+ model: "PreTrainedModel",
163
+ model_config: "PretrainedConfig",
164
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
165
+ ) -> RBLNAutoencoderKLTemporalDecoderConfig:
166
+ if rbln_config.sample_size is None:
167
+ rbln_config.sample_size = model_config.sample_size
168
+
169
+ if rbln_config.vae_scale_factor is None:
170
+ if hasattr(model_config, "block_out_channels"):
171
+ rbln_config.vae_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
172
+ else:
173
+ # vae image processor default value 8 (int)
174
+ rbln_config.vae_scale_factor = 8
175
+
176
+ compile_cfgs = []
177
+ if rbln_config.uses_encoder:
178
+ vae_enc_input_info = [
179
+ (
180
+ "x",
181
+ [
182
+ rbln_config.batch_size,
183
+ model_config.in_channels,
184
+ rbln_config.sample_size[0],
185
+ rbln_config.sample_size[1],
186
+ ],
187
+ "float32",
188
+ )
189
+ ]
190
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="encoder", input_info=vae_enc_input_info))
191
+
192
+ decode_batch_size = rbln_config.batch_size * rbln_config.decode_chunk_size
193
+ vae_dec_input_info = [
194
+ (
195
+ "z",
196
+ [
197
+ decode_batch_size,
198
+ model_config.latent_channels,
199
+ rbln_config.latent_sample_size[0],
200
+ rbln_config.latent_sample_size[1],
201
+ ],
202
+ "float32",
203
+ )
204
+ ]
205
+ compile_cfgs.append(RBLNCompileConfig(compiled_model_name="decoder", input_info=vae_dec_input_info))
206
+
207
+ rbln_config.set_compile_cfgs(compile_cfgs)
208
+ return rbln_config
209
+
210
+ @classmethod
211
+ def _create_runtimes(
212
+ cls,
213
+ compiled_models: List[rebel.RBLNCompiledModel],
214
+ rbln_config: RBLNAutoencoderKLTemporalDecoderConfig,
215
+ ) -> List[rebel.Runtime]:
216
+ if len(compiled_models) == 1:
217
+ # decoder
218
+ expected_models = ["decoder"]
219
+ else:
220
+ expected_models = ["encoder", "decoder"]
221
+
222
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
223
+ cls._raise_missing_compiled_file_error(expected_models)
224
+
225
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
226
+ return [
227
+ rebel.Runtime(
228
+ compiled_model,
229
+ tensor_type="pt",
230
+ device=device_val,
231
+ activate_profiler=rbln_config.activate_profiler,
232
+ timeout=rbln_config.timeout,
233
+ )
234
+ for compiled_model, device_val in zip(compiled_models, device_vals)
235
+ ]
236
+
237
+ def encode(
238
+ self, x: torch.FloatTensor, return_dict: bool = True
239
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
240
+ """
241
+ Encode an input image into a latent representation.
242
+
243
+ Args:
244
+ x: The input image to encode.
245
+ return_dict:
246
+ Whether to return output as a dictionary. Defaults to True.
247
+
248
+ Returns:
249
+ The latent representation or AutoencoderKLOutput if return_dict=True
250
+ """
251
+ posterior = self.encoder.encode(x)
252
+
253
+ if not return_dict:
254
+ return (posterior,)
255
+
256
+ return AutoencoderKLOutput(latent_dist=posterior)
257
+
258
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> torch.FloatTensor:
259
+ """
260
+ Decode a latent representation into a video.
261
+
262
+ Args:
263
+ z: The latent representation to decode.
264
+ return_dict:
265
+ Whether to return output as a dictionary. Defaults to True.
266
+
267
+ Returns:
268
+ The decoded video or DecoderOutput if return_dict=True
269
+ """
270
+ decoded = self.decoder.decode(z)
271
+
272
+ if not return_dict:
273
+ return (decoded,)
274
+
275
+ return DecoderOutput(sample=decoded)
@@ -0,0 +1,178 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, List, Union
16
+
17
+ import torch
18
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution, IdentityDistribution
19
+
20
+ from ....utils.runtime_utils import RBLNPytorchRuntime
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from diffusers import AutoencoderKL, AutoencoderKLCosmos, AutoencoderKLTemporalDecoder, VQModel
25
+
26
+
27
+ class RBLNRuntimeVAEEncoder(RBLNPytorchRuntime):
28
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
29
+ moments = self.forward(x.contiguous())
30
+ posterior = DiagonalGaussianDistribution(moments)
31
+ return posterior
32
+
33
+
34
+ class RBLNRuntimeVAEDecoder(RBLNPytorchRuntime):
35
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
36
+ return self.forward(z)
37
+
38
+
39
+ class RBLNRuntimeCosmosVAEEncoder(RBLNPytorchRuntime):
40
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
41
+ if self.use_slicing and x.shape[0] > 1:
42
+ encoded_slices = [self.forward(x_slice) for x_slice in x.split(1)]
43
+ h = torch.cat(encoded_slices)
44
+ else:
45
+ h = self.forward(x)
46
+ posterior = IdentityDistribution(h)
47
+ return posterior
48
+
49
+
50
+ class RBLNRuntimeCosmosVAEDecoder(RBLNPytorchRuntime):
51
+ def decode(self, z: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
52
+ if self.use_slicing and z.shape[0] > 1:
53
+ decoded_slices = [self.forward(z_slice) for z_slice in z.split(1)]
54
+ decoded = torch.cat(decoded_slices)
55
+ else:
56
+ decoded = self.forward(z)
57
+ return decoded
58
+
59
+
60
+ class _VAEDecoder(torch.nn.Module):
61
+ def __init__(self, vae: "AutoencoderKL"):
62
+ super().__init__()
63
+ self.vae = vae
64
+
65
+ def forward(self, z):
66
+ vae_out = self.vae.decode(z, return_dict=False)
67
+ return vae_out
68
+
69
+
70
+ class _VAETemporalDecoder(torch.nn.Module):
71
+ def __init__(self, vae: "AutoencoderKLTemporalDecoder"):
72
+ super().__init__()
73
+ self.vae = vae
74
+ self.num_frames = None
75
+
76
+ def forward(self, z):
77
+ vae_out = self.vae.decode(z, num_frames=self.num_frames, return_dict=False)
78
+ return vae_out
79
+
80
+
81
+ class _VAEEncoder(torch.nn.Module):
82
+ def __init__(self, vae: Union["AutoencoderKL", "AutoencoderKLTemporalDecoder"]):
83
+ super().__init__()
84
+ self.vae = vae
85
+
86
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True):
87
+ if hasattr(self, "use_tiling") and hasattr(self, "use_slicing"):
88
+ if self.use_tiling and (
89
+ x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size
90
+ ):
91
+ return self.tiled_encode(x, return_dict=return_dict)
92
+
93
+ if self.use_slicing and x.shape[0] > 1:
94
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
95
+ h = torch.cat(encoded_slices)
96
+ else:
97
+ h = self.encoder(x)
98
+ if self.quant_conv is not None:
99
+ h = self.quant_conv(h)
100
+
101
+ else:
102
+ h = self.encoder(x)
103
+ if self.quant_conv is not None:
104
+ h = self.quant_conv(h)
105
+ return h
106
+
107
+ def forward(self, x):
108
+ vae_out = _VAEEncoder.encode(self.vae, x, return_dict=False)
109
+ return vae_out
110
+
111
+
112
+ class _VAECosmosEncoder(torch.nn.Module):
113
+ def __init__(self, vae: "AutoencoderKLCosmos"):
114
+ super().__init__()
115
+ self.vae = vae
116
+
117
+ def forward(self, x):
118
+ vae_out = self.vae._encode(x)
119
+ return vae_out
120
+
121
+
122
+ class _VAECosmosDecoder(torch.nn.Module):
123
+ def __init__(self, vae: "AutoencoderKLCosmos"):
124
+ super().__init__()
125
+ self.vae = vae
126
+
127
+ def forward(self, z):
128
+ vae_out = self.vae._decode(z, return_dict=False)
129
+ return vae_out
130
+
131
+
132
+ class RBLNRuntimeVQEncoder(RBLNPytorchRuntime):
133
+ def encode(self, x: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
134
+ h = self.forward(x.contiguous())
135
+ return h
136
+
137
+
138
+ class RBLNRuntimeVQDecoder(RBLNPytorchRuntime):
139
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, shape=None, **kwargs) -> List[torch.Tensor]:
140
+ if not (force_not_quantize and not self.lookup_from_codebook):
141
+ raise ValueError(
142
+ "Currently, the `decode` method of the class `RBLNVQModel` is executed successfully only if `force_not_quantize` is True and `config.lookup_from_codebook` is False"
143
+ )
144
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
145
+ dec = self.forward(h.contiguous())
146
+ return dec, commit_loss
147
+
148
+
149
+ class _VQEncoder(torch.nn.Module):
150
+ def __init__(self, vq_model: "VQModel"):
151
+ super().__init__()
152
+ self.vq_model = vq_model
153
+
154
+ def encode(self, x: torch.Tensor, return_dict: bool = True):
155
+ h = self.vq_model.encoder(x)
156
+ h = self.vq_model.quant_conv(h)
157
+ return h
158
+
159
+ def forward(self, x: torch.Tensor):
160
+ vq_out = self.encode(x)
161
+ return vq_out
162
+
163
+
164
+ class _VQDecoder(torch.nn.Module):
165
+ def __init__(self, vq_model: "VQModel"):
166
+ super().__init__()
167
+ self.vq_model = vq_model
168
+
169
+ def decode(self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None):
170
+ quant = h
171
+ quant2 = self.vq_model.post_quant_conv(quant)
172
+ quant = quant if self.vq_model.config.norm_type == "spatial" else None
173
+ dec = self.vq_model.decoder(quant2, quant)
174
+ return dec
175
+
176
+ def forward(self, h: torch.Tensor):
177
+ vq_out = self.decode(h)
178
+ return vq_out
@@ -0,0 +1,211 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Any, List, Union
16
+
17
+ import rebel
18
+ import torch
19
+ from diffusers import VQModel
20
+ from diffusers.models.autoencoders.vae import DecoderOutput
21
+ from diffusers.models.autoencoders.vq_model import VQEncoderOutput
22
+
23
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
24
+ from ....modeling import RBLNModel
25
+ from ....utils.logging import get_logger
26
+ from ...configurations.models.configuration_vq_model import RBLNVQModelConfig
27
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
28
+ from .vae import RBLNRuntimeVQDecoder, RBLNRuntimeVQEncoder, _VQDecoder, _VQEncoder
29
+
30
+
31
+ if TYPE_CHECKING:
32
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class RBLNVQModel(RBLNModel):
38
+ """
39
+ RBLN implementation of VQModel for diffusion models.
40
+
41
+ This model is used to accelerate VQModel models from diffusers library on RBLN NPUs.
42
+ It can be configured to include both encoder and decoder, or just the decoder part for latent-to-image
43
+ conversion.
44
+
45
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
46
+ the library implements for all its models.
47
+ """
48
+
49
+ auto_model_class = VQModel
50
+ config_name = "config.json"
51
+ hf_library_name = "diffusers"
52
+
53
+ def __post_init__(self, **kwargs):
54
+ super().__post_init__(**kwargs)
55
+
56
+ if self.rbln_config.uses_encoder:
57
+ self.encoder = RBLNRuntimeVQEncoder(runtime=self.model[0], main_input_name="x")
58
+ else:
59
+ self.encoder = None
60
+
61
+ self.decoder = RBLNRuntimeVQDecoder(runtime=self.model[-1], main_input_name="z")
62
+ self.decoder.lookup_from_codebook = self.config.lookup_from_codebook
63
+ self.image_size = self.rbln_config.image_size
64
+
65
+ @classmethod
66
+ def get_compiled_model(cls, model, rbln_config: RBLNModelConfig):
67
+ if rbln_config.uses_encoder:
68
+ expected_models = ["encoder", "decoder"]
69
+ else:
70
+ expected_models = ["decoder"]
71
+
72
+ compiled_models = {}
73
+ for i, model_name in enumerate(expected_models):
74
+ if model_name == "encoder":
75
+ wrapped_model = _VQEncoder(model)
76
+ else:
77
+ wrapped_model = _VQDecoder(model)
78
+
79
+ wrapped_model.eval()
80
+
81
+ compiled_models[model_name] = cls.compile(
82
+ wrapped_model,
83
+ rbln_compile_config=rbln_config.compile_cfgs[i],
84
+ create_runtimes=rbln_config.create_runtimes,
85
+ device=rbln_config.device_map[model_name],
86
+ )
87
+
88
+ return compiled_models
89
+
90
+ @classmethod
91
+ def update_rbln_config_using_pipe(
92
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
93
+ ) -> "RBLNDiffusionMixinConfig":
94
+ return rbln_config
95
+
96
+ @classmethod
97
+ def _update_rbln_config(
98
+ cls,
99
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
100
+ model: "PreTrainedModel",
101
+ model_config: "PretrainedConfig",
102
+ rbln_config: RBLNVQModelConfig,
103
+ ) -> RBLNVQModelConfig:
104
+ if hasattr(model_config, "block_out_channels"):
105
+ rbln_config.vqmodel_scale_factor = 2 ** (len(model_config.block_out_channels) - 1)
106
+ else:
107
+ # image processor default value 8 (int)
108
+ rbln_config.vqmodel_scale_factor = 8
109
+
110
+ compile_cfgs = []
111
+ if rbln_config.uses_encoder:
112
+ enc_input_info = [
113
+ (
114
+ "x",
115
+ [
116
+ rbln_config.batch_size,
117
+ model_config.in_channels,
118
+ rbln_config.sample_size[0],
119
+ rbln_config.sample_size[1],
120
+ ],
121
+ "float32",
122
+ )
123
+ ]
124
+ enc_rbln_compile_config = RBLNCompileConfig(compiled_model_name="encoder", input_info=enc_input_info)
125
+ compile_cfgs.append(enc_rbln_compile_config)
126
+
127
+ dec_input_info = [
128
+ (
129
+ "h",
130
+ [
131
+ rbln_config.batch_size,
132
+ model_config.latent_channels,
133
+ rbln_config.latent_sample_size[0],
134
+ rbln_config.latent_sample_size[1],
135
+ ],
136
+ "float32",
137
+ )
138
+ ]
139
+ dec_rbln_compile_config = RBLNCompileConfig(compiled_model_name="decoder", input_info=dec_input_info)
140
+ compile_cfgs.append(dec_rbln_compile_config)
141
+
142
+ rbln_config.set_compile_cfgs(compile_cfgs)
143
+ return rbln_config
144
+
145
+ @classmethod
146
+ def _create_runtimes(
147
+ cls,
148
+ compiled_models: List[rebel.RBLNCompiledModel],
149
+ rbln_config: RBLNVQModelConfig,
150
+ ) -> List[rebel.Runtime]:
151
+ if len(compiled_models) == 1:
152
+ # decoder
153
+ expected_models = ["decoder"]
154
+ else:
155
+ # encoder, decoder
156
+ expected_models = ["encoder", "decoder"]
157
+
158
+ if any(model_name not in rbln_config.device_map for model_name in expected_models):
159
+ cls._raise_missing_compiled_file_error(expected_models)
160
+
161
+ device_vals = [rbln_config.device_map[model_name] for model_name in expected_models]
162
+ return [
163
+ rebel.Runtime(
164
+ compiled_model,
165
+ tensor_type="pt",
166
+ device=device_val,
167
+ activate_profiler=rbln_config.activate_profiler,
168
+ timeout=rbln_config.timeout,
169
+ )
170
+ for compiled_model, device_val in zip(compiled_models, device_vals)
171
+ ]
172
+
173
+ def encode(
174
+ self, x: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
175
+ ) -> Union[torch.FloatTensor, VQEncoderOutput]:
176
+ """
177
+ Encode an input image into a quantized latent representation.
178
+
179
+ Args:
180
+ x: The input image to encode.
181
+ return_dict:
182
+ Whether to return output as a dictionary. Defaults to True.
183
+ kwargs: Additional arguments to pass to the encoder/quantizer.
184
+
185
+ Returns:
186
+ The quantized latent representation or a specific output object.
187
+ """
188
+ posterior = self.encoder.encode(x)
189
+ if not return_dict:
190
+ return (posterior,)
191
+ return VQEncoderOutput(latents=posterior)
192
+
193
+ def decode(
194
+ self, h: torch.FloatTensor, return_dict: bool = True, **kwargs: Any
195
+ ) -> Union[torch.FloatTensor, DecoderOutput]:
196
+ """
197
+ Decode a quantized latent representation back into an image.
198
+
199
+ Args:
200
+ h: The quantized latent representation to decode.
201
+ return_dict:
202
+ Whether to return output as a dictionary. Defaults to True.
203
+ kwargs: Additional arguments to pass to the decoder.
204
+
205
+ Returns:
206
+ The decoded image or a DecoderOutput object.
207
+ """
208
+ dec, commit_loss = self.decoder.decode(h, **kwargs)
209
+ if not return_dict:
210
+ return (dec, commit_loss)
211
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)