optimum-rbln 0.9.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of optimum-rbln might be problematic. Click here for more details.

Files changed (264) hide show
  1. optimum/rbln/__init__.py +505 -0
  2. optimum/rbln/__version__.py +34 -0
  3. optimum/rbln/cli.py +660 -0
  4. optimum/rbln/configuration_utils.py +968 -0
  5. optimum/rbln/diffusers/__init__.py +198 -0
  6. optimum/rbln/diffusers/configurations/__init__.py +37 -0
  7. optimum/rbln/diffusers/configurations/models/__init__.py +10 -0
  8. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl.py +73 -0
  9. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_cosmos.py +84 -0
  10. optimum/rbln/diffusers/configurations/models/configuration_autoencoder_kl_temporal_decoder.py +67 -0
  11. optimum/rbln/diffusers/configurations/models/configuration_controlnet.py +64 -0
  12. optimum/rbln/diffusers/configurations/models/configuration_prior_transformer.py +59 -0
  13. optimum/rbln/diffusers/configurations/models/configuration_transformer_cosmos.py +78 -0
  14. optimum/rbln/diffusers/configurations/models/configuration_transformer_sd3.py +63 -0
  15. optimum/rbln/diffusers/configurations/models/configuration_unet_2d_condition.py +81 -0
  16. optimum/rbln/diffusers/configurations/models/configuration_unet_spatio_temporal_condition.py +59 -0
  17. optimum/rbln/diffusers/configurations/models/configuration_vq_model.py +74 -0
  18. optimum/rbln/diffusers/configurations/pipelines/__init__.py +34 -0
  19. optimum/rbln/diffusers/configurations/pipelines/configuration_controlnet.py +316 -0
  20. optimum/rbln/diffusers/configurations/pipelines/configuration_cosmos.py +117 -0
  21. optimum/rbln/diffusers/configurations/pipelines/configuration_kandinsky2_2.py +363 -0
  22. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion.py +156 -0
  23. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_3.py +176 -0
  24. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_diffusion_xl.py +159 -0
  25. optimum/rbln/diffusers/configurations/pipelines/configuration_stable_video_diffusion.py +114 -0
  26. optimum/rbln/diffusers/modeling_diffusers.py +451 -0
  27. optimum/rbln/diffusers/models/__init__.py +64 -0
  28. optimum/rbln/diffusers/models/autoencoders/__init__.py +18 -0
  29. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl.py +255 -0
  30. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_cosmos.py +245 -0
  31. optimum/rbln/diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +275 -0
  32. optimum/rbln/diffusers/models/autoencoders/vae.py +178 -0
  33. optimum/rbln/diffusers/models/autoencoders/vq_model.py +211 -0
  34. optimum/rbln/diffusers/models/controlnet.py +281 -0
  35. optimum/rbln/diffusers/models/transformers/__init__.py +17 -0
  36. optimum/rbln/diffusers/models/transformers/prior_transformer.py +160 -0
  37. optimum/rbln/diffusers/models/transformers/transformer_cosmos.py +344 -0
  38. optimum/rbln/diffusers/models/transformers/transformer_sd3.py +191 -0
  39. optimum/rbln/diffusers/models/unets/__init__.py +16 -0
  40. optimum/rbln/diffusers/models/unets/unet_2d_condition.py +408 -0
  41. optimum/rbln/diffusers/models/unets/unet_spatio_temporal_condition.py +201 -0
  42. optimum/rbln/diffusers/pipelines/__init__.py +113 -0
  43. optimum/rbln/diffusers/pipelines/auto_pipeline.py +307 -0
  44. optimum/rbln/diffusers/pipelines/controlnet/__init__.py +19 -0
  45. optimum/rbln/diffusers/pipelines/controlnet/multicontrolnet.py +139 -0
  46. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet.py +669 -0
  47. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +640 -0
  48. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +825 -0
  49. optimum/rbln/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +837 -0
  50. optimum/rbln/diffusers/pipelines/cosmos/__init__.py +17 -0
  51. optimum/rbln/diffusers/pipelines/cosmos/configuration_cosmos_guardrail.py +113 -0
  52. optimum/rbln/diffusers/pipelines/cosmos/cosmos_guardrail.py +425 -0
  53. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +128 -0
  54. optimum/rbln/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +128 -0
  55. optimum/rbln/diffusers/pipelines/kandinsky2_2/__init__.py +23 -0
  56. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +34 -0
  57. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +207 -0
  58. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +34 -0
  59. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpaint.py +34 -0
  60. optimum/rbln/diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +31 -0
  61. optimum/rbln/diffusers/pipelines/stable_diffusion/__init__.py +17 -0
  62. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +32 -0
  63. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +31 -0
  64. optimum/rbln/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +31 -0
  65. optimum/rbln/diffusers/pipelines/stable_diffusion_3/__init__.py +17 -0
  66. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +31 -0
  67. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +31 -0
  68. optimum/rbln/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +31 -0
  69. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/__init__.py +17 -0
  70. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +31 -0
  71. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +31 -0
  72. optimum/rbln/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +31 -0
  73. optimum/rbln/diffusers/pipelines/stable_video_diffusion/__init__.py +15 -0
  74. optimum/rbln/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +46 -0
  75. optimum/rbln/modeling.py +364 -0
  76. optimum/rbln/modeling_base.py +637 -0
  77. optimum/rbln/ops/__init__.py +19 -0
  78. optimum/rbln/ops/attn.py +455 -0
  79. optimum/rbln/ops/flash_attn.py +350 -0
  80. optimum/rbln/ops/kv_cache_update.py +29 -0
  81. optimum/rbln/ops/linear.py +32 -0
  82. optimum/rbln/ops/sliding_window_attn.py +111 -0
  83. optimum/rbln/transformers/__init__.py +340 -0
  84. optimum/rbln/transformers/configuration_generic.py +120 -0
  85. optimum/rbln/transformers/modeling_attention_utils.py +385 -0
  86. optimum/rbln/transformers/modeling_generic.py +280 -0
  87. optimum/rbln/transformers/modeling_outputs.py +37 -0
  88. optimum/rbln/transformers/modeling_rope_utils.py +314 -0
  89. optimum/rbln/transformers/models/__init__.py +343 -0
  90. optimum/rbln/transformers/models/audio_spectrogram_transformer/__init__.py +17 -0
  91. optimum/rbln/transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +47 -0
  92. optimum/rbln/transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +91 -0
  93. optimum/rbln/transformers/models/auto/__init__.py +31 -0
  94. optimum/rbln/transformers/models/auto/auto_factory.py +267 -0
  95. optimum/rbln/transformers/models/auto/modeling_auto.py +162 -0
  96. optimum/rbln/transformers/models/bart/__init__.py +17 -0
  97. optimum/rbln/transformers/models/bart/bart_architecture.py +163 -0
  98. optimum/rbln/transformers/models/bart/configuration_bart.py +36 -0
  99. optimum/rbln/transformers/models/bart/modeling_bart.py +86 -0
  100. optimum/rbln/transformers/models/bert/__init__.py +16 -0
  101. optimum/rbln/transformers/models/bert/bert_architecture.py +16 -0
  102. optimum/rbln/transformers/models/bert/configuration_bert.py +46 -0
  103. optimum/rbln/transformers/models/bert/modeling_bert.py +148 -0
  104. optimum/rbln/transformers/models/blip_2/__init__.py +20 -0
  105. optimum/rbln/transformers/models/blip_2/configuration_blip_2.py +115 -0
  106. optimum/rbln/transformers/models/blip_2/modeling_blip_2.py +526 -0
  107. optimum/rbln/transformers/models/clip/__init__.py +26 -0
  108. optimum/rbln/transformers/models/clip/configuration_clip.py +103 -0
  109. optimum/rbln/transformers/models/clip/modeling_clip.py +384 -0
  110. optimum/rbln/transformers/models/colpali/__init__.py +2 -0
  111. optimum/rbln/transformers/models/colpali/colpali_architecture.py +218 -0
  112. optimum/rbln/transformers/models/colpali/configuration_colpali.py +84 -0
  113. optimum/rbln/transformers/models/colpali/modeling_colpali.py +361 -0
  114. optimum/rbln/transformers/models/colqwen2/__init__.py +2 -0
  115. optimum/rbln/transformers/models/colqwen2/colqwen2_architecture.py +233 -0
  116. optimum/rbln/transformers/models/colqwen2/configuration_colqwen2.py +74 -0
  117. optimum/rbln/transformers/models/colqwen2/modeling_colqwen2.py +446 -0
  118. optimum/rbln/transformers/models/decoderonly/__init__.py +27 -0
  119. optimum/rbln/transformers/models/decoderonly/configuration_decoderonly.py +300 -0
  120. optimum/rbln/transformers/models/decoderonly/configuration_lora.py +411 -0
  121. optimum/rbln/transformers/models/decoderonly/decoderonly_architecture.py +1224 -0
  122. optimum/rbln/transformers/models/decoderonly/decoderonly_runtime_utils.py +508 -0
  123. optimum/rbln/transformers/models/decoderonly/generation_decoderonly.py +119 -0
  124. optimum/rbln/transformers/models/decoderonly/lora_architecture.py +204 -0
  125. optimum/rbln/transformers/models/decoderonly/modeling_decoderonly.py +823 -0
  126. optimum/rbln/transformers/models/depth_anything/__init__.py +16 -0
  127. optimum/rbln/transformers/models/depth_anything/configuration_depth_anything.py +24 -0
  128. optimum/rbln/transformers/models/depth_anything/modeling_depth_anything.py +42 -0
  129. optimum/rbln/transformers/models/distilbert/__init__.py +19 -0
  130. optimum/rbln/transformers/models/distilbert/configuration_distilbert.py +24 -0
  131. optimum/rbln/transformers/models/distilbert/modeling_distilbert.py +51 -0
  132. optimum/rbln/transformers/models/dpt/__init__.py +16 -0
  133. optimum/rbln/transformers/models/dpt/configuration_dpt.py +24 -0
  134. optimum/rbln/transformers/models/dpt/modeling_dpt.py +42 -0
  135. optimum/rbln/transformers/models/exaone/__init__.py +24 -0
  136. optimum/rbln/transformers/models/exaone/configuration_exaone.py +42 -0
  137. optimum/rbln/transformers/models/exaone/exaone_architecture.py +77 -0
  138. optimum/rbln/transformers/models/exaone/modeling_exaone.py +145 -0
  139. optimum/rbln/transformers/models/gemma/__init__.py +16 -0
  140. optimum/rbln/transformers/models/gemma/configuration_gemma.py +50 -0
  141. optimum/rbln/transformers/models/gemma/gemma_architecture.py +27 -0
  142. optimum/rbln/transformers/models/gemma/modeling_gemma.py +104 -0
  143. optimum/rbln/transformers/models/gemma3/__init__.py +16 -0
  144. optimum/rbln/transformers/models/gemma3/configuration_gemma3.py +109 -0
  145. optimum/rbln/transformers/models/gemma3/gemma3_architecture.py +170 -0
  146. optimum/rbln/transformers/models/gemma3/gemma3_runtime_utils.py +245 -0
  147. optimum/rbln/transformers/models/gemma3/modeling_gemma3.py +611 -0
  148. optimum/rbln/transformers/models/gpt2/__init__.py +16 -0
  149. optimum/rbln/transformers/models/gpt2/configuration_gpt2.py +50 -0
  150. optimum/rbln/transformers/models/gpt2/gpt2_architecture.py +93 -0
  151. optimum/rbln/transformers/models/gpt2/modeling_gpt2.py +55 -0
  152. optimum/rbln/transformers/models/grounding_dino/__init__.py +10 -0
  153. optimum/rbln/transformers/models/grounding_dino/configuration_grounding_dino.py +92 -0
  154. optimum/rbln/transformers/models/grounding_dino/grounding_dino_architecture.py +599 -0
  155. optimum/rbln/transformers/models/grounding_dino/modeling_grounding_dino.py +1048 -0
  156. optimum/rbln/transformers/models/idefics3/__init__.py +16 -0
  157. optimum/rbln/transformers/models/idefics3/configuration_idefics3.py +89 -0
  158. optimum/rbln/transformers/models/idefics3/modeling_idefics3.py +497 -0
  159. optimum/rbln/transformers/models/llama/__init__.py +16 -0
  160. optimum/rbln/transformers/models/llama/configuration_llama.py +50 -0
  161. optimum/rbln/transformers/models/llama/llama_architecture.py +19 -0
  162. optimum/rbln/transformers/models/llama/modeling_llama.py +104 -0
  163. optimum/rbln/transformers/models/llava/__init__.py +16 -0
  164. optimum/rbln/transformers/models/llava/configuration_llava.py +72 -0
  165. optimum/rbln/transformers/models/llava/modeling_llava.py +490 -0
  166. optimum/rbln/transformers/models/llava_next/__init__.py +16 -0
  167. optimum/rbln/transformers/models/llava_next/configuration_llava_next.py +69 -0
  168. optimum/rbln/transformers/models/llava_next/modeling_llava_next.py +493 -0
  169. optimum/rbln/transformers/models/midm/__init__.py +24 -0
  170. optimum/rbln/transformers/models/midm/configuration_midm.py +42 -0
  171. optimum/rbln/transformers/models/midm/midm_architecture.py +144 -0
  172. optimum/rbln/transformers/models/midm/modeling_midm.py +144 -0
  173. optimum/rbln/transformers/models/mistral/__init__.py +16 -0
  174. optimum/rbln/transformers/models/mistral/configuration_mistral.py +50 -0
  175. optimum/rbln/transformers/models/mistral/mistral_architecture.py +19 -0
  176. optimum/rbln/transformers/models/mistral/modeling_mistral.py +115 -0
  177. optimum/rbln/transformers/models/opt/__init__.py +16 -0
  178. optimum/rbln/transformers/models/opt/configuration_opt.py +29 -0
  179. optimum/rbln/transformers/models/opt/modeling_opt.py +102 -0
  180. optimum/rbln/transformers/models/opt/opt_architecture.py +74 -0
  181. optimum/rbln/transformers/models/pegasus/__init__.py +17 -0
  182. optimum/rbln/transformers/models/pegasus/configuration_pegasus.py +38 -0
  183. optimum/rbln/transformers/models/pegasus/modeling_pegasus.py +71 -0
  184. optimum/rbln/transformers/models/pegasus/pegasus_architecture.py +161 -0
  185. optimum/rbln/transformers/models/phi/__init__.py +16 -0
  186. optimum/rbln/transformers/models/phi/configuration_phi.py +50 -0
  187. optimum/rbln/transformers/models/phi/modeling_phi.py +92 -0
  188. optimum/rbln/transformers/models/phi/phi_architecture.py +115 -0
  189. optimum/rbln/transformers/models/pixtral/__init__.py +16 -0
  190. optimum/rbln/transformers/models/pixtral/configuration_pixtral.py +43 -0
  191. optimum/rbln/transformers/models/pixtral/modeling_pixtral.py +322 -0
  192. optimum/rbln/transformers/models/pixtral/pixtral_architecture.py +73 -0
  193. optimum/rbln/transformers/models/qwen2/__init__.py +16 -0
  194. optimum/rbln/transformers/models/qwen2/configuration_qwen2.py +50 -0
  195. optimum/rbln/transformers/models/qwen2/modeling_qwen2.py +123 -0
  196. optimum/rbln/transformers/models/qwen2/qwen2_architecture.py +19 -0
  197. optimum/rbln/transformers/models/qwen2_5_vl/__init__.py +19 -0
  198. optimum/rbln/transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +111 -0
  199. optimum/rbln/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +636 -0
  200. optimum/rbln/transformers/models/qwen2_5_vl/qwen2_5_vl_architecture.py +220 -0
  201. optimum/rbln/transformers/models/qwen2_vl/__init__.py +19 -0
  202. optimum/rbln/transformers/models/qwen2_vl/configuration_qwen2_vl.py +88 -0
  203. optimum/rbln/transformers/models/qwen2_vl/modeling_qwen2_vl.py +513 -0
  204. optimum/rbln/transformers/models/qwen2_vl/qwen2_vl_architecture.py +165 -0
  205. optimum/rbln/transformers/models/qwen3/__init__.py +16 -0
  206. optimum/rbln/transformers/models/qwen3/configuration_qwen3.py +71 -0
  207. optimum/rbln/transformers/models/qwen3/modeling_qwen3.py +133 -0
  208. optimum/rbln/transformers/models/qwen3/qwen3_architecture.py +31 -0
  209. optimum/rbln/transformers/models/resnet/__init__.py +23 -0
  210. optimum/rbln/transformers/models/resnet/configuration_resnet.py +42 -0
  211. optimum/rbln/transformers/models/resnet/modeling_resnet.py +99 -0
  212. optimum/rbln/transformers/models/roberta/__init__.py +24 -0
  213. optimum/rbln/transformers/models/roberta/configuration_roberta.py +33 -0
  214. optimum/rbln/transformers/models/roberta/modeling_roberta.py +72 -0
  215. optimum/rbln/transformers/models/seq2seq/__init__.py +16 -0
  216. optimum/rbln/transformers/models/seq2seq/configuration_seq2seq.py +71 -0
  217. optimum/rbln/transformers/models/seq2seq/modeling_seq2seq.py +477 -0
  218. optimum/rbln/transformers/models/seq2seq/seq2seq_architecture.py +527 -0
  219. optimum/rbln/transformers/models/siglip/__init__.py +16 -0
  220. optimum/rbln/transformers/models/siglip/configuration_siglip.py +76 -0
  221. optimum/rbln/transformers/models/siglip/modeling_siglip.py +199 -0
  222. optimum/rbln/transformers/models/swin/__init__.py +16 -0
  223. optimum/rbln/transformers/models/swin/configuration_swin.py +42 -0
  224. optimum/rbln/transformers/models/swin/modeling_swin.py +354 -0
  225. optimum/rbln/transformers/models/t5/__init__.py +17 -0
  226. optimum/rbln/transformers/models/t5/configuration_t5.py +36 -0
  227. optimum/rbln/transformers/models/t5/modeling_t5.py +130 -0
  228. optimum/rbln/transformers/models/t5/t5_architecture.py +264 -0
  229. optimum/rbln/transformers/models/time_series_transformer/__init__.py +26 -0
  230. optimum/rbln/transformers/models/time_series_transformer/configuration_time_series_transformer.py +41 -0
  231. optimum/rbln/transformers/models/time_series_transformer/modeling_time_series_transformer.py +435 -0
  232. optimum/rbln/transformers/models/time_series_transformer/time_series_transformers_architecture.py +337 -0
  233. optimum/rbln/transformers/models/vit/__init__.py +19 -0
  234. optimum/rbln/transformers/models/vit/configuration_vit.py +24 -0
  235. optimum/rbln/transformers/models/vit/modeling_vit.py +44 -0
  236. optimum/rbln/transformers/models/wav2vec2/__init__.py +16 -0
  237. optimum/rbln/transformers/models/wav2vec2/configuration_wav2vec2.py +38 -0
  238. optimum/rbln/transformers/models/wav2vec2/modeling_wav2vec2.py +104 -0
  239. optimum/rbln/transformers/models/whisper/__init__.py +17 -0
  240. optimum/rbln/transformers/models/whisper/configuration_whisper.py +72 -0
  241. optimum/rbln/transformers/models/whisper/generation_whisper.py +159 -0
  242. optimum/rbln/transformers/models/whisper/modeling_whisper.py +475 -0
  243. optimum/rbln/transformers/models/whisper/whisper_architecture.py +349 -0
  244. optimum/rbln/transformers/models/xlm_roberta/__init__.py +24 -0
  245. optimum/rbln/transformers/models/xlm_roberta/configuration_xlm_roberta.py +32 -0
  246. optimum/rbln/transformers/models/xlm_roberta/modeling_xlm_roberta.py +82 -0
  247. optimum/rbln/transformers/utils/__init__.py +0 -0
  248. optimum/rbln/transformers/utils/rbln_quantization.py +589 -0
  249. optimum/rbln/transformers/utils/rbln_runtime_wrapper.py +79 -0
  250. optimum/rbln/utils/__init__.py +16 -0
  251. optimum/rbln/utils/decorator_utils.py +86 -0
  252. optimum/rbln/utils/deprecation.py +213 -0
  253. optimum/rbln/utils/hub.py +94 -0
  254. optimum/rbln/utils/import_utils.py +170 -0
  255. optimum/rbln/utils/logging.py +110 -0
  256. optimum/rbln/utils/model_utils.py +63 -0
  257. optimum/rbln/utils/runtime_utils.py +249 -0
  258. optimum/rbln/utils/save_utils.py +102 -0
  259. optimum/rbln/utils/submodule.py +152 -0
  260. optimum_rbln-0.9.3.post1.dist-info/METADATA +124 -0
  261. optimum_rbln-0.9.3.post1.dist-info/RECORD +264 -0
  262. optimum_rbln-0.9.3.post1.dist-info/WHEEL +4 -0
  263. optimum_rbln-0.9.3.post1.dist-info/entry_points.txt +2 -0
  264. optimum_rbln-0.9.3.post1.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,117 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNT5EncoderModelConfig
19
+ from ....utils.logging import get_logger
20
+ from ...pipelines.cosmos.cosmos_guardrail import RBLNCosmosSafetyCheckerConfig
21
+ from ..models import RBLNAutoencoderKLCosmosConfig, RBLNCosmosTransformer3DModelConfig
22
+
23
+
24
+ logger = get_logger(__name__)
25
+
26
+
27
+ class RBLNCosmosPipelineBaseConfig(RBLNModelConfig):
28
+ submodules = ["text_encoder", "transformer", "vae", "safety_checker"]
29
+ _vae_uses_encoder = False
30
+
31
+ def __init__(
32
+ self,
33
+ text_encoder: Optional[RBLNT5EncoderModelConfig] = None,
34
+ transformer: Optional[RBLNCosmosTransformer3DModelConfig] = None,
35
+ vae: Optional[RBLNAutoencoderKLCosmosConfig] = None,
36
+ safety_checker: Optional[RBLNCosmosSafetyCheckerConfig] = None,
37
+ *,
38
+ batch_size: Optional[int] = None,
39
+ height: Optional[int] = None,
40
+ width: Optional[int] = None,
41
+ num_frames: Optional[int] = None,
42
+ fps: Optional[int] = None,
43
+ max_seq_len: Optional[int] = None,
44
+ **kwargs: Any,
45
+ ):
46
+ """
47
+ Args:
48
+ text_encoder (Optional[RBLNT5EncoderModelConfig]): Configuration for the text encoder component.
49
+ Initialized as RBLNT5EncoderModelConfig if not provided.
50
+ transformer (Optional[RBLNCosmosTransformer3DModelConfig]): Configuration for the Transformer model component.
51
+ Initialized as RBLNCosmosTransformer3DModelConfig if not provided.
52
+ vae (Optional[RBLNAutoencoderKLCosmosConfig]): Configuration for the VAE model component.
53
+ Initialized as RBLNAutoencoderKLCosmosConfig if not provided.
54
+ safety_checker (Optional[RBLNCosmosSafetyCheckerConfig]): Configuration for the safety checker component.
55
+ Initialized as RBLNCosmosSafetyCheckerConfig if not provided.
56
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
57
+ height (Optional[int]): Height of the generated videos.
58
+ width (Optional[int]): Width of the generated videos.
59
+ num_frames (Optional[int]): The number of frames in the generated video.
60
+ fps (Optional[int]): The frames per second of the generated video.
61
+ max_seq_len (Optional[int]): Maximum sequence length supported by the model.
62
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
63
+ """
64
+ super().__init__(**kwargs)
65
+
66
+ self.text_encoder = self.initialize_submodule_config(
67
+ text_encoder,
68
+ cls_name="RBLNT5EncoderModelConfig",
69
+ batch_size=batch_size,
70
+ max_seq_len=max_seq_len,
71
+ )
72
+ self.transformer = self.initialize_submodule_config(
73
+ transformer,
74
+ cls_name="RBLNCosmosTransformer3DModelConfig",
75
+ batch_size=batch_size,
76
+ max_seq_len=max_seq_len,
77
+ height=height,
78
+ width=width,
79
+ num_frames=num_frames,
80
+ fps=fps,
81
+ )
82
+ self.vae = self.initialize_submodule_config(
83
+ vae,
84
+ cls_name="RBLNAutoencoderKLCosmosConfig",
85
+ batch_size=batch_size,
86
+ uses_encoder=self.__class__._vae_uses_encoder,
87
+ height=height,
88
+ width=width,
89
+ num_frames=num_frames,
90
+ )
91
+ self.safety_checker = self.initialize_submodule_config(
92
+ safety_checker,
93
+ cls_name="RBLNCosmosSafetyCheckerConfig",
94
+ batch_size=batch_size,
95
+ height=height,
96
+ width=width,
97
+ )
98
+
99
+ @property
100
+ def batch_size(self):
101
+ return self.vae.batch_size
102
+
103
+ @property
104
+ def max_seq_len(self):
105
+ return self.text_encoder.max_seq_len
106
+
107
+
108
+ class RBLNCosmosTextToWorldPipelineConfig(RBLNCosmosPipelineBaseConfig):
109
+ """Config for Cosmos Text2World Pipeline"""
110
+
111
+ _vae_uses_encoder = False
112
+
113
+
114
+ class RBLNCosmosVideoToWorldPipelineConfig(RBLNCosmosPipelineBaseConfig):
115
+ """Config for Cosmos Video2World Pipeline"""
116
+
117
+ _vae_uses_encoder = True
@@ -0,0 +1,363 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelWithProjectionConfig, RBLNCLIPVisionModelWithProjectionConfig
19
+ from ..models import RBLNUNet2DConditionModelConfig, RBLNVQModelConfig
20
+ from ..models.configuration_prior_transformer import RBLNPriorTransformerConfig
21
+
22
+
23
+ class RBLNKandinskyV22PipelineBaseConfig(RBLNModelConfig):
24
+ submodules = ["unet", "movq"]
25
+ _movq_uses_encoder = False
26
+
27
+ def __init__(
28
+ self,
29
+ unet: Optional[RBLNUNet2DConditionModelConfig] = None,
30
+ movq: Optional[RBLNVQModelConfig] = None,
31
+ *,
32
+ sample_size: Optional[Tuple[int, int]] = None,
33
+ batch_size: Optional[int] = None,
34
+ guidance_scale: Optional[float] = None,
35
+ image_size: Optional[Tuple[int, int]] = None,
36
+ img_height: Optional[int] = None,
37
+ img_width: Optional[int] = None,
38
+ height: Optional[int] = None,
39
+ width: Optional[int] = None,
40
+ **kwargs: Any,
41
+ ):
42
+ """
43
+ Args:
44
+ unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
45
+ Initialized as RBLNUNet2DConditionModelConfig if not provided.
46
+ movq (Optional[RBLNVQModelConfig]): Configuration for the MoVQ (VQ-GAN) model component.
47
+ Initialized as RBLNVQModelConfig if not provided.
48
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
49
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
50
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
51
+ image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
52
+ Cannot be used together with img_height/img_width.
53
+ img_height (Optional[int]): Height of the generated images.
54
+ img_width (Optional[int]): Width of the generated images.
55
+ height (Optional[int]): Height of the generated images.
56
+ width (Optional[int]): Width of the generated images.
57
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
58
+
59
+ Raises:
60
+ ValueError: If both image_size and img_height/img_width are provided.
61
+
62
+ Note:
63
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
64
+ accommodate classifier-free guidance.
65
+ """
66
+ super().__init__(**kwargs)
67
+
68
+ # Initial check for image_size conflict remains as is
69
+ if image_size is not None and (
70
+ img_height is not None or img_width is not None or height is not None or width is not None
71
+ ):
72
+ raise ValueError("image_size cannot be provided alongside img_height/img_width or height/width")
73
+
74
+ # Prioritize height/width (HF-aligned)
75
+ if height is not None and width is not None:
76
+ if img_height is not None or img_width is not None:
77
+ # Raise error if both sets of arguments are provided
78
+ raise ValueError(
79
+ "Cannot provide both 'height'/'width' and 'img_height'/'img_width' simultaneously. "
80
+ "Please use one set of arguments for image dimensions, preferring 'height'/'width'."
81
+ )
82
+ image_size = (height, width)
83
+ elif (height is not None and width is None) or (height is None and width is not None):
84
+ raise ValueError("Both height and width must be provided together if used")
85
+ # Fallback to img_height/img_width for backward compatibility
86
+ elif img_height is not None and img_width is not None:
87
+ image_size = (img_height, img_width)
88
+ elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
89
+ raise ValueError("Both img_height and img_width must be provided together if used")
90
+
91
+ self.unet = self.initialize_submodule_config(
92
+ unet,
93
+ cls_name="RBLNUNet2DConditionModelConfig",
94
+ sample_size=sample_size,
95
+ )
96
+ self.movq = self.initialize_submodule_config(
97
+ movq,
98
+ cls_name="RBLNVQModelConfig",
99
+ batch_size=batch_size,
100
+ sample_size=image_size, # image size is equal to sample size in vae
101
+ uses_encoder=self._movq_uses_encoder,
102
+ )
103
+
104
+ # Get default guidance scale from original class to set UNet batch size
105
+ if guidance_scale is None:
106
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
107
+
108
+ if not self.unet.batch_size_is_specified:
109
+ do_classifier_free_guidance = guidance_scale > 1.0
110
+ if do_classifier_free_guidance:
111
+ self.unet.batch_size = self.movq.batch_size * 2
112
+ else:
113
+ self.unet.batch_size = self.movq.batch_size
114
+
115
+ @property
116
+ def batch_size(self):
117
+ return self.movq.batch_size
118
+
119
+ @property
120
+ def image_size(self):
121
+ return self.movq.sample_size
122
+
123
+
124
+ class RBLNKandinskyV22PipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
125
+ """Configuration class for the Kandinsky V2.2 text-to-image decoder pipeline."""
126
+
127
+ _movq_uses_encoder = False
128
+
129
+
130
+ class RBLNKandinskyV22Img2ImgPipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
131
+ """Configuration class for the Kandinsky V2.2 image-to-image decoder pipeline."""
132
+
133
+ _movq_uses_encoder = True
134
+
135
+
136
+ class RBLNKandinskyV22InpaintPipelineConfig(RBLNKandinskyV22PipelineBaseConfig):
137
+ """Configuration class for the Kandinsky V2.2 inpainting decoder pipeline."""
138
+
139
+ _movq_uses_encoder = True
140
+
141
+
142
+ class RBLNKandinskyV22PriorPipelineConfig(RBLNModelConfig):
143
+ """Configuration class for the Kandinsky V2.2 Prior pipeline."""
144
+
145
+ submodules = ["text_encoder", "image_encoder", "prior"]
146
+
147
+ def __init__(
148
+ self,
149
+ text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
150
+ image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
151
+ prior: Optional[RBLNPriorTransformerConfig] = None,
152
+ *,
153
+ batch_size: Optional[int] = None,
154
+ guidance_scale: Optional[float] = None,
155
+ **kwargs: Any,
156
+ ):
157
+ """
158
+ Initialize a configuration for Kandinsky 2.2 prior pipeline optimized for RBLN NPU.
159
+
160
+ This configuration sets up the prior components of the Kandinsky 2.2 architecture, which includes
161
+ text and image encoders along with a prior transformer that maps text/image embeddings to
162
+ latent representations used to condition the diffusion process.
163
+
164
+ Args:
165
+ text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Configuration for the text encoder component.
166
+ Initialized as RBLNCLIPTextModelWithProjectionConfig if not provided.
167
+ image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Configuration for the image encoder component.
168
+ Initialized as RBLNCLIPVisionModelWithProjectionConfig if not provided.
169
+ prior (Optional[RBLNPriorTransformerConfig]): Configuration for the prior transformer component.
170
+ Initialized as RBLNPriorTransformerConfig if not provided.
171
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
172
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
173
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
174
+
175
+ Note:
176
+ When guidance_scale > 1.0, the prior batch size is automatically doubled to
177
+ accommodate classifier-free guidance.
178
+ """
179
+ super().__init__(**kwargs)
180
+ self.text_encoder = self.initialize_submodule_config(
181
+ text_encoder,
182
+ cls_name="RBLNCLIPTextModelWithProjectionConfig",
183
+ batch_size=batch_size,
184
+ )
185
+ self.image_encoder = self.initialize_submodule_config(
186
+ image_encoder,
187
+ cls_name="RBLNCLIPVisionModelWithProjectionConfig",
188
+ batch_size=batch_size,
189
+ )
190
+ self.prior = self.initialize_submodule_config(
191
+ prior,
192
+ cls_name="RBLNPriorTransformerConfig",
193
+ )
194
+
195
+ # Get default guidance scale from original class to set UNet batch size
196
+ if guidance_scale is None:
197
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
198
+
199
+ if not self.prior.batch_size_is_specified:
200
+ do_classifier_free_guidance = guidance_scale > 1.0
201
+ if do_classifier_free_guidance:
202
+ self.prior.batch_size = self.text_encoder.batch_size * 2
203
+ else:
204
+ self.prior.batch_size = self.text_encoder.batch_size
205
+
206
+ @property
207
+ def batch_size(self):
208
+ return self.text_encoder.batch_size
209
+
210
+ @property
211
+ def image_size(self):
212
+ return self.image_encoder.image_size
213
+
214
+
215
+ class RBLNKandinskyV22CombinedPipelineBaseConfig(RBLNModelConfig):
216
+ """Base configuration class for Kandinsky V2.2 combined pipelines."""
217
+
218
+ submodules = ["prior_pipe", "decoder_pipe"]
219
+ _decoder_pipe_cls = RBLNKandinskyV22PipelineConfig
220
+
221
+ def __init__(
222
+ self,
223
+ prior_pipe: Optional[RBLNKandinskyV22PriorPipelineConfig] = None,
224
+ decoder_pipe: Optional[RBLNKandinskyV22PipelineConfig] = None,
225
+ *,
226
+ sample_size: Optional[Tuple[int, int]] = None,
227
+ image_size: Optional[Tuple[int, int]] = None,
228
+ batch_size: Optional[int] = None,
229
+ img_height: Optional[int] = None,
230
+ img_width: Optional[int] = None,
231
+ height: Optional[int] = None,
232
+ width: Optional[int] = None,
233
+ guidance_scale: Optional[float] = None,
234
+ prior_prior: Optional[RBLNPriorTransformerConfig] = None,
235
+ prior_image_encoder: Optional[RBLNCLIPVisionModelWithProjectionConfig] = None,
236
+ prior_text_encoder: Optional[RBLNCLIPTextModelWithProjectionConfig] = None,
237
+ unet: Optional[RBLNUNet2DConditionModelConfig] = None,
238
+ movq: Optional[RBLNVQModelConfig] = None,
239
+ **kwargs: Any,
240
+ ):
241
+ """
242
+ Initialize a configuration for combined Kandinsky 2.2 pipelines optimized for RBLN NPU.
243
+
244
+ This configuration integrates both the prior and decoder components of Kandinsky 2.2 into
245
+ a unified pipeline, allowing for end-to-end text-to-image generation in a single model.
246
+ It combines the text/image encoding, prior mapping, and diffusion steps together.
247
+
248
+ Args:
249
+ prior_pipe (Optional[RBLNKandinskyV22PriorPipelineConfig]): Configuration for the prior pipeline.
250
+ Initialized as RBLNKandinskyV22PriorPipelineConfig if not provided.
251
+ decoder_pipe (Optional[RBLNKandinskyV22PipelineConfig]): Configuration for the decoder pipeline.
252
+ Initialized as RBLNKandinskyV22PipelineConfig if not provided.
253
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
254
+ image_size (Optional[Tuple[int, int]]): Dimensions for the generated images.
255
+ Cannot be used together with img_height/img_width.
256
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
257
+ img_height (Optional[int]): Height of the generated images.
258
+ img_width (Optional[int]): Width of the generated images.
259
+ height (Optional[int]): Height of the generated images.
260
+ width (Optional[int]): Width of the generated images.
261
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
262
+ prior_prior (Optional[RBLNPriorTransformerConfig]): Direct configuration for the prior transformer.
263
+ Used if prior_pipe is not provided.
264
+ prior_image_encoder (Optional[RBLNCLIPVisionModelWithProjectionConfig]): Direct configuration for the image encoder.
265
+ Used if prior_pipe is not provided.
266
+ prior_text_encoder (Optional[RBLNCLIPTextModelWithProjectionConfig]): Direct configuration for the text encoder.
267
+ Used if prior_pipe is not provided.
268
+ unet (Optional[RBLNUNet2DConditionModelConfig]): Direct configuration for the UNet.
269
+ Used if decoder_pipe is not provided.
270
+ movq (Optional[RBLNVQModelConfig]): Direct configuration for the MoVQ (VQ-GAN) model.
271
+ Used if decoder_pipe is not provided.
272
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
273
+ """
274
+ super().__init__(**kwargs)
275
+
276
+ # Initial check for image_size conflict remains as is
277
+ if image_size is not None and (
278
+ img_height is not None or img_width is not None or height is not None or width is not None
279
+ ):
280
+ raise ValueError("image_size cannot be provided alongside img_height/img_width or height/width")
281
+
282
+ # Prioritize height/width (HF-aligned)
283
+ if height is not None and width is not None:
284
+ if img_height is not None or img_width is not None:
285
+ # Raise error if both sets of arguments are provided
286
+ raise ValueError(
287
+ "Cannot provide both 'height'/'width' and 'img_height'/'img_width' simultaneously. "
288
+ "Please use one set of arguments for image dimensions, preferring 'height'/'width'."
289
+ )
290
+ image_size = (height, width)
291
+ elif (height is not None and width is None) or (height is None and width is not None):
292
+ raise ValueError("Both height and width must be provided together if used")
293
+ # Fallback to img_height/img_width for backward compatibility
294
+ elif img_height is not None and img_width is not None:
295
+ image_size = (img_height, img_width)
296
+ elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
297
+ raise ValueError("Both img_height and img_width must be provided together if used")
298
+
299
+ self.prior_pipe = self.initialize_submodule_config(
300
+ prior_pipe,
301
+ cls_name="RBLNKandinskyV22PriorPipelineConfig",
302
+ prior=prior_prior,
303
+ image_encoder=prior_image_encoder,
304
+ text_encoder=prior_text_encoder,
305
+ batch_size=batch_size,
306
+ guidance_scale=guidance_scale,
307
+ )
308
+ self.decoder_pipe = self.initialize_submodule_config(
309
+ decoder_pipe,
310
+ cls_name=self._decoder_pipe_cls.__name__,
311
+ unet=unet,
312
+ movq=movq,
313
+ batch_size=batch_size,
314
+ sample_size=sample_size,
315
+ image_size=image_size,
316
+ guidance_scale=guidance_scale,
317
+ )
318
+
319
+ @property
320
+ def batch_size(self):
321
+ return self.prior_pipe.batch_size
322
+
323
+ @property
324
+ def image_size(self):
325
+ return self.prior_pipe.image_size
326
+
327
+ @property
328
+ def prior_prior(self):
329
+ return self.prior_pipe.prior
330
+
331
+ @property
332
+ def prior_image_encoder(self):
333
+ return self.prior_pipe.image_encoder
334
+
335
+ @property
336
+ def prior_text_encoder(self):
337
+ return self.prior_pipe.text_encoder
338
+
339
+ @property
340
+ def unet(self):
341
+ return self.decoder_pipe.unet
342
+
343
+ @property
344
+ def movq(self):
345
+ return self.decoder_pipe.movq
346
+
347
+
348
+ class RBLNKandinskyV22CombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
349
+ """Configuration class for the Kandinsky V2.2 combined text-to-image pipeline."""
350
+
351
+ _decoder_pipe_cls = RBLNKandinskyV22PipelineConfig
352
+
353
+
354
+ class RBLNKandinskyV22InpaintCombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
355
+ """Configuration class for the Kandinsky V2.2 combined inpainting pipeline."""
356
+
357
+ _decoder_pipe_cls = RBLNKandinskyV22InpaintPipelineConfig
358
+
359
+
360
+ class RBLNKandinskyV22Img2ImgCombinedPipelineConfig(RBLNKandinskyV22CombinedPipelineBaseConfig):
361
+ """Configuration class for the Kandinsky V2.2 combined image-to-image pipeline."""
362
+
363
+ _decoder_pipe_cls = RBLNKandinskyV22Img2ImgPipelineConfig
@@ -0,0 +1,156 @@
1
+ # Copyright 2025 Rebellions Inc. All rights reserved.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ from ....configuration_utils import RBLNModelConfig
18
+ from ....transformers import RBLNCLIPTextModelConfig
19
+ from ..models import RBLNAutoencoderKLConfig, RBLNUNet2DConditionModelConfig
20
+
21
+
22
+ class RBLNStableDiffusionPipelineBaseConfig(RBLNModelConfig):
23
+ submodules = ["text_encoder", "unet", "vae"]
24
+ _vae_uses_encoder = False
25
+
26
+ def __init__(
27
+ self,
28
+ text_encoder: Optional[RBLNCLIPTextModelConfig] = None,
29
+ unet: Optional[RBLNUNet2DConditionModelConfig] = None,
30
+ vae: Optional[RBLNAutoencoderKLConfig] = None,
31
+ *,
32
+ batch_size: Optional[int] = None,
33
+ img_height: Optional[int] = None,
34
+ img_width: Optional[int] = None,
35
+ height: Optional[int] = None,
36
+ width: Optional[int] = None,
37
+ sample_size: Optional[Tuple[int, int]] = None,
38
+ image_size: Optional[Tuple[int, int]] = None,
39
+ guidance_scale: Optional[float] = None,
40
+ **kwargs: Any,
41
+ ):
42
+ """
43
+ Args:
44
+ text_encoder (Optional[RBLNCLIPTextModelConfig]): Configuration for the text encoder component.
45
+ Initialized as RBLNCLIPTextModelConfig if not provided.
46
+ unet (Optional[RBLNUNet2DConditionModelConfig]): Configuration for the UNet model component.
47
+ Initialized as RBLNUNet2DConditionModelConfig if not provided.
48
+ vae (Optional[RBLNAutoencoderKLConfig]): Configuration for the VAE model component.
49
+ Initialized as RBLNAutoencoderKLConfig if not provided.
50
+ batch_size (Optional[int]): Batch size for inference, applied to all submodules.
51
+ img_height (Optional[int]): Height of the generated images.
52
+ img_width (Optional[int]): Width of the generated images.
53
+ height (Optional[int]): Height of the generated images.
54
+ width (Optional[int]): Width of the generated images.
55
+ sample_size (Optional[Tuple[int, int]]): Spatial dimensions for the UNet model.
56
+ image_size (Optional[Tuple[int, int]]): Alternative way to specify image dimensions.
57
+ Cannot be used together with img_height/img_width.
58
+ guidance_scale (Optional[float]): Scale for classifier-free guidance.
59
+ kwargs: Additional arguments passed to the parent RBLNModelConfig.
60
+
61
+ Raises:
62
+ ValueError: If both image_size and img_height/img_width are provided.
63
+
64
+ Note:
65
+ When guidance_scale > 1.0, the UNet batch size is automatically doubled to
66
+ accommodate classifier-free guidance.
67
+ """
68
+ super().__init__(**kwargs)
69
+
70
+ # Initial check for image_size conflict remains as is
71
+ if image_size is not None and (
72
+ img_height is not None or img_width is not None or height is not None or width is not None
73
+ ):
74
+ raise ValueError("image_size cannot be provided alongside img_height/img_width or height/width")
75
+
76
+ # Prioritize height/width (HF-aligned)
77
+ if height is not None and width is not None:
78
+ if img_height is not None or img_width is not None:
79
+ # Raise error if both sets of arguments are provided
80
+ raise ValueError(
81
+ "Cannot provide both 'height'/'width' and 'img_height'/'img_width' simultaneously. "
82
+ "Please use one set of arguments for image dimensions, preferring 'height'/'width'."
83
+ )
84
+ image_size = (height, width)
85
+ elif (height is not None and width is None) or (height is None and width is not None):
86
+ raise ValueError("Both height and width must be provided together if used")
87
+ # Fallback to img_height/img_width for backward compatibility
88
+ elif img_height is not None and img_width is not None:
89
+ image_size = (img_height, img_width)
90
+ elif (img_height is not None and img_width is None) or (img_height is None and img_width is not None):
91
+ raise ValueError("Both img_height and img_width must be provided together if used")
92
+
93
+ self.text_encoder = self.initialize_submodule_config(
94
+ text_encoder,
95
+ cls_name="RBLNCLIPTextModelConfig",
96
+ batch_size=batch_size,
97
+ )
98
+ self.unet = self.initialize_submodule_config(
99
+ unet,
100
+ cls_name="RBLNUNet2DConditionModelConfig",
101
+ sample_size=sample_size,
102
+ )
103
+ self.vae = self.initialize_submodule_config(
104
+ vae,
105
+ cls_name="RBLNAutoencoderKLConfig",
106
+ batch_size=batch_size,
107
+ uses_encoder=self.__class__._vae_uses_encoder,
108
+ sample_size=image_size,
109
+ )
110
+
111
+ # Get default guidance scale from original class to set UNet batch size
112
+ if guidance_scale is None:
113
+ guidance_scale = self.get_default_values_for_original_cls("__call__", ["guidance_scale"])["guidance_scale"]
114
+
115
+ if not self.unet.batch_size_is_specified:
116
+ do_classifier_free_guidance = guidance_scale > 1.0
117
+ if do_classifier_free_guidance:
118
+ self.unet.batch_size = self.text_encoder.batch_size * 2
119
+ else:
120
+ self.unet.batch_size = self.text_encoder.batch_size
121
+
122
+ @property
123
+ def batch_size(self):
124
+ return self.vae.batch_size
125
+
126
+ @property
127
+ def sample_size(self):
128
+ return self.unet.sample_size
129
+
130
+ @property
131
+ def image_size(self):
132
+ return self.vae.sample_size
133
+
134
+
135
+ class RBLNStableDiffusionPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
136
+ """
137
+ Configuration for Stable Diffusion pipeline.
138
+ """
139
+
140
+ _vae_uses_encoder = False
141
+
142
+
143
+ class RBLNStableDiffusionImg2ImgPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
144
+ """
145
+ Configuration for Stable Diffusion image-to-image pipeline.
146
+ """
147
+
148
+ _vae_uses_encoder = True
149
+
150
+
151
+ class RBLNStableDiffusionInpaintPipelineConfig(RBLNStableDiffusionPipelineBaseConfig):
152
+ """
153
+ Configuration for Stable Diffusion inpainting pipeline.
154
+ """
155
+
156
+ _vae_uses_encoder = True