diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,7 @@ import torch.nn as nn
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
23
  from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
- from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
25
  from ..attention import JointTransformerBlock
26
26
  from ..attention_processor import Attention, AttentionProcessor, FusedJointAttnProcessor2_0
27
27
  from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
@@ -40,6 +40,48 @@ class SD3ControlNetOutput(BaseOutput):
40
40
 
41
41
 
42
42
  class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
43
+ r"""
44
+ ControlNet model for [Stable Diffusion 3](https://huggingface.co/papers/2403.03206).
45
+
46
+ Parameters:
47
+ sample_size (`int`, defaults to `128`):
48
+ The width/height of the latents. This is fixed during training since it is used to learn a number of
49
+ position embeddings.
50
+ patch_size (`int`, defaults to `2`):
51
+ Patch size to turn the input data into small patches.
52
+ in_channels (`int`, defaults to `16`):
53
+ The number of latent channels in the input.
54
+ num_layers (`int`, defaults to `18`):
55
+ The number of layers of transformer blocks to use.
56
+ attention_head_dim (`int`, defaults to `64`):
57
+ The number of channels in each head.
58
+ num_attention_heads (`int`, defaults to `18`):
59
+ The number of heads to use for multi-head attention.
60
+ joint_attention_dim (`int`, defaults to `4096`):
61
+ The embedding dimension to use for joint text-image attention.
62
+ caption_projection_dim (`int`, defaults to `1152`):
63
+ The embedding dimension of caption embeddings.
64
+ pooled_projection_dim (`int`, defaults to `2048`):
65
+ The embedding dimension of pooled text projections.
66
+ out_channels (`int`, defaults to `16`):
67
+ The number of latent channels in the output.
68
+ pos_embed_max_size (`int`, defaults to `96`):
69
+ The maximum latent height/width of positional embeddings.
70
+ extra_conditioning_channels (`int`, defaults to `0`):
71
+ The number of extra channels to use for conditioning for patch embedding.
72
+ dual_attention_layers (`Tuple[int, ...]`, defaults to `()`):
73
+ The number of dual-stream transformer blocks to use.
74
+ qk_norm (`str`, *optional*, defaults to `None`):
75
+ The normalization to use for query and key in the attention layer. If `None`, no normalization is used.
76
+ pos_embed_type (`str`, defaults to `"sincos"`):
77
+ The type of positional embedding to use. Choose between `"sincos"` and `None`.
78
+ use_pos_embed (`bool`, defaults to `True`):
79
+ Whether to use positional embeddings.
80
+ force_zeros_for_pooled_projection (`bool`, defaults to `True`):
81
+ Whether to force zeros for pooled projection embeddings. This is handled in the pipelines by reading the
82
+ config value of the ControlNet model.
83
+ """
84
+
43
85
  _supports_gradient_checkpointing = True
44
86
 
45
87
  @register_to_config
@@ -93,7 +135,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
93
135
  JointTransformerBlock(
94
136
  dim=self.inner_dim,
95
137
  num_attention_heads=num_attention_heads,
96
- attention_head_dim=self.config.attention_head_dim,
138
+ attention_head_dim=attention_head_dim,
97
139
  context_pre_only=False,
98
140
  qk_norm=qk_norm,
99
141
  use_dual_attention=True if i in dual_attention_layers else False,
@@ -108,7 +150,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
108
150
  SD3SingleTransformerBlock(
109
151
  dim=self.inner_dim,
110
152
  num_attention_heads=num_attention_heads,
111
- attention_head_dim=self.config.attention_head_dim,
153
+ attention_head_dim=attention_head_dim,
112
154
  )
113
155
  for _ in range(num_layers)
114
156
  ]
@@ -262,10 +304,6 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
262
304
  if self.original_attn_processors is not None:
263
305
  self.set_attn_processor(self.original_attn_processors)
264
306
 
265
- def _set_gradient_checkpointing(self, module, value=False):
266
- if hasattr(module, "gradient_checkpointing"):
267
- module.gradient_checkpointing = value
268
-
269
307
  # Notes: This is for SD3.5 8b controlnet, which shares the pos_embed with the transformer
270
308
  # we should have handled this in conversion script
271
309
  def _get_pos_embed_from_transformer(self, transformer):
@@ -301,28 +339,28 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
301
339
 
302
340
  def forward(
303
341
  self,
304
- hidden_states: torch.FloatTensor,
342
+ hidden_states: torch.Tensor,
305
343
  controlnet_cond: torch.Tensor,
306
344
  conditioning_scale: float = 1.0,
307
- encoder_hidden_states: torch.FloatTensor = None,
308
- pooled_projections: torch.FloatTensor = None,
345
+ encoder_hidden_states: torch.Tensor = None,
346
+ pooled_projections: torch.Tensor = None,
309
347
  timestep: torch.LongTensor = None,
310
348
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
311
349
  return_dict: bool = True,
312
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
350
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
313
351
  """
314
352
  The [`SD3Transformer2DModel`] forward method.
315
353
 
316
354
  Args:
317
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
355
+ hidden_states (`torch.Tensor` of shape `(batch size, channel, height, width)`):
318
356
  Input `hidden_states`.
319
357
  controlnet_cond (`torch.Tensor`):
320
358
  The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
321
359
  conditioning_scale (`float`, defaults to `1.0`):
322
360
  The scale factor for ControlNet outputs.
323
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
361
+ encoder_hidden_states (`torch.Tensor` of shape `(batch size, sequence_len, embed_dims)`):
324
362
  Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
325
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
363
+ pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
326
364
  from the embeddings of input conditions.
327
365
  timestep ( `torch.LongTensor`):
328
366
  Used to indicate denoising step.
@@ -382,30 +420,16 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
382
420
 
383
421
  for block in self.transformer_blocks:
384
422
  if torch.is_grad_enabled() and self.gradient_checkpointing:
385
-
386
- def create_custom_forward(module, return_dict=None):
387
- def custom_forward(*inputs):
388
- if return_dict is not None:
389
- return module(*inputs, return_dict=return_dict)
390
- else:
391
- return module(*inputs)
392
-
393
- return custom_forward
394
-
395
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
396
423
  if self.context_embedder is not None:
397
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
398
- create_custom_forward(block),
424
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
425
+ block,
399
426
  hidden_states,
400
427
  encoder_hidden_states,
401
428
  temb,
402
- **ckpt_kwargs,
403
429
  )
404
430
  else:
405
431
  # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states`
406
- hidden_states = torch.utils.checkpoint.checkpoint(
407
- create_custom_forward(block), hidden_states, temb, **ckpt_kwargs
408
- )
432
+ hidden_states = self._gradient_checkpointing_func(block, hidden_states, temb)
409
433
 
410
434
  else:
411
435
  if self.context_embedder is not None:
@@ -455,11 +479,11 @@ class SD3MultiControlNetModel(ModelMixin):
455
479
 
456
480
  def forward(
457
481
  self,
458
- hidden_states: torch.FloatTensor,
482
+ hidden_states: torch.Tensor,
459
483
  controlnet_cond: List[torch.tensor],
460
484
  conditioning_scale: List[float],
461
- pooled_projections: torch.FloatTensor,
462
- encoder_hidden_states: torch.FloatTensor = None,
485
+ pooled_projections: torch.Tensor,
486
+ encoder_hidden_states: torch.Tensor = None,
463
487
  timestep: torch.LongTensor = None,
464
488
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
465
489
  return_dict: bool = True,
@@ -590,10 +590,6 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
590
590
  for module in self.children():
591
591
  fn_recursive_set_attention_slice(module, reversed_slice_size)
592
592
 
593
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
594
- if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, UNetMidBlock2DCrossAttn)):
595
- module.gradient_checkpointing = value
596
-
597
593
  def forward(
598
594
  self,
599
595
  sample: torch.Tensor,
@@ -671,10 +667,11 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
671
667
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
672
668
  # This would be a good case for the `match` statement (Python 3.10+)
673
669
  is_mps = sample.device.type == "mps"
670
+ is_npu = sample.device.type == "npu"
674
671
  if isinstance(timestep, float):
675
- dtype = torch.float32 if is_mps else torch.float64
672
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
676
673
  else:
677
- dtype = torch.int32 if is_mps else torch.int64
674
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
678
675
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
679
676
  elif len(timesteps.shape) == 0:
680
677
  timesteps = timesteps[None].to(sample.device)
@@ -690,7 +687,7 @@ class SparseControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
690
687
  t_emb = t_emb.to(dtype=sample.dtype)
691
688
 
692
689
  emb = self.time_embedding(t_emb, timestep_cond)
693
- emb = emb.repeat_interleave(sample_num_frames, dim=0)
690
+ emb = emb.repeat_interleave(sample_num_frames, dim=0, output_size=emb.shape[0] * sample_num_frames)
694
691
 
695
692
  # 2. pre-process
696
693
  batch_size, channels, num_frames, height, width = sample.shape
@@ -29,8 +29,6 @@ from ..attention_processor import (
29
29
  from ..embeddings import TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
30
30
  from ..modeling_utils import ModelMixin
31
31
  from ..unets.unet_2d_blocks import (
32
- CrossAttnDownBlock2D,
33
- DownBlock2D,
34
32
  UNetMidBlock2DCrossAttn,
35
33
  get_down_block,
36
34
  )
@@ -599,10 +597,6 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
599
597
  for module in self.children():
600
598
  fn_recursive_set_attention_slice(module, reversed_slice_size)
601
599
 
602
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
603
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
604
- module.gradient_checkpointing = value
605
-
606
600
  def forward(
607
601
  self,
608
602
  sample: torch.Tensor,
@@ -611,12 +605,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
611
605
  controlnet_cond: List[torch.Tensor],
612
606
  control_type: torch.Tensor,
613
607
  control_type_idx: List[int],
614
- conditioning_scale: float = 1.0,
608
+ conditioning_scale: Union[float, List[float]] = 1.0,
615
609
  class_labels: Optional[torch.Tensor] = None,
616
610
  timestep_cond: Optional[torch.Tensor] = None,
617
611
  attention_mask: Optional[torch.Tensor] = None,
618
612
  added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
619
613
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
614
+ from_multi: bool = False,
620
615
  guess_mode: bool = False,
621
616
  return_dict: bool = True,
622
617
  ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
@@ -653,6 +648,8 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
653
648
  Additional conditions for the Stable Diffusion XL UNet.
654
649
  cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
655
650
  A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
651
+ from_multi (`bool`, defaults to `False`):
652
+ Use standard scaling when called from `MultiControlNetUnionModel`.
656
653
  guess_mode (`bool`, defaults to `False`):
657
654
  In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if
658
655
  you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
@@ -664,6 +661,9 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
664
661
  If `return_dict` is `True`, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a tuple is
665
662
  returned where the first element is the sample tensor.
666
663
  """
664
+ if isinstance(conditioning_scale, float):
665
+ conditioning_scale = [conditioning_scale] * len(controlnet_cond)
666
+
667
667
  # check channel order
668
668
  channel_order = self.config.controlnet_conditioning_channel_order
669
669
 
@@ -681,10 +681,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
681
681
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
682
682
  # This would be a good case for the `match` statement (Python 3.10+)
683
683
  is_mps = sample.device.type == "mps"
684
+ is_npu = sample.device.type == "npu"
684
685
  if isinstance(timestep, float):
685
- dtype = torch.float32 if is_mps else torch.float64
686
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
686
687
  else:
687
- dtype = torch.int32 if is_mps else torch.int64
688
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
688
689
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
689
690
  elif len(timesteps.shape) == 0:
690
691
  timesteps = timesteps[None].to(sample.device)
@@ -747,12 +748,16 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
747
748
  inputs = []
748
749
  condition_list = []
749
750
 
750
- for cond, control_idx in zip(controlnet_cond, control_type_idx):
751
+ for cond, control_idx, scale in zip(controlnet_cond, control_type_idx, conditioning_scale):
751
752
  condition = self.controlnet_cond_embedding(cond)
752
753
  feat_seq = torch.mean(condition, dim=(2, 3))
753
754
  feat_seq = feat_seq + self.task_embedding[control_idx]
754
- inputs.append(feat_seq.unsqueeze(1))
755
- condition_list.append(condition)
755
+ if from_multi:
756
+ inputs.append(feat_seq.unsqueeze(1))
757
+ condition_list.append(condition)
758
+ else:
759
+ inputs.append(feat_seq.unsqueeze(1) * scale)
760
+ condition_list.append(condition * scale)
756
761
 
757
762
  condition = sample
758
763
  feat_seq = torch.mean(condition, dim=(2, 3))
@@ -764,10 +769,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
764
769
  x = layer(x)
765
770
 
766
771
  controlnet_cond_fuser = sample * 0.0
767
- for idx, condition in enumerate(condition_list[:-1]):
772
+ for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
768
773
  alpha = self.spatial_ch_projs(x[:, idx])
769
774
  alpha = alpha.unsqueeze(-1).unsqueeze(-1)
770
- controlnet_cond_fuser += condition + alpha
775
+ if from_multi:
776
+ controlnet_cond_fuser += condition + alpha
777
+ else:
778
+ controlnet_cond_fuser += condition + alpha * scale
771
779
 
772
780
  sample = sample + controlnet_cond_fuser
773
781
 
@@ -811,12 +819,13 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
811
819
  # 6. scaling
812
820
  if guess_mode and not self.config.global_pool_conditions:
813
821
  scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
814
- scales = scales * conditioning_scale
822
+ if from_multi:
823
+ scales = scales * conditioning_scale[0]
815
824
  down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
816
825
  mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
817
- else:
818
- down_block_res_samples = [sample * conditioning_scale for sample in down_block_res_samples]
819
- mid_block_res_sample = mid_block_res_sample * conditioning_scale
826
+ elif from_multi:
827
+ down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
828
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
820
829
 
821
830
  if self.config.global_pool_conditions:
822
831
  down_block_res_samples = [
@@ -20,7 +20,7 @@ import torch.utils.checkpoint
20
20
  from torch import Tensor, nn
21
21
 
22
22
  from ...configuration_utils import ConfigMixin, register_to_config
23
- from ...utils import BaseOutput, is_torch_version, logging
23
+ from ...utils import BaseOutput, logging
24
24
  from ...utils.torch_utils import apply_freeu
25
25
  from ..attention_processor import (
26
26
  ADDED_KV_ATTENTION_PROCESSORS,
@@ -864,10 +864,6 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
864
864
  for u in self.up_blocks:
865
865
  u.freeze_base_params()
866
866
 
867
- def _set_gradient_checkpointing(self, module, value=False):
868
- if hasattr(module, "gradient_checkpointing"):
869
- module.gradient_checkpointing = value
870
-
871
867
  @property
872
868
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
873
869
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -1088,10 +1084,11 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
1088
1084
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1089
1085
  # This would be a good case for the `match` statement (Python 3.10+)
1090
1086
  is_mps = sample.device.type == "mps"
1087
+ is_npu = sample.device.type == "npu"
1091
1088
  if isinstance(timestep, float):
1092
- dtype = torch.float32 if is_mps else torch.float64
1089
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
1093
1090
  else:
1094
- dtype = torch.int32 if is_mps else torch.int64
1091
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
1095
1092
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1096
1093
  elif len(timesteps.shape) == 0:
1097
1094
  timesteps = timesteps[None].to(sample.device)
@@ -1449,15 +1446,6 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1449
1446
  base_blocks = list(zip(self.base_resnets, self.base_attentions))
1450
1447
  ctrl_blocks = list(zip(self.ctrl_resnets, self.ctrl_attentions))
1451
1448
 
1452
- def create_custom_forward(module, return_dict=None):
1453
- def custom_forward(*inputs):
1454
- if return_dict is not None:
1455
- return module(*inputs, return_dict=return_dict)
1456
- else:
1457
- return module(*inputs)
1458
-
1459
- return custom_forward
1460
-
1461
1449
  for (b_res, b_attn), (c_res, c_attn), b2c, c2b in zip(
1462
1450
  base_blocks, ctrl_blocks, self.base_to_ctrl, self.ctrl_to_base
1463
1451
  ):
@@ -1467,13 +1455,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1467
1455
 
1468
1456
  # apply base subblock
1469
1457
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1470
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1471
- h_base = torch.utils.checkpoint.checkpoint(
1472
- create_custom_forward(b_res),
1473
- h_base,
1474
- temb,
1475
- **ckpt_kwargs,
1476
- )
1458
+ h_base = self._gradient_checkpointing_func(b_res, h_base, temb)
1477
1459
  else:
1478
1460
  h_base = b_res(h_base, temb)
1479
1461
 
@@ -1490,13 +1472,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1490
1472
  # apply ctrl subblock
1491
1473
  if apply_control:
1492
1474
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1493
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1494
- h_ctrl = torch.utils.checkpoint.checkpoint(
1495
- create_custom_forward(c_res),
1496
- h_ctrl,
1497
- temb,
1498
- **ckpt_kwargs,
1499
- )
1475
+ h_ctrl = self._gradient_checkpointing_func(c_res, h_ctrl, temb)
1500
1476
  else:
1501
1477
  h_ctrl = c_res(h_ctrl, temb)
1502
1478
  if c_attn is not None:
@@ -1861,15 +1837,6 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1861
1837
  and getattr(self, "b2", None)
1862
1838
  )
1863
1839
 
1864
- def create_custom_forward(module, return_dict=None):
1865
- def custom_forward(*inputs):
1866
- if return_dict is not None:
1867
- return module(*inputs, return_dict=return_dict)
1868
- else:
1869
- return module(*inputs)
1870
-
1871
- return custom_forward
1872
-
1873
1840
  def maybe_apply_freeu_to_subblock(hidden_states, res_h_base):
1874
1841
  # FreeU: Only operate on the first two stages
1875
1842
  if is_freeu_enabled:
@@ -1899,13 +1866,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1899
1866
  hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
1900
1867
 
1901
1868
  if torch.is_grad_enabled() and self.gradient_checkpointing:
1902
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1903
- hidden_states = torch.utils.checkpoint.checkpoint(
1904
- create_custom_forward(resnet),
1905
- hidden_states,
1906
- temb,
1907
- **ckpt_kwargs,
1908
- )
1869
+ hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
1909
1870
  else:
1910
1871
  hidden_states = resnet(hidden_states, temb)
1911
1872
 
@@ -0,0 +1,196 @@
1
+ import os
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ...models.controlnets.controlnet import ControlNetOutput
8
+ from ...models.controlnets.controlnet_union import ControlNetUnionModel
9
+ from ...models.modeling_utils import ModelMixin
10
+ from ...utils import logging
11
+
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ class MultiControlNetUnionModel(ModelMixin):
17
+ r"""
18
+ Multiple `ControlNetUnionModel` wrapper class for Multi-ControlNet-Union.
19
+
20
+ This module is a wrapper for multiple instances of the `ControlNetUnionModel`. The `forward()` API is designed to
21
+ be compatible with `ControlNetUnionModel`.
22
+
23
+ Args:
24
+ controlnets (`List[ControlNetUnionModel]`):
25
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
26
+ `ControlNetUnionModel` as a list.
27
+ """
28
+
29
+ def __init__(self, controlnets: Union[List[ControlNetUnionModel], Tuple[ControlNetUnionModel]]):
30
+ super().__init__()
31
+ self.nets = nn.ModuleList(controlnets)
32
+
33
+ def forward(
34
+ self,
35
+ sample: torch.Tensor,
36
+ timestep: Union[torch.Tensor, float, int],
37
+ encoder_hidden_states: torch.Tensor,
38
+ controlnet_cond: List[torch.tensor],
39
+ control_type: List[torch.Tensor],
40
+ control_type_idx: List[List[int]],
41
+ conditioning_scale: List[float],
42
+ class_labels: Optional[torch.Tensor] = None,
43
+ timestep_cond: Optional[torch.Tensor] = None,
44
+ attention_mask: Optional[torch.Tensor] = None,
45
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
46
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
47
+ guess_mode: bool = False,
48
+ return_dict: bool = True,
49
+ ) -> Union[ControlNetOutput, Tuple]:
50
+ down_block_res_samples, mid_block_res_sample = None, None
51
+ for i, (image, ctype, ctype_idx, scale, controlnet) in enumerate(
52
+ zip(controlnet_cond, control_type, control_type_idx, conditioning_scale, self.nets)
53
+ ):
54
+ if scale == 0.0:
55
+ continue
56
+ down_samples, mid_sample = controlnet(
57
+ sample=sample,
58
+ timestep=timestep,
59
+ encoder_hidden_states=encoder_hidden_states,
60
+ controlnet_cond=image,
61
+ control_type=ctype,
62
+ control_type_idx=ctype_idx,
63
+ conditioning_scale=scale,
64
+ class_labels=class_labels,
65
+ timestep_cond=timestep_cond,
66
+ attention_mask=attention_mask,
67
+ added_cond_kwargs=added_cond_kwargs,
68
+ cross_attention_kwargs=cross_attention_kwargs,
69
+ from_multi=True,
70
+ guess_mode=guess_mode,
71
+ return_dict=return_dict,
72
+ )
73
+
74
+ # merge samples
75
+ if down_block_res_samples is None and mid_block_res_sample is None:
76
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
77
+ else:
78
+ down_block_res_samples = [
79
+ samples_prev + samples_curr
80
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
81
+ ]
82
+ mid_block_res_sample += mid_sample
83
+
84
+ return down_block_res_samples, mid_block_res_sample
85
+
86
+ # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained with ControlNet->ControlNetUnion
87
+ def save_pretrained(
88
+ self,
89
+ save_directory: Union[str, os.PathLike],
90
+ is_main_process: bool = True,
91
+ save_function: Callable = None,
92
+ safe_serialization: bool = True,
93
+ variant: Optional[str] = None,
94
+ ):
95
+ """
96
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
97
+ `[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.from_pretrained`]` class method.
98
+
99
+ Arguments:
100
+ save_directory (`str` or `os.PathLike`):
101
+ Directory to which to save. Will be created if it doesn't exist.
102
+ is_main_process (`bool`, *optional*, defaults to `True`):
103
+ Whether the process calling this is the main process or not. Useful when in distributed training like
104
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
105
+ the main process to avoid race conditions.
106
+ save_function (`Callable`):
107
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
108
+ need to replace `torch.save` by another method. Can be configured with the environment variable
109
+ `DIFFUSERS_SAVE_MODE`.
110
+ safe_serialization (`bool`, *optional*, defaults to `True`):
111
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
112
+ variant (`str`, *optional*):
113
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
114
+ """
115
+ for idx, controlnet in enumerate(self.nets):
116
+ suffix = "" if idx == 0 else f"_{idx}"
117
+ controlnet.save_pretrained(
118
+ save_directory + suffix,
119
+ is_main_process=is_main_process,
120
+ save_function=save_function,
121
+ safe_serialization=safe_serialization,
122
+ variant=variant,
123
+ )
124
+
125
+ @classmethod
126
+ # Copied from diffusers.models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained with ControlNet->ControlNetUnion
127
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
128
+ r"""
129
+ Instantiate a pretrained MultiControlNetUnion model from multiple pre-trained controlnet models.
130
+
131
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
132
+ the model, you should first set it back in training mode with `model.train()`.
133
+
134
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
135
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
136
+ task.
137
+
138
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
139
+ weights are discarded.
140
+
141
+ Parameters:
142
+ pretrained_model_path (`os.PathLike`):
143
+ A path to a *directory* containing model weights saved using
144
+ [`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
145
+ `./my_model_directory/controlnet`.
146
+ torch_dtype (`str` or `torch.dtype`, *optional*):
147
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
148
+ will be automatically derived from the model's weights.
149
+ output_loading_info(`bool`, *optional*, defaults to `False`):
150
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
151
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
152
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
153
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
154
+ same device.
155
+
156
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
157
+ more information about each option see [designing a device
158
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
159
+ max_memory (`Dict`, *optional*):
160
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
161
+ GPU and the available CPU RAM if unset.
162
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
163
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
164
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
165
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
166
+ setting this argument to `True` will raise an error.
167
+ variant (`str`, *optional*):
168
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
169
+ ignored when using `from_flax`.
170
+ use_safetensors (`bool`, *optional*, defaults to `None`):
171
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
172
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
173
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
174
+ """
175
+ idx = 0
176
+ controlnets = []
177
+
178
+ # load controlnet and append to list until no controlnet directory exists anymore
179
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
180
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
181
+ model_path_to_load = pretrained_model_path
182
+ while os.path.isdir(model_path_to_load):
183
+ controlnet = ControlNetUnionModel.from_pretrained(model_path_to_load, **kwargs)
184
+ controlnets.append(controlnet)
185
+
186
+ idx += 1
187
+ model_path_to_load = pretrained_model_path + f"_{idx}"
188
+
189
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
190
+
191
+ if len(controlnets) == 0:
192
+ raise ValueError(
193
+ f"No ControlNetUnions found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
194
+ )
195
+
196
+ return cls(controlnets)