diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,469 @@
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from ..attention import FeedForward
26
+ from ..attention_processor import Attention
27
+ from ..cache_utils import CacheMixin
28
+ from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
29
+ from ..modeling_outputs import Transformer2DModelOutput
30
+ from ..modeling_utils import ModelMixin
31
+ from ..normalization import FP32LayerNorm
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ class WanAttnProcessor2_0:
38
+ def __init__(self):
39
+ if not hasattr(F, "scaled_dot_product_attention"):
40
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
41
+
42
+ def __call__(
43
+ self,
44
+ attn: Attention,
45
+ hidden_states: torch.Tensor,
46
+ encoder_hidden_states: Optional[torch.Tensor] = None,
47
+ attention_mask: Optional[torch.Tensor] = None,
48
+ rotary_emb: Optional[torch.Tensor] = None,
49
+ ) -> torch.Tensor:
50
+ encoder_hidden_states_img = None
51
+ if attn.add_k_proj is not None:
52
+ encoder_hidden_states_img = encoder_hidden_states[:, :257]
53
+ encoder_hidden_states = encoder_hidden_states[:, 257:]
54
+ if encoder_hidden_states is None:
55
+ encoder_hidden_states = hidden_states
56
+
57
+ query = attn.to_q(hidden_states)
58
+ key = attn.to_k(encoder_hidden_states)
59
+ value = attn.to_v(encoder_hidden_states)
60
+
61
+ if attn.norm_q is not None:
62
+ query = attn.norm_q(query)
63
+ if attn.norm_k is not None:
64
+ key = attn.norm_k(key)
65
+
66
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
69
+
70
+ if rotary_emb is not None:
71
+
72
+ def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
73
+ x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
74
+ x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
75
+ return x_out.type_as(hidden_states)
76
+
77
+ query = apply_rotary_emb(query, rotary_emb)
78
+ key = apply_rotary_emb(key, rotary_emb)
79
+
80
+ # I2V task
81
+ hidden_states_img = None
82
+ if encoder_hidden_states_img is not None:
83
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
84
+ key_img = attn.norm_added_k(key_img)
85
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
86
+
87
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
88
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
89
+
90
+ hidden_states_img = F.scaled_dot_product_attention(
91
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
92
+ )
93
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
94
+ hidden_states_img = hidden_states_img.type_as(query)
95
+
96
+ hidden_states = F.scaled_dot_product_attention(
97
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
98
+ )
99
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
100
+ hidden_states = hidden_states.type_as(query)
101
+
102
+ if hidden_states_img is not None:
103
+ hidden_states = hidden_states + hidden_states_img
104
+
105
+ hidden_states = attn.to_out[0](hidden_states)
106
+ hidden_states = attn.to_out[1](hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class WanImageEmbedding(torch.nn.Module):
111
+ def __init__(self, in_features: int, out_features: int):
112
+ super().__init__()
113
+
114
+ self.norm1 = FP32LayerNorm(in_features)
115
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
116
+ self.norm2 = FP32LayerNorm(out_features)
117
+
118
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
119
+ hidden_states = self.norm1(encoder_hidden_states_image)
120
+ hidden_states = self.ff(hidden_states)
121
+ hidden_states = self.norm2(hidden_states)
122
+ return hidden_states
123
+
124
+
125
+ class WanTimeTextImageEmbedding(nn.Module):
126
+ def __init__(
127
+ self,
128
+ dim: int,
129
+ time_freq_dim: int,
130
+ time_proj_dim: int,
131
+ text_embed_dim: int,
132
+ image_embed_dim: Optional[int] = None,
133
+ ):
134
+ super().__init__()
135
+
136
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
137
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
138
+ self.act_fn = nn.SiLU()
139
+ self.time_proj = nn.Linear(dim, time_proj_dim)
140
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
141
+
142
+ self.image_embedder = None
143
+ if image_embed_dim is not None:
144
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim)
145
+
146
+ def forward(
147
+ self,
148
+ timestep: torch.Tensor,
149
+ encoder_hidden_states: torch.Tensor,
150
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
151
+ ):
152
+ timestep = self.timesteps_proj(timestep)
153
+
154
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
155
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
156
+ timestep = timestep.to(time_embedder_dtype)
157
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
158
+ timestep_proj = self.time_proj(self.act_fn(temb))
159
+
160
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
161
+ if encoder_hidden_states_image is not None:
162
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
163
+
164
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
165
+
166
+
167
+ class WanRotaryPosEmbed(nn.Module):
168
+ def __init__(
169
+ self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
170
+ ):
171
+ super().__init__()
172
+
173
+ self.attention_head_dim = attention_head_dim
174
+ self.patch_size = patch_size
175
+ self.max_seq_len = max_seq_len
176
+
177
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
178
+ t_dim = attention_head_dim - h_dim - w_dim
179
+
180
+ freqs = []
181
+ for dim in [t_dim, h_dim, w_dim]:
182
+ freq = get_1d_rotary_pos_embed(
183
+ dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
184
+ )
185
+ freqs.append(freq)
186
+ self.freqs = torch.cat(freqs, dim=1)
187
+
188
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
189
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
190
+ p_t, p_h, p_w = self.patch_size
191
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
192
+
193
+ self.freqs = self.freqs.to(hidden_states.device)
194
+ freqs = self.freqs.split_with_sizes(
195
+ [
196
+ self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
197
+ self.attention_head_dim // 6,
198
+ self.attention_head_dim // 6,
199
+ ],
200
+ dim=1,
201
+ )
202
+
203
+ freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
204
+ freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
205
+ freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
206
+ freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
207
+ return freqs
208
+
209
+
210
+ class WanTransformerBlock(nn.Module):
211
+ def __init__(
212
+ self,
213
+ dim: int,
214
+ ffn_dim: int,
215
+ num_heads: int,
216
+ qk_norm: str = "rms_norm_across_heads",
217
+ cross_attn_norm: bool = False,
218
+ eps: float = 1e-6,
219
+ added_kv_proj_dim: Optional[int] = None,
220
+ ):
221
+ super().__init__()
222
+
223
+ # 1. Self-attention
224
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
225
+ self.attn1 = Attention(
226
+ query_dim=dim,
227
+ heads=num_heads,
228
+ kv_heads=num_heads,
229
+ dim_head=dim // num_heads,
230
+ qk_norm=qk_norm,
231
+ eps=eps,
232
+ bias=True,
233
+ cross_attention_dim=None,
234
+ out_bias=True,
235
+ processor=WanAttnProcessor2_0(),
236
+ )
237
+
238
+ # 2. Cross-attention
239
+ self.attn2 = Attention(
240
+ query_dim=dim,
241
+ heads=num_heads,
242
+ kv_heads=num_heads,
243
+ dim_head=dim // num_heads,
244
+ qk_norm=qk_norm,
245
+ eps=eps,
246
+ bias=True,
247
+ cross_attention_dim=None,
248
+ out_bias=True,
249
+ added_kv_proj_dim=added_kv_proj_dim,
250
+ added_proj_bias=True,
251
+ processor=WanAttnProcessor2_0(),
252
+ )
253
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
254
+
255
+ # 3. Feed-forward
256
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
257
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
258
+
259
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
260
+
261
+ def forward(
262
+ self,
263
+ hidden_states: torch.Tensor,
264
+ encoder_hidden_states: torch.Tensor,
265
+ temb: torch.Tensor,
266
+ rotary_emb: torch.Tensor,
267
+ ) -> torch.Tensor:
268
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
269
+ self.scale_shift_table + temb.float()
270
+ ).chunk(6, dim=1)
271
+
272
+ # 1. Self-attention
273
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
274
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
275
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
276
+
277
+ # 2. Cross-attention
278
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
279
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
280
+ hidden_states = hidden_states + attn_output
281
+
282
+ # 3. Feed-forward
283
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
284
+ hidden_states
285
+ )
286
+ ff_output = self.ffn(norm_hidden_states)
287
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
288
+
289
+ return hidden_states
290
+
291
+
292
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
293
+ r"""
294
+ A Transformer model for video-like data used in the Wan model.
295
+
296
+ Args:
297
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
298
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
299
+ num_attention_heads (`int`, defaults to `40`):
300
+ Fixed length for text embeddings.
301
+ attention_head_dim (`int`, defaults to `128`):
302
+ The number of channels in each head.
303
+ in_channels (`int`, defaults to `16`):
304
+ The number of channels in the input.
305
+ out_channels (`int`, defaults to `16`):
306
+ The number of channels in the output.
307
+ text_dim (`int`, defaults to `512`):
308
+ Input dimension for text embeddings.
309
+ freq_dim (`int`, defaults to `256`):
310
+ Dimension for sinusoidal time embeddings.
311
+ ffn_dim (`int`, defaults to `13824`):
312
+ Intermediate dimension in feed-forward network.
313
+ num_layers (`int`, defaults to `40`):
314
+ The number of layers of transformer blocks to use.
315
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
316
+ Window size for local attention (-1 indicates global attention).
317
+ cross_attn_norm (`bool`, defaults to `True`):
318
+ Enable cross-attention normalization.
319
+ qk_norm (`bool`, defaults to `True`):
320
+ Enable query/key normalization.
321
+ eps (`float`, defaults to `1e-6`):
322
+ Epsilon value for normalization layers.
323
+ add_img_emb (`bool`, defaults to `False`):
324
+ Whether to use img_emb.
325
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
326
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
327
+ """
328
+
329
+ _supports_gradient_checkpointing = True
330
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
331
+ _no_split_modules = ["WanTransformerBlock"]
332
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
333
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
334
+
335
+ @register_to_config
336
+ def __init__(
337
+ self,
338
+ patch_size: Tuple[int] = (1, 2, 2),
339
+ num_attention_heads: int = 40,
340
+ attention_head_dim: int = 128,
341
+ in_channels: int = 16,
342
+ out_channels: int = 16,
343
+ text_dim: int = 4096,
344
+ freq_dim: int = 256,
345
+ ffn_dim: int = 13824,
346
+ num_layers: int = 40,
347
+ cross_attn_norm: bool = True,
348
+ qk_norm: Optional[str] = "rms_norm_across_heads",
349
+ eps: float = 1e-6,
350
+ image_dim: Optional[int] = None,
351
+ added_kv_proj_dim: Optional[int] = None,
352
+ rope_max_seq_len: int = 1024,
353
+ ) -> None:
354
+ super().__init__()
355
+
356
+ inner_dim = num_attention_heads * attention_head_dim
357
+ out_channels = out_channels or in_channels
358
+
359
+ # 1. Patch & position embedding
360
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
361
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
362
+
363
+ # 2. Condition embeddings
364
+ # image_embedding_dim=1280 for I2V model
365
+ self.condition_embedder = WanTimeTextImageEmbedding(
366
+ dim=inner_dim,
367
+ time_freq_dim=freq_dim,
368
+ time_proj_dim=inner_dim * 6,
369
+ text_embed_dim=text_dim,
370
+ image_embed_dim=image_dim,
371
+ )
372
+
373
+ # 3. Transformer blocks
374
+ self.blocks = nn.ModuleList(
375
+ [
376
+ WanTransformerBlock(
377
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
378
+ )
379
+ for _ in range(num_layers)
380
+ ]
381
+ )
382
+
383
+ # 4. Output norm & projection
384
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
385
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
386
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
387
+
388
+ self.gradient_checkpointing = False
389
+
390
+ def forward(
391
+ self,
392
+ hidden_states: torch.Tensor,
393
+ timestep: torch.LongTensor,
394
+ encoder_hidden_states: torch.Tensor,
395
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
396
+ return_dict: bool = True,
397
+ attention_kwargs: Optional[Dict[str, Any]] = None,
398
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
399
+ if attention_kwargs is not None:
400
+ attention_kwargs = attention_kwargs.copy()
401
+ lora_scale = attention_kwargs.pop("scale", 1.0)
402
+ else:
403
+ lora_scale = 1.0
404
+
405
+ if USE_PEFT_BACKEND:
406
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
407
+ scale_lora_layers(self, lora_scale)
408
+ else:
409
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
410
+ logger.warning(
411
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
412
+ )
413
+
414
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
415
+ p_t, p_h, p_w = self.config.patch_size
416
+ post_patch_num_frames = num_frames // p_t
417
+ post_patch_height = height // p_h
418
+ post_patch_width = width // p_w
419
+
420
+ rotary_emb = self.rope(hidden_states)
421
+
422
+ hidden_states = self.patch_embedding(hidden_states)
423
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
424
+
425
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
426
+ timestep, encoder_hidden_states, encoder_hidden_states_image
427
+ )
428
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
429
+
430
+ if encoder_hidden_states_image is not None:
431
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
432
+
433
+ # 4. Transformer blocks
434
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
435
+ for block in self.blocks:
436
+ hidden_states = self._gradient_checkpointing_func(
437
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
438
+ )
439
+ else:
440
+ for block in self.blocks:
441
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
442
+
443
+ # 5. Output norm, projection & unpatchify
444
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
445
+
446
+ # Move the shift and scale tensors to the same device as hidden_states.
447
+ # When using multi-GPU inference via accelerate these will be on the
448
+ # first device rather than the last device, which hidden_states ends up
449
+ # on.
450
+ shift = shift.to(hidden_states.device)
451
+ scale = scale.to(hidden_states.device)
452
+
453
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
454
+ hidden_states = self.proj_out(hidden_states)
455
+
456
+ hidden_states = hidden_states.reshape(
457
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
458
+ )
459
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
460
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
461
+
462
+ if USE_PEFT_BACKEND:
463
+ # remove `lora_scale` from each PEFT layer
464
+ unscale_lora_layers(self, lora_scale)
465
+
466
+ if not return_dict:
467
+ return (output,)
468
+
469
+ return Transformer2DModelOutput(sample=output)
@@ -71,6 +71,8 @@ class UNet1DModel(ModelMixin, ConfigMixin):
71
71
  Experimental feature for using a UNet without upsampling.
72
72
  """
73
73
 
74
+ _skip_layerwise_casting_patterns = ["norm"]
75
+
74
76
  @register_to_config
75
77
  def __init__(
76
78
  self,
@@ -223,7 +225,7 @@ class UNet1DModel(ModelMixin, ConfigMixin):
223
225
 
224
226
  timestep_embed = self.time_proj(timesteps)
225
227
  if self.config.use_timestep_embedding:
226
- timestep_embed = self.time_mlp(timestep_embed)
228
+ timestep_embed = self.time_mlp(timestep_embed.to(sample.dtype))
227
229
  else:
228
230
  timestep_embed = timestep_embed[..., None]
229
231
  timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype)
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
58
58
  down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
59
  Tuple of downsample block types.
60
60
  mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
62
62
  up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
63
  Tuple of upsample block types.
64
64
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -90,6 +90,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
90
90
  """
91
91
 
92
92
  _supports_gradient_checkpointing = True
93
+ _skip_layerwise_casting_patterns = ["norm"]
93
94
 
94
95
  @register_to_config
95
96
  def __init__(
@@ -103,6 +104,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
103
104
  freq_shift: int = 0,
104
105
  flip_sin_to_cos: bool = True,
105
106
  down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
107
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
106
108
  up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
107
109
  block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
108
110
  layers_per_block: int = 2,
@@ -194,19 +196,22 @@ class UNet2DModel(ModelMixin, ConfigMixin):
194
196
  self.down_blocks.append(down_block)
195
197
 
196
198
  # mid
197
- self.mid_block = UNetMidBlock2D(
198
- in_channels=block_out_channels[-1],
199
- temb_channels=time_embed_dim,
200
- dropout=dropout,
201
- resnet_eps=norm_eps,
202
- resnet_act_fn=act_fn,
203
- output_scale_factor=mid_block_scale_factor,
204
- resnet_time_scale_shift=resnet_time_scale_shift,
205
- attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
206
- resnet_groups=norm_num_groups,
207
- attn_groups=attn_norm_num_groups,
208
- add_attention=add_attention,
209
- )
199
+ if mid_block_type is None:
200
+ self.mid_block = None
201
+ else:
202
+ self.mid_block = UNetMidBlock2D(
203
+ in_channels=block_out_channels[-1],
204
+ temb_channels=time_embed_dim,
205
+ dropout=dropout,
206
+ resnet_eps=norm_eps,
207
+ resnet_act_fn=act_fn,
208
+ output_scale_factor=mid_block_scale_factor,
209
+ resnet_time_scale_shift=resnet_time_scale_shift,
210
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
211
+ resnet_groups=norm_num_groups,
212
+ attn_groups=attn_norm_num_groups,
213
+ add_attention=add_attention,
214
+ )
210
215
 
211
216
  # up
212
217
  reversed_block_out_channels = list(reversed(block_out_channels))
@@ -235,7 +240,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
235
240
  dropout=dropout,
236
241
  )
237
242
  self.up_blocks.append(up_block)
238
- prev_output_channel = output_channel
239
243
 
240
244
  # out
241
245
  num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32)
@@ -243,10 +247,6 @@ class UNet2DModel(ModelMixin, ConfigMixin):
243
247
  self.conv_act = nn.SiLU()
244
248
  self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
245
249
 
246
- def _set_gradient_checkpointing(self, module, value=False):
247
- if hasattr(module, "gradient_checkpointing"):
248
- module.gradient_checkpointing = value
249
-
250
250
  def forward(
251
251
  self,
252
252
  sample: torch.Tensor,
@@ -322,7 +322,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
322
322
  down_block_res_samples += res_samples
323
323
 
324
324
  # 4. mid
325
- sample = self.mid_block(sample, emb)
325
+ if self.mid_block is not None:
326
+ sample = self.mid_block(sample, emb)
326
327
 
327
328
  # 5. up
328
329
  skip_sample = None