diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (551) hide show
  1. diffusers/__init__.py +145 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/__init__.py +1 -1
  4. diffusers/commands/custom_blocks.py +134 -0
  5. diffusers/commands/diffusers_cli.py +3 -1
  6. diffusers/commands/env.py +1 -1
  7. diffusers/commands/fp16_safetensors.py +2 -2
  8. diffusers/configuration_utils.py +11 -2
  9. diffusers/dependency_versions_check.py +1 -1
  10. diffusers/dependency_versions_table.py +3 -3
  11. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  12. diffusers/guiders/__init__.py +41 -0
  13. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  14. diffusers/guiders/auto_guidance.py +190 -0
  15. diffusers/guiders/classifier_free_guidance.py +141 -0
  16. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  17. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  18. diffusers/guiders/guider_utils.py +309 -0
  19. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  20. diffusers/guiders/skip_layer_guidance.py +262 -0
  21. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  22. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  23. diffusers/hooks/__init__.py +17 -0
  24. diffusers/hooks/_common.py +56 -0
  25. diffusers/hooks/_helpers.py +293 -0
  26. diffusers/hooks/faster_cache.py +9 -8
  27. diffusers/hooks/first_block_cache.py +259 -0
  28. diffusers/hooks/group_offloading.py +332 -227
  29. diffusers/hooks/hooks.py +58 -3
  30. diffusers/hooks/layer_skip.py +263 -0
  31. diffusers/hooks/layerwise_casting.py +5 -10
  32. diffusers/hooks/pyramid_attention_broadcast.py +15 -12
  33. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  34. diffusers/hooks/utils.py +43 -0
  35. diffusers/image_processor.py +7 -2
  36. diffusers/loaders/__init__.py +10 -0
  37. diffusers/loaders/ip_adapter.py +260 -18
  38. diffusers/loaders/lora_base.py +261 -127
  39. diffusers/loaders/lora_conversion_utils.py +657 -35
  40. diffusers/loaders/lora_pipeline.py +2778 -1246
  41. diffusers/loaders/peft.py +78 -112
  42. diffusers/loaders/single_file.py +2 -2
  43. diffusers/loaders/single_file_model.py +64 -15
  44. diffusers/loaders/single_file_utils.py +395 -7
  45. diffusers/loaders/textual_inversion.py +3 -2
  46. diffusers/loaders/transformer_flux.py +10 -11
  47. diffusers/loaders/transformer_sd3.py +8 -3
  48. diffusers/loaders/unet.py +24 -21
  49. diffusers/loaders/unet_loader_utils.py +6 -3
  50. diffusers/loaders/utils.py +1 -1
  51. diffusers/models/__init__.py +23 -1
  52. diffusers/models/activations.py +5 -5
  53. diffusers/models/adapter.py +2 -3
  54. diffusers/models/attention.py +488 -7
  55. diffusers/models/attention_dispatch.py +1218 -0
  56. diffusers/models/attention_flax.py +10 -10
  57. diffusers/models/attention_processor.py +113 -667
  58. diffusers/models/auto_model.py +49 -12
  59. diffusers/models/autoencoders/__init__.py +2 -0
  60. diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
  61. diffusers/models/autoencoders/autoencoder_dc.py +17 -4
  62. diffusers/models/autoencoders/autoencoder_kl.py +5 -5
  63. diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
  64. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
  65. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
  66. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
  67. diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
  68. diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
  69. diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
  70. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  71. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
  72. diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
  73. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
  74. diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
  75. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  76. diffusers/models/autoencoders/vae.py +13 -2
  77. diffusers/models/autoencoders/vq_model.py +2 -2
  78. diffusers/models/cache_utils.py +32 -10
  79. diffusers/models/controlnet.py +1 -1
  80. diffusers/models/controlnet_flux.py +1 -1
  81. diffusers/models/controlnet_sd3.py +1 -1
  82. diffusers/models/controlnet_sparsectrl.py +1 -1
  83. diffusers/models/controlnets/__init__.py +1 -0
  84. diffusers/models/controlnets/controlnet.py +3 -3
  85. diffusers/models/controlnets/controlnet_flax.py +1 -1
  86. diffusers/models/controlnets/controlnet_flux.py +21 -20
  87. diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
  88. diffusers/models/controlnets/controlnet_sana.py +290 -0
  89. diffusers/models/controlnets/controlnet_sd3.py +1 -1
  90. diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
  91. diffusers/models/controlnets/controlnet_union.py +5 -5
  92. diffusers/models/controlnets/controlnet_xs.py +7 -7
  93. diffusers/models/controlnets/multicontrolnet.py +4 -5
  94. diffusers/models/controlnets/multicontrolnet_union.py +5 -6
  95. diffusers/models/downsampling.py +2 -2
  96. diffusers/models/embeddings.py +36 -46
  97. diffusers/models/embeddings_flax.py +2 -2
  98. diffusers/models/lora.py +3 -3
  99. diffusers/models/model_loading_utils.py +233 -1
  100. diffusers/models/modeling_flax_utils.py +1 -2
  101. diffusers/models/modeling_utils.py +203 -108
  102. diffusers/models/normalization.py +4 -4
  103. diffusers/models/resnet.py +2 -2
  104. diffusers/models/resnet_flax.py +1 -1
  105. diffusers/models/transformers/__init__.py +7 -0
  106. diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
  107. diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
  108. diffusers/models/transformers/consisid_transformer_3d.py +1 -1
  109. diffusers/models/transformers/dit_transformer_2d.py +2 -2
  110. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  111. diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
  112. diffusers/models/transformers/latte_transformer_3d.py +4 -5
  113. diffusers/models/transformers/lumina_nextdit2d.py +2 -2
  114. diffusers/models/transformers/pixart_transformer_2d.py +3 -3
  115. diffusers/models/transformers/prior_transformer.py +1 -1
  116. diffusers/models/transformers/sana_transformer.py +8 -3
  117. diffusers/models/transformers/stable_audio_transformer.py +5 -9
  118. diffusers/models/transformers/t5_film_transformer.py +3 -3
  119. diffusers/models/transformers/transformer_2d.py +1 -1
  120. diffusers/models/transformers/transformer_allegro.py +1 -1
  121. diffusers/models/transformers/transformer_chroma.py +641 -0
  122. diffusers/models/transformers/transformer_cogview3plus.py +5 -10
  123. diffusers/models/transformers/transformer_cogview4.py +353 -27
  124. diffusers/models/transformers/transformer_cosmos.py +586 -0
  125. diffusers/models/transformers/transformer_flux.py +376 -138
  126. diffusers/models/transformers/transformer_hidream_image.py +942 -0
  127. diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
  128. diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
  129. diffusers/models/transformers/transformer_ltx.py +105 -24
  130. diffusers/models/transformers/transformer_lumina2.py +1 -1
  131. diffusers/models/transformers/transformer_mochi.py +1 -1
  132. diffusers/models/transformers/transformer_omnigen.py +2 -2
  133. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  134. diffusers/models/transformers/transformer_sd3.py +7 -7
  135. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  136. diffusers/models/transformers/transformer_temporal.py +1 -1
  137. diffusers/models/transformers/transformer_wan.py +316 -87
  138. diffusers/models/transformers/transformer_wan_vace.py +387 -0
  139. diffusers/models/unets/unet_1d.py +1 -1
  140. diffusers/models/unets/unet_1d_blocks.py +1 -1
  141. diffusers/models/unets/unet_2d.py +1 -1
  142. diffusers/models/unets/unet_2d_blocks.py +1 -1
  143. diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
  144. diffusers/models/unets/unet_2d_condition.py +4 -3
  145. diffusers/models/unets/unet_2d_condition_flax.py +2 -2
  146. diffusers/models/unets/unet_3d_blocks.py +1 -1
  147. diffusers/models/unets/unet_3d_condition.py +3 -3
  148. diffusers/models/unets/unet_i2vgen_xl.py +3 -3
  149. diffusers/models/unets/unet_kandinsky3.py +1 -1
  150. diffusers/models/unets/unet_motion_model.py +2 -2
  151. diffusers/models/unets/unet_stable_cascade.py +1 -1
  152. diffusers/models/upsampling.py +2 -2
  153. diffusers/models/vae_flax.py +2 -2
  154. diffusers/models/vq_model.py +1 -1
  155. diffusers/modular_pipelines/__init__.py +83 -0
  156. diffusers/modular_pipelines/components_manager.py +1068 -0
  157. diffusers/modular_pipelines/flux/__init__.py +66 -0
  158. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  159. diffusers/modular_pipelines/flux/decoders.py +109 -0
  160. diffusers/modular_pipelines/flux/denoise.py +227 -0
  161. diffusers/modular_pipelines/flux/encoders.py +412 -0
  162. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  163. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  164. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  165. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  166. diffusers/modular_pipelines/node_utils.py +665 -0
  167. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  168. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  169. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  170. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  171. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  172. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  173. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  174. diffusers/modular_pipelines/wan/__init__.py +66 -0
  175. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  176. diffusers/modular_pipelines/wan/decoders.py +105 -0
  177. diffusers/modular_pipelines/wan/denoise.py +261 -0
  178. diffusers/modular_pipelines/wan/encoders.py +242 -0
  179. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  180. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  181. diffusers/pipelines/__init__.py +68 -6
  182. diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
  183. diffusers/pipelines/amused/pipeline_amused.py +7 -6
  184. diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
  185. diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
  186. diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
  187. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
  188. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
  189. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
  190. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
  191. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
  192. diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
  193. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  194. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
  195. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
  196. diffusers/pipelines/auto_pipeline.py +23 -20
  197. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  198. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  199. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
  200. diffusers/pipelines/chroma/__init__.py +49 -0
  201. diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
  202. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
  203. diffusers/pipelines/chroma/pipeline_output.py +21 -0
  204. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
  205. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
  206. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
  207. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
  208. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
  209. diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
  210. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
  211. diffusers/pipelines/consisid/consisid_utils.py +2 -2
  212. diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
  213. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  214. diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
  215. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
  216. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
  217. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
  218. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
  219. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
  220. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
  221. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
  222. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
  223. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
  224. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  225. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
  226. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
  227. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  228. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
  229. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
  230. diffusers/pipelines/cosmos/__init__.py +54 -0
  231. diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
  232. diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
  233. diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
  234. diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
  235. diffusers/pipelines/cosmos/pipeline_output.py +40 -0
  236. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
  237. diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
  238. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  239. diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
  240. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
  241. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
  242. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
  243. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
  244. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
  245. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
  246. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
  247. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  248. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
  249. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  250. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
  251. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
  252. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  253. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  254. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  255. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  256. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  257. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
  258. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
  259. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
  260. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
  261. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
  262. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
  263. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  264. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
  265. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
  266. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
  267. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
  268. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
  269. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  270. diffusers/pipelines/dit/pipeline_dit.py +4 -2
  271. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
  272. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
  273. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
  274. diffusers/pipelines/flux/__init__.py +4 -0
  275. diffusers/pipelines/flux/modeling_flux.py +1 -1
  276. diffusers/pipelines/flux/pipeline_flux.py +37 -36
  277. diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
  278. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
  279. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
  280. diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
  281. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
  282. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
  283. diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
  284. diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
  285. diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
  286. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  287. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  288. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
  289. diffusers/pipelines/flux/pipeline_output.py +6 -4
  290. diffusers/pipelines/free_init_utils.py +2 -2
  291. diffusers/pipelines/free_noise_utils.py +3 -3
  292. diffusers/pipelines/hidream_image/__init__.py +47 -0
  293. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
  294. diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
  295. diffusers/pipelines/hunyuan_video/__init__.py +2 -0
  296. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
  297. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
  298. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
  299. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
  300. diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
  301. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
  302. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
  303. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
  304. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
  305. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
  306. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
  307. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
  308. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  309. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
  310. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
  311. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
  312. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
  313. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
  314. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
  315. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
  316. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
  317. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
  318. diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
  319. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
  320. diffusers/pipelines/kolors/text_encoder.py +3 -3
  321. diffusers/pipelines/kolors/tokenizer.py +1 -1
  322. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
  323. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
  324. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  325. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
  326. diffusers/pipelines/latte/pipeline_latte.py +12 -12
  327. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
  328. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
  329. diffusers/pipelines/ltx/__init__.py +4 -0
  330. diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
  331. diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
  332. diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
  333. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
  334. diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
  335. diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
  336. diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
  337. diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
  338. diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
  339. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
  340. diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
  341. diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
  342. diffusers/pipelines/onnx_utils.py +15 -2
  343. diffusers/pipelines/pag/pag_utils.py +2 -2
  344. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
  345. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
  346. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
  347. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
  348. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
  349. diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
  350. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
  351. diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
  352. diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
  353. diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
  354. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
  355. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
  356. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
  357. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
  358. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
  359. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
  360. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
  361. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  362. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
  363. diffusers/pipelines/pia/pipeline_pia.py +8 -6
  364. diffusers/pipelines/pipeline_flax_utils.py +5 -6
  365. diffusers/pipelines/pipeline_loading_utils.py +113 -15
  366. diffusers/pipelines/pipeline_utils.py +127 -48
  367. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
  368. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
  369. diffusers/pipelines/qwenimage/__init__.py +55 -0
  370. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  371. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  372. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  373. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  374. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  375. diffusers/pipelines/sana/__init__.py +4 -0
  376. diffusers/pipelines/sana/pipeline_sana.py +23 -21
  377. diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
  378. diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
  379. diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
  380. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
  381. diffusers/pipelines/shap_e/camera.py +1 -1
  382. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  383. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  384. diffusers/pipelines/shap_e/renderer.py +3 -3
  385. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  386. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  387. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  388. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  389. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  390. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  391. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  392. diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
  393. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
  394. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
  395. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
  396. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
  397. diffusers/pipelines/stable_diffusion/__init__.py +0 -7
  398. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  399. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
  400. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  401. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
  402. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  403. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
  404. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
  405. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
  406. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
  407. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
  408. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
  409. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
  410. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
  411. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
  412. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
  413. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
  414. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
  415. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
  416. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
  417. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  418. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  419. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  420. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
  421. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
  422. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
  423. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
  424. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
  425. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
  426. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
  427. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
  428. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
  429. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
  430. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
  431. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
  432. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  433. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
  434. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  435. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
  436. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
  437. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
  438. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
  439. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
  440. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
  441. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
  442. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
  443. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
  444. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
  445. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
  446. diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
  447. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
  448. diffusers/pipelines/unclip/text_proj.py +2 -2
  449. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
  450. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  451. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
  452. diffusers/pipelines/visualcloze/__init__.py +52 -0
  453. diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
  454. diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
  455. diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
  456. diffusers/pipelines/wan/__init__.py +2 -0
  457. diffusers/pipelines/wan/pipeline_wan.py +91 -30
  458. diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
  459. diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
  460. diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
  461. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  462. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
  463. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  464. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  465. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
  466. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
  467. diffusers/quantizers/__init__.py +3 -1
  468. diffusers/quantizers/base.py +17 -1
  469. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
  470. diffusers/quantizers/bitsandbytes/utils.py +10 -7
  471. diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
  472. diffusers/quantizers/gguf/utils.py +108 -16
  473. diffusers/quantizers/pipe_quant_config.py +202 -0
  474. diffusers/quantizers/quantization_config.py +18 -16
  475. diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
  476. diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
  477. diffusers/schedulers/__init__.py +3 -1
  478. diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
  479. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  480. diffusers/schedulers/scheduling_consistency_models.py +1 -1
  481. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
  482. diffusers/schedulers/scheduling_ddim.py +8 -8
  483. diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
  484. diffusers/schedulers/scheduling_ddim_flax.py +6 -6
  485. diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
  486. diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
  487. diffusers/schedulers/scheduling_ddpm.py +9 -9
  488. diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
  489. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
  490. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
  491. diffusers/schedulers/scheduling_deis_multistep.py +16 -9
  492. diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
  493. diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
  494. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
  495. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
  496. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  497. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
  498. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
  499. diffusers/schedulers/scheduling_edm_euler.py +20 -11
  500. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
  501. diffusers/schedulers/scheduling_euler_discrete.py +3 -3
  502. diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
  503. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
  504. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
  505. diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
  506. diffusers/schedulers/scheduling_heun_discrete.py +2 -2
  507. diffusers/schedulers/scheduling_ipndm.py +2 -2
  508. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
  509. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
  510. diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
  511. diffusers/schedulers/scheduling_lcm.py +3 -3
  512. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  513. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  514. diffusers/schedulers/scheduling_pndm.py +4 -4
  515. diffusers/schedulers/scheduling_pndm_flax.py +4 -4
  516. diffusers/schedulers/scheduling_repaint.py +9 -9
  517. diffusers/schedulers/scheduling_sasolver.py +15 -15
  518. diffusers/schedulers/scheduling_scm.py +1 -2
  519. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  520. diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
  521. diffusers/schedulers/scheduling_tcd.py +3 -3
  522. diffusers/schedulers/scheduling_unclip.py +5 -5
  523. diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
  524. diffusers/schedulers/scheduling_utils.py +3 -3
  525. diffusers/schedulers/scheduling_utils_flax.py +2 -2
  526. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  527. diffusers/training_utils.py +91 -5
  528. diffusers/utils/__init__.py +15 -0
  529. diffusers/utils/accelerate_utils.py +1 -1
  530. diffusers/utils/constants.py +4 -0
  531. diffusers/utils/doc_utils.py +1 -1
  532. diffusers/utils/dummy_pt_objects.py +432 -0
  533. diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
  534. diffusers/utils/dynamic_modules_utils.py +85 -8
  535. diffusers/utils/export_utils.py +1 -1
  536. diffusers/utils/hub_utils.py +33 -17
  537. diffusers/utils/import_utils.py +151 -18
  538. diffusers/utils/logging.py +1 -1
  539. diffusers/utils/outputs.py +2 -1
  540. diffusers/utils/peft_utils.py +96 -10
  541. diffusers/utils/state_dict_utils.py +20 -3
  542. diffusers/utils/testing_utils.py +195 -17
  543. diffusers/utils/torch_utils.py +43 -5
  544. diffusers/video_processor.py +2 -2
  545. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
  546. diffusers-0.35.0.dist-info/RECORD +703 -0
  547. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
  548. diffusers-0.33.1.dist-info/RECORD +0 -608
  549. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  550. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  551. {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2024 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 Harutatsu Akiyama, Jinbin Bai, and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -18,7 +18,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
18
  import numpy as np
19
19
  import PIL.Image
20
20
  import torch
21
- import torch.nn.functional as F
22
21
  from transformers import (
23
22
  CLIPImageProcessor,
24
23
  CLIPTextModel,
@@ -35,7 +34,13 @@ from ...loaders import (
35
34
  StableDiffusionXLLoraLoaderMixin,
36
35
  TextualInversionLoaderMixin,
37
36
  )
38
- from ...models import AutoencoderKL, ControlNetModel, ControlNetUnionModel, ImageProjection, UNet2DConditionModel
37
+ from ...models import (
38
+ AutoencoderKL,
39
+ ControlNetUnionModel,
40
+ ImageProjection,
41
+ MultiControlNetUnionModel,
42
+ UNet2DConditionModel,
43
+ )
39
44
  from ...models.attention_processor import (
40
45
  AttnProcessor2_0,
41
46
  XFormersAttnProcessor,
@@ -51,7 +56,7 @@ from ...utils import (
51
56
  scale_lora_layers,
52
57
  unscale_lora_layers,
53
58
  )
54
- from ...utils.torch_utils import is_compiled_module, randn_tensor
59
+ from ...utils.torch_utils import empty_device_cache, is_compiled_module, randn_tensor
55
60
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
56
61
  from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
57
62
 
@@ -134,7 +139,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
134
139
  r"""
135
140
  Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
136
141
  Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
137
- Flawed](https://arxiv.org/pdf/2305.08891.pdf).
142
+ Flawed](https://huggingface.co/papers/2305.08891).
138
143
 
139
144
  Args:
140
145
  noise_cfg (`torch.Tensor`):
@@ -230,7 +235,9 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
230
235
  tokenizer: CLIPTokenizer,
231
236
  tokenizer_2: CLIPTokenizer,
232
237
  unet: UNet2DConditionModel,
233
- controlnet: ControlNetUnionModel,
238
+ controlnet: Union[
239
+ ControlNetUnionModel, List[ControlNetUnionModel], Tuple[ControlNetUnionModel], MultiControlNetUnionModel
240
+ ],
234
241
  scheduler: KarrasDiffusionSchedulers,
235
242
  requires_aesthetics_score: bool = False,
236
243
  force_zeros_for_empty_prompt: bool = True,
@@ -240,8 +247,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
240
247
  ):
241
248
  super().__init__()
242
249
 
243
- if not isinstance(controlnet, ControlNetUnionModel):
244
- raise ValueError("Expected `controlnet` to be of type `ControlNetUnionModel`.")
250
+ if isinstance(controlnet, (list, tuple)):
251
+ controlnet = MultiControlNetUnionModel(controlnet)
245
252
 
246
253
  self.register_modules(
247
254
  vae=vae,
@@ -587,7 +594,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
587
594
  def prepare_extra_step_kwargs(self, generator, eta):
588
595
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
589
596
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
590
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
597
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
591
598
  # and should be between [0, 1]
592
599
 
593
600
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -660,6 +667,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
660
667
  controlnet_conditioning_scale=1.0,
661
668
  control_guidance_start=0.0,
662
669
  control_guidance_end=1.0,
670
+ control_mode=None,
663
671
  callback_on_step_end_tensor_inputs=None,
664
672
  padding_mask_crop=None,
665
673
  ):
@@ -747,25 +755,34 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
747
755
  "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
748
756
  )
749
757
 
758
+ # `prompt` needs more sophisticated handling when there are multiple
759
+ # conditionings.
760
+ if isinstance(self.controlnet, MultiControlNetUnionModel):
761
+ if isinstance(prompt, list):
762
+ logger.warning(
763
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
764
+ " prompts. The conditionings will be fixed across the prompts."
765
+ )
766
+
750
767
  # Check `image`
751
- is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
752
- self.controlnet, torch._dynamo.eval_frame.OptimizedModule
753
- )
754
- if (
755
- isinstance(self.controlnet, ControlNetModel)
756
- or is_compiled
757
- and isinstance(self.controlnet._orig_mod, ControlNetModel)
758
- ):
759
- self.check_image(image, prompt, prompt_embeds)
760
- elif (
761
- isinstance(self.controlnet, ControlNetUnionModel)
762
- or is_compiled
763
- and isinstance(self.controlnet._orig_mod, ControlNetUnionModel)
764
- ):
765
- self.check_image(image, prompt, prompt_embeds)
768
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
766
769
 
767
- else:
768
- assert False
770
+ if isinstance(controlnet, ControlNetUnionModel):
771
+ for image_ in image:
772
+ self.check_image(image_, prompt, prompt_embeds)
773
+ elif isinstance(controlnet, MultiControlNetUnionModel):
774
+ if not isinstance(image, list):
775
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
776
+ elif not all(isinstance(i, list) for i in image):
777
+ raise ValueError("For multiple controlnets: elements of `image` must be list of conditionings.")
778
+ elif len(image) != len(self.controlnet.nets):
779
+ raise ValueError(
780
+ f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
781
+ )
782
+
783
+ for images_ in image:
784
+ for image_ in images_:
785
+ self.check_image(image_, prompt, prompt_embeds)
769
786
 
770
787
  if not isinstance(control_guidance_start, (tuple, list)):
771
788
  control_guidance_start = [control_guidance_start]
@@ -778,6 +795,12 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
778
795
  f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
779
796
  )
780
797
 
798
+ if isinstance(controlnet, MultiControlNetUnionModel):
799
+ if len(control_guidance_start) != len(self.controlnet.nets):
800
+ raise ValueError(
801
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
802
+ )
803
+
781
804
  for start, end in zip(control_guidance_start, control_guidance_end):
782
805
  if start >= end:
783
806
  raise ValueError(
@@ -788,6 +811,28 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
788
811
  if end > 1.0:
789
812
  raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
790
813
 
814
+ # Check `control_mode`
815
+ if isinstance(controlnet, ControlNetUnionModel):
816
+ if max(control_mode) >= controlnet.config.num_control_type:
817
+ raise ValueError(f"control_mode: must be lower than {controlnet.config.num_control_type}.")
818
+ elif isinstance(controlnet, MultiControlNetUnionModel):
819
+ for _control_mode, _controlnet in zip(control_mode, self.controlnet.nets):
820
+ if max(_control_mode) >= _controlnet.config.num_control_type:
821
+ raise ValueError(f"control_mode: must be lower than {_controlnet.config.num_control_type}.")
822
+
823
+ # Equal number of `image` and `control_mode` elements
824
+ if isinstance(controlnet, ControlNetUnionModel):
825
+ if len(image) != len(control_mode):
826
+ raise ValueError("Expected len(control_image) == len(control_mode)")
827
+ elif isinstance(controlnet, MultiControlNetUnionModel):
828
+ if not all(isinstance(i, list) for i in control_mode):
829
+ raise ValueError(
830
+ "For multiple controlnets: elements of control_mode must be lists representing conditioning mode."
831
+ )
832
+
833
+ elif sum(len(x) for x in image) != sum(len(x) for x in control_mode):
834
+ raise ValueError("Expected len(control_image) == len(control_mode)")
835
+
791
836
  if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
792
837
  raise ValueError(
793
838
  "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
@@ -1091,7 +1136,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1091
1136
  return self._clip_skip
1092
1137
 
1093
1138
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1094
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1139
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
1095
1140
  # corresponds to doing no classifier free guidance.
1096
1141
  @property
1097
1142
  def do_classifier_free_guidance(self):
@@ -1117,7 +1162,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1117
1162
  prompt_2: Optional[Union[str, List[str]]] = None,
1118
1163
  image: PipelineImageInput = None,
1119
1164
  mask_image: PipelineImageInput = None,
1120
- control_image: PipelineImageInput = None,
1165
+ control_image: Union[PipelineImageInput, List[PipelineImageInput]] = None,
1121
1166
  height: Optional[int] = None,
1122
1167
  width: Optional[int] = None,
1123
1168
  padding_mask_crop: Optional[int] = None,
@@ -1145,7 +1190,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1145
1190
  guess_mode: bool = False,
1146
1191
  control_guidance_start: Union[float, List[float]] = 0.0,
1147
1192
  control_guidance_end: Union[float, List[float]] = 1.0,
1148
- control_mode: Optional[Union[int, List[int]]] = None,
1193
+ control_mode: Optional[Union[int, List[int], List[List[int]]]] = None,
1149
1194
  guidance_rescale: float = 0.0,
1150
1195
  original_size: Tuple[int, int] = None,
1151
1196
  crops_coords_top_left: Tuple[int, int] = (0, 0),
@@ -1177,6 +1222,13 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1177
1222
  repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1178
1223
  to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1179
1224
  instead of 3, so the expected shape would be `(B, H, W, 1)`.
1225
+ control_image (`PipelineImageInput` or `List[PipelineImageInput]`, *optional*):
1226
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
1227
+ specified as `torch.Tensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be accepted
1228
+ as an image. The dimensions of the output image defaults to `image`'s dimensions. If height and/or
1229
+ width are passed, `image` is resized accordingly. If multiple ControlNets are specified in `init`,
1230
+ images must be passed as a list such that each element of the list can be correctly batched for input
1231
+ to a single ControlNet.
1180
1232
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1181
1233
  The height in pixels of the generated image.
1182
1234
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -1215,11 +1267,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1215
1267
  forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1216
1268
  Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1217
1269
  guidance_scale (`float`, *optional*, defaults to 7.5):
1218
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1219
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
1220
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1221
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1222
- usually at the expense of lower image quality.
1270
+ Guidance scale as defined in [Classifier-Free Diffusion
1271
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
1272
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
1273
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
1274
+ the text `prompt`, usually at the expense of lower image quality.
1223
1275
  negative_prompt (`str` or `List[str]`, *optional*):
1224
1276
  The prompt or prompts not to guide the image generation. If not defined, one has to pass
1225
1277
  `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
@@ -1250,8 +1302,8 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1250
1302
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1251
1303
  The number of images to generate per prompt.
1252
1304
  eta (`float`, *optional*, defaults to 0.0):
1253
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1254
- [`schedulers.DDIMScheduler`], will be ignored for others.
1305
+ Corresponds to parameter eta (η) in the DDIM paper: https://huggingface.co/papers/2010.02502. Only
1306
+ applies to [`schedulers.DDIMScheduler`], will be ignored for others.
1255
1307
  generator (`torch.Generator`, *optional*):
1256
1308
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1257
1309
  to make generation deterministic.
@@ -1269,6 +1321,22 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1269
1321
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1270
1322
  `self.processor` in
1271
1323
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1324
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1325
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
1326
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
1327
+ the corresponding scale as a list.
1328
+ guess_mode (`bool`, *optional*, defaults to `False`):
1329
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
1330
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
1331
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
1332
+ The percentage of total steps at which the ControlNet starts applying.
1333
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
1334
+ The percentage of total steps at which the ControlNet stops applying.
1335
+ control_mode (`int` or `List[int]` or `List[List[int]], *optional*):
1336
+ The control condition types for the ControlNet. See the ControlNet's model card forinformation on the
1337
+ available control modes. If multiple ControlNets are specified in `init`, control_mode should be a list
1338
+ where each ControlNet should have its corresponding control mode list. Should reflect the order of
1339
+ conditions in control_image.
1272
1340
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1273
1341
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1274
1342
  `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1333,22 +1401,6 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1333
1401
 
1334
1402
  controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1335
1403
 
1336
- # align format for control guidance
1337
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1338
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1339
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1340
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1341
-
1342
- # # 0.0 Default height and width to unet
1343
- # height = height or self.unet.config.sample_size * self.vae_scale_factor
1344
- # width = width or self.unet.config.sample_size * self.vae_scale_factor
1345
-
1346
- # 0.1 align format for control guidance
1347
- if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1348
- control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1349
- elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1350
- control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1351
-
1352
1404
  if not isinstance(control_image, list):
1353
1405
  control_image = [control_image]
1354
1406
  else:
@@ -1357,40 +1409,59 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1357
1409
  if not isinstance(control_mode, list):
1358
1410
  control_mode = [control_mode]
1359
1411
 
1360
- if len(control_image) != len(control_mode):
1361
- raise ValueError("Expected len(control_image) == len(control_type)")
1412
+ if isinstance(controlnet, MultiControlNetUnionModel):
1413
+ control_image = [[item] for item in control_image]
1414
+ control_mode = [[item] for item in control_mode]
1362
1415
 
1363
- num_control_type = controlnet.config.num_control_type
1416
+ # align format for control guidance
1417
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
1418
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
1419
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1420
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1421
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1422
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1423
+ control_guidance_start, control_guidance_end = (
1424
+ mult * [control_guidance_start],
1425
+ mult * [control_guidance_end],
1426
+ )
1427
+
1428
+ if isinstance(controlnet_conditioning_scale, float):
1429
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetUnionModel) else len(control_mode)
1430
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * mult
1364
1431
 
1365
1432
  # 1. Check inputs
1366
- control_type = [0 for _ in range(num_control_type)]
1367
- for _image, control_idx in zip(control_image, control_mode):
1368
- control_type[control_idx] = 1
1369
- self.check_inputs(
1370
- prompt,
1371
- prompt_2,
1372
- _image,
1373
- mask_image,
1374
- strength,
1375
- num_inference_steps,
1376
- callback_steps,
1377
- output_type,
1378
- negative_prompt,
1379
- negative_prompt_2,
1380
- prompt_embeds,
1381
- negative_prompt_embeds,
1382
- ip_adapter_image,
1383
- ip_adapter_image_embeds,
1384
- pooled_prompt_embeds,
1385
- negative_pooled_prompt_embeds,
1386
- controlnet_conditioning_scale,
1387
- control_guidance_start,
1388
- control_guidance_end,
1389
- callback_on_step_end_tensor_inputs,
1390
- padding_mask_crop,
1391
- )
1433
+ self.check_inputs(
1434
+ prompt,
1435
+ prompt_2,
1436
+ control_image,
1437
+ mask_image,
1438
+ strength,
1439
+ num_inference_steps,
1440
+ callback_steps,
1441
+ output_type,
1442
+ negative_prompt,
1443
+ negative_prompt_2,
1444
+ prompt_embeds,
1445
+ negative_prompt_embeds,
1446
+ ip_adapter_image,
1447
+ ip_adapter_image_embeds,
1448
+ pooled_prompt_embeds,
1449
+ negative_pooled_prompt_embeds,
1450
+ controlnet_conditioning_scale,
1451
+ control_guidance_start,
1452
+ control_guidance_end,
1453
+ control_mode,
1454
+ callback_on_step_end_tensor_inputs,
1455
+ padding_mask_crop,
1456
+ )
1392
1457
 
1393
- control_type = torch.Tensor(control_type)
1458
+ if isinstance(controlnet, ControlNetUnionModel):
1459
+ control_type = torch.zeros(controlnet.config.num_control_type).scatter_(0, torch.tensor(control_mode), 1)
1460
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1461
+ control_type = [
1462
+ torch.zeros(controlnet_.config.num_control_type).scatter_(0, torch.tensor(control_mode_), 1)
1463
+ for control_mode_, controlnet_ in zip(control_mode, self.controlnet.nets)
1464
+ ]
1394
1465
 
1395
1466
  self._guidance_scale = guidance_scale
1396
1467
  self._clip_skip = clip_skip
@@ -1483,21 +1554,55 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1483
1554
  init_image = init_image.to(dtype=torch.float32)
1484
1555
 
1485
1556
  # 5.2 Prepare control images
1486
- for idx, _ in enumerate(control_image):
1487
- control_image[idx] = self.prepare_control_image(
1488
- image=control_image[idx],
1489
- width=width,
1490
- height=height,
1491
- batch_size=batch_size * num_images_per_prompt,
1492
- num_images_per_prompt=num_images_per_prompt,
1493
- device=device,
1494
- dtype=controlnet.dtype,
1495
- crops_coords=crops_coords,
1496
- resize_mode=resize_mode,
1497
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1498
- guess_mode=guess_mode,
1499
- )
1500
- height, width = control_image[idx].shape[-2:]
1557
+ if isinstance(controlnet, ControlNetUnionModel):
1558
+ control_images = []
1559
+
1560
+ for image_ in control_image:
1561
+ image_ = self.prepare_control_image(
1562
+ image=image_,
1563
+ width=width,
1564
+ height=height,
1565
+ batch_size=batch_size * num_images_per_prompt,
1566
+ num_images_per_prompt=num_images_per_prompt,
1567
+ device=device,
1568
+ dtype=controlnet.dtype,
1569
+ crops_coords=crops_coords,
1570
+ resize_mode=resize_mode,
1571
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1572
+ guess_mode=guess_mode,
1573
+ )
1574
+
1575
+ control_images.append(image_)
1576
+
1577
+ control_image = control_images
1578
+ height, width = control_image[0].shape[-2:]
1579
+
1580
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1581
+ control_images = []
1582
+
1583
+ for control_image_ in control_image:
1584
+ images = []
1585
+
1586
+ for image_ in control_image_:
1587
+ image_ = self.prepare_control_image(
1588
+ image=image_,
1589
+ width=width,
1590
+ height=height,
1591
+ batch_size=batch_size * num_images_per_prompt,
1592
+ num_images_per_prompt=num_images_per_prompt,
1593
+ device=device,
1594
+ dtype=controlnet.dtype,
1595
+ crops_coords=crops_coords,
1596
+ resize_mode=resize_mode,
1597
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1598
+ guess_mode=guess_mode,
1599
+ )
1600
+
1601
+ images.append(image_)
1602
+ control_images.append(images)
1603
+
1604
+ control_image = control_images
1605
+ height, width = control_image[0][0].shape[-2:]
1501
1606
 
1502
1607
  # 5.3 Prepare mask
1503
1608
  mask = self.mask_processor.preprocess(
@@ -1559,10 +1664,11 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1559
1664
  # 8.2 Create tensor stating which controlnets to keep
1560
1665
  controlnet_keep = []
1561
1666
  for i in range(len(timesteps)):
1562
- controlnet_keep.append(
1563
- 1.0
1564
- - float(i / len(timesteps) < control_guidance_start or (i + 1) / len(timesteps) > control_guidance_end)
1565
- )
1667
+ keeps = [
1668
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1669
+ for s, e in zip(control_guidance_start, control_guidance_end)
1670
+ ]
1671
+ controlnet_keep.append(keeps)
1566
1672
 
1567
1673
  # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1568
1674
  height, width = latents.shape[-2:]
@@ -1627,11 +1733,24 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1627
1733
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1628
1734
  timesteps = timesteps[:num_inference_steps]
1629
1735
 
1630
- control_type = (
1631
- control_type.reshape(1, -1)
1632
- .to(device, dtype=prompt_embeds.dtype)
1633
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1736
+ control_type_repeat_factor = (
1737
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
1634
1738
  )
1739
+
1740
+ if isinstance(controlnet, ControlNetUnionModel):
1741
+ control_type = (
1742
+ control_type.reshape(1, -1)
1743
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1744
+ .repeat(control_type_repeat_factor, 1)
1745
+ )
1746
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1747
+ control_type = [
1748
+ _control_type.reshape(1, -1)
1749
+ .to(self._execution_device, dtype=prompt_embeds.dtype)
1750
+ .repeat(control_type_repeat_factor, 1)
1751
+ for _control_type in control_type
1752
+ ]
1753
+
1635
1754
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1636
1755
  for i, t in enumerate(timesteps):
1637
1756
  if self.interrupt:
@@ -1715,7 +1834,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1715
1834
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1716
1835
 
1717
1836
  if self.do_classifier_free_guidance and guidance_rescale > 0.0:
1718
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1837
+ # Based on 3.4. in https://huggingface.co/papers/2305.08891
1719
1838
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1720
1839
 
1721
1840
  # compute the previous noisy sample x_t -> x_t-1
@@ -1766,7 +1885,7 @@ class StableDiffusionXLControlNetUnionInpaintPipeline(
1766
1885
  if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1767
1886
  self.unet.to("cpu")
1768
1887
  self.controlnet.to("cpu")
1769
- torch.cuda.empty_cache()
1888
+ empty_device_cache()
1770
1889
 
1771
1890
  if not output_type == "latent":
1772
1891
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -603,7 +603,7 @@ class StableDiffusionXLControlNetUnionPipeline(
603
603
  def prepare_extra_step_kwargs(self, generator, eta):
604
604
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
605
605
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
606
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
606
+ # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
607
607
  # and should be between [0, 1]
608
608
 
609
609
  accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
@@ -960,7 +960,7 @@ class StableDiffusionXLControlNetUnionPipeline(
960
960
  return self._clip_skip
961
961
 
962
962
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
963
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
963
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
964
964
  # corresponds to doing no classifier free guidance.
965
965
  @property
966
966
  def do_classifier_free_guidance(self):
@@ -1082,8 +1082,8 @@ class StableDiffusionXLControlNetUnionPipeline(
1082
1082
  num_images_per_prompt (`int`, *optional*, defaults to 1):
1083
1083
  The number of images to generate per prompt.
1084
1084
  eta (`float`, *optional*, defaults to 0.0):
1085
- Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
1086
- to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1085
+ Corresponds to parameter eta (η) from the [DDIM](https://huggingface.co/papers/2010.02502) paper. Only
1086
+ applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
1087
1087
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1088
1088
  A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
1089
1089
  generation deterministic.
@@ -1452,17 +1452,21 @@ class StableDiffusionXLControlNetUnionPipeline(
1452
1452
  is_controlnet_compiled = is_compiled_module(self.controlnet)
1453
1453
  is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
1454
1454
 
1455
+ control_type_repeat_factor = (
1456
+ batch_size * num_images_per_prompt * (2 if self.do_classifier_free_guidance else 1)
1457
+ )
1458
+
1455
1459
  if isinstance(controlnet, ControlNetUnionModel):
1456
1460
  control_type = (
1457
1461
  control_type.reshape(1, -1)
1458
1462
  .to(self._execution_device, dtype=prompt_embeds.dtype)
1459
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1463
+ .repeat(control_type_repeat_factor, 1)
1460
1464
  )
1461
- if isinstance(controlnet, MultiControlNetUnionModel):
1465
+ elif isinstance(controlnet, MultiControlNetUnionModel):
1462
1466
  control_type = [
1463
1467
  _control_type.reshape(1, -1)
1464
1468
  .to(self._execution_device, dtype=prompt_embeds.dtype)
1465
- .repeat(batch_size * num_images_per_prompt * 2, 1)
1469
+ .repeat(control_type_repeat_factor, 1)
1466
1470
  for _control_type in control_type
1467
1471
  ]
1468
1472