optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,408 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel, UNet2DConditionOutput
20
+ from transformers import PretrainedConfig
21
+
22
+ from ....configuration_utils import RBLNCompileConfig
23
+ from ....modeling import RBLNModel
24
+ from ....utils.logging import get_logger
25
+ from ...configurations import RBLNUNet2DConditionModelConfig
26
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
27
+
28
+
29
+ if TYPE_CHECKING:
30
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
31
+
32
+ logger = get_logger(__name__)
33
+
34
+
35
+ class _UNet_SD(torch.nn.Module):
36
+ def __init__(self, unet: "UNet2DConditionModel"):
37
+ super().__init__()
38
+ self.unet = unet
39
+
40
+ def forward(
41
+ self,
42
+ sample: torch.Tensor,
43
+ timestep: Union[torch.Tensor, float, int],
44
+ encoder_hidden_states: torch.Tensor,
45
+ *down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
46
+ text_embeds: Optional[torch.Tensor] = None,
47
+ time_ids: Optional[torch.Tensor] = None,
48
+ ) -> torch.Tensor:
49
+ if text_embeds is not None and time_ids is not None:
50
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
51
+ else:
52
+ added_cond_kwargs = {}
53
+
54
+ if len(down_and_mid_block_additional_residuals) != 0:
55
+ down_block_additional_residuals, mid_block_additional_residual = (
56
+ down_and_mid_block_additional_residuals[:-1],
57
+ down_and_mid_block_additional_residuals[-1],
58
+ )
59
+ else:
60
+ down_block_additional_residuals, mid_block_additional_residual = None, None
61
+
62
+ unet_out = self.unet(
63
+ sample=sample,
64
+ timestep=timestep,
65
+ encoder_hidden_states=encoder_hidden_states,
66
+ down_block_additional_residuals=down_block_additional_residuals,
67
+ mid_block_additional_residual=mid_block_additional_residual,
68
+ added_cond_kwargs=added_cond_kwargs,
69
+ return_dict=False,
70
+ )
71
+ return unet_out
72
+
73
+
74
+ class _UNet_SDXL(torch.nn.Module):
75
+ def __init__(self, unet: "UNet2DConditionModel"):
76
+ super().__init__()
77
+ self.unet = unet
78
+
79
+ def forward(
80
+ self,
81
+ sample: torch.Tensor,
82
+ timestep: Union[torch.Tensor, float, int],
83
+ encoder_hidden_states: torch.Tensor,
84
+ *down_and_mid_block_additional_residuals: Optional[Tuple[torch.Tensor]],
85
+ ) -> torch.Tensor:
86
+ if len(down_and_mid_block_additional_residuals) == 2:
87
+ added_cond_kwargs = {
88
+ "text_embeds": down_and_mid_block_additional_residuals[0],
89
+ "time_ids": down_and_mid_block_additional_residuals[1],
90
+ }
91
+ down_block_additional_residuals = None
92
+ mid_block_additional_residual = None
93
+ elif len(down_and_mid_block_additional_residuals) > 2:
94
+ added_cond_kwargs = {
95
+ "text_embeds": down_and_mid_block_additional_residuals[-2],
96
+ "time_ids": down_and_mid_block_additional_residuals[-1],
97
+ }
98
+ down_block_additional_residuals, mid_block_additional_residual = (
99
+ down_and_mid_block_additional_residuals[:-3],
100
+ down_and_mid_block_additional_residuals[-3],
101
+ )
102
+ else:
103
+ added_cond_kwargs = {}
104
+ down_block_additional_residuals = None
105
+ mid_block_additional_residual = None
106
+
107
+ unet_out = self.unet(
108
+ sample=sample,
109
+ timestep=timestep,
110
+ encoder_hidden_states=encoder_hidden_states,
111
+ down_block_additional_residuals=down_block_additional_residuals,
112
+ mid_block_additional_residual=mid_block_additional_residual,
113
+ added_cond_kwargs=added_cond_kwargs,
114
+ return_dict=False,
115
+ )
116
+ return unet_out
117
+
118
+
119
+ class _UNet_Kandinsky(torch.nn.Module):
120
+ def __init__(self, unet: "UNet2DConditionModel"):
121
+ super().__init__()
122
+ self.unet = unet
123
+
124
+ def forward(
125
+ self,
126
+ sample: torch.Tensor,
127
+ timestep: Union[torch.Tensor, float, int],
128
+ image_embeds: torch.Tensor,
129
+ ) -> torch.Tensor:
130
+ added_cond_kwargs = {"image_embeds": image_embeds}
131
+
132
+ unet_out = self.unet(
133
+ sample=sample,
134
+ timestep=timestep,
135
+ encoder_hidden_states=None,
136
+ added_cond_kwargs=added_cond_kwargs,
137
+ return_dict=False,
138
+ )
139
+ return unet_out
140
+
141
+
142
+ class RBLNUNet2DConditionModel(RBLNModel):
143
+ """
144
+ RBLN implementation of UNet2DConditionModel for diffusion models.
145
+
146
+ This model is used to accelerate UNet2DCondition models from diffusers library on RBLN NPUs.
147
+ It is a key component in diffusion-based image generation models like Stable Diffusion.
148
+
149
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
150
+ the library implements for all its models.
151
+ """
152
+
153
+ hf_library_name = "diffusers"
154
+ auto_model_class = UNet2DConditionModel
155
+ _rbln_config_class = RBLNUNet2DConditionModelConfig
156
+ _output_class = UNet2DConditionOutput
157
+
158
+ def __post_init__(self, **kwargs):
159
+ super().__post_init__(**kwargs)
160
+ self.in_features = self.rbln_config.in_features
161
+ if self.in_features is not None:
162
+
163
+ @dataclass
164
+ class LINEAR1:
165
+ in_features: int
166
+
167
+ @dataclass
168
+ class ADDEMBEDDING:
169
+ linear_1: LINEAR1
170
+
171
+ self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
172
+
173
+ @classmethod
174
+ def _wrap_model_if_needed(
175
+ cls, model: torch.nn.Module, rbln_config: RBLNUNet2DConditionModelConfig
176
+ ) -> torch.nn.Module:
177
+ if model.config.addition_embed_type == "text_time":
178
+ return _UNet_SDXL(model).eval()
179
+ elif model.config.addition_embed_type == "image":
180
+ return _UNet_Kandinsky(model).eval()
181
+ else:
182
+ return _UNet_SD(model).eval()
183
+
184
+ @classmethod
185
+ def get_unet_sample_size(
186
+ cls,
187
+ pipe: RBLNDiffusionMixin,
188
+ rbln_config: RBLNUNet2DConditionModelConfig,
189
+ image_size: Optional[Tuple[int, int]] = None,
190
+ ) -> Tuple[int, int]:
191
+ if hasattr(pipe, "movq"):
192
+ scale_factor = 2 ** (len(pipe.movq.config.block_out_channels) - 1)
193
+ else:
194
+ scale_factor = pipe.vae_scale_factor
195
+
196
+ if image_size is None:
197
+ if "Img2Img" in pipe.__class__.__name__:
198
+ if hasattr(pipe, "vae"):
199
+ # In case of img2img, sample size of unet is determined by vae encoder.
200
+ vae_sample_size = pipe.vae.config.sample_size
201
+ if isinstance(vae_sample_size, int):
202
+ vae_sample_size = (vae_sample_size, vae_sample_size)
203
+
204
+ sample_size = (
205
+ vae_sample_size[0] // scale_factor,
206
+ vae_sample_size[1] // scale_factor,
207
+ )
208
+ elif hasattr(pipe, "movq"):
209
+ logger.warning(
210
+ "RBLN config 'image_size' should have been provided for this pipeline. "
211
+ "Both variable will be set 512 by default."
212
+ )
213
+ sample_size = (512 // scale_factor, 512 // scale_factor)
214
+ else:
215
+ sample_size = pipe.unet.config.sample_size
216
+ if isinstance(sample_size, int):
217
+ sample_size = (sample_size, sample_size)
218
+ else:
219
+ sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
220
+
221
+ return sample_size
222
+
223
+ @classmethod
224
+ def update_rbln_config_using_pipe(
225
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
226
+ ) -> "RBLNDiffusionMixinConfig":
227
+ rbln_config.unet.text_model_hidden_size = (
228
+ pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
229
+ )
230
+ rbln_config.unet.image_model_hidden_size = pipe.unet.config.encoder_hid_dim if hasattr(pipe, "unet") else None
231
+
232
+ rbln_config.unet.max_seq_len = (
233
+ pipe.text_encoder.config.max_position_embeddings if hasattr(pipe, "text_encoder") else None
234
+ )
235
+
236
+ rbln_config.unet.sample_size = cls.get_unet_sample_size(
237
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
238
+ )
239
+ rbln_config.unet.use_additional_residuals = "controlnet" in pipe.config.keys()
240
+
241
+ return rbln_config
242
+
243
+ @classmethod
244
+ def _update_rbln_config(
245
+ cls,
246
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
247
+ model: "PreTrainedModel",
248
+ model_config: "PretrainedConfig",
249
+ rbln_config: RBLNUNet2DConditionModelConfig,
250
+ ) -> RBLNUNet2DConditionModelConfig:
251
+ if rbln_config.sample_size is None:
252
+ rbln_config.sample_size = model_config.sample_size
253
+
254
+ if isinstance(rbln_config.sample_size, int):
255
+ rbln_config.sample_size = (rbln_config.sample_size, rbln_config.sample_size)
256
+
257
+ input_info = [
258
+ (
259
+ "sample",
260
+ [
261
+ rbln_config.batch_size,
262
+ model_config.in_channels,
263
+ rbln_config.sample_size[0],
264
+ rbln_config.sample_size[1],
265
+ ],
266
+ "float32",
267
+ ),
268
+ ("timestep", [], "float32"),
269
+ ]
270
+
271
+ if rbln_config.max_seq_len is not None:
272
+ input_info.append(
273
+ (
274
+ "encoder_hidden_states",
275
+ [rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
276
+ "float32",
277
+ ),
278
+ )
279
+
280
+ if rbln_config.use_additional_residuals:
281
+ # down block addtional residuals
282
+ first_shape = [
283
+ rbln_config.batch_size,
284
+ model_config.block_out_channels[0],
285
+ rbln_config.sample_size[0],
286
+ rbln_config.sample_size[1],
287
+ ]
288
+ height, width = rbln_config.sample_size[0], rbln_config.sample_size[1]
289
+ input_info.append(("down_block_additional_residuals_0", first_shape, "float32"))
290
+ name_idx = 1
291
+ for idx, _ in enumerate(model_config.down_block_types):
292
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
293
+ for _ in range(model_config.layers_per_block):
294
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
295
+ name_idx += 1
296
+ if idx != len(model_config.down_block_types) - 1:
297
+ height = height // 2
298
+ width = width // 2
299
+ shape = [rbln_config.batch_size, model_config.block_out_channels[idx], height, width]
300
+ input_info.append((f"down_block_additional_residuals_{name_idx}", shape, "float32"))
301
+ name_idx += 1
302
+
303
+ # mid block addtional residual
304
+ num_cross_attn_blocks = model_config.down_block_types.count("CrossAttnDownBlock2D")
305
+ out_channels = model_config.block_out_channels[-1]
306
+ shape = [
307
+ rbln_config.batch_size,
308
+ out_channels,
309
+ rbln_config.sample_size[0] // 2**num_cross_attn_blocks,
310
+ rbln_config.sample_size[1] // 2**num_cross_attn_blocks,
311
+ ]
312
+ input_info.append(("mid_block_additional_residual", shape, "float32"))
313
+
314
+ if hasattr(model_config, "addition_embed_type"):
315
+ if model_config.addition_embed_type == "text_time":
316
+ rbln_config.in_features = model_config.projection_class_embeddings_input_dim
317
+ input_info.append(
318
+ ("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32")
319
+ )
320
+ input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
321
+ elif model_config.addition_embed_type == "image":
322
+ input_info.append(
323
+ ("image_embeds", [rbln_config.batch_size, rbln_config.image_model_hidden_size], "float32")
324
+ )
325
+
326
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
327
+ rbln_config.set_compile_cfgs([rbln_compile_config])
328
+
329
+ return rbln_config
330
+
331
+ @property
332
+ def compiled_batch_size(self):
333
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
334
+
335
+ def forward(
336
+ self,
337
+ sample: torch.Tensor,
338
+ timestep: Union[torch.Tensor, float, int],
339
+ encoder_hidden_states: torch.Tensor,
340
+ class_labels: Optional[torch.Tensor] = None,
341
+ timestep_cond: Optional[torch.Tensor] = None,
342
+ attention_mask: Optional[torch.Tensor] = None,
343
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
344
+ added_cond_kwargs: Dict[str, torch.Tensor] = {},
345
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
346
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
347
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
348
+ encoder_attention_mask: Optional[torch.Tensor] = None,
349
+ return_dict: bool = True,
350
+ **kwargs,
351
+ ) -> Union[UNet2DConditionOutput, Tuple]:
352
+ """
353
+ Forward pass for the RBLN-optimized UNet2DConditionModel.
354
+
355
+ Args:
356
+ sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
357
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
358
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
359
+ added_cond_kwargs (Dict[str, torch.Tensor]): A kwargs dictionary containing additional embeddings that
360
+ if specified are added to the embeddings that are passed along to the UNet blocks.
361
+ down_block_additional_residuals (Optional[Tuple[torch.Tensor]]): A tuple of tensors that if specified are added to the residuals of down unet blocks.
362
+ mid_block_additional_residual (Optional[torch.Tensor]): A tensor that if specified is added to the residual of the middle unet block.
363
+ return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
364
+
365
+ Returns:
366
+ (Union[`~diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`], Tuple)
367
+ """
368
+ sample_batch_size = sample.size()[0]
369
+ compiled_batch_size = self.compiled_batch_size
370
+ if sample_batch_size != compiled_batch_size and (
371
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
372
+ ):
373
+ raise ValueError(
374
+ f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
375
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size of UNet in Stable Diffusion. "
376
+ "Adjust the batch size of UNet during compilation to match the runtime batch size.\n\n"
377
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/stable_diffusion.html#important-batch-size-configuration-for-guidance-scale"
378
+ )
379
+
380
+ added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
381
+
382
+ if down_block_additional_residuals is not None:
383
+ down_block_additional_residuals = [t.contiguous() for t in down_block_additional_residuals]
384
+ return super().forward(
385
+ sample.contiguous(),
386
+ timestep.float(),
387
+ encoder_hidden_states,
388
+ *down_block_additional_residuals,
389
+ mid_block_additional_residual,
390
+ **added_cond_kwargs,
391
+ return_dict=return_dict,
392
+ )
393
+
394
+ if "image_embeds" in added_cond_kwargs:
395
+ return super().forward(
396
+ sample.contiguous(),
397
+ timestep.float(),
398
+ **added_cond_kwargs,
399
+ return_dict=return_dict,
400
+ )
401
+
402
+ return super().forward(
403
+ sample.contiguous(),
404
+ timestep.float(),
405
+ encoder_hidden_states,
406
+ **added_cond_kwargs,
407
+ return_dict=return_dict,
408
+ )
@@ -0,0 +1,201 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from diffusers.models.unets.unet_spatio_temporal_condition import (
20
+ UNetSpatioTemporalConditionModel,
21
+ UNetSpatioTemporalConditionOutput,
22
+ )
23
+ from transformers import PretrainedConfig
24
+
25
+ from ....configuration_utils import RBLNCompileConfig
26
+ from ....modeling import RBLNModel
27
+ from ....utils.logging import get_logger
28
+ from ...configurations import RBLNUNetSpatioTemporalConditionModelConfig
29
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
30
+
31
+
32
+ if TYPE_CHECKING:
33
+ from transformers import AutoFeatureExtractor, AutoProcessor, PreTrainedModel
34
+
35
+ logger = get_logger(__name__)
36
+
37
+
38
+ class _UNet_STCM(torch.nn.Module):
39
+ def __init__(self, unet: "UNetSpatioTemporalConditionModel"):
40
+ super().__init__()
41
+ self.unet = unet
42
+
43
+ def forward(
44
+ self,
45
+ sample: torch.Tensor,
46
+ timestep: Union[torch.Tensor, float, int],
47
+ encoder_hidden_states: torch.Tensor,
48
+ added_time_ids: torch.Tensor,
49
+ ) -> torch.Tensor:
50
+ unet_out = self.unet(
51
+ sample=sample,
52
+ timestep=timestep,
53
+ encoder_hidden_states=encoder_hidden_states,
54
+ added_time_ids=added_time_ids,
55
+ return_dict=False,
56
+ )
57
+ return unet_out
58
+
59
+
60
+ class RBLNUNetSpatioTemporalConditionModel(RBLNModel):
61
+ hf_library_name = "diffusers"
62
+ auto_model_class = UNetSpatioTemporalConditionModel
63
+ _rbln_config_class = RBLNUNetSpatioTemporalConditionModelConfig
64
+ output_class = UNetSpatioTemporalConditionOutput
65
+ output_key = "sample"
66
+
67
+ def __post_init__(self, **kwargs):
68
+ super().__post_init__(**kwargs)
69
+ self.in_features = self.rbln_config.in_features
70
+ if self.in_features is not None:
71
+
72
+ @dataclass
73
+ class LINEAR1:
74
+ in_features: int
75
+
76
+ @dataclass
77
+ class ADDEMBEDDING:
78
+ linear_1: LINEAR1
79
+
80
+ self.add_embedding = ADDEMBEDDING(LINEAR1(self.in_features))
81
+
82
+ @classmethod
83
+ def _wrap_model_if_needed(
84
+ cls, model: torch.nn.Module, rbln_config: RBLNUNetSpatioTemporalConditionModelConfig
85
+ ) -> torch.nn.Module:
86
+ return _UNet_STCM(model).eval()
87
+
88
+ @classmethod
89
+ def get_unet_sample_size(
90
+ cls,
91
+ pipe: RBLNDiffusionMixin,
92
+ rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
93
+ image_size: Optional[Tuple[int, int]] = None,
94
+ ) -> Union[int, Tuple[int, int]]:
95
+ scale_factor = pipe.vae_scale_factor
96
+
97
+ if image_size is None:
98
+ vae_sample_size = pipe.vae.config.sample_size
99
+ if isinstance(vae_sample_size, int):
100
+ vae_sample_size = (vae_sample_size, vae_sample_size)
101
+
102
+ sample_size = (
103
+ vae_sample_size[0] // scale_factor,
104
+ vae_sample_size[1] // scale_factor,
105
+ )
106
+ else:
107
+ sample_size = (image_size[0] // scale_factor, image_size[1] // scale_factor)
108
+ return sample_size
109
+
110
+ @classmethod
111
+ def update_rbln_config_using_pipe(
112
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
113
+ ) -> Dict[str, Any]:
114
+ rbln_config.unet.sample_size = cls.get_unet_sample_size(
115
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
116
+ )
117
+ return rbln_config
118
+
119
+ @classmethod
120
+ def _update_rbln_config(
121
+ cls,
122
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor"],
123
+ model: "PreTrainedModel",
124
+ model_config: "PretrainedConfig",
125
+ rbln_config: RBLNUNetSpatioTemporalConditionModelConfig,
126
+ ) -> RBLNUNetSpatioTemporalConditionModelConfig:
127
+ if rbln_config.num_frames is None:
128
+ rbln_config.num_frames = model_config.num_frames
129
+
130
+ if rbln_config.sample_size is None:
131
+ rbln_config.sample_size = model_config.sample_size
132
+
133
+ input_info = [
134
+ (
135
+ "sample",
136
+ [
137
+ rbln_config.batch_size,
138
+ rbln_config.num_frames,
139
+ model_config.in_channels,
140
+ rbln_config.sample_size[0],
141
+ rbln_config.sample_size[1],
142
+ ],
143
+ "float32",
144
+ ),
145
+ ("timestep", [], "float32"),
146
+ ("encoder_hidden_states", [rbln_config.batch_size, 1, model_config.cross_attention_dim], "float32"),
147
+ ("added_time_ids", [rbln_config.batch_size, 3], "float32"),
148
+ ]
149
+
150
+ if hasattr(model_config, "addition_time_embed_dim"):
151
+ rbln_config.in_features = model_config.projection_class_embeddings_input_dim
152
+
153
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
154
+ rbln_config.set_compile_cfgs([rbln_compile_config])
155
+
156
+ return rbln_config
157
+
158
+ @property
159
+ def compiled_batch_size(self):
160
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
161
+
162
+ def forward(
163
+ self,
164
+ sample: torch.Tensor,
165
+ timestep: Union[torch.Tensor, float, int],
166
+ encoder_hidden_states: torch.Tensor,
167
+ added_time_ids: torch.Tensor,
168
+ return_dict: bool = True,
169
+ **kwargs,
170
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
171
+ """
172
+ Forward pass for the RBLN-optimized UNetSpatioTemporalConditionModel.
173
+
174
+ Args:
175
+ sample (torch.Tensor): The noisy input tensor with the following shape `(batch, channel, height, width)`.
176
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
177
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
178
+ added_time_ids (torch.Tensor): A tensor containing additional sinusoidal embeddings and added to the time embeddings.
179
+ return_dict (bool): Whether or not to return a [`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`] instead of a plain tuple.
180
+
181
+ Returns:
182
+ (Union[`~diffusers.models.unets.unet_spatio_temporal_condition.UNetSpatioTemporalConditionOutput`], Tuple)
183
+ """
184
+ sample_batch_size = sample.size()[0]
185
+ compiled_batch_size = self.compiled_batch_size
186
+ if sample_batch_size != compiled_batch_size and (
187
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
188
+ ):
189
+ raise ValueError(
190
+ f"Mismatch between UNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
191
+ "This may be caused by the 'guidance scale' parameter, which doubles the runtime batch size in Stable Diffusion. "
192
+ "Adjust the batch size during compilation or modify the 'guidance scale' to match the compiled batch size.\n\n"
193
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api.html#stable-diffusion"
194
+ )
195
+ return super().forward(
196
+ sample.contiguous(),
197
+ timestep.float(),
198
+ encoder_hidden_states,
199
+ added_time_ids,
200
+ return_dict=return_dict,
201
+ )