diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -60,6 +60,16 @@ if is_invisible_watermark_available():
60
60
  from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
61
61
 
62
62
 
63
+ from ...utils import is_torch_xla_available
64
+
65
+
66
+ if is_torch_xla_available():
67
+ import torch_xla.core.xla_model as xm
68
+
69
+ XLA_AVAILABLE = True
70
+ else:
71
+ XLA_AVAILABLE = False
72
+
63
73
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
74
 
65
75
 
@@ -209,6 +219,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
209
219
  "add_time_ids",
210
220
  "mask",
211
221
  "masked_image_latents",
222
+ "control_image",
212
223
  ]
213
224
 
214
225
  def __init__(
@@ -246,7 +257,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
246
257
  )
247
258
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
248
259
  self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
249
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
260
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
250
261
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
251
262
  self.mask_processor = VaeImageProcessor(
252
263
  vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
@@ -388,7 +399,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
388
399
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
389
400
 
390
401
  # We are only ALWAYS interested in the pooled output of the final text encoder
391
- pooled_prompt_embeds = prompt_embeds[0]
402
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
403
+ pooled_prompt_embeds = prompt_embeds[0]
404
+
392
405
  if clip_skip is None:
393
406
  prompt_embeds = prompt_embeds.hidden_states[-2]
394
407
  else:
@@ -447,8 +460,10 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
447
460
  uncond_input.input_ids.to(device),
448
461
  output_hidden_states=True,
449
462
  )
463
+
450
464
  # We are only ALWAYS interested in the pooled output of the final text encoder
451
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
465
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
466
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
452
467
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
453
468
 
454
469
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -712,7 +727,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
712
727
  if padding_mask_crop is not None:
713
728
  if not isinstance(image, PIL.Image.Image):
714
729
  raise ValueError(
715
- f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
730
+ f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
716
731
  )
717
732
  if not isinstance(mask_image, PIL.Image.Image):
718
733
  raise ValueError(
@@ -720,7 +735,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
720
735
  f" {type(mask_image)}."
721
736
  )
722
737
  if output_type != "pil":
723
- raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
738
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
724
739
 
725
740
  if prompt_embeds is not None and pooled_prompt_embeds is None:
726
741
  raise ValueError(
@@ -752,26 +767,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
752
767
  else:
753
768
  assert False
754
769
 
755
- # Check `controlnet_conditioning_scale`
756
- if (
757
- isinstance(self.controlnet, ControlNetModel)
758
- or is_compiled
759
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
760
- ):
761
- if not isinstance(controlnet_conditioning_scale, float):
762
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
763
-
764
- elif (
765
- isinstance(self.controlnet, ControlNetUnionModel)
766
- or is_compiled
767
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
768
- ):
769
- if not isinstance(controlnet_conditioning_scale, float):
770
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
771
-
772
- else:
773
- assert False
774
-
775
770
  if not isinstance(control_guidance_start, (tuple, list)):
776
771
  control_guidance_start = [control_guidance_start]
777
772
 
@@ -1356,6 +1351,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1356
1351
 
1357
1352
  if not isinstance(control_image, list):
1358
1353
  control_image = [control_image]
1354
+ else:
1355
+ control_image = control_image.copy()
1359
1356
 
1360
1357
  if not isinstance(control_mode, list):
1361
1358
  control_mode = [control_mode]
@@ -1747,6 +1744,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1747
1744
  latents = callback_outputs.pop("latents", latents)
1748
1745
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1749
1746
  negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1747
+ control_image = callback_outputs.pop("control_image", control_image)
1750
1748
 
1751
1749
  # call the callback, if provided
1752
1750
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -1755,6 +1753,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1755
1753
  step_idx = i // getattr(self.scheduler, "order", 1)
1756
1754
  callback(step_idx, t, latents)
1757
1755
 
1756
+ if XLA_AVAILABLE:
1757
+ xm.mark_step()
1758
+
1758
1759
  # make sure the VAE is in float32 mode, as it overflows in float16
1759
1760
  if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1760
1761
  self.upcast_vae()
@@ -19,7 +19,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
19
  import numpy as np
20
20
  import PIL.Image
21
21
  import torch
22
- import torch.nn.functional as F
23
22
  from transformers import (
24
23
  CLIPImageProcessor,
25
24
  CLIPTextModel,
@@ -38,7 +37,13 @@ from ...loaders import (
38
37
  StableDiffusionXLLoraLoaderMixin,
39
38
  TextualInversionLoaderMixin,
40
39
  )
41
- from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
40
+ from ...models import (
41
+ AutoencoderKL,
42
+ ControlNetUnionModel,
43
+ ImageProjection,
44
+ MultiControlNetUnionModel,
45
+ UNet2DConditionModel,
46
+ )
42
47
  from ...models.attention_processor import (
43
48
  AttnProcessor2_0,
44
49
  XFormersAttnProcessor,
@@ -60,6 +65,17 @@ from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutpu
60
65
  if is_invisible_watermark_available():
61
66
  from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
62
67
 
68
+
69
+ from ...utils import is_torch_xla_available
70
+
71
+
72
+ if is_torch_xla_available():
73
+ import torch_xla.core.xla_model as xm
74
+
75
+ XLA_AVAILABLE = True
76
+ else:
77
+ XLA_AVAILABLE = False
78
+
63
79
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
64
80
 
65
81
 
@@ -233,7 +249,9 @@ class StableDiffusionXLControlNetUnionPipeline(
233
249
  tokenizer: CLIPTokenizer,
234
250
  tokenizer_2: CLIPTokenizer,
235
251
  unet: UNet2DConditionModel,
236
- controlnet: ControlNetUnionModel,
252
+ controlnet: Union[
253
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
254
+ ],
237
255
  scheduler: KarrasDiffusionSchedulers,
238
256
  force_zeros_for_empty_prompt: bool = True,
239
257
  add_watermarker: Optional[bool] = None,
@@ -242,8 +260,8 @@ class StableDiffusionXLControlNetUnionPipeline(
242
260
  ):
243
261
  super().__init__()
244
262
 
245
- if not isinstance(controlnet, ControlNetUnionModel):
246
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
263
+ if isinstance(controlnet, (list, tuple)):
264
+ controlnet = MultiControlNetUnionModel(controlnet)
247
265
 
248
266
  self.register_modules(
249
267
  vae=vae,
@@ -257,7 +275,7 @@ class StableDiffusionXLControlNetUnionPipeline(
257
275
  feature_extractor=feature_extractor,
258
276
  image_encoder=image_encoder,
259
277
  )
260
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
278
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
261
279
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
262
280
  self.control_image_processor = VaeImageProcessor(
263
281
  vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
@@ -397,7 +415,9 @@ class StableDiffusionXLControlNetUnionPipeline(
397
415
  prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
398
416
 
399
417
  # We are only ALWAYS interested in the pooled output of the final text encoder
400
- pooled_prompt_embeds = prompt_embeds[0]
418
+ if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
419
+ pooled_prompt_embeds = prompt_embeds[0]
420
+
401
421
  if clip_skip is None:
402
422
  prompt_embeds = prompt_embeds.hidden_states[-2]
403
423
  else:
@@ -456,8 +476,10 @@ class StableDiffusionXLControlNetUnionPipeline(
456
476
  uncond_input.input_ids.to(device),
457
477
  output_hidden_states=True,
458
478
  )
479
+
459
480
  # We are only ALWAYS interested in the pooled output of the final text encoder
460
- negative_pooled_prompt_embeds = negative_prompt_embeds[0]
481
+ if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
482
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
461
483
  negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
462
484
 
463
485
  negative_prompt_embeds_list.append(negative_prompt_embeds)
@@ -649,6 +671,7 @@ class StableDiffusionXLControlNetUnionPipeline(
649
671
  controlnet_conditioning_scale=1.0,
650
672
  control_guidance_start=0.0,
651
673
  control_guidance_end=1.0,
674
+ control_mode=None,
652
675
  callback_on_step_end_tensor_inputs=None,
653
676
  ):
654
677
  if callback_on_step_end_tensor_inputs is not None and not all(
@@ -706,66 +729,90 @@ class StableDiffusionXLControlNetUnionPipeline(
706
729
  "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
707
730
  )
708
731
 
709
- # Check `image`
710
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
711
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
712
- )
713
- if (
714
- isinstance(self.controlnet, ControlNetModel)
715
- or is_compiled
716
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
717
- ):
718
- self.check_image(image, prompt, prompt_embeds)
719
- elif (
720
- isinstance(self.controlnet, ControlNetUnionModel)
721
- or is_compiled
722
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
723
- ):
724
- self.check_image(image, prompt, prompt_embeds)
725
-
726
- else:
727
- assert False
728
-
729
- # Check `controlnet_conditioning_scale`
730
- if (
731
- isinstance(self.controlnet, ControlNetModel)
732
- or is_compiled
733
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
734
- ):
735
- if not isinstance(controlnet_conditioning_scale, float):
736
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
732
+ # `prompt` needs more sophisticated handling when there are multiple
733
+ # conditionings.
734
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
735
+ if isinstance(prompt, list):
736
+ logger.warning(
737
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
738
+ " prompts. The conditionings will be fixed across the prompts."
739
+ )
737
740
 
738
- elif (
739
- isinstance(self.controlnet, ControlNetUnionModel)
740
- or is_compiled
741
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
742
- ):
743
- if not isinstance(controlnet_conditioning_scale, float):
744
- raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
741
+ # Check `image`
742
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
745
743
 
746
- else:
747
- assert False
744
+ if isinstance(controlnet, ControlNetUnionModel):
745
+ for image_ in image:
746
+ self.check_image(image_, prompt, prompt_embeds)
747
+ elif isinstance(controlnet, MultiControlNetUnionModel):
748
+ if not isinstance(image, list):
749
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
750
+ elif not all(isinstance(i, list) for i in image):
751
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
752
+ elif len(image) != len(self.controlnet.nets):
753
+ raise ValueError(
754
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
755
+ )
748
756
 
749
- if not isinstance(control_guidance_start, (tuple, list)):
750
- control_guidance_start = [control_guidance_start]
757
+ for images_ in image:
758
+ for image_ in images_:
759
+ self.check_image(image_, prompt, prompt_embeds)
751
760
 
752
- if not isinstance(control_guidance_end, (tuple, list)):
753
- control_guidance_end = [control_guidance_end]
761
+ # Check `controlnet_conditioning_scale`
762
+ if isinstance(controlnet, MultiControlNetUnionModel):
763
+ if isinstance(controlnet_conditioning_scale, list):
764
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
765
+ raise ValueError("A single batch of multiple conditionings is not supported at the moment.")
766
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
767
+ self.controlnet.nets
768
+ ):
769
+ raise ValueError(
770
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
771
+ " the same length as the number of controlnets"
772
+ )
754
773
 
755
774
  if len(control_guidance_start) != len(control_guidance_end):
756
775
  raise ValueError(
757
776
  f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
758
777
  )
759
778
 
779
+ if isinstance(controlnet, MultiControlNetUnionModel):
780
+ if len(control_guidance_start) != len(self.controlnet.nets):
781
+ raise ValueError(
782
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
783
+ )
784
+
760
785
  for start, end in zip(control_guidance_start, control_guidance_end):
761
786
  if start >= end:
762
787
  raise ValueError(
763
- f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
788
+ f"control_guidance_start: {start} cannot be larger or equal to control guidance end: {end}."
764
789
  )
765
790
  if start < 0.0:
766
- raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
791
+ raise ValueError(f"control_guidance_start: {start} can't be smaller than 0.")
767
792
  if end > 1.0:
768
- raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
793
+ raise ValueError(f"control_guidance_end: {end} can't be larger than 1.0.")
794
+
795
+ # Check `control_mode`
796
+ if isinstance(controlnet, ControlNetUnionModel):
797
+ if max(control_mode) >= controlnet.config.num_control_type:
798
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
799
+ elif isinstance(controlnet, MultiControlNetUnionModel):
800
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
801
+ if max(_control_mode) >= _controlnet.config.num_control_type:
802
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
803
+
804
+ # Equal number of `image` and `control_mode` elements
805
+ if isinstance(controlnet, ControlNetUnionModel):
806
+ if len(image) != len(control_mode):
807
+ raise ValueError("Expected len(control_image) == len(control_mode)")
808
+ elif isinstance(controlnet, MultiControlNetUnionModel):
809
+ if not all(isinstance(i, list) for i in control_mode):
810
+ raise ValueError(
811
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
812
+ )
813
+
814
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
815
+ raise ValueError("Expected len(control_image) == len(control_mode)")
769
816
 
770
817
  if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
771
818
  raise ValueError(
@@ -941,7 +988,7 @@ class StableDiffusionXLControlNetUnionPipeline(
941
988
  self,
942
989
  prompt: Union[str, List[str]] = None,
943
990
  prompt_2: Optional[Union[str, List[str]]] = None,
944
- control_image: PipelineImageInput = None,
991
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
945
992
  height: Optional[int] = None,
946
993
  width: Optional[int] = None,
947
994
  num_inference_steps: int = 50,
@@ -968,7 +1015,7 @@ class StableDiffusionXLControlNetUnionPipeline(
968
1015
  guess_mode: bool = False,
969
1016
  control_guidance_start: Union[float, List[float]] = 0.0,
970
1017
  control_guidance_end: Union[float, List[float]] = 1.0,
971
- control_mode: Optional[Union[int, List[int]]] = None,
1018
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
972
1019
  original_size: Tuple[int, int] = None,
973
1020
  crops_coords_top_left: Tuple[int, int] = (0, 0),
974
1021
  target_size: Tuple[int, int] = None,
@@ -990,7 +1037,7 @@ class StableDiffusionXLControlNetUnionPipeline(
990
1037
  prompt_2 (`str` or `List[str]`, *optional*):
991
1038
  The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
992
1039
  used in both text-encoders.
993
- control_image (`PipelineImageInput`):
1040
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
994
1041
  The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
995
1042
  specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
996
1043
  as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
@@ -1082,6 +1129,11 @@ class StableDiffusionXLControlNetUnionPipeline(
1082
1129
  The percentage of total steps at which the ControlNet starts applying.
1083
1130
  control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1084
1131
  The percentage of total steps at which the ControlNet stops applying.
1132
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
1133
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
1134
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
1135
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
1136
+ conditions in control_image.
1085
1137
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1086
1138
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1087
1139
  `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1137,47 +1189,61 @@ class StableDiffusionXLControlNetUnionPipeline(
1137
1189
 
1138
1190
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1139
1191
 
1140
- # align format for control guidance
1141
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1142
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1143
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1144
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1145
-
1146
1192
  if not isinstance(control_image, list):
1147
1193
  control_image = [control_image]
1194
+ else:
1195
+ control_image = control_image.copy()
1148
1196
 
1149
1197
  if not isinstance(control_mode, list):
1150
1198
  control_mode = [control_mode]
1151
1199
 
1152
- if len(control_image) != len(control_mode):
1153
- raise ValueError("Expected len(control_image) == len(control_type)")
1200
+ if isinstance(controlnet, MultiControlNetUnionModel):
1201
+ control_image = [[item] for item in control_image]
1202
+ control_mode = [[item] for item in control_mode]
1203
+
1204
+ # align format for control guidance
1205
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1206
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1207
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1208
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1209
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1210
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1211
+ control_guidance_start, control_guidance_end = (
1212
+ mult * [control_guidance_start],
1213
+ mult * [control_guidance_end],
1214
+ )
1154
1215
 
1155
- num_control_type = controlnet.config.num_control_type
1216
+ if isinstance(controlnet_conditioning_scale, float):
1217
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1218
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
1156
1219
 
1157
1220
  # 1. Check inputs
1158
- control_type = [0 for _ in range(num_control_type)]
1159
- # 1. Check inputs. Raise error if not correct
1160
- for _image, control_idx in zip(control_image, control_mode):
1161
- control_type[control_idx] = 1
1162
- self.check_inputs(
1163
- prompt,
1164
- prompt_2,
1165
- _image,
1166
- negative_prompt,
1167
- negative_prompt_2,
1168
- prompt_embeds,
1169
- negative_prompt_embeds,
1170
- pooled_prompt_embeds,
1171
- ip_adapter_image,
1172
- ip_adapter_image_embeds,
1173
- negative_pooled_prompt_embeds,
1174
- controlnet_conditioning_scale,
1175
- control_guidance_start,
1176
- control_guidance_end,
1177
- callback_on_step_end_tensor_inputs,
1178
- )
1221
+ self.check_inputs(
1222
+ prompt,
1223
+ prompt_2,
1224
+ control_image,
1225
+ negative_prompt,
1226
+ negative_prompt_2,
1227
+ prompt_embeds,
1228
+ negative_prompt_embeds,
1229
+ pooled_prompt_embeds,
1230
+ ip_adapter_image,
1231
+ ip_adapter_image_embeds,
1232
+ negative_pooled_prompt_embeds,
1233
+ controlnet_conditioning_scale,
1234
+ control_guidance_start,
1235
+ control_guidance_end,
1236
+ control_mode,
1237
+ callback_on_step_end_tensor_inputs,
1238
+ )
1179
1239
 
1180
- control_type = torch.Tensor(control_type)
1240
+ if isinstance(controlnet, ControlNetUnionModel):
1241
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
1242
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1243
+ control_type = [
1244
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
1245
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
1246
+ ]
1181
1247
 
1182
1248
  self._guidance_scale = guidance_scale
1183
1249
  self._clip_skip = clip_skip
@@ -1195,7 +1261,11 @@ class StableDiffusionXLControlNetUnionPipeline(
1195
1261
 
1196
1262
  device = self._execution_device
1197
1263
 
1198
- global_pool_conditions = controlnet.config.global_pool_conditions
1264
+ global_pool_conditions = (
1265
+ controlnet.config.global_pool_conditions
1266
+ if isinstance(controlnet, ControlNetUnionModel)
1267
+ else controlnet.nets[0].config.global_pool_conditions
1268
+ )
1199
1269
  guess_mode = guess_mode or global_pool_conditions
1200
1270
 
1201
1271
  # 3.1 Encode input prompt
@@ -1234,19 +1304,51 @@ class StableDiffusionXLControlNetUnionPipeline(
1234
1304
  )
1235
1305
 
1236
1306
  # 4. Prepare image
1237
- for idx, _ in enumerate(control_image):
1238
- control_image[idx] = self.prepare_image(
1239
- image=control_image[idx],
1240
- width=width,
1241
- height=height,
1242
- batch_size=batch_size * num_images_per_prompt,
1243
- num_images_per_prompt=num_images_per_prompt,
1244
- device=device,
1245
- dtype=controlnet.dtype,
1246
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1247
- guess_mode=guess_mode,
1248
- )
1249
- height, width = control_image[idx].shape[-2:]
1307
+ if isinstance(controlnet, ControlNetUnionModel):
1308
+ control_images = []
1309
+
1310
+ for image_ in control_image:
1311
+ image_ = self.prepare_image(
1312
+ image=image_,
1313
+ width=width,
1314
+ height=height,
1315
+ batch_size=batch_size * num_images_per_prompt,
1316
+ num_images_per_prompt=num_images_per_prompt,
1317
+ device=device,
1318
+ dtype=controlnet.dtype,
1319
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1320
+ guess_mode=guess_mode,
1321
+ )
1322
+
1323
+ control_images.append(image_)
1324
+
1325
+ control_image = control_images
1326
+ height, width = control_image[0].shape[-2:]
1327
+
1328
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1329
+ control_images = []
1330
+
1331
+ for control_image_ in control_image:
1332
+ images = []
1333
+
1334
+ for image_ in control_image_:
1335
+ image_ = self.prepare_image(
1336
+ image=image_,
1337
+ width=width,
1338
+ height=height,
1339
+ batch_size=batch_size * num_images_per_prompt,
1340
+ num_images_per_prompt=num_images_per_prompt,
1341
+ device=device,
1342
+ dtype=controlnet.dtype,
1343
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1344
+ guess_mode=guess_mode,
1345
+ )
1346
+
1347
+ images.append(image_)
1348
+ control_images.append(images)
1349
+
1350
+ control_image = control_images
1351
+ height, width = control_image[0][0].shape[-2:]
1250
1352
 
1251
1353
  # 5. Prepare timesteps
1252
1354
  timesteps, num_inference_steps = retrieve_timesteps(
@@ -1281,10 +1383,11 @@ class StableDiffusionXLControlNetUnionPipeline(
1281
1383
  # 7.1 Create tensor stating which controlnets to keep
1282
1384
  controlnet_keep = []
1283
1385
  for i in range(len(timesteps)):
1284
- controlnet_keep.append(
1285
- 1.0
1286
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1287
- )
1386
+ keeps = [
1387
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1388
+ for s, e in zip(control_guidance_start, control_guidance_end)
1389
+ ]
1390
+ controlnet_keep.append(keeps)
1288
1391
 
1289
1392
  # 7.2 Prepare added time ids & embeddings
1290
1393
  original_size = original_size or (height, width)
@@ -1349,11 +1452,20 @@ class StableDiffusionXLControlNetUnionPipeline(
1349
1452
  is_controlnet_compiled = is_compiled_module(self.controlnet)
1350
1453
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1351
1454
 
1352
- control_type = (
1353
- control_type.reshape(1, -1)
1354
- .to(device, dtype=prompt_embeds.dtype)
1355
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1356
- )
1455
+ if isinstance(controlnet, ControlNetUnionModel):
1456
+ control_type = (
1457
+ control_type.reshape(1, -1)
1458
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1459
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
1460
+ )
1461
+ if isinstance(controlnet, MultiControlNetUnionModel):
1462
+ control_type = [
1463
+ _control_type.reshape(1, -1)
1464
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1465
+ .repeat(batch_size * num_images_per_prompt * 2, 1)
1466
+ for _control_type in control_type
1467
+ ]
1468
+
1357
1469
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1358
1470
  for i, t in enumerate(timesteps):
1359
1471
  if self.interrupt:
@@ -1454,6 +1566,9 @@ class StableDiffusionXLControlNetUnionPipeline(
1454
1566
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1455
1567
  progress_bar.update()
1456
1568
 
1569
+ if XLA_AVAILABLE:
1570
+ xm.mark_step()
1571
+
1457
1572
  if not output_type == "latent":
1458
1573
  # make sure the VAE is in float32 mode, as it overflows in float16
1459
1574
  needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast