diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -317,6 +317,7 @@ class AutoencoderOobleck(ModelMixin, ConfigMixin):
317
317
  """
318
318
 
319
319
  _supports_gradient_checkpointing = False
320
+ _supports_group_offloading = False
320
321
 
321
322
  @register_to_config
322
323
  def __init__(
@@ -154,10 +154,6 @@ class AutoencoderTiny(ModelMixin, ConfigMixin):
154
154
  self.register_to_config(block_out_channels=decoder_block_out_channels)
155
155
  self.register_to_config(force_upcast=False)
156
156
 
157
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
158
- if isinstance(module, (EncoderTiny, DecoderTiny)):
159
- module.gradient_checkpointing = value
160
-
161
157
  def scale_latents(self, x: torch.Tensor) -> torch.Tensor:
162
158
  """raw latents -> [0, 1]"""
163
159
  return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1)
@@ -60,7 +60,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
60
60
 
61
61
  >>> vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
62
62
  >>> pipe = StableDiffusionPipeline.from_pretrained(
63
- ... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
63
+ ... "stable-diffusion-v1-5/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
64
64
  ... ).to("cuda")
65
65
 
66
66
  >>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
@@ -68,6 +68,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
68
68
  ```
69
69
  """
70
70
 
71
+ _supports_group_offloading = False
72
+
71
73
  @register_to_config
72
74
  def __init__(
73
75
  self,
@@ -18,7 +18,7 @@ import numpy as np
18
18
  import torch
19
19
  import torch.nn as nn
20
20
 
21
- from ...utils import BaseOutput, is_torch_version
21
+ from ...utils import BaseOutput
22
22
  from ...utils.torch_utils import randn_tensor
23
23
  from ..activations import get_activation
24
24
  from ..attention_processor import SpatialNorm
@@ -156,28 +156,11 @@ class Encoder(nn.Module):
156
156
  sample = self.conv_in(sample)
157
157
 
158
158
  if torch.is_grad_enabled() and self.gradient_checkpointing:
159
-
160
- def create_custom_forward(module):
161
- def custom_forward(*inputs):
162
- return module(*inputs)
163
-
164
- return custom_forward
165
-
166
159
  # down
167
- if is_torch_version(">=", "1.11.0"):
168
- for down_block in self.down_blocks:
169
- sample = torch.utils.checkpoint.checkpoint(
170
- create_custom_forward(down_block), sample, use_reentrant=False
171
- )
172
- # middle
173
- sample = torch.utils.checkpoint.checkpoint(
174
- create_custom_forward(self.mid_block), sample, use_reentrant=False
175
- )
176
- else:
177
- for down_block in self.down_blocks:
178
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
179
- # middle
180
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
160
+ for down_block in self.down_blocks:
161
+ sample = self._gradient_checkpointing_func(down_block, sample)
162
+ # middle
163
+ sample = self._gradient_checkpointing_func(self.mid_block, sample)
181
164
 
182
165
  else:
183
166
  # down
@@ -305,41 +288,13 @@ class Decoder(nn.Module):
305
288
 
306
289
  upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
307
290
  if torch.is_grad_enabled() and self.gradient_checkpointing:
291
+ # middle
292
+ sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
293
+ sample = sample.to(upscale_dtype)
308
294
 
309
- def create_custom_forward(module):
310
- def custom_forward(*inputs):
311
- return module(*inputs)
312
-
313
- return custom_forward
314
-
315
- if is_torch_version(">=", "1.11.0"):
316
- # middle
317
- sample = torch.utils.checkpoint.checkpoint(
318
- create_custom_forward(self.mid_block),
319
- sample,
320
- latent_embeds,
321
- use_reentrant=False,
322
- )
323
- sample = sample.to(upscale_dtype)
324
-
325
- # up
326
- for up_block in self.up_blocks:
327
- sample = torch.utils.checkpoint.checkpoint(
328
- create_custom_forward(up_block),
329
- sample,
330
- latent_embeds,
331
- use_reentrant=False,
332
- )
333
- else:
334
- # middle
335
- sample = torch.utils.checkpoint.checkpoint(
336
- create_custom_forward(self.mid_block), sample, latent_embeds
337
- )
338
- sample = sample.to(upscale_dtype)
339
-
340
- # up
341
- for up_block in self.up_blocks:
342
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
295
+ # up
296
+ for up_block in self.up_blocks:
297
+ sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
343
298
  else:
344
299
  # middle
345
300
  sample = self.mid_block(sample, latent_embeds)
@@ -558,72 +513,28 @@ class MaskConditionDecoder(nn.Module):
558
513
 
559
514
  upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
560
515
  if torch.is_grad_enabled() and self.gradient_checkpointing:
516
+ # middle
517
+ sample = self._gradient_checkpointing_func(self.mid_block, sample, latent_embeds)
518
+ sample = sample.to(upscale_dtype)
561
519
 
562
- def create_custom_forward(module):
563
- def custom_forward(*inputs):
564
- return module(*inputs)
565
-
566
- return custom_forward
567
-
568
- if is_torch_version(">=", "1.11.0"):
569
- # middle
570
- sample = torch.utils.checkpoint.checkpoint(
571
- create_custom_forward(self.mid_block),
572
- sample,
573
- latent_embeds,
574
- use_reentrant=False,
575
- )
576
- sample = sample.to(upscale_dtype)
577
-
578
- # condition encoder
579
- if image is not None and mask is not None:
580
- masked_image = (1 - mask) * image
581
- im_x = torch.utils.checkpoint.checkpoint(
582
- create_custom_forward(self.condition_encoder),
583
- masked_image,
584
- mask,
585
- use_reentrant=False,
586
- )
587
-
588
- # up
589
- for up_block in self.up_blocks:
590
- if image is not None and mask is not None:
591
- sample_ = im_x[str(tuple(sample.shape))]
592
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
593
- sample = sample * mask_ + sample_ * (1 - mask_)
594
- sample = torch.utils.checkpoint.checkpoint(
595
- create_custom_forward(up_block),
596
- sample,
597
- latent_embeds,
598
- use_reentrant=False,
599
- )
600
- if image is not None and mask is not None:
601
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
602
- else:
603
- # middle
604
- sample = torch.utils.checkpoint.checkpoint(
605
- create_custom_forward(self.mid_block), sample, latent_embeds
520
+ # condition encoder
521
+ if image is not None and mask is not None:
522
+ masked_image = (1 - mask) * image
523
+ im_x = self._gradient_checkpointing_func(
524
+ self.condition_encoder,
525
+ masked_image,
526
+ mask,
606
527
  )
607
- sample = sample.to(upscale_dtype)
608
528
 
609
- # condition encoder
610
- if image is not None and mask is not None:
611
- masked_image = (1 - mask) * image
612
- im_x = torch.utils.checkpoint.checkpoint(
613
- create_custom_forward(self.condition_encoder),
614
- masked_image,
615
- mask,
616
- )
617
-
618
- # up
619
- for up_block in self.up_blocks:
620
- if image is not None and mask is not None:
621
- sample_ = im_x[str(tuple(sample.shape))]
622
- mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
623
- sample = sample * mask_ + sample_ * (1 - mask_)
624
- sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
529
+ # up
530
+ for up_block in self.up_blocks:
625
531
  if image is not None and mask is not None:
626
- sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
532
+ sample_ = im_x[str(tuple(sample.shape))]
533
+ mask_ = nn.functional.interpolate(mask, size=sample.shape[-2:], mode="nearest")
534
+ sample = sample * mask_ + sample_ * (1 - mask_)
535
+ sample = self._gradient_checkpointing_func(up_block, sample, latent_embeds)
536
+ if image is not None and mask is not None:
537
+ sample = sample * mask + im_x[str(tuple(sample.shape))] * (1 - mask)
627
538
  else:
628
539
  # middle
629
540
  sample = self.mid_block(sample, latent_embeds)
@@ -890,17 +801,7 @@ class EncoderTiny(nn.Module):
890
801
  def forward(self, x: torch.Tensor) -> torch.Tensor:
891
802
  r"""The forward method of the `EncoderTiny` class."""
892
803
  if torch.is_grad_enabled() and self.gradient_checkpointing:
893
-
894
- def create_custom_forward(module):
895
- def custom_forward(*inputs):
896
- return module(*inputs)
897
-
898
- return custom_forward
899
-
900
- if is_torch_version(">=", "1.11.0"):
901
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
902
- else:
903
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
804
+ x = self._gradient_checkpointing_func(self.layers, x)
904
805
 
905
806
  else:
906
807
  # scale image from [-1, 1] to [0, 1] to match TAESD convention
@@ -976,18 +877,7 @@ class DecoderTiny(nn.Module):
976
877
  x = torch.tanh(x / 3) * 3
977
878
 
978
879
  if torch.is_grad_enabled() and self.gradient_checkpointing:
979
-
980
- def create_custom_forward(module):
981
- def custom_forward(*inputs):
982
- return module(*inputs)
983
-
984
- return custom_forward
985
-
986
- if is_torch_version(">=", "1.11.0"):
987
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x, use_reentrant=False)
988
- else:
989
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.layers), x)
990
-
880
+ x = self._gradient_checkpointing_func(self.layers, x)
991
881
  else:
992
882
  x = self.layers(x)
993
883
 
@@ -71,6 +71,9 @@ class VQModel(ModelMixin, ConfigMixin):
71
71
  Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72
72
  """
73
73
 
74
+ _skip_layerwise_casting_patterns = ["quantize"]
75
+ _supports_group_offloading = False
76
+
74
77
  @register_to_config
75
78
  def __init__(
76
79
  self,
@@ -0,0 +1,108 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from ..utils.logging import get_logger
16
+
17
+
18
+ logger = get_logger(__name__) # pylint: disable=invalid-name
19
+
20
+
21
+ class CacheMixin:
22
+ r"""
23
+ A class for enable/disabling caching techniques on diffusion models.
24
+
25
+ Supported caching techniques:
26
+ - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
27
+ - [FasterCache](https://huggingface.co/papers/2410.19355)
28
+ """
29
+
30
+ _cache_config = None
31
+
32
+ @property
33
+ def is_cache_enabled(self) -> bool:
34
+ return self._cache_config is not None
35
+
36
+ def enable_cache(self, config) -> None:
37
+ r"""
38
+ Enable caching techniques on the model.
39
+
40
+ Args:
41
+ config (`Union[PyramidAttentionBroadcastConfig]`):
42
+ The configuration for applying the caching technique. Currently supported caching techniques are:
43
+ - [`~hooks.PyramidAttentionBroadcastConfig`]
44
+
45
+ Example:
46
+
47
+ ```python
48
+ >>> import torch
49
+ >>> from diffusers import CogVideoXPipeline, PyramidAttentionBroadcastConfig
50
+
51
+ >>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
52
+ >>> pipe.to("cuda")
53
+
54
+ >>> config = PyramidAttentionBroadcastConfig(
55
+ ... spatial_attention_block_skip_range=2,
56
+ ... spatial_attention_timestep_skip_range=(100, 800),
57
+ ... current_timestep_callback=lambda: pipe.current_timestep,
58
+ ... )
59
+ >>> pipe.transformer.enable_cache(config)
60
+ ```
61
+ """
62
+
63
+ from ..hooks import (
64
+ FasterCacheConfig,
65
+ PyramidAttentionBroadcastConfig,
66
+ apply_faster_cache,
67
+ apply_pyramid_attention_broadcast,
68
+ )
69
+
70
+ if self.is_cache_enabled:
71
+ raise ValueError(
72
+ f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
73
+ )
74
+
75
+ if isinstance(config, PyramidAttentionBroadcastConfig):
76
+ apply_pyramid_attention_broadcast(self, config)
77
+ elif isinstance(config, FasterCacheConfig):
78
+ apply_faster_cache(self, config)
79
+ else:
80
+ raise ValueError(f"Cache config {type(config)} is not supported.")
81
+
82
+ self._cache_config = config
83
+
84
+ def disable_cache(self) -> None:
85
+ from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
86
+ from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
87
+ from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
88
+
89
+ if self._cache_config is None:
90
+ logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
91
+ return
92
+
93
+ if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
94
+ registry = HookRegistry.check_if_exists_or_initialize(self)
95
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
96
+ elif isinstance(self._cache_config, FasterCacheConfig):
97
+ registry = HookRegistry.check_if_exists_or_initialize(self)
98
+ registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
99
+ registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
100
+ else:
101
+ raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
102
+
103
+ self._cache_config = None
104
+
105
+ def _reset_stateful_cache(self, recurse: bool = True) -> None:
106
+ from ..hooks import HookRegistry
107
+
108
+ HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
@@ -18,6 +18,7 @@ if is_torch_available():
18
18
  from .controlnet_union import ControlNetUnionModel
19
19
  from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
20
20
  from .multicontrolnet import MultiControlNetModel
21
+ from .multicontrolnet_union import MultiControlNetUnionModel
21
22
 
22
23
  if is_flax_available():
23
24
  from .controlnet_flax import FlaxControlNetModel
@@ -31,8 +31,6 @@ from ..attention_processor import (
31
31
  from ..embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
32
32
  from ..modeling_utils import ModelMixin
33
33
  from ..unets.unet_2d_blocks import (
34
- CrossAttnDownBlock2D,
35
- DownBlock2D,
36
34
  UNetMidBlock2D,
37
35
  UNetMidBlock2DCrossAttn,
38
36
  get_down_block,
@@ -659,10 +657,6 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
659
657
  for module in self.children():
660
658
  fn_recursive_set_attention_slice(module, reversed_slice_size)
661
659
 
662
- def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
663
- if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
664
- module.gradient_checkpointing = value
665
-
666
660
  def forward(
667
661
  self,
668
662
  sample: torch.Tensor,
@@ -740,10 +734,11 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
740
734
  # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
741
735
  # This would be a good case for the `match` statement (Python 3.10+)
742
736
  is_mps = sample.device.type == "mps"
737
+ is_npu = sample.device.type == "npu"
743
738
  if isinstance(timestep, float):
744
- dtype = torch.float32 if is_mps else torch.float64
739
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
745
740
  else:
746
- dtype = torch.int32 if is_mps else torch.int64
741
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
747
742
  timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
748
743
  elif len(timesteps.shape) == 0:
749
744
  timesteps = timesteps[None].to(sample.device)
@@ -22,7 +22,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
23
  from ...models.attention_processor import AttentionProcessor
24
24
  from ...models.modeling_utils import ModelMixin
25
- from ...utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
25
+ from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
26
26
  from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
27
27
  from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
28
  from ..modeling_outputs import Transformer2DModelOutput
@@ -178,10 +178,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
178
178
  for name, module in self.named_children():
179
179
  fn_recursive_attn_processor(name, module, processor)
180
180
 
181
- def _set_gradient_checkpointing(self, module, value=False):
182
- if hasattr(module, "gradient_checkpointing"):
183
- module.gradient_checkpointing = value
184
-
185
181
  @classmethod
186
182
  def from_transformer(
187
183
  cls,
@@ -302,15 +298,6 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
302
298
  )
303
299
  encoder_hidden_states = self.context_embedder(encoder_hidden_states)
304
300
 
305
- if self.union:
306
- # union mode
307
- if controlnet_mode is None:
308
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
309
- # union mode emb
310
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
311
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
312
- txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
313
-
314
301
  if txt_ids.ndim == 3:
315
302
  logger.warning(
316
303
  "Passing `txt_ids` 3d torch.Tensor is deprecated."
@@ -324,30 +311,27 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
324
311
  )
325
312
  img_ids = img_ids[0]
326
313
 
314
+ if self.union:
315
+ # union mode
316
+ if controlnet_mode is None:
317
+ raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
318
+ # union mode emb
319
+ controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
320
+ encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
321
+ txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
322
+
327
323
  ids = torch.cat((txt_ids, img_ids), dim=0)
328
324
  image_rotary_emb = self.pos_embed(ids)
329
325
 
330
326
  block_samples = ()
331
327
  for index_block, block in enumerate(self.transformer_blocks):
332
328
  if torch.is_grad_enabled() and self.gradient_checkpointing:
333
-
334
- def create_custom_forward(module, return_dict=None):
335
- def custom_forward(*inputs):
336
- if return_dict is not None:
337
- return module(*inputs, return_dict=return_dict)
338
- else:
339
- return module(*inputs)
340
-
341
- return custom_forward
342
-
343
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
- create_custom_forward(block),
329
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
330
+ block,
346
331
  hidden_states,
347
332
  encoder_hidden_states,
348
333
  temb,
349
334
  image_rotary_emb,
350
- **ckpt_kwargs,
351
335
  )
352
336
 
353
337
  else:
@@ -364,23 +348,11 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
364
348
  single_block_samples = ()
365
349
  for index_block, block in enumerate(self.single_transformer_blocks):
366
350
  if torch.is_grad_enabled() and self.gradient_checkpointing:
367
-
368
- def create_custom_forward(module, return_dict=None):
369
- def custom_forward(*inputs):
370
- if return_dict is not None:
371
- return module(*inputs, return_dict=return_dict)
372
- else:
373
- return module(*inputs)
374
-
375
- return custom_forward
376
-
377
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
- hidden_states = torch.utils.checkpoint.checkpoint(
379
- create_custom_forward(block),
351
+ hidden_states = self._gradient_checkpointing_func(
352
+ block,
380
353
  hidden_states,
381
354
  temb,
382
355
  image_rotary_emb,
383
- **ckpt_kwargs,
384
356
  )
385
357
 
386
358
  else: