optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,281 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, Dict, Optional, Union
16
+
17
+ import torch
18
+ from diffusers import ControlNetModel
19
+ from diffusers.models.controlnets.controlnet import ControlNetOutput
20
+ from transformers import PretrainedConfig
21
+
22
+ from ...configuration_utils import RBLNCompileConfig, RBLNModelConfig
23
+ from ...modeling import RBLNModel
24
+ from ...utils.logging import get_logger
25
+ from ...utils.model_utils import get_rbln_model_cls
26
+ from ..configurations import RBLNControlNetModelConfig
27
+ from ..modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
28
+
29
+
30
+ if TYPE_CHECKING:
31
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PreTrainedModel
32
+
33
+
34
+ logger = get_logger(__name__)
35
+
36
+
37
+ class _ControlNetModel(torch.nn.Module):
38
+ def __init__(self, controlnet: "ControlNetModel"):
39
+ super().__init__()
40
+ self.controlnet = controlnet
41
+
42
+ def forward(
43
+ self,
44
+ sample: torch.Tensor,
45
+ timestep: torch.Tensor,
46
+ controlnet_cond: torch.Tensor,
47
+ conditioning_scale,
48
+ text_embeds: Optional[torch.Tensor] = None,
49
+ time_ids: Optional[torch.Tensor] = None,
50
+ ):
51
+ if text_embeds is not None and time_ids is not None:
52
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
53
+ else:
54
+ added_cond_kwargs = {}
55
+
56
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
57
+ sample=sample,
58
+ timestep=timestep,
59
+ encoder_hidden_states=None,
60
+ controlnet_cond=controlnet_cond,
61
+ conditioning_scale=conditioning_scale,
62
+ added_cond_kwargs=added_cond_kwargs,
63
+ return_dict=False,
64
+ )
65
+ return down_block_res_samples, mid_block_res_sample
66
+
67
+
68
+ class _ControlNetModel_Cross_Attention(torch.nn.Module):
69
+ def __init__(self, controlnet: "ControlNetModel"):
70
+ super().__init__()
71
+ self.controlnet = controlnet
72
+
73
+ def forward(
74
+ self,
75
+ sample: torch.Tensor,
76
+ timestep: torch.Tensor,
77
+ encoder_hidden_states: torch.Tensor,
78
+ controlnet_cond: torch.Tensor,
79
+ conditioning_scale,
80
+ text_embeds: Optional[torch.Tensor] = None,
81
+ time_ids: Optional[torch.Tensor] = None,
82
+ ):
83
+ if text_embeds is not None and time_ids is not None:
84
+ added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
85
+ else:
86
+ added_cond_kwargs = {}
87
+
88
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
89
+ sample=sample,
90
+ timestep=timestep,
91
+ encoder_hidden_states=encoder_hidden_states,
92
+ controlnet_cond=controlnet_cond,
93
+ conditioning_scale=conditioning_scale,
94
+ added_cond_kwargs=added_cond_kwargs,
95
+ return_dict=False,
96
+ )
97
+ return down_block_res_samples, mid_block_res_sample
98
+
99
+
100
+ class RBLNControlNetModel(RBLNModel):
101
+ """
102
+ RBLN implementation of ControlNetModel for diffusion models.
103
+
104
+ This model is used to accelerate ControlNetModel models from diffusers library on RBLN NPUs.
105
+
106
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
107
+ the library implements for all its models.
108
+ """
109
+
110
+ hf_library_name = "diffusers"
111
+ auto_model_class = ControlNetModel
112
+ output_class = ControlNetOutput
113
+
114
+ def __post_init__(self, **kwargs):
115
+ super().__post_init__(**kwargs)
116
+ self.use_encoder_hidden_states = any(
117
+ item[0] == "encoder_hidden_states" for item in self.rbln_config.compile_cfgs[0].input_info
118
+ )
119
+
120
+ @classmethod
121
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
122
+ use_encoder_hidden_states = False
123
+ for down_block in model.down_blocks:
124
+ if use_encoder_hidden_states := getattr(down_block, "has_cross_attention", False):
125
+ break
126
+
127
+ if use_encoder_hidden_states:
128
+ return _ControlNetModel_Cross_Attention(model).eval()
129
+ else:
130
+ return _ControlNetModel(model).eval()
131
+
132
+ @classmethod
133
+ def update_rbln_config_using_pipe(
134
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
135
+ ) -> "RBLNDiffusionMixinConfig":
136
+ rbln_vae_cls = get_rbln_model_cls(f"RBLN{pipe.vae.__class__.__name__}")
137
+ rbln_unet_cls = get_rbln_model_cls(f"RBLN{pipe.unet.__class__.__name__}")
138
+
139
+ rbln_config.controlnet.max_seq_len = pipe.text_encoder.config.max_position_embeddings
140
+ text_model_hidden_size = pipe.text_encoder_2.config.hidden_size if hasattr(pipe, "text_encoder_2") else None
141
+ rbln_config.controlnet.text_model_hidden_size = text_model_hidden_size
142
+ rbln_config.controlnet.vae_sample_size = rbln_vae_cls.get_vae_sample_size(pipe, rbln_config.vae)
143
+ rbln_config.controlnet.unet_sample_size = rbln_unet_cls.get_unet_sample_size(
144
+ pipe, rbln_config.unet, image_size=rbln_config.image_size
145
+ )
146
+
147
+ return rbln_config
148
+
149
+ @classmethod
150
+ def _update_rbln_config(
151
+ cls,
152
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
153
+ model: "PreTrainedModel",
154
+ model_config: "PretrainedConfig",
155
+ rbln_config: RBLNControlNetModelConfig,
156
+ ) -> RBLNModelConfig:
157
+ if rbln_config.unet_sample_size is None:
158
+ raise ValueError("`unet_sample_size` (latent height, width) must be specified (ex. unet's sample_size)")
159
+
160
+ if rbln_config.vae_sample_size is None:
161
+ raise ValueError("`vae_sample_size` (input image height, width) must be specified (ex. vae's sample_size)")
162
+
163
+ if rbln_config.max_seq_len is None:
164
+ raise ValueError("`max_seq_len` (ex. text_encoder's max_position_embeddings) must be specified")
165
+
166
+ input_info = [
167
+ (
168
+ "sample",
169
+ [
170
+ rbln_config.batch_size,
171
+ model_config.in_channels,
172
+ rbln_config.unet_sample_size[0],
173
+ rbln_config.unet_sample_size[1],
174
+ ],
175
+ "float32",
176
+ ),
177
+ ("timestep", [], "float32"),
178
+ ]
179
+
180
+ use_encoder_hidden_states = any(element != "DownBlock2D" for element in model_config.down_block_types)
181
+ if use_encoder_hidden_states:
182
+ input_info.append(
183
+ (
184
+ "encoder_hidden_states",
185
+ [rbln_config.batch_size, rbln_config.max_seq_len, model_config.cross_attention_dim],
186
+ "float32",
187
+ )
188
+ )
189
+
190
+ input_info.append(
191
+ (
192
+ "controlnet_cond",
193
+ [rbln_config.batch_size, 3, rbln_config.vae_sample_size[0], rbln_config.vae_sample_size[1]],
194
+ "float32",
195
+ )
196
+ )
197
+ input_info.append(("conditioning_scale", [], "float32"))
198
+
199
+ if hasattr(model_config, "addition_embed_type") and model_config.addition_embed_type == "text_time":
200
+ input_info.append(("text_embeds", [rbln_config.batch_size, rbln_config.text_model_hidden_size], "float32"))
201
+ input_info.append(("time_ids", [rbln_config.batch_size, 6], "float32"))
202
+
203
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
204
+ rbln_config.set_compile_cfgs([rbln_compile_config])
205
+ return rbln_config
206
+
207
+ @property
208
+ def compiled_batch_size(self):
209
+ return self.rbln_config.compile_cfgs[0].input_info[0][1][0]
210
+
211
+ def forward(
212
+ self,
213
+ sample: torch.FloatTensor,
214
+ timestep: Union[torch.Tensor, float, int],
215
+ encoder_hidden_states: torch.Tensor,
216
+ controlnet_cond: torch.FloatTensor,
217
+ conditioning_scale: torch.Tensor = 1.0,
218
+ added_cond_kwargs: Dict[str, torch.Tensor] = {},
219
+ return_dict: bool = True,
220
+ **kwargs,
221
+ ):
222
+ """
223
+ Forward pass for the RBLN-optimized ControlNetModel.
224
+
225
+ Args:
226
+ sample (torch.FloatTensor): The noisy input tensor.
227
+ timestep (Union[torch.Tensor, float, int]): The number of timesteps to denoise an input.
228
+ encoder_hidden_states (torch.Tensor): The encoder hidden states.
229
+ controlnet_cond (torch.FloatTensor): The conditional input tensor of shape `(batch_size, max_seq_len, hidden_size)`.
230
+ conditioning_scale (torch.Tensor): The scale factor for ControlNet outputs.
231
+ added_cond_kwargs (Dict[str, torch.Tensor]): Additional conditions for the Stable Diffusion XL UNet.
232
+ return_dict (bool): Whether or not to return a [`~diffusers.models.controlnets.controlnet.ControlNetOutput`] instead of a plain tuple
233
+
234
+ Returns:
235
+ (Union[`~diffusers.models.controlnets.controlnet.ControlNetOutput`], Tuple)
236
+ """
237
+ sample_batch_size = sample.size()[0]
238
+ compiled_batch_size = self.compiled_batch_size
239
+ if sample_batch_size != compiled_batch_size and (
240
+ sample_batch_size * 2 == compiled_batch_size or sample_batch_size == compiled_batch_size * 2
241
+ ):
242
+ raise ValueError(
243
+ f"Mismatch between ControlNet's runtime batch size ({sample_batch_size}) and compiled batch size ({compiled_batch_size}). "
244
+ "This may be caused by the 'guidance_scale' parameter, which doubles the runtime batch size of ControlNet in Stable Diffusion. "
245
+ "Adjust the batch size of ControlNet during compilation to match the runtime batch size.\n\n"
246
+ "For details, see: https://docs.rbln.ai/software/optimum/model_api/diffusers/pipelines/controlnet.html#important-batch-size-configuration-for-guidance-scale"
247
+ )
248
+
249
+ added_cond_kwargs = {} if added_cond_kwargs is None else added_cond_kwargs
250
+ if self.use_encoder_hidden_states:
251
+ output = self.model[0](
252
+ sample.contiguous(),
253
+ timestep.float(),
254
+ encoder_hidden_states,
255
+ controlnet_cond,
256
+ torch.tensor(conditioning_scale),
257
+ **added_cond_kwargs,
258
+ )
259
+ else:
260
+ output = self.model[0](
261
+ sample.contiguous(),
262
+ timestep.float(),
263
+ controlnet_cond,
264
+ torch.tensor(conditioning_scale),
265
+ **added_cond_kwargs,
266
+ )
267
+
268
+ down_block_res_samples = output[:-1]
269
+ mid_block_res_sample = output[-1]
270
+ output = (down_block_res_samples, mid_block_res_sample)
271
+ output = self._prepare_output(output, return_dict)
272
+ return output
273
+
274
+ def _prepare_output(self, output, return_dict):
275
+ if not return_dict:
276
+ return (output,) if not isinstance(output, (tuple, list)) else output
277
+ else:
278
+ return ControlNetOutput(
279
+ down_block_res_samples=output[:-1],
280
+ mid_block_res_sample=output[-1],
281
+ )
@@ -0,0 +1,17 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .prior_transformer import RBLNPriorTransformer
16
+ from .transformer_cosmos import RBLNCosmosTransformer3DModel
17
+ from .transformer_sd3 import RBLNSD3Transformer2DModel
@@ -0,0 +1,160 @@
1
+ # Copyright 2024 Rebellions Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import TYPE_CHECKING, Optional, Union
17
+
18
+ import torch
19
+ from diffusers.models.transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
20
+
21
+ from ....configuration_utils import RBLNCompileConfig, RBLNModelConfig
22
+ from ....modeling import RBLNModel
23
+ from ....utils.logging import get_logger
24
+ from ...configurations.models import RBLNPriorTransformerConfig
25
+ from ...modeling_diffusers import RBLNDiffusionMixin, RBLNDiffusionMixinConfig
26
+
27
+
28
+ if TYPE_CHECKING:
29
+ from transformers import AutoFeatureExtractor, AutoProcessor, AutoTokenizer, PretrainedConfig, PreTrainedModel
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ class _PriorTransformer(torch.nn.Module):
35
+ def __init__(self, prior: PriorTransformer):
36
+ super().__init__()
37
+ self._prior = prior
38
+
39
+ def forward(
40
+ self,
41
+ hidden_states,
42
+ timestep,
43
+ proj_embedding,
44
+ encoder_hidden_states,
45
+ attention_mask,
46
+ return_dict=True,
47
+ ):
48
+ return self._prior.forward(
49
+ hidden_states,
50
+ timestep,
51
+ proj_embedding,
52
+ encoder_hidden_states,
53
+ attention_mask,
54
+ return_dict=False,
55
+ )
56
+
57
+
58
+ class RBLNPriorTransformer(RBLNModel):
59
+ """
60
+ RBLN implementation of PriorTransformer for diffusion models like Kandinsky V2.2.
61
+
62
+ The PriorTransformer takes text and/or image embeddings from encoders (like CLIP) and
63
+ maps them to a shared latent space that guides the diffusion process to generate the desired image.
64
+
65
+ This class inherits from [`RBLNModel`]. Check the superclass documentation for the generic methods
66
+ the library implements for all its models.
67
+ """
68
+
69
+ hf_library_name = "diffusers"
70
+ auto_model_class = PriorTransformer
71
+ _output_class = PriorTransformerOutput
72
+
73
+ def __post_init__(self, **kwargs):
74
+ super().__post_init__(**kwargs)
75
+ artifacts = torch.load(self.model_save_dir / self.subfolder / "torch_artifacts.pth", weights_only=False)
76
+ self.clip_mean = artifacts["clip_mean"]
77
+ self.clip_std = artifacts["clip_std"]
78
+
79
+ @classmethod
80
+ def _wrap_model_if_needed(cls, model: torch.nn.Module, rbln_config: RBLNModelConfig) -> torch.nn.Module:
81
+ return _PriorTransformer(model).eval()
82
+
83
+ @classmethod
84
+ def update_rbln_config_using_pipe(
85
+ cls, pipe: RBLNDiffusionMixin, rbln_config: "RBLNDiffusionMixinConfig", submodule_name: str
86
+ ) -> "RBLNDiffusionMixinConfig":
87
+ return rbln_config
88
+
89
+ @classmethod
90
+ def save_torch_artifacts(
91
+ cls, model: "PreTrainedModel", save_dir_path: Path, subfolder: str, rbln_config: RBLNModelConfig
92
+ ):
93
+ save_dict = {}
94
+ save_dict["clip_mean"] = model.clip_mean
95
+ save_dict["clip_std"] = model.clip_std
96
+ torch.save(save_dict, save_dir_path / subfolder / "torch_artifacts.pth")
97
+
98
+ @classmethod
99
+ def _update_rbln_config(
100
+ cls,
101
+ preprocessors: Union["AutoFeatureExtractor", "AutoProcessor", "AutoTokenizer"],
102
+ model: "PreTrainedModel",
103
+ model_config: "PretrainedConfig",
104
+ rbln_config: RBLNPriorTransformerConfig,
105
+ ) -> RBLNPriorTransformerConfig:
106
+ rbln_config.embedding_dim = rbln_config.embedding_dim or model_config.embedding_dim
107
+ rbln_config.num_embeddings = rbln_config.num_embeddings or model_config.num_embeddings
108
+
109
+ input_info = [
110
+ ("hidden_states", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
111
+ ("timestep", [], "float32"),
112
+ ("proj_embedding", [rbln_config.batch_size, rbln_config.embedding_dim], "float32"),
113
+ (
114
+ "encoder_hidden_states",
115
+ [rbln_config.batch_size, rbln_config.num_embeddings, rbln_config.embedding_dim],
116
+ "float32",
117
+ ),
118
+ ("attention_mask", [rbln_config.batch_size, rbln_config.num_embeddings], "float32"),
119
+ ]
120
+
121
+ rbln_compile_config = RBLNCompileConfig(input_info=input_info)
122
+ rbln_config.set_compile_cfgs([rbln_compile_config])
123
+ return rbln_config
124
+
125
+ def post_process_latents(self, prior_latents):
126
+ prior_latents = (prior_latents * self.clip_std) + self.clip_mean
127
+ return prior_latents
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.Tensor,
132
+ timestep: Union[torch.Tensor, float, int],
133
+ proj_embedding: torch.Tensor,
134
+ encoder_hidden_states: Optional[torch.Tensor] = None,
135
+ attention_mask: Optional[torch.Tensor] = None,
136
+ return_dict: bool = True,
137
+ ):
138
+ """
139
+ Forward pass for the RBLN-optimized PriorTransformer.
140
+
141
+ Args:
142
+ hidden_states (torch.Tensor): The currently predicted image embeddings.
143
+ timestep (Union[torch.Tensor, float, int]): Current denoising step.
144
+ proj_embedding (torch.Tensor): Projected embedding vector the denoising process is conditioned on.
145
+ encoder_hidden_states (Optional[torch.Tensor]): Hidden states of the text embeddings the denoising process is conditioned on.
146
+ attention_mask (Optional[torch.Tensor]): Text mask for the text embeddings.
147
+ return_dict (bool): Whether or not to return a [`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`] instead of a plain tuple.
148
+
149
+ Returns:
150
+ (Union[`~diffusers.models.transformers.prior_transformer.PriorTransformerOutput`, Tuple])
151
+ """
152
+ # Convert timestep(long) and attention_mask(bool) to float
153
+ return super().forward(
154
+ hidden_states,
155
+ timestep.float(),
156
+ proj_embedding,
157
+ encoder_hidden_states,
158
+ attention_mask.float(),
159
+ return_dict=return_dict,
160
+ )