diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +198 -28
  13. diffusers/loaders/lora_conversion_utils.py +679 -44
  14. diffusers/loaders/lora_pipeline.py +1963 -801
  15. diffusers/loaders/peft.py +169 -84
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +653 -75
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +22 -32
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +10 -2
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +14 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.1.dist-info/RECORD +0 -550
  387. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -213,7 +213,9 @@ class Attention(nn.Module):
213
213
  self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
214
214
  self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
215
215
  else:
216
- raise ValueError(f"unknown qk_norm: {qk_norm}. Should be None,'layer_norm','fp32_layer_norm','rms_norm'")
216
+ raise ValueError(
217
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
218
+ )
217
219
 
218
220
  if cross_attention_norm is None:
219
221
  self.norm_cross = None
@@ -272,12 +274,20 @@ class Attention(nn.Module):
272
274
  self.to_add_out = None
273
275
 
274
276
  if qk_norm is not None and added_kv_proj_dim is not None:
275
- if qk_norm == "fp32_layer_norm":
277
+ if qk_norm == "layer_norm":
278
+ self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
279
+ self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
280
+ elif qk_norm == "fp32_layer_norm":
276
281
  self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
277
282
  self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
278
283
  elif qk_norm == "rms_norm":
279
284
  self.norm_added_q = RMSNorm(dim_head, eps=eps)
280
285
  self.norm_added_k = RMSNorm(dim_head, eps=eps)
286
+ elif qk_norm == "rms_norm_across_heads":
287
+ # Wan applies qk norm across all heads
288
+ # Wan also doesn't apply a q norm
289
+ self.norm_added_q = None
290
+ self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
281
291
  else:
282
292
  raise ValueError(
283
293
  f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
@@ -297,7 +307,10 @@ class Attention(nn.Module):
297
307
  self.set_processor(processor)
298
308
 
299
309
  def set_use_xla_flash_attention(
300
- self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
310
+ self,
311
+ use_xla_flash_attention: bool,
312
+ partition_spec: Optional[Tuple[Optional[str], ...]] = None,
313
+ is_flux=False,
301
314
  ) -> None:
302
315
  r"""
303
316
  Set whether to use xla flash attention from `torch_xla` or not.
@@ -316,7 +329,10 @@ class Attention(nn.Module):
316
329
  elif is_spmd() and is_torch_xla_version("<", "2.4"):
317
330
  raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318
331
  else:
319
- processor = XLAFlashAttnProcessor2_0(partition_spec)
332
+ if is_flux:
333
+ processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
334
+ else:
335
+ processor = XLAFlashAttnProcessor2_0(partition_spec)
320
336
  else:
321
337
  processor = (
322
338
  AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
@@ -399,11 +415,12 @@ class Attention(nn.Module):
399
415
  else:
400
416
  try:
401
417
  # Make sure we can run the memory efficient attention
402
- _ = xformers.ops.memory_efficient_attention(
403
- torch.randn((1, 2, 40), device="cuda"),
404
- torch.randn((1, 2, 40), device="cuda"),
405
- torch.randn((1, 2, 40), device="cuda"),
406
- )
418
+ dtype = None
419
+ if attention_op is not None:
420
+ op_fw, op_bw = attention_op
421
+ dtype, *_ = op_fw.SUPPORTED_DTYPES
422
+ q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
423
+ _ = xformers.ops.memory_efficient_attention(q, q, q)
407
424
  except Exception as e:
408
425
  raise e
409
426
 
@@ -724,10 +741,14 @@ class Attention(nn.Module):
724
741
 
725
742
  if out_dim == 3:
726
743
  if attention_mask.shape[0] < batch_size * head_size:
727
- attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
744
+ attention_mask = attention_mask.repeat_interleave(
745
+ head_size, dim=0, output_size=attention_mask.shape[0] * head_size
746
+ )
728
747
  elif out_dim == 4:
729
748
  attention_mask = attention_mask.unsqueeze(1)
730
- attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
749
+ attention_mask = attention_mask.repeat_interleave(
750
+ head_size, dim=1, output_size=attention_mask.shape[1] * head_size
751
+ )
731
752
 
732
753
  return attention_mask
733
754
 
@@ -899,7 +920,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
899
920
  scores = torch.matmul(key.transpose(-1, -2), query)
900
921
  scores = scores.to(dtype=torch.float32)
901
922
  scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
902
- hidden_states = torch.matmul(value, scores)
923
+ hidden_states = torch.matmul(value, scores.to(value.dtype))
903
924
  return hidden_states
904
925
 
905
926
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -1401,7 +1422,7 @@ class JointAttnProcessor2_0:
1401
1422
 
1402
1423
  def __init__(self):
1403
1424
  if not hasattr(F, "scaled_dot_product_attention"):
1404
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1425
+ raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1405
1426
 
1406
1427
  def __call__(
1407
1428
  self,
@@ -2321,6 +2342,7 @@ class FluxAttnProcessor2_0:
2321
2342
  hidden_states = F.scaled_dot_product_attention(
2322
2343
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2323
2344
  )
2345
+
2324
2346
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2325
2347
  hidden_states = hidden_states.to(query.dtype)
2326
2348
 
@@ -2522,6 +2544,7 @@ class FusedFluxAttnProcessor2_0:
2522
2544
  key = apply_rotary_emb(key, image_rotary_emb)
2523
2545
 
2524
2546
  hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2547
+
2525
2548
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2526
2549
  hidden_states = hidden_states.to(query.dtype)
2527
2550
 
@@ -2771,9 +2794,8 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2771
2794
 
2772
2795
  # IP-adapter
2773
2796
  ip_query = hidden_states_query_proj
2774
- ip_attn_output = None
2775
- # for ip-adapter
2776
- # TODO: support for multiple adapters
2797
+ ip_attn_output = torch.zeros_like(hidden_states)
2798
+
2777
2799
  for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
2778
2800
  ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
2779
2801
  ):
@@ -2784,12 +2806,14 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
2784
2806
  ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2785
2807
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
2786
2808
  # TODO: add support for attn.scale when we move to Torch 2.1
2787
- ip_attn_output = F.scaled_dot_product_attention(
2809
+ current_ip_hidden_states = F.scaled_dot_product_attention(
2788
2810
  ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2789
2811
  )
2790
- ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2791
- ip_attn_output = scale * ip_attn_output
2792
- ip_attn_output = ip_attn_output.to(ip_query.dtype)
2812
+ current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
2813
+ batch_size, -1, attn.heads * head_dim
2814
+ )
2815
+ current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
2816
+ ip_attn_output += scale * current_ip_hidden_states
2793
2817
 
2794
2818
  return hidden_states, encoder_hidden_states, ip_attn_output
2795
2819
  else:
@@ -2818,9 +2842,7 @@ class CogVideoXAttnProcessor2_0:
2818
2842
 
2819
2843
  hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
2820
2844
 
2821
- batch_size, sequence_length, _ = (
2822
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2823
- )
2845
+ batch_size, sequence_length, _ = hidden_states.shape
2824
2846
 
2825
2847
  if attention_mask is not None:
2826
2848
  attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
@@ -3148,6 +3170,11 @@ class AttnProcessorNPU:
3148
3170
  # scaled_dot_product_attention expects attention_mask shape to be
3149
3171
  # (batch, heads, source_length, target_length)
3150
3172
  attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
3173
+ attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
3174
+ if attention_mask.dtype == torch.bool:
3175
+ attention_mask = torch.logical_not(attention_mask.bool())
3176
+ else:
3177
+ attention_mask = attention_mask.bool()
3151
3178
 
3152
3179
  if attn.group_norm is not None:
3153
3180
  hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@@ -3422,6 +3449,106 @@ class XLAFlashAttnProcessor2_0:
3422
3449
  return hidden_states
3423
3450
 
3424
3451
 
3452
+ class XLAFluxFlashAttnProcessor2_0:
3453
+ r"""
3454
+ Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3455
+ """
3456
+
3457
+ def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3458
+ if not hasattr(F, "scaled_dot_product_attention"):
3459
+ raise ImportError(
3460
+ "XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3461
+ )
3462
+ if is_torch_xla_version("<", "2.3"):
3463
+ raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3464
+ if is_spmd() and is_torch_xla_version("<", "2.4"):
3465
+ raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3466
+ self.partition_spec = partition_spec
3467
+
3468
+ def __call__(
3469
+ self,
3470
+ attn: Attention,
3471
+ hidden_states: torch.FloatTensor,
3472
+ encoder_hidden_states: torch.FloatTensor = None,
3473
+ attention_mask: Optional[torch.FloatTensor] = None,
3474
+ image_rotary_emb: Optional[torch.Tensor] = None,
3475
+ ) -> torch.FloatTensor:
3476
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3477
+
3478
+ # `sample` projections.
3479
+ query = attn.to_q(hidden_states)
3480
+ key = attn.to_k(hidden_states)
3481
+ value = attn.to_v(hidden_states)
3482
+
3483
+ inner_dim = key.shape[-1]
3484
+ head_dim = inner_dim // attn.heads
3485
+
3486
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3487
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3488
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3489
+
3490
+ if attn.norm_q is not None:
3491
+ query = attn.norm_q(query)
3492
+ if attn.norm_k is not None:
3493
+ key = attn.norm_k(key)
3494
+
3495
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3496
+ if encoder_hidden_states is not None:
3497
+ # `context` projections.
3498
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
3499
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
3500
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3501
+
3502
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
3503
+ batch_size, -1, attn.heads, head_dim
3504
+ ).transpose(1, 2)
3505
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
3506
+ batch_size, -1, attn.heads, head_dim
3507
+ ).transpose(1, 2)
3508
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
3509
+ batch_size, -1, attn.heads, head_dim
3510
+ ).transpose(1, 2)
3511
+
3512
+ if attn.norm_added_q is not None:
3513
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
3514
+ if attn.norm_added_k is not None:
3515
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3516
+
3517
+ # attention
3518
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
3519
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
3520
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
3521
+
3522
+ if image_rotary_emb is not None:
3523
+ from .embeddings import apply_rotary_emb
3524
+
3525
+ query = apply_rotary_emb(query, image_rotary_emb)
3526
+ key = apply_rotary_emb(key, image_rotary_emb)
3527
+
3528
+ query /= math.sqrt(head_dim)
3529
+ hidden_states = flash_attention(query, key, value, causal=False)
3530
+
3531
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3532
+ hidden_states = hidden_states.to(query.dtype)
3533
+
3534
+ if encoder_hidden_states is not None:
3535
+ encoder_hidden_states, hidden_states = (
3536
+ hidden_states[:, : encoder_hidden_states.shape[1]],
3537
+ hidden_states[:, encoder_hidden_states.shape[1] :],
3538
+ )
3539
+
3540
+ # linear proj
3541
+ hidden_states = attn.to_out[0](hidden_states)
3542
+ # dropout
3543
+ hidden_states = attn.to_out[1](hidden_states)
3544
+
3545
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3546
+
3547
+ return hidden_states, encoder_hidden_states
3548
+ else:
3549
+ return hidden_states
3550
+
3551
+
3425
3552
  class MochiVaeAttnProcessor2_0:
3426
3553
  r"""
3427
3554
  Attention processor used in Mochi VAE.
@@ -3583,8 +3710,10 @@ class StableAudioAttnProcessor2_0:
3583
3710
  if kv_heads != attn.heads:
3584
3711
  # if GQA or MQA, repeat the key/value heads to reach the number of query heads.
3585
3712
  heads_per_kv_head = attn.heads // kv_heads
3586
- key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
3587
- value = torch.repeat_interleave(value, heads_per_kv_head, dim=1)
3713
+ key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
3714
+ value = torch.repeat_interleave(
3715
+ value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
3716
+ )
3588
3717
 
3589
3718
  if attn.norm_q is not None:
3590
3719
  query = attn.norm_q(query)
@@ -4839,6 +4968,8 @@ class IPAdapterAttnProcessor(nn.Module):
4839
4968
  )
4840
4969
  else:
4841
4970
  for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
4971
+ if mask is None:
4972
+ continue
4842
4973
  if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
4843
4974
  raise ValueError(
4844
4975
  "Each element of the ip_adapter_masks array should be a tensor with shape "
@@ -5056,6 +5187,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
5056
5187
  )
5057
5188
  else:
5058
5189
  for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
5190
+ if mask is None:
5191
+ continue
5059
5192
  if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
5060
5193
  raise ValueError(
5061
5194
  "Each element of the ip_adapter_masks array should be a tensor with shape "
@@ -5887,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
5887
6020
  key = attn.to_k(encoder_hidden_states)
5888
6021
  value = attn.to_v(encoder_hidden_states)
5889
6022
 
6023
+ if attn.norm_q is not None:
6024
+ query = attn.norm_q(query)
6025
+ if attn.norm_k is not None:
6026
+ key = attn.norm_k(key)
6027
+
5890
6028
  query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5891
6029
  key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5892
6030
  value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
@@ -0,0 +1,169 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import importlib
16
+ import os
17
+ from typing import Optional, Union
18
+
19
+ from huggingface_hub.utils import validate_hf_hub_args
20
+
21
+ from ..configuration_utils import ConfigMixin
22
+
23
+
24
+ class AutoModel(ConfigMixin):
25
+ config_name = "config.json"
26
+
27
+ def __init__(self, *args, **kwargs):
28
+ raise EnvironmentError(
29
+ f"{self.__class__.__name__} is designed to be instantiated "
30
+ f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
31
+ f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
32
+ )
33
+
34
+ @classmethod
35
+ @validate_hf_hub_args
36
+ def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
37
+ r"""
38
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
39
+
40
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
41
+ train the model, set it back in training mode with `model.train()`.
42
+
43
+ Parameters:
44
+ pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
45
+ Can be either:
46
+
47
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
48
+ the Hub.
49
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
50
+ with [`~ModelMixin.save_pretrained`].
51
+
52
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
53
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
54
+ is not used.
55
+ torch_dtype (`str` or `torch.dtype`, *optional*):
56
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
57
+ dtype is automatically derived from the model's weights.
58
+ force_download (`bool`, *optional*, defaults to `False`):
59
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
60
+ cached versions if they exist.
61
+ proxies (`Dict[str, str]`, *optional*):
62
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
63
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
64
+ output_loading_info (`bool`, *optional*, defaults to `False`):
65
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
66
+ local_files_only(`bool`, *optional*, defaults to `False`):
67
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
68
+ won't be downloaded from the Hub.
69
+ token (`str` or *bool*, *optional*):
70
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
71
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
72
+ revision (`str`, *optional*, defaults to `"main"`):
73
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
74
+ allowed by Git.
75
+ from_flax (`bool`, *optional*, defaults to `False`):
76
+ Load the model weights from a Flax checkpoint save file.
77
+ subfolder (`str`, *optional*, defaults to `""`):
78
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
79
+ mirror (`str`, *optional*):
80
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
81
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
82
+ information.
83
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
84
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
85
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
86
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
87
+
88
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
89
+ more information about each option see [designing a device
90
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
91
+ max_memory (`Dict`, *optional*):
92
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
93
+ each GPU and the available CPU RAM if unset.
94
+ offload_folder (`str` or `os.PathLike`, *optional*):
95
+ The path to offload weights if `device_map` contains the value `"disk"`.
96
+ offload_state_dict (`bool`, *optional*):
97
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
98
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
99
+ when there is some disk offload.
100
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
101
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
102
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
103
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
104
+ argument to `True` will raise an error.
105
+ variant (`str`, *optional*):
106
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
107
+ loading `from_flax`.
108
+ use_safetensors (`bool`, *optional*, defaults to `None`):
109
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
110
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
111
+ weights. If set to `False`, `safetensors` weights are not loaded.
112
+ disable_mmap ('bool', *optional*, defaults to 'False'):
113
+ Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
114
+ is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
115
+
116
+ <Tip>
117
+
118
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
119
+ `huggingface-cli login`. You can also activate the special
120
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
121
+ firewalled environment.
122
+
123
+ </Tip>
124
+
125
+ Example:
126
+
127
+ ```py
128
+ from diffusers import AutoModel
129
+
130
+ unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
131
+ ```
132
+
133
+ If you get the error message below, you need to finetune the weights for your downstream task:
134
+
135
+ ```bash
136
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
137
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
138
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
139
+ ```
140
+ """
141
+ cache_dir = kwargs.pop("cache_dir", None)
142
+ force_download = kwargs.pop("force_download", False)
143
+ proxies = kwargs.pop("proxies", None)
144
+ token = kwargs.pop("token", None)
145
+ local_files_only = kwargs.pop("local_files_only", False)
146
+ revision = kwargs.pop("revision", None)
147
+ subfolder = kwargs.pop("subfolder", None)
148
+
149
+ load_config_kwargs = {
150
+ "cache_dir": cache_dir,
151
+ "force_download": force_download,
152
+ "proxies": proxies,
153
+ "token": token,
154
+ "local_files_only": local_files_only,
155
+ "revision": revision,
156
+ "subfolder": subfolder,
157
+ }
158
+
159
+ config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
160
+ orig_class_name = config["_class_name"]
161
+
162
+ library = importlib.import_module("diffusers")
163
+
164
+ model_cls = getattr(library, orig_class_name, None)
165
+ if model_cls is None:
166
+ raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
167
+
168
+ kwargs = {**load_config_kwargs, **kwargs}
169
+ return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
@@ -5,8 +5,10 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
5
5
  from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
6
6
  from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
7
7
  from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
8
+ from .autoencoder_kl_magvit import AutoencoderKLMagvit
8
9
  from .autoencoder_kl_mochi import AutoencoderKLMochi
9
10
  from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
11
+ from .autoencoder_kl_wan import AutoencoderKLWan
10
12
  from .autoencoder_oobleck import AutoencoderOobleck
11
13
  from .autoencoder_tiny import AutoencoderTiny
12
14
  from .consistency_decoder_vae import ConsistencyDecoderVAE
@@ -60,6 +60,8 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
60
60
  Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
61
61
  """
62
62
 
63
+ _skip_layerwise_casting_patterns = ["decoder"]
64
+
63
65
  @register_to_config
64
66
  def __init__(
65
67
  self,
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
190
190
  x = F.pixel_shuffle(x, self.factor)
191
191
 
192
192
  if self.shortcut:
193
- y = hidden_states.repeat_interleave(self.repeats, dim=1)
193
+ y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
194
194
  y = F.pixel_shuffle(y, self.factor)
195
195
  hidden_states = x + y
196
196
  else:
@@ -361,7 +361,9 @@ class Decoder(nn.Module):
361
361
 
362
362
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
363
363
  if self.in_shortcut:
364
- x = hidden_states.repeat_interleave(self.in_shortcut_repeats, dim=1)
364
+ x = hidden_states.repeat_interleave(
365
+ self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
366
+ )
365
367
  hidden_states = self.conv_in(hidden_states) + x
366
368
  else:
367
369
  hidden_states = self.conv_in(hidden_states)
@@ -486,6 +488,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
486
488
  self.tile_sample_stride_height = 448
487
489
  self.tile_sample_stride_width = 448
488
490
 
491
+ self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
492
+ self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
493
+
489
494
  def enable_tiling(
490
495
  self,
491
496
  tile_sample_min_height: Optional[int] = None,
@@ -515,6 +520,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
515
520
  self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
516
521
  self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
517
522
  self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
523
+ self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
524
+ self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
518
525
 
519
526
  def disable_tiling(self) -> None:
520
527
  r"""
@@ -606,11 +613,106 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
606
613
  return (decoded,)
607
614
  return DecoderOutput(sample=decoded)
608
615
 
616
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
617
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
618
+ for y in range(blend_extent):
619
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
620
+ return b
621
+
622
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
623
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
624
+ for x in range(blend_extent):
625
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
626
+ return b
627
+
609
628
  def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
610
- raise NotImplementedError("`tiled_encode` has not been implemented for AutoencoderDC.")
629
+ batch_size, num_channels, height, width = x.shape
630
+ latent_height = height // self.spatial_compression_ratio
631
+ latent_width = width // self.spatial_compression_ratio
632
+
633
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
634
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
635
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
636
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
637
+ blend_height = tile_latent_min_height - tile_latent_stride_height
638
+ blend_width = tile_latent_min_width - tile_latent_stride_width
639
+
640
+ # Split x into overlapping tiles and encode them separately.
641
+ # The tiles have an overlap to avoid seams between tiles.
642
+ rows = []
643
+ for i in range(0, x.shape[2], self.tile_sample_stride_height):
644
+ row = []
645
+ for j in range(0, x.shape[3], self.tile_sample_stride_width):
646
+ tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
647
+ if (
648
+ tile.shape[2] % self.spatial_compression_ratio != 0
649
+ or tile.shape[3] % self.spatial_compression_ratio != 0
650
+ ):
651
+ pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
652
+ pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
653
+ tile = F.pad(tile, (0, pad_w, 0, pad_h))
654
+ tile = self.encoder(tile)
655
+ row.append(tile)
656
+ rows.append(row)
657
+ result_rows = []
658
+ for i, row in enumerate(rows):
659
+ result_row = []
660
+ for j, tile in enumerate(row):
661
+ # blend the above tile and the left tile
662
+ # to the current tile and add the current tile to the result row
663
+ if i > 0:
664
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
665
+ if j > 0:
666
+ tile = self.blend_h(row[j - 1], tile, blend_width)
667
+ result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
668
+ result_rows.append(torch.cat(result_row, dim=3))
669
+
670
+ encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
671
+
672
+ if not return_dict:
673
+ return (encoded,)
674
+ return EncoderOutput(latent=encoded)
611
675
 
612
676
  def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
613
- raise NotImplementedError("`tiled_decode` has not been implemented for AutoencoderDC.")
677
+ batch_size, num_channels, height, width = z.shape
678
+
679
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
680
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
681
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
682
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
683
+
684
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
685
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
686
+
687
+ # Split z into overlapping tiles and decode them separately.
688
+ # The tiles have an overlap to avoid seams between tiles.
689
+ rows = []
690
+ for i in range(0, height, tile_latent_stride_height):
691
+ row = []
692
+ for j in range(0, width, tile_latent_stride_width):
693
+ tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
694
+ decoded = self.decoder(tile)
695
+ row.append(decoded)
696
+ rows.append(row)
697
+
698
+ result_rows = []
699
+ for i, row in enumerate(rows):
700
+ result_row = []
701
+ for j, tile in enumerate(row):
702
+ # blend the above tile and the left tile
703
+ # to the current tile and add the current tile to the result row
704
+ if i > 0:
705
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
706
+ if j > 0:
707
+ tile = self.blend_h(row[j - 1], tile, blend_width)
708
+ result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
709
+ result_rows.append(torch.cat(result_row, dim=3))
710
+
711
+ decoded = torch.cat(result_rows, dim=2)
712
+
713
+ if not return_dict:
714
+ return (decoded,)
715
+ return DecoderOutput(sample=decoded)
614
716
 
615
717
  def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
616
718
  encoded = self.encode(sample, return_dict=False)[0]